diff --git a/.travis.yml b/.travis.yml index 77dd2ae55..ac9322211 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,12 +1,12 @@ -dist: xenial +dist: bionic language: python -sudo: false services: - docker python: - 2.7 - 3.6 - 3.7 +- 3.8 env: - TEST_SERVER_MODE=false - TEST_SERVER_MODE=true @@ -17,7 +17,14 @@ install: python setup.py sdist if [ "$TEST_SERVER_MODE" = "true" ]; then - docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh & + if [ "$TRAVIS_PYTHON_VERSION" = "3.8" ]; then + # Python 3.8 does not provide Stretch images yet [1] + # [1] https://github.com/docker-library/python/issues/428 + PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-buster + else + PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-stretch + fi + docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh & fi travis_retry pip install boto==2.45.0 travis_retry pip install boto3 @@ -28,8 +35,10 @@ install: if [ "$TEST_SERVER_MODE" = "true" ]; then python wait_for.py fi +before_script: +- if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then make lint; fi script: -- make test +- make test-only after_success: - coveralls before_deploy: diff --git a/AUTHORS.md b/AUTHORS.md index 01b000182..0228ac665 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -57,3 +57,4 @@ Moto is written by Steve Pulec with contributions from: * [Bendeguz Acs](https://github.com/acsbendi) * [Craig Anderson](https://github.com/craiga) * [Robert Lewis](https://github.com/ralewis85) +* [Kyle Jones](https://github.com/Kerl1310) diff --git a/CHANGELOG.md b/CHANGELOG.md index f42619b33..732dad23a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,189 @@ Moto Changelog =================== +1.3.14 +----- + + General Changes: + * Support for Python 3.8 + * Linting: Black is now enforced. + + New Services: + * Athena + * Config + * DataSync + * Step Functions + + New methods: + * Athena: + * create_work_group() + * list_work_groups() + * API Gateway: + * delete_stage() + * update_api_key() + * CloudWatch Logs + * list_tags_log_group() + * tag_log_group() + * untag_log_group() + * Config + * batch_get_resource_config() + * delete_aggregation_authorization() + * delete_configuration_aggregator() + * describe_aggregation_authorizations() + * describe_configuration_aggregators() + * get_resource_config_history() + * list_aggregate_discovered_resources() (For S3) + * list_discovered_resources() (For S3) + * put_aggregation_authorization() + * put_configuration_aggregator() + * Cognito + * assume_role_with_web_identity() + * describe_identity_pool() + * get_open_id_token() + * update_user_pool_domain() + * DataSync: + * cancel_task_execution() + * create_location() + * create_task() + * start_task_execution() + * EC2: + * create_launch_template() + * create_launch_template_version() + * describe_launch_template_versions() + * describe_launch_templates() + * ECS + * decrypt() + * encrypt() + * generate_data_key_without_plaintext() + * generate_random() + * re_encrypt() + * Glue + * batch_get_partition() + * IAM + * create_open_id_connect_provider() + * create_virtual_mfa_device() + * delete_account_password_policy() + * delete_open_id_connect_provider() + * delete_policy() + * delete_virtual_mfa_device() + * get_account_password_policy() + * get_open_id_connect_provider() + * list_open_id_connect_providers() + * list_virtual_mfa_devices() + * update_account_password_policy() + * Lambda + * create_event_source_mapping() + * delete_event_source_mapping() + * get_event_source_mapping() + * list_event_source_mappings() + * update_configuration() + * update_event_source_mapping() + * update_function_code() + * KMS + * decrypt() + * encrypt() + * generate_data_key_without_plaintext() + * generate_random() + * re_encrypt() + * SES + * send_templated_email() + * SNS + * add_permission() + * list_tags_for_resource() + * remove_permission() + * tag_resource() + * untag_resource() + * SSM + * describe_parameters() + * get_parameter_history() + * Step Functions + * create_state_machine() + * delete_state_machine() + * describe_execution() + * describe_state_machine() + * describe_state_machine_for_execution() + * list_executions() + * list_state_machines() + * list_tags_for_resource() + * start_execution() + * stop_execution() + SQS + * list_queue_tags() + * send_message_batch() + + General updates: + * API Gateway: + * Now generates valid IDs + * API Keys, Usage Plans now support tags + * ACM: + * list_certificates() accepts the status parameter + * Batch: + * submit_job() can now be called with job name + * CloudWatch Events + * Multi-region support + * CloudWatch Logs + * get_log_events() now supports pagination + * Cognito: + * Now throws UsernameExistsException for known users + * DynamoDB + * update_item() now supports lists, the list_append-operator and removing nested items + * delete_item() now supports condition expressions + * get_item() now supports projection expression + * Enforces 400KB item size + * Validation on duplicate keys in batch_get_item() + * Validation on AttributeDefinitions on create_table() + * Validation on Query Key Expression + * Projection Expressions now support nested attributes + * EC2: + * Change DesiredCapacity behaviour for AutoScaling groups + * Extend list of supported EC2 ENI properties + * Create ASG from Instance now supported + * ASG attached to a terminated instance now recreate the instance of required + * Unify OwnerIDs + * ECS + * Task definition revision deregistration: remaining revisions now remain unchanged + * Fix created_at/updated_at format for deployments + * Support multiple regions + * ELB + * Return correct response then describing target health of stopped instances + * Target groups now longer show terminated instances + * 'fixed-response' now a supported action-type + * Now supports redirect: authenticate-cognito + * Kinesis FireHose + * Now supports ExtendedS3DestinationConfiguration + * KMS + * Now supports tags + * Organizations + * create_organization() now creates Master account + * Redshift + * Fix timezone problems when creating a cluster + * Support for enhanced_vpc_routing-parameter + * Route53 + * Implemented UPSERT for change_resource_records + * S3: + * Support partNumber for head_object + * Support for INTELLIGENT_TIERING, GLACIER and DEEP_ARCHIVE + * Fix KeyCount attribute + * list_objects now supports pagination (next_marker) + * Support tagging for versioned objects + * STS + * Implement validation on policy length + * Lambda + * Support EventSourceMappings for SQS, DynamoDB + * get_function(), delete_function() now both support ARNs as parameters + * IAM + * Roles now support tags + * Policy Validation: SID can be empty + * Validate roles have no attachments when deleting + * SecretsManager + * Now supports binary secrets + * IOT + * update_thing_shadow validation + * delete_thing now also removed principals + * SQS + * Tags supported for create_queue() + + 1.3.7 ----- diff --git a/CONFIG_README.md b/CONFIG_README.md new file mode 100644 index 000000000..356bb87a0 --- /dev/null +++ b/CONFIG_README.md @@ -0,0 +1,120 @@ +# AWS Config Querying Support in Moto + +An experimental feature for AWS Config has been developed to provide AWS Config capabilities in your unit tests. +This feature is experimental as there are many services that are not yet supported and will require the community to add them in +over time. This page details how the feature works and how you can use it. + +## What is this and why would I use this? + +AWS Config is an AWS service that describes your AWS resource types and can track their changes over time. At this time, moto does not +have support for handling the configuration history changes, but it does have a few methods mocked out that can be immensely useful +for unit testing. + +If you are developing automation that needs to pull against AWS Config, then this will help you write tests that can simulate your +code in production. + +## How does this work? + +The AWS Config capabilities in moto work by examining the state of resources that are created within moto, and then returning that data +in the way that AWS Config would return it (sans history). This will work by querying all of the moto backends (regions) for a given +resource type. + +However, this will only work on resource types that have this enabled. + +### Current enabled resource types: + +1. S3 + + +## Developer Guide + +There are several pieces to this for adding new capabilities to moto: + +1. Listing resources +1. Describing resources + +For both, there are a number of pre-requisites: + +### Base Components + +In the `moto/core/models.py` file is a class named `ConfigQueryModel`. This is a base class that keeps track of all the +resource type backends. + +At a minimum, resource types that have this enabled will have: + +1. A `config.py` file that will import the resource type backends (from the `__init__.py`) +1. In the resource's `config.py`, an implementation of the `ConfigQueryModel` class with logic unique to the resource type +1. An instantiation of the `ConfigQueryModel` +1. In the `moto/config/models.py` file, import the `ConfigQueryModel` instantiation, and update `RESOURCE_MAP` to have a mapping of the AWS Config resource type + to the instantiation on the previous step (just imported). + +An example of the above is implemented for S3. You can see that by looking at: + +1. `moto/s3/config.py` +1. `moto/config/models.py` + +As well as the corresponding unit tests in: + +1. `tests/s3/test_s3.py` +1. `tests/config/test_config.py` + +Note for unit testing, you will want to add a test to ensure that you can query all the resources effectively. For testing this feature, +the unit tests for the `ConfigQueryModel` will not make use of `boto` to create resources, such as S3 buckets. You will need to use the +backend model methods to provision the resources. This is to make tests compatible with the moto server. You should absolutely make tests +in the resource type to test listing and object fetching. + +### Listing +S3 is currently the model implementation, but it also odd in that S3 is a global resource type with regional resource residency. + +But for most resource types the following is true: + +1. There are regional backends with their own sets of data +1. Config aggregation can pull data from any backend region -- we assume that everything lives in the same account + +Implementing the listing capability will be different for each resource type. At a minimum, you will need to return a `List` of `Dict`s +that look like this: + +```python + [ + { + 'type': 'AWS::The AWS Config data type', + 'name': 'The name of the resource', + 'id': 'The ID of the resource', + 'region': 'The region of the resource -- if global, then you may want to have the calling logic pass in the + aggregator region in for the resource region -- or just us-east-1 :P' + } + , ... +] +``` + +It's recommended to read the comment for the `ConfigQueryModel`'s `list_config_service_resources` function in [base class here](moto/core/models.py). + +^^ The AWS Config code will see this and format it correct for both aggregated and non-aggregated calls. + +#### General implementation tips +The aggregation and non-aggregation querying can and should just use the same overall logic. The differences are: + +1. Non-aggregated listing will specify the region-name of the resource backend `backend_region` +1. Aggregated listing will need to be able to list resource types across ALL backends and filter optionally by passing in `resource_region`. + +An example of a working implementation of this is [S3](moto/s3/config.py). + +Pagination should generally be able to pull out the resource across any region so should be sharded by `region-item-name` -- not done for S3 +because S3 has a globally unique name space. + +### Describing Resources +Fetching a resource's configuration has some similarities to listing resources, but it requires more work (to implement). Due to the +various ways that a resource can be configured, some work will need to be done to ensure that the Config dict returned is correct. + +For most resource types the following is true: + +1. There are regional backends with their own sets of data +1. Config aggregation can pull data from any backend region -- we assume that everything lives in the same account + +The current implementation is for S3. S3 is very complex and depending on how the bucket is configured will depend on what Config will +return for it. + +When implementing resource config fetching, you will need to return at a minimum `None` if the resource is not found, or a `dict` that looks +like what AWS Config would return. + +It's recommended to read the comment for the `ConfigQueryModel` 's `get_config_resource` function in [base class here](moto/core/models.py). diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index d149b0dd8..5e6ef1c9e 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -1,4 +1,25 @@ +## accessanalyzer +0% implemented +- [ ] create_analyzer +- [ ] create_archive_rule +- [ ] delete_analyzer +- [ ] delete_archive_rule +- [ ] get_analyzed_resource +- [ ] get_analyzer +- [ ] get_archive_rule +- [ ] get_finding +- [ ] list_analyzed_resources +- [ ] list_analyzers +- [ ] list_archive_rules +- [ ] list_findings +- [ ] list_tags_for_resource +- [ ] start_resource_scan +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_archive_rule +- [ ] update_findings + ## acm 38% implemented - [X] add_tags_to_certificate @@ -137,21 +158,28 @@ ## amplify 0% implemented - [ ] create_app +- [ ] create_backend_environment - [ ] create_branch - [ ] create_deployment - [ ] create_domain_association - [ ] create_webhook - [ ] delete_app +- [ ] delete_backend_environment - [ ] delete_branch - [ ] delete_domain_association - [ ] delete_job - [ ] delete_webhook +- [ ] generate_access_logs - [ ] get_app +- [ ] get_artifact_url +- [ ] get_backend_environment - [ ] get_branch - [ ] get_domain_association - [ ] get_job - [ ] get_webhook - [ ] list_apps +- [ ] list_artifacts +- [ ] list_backend_environments - [ ] list_branches - [ ] list_domain_associations - [ ] list_jobs @@ -168,7 +196,7 @@ - [ ] update_webhook ## apigateway -24% implemented +25% implemented - [ ] create_api_key - [ ] create_authorizer - [ ] create_base_path_mapping @@ -201,7 +229,7 @@ - [ ] delete_request_validator - [X] delete_resource - [X] delete_rest_api -- [ ] delete_stage +- [X] delete_stage - [X] delete_usage_plan - [X] delete_usage_plan_key - [ ] delete_vpc_link @@ -292,6 +320,8 @@ ## apigatewaymanagementapi 0% implemented +- [ ] delete_connection +- [ ] get_connection - [ ] post_to_connection ## apigatewayv2 @@ -310,6 +340,7 @@ - [ ] delete_api - [ ] delete_api_mapping - [ ] delete_authorizer +- [ ] delete_cors_configuration - [ ] delete_deployment - [ ] delete_domain_name - [ ] delete_integration @@ -317,6 +348,7 @@ - [ ] delete_model - [ ] delete_route - [ ] delete_route_response +- [ ] delete_route_settings - [ ] delete_stage - [ ] get_api - [ ] get_api_mapping @@ -342,6 +374,8 @@ - [ ] get_stage - [ ] get_stages - [ ] get_tags +- [ ] import_api +- [ ] reimport_api - [ ] tag_resource - [ ] untag_resource - [ ] update_api @@ -356,6 +390,38 @@ - [ ] update_route_response - [ ] update_stage +## appconfig +0% implemented +- [ ] create_application +- [ ] create_configuration_profile +- [ ] create_deployment_strategy +- [ ] create_environment +- [ ] delete_application +- [ ] delete_configuration_profile +- [ ] delete_deployment_strategy +- [ ] delete_environment +- [ ] get_application +- [ ] get_configuration +- [ ] get_configuration_profile +- [ ] get_deployment +- [ ] get_deployment_strategy +- [ ] get_environment +- [ ] list_applications +- [ ] list_configuration_profiles +- [ ] list_deployment_strategies +- [ ] list_deployments +- [ ] list_environments +- [ ] list_tags_for_resource +- [ ] start_deployment +- [ ] stop_deployment +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_application +- [ ] update_configuration_profile +- [ ] update_deployment_strategy +- [ ] update_environment +- [ ] validate_configuration + ## application-autoscaling 0% implemented - [ ] delete_scaling_policy @@ -373,20 +439,30 @@ 0% implemented - [ ] create_application - [ ] create_component +- [ ] create_log_pattern - [ ] delete_application - [ ] delete_component +- [ ] delete_log_pattern - [ ] describe_application - [ ] describe_component - [ ] describe_component_configuration - [ ] describe_component_configuration_recommendation +- [ ] describe_log_pattern - [ ] describe_observation - [ ] describe_problem - [ ] describe_problem_observations - [ ] list_applications - [ ] list_components +- [ ] list_log_pattern_sets +- [ ] list_log_patterns - [ ] list_problems +- [ ] list_tags_for_resource +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_application - [ ] update_component - [ ] update_component_configuration +- [ ] update_log_pattern ## appmesh 0% implemented @@ -471,18 +547,22 @@ ## appsync 0% implemented +- [ ] create_api_cache - [ ] create_api_key - [ ] create_data_source - [ ] create_function - [ ] create_graphql_api - [ ] create_resolver - [ ] create_type +- [ ] delete_api_cache - [ ] delete_api_key - [ ] delete_data_source - [ ] delete_function - [ ] delete_graphql_api - [ ] delete_resolver - [ ] delete_type +- [ ] flush_api_cache +- [ ] get_api_cache - [ ] get_data_source - [ ] get_function - [ ] get_graphql_api @@ -501,6 +581,7 @@ - [ ] start_schema_creation - [ ] tag_resource - [ ] untag_resource +- [ ] update_api_cache - [ ] update_api_key - [ ] update_data_source - [ ] update_function @@ -509,11 +590,11 @@ - [ ] update_type ## athena -0% implemented +10% implemented - [ ] batch_get_named_query - [ ] batch_get_query_execution - [ ] create_named_query -- [ ] create_work_group +- [X] create_work_group - [ ] delete_named_query - [ ] delete_work_group - [ ] get_named_query @@ -523,7 +604,7 @@ - [ ] list_named_queries - [ ] list_query_executions - [ ] list_tags_for_resource -- [ ] list_work_groups +- [X] list_work_groups - [ ] start_query_execution - [ ] stop_query_execution - [ ] tag_resource @@ -680,62 +761,101 @@ ## ce 0% implemented +- [ ] create_cost_category_definition +- [ ] delete_cost_category_definition +- [ ] describe_cost_category_definition - [ ] get_cost_and_usage +- [ ] get_cost_and_usage_with_resources - [ ] get_cost_forecast - [ ] get_dimension_values - [ ] get_reservation_coverage - [ ] get_reservation_purchase_recommendation - [ ] get_reservation_utilization - [ ] get_rightsizing_recommendation +- [ ] get_savings_plans_coverage +- [ ] get_savings_plans_purchase_recommendation +- [ ] get_savings_plans_utilization +- [ ] get_savings_plans_utilization_details - [ ] get_tags - [ ] get_usage_forecast +- [ ] list_cost_category_definitions +- [ ] update_cost_category_definition ## chime 0% implemented - [ ] associate_phone_number_with_user - [ ] associate_phone_numbers_with_voice_connector +- [ ] associate_phone_numbers_with_voice_connector_group +- [ ] batch_create_attendee +- [ ] batch_create_room_membership - [ ] batch_delete_phone_number - [ ] batch_suspend_user - [ ] batch_unsuspend_user - [ ] batch_update_phone_number - [ ] batch_update_user - [ ] create_account +- [ ] create_attendee - [ ] create_bot +- [ ] create_meeting - [ ] create_phone_number_order +- [ ] create_room +- [ ] create_room_membership - [ ] create_voice_connector +- [ ] create_voice_connector_group - [ ] delete_account +- [ ] delete_attendee - [ ] delete_events_configuration +- [ ] delete_meeting - [ ] delete_phone_number +- [ ] delete_room +- [ ] delete_room_membership - [ ] delete_voice_connector +- [ ] delete_voice_connector_group - [ ] delete_voice_connector_origination +- [ ] delete_voice_connector_streaming_configuration - [ ] delete_voice_connector_termination - [ ] delete_voice_connector_termination_credentials - [ ] disassociate_phone_number_from_user - [ ] disassociate_phone_numbers_from_voice_connector +- [ ] disassociate_phone_numbers_from_voice_connector_group - [ ] get_account - [ ] get_account_settings +- [ ] get_attendee - [ ] get_bot - [ ] get_events_configuration - [ ] get_global_settings +- [ ] get_meeting - [ ] get_phone_number - [ ] get_phone_number_order +- [ ] get_phone_number_settings +- [ ] get_room - [ ] get_user - [ ] get_user_settings - [ ] get_voice_connector +- [ ] get_voice_connector_group +- [ ] get_voice_connector_logging_configuration - [ ] get_voice_connector_origination +- [ ] get_voice_connector_streaming_configuration - [ ] get_voice_connector_termination - [ ] get_voice_connector_termination_health - [ ] invite_users - [ ] list_accounts +- [ ] list_attendees - [ ] list_bots +- [ ] list_meetings - [ ] list_phone_number_orders - [ ] list_phone_numbers +- [ ] list_room_memberships +- [ ] list_rooms - [ ] list_users +- [ ] list_voice_connector_groups - [ ] list_voice_connector_termination_credentials - [ ] list_voice_connectors - [ ] logout_user - [ ] put_events_configuration +- [ ] put_voice_connector_logging_configuration - [ ] put_voice_connector_origination +- [ ] put_voice_connector_streaming_configuration - [ ] put_voice_connector_termination - [ ] put_voice_connector_termination_credentials - [ ] regenerate_security_token @@ -747,9 +867,13 @@ - [ ] update_bot - [ ] update_global_settings - [ ] update_phone_number +- [ ] update_phone_number_settings +- [ ] update_room +- [ ] update_room_membership - [ ] update_user - [ ] update_user_settings - [ ] update_voice_connector +- [ ] update_voice_connector_group ## cloud9 0% implemented @@ -834,7 +958,7 @@ - [ ] upgrade_published_schema ## cloudformation -40% implemented +32% implemented - [ ] cancel_update_stack - [ ] continue_update_rollback - [X] create_change_set @@ -845,6 +969,7 @@ - [X] delete_stack - [X] delete_stack_instances - [X] delete_stack_set +- [ ] deregister_type - [ ] describe_account_limits - [X] describe_change_set - [ ] describe_stack_drift_detection_status @@ -856,8 +981,11 @@ - [ ] describe_stack_set - [ ] describe_stack_set_operation - [X] describe_stacks +- [ ] describe_type +- [ ] describe_type_registration - [ ] detect_stack_drift - [ ] detect_stack_resource_drift +- [ ] detect_stack_set_drift - [ ] estimate_template_cost - [X] execute_change_set - [ ] get_stack_policy @@ -872,7 +1000,13 @@ - [ ] list_stack_set_operations - [ ] list_stack_sets - [X] list_stacks +- [ ] list_type_registrations +- [ ] list_type_versions +- [ ] list_types +- [ ] record_handler_progress +- [ ] register_type - [ ] set_stack_policy +- [ ] set_type_default_version - [ ] signal_resource - [ ] stop_stack_set_operation - [X] update_stack @@ -983,6 +1117,7 @@ - [ ] delete_suggester - [ ] describe_analysis_schemes - [ ] describe_availability_options +- [ ] describe_domain_endpoint_options - [ ] describe_domains - [ ] describe_expressions - [ ] describe_index_fields @@ -992,6 +1127,7 @@ - [ ] index_documents - [ ] list_domain_names - [ ] update_availability_options +- [ ] update_domain_endpoint_options - [ ] update_scaling_parameters - [ ] update_service_access_policies @@ -1008,36 +1144,46 @@ - [ ] delete_trail - [ ] describe_trails - [ ] get_event_selectors +- [ ] get_insight_selectors +- [ ] get_trail - [ ] get_trail_status - [ ] list_public_keys - [ ] list_tags +- [ ] list_trails - [ ] lookup_events - [ ] put_event_selectors +- [ ] put_insight_selectors - [ ] remove_tags - [ ] start_logging - [ ] stop_logging - [ ] update_trail ## cloudwatch -39% implemented +34% implemented - [X] delete_alarms - [ ] delete_anomaly_detector - [X] delete_dashboards +- [ ] delete_insight_rules - [ ] describe_alarm_history - [ ] describe_alarms - [ ] describe_alarms_for_metric - [ ] describe_anomaly_detectors +- [ ] describe_insight_rules - [ ] disable_alarm_actions +- [ ] disable_insight_rules - [ ] enable_alarm_actions +- [ ] enable_insight_rules - [X] get_dashboard +- [ ] get_insight_rule_report - [ ] get_metric_data - [X] get_metric_statistics - [ ] get_metric_widget_image - [X] list_dashboards -- [ ] list_metrics +- [X] list_metrics - [ ] list_tags_for_resource - [ ] put_anomaly_detector - [X] put_dashboard +- [ ] put_insight_rule - [X] put_metric_alarm - [X] put_metric_data - [X] set_alarm_state @@ -1049,38 +1195,64 @@ - [ ] batch_delete_builds - [ ] batch_get_builds - [ ] batch_get_projects +- [ ] batch_get_report_groups +- [ ] batch_get_reports - [ ] create_project +- [ ] create_report_group - [ ] create_webhook - [ ] delete_project +- [ ] delete_report +- [ ] delete_report_group +- [ ] delete_resource_policy - [ ] delete_source_credentials - [ ] delete_webhook +- [ ] describe_test_cases +- [ ] get_resource_policy - [ ] import_source_credentials - [ ] invalidate_project_cache - [ ] list_builds - [ ] list_builds_for_project - [ ] list_curated_environment_images - [ ] list_projects +- [ ] list_report_groups +- [ ] list_reports +- [ ] list_reports_for_report_group +- [ ] list_shared_projects +- [ ] list_shared_report_groups - [ ] list_source_credentials +- [ ] put_resource_policy - [ ] start_build - [ ] stop_build - [ ] update_project +- [ ] update_report_group - [ ] update_webhook ## codecommit 0% implemented +- [ ] associate_approval_rule_template_with_repository +- [ ] batch_associate_approval_rule_template_with_repositories - [ ] batch_describe_merge_conflicts +- [ ] batch_disassociate_approval_rule_template_from_repositories +- [ ] batch_get_commits - [ ] batch_get_repositories +- [ ] create_approval_rule_template - [ ] create_branch - [ ] create_commit - [ ] create_pull_request +- [ ] create_pull_request_approval_rule - [ ] create_repository - [ ] create_unreferenced_merge_commit +- [ ] delete_approval_rule_template - [ ] delete_branch - [ ] delete_comment_content - [ ] delete_file +- [ ] delete_pull_request_approval_rule - [ ] delete_repository - [ ] describe_merge_conflicts - [ ] describe_pull_request_events +- [ ] disassociate_approval_rule_template_from_repository +- [ ] evaluate_pull_request_approval_rules +- [ ] get_approval_rule_template - [ ] get_blob - [ ] get_branch - [ ] get_comment @@ -1094,11 +1266,16 @@ - [ ] get_merge_conflicts - [ ] get_merge_options - [ ] get_pull_request +- [ ] get_pull_request_approval_states +- [ ] get_pull_request_override_state - [ ] get_repository - [ ] get_repository_triggers +- [ ] list_approval_rule_templates +- [ ] list_associated_approval_rule_templates_for_repository - [ ] list_branches - [ ] list_pull_requests - [ ] list_repositories +- [ ] list_repositories_for_approval_rule_template - [ ] list_tags_for_resource - [ ] merge_branches_by_fast_forward - [ ] merge_branches_by_squash @@ -1106,6 +1283,7 @@ - [ ] merge_pull_request_by_fast_forward - [ ] merge_pull_request_by_squash - [ ] merge_pull_request_by_three_way +- [ ] override_pull_request_approval_rules - [ ] post_comment_for_compared_commit - [ ] post_comment_for_pull_request - [ ] post_comment_reply @@ -1114,8 +1292,13 @@ - [ ] tag_resource - [ ] test_repository_triggers - [ ] untag_resource +- [ ] update_approval_rule_template_content +- [ ] update_approval_rule_template_description +- [ ] update_approval_rule_template_name - [ ] update_comment - [ ] update_default_branch +- [ ] update_pull_request_approval_rule_content +- [ ] update_pull_request_approval_state - [ ] update_pull_request_description - [ ] update_pull_request_status - [ ] update_pull_request_title @@ -1171,27 +1354,46 @@ - [ ] update_application - [ ] update_deployment_group -## codepipeline +## codeguru-reviewer 0% implemented +- [ ] associate_repository +- [ ] describe_repository_association +- [ ] disassociate_repository +- [ ] list_repository_associations + +## codeguruprofiler +0% implemented +- [ ] configure_agent +- [ ] create_profiling_group +- [ ] delete_profiling_group +- [ ] describe_profiling_group +- [ ] get_profile +- [ ] list_profile_times +- [ ] list_profiling_groups +- [ ] post_agent_profile +- [ ] update_profiling_group + +## codepipeline +13% implemented - [ ] acknowledge_job - [ ] acknowledge_third_party_job - [ ] create_custom_action_type -- [ ] create_pipeline +- [X] create_pipeline - [ ] delete_custom_action_type -- [ ] delete_pipeline +- [X] delete_pipeline - [ ] delete_webhook - [ ] deregister_webhook_with_third_party - [ ] disable_stage_transition - [ ] enable_stage_transition - [ ] get_job_details -- [ ] get_pipeline +- [X] get_pipeline - [ ] get_pipeline_execution - [ ] get_pipeline_state - [ ] get_third_party_job_details - [ ] list_action_executions - [ ] list_action_types - [ ] list_pipeline_executions -- [ ] list_pipelines +- [X] list_pipelines - [ ] list_tags_for_resource - [ ] list_webhooks - [ ] poll_for_jobs @@ -1208,7 +1410,7 @@ - [ ] start_pipeline_execution - [ ] tag_resource - [ ] untag_resource -- [ ] update_pipeline +- [X] update_pipeline ## codestar 0% implemented @@ -1231,13 +1433,29 @@ - [ ] update_team_member - [ ] update_user_profile +## codestar-notifications +0% implemented +- [ ] create_notification_rule +- [ ] delete_notification_rule +- [ ] delete_target +- [ ] describe_notification_rule +- [ ] list_event_types +- [ ] list_notification_rules +- [ ] list_tags_for_resource +- [ ] list_targets +- [ ] subscribe +- [ ] tag_resource +- [ ] unsubscribe +- [ ] untag_resource +- [ ] update_notification_rule + ## cognito-identity -23% implemented +28% implemented - [X] create_identity_pool - [ ] delete_identities - [ ] delete_identity_pool - [ ] describe_identity -- [ ] describe_identity_pool +- [X] describe_identity_pool - [X] get_credentials_for_identity - [X] get_id - [ ] get_identity_pool_roles @@ -1385,13 +1603,17 @@ - [ ] batch_detect_key_phrases - [ ] batch_detect_sentiment - [ ] batch_detect_syntax +- [ ] classify_document - [ ] create_document_classifier +- [ ] create_endpoint - [ ] create_entity_recognizer - [ ] delete_document_classifier +- [ ] delete_endpoint - [ ] delete_entity_recognizer - [ ] describe_document_classification_job - [ ] describe_document_classifier - [ ] describe_dominant_language_detection_job +- [ ] describe_endpoint - [ ] describe_entities_detection_job - [ ] describe_entity_recognizer - [ ] describe_key_phrases_detection_job @@ -1405,6 +1627,7 @@ - [ ] list_document_classification_jobs - [ ] list_document_classifiers - [ ] list_dominant_language_detection_jobs +- [ ] list_endpoints - [ ] list_entities_detection_jobs - [ ] list_entity_recognizers - [ ] list_key_phrases_detection_jobs @@ -1425,25 +1648,48 @@ - [ ] stop_training_entity_recognizer - [ ] tag_resource - [ ] untag_resource +- [ ] update_endpoint ## comprehendmedical 0% implemented +- [ ] describe_entities_detection_v2_job +- [ ] describe_phi_detection_job - [ ] detect_entities +- [ ] detect_entities_v2 - [ ] detect_phi +- [ ] list_entities_detection_v2_jobs +- [ ] list_phi_detection_jobs +- [ ] start_entities_detection_v2_job +- [ ] start_phi_detection_job +- [ ] stop_entities_detection_v2_job +- [ ] stop_phi_detection_job + +## compute-optimizer +0% implemented +- [ ] get_auto_scaling_group_recommendations +- [ ] get_ec2_instance_recommendations +- [ ] get_ec2_recommendation_projected_metrics +- [ ] get_enrollment_status +- [ ] get_recommendation_summaries +- [ ] update_enrollment_status ## config -24% implemented -- [ ] batch_get_aggregate_resource_config -- [ ] batch_get_resource_config +25% implemented +- [X] batch_get_aggregate_resource_config +- [X] batch_get_resource_config - [X] delete_aggregation_authorization - [ ] delete_config_rule - [X] delete_configuration_aggregator - [X] delete_configuration_recorder +- [ ] delete_conformance_pack - [X] delete_delivery_channel - [ ] delete_evaluation_results - [ ] delete_organization_config_rule +- [ ] delete_organization_conformance_pack - [ ] delete_pending_aggregation_request - [ ] delete_remediation_configuration +- [ ] delete_remediation_exceptions +- [ ] delete_resource_config - [ ] delete_retention_configuration - [ ] deliver_config_snapshot - [ ] describe_aggregate_compliance_by_config_rules @@ -1456,12 +1702,18 @@ - [X] describe_configuration_aggregators - [X] describe_configuration_recorder_status - [X] describe_configuration_recorders +- [ ] describe_conformance_pack_compliance +- [ ] describe_conformance_pack_status +- [ ] describe_conformance_packs - [ ] describe_delivery_channel_status - [X] describe_delivery_channels - [ ] describe_organization_config_rule_statuses - [ ] describe_organization_config_rules +- [ ] describe_organization_conformance_pack_statuses +- [ ] describe_organization_conformance_packs - [ ] describe_pending_aggregation_requests - [ ] describe_remediation_configurations +- [ ] describe_remediation_exceptions - [ ] describe_remediation_execution_status - [ ] describe_retention_configurations - [ ] get_aggregate_compliance_details_by_config_rule @@ -1472,20 +1724,27 @@ - [ ] get_compliance_details_by_resource - [ ] get_compliance_summary_by_config_rule - [ ] get_compliance_summary_by_resource_type +- [ ] get_conformance_pack_compliance_details +- [ ] get_conformance_pack_compliance_summary - [ ] get_discovered_resource_counts - [ ] get_organization_config_rule_detailed_status -- [ ] get_resource_config_history -- [ ] list_aggregate_discovered_resources -- [ ] list_discovered_resources +- [ ] get_organization_conformance_pack_detailed_status +- [X] get_resource_config_history +- [X] list_aggregate_discovered_resources +- [X] list_discovered_resources - [ ] list_tags_for_resource - [X] put_aggregation_authorization - [ ] put_config_rule - [X] put_configuration_aggregator - [X] put_configuration_recorder +- [ ] put_conformance_pack - [X] put_delivery_channel - [ ] put_evaluations - [ ] put_organization_config_rule +- [ ] put_organization_conformance_pack - [ ] put_remediation_configurations +- [ ] put_remediation_exceptions +- [ ] put_resource_config - [ ] put_retention_configuration - [ ] select_resource_config - [ ] start_config_rules_evaluation @@ -1506,12 +1765,20 @@ - [ ] get_current_metric_data - [ ] get_federation_token - [ ] get_metric_data +- [ ] list_contact_flows +- [ ] list_hours_of_operations +- [ ] list_phone_numbers +- [ ] list_queues - [ ] list_routing_profiles - [ ] list_security_profiles +- [ ] list_tags_for_resource - [ ] list_user_hierarchy_groups - [ ] list_users +- [ ] start_chat_contact - [ ] start_outbound_voice_contact - [ ] stop_contact +- [ ] tag_resource +- [ ] untag_resource - [ ] update_contact_attributes - [ ] update_user_hierarchy - [ ] update_user_identity_info @@ -1519,12 +1786,46 @@ - [ ] update_user_routing_profile - [ ] update_user_security_profiles +## connectparticipant +0% implemented +- [ ] create_participant_connection +- [ ] disconnect_participant +- [ ] get_transcript +- [ ] send_event +- [ ] send_message + ## cur 0% implemented - [ ] delete_report_definition - [ ] describe_report_definitions +- [ ] modify_report_definition - [ ] put_report_definition +## dataexchange +0% implemented +- [ ] cancel_job +- [ ] create_data_set +- [ ] create_job +- [ ] create_revision +- [ ] delete_asset +- [ ] delete_data_set +- [ ] delete_revision +- [ ] get_asset +- [ ] get_data_set +- [ ] get_job +- [ ] get_revision +- [ ] list_data_set_revisions +- [ ] list_data_sets +- [ ] list_jobs +- [ ] list_revision_assets +- [ ] list_tags_for_resource +- [ ] start_job +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_asset +- [ ] update_data_set +- [ ] update_revision + ## datapipeline 42% implemented - [X] activate_pipeline @@ -1548,20 +1849,22 @@ - [ ] validate_pipeline_definition ## datasync -0% implemented -- [ ] cancel_task_execution +22% implemented +- [X] cancel_task_execution - [ ] create_agent - [ ] create_location_efs - [ ] create_location_nfs - [ ] create_location_s3 -- [ ] create_task +- [ ] create_location_smb +- [X] create_task - [ ] delete_agent -- [ ] delete_location -- [ ] delete_task +- [X] delete_location +- [X] delete_task - [ ] describe_agent - [ ] describe_location_efs - [ ] describe_location_nfs - [ ] describe_location_s3 +- [ ] describe_location_smb - [ ] describe_task - [ ] describe_task_execution - [ ] list_agents @@ -1569,11 +1872,11 @@ - [ ] list_tags_for_resource - [ ] list_task_executions - [ ] list_tasks -- [ ] start_task_execution +- [X] start_task_execution - [ ] tag_resource - [ ] untag_resource - [ ] update_agent -- [ ] update_task +- [X] update_task ## dax 0% implemented @@ -1599,6 +1902,20 @@ - [ ] update_parameter_group - [ ] update_subnet_group +## detective +0% implemented +- [ ] accept_invitation +- [ ] create_graph +- [ ] create_members +- [ ] delete_graph +- [ ] delete_members +- [ ] disassociate_membership +- [ ] get_members +- [ ] list_graphs +- [ ] list_invitations +- [ ] list_members +- [ ] reject_invitation + ## devicefarm 0% implemented - [ ] create_device_pool @@ -1759,6 +2076,9 @@ - [ ] delete_lifecycle_policy - [ ] get_lifecycle_policies - [ ] get_lifecycle_policy +- [ ] list_tags_for_resource +- [ ] tag_resource +- [ ] untag_resource - [ ] update_lifecycle_policy ## dms @@ -1771,6 +2091,7 @@ - [ ] create_replication_subnet_group - [ ] create_replication_task - [ ] delete_certificate +- [ ] delete_connection - [ ] delete_endpoint - [ ] delete_event_subscription - [ ] delete_replication_instance @@ -1826,6 +2147,7 @@ - [ ] delete_db_cluster_snapshot - [ ] delete_db_instance - [ ] delete_db_subnet_group +- [ ] describe_certificates - [ ] describe_db_cluster_parameter_groups - [ ] describe_db_cluster_parameters - [ ] describe_db_cluster_snapshot_attributes @@ -1874,24 +2196,31 @@ - [ ] delete_log_subscription - [ ] delete_snapshot - [ ] delete_trust +- [ ] deregister_certificate - [ ] deregister_event_topic +- [ ] describe_certificate - [ ] describe_conditional_forwarders - [ ] describe_directories - [ ] describe_domain_controllers - [ ] describe_event_topics +- [ ] describe_ldaps_settings - [ ] describe_shared_directories - [ ] describe_snapshots - [ ] describe_trusts +- [ ] disable_ldaps - [ ] disable_radius - [ ] disable_sso +- [ ] enable_ldaps - [ ] enable_radius - [ ] enable_sso - [ ] get_directory_limits - [ ] get_snapshot_limits +- [ ] list_certificates - [ ] list_ip_routes - [ ] list_log_subscriptions - [ ] list_schema_extensions - [ ] list_tags_for_resource +- [ ] register_certificate - [ ] register_event_topic - [ ] reject_shared_directory - [ ] remove_ip_routes @@ -1908,7 +2237,7 @@ - [ ] verify_trust ## dynamodb -19% implemented +17% implemented - [ ] batch_get_item - [ ] batch_write_item - [ ] create_backup @@ -1919,14 +2248,17 @@ - [X] delete_table - [ ] describe_backup - [ ] describe_continuous_backups +- [ ] describe_contributor_insights - [ ] describe_endpoints - [ ] describe_global_table - [ ] describe_global_table_settings - [ ] describe_limits - [ ] describe_table +- [ ] describe_table_replica_auto_scaling - [ ] describe_time_to_live - [X] get_item - [ ] list_backups +- [ ] list_contributor_insights - [ ] list_global_tables - [ ] list_tables - [ ] list_tags_of_resource @@ -1940,10 +2272,12 @@ - [ ] transact_write_items - [ ] untag_resource - [ ] update_continuous_backups +- [ ] update_contributor_insights - [ ] update_global_table - [ ] update_global_table_settings - [ ] update_item - [ ] update_table +- [ ] update_table_replica_auto_scaling - [ ] update_time_to_live ## dynamodbstreams @@ -1953,9 +2287,16 @@ - [X] get_shard_iterator - [X] list_streams +## ebs +0% implemented +- [ ] get_snapshot_block +- [ ] list_changed_blocks +- [ ] list_snapshot_blocks + ## ec2 -28% implemented +26% implemented - [ ] accept_reserved_instances_exchange_quote +- [ ] accept_transit_gateway_peering_attachment - [ ] accept_transit_gateway_vpc_attachment - [ ] accept_vpc_endpoint_connections - [X] accept_vpc_peering_connection @@ -1971,6 +2312,7 @@ - [ ] associate_iam_instance_profile - [X] associate_route_table - [ ] associate_subnet_cidr_block +- [ ] associate_transit_gateway_multicast_domain - [ ] associate_transit_gateway_route_table - [X] associate_vpc_cidr_block - [ ] attach_classic_link_vpc @@ -2011,6 +2353,8 @@ - [X] create_key_pair - [X] create_launch_template - [ ] create_launch_template_version +- [ ] create_local_gateway_route +- [ ] create_local_gateway_route_table_vpc_association - [X] create_nat_gateway - [X] create_network_acl - [X] create_network_acl_entry @@ -2031,6 +2375,8 @@ - [ ] create_traffic_mirror_session - [ ] create_traffic_mirror_target - [ ] create_transit_gateway +- [ ] create_transit_gateway_multicast_domain +- [ ] create_transit_gateway_peering_attachment - [ ] create_transit_gateway_route - [ ] create_transit_gateway_route_table - [ ] create_transit_gateway_vpc_attachment @@ -2055,12 +2401,15 @@ - [X] delete_key_pair - [ ] delete_launch_template - [ ] delete_launch_template_versions +- [ ] delete_local_gateway_route +- [ ] delete_local_gateway_route_table_vpc_association - [X] delete_nat_gateway - [X] delete_network_acl - [X] delete_network_acl_entry - [X] delete_network_interface - [ ] delete_network_interface_permission - [ ] delete_placement_group +- [ ] delete_queued_reserved_instances - [X] delete_route - [X] delete_route_table - [X] delete_security_group @@ -2073,6 +2422,8 @@ - [ ] delete_traffic_mirror_session - [ ] delete_traffic_mirror_target - [ ] delete_transit_gateway +- [ ] delete_transit_gateway_multicast_domain +- [ ] delete_transit_gateway_peering_attachment - [ ] delete_transit_gateway_route - [ ] delete_transit_gateway_route_table - [ ] delete_transit_gateway_vpc_attachment @@ -2087,6 +2438,8 @@ - [X] delete_vpn_gateway - [ ] deprovision_byoip_cidr - [X] deregister_image +- [ ] deregister_transit_gateway_multicast_group_members +- [ ] deregister_transit_gateway_multicast_group_sources - [ ] describe_account_attributes - [X] describe_addresses - [ ] describe_aggregate_id_format @@ -2100,12 +2453,15 @@ - [ ] describe_client_vpn_endpoints - [ ] describe_client_vpn_routes - [ ] describe_client_vpn_target_networks +- [ ] describe_coip_pools - [ ] describe_conversion_tasks - [ ] describe_customer_gateways - [X] describe_dhcp_options - [ ] describe_egress_only_internet_gateways - [ ] describe_elastic_gpus +- [ ] describe_export_image_tasks - [ ] describe_export_tasks +- [ ] describe_fast_snapshot_restores - [ ] describe_fleet_history - [ ] describe_fleet_instances - [ ] describe_fleets @@ -2125,11 +2481,19 @@ - [X] describe_instance_attribute - [ ] describe_instance_credit_specifications - [ ] describe_instance_status +- [ ] describe_instance_type_offerings +- [ ] describe_instance_types - [ ] describe_instances - [X] describe_internet_gateways - [X] describe_key_pairs - [ ] describe_launch_template_versions - [ ] describe_launch_templates +- [ ] describe_local_gateway_route_table_virtual_interface_group_associations +- [ ] describe_local_gateway_route_table_vpc_associations +- [ ] describe_local_gateway_route_tables +- [ ] describe_local_gateway_virtual_interface_groups +- [ ] describe_local_gateway_virtual_interfaces +- [ ] describe_local_gateways - [ ] describe_moving_addresses - [ ] describe_nat_gateways - [ ] describe_network_acls @@ -2165,6 +2529,8 @@ - [ ] describe_traffic_mirror_sessions - [ ] describe_traffic_mirror_targets - [ ] describe_transit_gateway_attachments +- [ ] describe_transit_gateway_multicast_domains +- [ ] describe_transit_gateway_peering_attachments - [ ] describe_transit_gateway_route_tables - [ ] describe_transit_gateway_vpc_attachments - [ ] describe_transit_gateways @@ -2191,29 +2557,35 @@ - [X] detach_volume - [X] detach_vpn_gateway - [ ] disable_ebs_encryption_by_default +- [ ] disable_fast_snapshot_restores - [ ] disable_transit_gateway_route_table_propagation - [ ] disable_vgw_route_propagation -- [ ] disable_vpc_classic_link -- [ ] disable_vpc_classic_link_dns_support +- [X] disable_vpc_classic_link +- [X] disable_vpc_classic_link_dns_support - [X] disassociate_address - [ ] disassociate_client_vpn_target_network - [ ] disassociate_iam_instance_profile - [X] disassociate_route_table - [ ] disassociate_subnet_cidr_block +- [ ] disassociate_transit_gateway_multicast_domain - [ ] disassociate_transit_gateway_route_table - [X] disassociate_vpc_cidr_block - [ ] enable_ebs_encryption_by_default +- [ ] enable_fast_snapshot_restores - [ ] enable_transit_gateway_route_table_propagation - [ ] enable_vgw_route_propagation - [ ] enable_volume_io -- [ ] enable_vpc_classic_link -- [ ] enable_vpc_classic_link_dns_support +- [X] enable_vpc_classic_link +- [X] enable_vpc_classic_link_dns_support - [ ] export_client_vpn_client_certificate_revocation_list - [ ] export_client_vpn_client_configuration +- [ ] export_image - [ ] export_transit_gateway_routes - [ ] get_capacity_reservation_usage +- [ ] get_coip_pool_usage - [ ] get_console_output - [ ] get_console_screenshot +- [ ] get_default_credit_specification - [ ] get_ebs_default_kms_key_id - [ ] get_ebs_encryption_by_default - [ ] get_host_reservation_purchase_preview @@ -2221,6 +2593,7 @@ - [ ] get_password_data - [ ] get_reserved_instances_exchange_quote - [ ] get_transit_gateway_attachment_propagations +- [ ] get_transit_gateway_multicast_domain_associations - [ ] get_transit_gateway_route_table_associations - [ ] get_transit_gateway_route_table_propagations - [ ] import_client_vpn_client_certificate_revocation_list @@ -2231,6 +2604,7 @@ - [ ] import_volume - [ ] modify_capacity_reservation - [ ] modify_client_vpn_endpoint +- [ ] modify_default_credit_specification - [ ] modify_ebs_default_kms_key_id - [ ] modify_fleet - [ ] modify_fpga_image_attribute @@ -2242,6 +2616,7 @@ - [ ] modify_instance_capacity_reservation_attributes - [ ] modify_instance_credit_specification - [ ] modify_instance_event_start_time +- [ ] modify_instance_metadata_options - [ ] modify_instance_placement - [ ] modify_launch_template - [X] modify_network_interface_attribute @@ -2263,6 +2638,8 @@ - [ ] modify_vpc_peering_connection_options - [ ] modify_vpc_tenancy - [ ] modify_vpn_connection +- [ ] modify_vpn_tunnel_certificate +- [ ] modify_vpn_tunnel_options - [ ] monitor_instances - [ ] move_address_to_vpc - [ ] provision_byoip_cidr @@ -2271,6 +2648,9 @@ - [ ] purchase_scheduled_instances - [X] reboot_instances - [ ] register_image +- [ ] register_transit_gateway_multicast_group_members +- [ ] register_transit_gateway_multicast_group_sources +- [ ] reject_transit_gateway_peering_attachment - [ ] reject_transit_gateway_vpc_attachment - [ ] reject_vpc_endpoint_connections - [X] reject_vpc_peering_connection @@ -2297,7 +2677,10 @@ - [X] revoke_security_group_ingress - [ ] run_instances - [ ] run_scheduled_instances +- [ ] search_local_gateway_routes +- [ ] search_transit_gateway_multicast_groups - [ ] search_transit_gateway_routes +- [ ] send_diagnostic_interrupt - [X] start_instances - [X] stop_instances - [ ] terminate_client_vpn_connections @@ -2314,7 +2697,7 @@ - [ ] send_ssh_public_key ## ecr -30% implemented +27% implemented - [ ] batch_check_layer_availability - [X] batch_delete_image - [X] batch_get_image @@ -2323,6 +2706,7 @@ - [ ] delete_lifecycle_policy - [X] delete_repository - [ ] delete_repository_policy +- [ ] describe_image_scan_findings - [X] describe_images - [X] describe_repositories - [ ] get_authorization_token @@ -2334,16 +2718,19 @@ - [X] list_images - [ ] list_tags_for_resource - [X] put_image +- [ ] put_image_scanning_configuration - [ ] put_image_tag_mutability - [ ] put_lifecycle_policy - [ ] set_repository_policy +- [ ] start_image_scan - [ ] start_lifecycle_policy_preview - [ ] tag_resource - [ ] untag_resource - [ ] upload_layer_part ## ecs -63% implemented +62% implemented +- [ ] create_capacity_provider - [X] create_cluster - [X] create_service - [ ] create_task_set @@ -2354,6 +2741,7 @@ - [ ] delete_task_set - [X] deregister_container_instance - [X] deregister_task_definition +- [ ] describe_capacity_providers - [X] describe_clusters - [X] describe_container_instances - [X] describe_services @@ -2373,6 +2761,7 @@ - [ ] put_account_setting - [ ] put_account_setting_default - [X] put_attributes +- [ ] put_cluster_capacity_providers - [X] register_container_instance - [X] register_task_definition - [X] run_task @@ -2381,8 +2770,9 @@ - [ ] submit_attachment_state_changes - [ ] submit_container_state_change - [ ] submit_task_state_change -- [ ] tag_resource -- [ ] untag_resource +- [X] tag_resource +- [X] untag_resource +- [ ] update_cluster_settings - [ ] update_container_agent - [X] update_container_instances_state - [X] update_service @@ -2409,13 +2799,32 @@ ## eks 0% implemented - [ ] create_cluster +- [ ] create_fargate_profile +- [ ] create_nodegroup - [ ] delete_cluster +- [ ] delete_fargate_profile +- [ ] delete_nodegroup - [ ] describe_cluster +- [ ] describe_fargate_profile +- [ ] describe_nodegroup - [ ] describe_update - [ ] list_clusters +- [ ] list_fargate_profiles +- [ ] list_nodegroups +- [ ] list_tags_for_resource - [ ] list_updates +- [ ] tag_resource +- [ ] untag_resource - [ ] update_cluster_config - [ ] update_cluster_version +- [ ] update_nodegroup_config +- [ ] update_nodegroup_version + +## elastic-inference +0% implemented +- [ ] list_tags_for_resource +- [ ] tag_resource +- [ ] untag_resource ## elasticache 0% implemented @@ -2423,6 +2832,7 @@ - [ ] authorize_cache_security_group_ingress - [ ] batch_apply_update_action - [ ] batch_stop_update_action +- [ ] complete_migration - [ ] copy_snapshot - [ ] create_cache_cluster - [ ] create_cache_parameter_group @@ -2464,6 +2874,7 @@ - [ ] remove_tags_from_resource - [ ] reset_cache_parameter_group - [ ] revoke_cache_security_group_ingress +- [ ] start_migration - [ ] test_failover ## elasticbeanstalk @@ -2603,7 +3014,7 @@ - [X] set_subnets ## emr -55% implemented +50% implemented - [ ] add_instance_fleet - [X] add_instance_groups - [X] add_job_flow_steps @@ -2615,6 +3026,7 @@ - [X] describe_job_flows - [ ] describe_security_configuration - [X] describe_step +- [ ] get_block_public_access_configuration - [X] list_bootstrap_actions - [X] list_clusters - [ ] list_instance_fleets @@ -2622,9 +3034,11 @@ - [ ] list_instances - [ ] list_security_configurations - [X] list_steps +- [ ] modify_cluster - [ ] modify_instance_fleet - [X] modify_instance_groups - [ ] put_auto_scaling_policy +- [ ] put_block_public_access_configuration - [ ] remove_auto_scaling_policy - [X] remove_tags - [X] run_job_flow @@ -2659,12 +3073,12 @@ - [ ] upgrade_elasticsearch_domain ## events -48% implemented +58% implemented - [ ] activate_event_source -- [ ] create_event_bus +- [X] create_event_bus - [ ] create_partner_event_source - [ ] deactivate_event_source -- [ ] delete_event_bus +- [X] delete_event_bus - [ ] delete_partner_event_source - [X] delete_rule - [X] describe_event_bus @@ -2673,7 +3087,7 @@ - [X] describe_rule - [X] disable_rule - [X] enable_rule -- [ ] list_event_buses +- [X] list_event_buses - [ ] list_event_sources - [ ] list_partner_event_source_accounts - [ ] list_partner_event_sources @@ -2724,6 +3138,72 @@ - [ ] put_notification_channel - [ ] put_policy +## forecast +0% implemented +- [ ] create_dataset +- [ ] create_dataset_group +- [ ] create_dataset_import_job +- [ ] create_forecast +- [ ] create_forecast_export_job +- [ ] create_predictor +- [ ] delete_dataset +- [ ] delete_dataset_group +- [ ] delete_dataset_import_job +- [ ] delete_forecast +- [ ] delete_forecast_export_job +- [ ] delete_predictor +- [ ] describe_dataset +- [ ] describe_dataset_group +- [ ] describe_dataset_import_job +- [ ] describe_forecast +- [ ] describe_forecast_export_job +- [ ] describe_predictor +- [ ] get_accuracy_metrics +- [ ] list_dataset_groups +- [ ] list_dataset_import_jobs +- [ ] list_datasets +- [ ] list_forecast_export_jobs +- [ ] list_forecasts +- [ ] list_predictors +- [ ] update_dataset_group + +## forecastquery +0% implemented +- [ ] query_forecast + +## frauddetector +0% implemented +- [ ] batch_create_variable +- [ ] batch_get_variable +- [ ] create_detector_version +- [ ] create_model_version +- [ ] create_rule +- [ ] create_variable +- [ ] delete_detector_version +- [ ] delete_event +- [ ] describe_detector +- [ ] describe_model_versions +- [ ] get_detector_version +- [ ] get_detectors +- [ ] get_external_models +- [ ] get_model_version +- [ ] get_models +- [ ] get_outcomes +- [ ] get_prediction +- [ ] get_rules +- [ ] get_variables +- [ ] put_detector +- [ ] put_external_model +- [ ] put_model +- [ ] put_outcome +- [ ] update_detector_version +- [ ] update_detector_version_metadata +- [ ] update_detector_version_status +- [ ] update_model_version +- [ ] update_rule_metadata +- [ ] update_rule_version +- [ ] update_variable + ## fsx 0% implemented - [ ] create_backup @@ -2871,7 +3351,7 @@ - [ ] update_listener ## glue -5% implemented +4% implemented - [ ] batch_create_partition - [ ] batch_delete_connection - [ ] batch_delete_partition @@ -2884,12 +3364,14 @@ - [ ] batch_get_triggers - [ ] batch_get_workflows - [ ] batch_stop_job_run +- [ ] cancel_ml_task_run - [ ] create_classifier - [ ] create_connection - [ ] create_crawler - [X] create_database - [ ] create_dev_endpoint - [ ] create_job +- [ ] create_ml_transform - [ ] create_partition - [ ] create_script - [ ] create_security_configuration @@ -2903,6 +3385,7 @@ - [ ] delete_database - [ ] delete_dev_endpoint - [ ] delete_job +- [ ] delete_ml_transform - [ ] delete_partition - [ ] delete_resource_policy - [ ] delete_security_configuration @@ -2927,11 +3410,14 @@ - [ ] get_dev_endpoints - [ ] get_job - [ ] get_job_bookmark -- [ ] get_job_bookmarks - [ ] get_job_run - [ ] get_job_runs - [ ] get_jobs - [ ] get_mapping +- [ ] get_ml_task_run +- [ ] get_ml_task_runs +- [ ] get_ml_transform +- [ ] get_ml_transforms - [ ] get_partition - [ ] get_partitions - [ ] get_plan @@ -2961,9 +3447,14 @@ - [ ] put_resource_policy - [ ] put_workflow_run_properties - [ ] reset_job_bookmark +- [ ] search_tables - [ ] start_crawler - [ ] start_crawler_schedule +- [ ] start_export_labels_task_run +- [ ] start_import_labels_task_run - [ ] start_job_run +- [ ] start_ml_evaluation_task_run +- [ ] start_ml_labeling_set_generation_task_run - [ ] start_trigger - [ ] start_workflow_run - [ ] stop_crawler @@ -2978,6 +3469,7 @@ - [ ] update_database - [ ] update_dev_endpoint - [ ] update_job +- [ ] update_ml_transform - [ ] update_partition - [ ] update_table - [ ] update_trigger @@ -3113,6 +3605,7 @@ - [ ] create_filter - [ ] create_ip_set - [ ] create_members +- [ ] create_publishing_destination - [ ] create_sample_findings - [ ] create_threat_intel_set - [ ] decline_invitations @@ -3121,7 +3614,9 @@ - [ ] delete_invitations - [ ] delete_ip_set - [ ] delete_members +- [ ] delete_publishing_destination - [ ] delete_threat_intel_set +- [ ] describe_publishing_destination - [ ] disassociate_from_master_account - [ ] disassociate_members - [ ] get_detector @@ -3140,6 +3635,7 @@ - [ ] list_invitations - [ ] list_ip_sets - [ ] list_members +- [ ] list_publishing_destinations - [ ] list_tags_for_resource - [ ] list_threat_intel_sets - [ ] start_monitoring_members @@ -3151,6 +3647,7 @@ - [ ] update_filter - [ ] update_findings_feedback - [ ] update_ip_set +- [ ] update_publishing_destination - [ ] update_threat_intel_set ## health @@ -3163,7 +3660,7 @@ - [ ] describe_events ## iam -55% implemented +67% implemented - [ ] add_client_id_to_open_id_connect_provider - [X] add_role_to_instance_profile - [X] add_user_to_group @@ -3176,7 +3673,7 @@ - [X] create_group - [X] create_instance_profile - [X] create_login_profile -- [ ] create_open_id_connect_provider +- [X] create_open_id_connect_provider - [X] create_policy - [X] create_policy_version - [X] create_role @@ -3184,17 +3681,17 @@ - [ ] create_service_linked_role - [ ] create_service_specific_credential - [X] create_user -- [ ] create_virtual_mfa_device +- [X] create_virtual_mfa_device - [X] deactivate_mfa_device - [X] delete_access_key - [X] delete_account_alias -- [ ] delete_account_password_policy -- [ ] delete_group +- [X] delete_account_password_policy +- [X] delete_group - [ ] delete_group_policy - [ ] delete_instance_profile - [X] delete_login_profile -- [ ] delete_open_id_connect_provider -- [ ] delete_policy +- [X] delete_open_id_connect_provider +- [X] delete_policy - [X] delete_policy_version - [X] delete_role - [ ] delete_role_permissions_boundary @@ -3204,11 +3701,11 @@ - [ ] delete_service_linked_role - [ ] delete_service_specific_credential - [X] delete_signing_certificate -- [ ] delete_ssh_public_key +- [X] delete_ssh_public_key - [X] delete_user - [ ] delete_user_permissions_boundary - [X] delete_user_policy -- [ ] delete_virtual_mfa_device +- [X] delete_virtual_mfa_device - [X] detach_group_policy - [X] detach_role_policy - [X] detach_user_policy @@ -3218,8 +3715,8 @@ - [ ] generate_service_last_accessed_details - [X] get_access_key_last_used - [X] get_account_authorization_details -- [ ] get_account_password_policy -- [ ] get_account_summary +- [X] get_account_password_policy +- [X] get_account_summary - [ ] get_context_keys_for_custom_policy - [ ] get_context_keys_for_principal_policy - [X] get_credential_report @@ -3227,7 +3724,7 @@ - [X] get_group_policy - [X] get_instance_profile - [X] get_login_profile -- [ ] get_open_id_connect_provider +- [X] get_open_id_connect_provider - [ ] get_organizations_access_report - [X] get_policy - [X] get_policy_version @@ -3238,7 +3735,7 @@ - [ ] get_service_last_accessed_details - [ ] get_service_last_accessed_details_with_entities - [ ] get_service_linked_role_deletion_status -- [ ] get_ssh_public_key +- [X] get_ssh_public_key - [X] get_user - [X] get_user_policy - [ ] list_access_keys @@ -3253,7 +3750,7 @@ - [ ] list_instance_profiles - [ ] list_instance_profiles_for_role - [X] list_mfa_devices -- [ ] list_open_id_connect_providers +- [X] list_open_id_connect_providers - [X] list_policies - [ ] list_policies_granting_service_access - [X] list_policy_versions @@ -3268,7 +3765,7 @@ - [X] list_user_policies - [ ] list_user_tags - [X] list_users -- [ ] list_virtual_mfa_devices +- [X] list_virtual_mfa_devices - [X] put_group_policy - [ ] put_role_permissions_boundary - [X] put_role_policy @@ -3288,7 +3785,7 @@ - [X] untag_role - [ ] untag_user - [X] update_access_key -- [ ] update_account_password_policy +- [X] update_account_password_policy - [ ] update_assume_role_policy - [ ] update_group - [X] update_login_profile @@ -3299,11 +3796,56 @@ - [ ] update_server_certificate - [ ] update_service_specific_credential - [X] update_signing_certificate -- [ ] update_ssh_public_key +- [X] update_ssh_public_key - [X] update_user -- [ ] upload_server_certificate +- [X] upload_server_certificate - [X] upload_signing_certificate -- [ ] upload_ssh_public_key +- [X] upload_ssh_public_key + +## imagebuilder +0% implemented +- [ ] cancel_image_creation +- [ ] create_component +- [ ] create_distribution_configuration +- [ ] create_image +- [ ] create_image_pipeline +- [ ] create_image_recipe +- [ ] create_infrastructure_configuration +- [ ] delete_component +- [ ] delete_distribution_configuration +- [ ] delete_image +- [ ] delete_image_pipeline +- [ ] delete_image_recipe +- [ ] delete_infrastructure_configuration +- [ ] get_component +- [ ] get_component_policy +- [ ] get_distribution_configuration +- [ ] get_image +- [ ] get_image_pipeline +- [ ] get_image_policy +- [ ] get_image_recipe +- [ ] get_image_recipe_policy +- [ ] get_infrastructure_configuration +- [ ] import_component +- [ ] list_component_build_versions +- [ ] list_components +- [ ] list_distribution_configurations +- [ ] list_image_build_versions +- [ ] list_image_pipeline_images +- [ ] list_image_pipelines +- [ ] list_image_recipes +- [ ] list_images +- [ ] list_infrastructure_configurations +- [ ] list_tags_for_resource +- [ ] put_component_policy +- [ ] put_image_policy +- [ ] put_image_recipe_policy +- [ ] start_image_pipeline_execution +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_distribution_configuration +- [ ] update_image_pipeline +- [ ] update_infrastructure_configuration ## importexport 0% implemented @@ -3355,7 +3897,7 @@ - [ ] update_assessment_target ## iot -24% implemented +20% implemented - [ ] accept_certificate_transfer - [ ] add_thing_to_billing_group - [X] add_thing_to_thing_group @@ -3364,20 +3906,27 @@ - [X] attach_principal_policy - [ ] attach_security_profile - [X] attach_thing_principal +- [ ] cancel_audit_mitigation_actions_task - [ ] cancel_audit_task - [ ] cancel_certificate_transfer - [ ] cancel_job - [ ] cancel_job_execution - [ ] clear_default_authorizer +- [ ] confirm_topic_rule_destination - [ ] create_authorizer - [ ] create_billing_group - [ ] create_certificate_from_csr +- [ ] create_domain_configuration - [ ] create_dynamic_thing_group - [X] create_job - [X] create_keys_and_certificate +- [ ] create_mitigation_action - [ ] create_ota_update - [X] create_policy - [ ] create_policy_version +- [ ] create_provisioning_claim +- [ ] create_provisioning_template +- [ ] create_provisioning_template_version - [ ] create_role_alias - [ ] create_scheduled_audit - [ ] create_security_profile @@ -3386,17 +3935,22 @@ - [X] create_thing_group - [X] create_thing_type - [ ] create_topic_rule +- [ ] create_topic_rule_destination - [ ] delete_account_audit_configuration - [ ] delete_authorizer - [ ] delete_billing_group - [ ] delete_ca_certificate - [X] delete_certificate +- [ ] delete_domain_configuration - [ ] delete_dynamic_thing_group - [ ] delete_job - [ ] delete_job_execution +- [ ] delete_mitigation_action - [ ] delete_ota_update - [X] delete_policy - [ ] delete_policy_version +- [ ] delete_provisioning_template +- [ ] delete_provisioning_template_version - [ ] delete_registration_code - [ ] delete_role_alias - [ ] delete_scheduled_audit @@ -3406,20 +3960,27 @@ - [X] delete_thing_group - [X] delete_thing_type - [ ] delete_topic_rule +- [ ] delete_topic_rule_destination - [ ] delete_v2_logging_level - [ ] deprecate_thing_type - [ ] describe_account_audit_configuration +- [ ] describe_audit_finding +- [ ] describe_audit_mitigation_actions_task - [ ] describe_audit_task - [ ] describe_authorizer - [ ] describe_billing_group - [ ] describe_ca_certificate - [X] describe_certificate - [ ] describe_default_authorizer +- [ ] describe_domain_configuration - [ ] describe_endpoint - [ ] describe_event_configurations - [ ] describe_index - [X] describe_job - [ ] describe_job_execution +- [ ] describe_mitigation_action +- [ ] describe_provisioning_template +- [ ] describe_provisioning_template_version - [ ] describe_role_alias - [ ] describe_scheduled_audit - [ ] describe_security_profile @@ -3434,30 +3995,37 @@ - [X] detach_thing_principal - [ ] disable_topic_rule - [ ] enable_topic_rule +- [ ] get_cardinality - [ ] get_effective_policies - [ ] get_indexing_configuration - [ ] get_job_document - [ ] get_logging_options - [ ] get_ota_update +- [ ] get_percentiles - [X] get_policy - [ ] get_policy_version - [ ] get_registration_code - [ ] get_statistics - [ ] get_topic_rule +- [ ] get_topic_rule_destination - [ ] get_v2_logging_options - [ ] list_active_violations - [ ] list_attached_policies - [ ] list_audit_findings +- [ ] list_audit_mitigation_actions_executions +- [ ] list_audit_mitigation_actions_tasks - [ ] list_audit_tasks - [ ] list_authorizers - [ ] list_billing_groups - [ ] list_ca_certificates - [X] list_certificates - [ ] list_certificates_by_ca +- [ ] list_domain_configurations - [ ] list_indices - [ ] list_job_executions_for_job - [ ] list_job_executions_for_thing - [ ] list_jobs +- [ ] list_mitigation_actions - [ ] list_ota_updates - [ ] list_outgoing_certificates - [X] list_policies @@ -3465,6 +4033,8 @@ - [ ] list_policy_versions - [X] list_principal_policies - [X] list_principal_things +- [ ] list_provisioning_template_versions +- [ ] list_provisioning_templates - [ ] list_role_aliases - [ ] list_scheduled_audits - [ ] list_security_profiles @@ -3482,6 +4052,7 @@ - [X] list_things - [ ] list_things_in_billing_group - [X] list_things_in_thing_group +- [ ] list_topic_rule_destinations - [ ] list_topic_rules - [ ] list_v2_logging_levels - [ ] list_violation_events @@ -3498,6 +4069,7 @@ - [ ] set_logging_options - [ ] set_v2_logging_level - [ ] set_v2_logging_options +- [ ] start_audit_mitigation_actions_task - [ ] start_on_demand_audit_task - [ ] start_thing_registration_task - [ ] stop_thing_registration_task @@ -3511,10 +4083,13 @@ - [ ] update_billing_group - [ ] update_ca_certificate - [X] update_certificate +- [ ] update_domain_configuration - [ ] update_dynamic_thing_group - [ ] update_event_configurations - [ ] update_indexing_configuration - [ ] update_job +- [ ] update_mitigation_action +- [ ] update_provisioning_template - [ ] update_role_alias - [ ] update_scheduled_audit - [ ] update_security_profile @@ -3522,6 +4097,7 @@ - [X] update_thing - [X] update_thing_group - [X] update_thing_groups_for_thing +- [ ] update_topic_rule_destination - [ ] validate_security_profile_behaviors ## iot-data @@ -3636,6 +4212,16 @@ - [ ] describe_detector - [ ] list_detectors +## iotsecuretunneling +0% implemented +- [ ] close_tunnel +- [ ] describe_tunnel +- [ ] list_tags_for_resource +- [ ] list_tunnels +- [ ] open_tunnel +- [ ] tag_resource +- [ ] untag_resource + ## iotthingsgraph 0% implemented - [ ] associate_entity_to_thing @@ -3692,8 +4278,33 @@ - [ ] list_tags_for_resource - [ ] tag_resource - [ ] untag_resource +- [ ] update_broker_count - [ ] update_broker_storage - [ ] update_cluster_configuration +- [ ] update_monitoring + +## kendra +0% implemented +- [ ] batch_delete_document +- [ ] batch_put_document +- [ ] create_data_source +- [ ] create_faq +- [ ] create_index +- [ ] delete_faq +- [ ] delete_index +- [ ] describe_data_source +- [ ] describe_faq +- [ ] describe_index +- [ ] list_data_source_sync_jobs +- [ ] list_data_sources +- [ ] list_faqs +- [ ] list_indices +- [ ] query +- [ ] start_data_source_sync_job +- [ ] stop_data_source_sync_job +- [ ] submit_feedback +- [ ] update_data_source +- [ ] update_index ## kinesis 50% implemented @@ -3737,6 +4348,11 @@ 0% implemented - [ ] get_media +## kinesis-video-signaling +0% implemented +- [ ] get_ice_server_config +- [ ] send_alexa_offer_to_master + ## kinesisanalytics 0% implemented - [ ] add_application_cloud_watch_logging_option @@ -3767,6 +4383,7 @@ - [ ] add_application_input_processing_configuration - [ ] add_application_output - [ ] add_application_reference_data_source +- [ ] add_application_vpc_configuration - [ ] create_application - [ ] create_application_snapshot - [ ] delete_application @@ -3775,6 +4392,7 @@ - [ ] delete_application_output - [ ] delete_application_reference_data_source - [ ] delete_application_snapshot +- [ ] delete_application_vpc_configuration - [ ] describe_application - [ ] describe_application_snapshot - [ ] discover_input_schema @@ -3789,26 +4407,35 @@ ## kinesisvideo 0% implemented +- [ ] create_signaling_channel - [ ] create_stream +- [ ] delete_signaling_channel - [ ] delete_stream +- [ ] describe_signaling_channel - [ ] describe_stream - [ ] get_data_endpoint +- [ ] get_signaling_channel_endpoint +- [ ] list_signaling_channels - [ ] list_streams +- [ ] list_tags_for_resource - [ ] list_tags_for_stream +- [ ] tag_resource - [ ] tag_stream +- [ ] untag_resource - [ ] untag_stream - [ ] update_data_retention +- [ ] update_signaling_channel - [ ] update_stream ## kms -41% implemented +43% implemented - [X] cancel_key_deletion - [ ] connect_custom_key_store - [ ] create_alias - [ ] create_custom_key_store - [ ] create_grant - [X] create_key -- [ ] decrypt +- [X] decrypt - [X] delete_alias - [ ] delete_custom_key_store - [ ] delete_imported_key_material @@ -3819,13 +4446,16 @@ - [ ] disconnect_custom_key_store - [X] enable_key - [X] enable_key_rotation -- [ ] encrypt +- [X] encrypt - [X] generate_data_key +- [ ] generate_data_key_pair +- [ ] generate_data_key_pair_without_plaintext - [ ] generate_data_key_without_plaintext - [ ] generate_random - [X] get_key_policy - [X] get_key_rotation_status - [ ] get_parameters_for_import +- [ ] get_public_key - [ ] import_key_material - [ ] list_aliases - [ ] list_grants @@ -3834,57 +4464,85 @@ - [X] list_resource_tags - [ ] list_retirable_grants - [X] put_key_policy -- [ ] re_encrypt +- [X] re_encrypt - [ ] retire_grant - [ ] revoke_grant - [X] schedule_key_deletion +- [ ] sign - [X] tag_resource - [ ] untag_resource - [ ] update_alias - [ ] update_custom_key_store - [X] update_key_description +- [ ] verify + +## lakeformation +0% implemented +- [ ] batch_grant_permissions +- [ ] batch_revoke_permissions +- [ ] deregister_resource +- [ ] describe_resource +- [ ] get_data_lake_settings +- [ ] get_effective_permissions_for_path +- [ ] grant_permissions +- [ ] list_permissions +- [ ] list_resources +- [ ] put_data_lake_settings +- [ ] register_resource +- [ ] revoke_permissions +- [ ] update_resource ## lambda -0% implemented +32% implemented - [ ] add_layer_version_permission - [ ] add_permission - [ ] create_alias -- [ ] create_event_source_mapping -- [ ] create_function +- [X] create_event_source_mapping +- [X] create_function - [ ] delete_alias -- [ ] delete_event_source_mapping -- [ ] delete_function +- [X] delete_event_source_mapping +- [X] delete_function - [ ] delete_function_concurrency +- [ ] delete_function_event_invoke_config - [ ] delete_layer_version +- [ ] delete_provisioned_concurrency_config - [ ] get_account_settings - [ ] get_alias -- [ ] get_event_source_mapping -- [ ] get_function +- [X] get_event_source_mapping +- [X] get_function +- [ ] get_function_concurrency - [ ] get_function_configuration +- [ ] get_function_event_invoke_config - [ ] get_layer_version - [ ] get_layer_version_by_arn - [ ] get_layer_version_policy - [ ] get_policy -- [ ] invoke +- [ ] get_provisioned_concurrency_config +- [X] invoke - [ ] invoke_async - [ ] list_aliases -- [ ] list_event_source_mappings -- [ ] list_functions +- [X] list_event_source_mappings +- [ ] list_function_event_invoke_configs +- [X] list_functions - [ ] list_layer_versions - [ ] list_layers -- [ ] list_tags -- [ ] list_versions_by_function +- [ ] list_provisioned_concurrency_configs +- [X] list_tags +- [X] list_versions_by_function - [ ] publish_layer_version - [ ] publish_version - [ ] put_function_concurrency +- [ ] put_function_event_invoke_config +- [ ] put_provisioned_concurrency_config - [ ] remove_layer_version_permission - [ ] remove_permission -- [ ] tag_resource -- [ ] untag_resource +- [X] tag_resource +- [X] untag_resource - [ ] update_alias -- [ ] update_event_source_mapping -- [ ] update_function_code -- [ ] update_function_configuration +- [X] update_event_source_mapping +- [X] update_function_code +- [X] update_function_configuration +- [ ] update_function_event_invoke_config ## lex-models 0% implemented @@ -3927,8 +4585,11 @@ ## lex-runtime 0% implemented +- [ ] delete_session +- [ ] get_session - [ ] post_content - [ ] post_text +- [ ] put_session ## license-manager 0% implemented @@ -3937,6 +4598,7 @@ - [ ] get_license_configuration - [ ] get_service_settings - [ ] list_associations_for_license_configuration +- [ ] list_failures_for_license_configuration_operations - [ ] list_license_configurations - [ ] list_license_specifications_for_resource - [ ] list_resource_inventory @@ -3972,6 +4634,7 @@ - [ ] create_relational_database - [ ] create_relational_database_from_snapshot - [ ] create_relational_database_snapshot +- [ ] delete_auto_snapshot - [ ] delete_disk - [ ] delete_disk_snapshot - [ ] delete_domain @@ -3987,9 +4650,12 @@ - [ ] detach_disk - [ ] detach_instances_from_load_balancer - [ ] detach_static_ip +- [ ] disable_add_on - [ ] download_default_key_pair +- [ ] enable_add_on - [ ] export_snapshot - [ ] get_active_names +- [ ] get_auto_snapshots - [ ] get_blueprints - [ ] get_bundles - [ ] get_cloud_formation_stack_records @@ -4053,7 +4719,7 @@ - [ ] update_relational_database_parameters ## logs -28% implemented +35% implemented - [ ] associate_kms_key - [ ] cancel_export_task - [ ] create_export_task @@ -4080,7 +4746,7 @@ - [ ] get_log_group_fields - [ ] get_log_record - [ ] get_query_results -- [ ] list_tags_log_group +- [X] list_tags_log_group - [ ] put_destination - [ ] put_destination_policy - [X] put_log_events @@ -4090,9 +4756,9 @@ - [ ] put_subscription_filter - [ ] start_query - [ ] stop_query -- [ ] tag_log_group +- [X] tag_log_group - [ ] test_metric_filter -- [ ] untag_log_group +- [X] untag_log_group ## machinelearning 0% implemented @@ -4156,6 +4822,15 @@ - [ ] reject_invitation - [ ] vote_on_proposal +## marketplace-catalog +0% implemented +- [ ] cancel_change_set +- [ ] describe_change_set +- [ ] describe_entity +- [ ] list_change_sets +- [ ] list_entities +- [ ] start_change_set + ## marketplace-entitlement 0% implemented - [ ] get_entitlements @@ -4219,43 +4894,58 @@ - [ ] create_channel - [ ] create_input - [ ] create_input_security_group +- [ ] create_multiplex +- [ ] create_multiplex_program - [ ] create_tags - [ ] delete_channel - [ ] delete_input - [ ] delete_input_security_group +- [ ] delete_multiplex +- [ ] delete_multiplex_program - [ ] delete_reservation - [ ] delete_schedule - [ ] delete_tags - [ ] describe_channel - [ ] describe_input - [ ] describe_input_security_group +- [ ] describe_multiplex +- [ ] describe_multiplex_program - [ ] describe_offering - [ ] describe_reservation - [ ] describe_schedule - [ ] list_channels - [ ] list_input_security_groups - [ ] list_inputs +- [ ] list_multiplex_programs +- [ ] list_multiplexes - [ ] list_offerings - [ ] list_reservations - [ ] list_tags_for_resource - [ ] purchase_offering - [ ] start_channel +- [ ] start_multiplex - [ ] stop_channel +- [ ] stop_multiplex - [ ] update_channel - [ ] update_channel_class - [ ] update_input - [ ] update_input_security_group +- [ ] update_multiplex +- [ ] update_multiplex_program - [ ] update_reservation ## mediapackage 0% implemented - [ ] create_channel +- [ ] create_harvest_job - [ ] create_origin_endpoint - [ ] delete_channel - [ ] delete_origin_endpoint - [ ] describe_channel +- [ ] describe_harvest_job - [ ] describe_origin_endpoint - [ ] list_channels +- [ ] list_harvest_jobs - [ ] list_origin_endpoints - [ ] list_tags_for_resource - [ ] rotate_channel_credentials @@ -4345,6 +5035,12 @@ - [ ] notify_migration_task_state - [ ] put_resource_attributes +## migrationhub-config +0% implemented +- [ ] create_home_region_control +- [ ] describe_home_region_controls +- [ ] get_home_region + ## mobile 0% implemented - [ ] create_project @@ -4484,6 +5180,37 @@ - [ ] restore_db_cluster_from_snapshot - [ ] restore_db_cluster_to_point_in_time +## networkmanager +0% implemented +- [ ] associate_customer_gateway +- [ ] associate_link +- [ ] create_device +- [ ] create_global_network +- [ ] create_link +- [ ] create_site +- [ ] delete_device +- [ ] delete_global_network +- [ ] delete_link +- [ ] delete_site +- [ ] deregister_transit_gateway +- [ ] describe_global_networks +- [ ] disassociate_customer_gateway +- [ ] disassociate_link +- [ ] get_customer_gateway_associations +- [ ] get_devices +- [ ] get_link_associations +- [ ] get_links +- [ ] get_sites +- [ ] get_transit_gateway_registrations +- [ ] list_tags_for_resource +- [ ] register_transit_gateway +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_device +- [ ] update_global_network +- [ ] update_link +- [ ] update_site + ## opsworks 12% implemented - [ ] assign_instance @@ -4581,7 +5308,7 @@ - [ ] update_server_engine_attributes ## organizations -41% implemented +48% implemented - [ ] accept_handshake - [X] attach_policy - [ ] cancel_handshake @@ -4595,7 +5322,8 @@ - [ ] delete_organizational_unit - [ ] delete_policy - [X] describe_account -- [ ] describe_create_account_status +- [X] describe_create_account_status +- [ ] describe_effective_policy - [ ] describe_handshake - [X] describe_organization - [X] describe_organizational_unit @@ -4620,17 +5348,26 @@ - [X] list_policies - [X] list_policies_for_target - [X] list_roots -- [ ] list_tags_for_resource +- [X] list_tags_for_resource - [X] list_targets_for_policy - [X] move_account - [ ] remove_account_from_organization -- [ ] tag_resource -- [ ] untag_resource +- [X] tag_resource +- [X] untag_resource - [ ] update_organizational_unit - [ ] update_policy +## outposts +0% implemented +- [ ] create_outpost +- [ ] get_outpost +- [ ] get_outpost_instance_types +- [ ] list_outposts +- [ ] list_sites + ## personalize 0% implemented +- [ ] create_batch_inference_job - [ ] create_campaign - [ ] create_dataset - [ ] create_dataset_group @@ -4646,6 +5383,7 @@ - [ ] delete_schema - [ ] delete_solution - [ ] describe_algorithm +- [ ] describe_batch_inference_job - [ ] describe_campaign - [ ] describe_dataset - [ ] describe_dataset_group @@ -4657,6 +5395,7 @@ - [ ] describe_solution - [ ] describe_solution_version - [ ] get_solution_metrics +- [ ] list_batch_inference_jobs - [ ] list_campaigns - [ ] list_dataset_groups - [ ] list_dataset_import_jobs @@ -4686,9 +5425,14 @@ 0% implemented - [ ] create_app - [ ] create_campaign +- [ ] create_email_template - [ ] create_export_job - [ ] create_import_job +- [ ] create_journey +- [ ] create_push_template - [ ] create_segment +- [ ] create_sms_template +- [ ] create_voice_template - [ ] delete_adm_channel - [ ] delete_apns_channel - [ ] delete_apns_sandbox_channel @@ -4698,13 +5442,18 @@ - [ ] delete_baidu_channel - [ ] delete_campaign - [ ] delete_email_channel +- [ ] delete_email_template - [ ] delete_endpoint - [ ] delete_event_stream - [ ] delete_gcm_channel +- [ ] delete_journey +- [ ] delete_push_template - [ ] delete_segment - [ ] delete_sms_channel +- [ ] delete_sms_template - [ ] delete_user_endpoints - [ ] delete_voice_channel +- [ ] delete_voice_template - [ ] get_adm_channel - [ ] get_apns_channel - [ ] get_apns_sandbox_channel @@ -4723,6 +5472,7 @@ - [ ] get_campaigns - [ ] get_channels - [ ] get_email_channel +- [ ] get_email_template - [ ] get_endpoint - [ ] get_event_stream - [ ] get_export_job @@ -4730,6 +5480,11 @@ - [ ] get_gcm_channel - [ ] get_import_job - [ ] get_import_jobs +- [ ] get_journey +- [ ] get_journey_date_range_kpi +- [ ] get_journey_execution_activity_metrics +- [ ] get_journey_execution_metrics +- [ ] get_push_template - [ ] get_segment - [ ] get_segment_export_jobs - [ ] get_segment_import_jobs @@ -4737,9 +5492,13 @@ - [ ] get_segment_versions - [ ] get_segments - [ ] get_sms_channel +- [ ] get_sms_template - [ ] get_user_endpoints - [ ] get_voice_channel +- [ ] get_voice_template +- [ ] list_journeys - [ ] list_tags_for_resource +- [ ] list_templates - [ ] phone_number_validate - [ ] put_event_stream - [ ] put_events @@ -4757,12 +5516,18 @@ - [ ] update_baidu_channel - [ ] update_campaign - [ ] update_email_channel +- [ ] update_email_template - [ ] update_endpoint - [ ] update_endpoints_batch - [ ] update_gcm_channel +- [ ] update_journey +- [ ] update_journey_state +- [ ] update_push_template - [ ] update_segment - [ ] update_sms_channel +- [ ] update_sms_template - [ ] update_voice_channel +- [ ] update_voice_template ## pinpoint-email 0% implemented @@ -4837,39 +5602,117 @@ - [ ] get_attribute_values - [ ] get_products +## qldb +0% implemented +- [ ] create_ledger +- [ ] delete_ledger +- [ ] describe_journal_s3_export +- [ ] describe_ledger +- [ ] export_journal_to_s3 +- [ ] get_block +- [ ] get_digest +- [ ] get_revision +- [ ] list_journal_s3_exports +- [ ] list_journal_s3_exports_for_ledger +- [ ] list_ledgers +- [ ] list_tags_for_resource +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_ledger + +## qldb-session +0% implemented +- [ ] send_command + ## quicksight 0% implemented +- [ ] cancel_ingestion +- [ ] create_dashboard +- [ ] create_data_set +- [ ] create_data_source - [ ] create_group - [ ] create_group_membership +- [ ] create_iam_policy_assignment +- [ ] create_ingestion +- [ ] create_template +- [ ] create_template_alias +- [ ] delete_dashboard +- [ ] delete_data_set +- [ ] delete_data_source - [ ] delete_group - [ ] delete_group_membership +- [ ] delete_iam_policy_assignment +- [ ] delete_template +- [ ] delete_template_alias - [ ] delete_user - [ ] delete_user_by_principal_id +- [ ] describe_dashboard +- [ ] describe_dashboard_permissions +- [ ] describe_data_set +- [ ] describe_data_set_permissions +- [ ] describe_data_source +- [ ] describe_data_source_permissions - [ ] describe_group +- [ ] describe_iam_policy_assignment +- [ ] describe_ingestion +- [ ] describe_template +- [ ] describe_template_alias +- [ ] describe_template_permissions - [ ] describe_user - [ ] get_dashboard_embed_url +- [ ] list_dashboard_versions +- [ ] list_dashboards +- [ ] list_data_sets +- [ ] list_data_sources - [ ] list_group_memberships - [ ] list_groups +- [ ] list_iam_policy_assignments +- [ ] list_iam_policy_assignments_for_user +- [ ] list_ingestions +- [ ] list_tags_for_resource +- [ ] list_template_aliases +- [ ] list_template_versions +- [ ] list_templates - [ ] list_user_groups - [ ] list_users - [ ] register_user +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_dashboard +- [ ] update_dashboard_permissions +- [ ] update_dashboard_published_version +- [ ] update_data_set +- [ ] update_data_set_permissions +- [ ] update_data_source +- [ ] update_data_source_permissions - [ ] update_group +- [ ] update_iam_policy_assignment +- [ ] update_template +- [ ] update_template_alias +- [ ] update_template_permissions - [ ] update_user ## ram 0% implemented - [ ] accept_resource_share_invitation - [ ] associate_resource_share +- [ ] associate_resource_share_permission - [ ] create_resource_share - [ ] delete_resource_share - [ ] disassociate_resource_share +- [ ] disassociate_resource_share_permission - [ ] enable_sharing_with_aws_organization +- [ ] get_permission - [ ] get_resource_policies - [ ] get_resource_share_associations - [ ] get_resource_share_invitations - [ ] get_resource_shares +- [ ] list_pending_invitation_resources +- [ ] list_permissions - [ ] list_principals +- [ ] list_resource_share_permissions - [ ] list_resources +- [ ] promote_resource_share_created_from_policy - [ ] reject_resource_share_invitation - [ ] tag_resource - [ ] untag_resource @@ -4889,6 +5732,7 @@ - [ ] copy_db_parameter_group - [ ] copy_db_snapshot - [ ] copy_option_group +- [ ] create_custom_availability_zone - [ ] create_db_cluster - [ ] create_db_cluster_endpoint - [ ] create_db_cluster_parameter_group @@ -4896,12 +5740,14 @@ - [ ] create_db_instance - [ ] create_db_instance_read_replica - [ ] create_db_parameter_group +- [ ] create_db_proxy - [ ] create_db_security_group - [ ] create_db_snapshot - [ ] create_db_subnet_group - [ ] create_event_subscription - [ ] create_global_cluster - [ ] create_option_group +- [ ] delete_custom_availability_zone - [ ] delete_db_cluster - [ ] delete_db_cluster_endpoint - [ ] delete_db_cluster_parameter_group @@ -4909,14 +5755,18 @@ - [ ] delete_db_instance - [ ] delete_db_instance_automated_backup - [ ] delete_db_parameter_group +- [ ] delete_db_proxy - [ ] delete_db_security_group - [ ] delete_db_snapshot - [ ] delete_db_subnet_group - [ ] delete_event_subscription - [ ] delete_global_cluster +- [ ] delete_installation_media - [ ] delete_option_group +- [ ] deregister_db_proxy_targets - [ ] describe_account_attributes - [ ] describe_certificates +- [ ] describe_custom_availability_zones - [ ] describe_db_cluster_backtracks - [ ] describe_db_cluster_endpoints - [ ] describe_db_cluster_parameter_groups @@ -4930,6 +5780,9 @@ - [ ] describe_db_log_files - [ ] describe_db_parameter_groups - [ ] describe_db_parameters +- [ ] describe_db_proxies +- [ ] describe_db_proxy_target_groups +- [ ] describe_db_proxy_targets - [ ] describe_db_security_groups - [ ] describe_db_snapshot_attributes - [ ] describe_db_snapshots @@ -4940,6 +5793,7 @@ - [ ] describe_event_subscriptions - [ ] describe_events - [ ] describe_global_clusters +- [ ] describe_installation_media - [ ] describe_option_group_options - [ ] describe_option_groups - [ ] describe_orderable_db_instance_options @@ -4950,6 +5804,7 @@ - [ ] describe_valid_db_instance_modifications - [ ] download_db_log_file_portion - [ ] failover_db_cluster +- [ ] import_installation_media - [ ] list_tags_for_resource - [ ] modify_current_db_cluster_capacity - [ ] modify_db_cluster @@ -4958,6 +5813,8 @@ - [ ] modify_db_cluster_snapshot_attribute - [ ] modify_db_instance - [ ] modify_db_parameter_group +- [ ] modify_db_proxy +- [ ] modify_db_proxy_target_group - [ ] modify_db_snapshot - [ ] modify_db_snapshot_attribute - [ ] modify_db_subnet_group @@ -4968,6 +5825,7 @@ - [ ] promote_read_replica_db_cluster - [ ] purchase_reserved_db_instances_offering - [ ] reboot_db_instance +- [ ] register_db_proxy_targets - [ ] remove_from_global_cluster - [ ] remove_role_from_db_cluster - [ ] remove_role_from_db_instance @@ -4999,7 +5857,7 @@ - [ ] rollback_transaction ## redshift -32% implemented +30% implemented - [ ] accept_reserved_node_exchange - [ ] authorize_cluster_security_group_ingress - [ ] authorize_snapshot_access @@ -5015,6 +5873,7 @@ - [ ] create_event_subscription - [ ] create_hsm_client_certificate - [ ] create_hsm_configuration +- [ ] create_scheduled_action - [X] create_snapshot_copy_grant - [ ] create_snapshot_schedule - [X] create_tags @@ -5026,6 +5885,7 @@ - [ ] delete_event_subscription - [ ] delete_hsm_client_certificate - [ ] delete_hsm_configuration +- [ ] delete_scheduled_action - [X] delete_snapshot_copy_grant - [ ] delete_snapshot_schedule - [X] delete_tags @@ -5046,10 +5906,12 @@ - [ ] describe_hsm_client_certificates - [ ] describe_hsm_configurations - [ ] describe_logging_status +- [ ] describe_node_configuration_options - [ ] describe_orderable_cluster_options - [ ] describe_reserved_node_offerings - [ ] describe_reserved_nodes - [ ] describe_resize +- [ ] describe_scheduled_actions - [X] describe_snapshot_copy_grants - [ ] describe_snapshot_schedules - [ ] describe_storage @@ -5070,6 +5932,7 @@ - [ ] modify_cluster_snapshot_schedule - [ ] modify_cluster_subnet_group - [ ] modify_event_subscription +- [ ] modify_scheduled_action - [X] modify_snapshot_copy_retention_period - [ ] modify_snapshot_schedule - [ ] purchase_reserved_node_offering @@ -5086,12 +5949,17 @@ 0% implemented - [ ] compare_faces - [ ] create_collection +- [ ] create_project +- [ ] create_project_version - [ ] create_stream_processor - [ ] delete_collection - [ ] delete_faces - [ ] delete_stream_processor - [ ] describe_collection +- [ ] describe_project_versions +- [ ] describe_projects - [ ] describe_stream_processor +- [ ] detect_custom_labels - [ ] detect_faces - [ ] detect_labels - [ ] detect_moderation_labels @@ -5116,7 +5984,9 @@ - [ ] start_face_search - [ ] start_label_detection - [ ] start_person_tracking +- [ ] start_project_version - [ ] start_stream_processor +- [ ] stop_project_version - [ ] stop_stream_processor ## resource-groups @@ -5135,10 +6005,13 @@ - [X] update_group_query ## resourcegroupstaggingapi -60% implemented +37% implemented +- [ ] describe_report_creation +- [ ] get_compliance_summary - [X] get_resources - [X] get_tag_keys - [X] get_tag_values +- [ ] start_report_creation - [ ] tag_resources - [ ] untag_resources @@ -5304,63 +6177,63 @@ - [X] delete_bucket_cors - [ ] delete_bucket_encryption - [ ] delete_bucket_inventory_configuration -- [ ] delete_bucket_lifecycle +- [X] delete_bucket_lifecycle - [ ] delete_bucket_metrics_configuration - [X] delete_bucket_policy - [ ] delete_bucket_replication - [X] delete_bucket_tagging - [ ] delete_bucket_website -- [ ] delete_object +- [X] delete_object - [ ] delete_object_tagging -- [ ] delete_objects -- [ ] delete_public_access_block +- [X] delete_objects +- [X] delete_public_access_block - [ ] get_bucket_accelerate_configuration - [X] get_bucket_acl - [ ] get_bucket_analytics_configuration -- [ ] get_bucket_cors +- [X] get_bucket_cors - [ ] get_bucket_encryption - [ ] get_bucket_inventory_configuration -- [ ] get_bucket_lifecycle -- [ ] get_bucket_lifecycle_configuration -- [ ] get_bucket_location -- [ ] get_bucket_logging +- [X] get_bucket_lifecycle +- [X] get_bucket_lifecycle_configuration +- [X] get_bucket_location +- [X] get_bucket_logging - [ ] get_bucket_metrics_configuration - [ ] get_bucket_notification - [ ] get_bucket_notification_configuration - [X] get_bucket_policy -- [ ] get_bucket_policy_status +- [X] get_bucket_policy_status - [ ] get_bucket_replication - [ ] get_bucket_request_payment -- [ ] get_bucket_tagging +- [X] get_bucket_tagging - [X] get_bucket_versioning - [ ] get_bucket_website -- [ ] get_object -- [ ] get_object_acl +- [X] get_object +- [X] get_object_acl - [ ] get_object_legal_hold - [ ] get_object_lock_configuration - [ ] get_object_retention - [ ] get_object_tagging - [ ] get_object_torrent -- [ ] get_public_access_block +- [X] get_public_access_block - [ ] head_bucket - [ ] head_object - [ ] list_bucket_analytics_configurations - [ ] list_bucket_inventory_configurations - [ ] list_bucket_metrics_configurations -- [ ] list_buckets -- [ ] list_multipart_uploads +- [X] list_buckets +- [X] list_multipart_uploads - [ ] list_object_versions -- [ ] list_objects -- [ ] list_objects_v2 +- [X] list_objects +- [X] list_objects_v2 - [ ] list_parts - [X] put_bucket_accelerate_configuration -- [ ] put_bucket_acl +- [X] put_bucket_acl - [ ] put_bucket_analytics_configuration - [X] put_bucket_cors - [ ] put_bucket_encryption - [ ] put_bucket_inventory_configuration -- [ ] put_bucket_lifecycle -- [ ] put_bucket_lifecycle_configuration +- [X] put_bucket_lifecycle +- [X] put_bucket_lifecycle_configuration - [X] put_bucket_logging - [ ] put_bucket_metrics_configuration - [ ] put_bucket_notification @@ -5369,15 +6242,15 @@ - [ ] put_bucket_replication - [ ] put_bucket_request_payment - [X] put_bucket_tagging -- [ ] put_bucket_versioning +- [X] put_bucket_versioning - [ ] put_bucket_website -- [ ] put_object +- [X] put_object - [ ] put_object_acl - [ ] put_object_legal_hold - [ ] put_object_lock_configuration - [ ] put_object_retention - [ ] put_object_tagging -- [ ] put_public_access_block +- [X] put_public_access_block - [ ] restore_object - [ ] select_object_content - [ ] upload_part @@ -5385,11 +6258,19 @@ ## s3control 0% implemented +- [ ] create_access_point - [ ] create_job +- [ ] delete_access_point +- [ ] delete_access_point_policy - [ ] delete_public_access_block - [ ] describe_job +- [ ] get_access_point +- [ ] get_access_point_policy +- [ ] get_access_point_policy_status - [ ] get_public_access_block +- [ ] list_access_points - [ ] list_jobs +- [ ] put_access_point_policy - [ ] put_public_access_block - [ ] update_job_priority - [ ] update_job_status @@ -5397,85 +6278,192 @@ ## sagemaker 0% implemented - [ ] add_tags +- [ ] associate_trial_component - [ ] create_algorithm +- [ ] create_app +- [ ] create_auto_ml_job - [ ] create_code_repository - [ ] create_compilation_job +- [ ] create_domain - [ ] create_endpoint - [ ] create_endpoint_config +- [ ] create_experiment +- [ ] create_flow_definition +- [ ] create_human_task_ui - [ ] create_hyper_parameter_tuning_job - [ ] create_labeling_job - [ ] create_model - [ ] create_model_package +- [ ] create_monitoring_schedule - [ ] create_notebook_instance - [ ] create_notebook_instance_lifecycle_config +- [ ] create_presigned_domain_url - [ ] create_presigned_notebook_instance_url +- [ ] create_processing_job - [ ] create_training_job - [ ] create_transform_job +- [ ] create_trial +- [ ] create_trial_component +- [ ] create_user_profile - [ ] create_workteam - [ ] delete_algorithm +- [ ] delete_app - [ ] delete_code_repository +- [ ] delete_domain - [ ] delete_endpoint - [ ] delete_endpoint_config +- [ ] delete_experiment +- [ ] delete_flow_definition - [ ] delete_model - [ ] delete_model_package +- [ ] delete_monitoring_schedule - [ ] delete_notebook_instance - [ ] delete_notebook_instance_lifecycle_config - [ ] delete_tags +- [ ] delete_trial +- [ ] delete_trial_component +- [ ] delete_user_profile - [ ] delete_workteam - [ ] describe_algorithm +- [ ] describe_app +- [ ] describe_auto_ml_job - [ ] describe_code_repository - [ ] describe_compilation_job +- [ ] describe_domain - [ ] describe_endpoint - [ ] describe_endpoint_config +- [ ] describe_experiment +- [ ] describe_flow_definition +- [ ] describe_human_task_ui - [ ] describe_hyper_parameter_tuning_job - [ ] describe_labeling_job - [ ] describe_model - [ ] describe_model_package +- [ ] describe_monitoring_schedule - [ ] describe_notebook_instance - [ ] describe_notebook_instance_lifecycle_config +- [ ] describe_processing_job - [ ] describe_subscribed_workteam - [ ] describe_training_job - [ ] describe_transform_job +- [ ] describe_trial +- [ ] describe_trial_component +- [ ] describe_user_profile - [ ] describe_workteam +- [ ] disassociate_trial_component - [ ] get_search_suggestions - [ ] list_algorithms +- [ ] list_apps +- [ ] list_auto_ml_jobs +- [ ] list_candidates_for_auto_ml_job - [ ] list_code_repositories - [ ] list_compilation_jobs +- [ ] list_domains - [ ] list_endpoint_configs - [ ] list_endpoints +- [ ] list_experiments +- [ ] list_flow_definitions +- [ ] list_human_task_uis - [ ] list_hyper_parameter_tuning_jobs - [ ] list_labeling_jobs - [ ] list_labeling_jobs_for_workteam - [ ] list_model_packages - [ ] list_models +- [ ] list_monitoring_executions +- [ ] list_monitoring_schedules - [ ] list_notebook_instance_lifecycle_configs - [ ] list_notebook_instances +- [ ] list_processing_jobs - [ ] list_subscribed_workteams - [ ] list_tags - [ ] list_training_jobs - [ ] list_training_jobs_for_hyper_parameter_tuning_job - [ ] list_transform_jobs +- [ ] list_trial_components +- [ ] list_trials +- [ ] list_user_profiles - [ ] list_workteams - [ ] render_ui_template - [ ] search +- [ ] start_monitoring_schedule - [ ] start_notebook_instance +- [ ] stop_auto_ml_job - [ ] stop_compilation_job - [ ] stop_hyper_parameter_tuning_job - [ ] stop_labeling_job +- [ ] stop_monitoring_schedule - [ ] stop_notebook_instance +- [ ] stop_processing_job - [ ] stop_training_job - [ ] stop_transform_job - [ ] update_code_repository +- [ ] update_domain - [ ] update_endpoint - [ ] update_endpoint_weights_and_capacities +- [ ] update_experiment +- [ ] update_monitoring_schedule - [ ] update_notebook_instance - [ ] update_notebook_instance_lifecycle_config +- [ ] update_trial +- [ ] update_trial_component +- [ ] update_user_profile - [ ] update_workteam +## sagemaker-a2i-runtime +0% implemented +- [ ] delete_human_loop +- [ ] describe_human_loop +- [ ] list_human_loops +- [ ] start_human_loop +- [ ] stop_human_loop + ## sagemaker-runtime 0% implemented - [ ] invoke_endpoint +## savingsplans +0% implemented +- [ ] create_savings_plan +- [ ] describe_savings_plan_rates +- [ ] describe_savings_plans +- [ ] describe_savings_plans_offering_rates +- [ ] describe_savings_plans_offerings +- [ ] list_tags_for_resource +- [ ] tag_resource +- [ ] untag_resource + +## schemas +0% implemented +- [ ] create_discoverer +- [ ] create_registry +- [ ] create_schema +- [ ] delete_discoverer +- [ ] delete_registry +- [ ] delete_schema +- [ ] delete_schema_version +- [ ] describe_code_binding +- [ ] describe_discoverer +- [ ] describe_registry +- [ ] describe_schema +- [ ] get_code_binding_source +- [ ] get_discovered_schema +- [ ] list_discoverers +- [ ] list_registries +- [ ] list_schema_versions +- [ ] list_schemas +- [ ] list_tags_for_resource +- [ ] lock_service_linked_role +- [ ] put_code_binding +- [ ] search_schemas +- [ ] start_discoverer +- [ ] stop_discoverer +- [ ] tag_resource +- [ ] unlock_service_linked_role +- [ ] untag_resource +- [ ] update_discoverer +- [ ] update_registry +- [ ] update_schema + ## sdb 0% implemented - [ ] batch_delete_attributes @@ -5490,14 +6478,14 @@ - [ ] select ## secretsmanager -55% implemented +61% implemented - [ ] cancel_rotate_secret - [X] create_secret - [ ] delete_resource_policy - [X] delete_secret - [X] describe_secret - [X] get_random_password -- [ ] get_resource_policy +- [X] get_resource_policy - [X] get_secret_value - [X] list_secret_version_ids - [X] list_secrets @@ -5696,7 +6684,7 @@ - [ ] update_service ## ses -12% implemented +14% implemented - [ ] clone_receipt_rule_set - [ ] create_configuration_set - [ ] create_configuration_set_event_destination @@ -5747,7 +6735,7 @@ - [ ] send_custom_verification_email - [X] send_email - [X] send_raw_email -- [ ] send_templated_email +- [X] send_templated_email - [ ] set_active_receipt_rule_set - [ ] set_identity_dkim_enabled - [ ] set_identity_feedback_forwarding_enabled @@ -5769,6 +6757,58 @@ - [X] verify_email_address - [X] verify_email_identity +## sesv2 +0% implemented +- [ ] create_configuration_set +- [ ] create_configuration_set_event_destination +- [ ] create_dedicated_ip_pool +- [ ] create_deliverability_test_report +- [ ] create_email_identity +- [ ] delete_configuration_set +- [ ] delete_configuration_set_event_destination +- [ ] delete_dedicated_ip_pool +- [ ] delete_email_identity +- [ ] delete_suppressed_destination +- [ ] get_account +- [ ] get_blacklist_reports +- [ ] get_configuration_set +- [ ] get_configuration_set_event_destinations +- [ ] get_dedicated_ip +- [ ] get_dedicated_ips +- [ ] get_deliverability_dashboard_options +- [ ] get_deliverability_test_report +- [ ] get_domain_deliverability_campaign +- [ ] get_domain_statistics_report +- [ ] get_email_identity +- [ ] get_suppressed_destination +- [ ] list_configuration_sets +- [ ] list_dedicated_ip_pools +- [ ] list_deliverability_test_reports +- [ ] list_domain_deliverability_campaigns +- [ ] list_email_identities +- [ ] list_suppressed_destinations +- [ ] list_tags_for_resource +- [ ] put_account_dedicated_ip_warmup_attributes +- [ ] put_account_sending_attributes +- [ ] put_account_suppression_attributes +- [ ] put_configuration_set_delivery_options +- [ ] put_configuration_set_reputation_options +- [ ] put_configuration_set_sending_options +- [ ] put_configuration_set_suppression_options +- [ ] put_configuration_set_tracking_options +- [ ] put_dedicated_ip_in_pool +- [ ] put_dedicated_ip_warmup_attributes +- [ ] put_deliverability_dashboard_option +- [ ] put_email_identity_dkim_attributes +- [ ] put_email_identity_dkim_signing_attributes +- [ ] put_email_identity_feedback_attributes +- [ ] put_email_identity_mail_from_attributes +- [ ] put_suppressed_destination +- [ ] send_email +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_configuration_set_event_destination + ## shield 0% implemented - [ ] associate_drt_log_bucket @@ -5799,8 +6839,11 @@ - [ ] list_signing_jobs - [ ] list_signing_platforms - [ ] list_signing_profiles +- [ ] list_tags_for_resource - [ ] put_signing_profile - [ ] start_signing_job +- [ ] tag_resource +- [ ] untag_resource ## sms 0% implemented @@ -5858,6 +6901,7 @@ - [ ] get_job_manifest - [ ] get_job_unlock_code - [ ] get_snowball_usage +- [ ] get_software_updates - [ ] list_cluster_jobs - [ ] list_clusters - [ ] list_compatible_images @@ -5866,8 +6910,8 @@ - [ ] update_job ## sns -48% implemented -- [ ] add_permission +63% implemented +- [X] add_permission - [ ] check_if_phone_number_is_opted_out - [ ] confirm_subscription - [X] create_platform_application @@ -5886,23 +6930,23 @@ - [X] list_platform_applications - [X] list_subscriptions - [ ] list_subscriptions_by_topic -- [ ] list_tags_for_resource +- [X] list_tags_for_resource - [X] list_topics - [ ] opt_in_phone_number - [X] publish -- [ ] remove_permission +- [X] remove_permission - [X] set_endpoint_attributes - [ ] set_platform_application_attributes - [ ] set_sms_attributes - [X] set_subscription_attributes - [ ] set_topic_attributes - [X] subscribe -- [ ] tag_resource +- [X] tag_resource - [X] unsubscribe -- [ ] untag_resource +- [X] untag_resource ## sqs -65% implemented +85% implemented - [X] add_permission - [X] change_message_visibility - [ ] change_message_visibility_batch @@ -5910,22 +6954,22 @@ - [X] delete_message - [ ] delete_message_batch - [X] delete_queue -- [ ] get_queue_attributes -- [ ] get_queue_url +- [X] get_queue_attributes +- [X] get_queue_url - [X] list_dead_letter_source_queues -- [ ] list_queue_tags +- [X] list_queue_tags - [X] list_queues - [X] purge_queue - [ ] receive_message - [X] remove_permission - [X] send_message -- [ ] send_message_batch +- [X] send_message_batch - [X] set_queue_attributes - [X] tag_queue - [X] untag_queue ## ssm -10% implemented +11% implemented - [X] add_tags_to_resource - [ ] cancel_command - [ ] cancel_maintenance_window_execution @@ -5976,13 +7020,14 @@ - [ ] describe_maintenance_windows - [ ] describe_maintenance_windows_for_target - [ ] describe_ops_items -- [ ] describe_parameters +- [X] describe_parameters - [ ] describe_patch_baselines - [ ] describe_patch_group_state - [ ] describe_patch_groups - [ ] describe_patch_properties - [ ] describe_sessions - [ ] get_automation_execution +- [ ] get_calendar_state - [X] get_command_invocation - [ ] get_connection_status - [ ] get_default_patch_baseline @@ -5998,7 +7043,7 @@ - [ ] get_ops_item - [ ] get_ops_summary - [X] get_parameter -- [ ] get_parameter_history +- [X] get_parameter_history - [X] get_parameters - [X] get_parameters_by_path - [ ] get_patch_baseline @@ -6045,29 +7090,43 @@ - [ ] update_managed_instance_role - [ ] update_ops_item - [ ] update_patch_baseline +- [ ] update_resource_data_sync - [ ] update_service_setting -## stepfunctions +## sso 0% implemented +- [ ] get_role_credentials +- [ ] list_account_roles +- [ ] list_accounts +- [ ] logout + +## sso-oidc +0% implemented +- [ ] create_token +- [ ] register_client +- [ ] start_device_authorization + +## stepfunctions +36% implemented - [ ] create_activity -- [ ] create_state_machine +- [X] create_state_machine - [ ] delete_activity -- [ ] delete_state_machine +- [X] delete_state_machine - [ ] describe_activity -- [ ] describe_execution -- [ ] describe_state_machine +- [X] describe_execution +- [X] describe_state_machine - [ ] describe_state_machine_for_execution - [ ] get_activity_task - [ ] get_execution_history - [ ] list_activities -- [ ] list_executions -- [ ] list_state_machines +- [X] list_executions +- [X] list_state_machines - [ ] list_tags_for_resource - [ ] send_task_failure - [ ] send_task_heartbeat - [ ] send_task_success -- [ ] start_execution -- [ ] stop_execution +- [X] start_execution +- [X] stop_execution - [ ] tag_resource - [ ] untag_resource - [ ] update_state_machine @@ -6099,6 +7158,7 @@ - [ ] delete_tape - [ ] delete_tape_archive - [ ] delete_volume +- [ ] describe_availability_monitor_test - [ ] describe_bandwidth_rate_limit - [ ] describe_cache - [ ] describe_cached_iscsi_volumes @@ -6136,6 +7196,7 @@ - [ ] set_local_console_password - [ ] set_smb_guest_password - [ ] shutdown_gateway +- [ ] start_availability_monitor_test - [ ] start_gateway - [ ] update_bandwidth_rate_limit - [ ] update_chap_credentials @@ -6428,6 +7489,45 @@ - [ ] update_web_acl - [ ] update_xss_match_set +## wafv2 +0% implemented +- [ ] associate_web_acl +- [ ] check_capacity +- [ ] create_ip_set +- [ ] create_regex_pattern_set +- [ ] create_rule_group +- [ ] create_web_acl +- [ ] delete_ip_set +- [ ] delete_logging_configuration +- [ ] delete_regex_pattern_set +- [ ] delete_rule_group +- [ ] delete_web_acl +- [ ] describe_managed_rule_group +- [ ] disassociate_web_acl +- [ ] get_ip_set +- [ ] get_logging_configuration +- [ ] get_rate_based_statement_managed_keys +- [ ] get_regex_pattern_set +- [ ] get_rule_group +- [ ] get_sampled_requests +- [ ] get_web_acl +- [ ] get_web_acl_for_resource +- [ ] list_available_managed_rule_groups +- [ ] list_ip_sets +- [ ] list_logging_configurations +- [ ] list_regex_pattern_sets +- [ ] list_resources_for_web_acl +- [ ] list_rule_groups +- [ ] list_tags_for_resource +- [ ] list_web_acls +- [ ] put_logging_configuration +- [ ] tag_resource +- [ ] untag_resource +- [ ] update_ip_set +- [ ] update_regex_pattern_set +- [ ] update_rule_group +- [ ] update_web_acl + ## workdocs 0% implemented - [ ] abort_document_version_upload @@ -6541,6 +7641,10 @@ - [ ] update_primary_email_address - [ ] update_resource +## workmailmessageflow +0% implemented +- [ ] get_raw_message_content + ## workspaces 0% implemented - [ ] associate_ip_groups @@ -6552,6 +7656,7 @@ - [ ] delete_ip_group - [ ] delete_tags - [ ] delete_workspace_image +- [ ] deregister_workspace_directory - [ ] describe_account - [ ] describe_account_modifications - [ ] describe_client_properties @@ -6560,6 +7665,7 @@ - [ ] describe_workspace_bundles - [ ] describe_workspace_directories - [ ] describe_workspace_images +- [ ] describe_workspace_snapshots - [ ] describe_workspaces - [ ] describe_workspaces_connection_status - [ ] disassociate_ip_groups @@ -6567,10 +7673,15 @@ - [ ] list_available_management_cidr_ranges - [ ] modify_account - [ ] modify_client_properties +- [ ] modify_selfservice_permissions +- [ ] modify_workspace_access_properties +- [ ] modify_workspace_creation_properties - [ ] modify_workspace_properties - [ ] modify_workspace_state - [ ] reboot_workspaces - [ ] rebuild_workspaces +- [ ] register_workspace_directory +- [ ] restore_workspace - [ ] revoke_ip_rules - [ ] start_workspaces - [ ] stop_workspaces diff --git a/Makefile b/Makefile index 2a7249760..e84d036b7 100644 --- a/Makefile +++ b/Makefile @@ -14,12 +14,16 @@ init: lint: flake8 moto + black --check moto/ tests/ -test: lint +test-only: rm -f .coverage rm -rf cover @nosetests -sv --with-coverage --cover-html ./tests/ $(TEST_EXCLUDE) + +test: lint test-only + test_server: @TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/ @@ -27,7 +31,8 @@ aws_managed_policies: scripts/update_managed_policies.py upload_pypi_artifact: - python setup.py sdist bdist_wheel upload + python setup.py sdist bdist_wheel + twine upload dist/* push_dockerhub_image: docker build -t motoserver/moto . diff --git a/README.md b/README.md index 4e39ada35..4024328a9 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,9 @@ [![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org) ![PyPI](https://img.shields.io/pypi/v/moto.svg) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg) -![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) +![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -# In a nutshell +## In a nutshell Moto is a library that allows your tests to easily mock out AWS Services. @@ -297,6 +297,9 @@ def test_describe_instances_allowed(): See [the related test suite](https://github.com/spulec/moto/blob/master/tests/test_core/test_auth.py) for more examples. +## Experimental: AWS Config Querying +For details about the experimental AWS Config support please see the [AWS Config readme here](CONFIG_README.md). + ## Very Important -- Recommended Usage There are some important caveats to be aware of when using moto: diff --git a/docs/_build/html/_sources/index.rst.txt b/docs/_build/html/_sources/index.rst.txt index 0c4133048..fc5ed7652 100644 --- a/docs/_build/html/_sources/index.rst.txt +++ b/docs/_build/html/_sources/index.rst.txt @@ -30,6 +30,8 @@ Currently implemented Services: +-----------------------+---------------------+-----------------------------------+ | Data Pipeline | @mock_datapipeline | basic endpoints done | +-----------------------+---------------------+-----------------------------------+ +| DataSync | @mock_datasync | some endpoints done | ++-----------------------+---------------------+-----------------------------------+ | - DynamoDB | - @mock_dynamodb | - core endpoints done | | - DynamoDB2 | - @mock_dynamodb2 | - core endpoints + partial indexes| +-----------------------+---------------------+-----------------------------------+ diff --git a/docs/index.rst b/docs/index.rst index 4811fb797..6311597fe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -94,6 +94,8 @@ Currently implemented Services: +---------------------------+-----------------------+------------------------------------+ | SES | @mock_ses | all endpoints done | +---------------------------+-----------------------+------------------------------------+ +| SFN | @mock_stepfunctions | basic endpoints done | ++---------------------------+-----------------------+------------------------------------+ | SNS | @mock_sns | all endpoints done | +---------------------------+-----------------------+------------------------------------+ | SQS | @mock_sqs | core endpoints done | diff --git a/moto/__init__.py b/moto/__init__.py index 8594cedd2..a9f1bb8ba 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -1,62 +1,75 @@ from __future__ import unicode_literals -import logging + +from .acm import mock_acm # noqa +from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa +from .athena import mock_athena # noqa +from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # noqa +from .awslambda import mock_lambda, mock_lambda_deprecated # noqa +from .batch import mock_batch # noqa +from .cloudformation import mock_cloudformation # noqa +from .cloudformation import mock_cloudformation_deprecated # noqa +from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa +from .codepipeline import mock_codepipeline # noqa +from .cognitoidentity import mock_cognitoidentity # noqa +from .cognitoidentity import mock_cognitoidentity_deprecated # noqa +from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa +from .config import mock_config # noqa +from .datapipeline import mock_datapipeline # noqa +from .datapipeline import mock_datapipeline_deprecated # noqa +from .datasync import mock_datasync # noqa +from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa +from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa +from .dynamodbstreams import mock_dynamodbstreams # noqa +from .ec2 import mock_ec2, mock_ec2_deprecated # noqa +from .ecr import mock_ecr, mock_ecr_deprecated # noqa +from .ecs import mock_ecs, mock_ecs_deprecated # noqa +from .elb import mock_elb, mock_elb_deprecated # noqa +from .elbv2 import mock_elbv2 # noqa +from .emr import mock_emr, mock_emr_deprecated # noqa +from .events import mock_events # noqa +from .glacier import mock_glacier, mock_glacier_deprecated # noqa +from .glue import mock_glue # noqa +from .iam import mock_iam, mock_iam_deprecated # noqa +from .iot import mock_iot # noqa +from .iotdata import mock_iotdata # noqa +from .kinesis import mock_kinesis, mock_kinesis_deprecated # noqa +from .kms import mock_kms, mock_kms_deprecated # noqa +from .logs import mock_logs, mock_logs_deprecated # noqa +from .opsworks import mock_opsworks, mock_opsworks_deprecated # noqa +from .organizations import mock_organizations # noqa +from .polly import mock_polly # noqa +from .rds import mock_rds, mock_rds_deprecated # noqa +from .rds2 import mock_rds2, mock_rds2_deprecated # noqa +from .redshift import mock_redshift, mock_redshift_deprecated # noqa +from .resourcegroups import mock_resourcegroups # noqa +from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # noqa +from .route53 import mock_route53, mock_route53_deprecated # noqa +from .s3 import mock_s3, mock_s3_deprecated # noqa +from .secretsmanager import mock_secretsmanager # noqa +from .ses import mock_ses, mock_ses_deprecated # noqa +from .sns import mock_sns, mock_sns_deprecated # noqa +from .sqs import mock_sqs, mock_sqs_deprecated # noqa +from .ssm import mock_ssm # noqa +from .stepfunctions import mock_stepfunctions # noqa +from .sts import mock_sts, mock_sts_deprecated # noqa +from .swf import mock_swf, mock_swf_deprecated # noqa +from .xray import XRaySegment, mock_xray, mock_xray_client # noqa + +# import logging # logging.getLogger('boto').setLevel(logging.CRITICAL) -__title__ = 'moto' -__version__ = '1.3.14.dev' - -from .acm import mock_acm # flake8: noqa -from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa -from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa -from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa -from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # flake8: noqa -from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # flake8: noqa -from .cognitoidentity import mock_cognitoidentity, mock_cognitoidentity_deprecated # flake8: noqa -from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # flake8: noqa -from .config import mock_config # flake8: noqa -from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # flake8: noqa -from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # flake8: noqa -from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # flake8: noqa -from .dynamodbstreams import mock_dynamodbstreams # flake8: noqa -from .ec2 import mock_ec2, mock_ec2_deprecated # flake8: noqa -from .ecr import mock_ecr, mock_ecr_deprecated # flake8: noqa -from .ecs import mock_ecs, mock_ecs_deprecated # flake8: noqa -from .elb import mock_elb, mock_elb_deprecated # flake8: noqa -from .elbv2 import mock_elbv2 # flake8: noqa -from .emr import mock_emr, mock_emr_deprecated # flake8: noqa -from .events import mock_events # flake8: noqa -from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa -from .glue import mock_glue # flake8: noqa -from .iam import mock_iam, mock_iam_deprecated # flake8: noqa -from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa -from .kms import mock_kms, mock_kms_deprecated # flake8: noqa -from .organizations import mock_organizations # flake8: noqa -from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa -from .polly import mock_polly # flake8: noqa -from .rds import mock_rds, mock_rds_deprecated # flake8: noqa -from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa -from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa -from .resourcegroups import mock_resourcegroups # flake8: noqa -from .s3 import mock_s3, mock_s3_deprecated # flake8: noqa -from .ses import mock_ses, mock_ses_deprecated # flake8: noqa -from .secretsmanager import mock_secretsmanager # flake8: noqa -from .sns import mock_sns, mock_sns_deprecated # flake8: noqa -from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa -from .sts import mock_sts, mock_sts_deprecated # flake8: noqa -from .ssm import mock_ssm # flake8: noqa -from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa -from .swf import mock_swf, mock_swf_deprecated # flake8: noqa -from .xray import mock_xray, mock_xray_client, XRaySegment # flake8: noqa -from .logs import mock_logs, mock_logs_deprecated # flake8: noqa -from .batch import mock_batch # flake8: noqa -from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # flake8: noqa -from .iot import mock_iot # flake8: noqa -from .iotdata import mock_iotdata # flake8: noqa +__title__ = "moto" +__version__ = "1.3.15.dev" try: # Need to monkey-patch botocore requests back to underlying urllib3 classes - from botocore.awsrequest import HTTPSConnectionPool, HTTPConnectionPool, HTTPConnection, VerifiedHTTPSConnection + from botocore.awsrequest import ( + HTTPSConnectionPool, + HTTPConnectionPool, + HTTPConnection, + VerifiedHTTPSConnection, + ) except ImportError: pass else: diff --git a/moto/acm/__init__.py b/moto/acm/__init__.py index 6cd8a4aa5..07804282e 100644 --- a/moto/acm/__init__.py +++ b/moto/acm/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import acm_backends from ..core.models import base_decorator -acm_backend = acm_backends['us-east-1'] +acm_backend = acm_backends["us-east-1"] mock_acm = base_decorator(acm_backends) diff --git a/moto/acm/models.py b/moto/acm/models.py index b25dbcdff..3df541982 100644 --- a/moto/acm/models.py +++ b/moto/acm/models.py @@ -13,8 +13,9 @@ import cryptography.hazmat.primitives.asymmetric.rsa from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.backends import default_backend +from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID + -DEFAULT_ACCOUNT_ID = 123456789012 GOOGLE_ROOT_CA = b"""-----BEGIN CERTIFICATE----- MIIEKDCCAxCgAwIBAgIQAQAhJYiw+lmnd+8Fe2Yn3zANBgkqhkiG9w0BAQsFADBC MQswCQYDVQQGEwJVUzEWMBQGA1UEChMNR2VvVHJ1c3QgSW5jLjEbMBkGA1UEAxMS @@ -57,20 +58,29 @@ class AWSError(Exception): self.message = message def response(self): - resp = {'__type': self.TYPE, 'message': self.message} + resp = {"__type": self.TYPE, "message": self.message} return json.dumps(resp), dict(status=self.STATUS) class AWSValidationException(AWSError): - TYPE = 'ValidationException' + TYPE = "ValidationException" class AWSResourceNotFoundException(AWSError): - TYPE = 'ResourceNotFoundException' + TYPE = "ResourceNotFoundException" class CertBundle(BaseModel): - def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'): + def __init__( + self, + certificate, + private_key, + chain=None, + region="us-east-1", + arn=None, + cert_type="IMPORTED", + cert_status="ISSUED", + ): self.created_at = datetime.datetime.now() self.cert = certificate self._cert = None @@ -87,7 +97,7 @@ class CertBundle(BaseModel): if self.chain is None: self.chain = GOOGLE_ROOT_CA else: - self.chain += b'\n' + GOOGLE_ROOT_CA + self.chain += b"\n" + GOOGLE_ROOT_CA # Takes care of PEM checking self.validate_pk() @@ -114,149 +124,209 @@ class CertBundle(BaseModel): sans.add(domain_name) sans = [cryptography.x509.DNSName(item) for item in sans] - key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) - subject = cryptography.x509.Name([ - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name), - ]) - issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"), - ]) - cert = cryptography.x509.CertificateBuilder().subject_name( - subject - ).issuer_name( - issuer - ).public_key( - key.public_key() - ).serial_number( - cryptography.x509.random_serial_number() - ).not_valid_before( - datetime.datetime.utcnow() - ).not_valid_after( - datetime.datetime.utcnow() + datetime.timedelta(days=365) - ).add_extension( - cryptography.x509.SubjectAlternativeName(sans), - critical=False, - ).sign(key, hashes.SHA512(), default_backend()) + key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + subject = cryptography.x509.Name( + [ + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COUNTRY_NAME, "US" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COMMON_NAME, domain_name + ), + ] + ) + issuer = cryptography.x509.Name( + [ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COUNTRY_NAME, "US" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COMMON_NAME, "Amazon" + ), + ] + ) + cert = ( + cryptography.x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(cryptography.x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + .add_extension( + cryptography.x509.SubjectAlternativeName(sans), critical=False + ) + .sign(key, hashes.SHA512(), default_backend()) + ) cert_armored = cert.public_bytes(serialization.Encoding.PEM) private_key = key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) - return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION', region=region) + return cls( + cert_armored, + private_key, + cert_type="AMAZON_ISSUED", + cert_status="PENDING_VALIDATION", + region=region, + ) def validate_pk(self): try: - self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend()) + self._key = serialization.load_pem_private_key( + self.key, password=None, backend=default_backend() + ) if self._key.key_size > 2048: - AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.') + AWSValidationException( + "The private key length is not supported. Only 1024-bit and 2048-bit are allowed." + ) except Exception as err: if isinstance(err, AWSValidationException): raise - raise AWSValidationException('The private key is not PEM-encoded or is not valid.') + raise AWSValidationException( + "The private key is not PEM-encoded or is not valid." + ) def validate_certificate(self): try: - self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend()) + self._cert = cryptography.x509.load_pem_x509_certificate( + self.cert, default_backend() + ) now = datetime.datetime.utcnow() if self._cert.not_valid_after < now: - raise AWSValidationException('The certificate has expired, is not valid.') + raise AWSValidationException( + "The certificate has expired, is not valid." + ) if self._cert.not_valid_before > now: - raise AWSValidationException('The certificate is not in effect yet, is not valid.') + raise AWSValidationException( + "The certificate is not in effect yet, is not valid." + ) # Extracting some common fields for ease of use # Have to search through cert.subject for OIDs - self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value + self.common_name = self._cert.subject.get_attributes_for_oid( + cryptography.x509.OID_COMMON_NAME + )[0].value except Exception as err: if isinstance(err, AWSValidationException): raise - raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') + raise AWSValidationException( + "The certificate is not PEM-encoded or is not valid." + ) def validate_chain(self): try: self._chain = [] - for cert_armored in self.chain.split(b'-\n-'): + for cert_armored in self.chain.split(b"-\n-"): # Would leave encoded but Py2 does not have raw binary strings cert_armored = cert_armored.decode() # Fix missing -'s on split - cert_armored = re.sub(r'^----B', '-----B', cert_armored) - cert_armored = re.sub(r'E----$', 'E-----', cert_armored) - cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend()) + cert_armored = re.sub(r"^----B", "-----B", cert_armored) + cert_armored = re.sub(r"E----$", "E-----", cert_armored) + cert = cryptography.x509.load_pem_x509_certificate( + cert_armored.encode(), default_backend() + ) self._chain.append(cert) now = datetime.datetime.now() if self._cert.not_valid_after < now: - raise AWSValidationException('The certificate chain has expired, is not valid.') + raise AWSValidationException( + "The certificate chain has expired, is not valid." + ) if self._cert.not_valid_before > now: - raise AWSValidationException('The certificate chain is not in effect yet, is not valid.') + raise AWSValidationException( + "The certificate chain is not in effect yet, is not valid." + ) except Exception as err: if isinstance(err, AWSValidationException): raise - raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') + raise AWSValidationException( + "The certificate is not PEM-encoded or is not valid." + ) def check(self): # Basically, if the certificate is pending, and then checked again after 1 min # It will appear as if its been validated - if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \ - (datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min - self.status = 'ISSUED' + if ( + self.type == "AMAZON_ISSUED" + and self.status == "PENDING_VALIDATION" + and (datetime.datetime.now() - self.created_at).total_seconds() > 60 + ): # 1min + self.status = "ISSUED" def describe(self): # 'RenewalSummary': {}, # Only when cert is amazon issued if self._key.key_size == 1024: - key_algo = 'RSA_1024' + key_algo = "RSA_1024" elif self._key.key_size == 2048: - key_algo = 'RSA_2048' + key_algo = "RSA_2048" else: - key_algo = 'EC_prime256v1' + key_algo = "EC_prime256v1" # Look for SANs - san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME) + san_obj = self._cert.extensions.get_extension_for_oid( + cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME + ) sans = [] if san_obj is not None: sans = [item.value for item in san_obj.value] result = { - 'Certificate': { - 'CertificateArn': self.arn, - 'DomainName': self.common_name, - 'InUseBy': [], - 'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value, - 'KeyAlgorithm': key_algo, - 'NotAfter': datetime_to_epoch(self._cert.not_valid_after), - 'NotBefore': datetime_to_epoch(self._cert.not_valid_before), - 'Serial': self._cert.serial_number, - 'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''), - 'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED. - 'Subject': 'CN={0}'.format(self.common_name), - 'SubjectAlternativeNames': sans, - 'Type': self.type # One of IMPORTED, AMAZON_ISSUED + "Certificate": { + "CertificateArn": self.arn, + "DomainName": self.common_name, + "InUseBy": [], + "Issuer": self._cert.issuer.get_attributes_for_oid( + cryptography.x509.OID_COMMON_NAME + )[0].value, + "KeyAlgorithm": key_algo, + "NotAfter": datetime_to_epoch(self._cert.not_valid_after), + "NotBefore": datetime_to_epoch(self._cert.not_valid_before), + "Serial": self._cert.serial_number, + "SignatureAlgorithm": self._cert.signature_algorithm_oid._name.upper().replace( + "ENCRYPTION", "" + ), + "Status": self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED. + "Subject": "CN={0}".format(self.common_name), + "SubjectAlternativeNames": sans, + "Type": self.type, # One of IMPORTED, AMAZON_ISSUED } } - if self.type == 'IMPORTED': - result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at) + if self.type == "IMPORTED": + result["Certificate"]["ImportedAt"] = datetime_to_epoch(self.created_at) else: - result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at) - result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at) + result["Certificate"]["CreatedAt"] = datetime_to_epoch(self.created_at) + result["Certificate"]["IssuedAt"] = datetime_to_epoch(self.created_at) return result @@ -264,7 +334,7 @@ class CertBundle(BaseModel): return self.arn def __repr__(self): - return '' + return "" class AWSCertificateManagerBackend(BaseBackend): @@ -281,7 +351,9 @@ class AWSCertificateManagerBackend(BaseBackend): @staticmethod def _arn_not_found(arn): - msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID) + msg = "Certificate with arn {0} not found in account {1}".format( + arn, DEFAULT_ACCOUNT_ID + ) return AWSResourceNotFoundException(msg) def _get_arn_from_idempotency_token(self, token): @@ -298,17 +370,20 @@ class AWSCertificateManagerBackend(BaseBackend): """ now = datetime.datetime.now() if token in self._idempotency_tokens: - if self._idempotency_tokens[token]['expires'] < now: + if self._idempotency_tokens[token]["expires"] < now: # Token has expired, new request del self._idempotency_tokens[token] return None else: - return self._idempotency_tokens[token]['arn'] + return self._idempotency_tokens[token]["arn"] return None def _set_idempotency_token_arn(self, token, arn): - self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)} + self._idempotency_tokens[token] = { + "arn": arn, + "expires": datetime.datetime.now() + datetime.timedelta(hours=1), + } def import_cert(self, certificate, private_key, chain=None, arn=None): if arn is not None: @@ -316,7 +391,9 @@ class AWSCertificateManagerBackend(BaseBackend): raise self._arn_not_found(arn) else: # Will reuse provided ARN - bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn) + bundle = CertBundle( + certificate, private_key, chain=chain, region=region, arn=arn + ) else: # Will generate a random ARN bundle = CertBundle(certificate, private_key, chain=chain, region=region) @@ -351,13 +428,21 @@ class AWSCertificateManagerBackend(BaseBackend): del self._certificates[arn] - def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names): + def request_certificate( + self, + domain_name, + domain_validation_options, + idempotency_token, + subject_alt_names, + ): if idempotency_token is not None: arn = self._get_arn_from_idempotency_token(idempotency_token) if arn is not None: return arn - cert = CertBundle.generate_cert(domain_name, region=self.region, sans=subject_alt_names) + cert = CertBundle.generate_cert( + domain_name, region=self.region, sans=subject_alt_names + ) if idempotency_token is not None: self._set_idempotency_token_arn(idempotency_token, cert.arn) self._certificates[cert.arn] = cert @@ -369,8 +454,8 @@ class AWSCertificateManagerBackend(BaseBackend): cert_bundle = self.get_certificate(arn) for tag in tags: - key = tag['Key'] - value = tag.get('Value', None) + key = tag["Key"] + value = tag.get("Value", None) cert_bundle.tags[key] = value def remove_tags_from_certificate(self, arn, tags): @@ -378,8 +463,8 @@ class AWSCertificateManagerBackend(BaseBackend): cert_bundle = self.get_certificate(arn) for tag in tags: - key = tag['Key'] - value = tag.get('Value', None) + key = tag["Key"] + value = tag.get("Value", None) try: # If value isnt provided, just delete key diff --git a/moto/acm/responses.py b/moto/acm/responses.py index 0d0ac640b..13b22fa95 100644 --- a/moto/acm/responses.py +++ b/moto/acm/responses.py @@ -7,7 +7,6 @@ from .models import acm_backends, AWSError, AWSValidationException class AWSCertificateManagerResponse(BaseResponse): - @property def acm_backend(self): """ @@ -29,40 +28,49 @@ class AWSCertificateManagerResponse(BaseResponse): return self.request_params.get(param, default) def add_tags_to_certificate(self): - arn = self._get_param('CertificateArn') - tags = self._get_param('Tags') + arn = self._get_param("CertificateArn") + tags = self._get_param("Tags") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: self.acm_backend.add_tags_to_certificate(arn, tags) except AWSError as err: return err.response() - return '' + return "" def delete_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: self.acm_backend.delete_certificate(arn) except AWSError as err: return err.response() - return '' + return "" def describe_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: cert_bundle = self.acm_backend.get_certificate(arn) @@ -72,11 +80,14 @@ class AWSCertificateManagerResponse(BaseResponse): return json.dumps(cert_bundle.describe()) def get_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: cert_bundle = self.acm_backend.get_certificate(arn) @@ -84,8 +95,8 @@ class AWSCertificateManagerResponse(BaseResponse): return err.response() result = { - 'Certificate': cert_bundle.cert.decode(), - 'CertificateChain': cert_bundle.chain.decode() + "Certificate": cert_bundle.cert.decode(), + "CertificateChain": cert_bundle.chain.decode(), } return json.dumps(result) @@ -102,104 +113,129 @@ class AWSCertificateManagerResponse(BaseResponse): :return: str(JSON) for response """ - certificate = self._get_param('Certificate') - private_key = self._get_param('PrivateKey') - chain = self._get_param('CertificateChain') # Optional - current_arn = self._get_param('CertificateArn') # Optional + certificate = self._get_param("Certificate") + private_key = self._get_param("PrivateKey") + chain = self._get_param("CertificateChain") # Optional + current_arn = self._get_param("CertificateArn") # Optional # Simple parameter decoding. Rather do it here as its a data transport decision not part of the # actual data try: certificate = base64.standard_b64decode(certificate) except Exception: - return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response() + return AWSValidationException( + "The certificate is not PEM-encoded or is not valid." + ).response() try: private_key = base64.standard_b64decode(private_key) except Exception: - return AWSValidationException('The private key is not PEM-encoded or is not valid.').response() + return AWSValidationException( + "The private key is not PEM-encoded or is not valid." + ).response() if chain is not None: try: chain = base64.standard_b64decode(chain) except Exception: - return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response() + return AWSValidationException( + "The certificate chain is not PEM-encoded or is not valid." + ).response() try: - arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn) + arn = self.acm_backend.import_cert( + certificate, private_key, chain=chain, arn=current_arn + ) except AWSError as err: return err.response() - return json.dumps({'CertificateArn': arn}) + return json.dumps({"CertificateArn": arn}) def list_certificates(self): certs = [] - statuses = self._get_param('CertificateStatuses') + statuses = self._get_param("CertificateStatuses") for cert_bundle in self.acm_backend.get_certificates_list(statuses): - certs.append({ - 'CertificateArn': cert_bundle.arn, - 'DomainName': cert_bundle.common_name - }) + certs.append( + { + "CertificateArn": cert_bundle.arn, + "DomainName": cert_bundle.common_name, + } + ) - result = {'CertificateSummaryList': certs} + result = {"CertificateSummaryList": certs} return json.dumps(result) def list_tags_for_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return {'__type': 'MissingParameter', 'message': msg}, dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return {"__type": "MissingParameter", "message": msg}, dict(status=400) try: cert_bundle = self.acm_backend.get_certificate(arn) except AWSError as err: return err.response() - result = {'Tags': []} + result = {"Tags": []} # Tag "objects" can not contain the Value part for key, value in cert_bundle.tags.items(): - tag_dict = {'Key': key} + tag_dict = {"Key": key} if value is not None: - tag_dict['Value'] = value - result['Tags'].append(tag_dict) + tag_dict["Value"] = value + result["Tags"].append(tag_dict) return json.dumps(result) def remove_tags_from_certificate(self): - arn = self._get_param('CertificateArn') - tags = self._get_param('Tags') + arn = self._get_param("CertificateArn") + tags = self._get_param("Tags") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: self.acm_backend.remove_tags_from_certificate(arn, tags) except AWSError as err: return err.response() - return '' + return "" def request_certificate(self): - domain_name = self._get_param('DomainName') - domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm - idempotency_token = self._get_param('IdempotencyToken') - subject_alt_names = self._get_param('SubjectAlternativeNames') + domain_name = self._get_param("DomainName") + domain_validation_options = self._get_param( + "DomainValidationOptions" + ) # is ignored atm + idempotency_token = self._get_param("IdempotencyToken") + subject_alt_names = self._get_param("SubjectAlternativeNames") if subject_alt_names is not None and len(subject_alt_names) > 10: # There is initial AWS limit of 10 - msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised' - return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400) + msg = ( + "An ACM limit has been exceeded. Need to request SAN limit to be raised" + ) + return ( + json.dumps({"__type": "LimitExceededException", "message": msg}), + dict(status=400), + ) try: - arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names) + arn = self.acm_backend.request_certificate( + domain_name, + domain_validation_options, + idempotency_token, + subject_alt_names, + ) except AWSError as err: return err.response() - return json.dumps({'CertificateArn': arn}) + return json.dumps({"CertificateArn": arn}) def resend_validation_email(self): - arn = self._get_param('CertificateArn') - domain = self._get_param('Domain') + arn = self._get_param("CertificateArn") + domain = self._get_param("Domain") # ValidationDomain not used yet. # Contains domain which is equal to or a subset of Domain # that AWS will send validation emails to @@ -207,18 +243,21 @@ class AWSCertificateManagerResponse(BaseResponse): # validation_domain = self._get_param('ValidationDomain') if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: cert_bundle = self.acm_backend.get_certificate(arn) if cert_bundle.common_name != domain: - msg = 'Parameter Domain does not match certificate domain' - _type = 'InvalidDomainValidationOptionsException' - return json.dumps({'__type': _type, 'message': msg}), dict(status=400) + msg = "Parameter Domain does not match certificate domain" + _type = "InvalidDomainValidationOptionsException" + return json.dumps({"__type": _type, "message": msg}), dict(status=400) except AWSError as err: return err.response() - return '' + return "" diff --git a/moto/acm/urls.py b/moto/acm/urls.py index 20acbb3f4..8a8d3e2ef 100644 --- a/moto/acm/urls.py +++ b/moto/acm/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import AWSCertificateManagerResponse -url_bases = [ - "https?://acm.(.+).amazonaws.com", -] +url_bases = ["https?://acm.(.+).amazonaws.com"] -url_paths = { - '{0}/$': AWSCertificateManagerResponse.dispatch, -} +url_paths = {"{0}/$": AWSCertificateManagerResponse.dispatch} diff --git a/moto/acm/utils.py b/moto/acm/utils.py index b3c441454..6d695d95c 100644 --- a/moto/acm/utils.py +++ b/moto/acm/utils.py @@ -4,4 +4,6 @@ import uuid def make_arn_for_certificate(account_id, region_name): # Example # arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b - return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4()) + return "arn:aws:acm:{0}:{1}:certificate/{2}".format( + region_name, account_id, uuid.uuid4() + ) diff --git a/moto/apigateway/__init__.py b/moto/apigateway/__init__.py index 98b2058d9..42da3db53 100644 --- a/moto/apigateway/__init__.py +++ b/moto/apigateway/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import apigateway_backends from ..core.models import base_decorator, deprecated_base_decorator -apigateway_backend = apigateway_backends['us-east-1'] +apigateway_backend = apigateway_backends["us-east-1"] mock_apigateway = base_decorator(apigateway_backends) mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends) diff --git a/moto/apigateway/exceptions.py b/moto/apigateway/exceptions.py index 62fa24392..434ebc467 100644 --- a/moto/apigateway/exceptions.py +++ b/moto/apigateway/exceptions.py @@ -2,12 +2,96 @@ from __future__ import unicode_literals from moto.core.exceptions import RESTError +class BadRequestException(RESTError): + pass + + +class AwsProxyNotAllowed(BadRequestException): + def __init__(self): + super(AwsProxyNotAllowed, self).__init__( + "BadRequestException", + "Integrations of type 'AWS_PROXY' currently only supports Lambda function and Firehose stream invocations.", + ) + + +class CrossAccountNotAllowed(RESTError): + def __init__(self): + super(CrossAccountNotAllowed, self).__init__( + "AccessDeniedException", "Cross-account pass role is not allowed." + ) + + +class RoleNotSpecified(BadRequestException): + def __init__(self): + super(RoleNotSpecified, self).__init__( + "BadRequestException", "Role ARN must be specified for AWS integrations" + ) + + +class IntegrationMethodNotDefined(BadRequestException): + def __init__(self): + super(IntegrationMethodNotDefined, self).__init__( + "BadRequestException", "Enumeration value for HttpMethod must be non-empty" + ) + + +class InvalidResourcePathException(BadRequestException): + def __init__(self): + super(InvalidResourcePathException, self).__init__( + "BadRequestException", + "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end.", + ) + + +class InvalidHttpEndpoint(BadRequestException): + def __init__(self): + super(InvalidHttpEndpoint, self).__init__( + "BadRequestException", "Invalid HTTP endpoint specified for URI" + ) + + +class InvalidArn(BadRequestException): + def __init__(self): + super(InvalidArn, self).__init__( + "BadRequestException", "Invalid ARN specified in the request" + ) + + +class InvalidIntegrationArn(BadRequestException): + def __init__(self): + super(InvalidIntegrationArn, self).__init__( + "BadRequestException", "AWS ARN for integration must contain path or action" + ) + + +class InvalidRequestInput(BadRequestException): + def __init__(self): + super(InvalidRequestInput, self).__init__( + "BadRequestException", "Invalid request input" + ) + + +class NoIntegrationDefined(BadRequestException): + def __init__(self): + super(NoIntegrationDefined, self).__init__( + "BadRequestException", "No integration defined for method" + ) + + +class NoMethodDefined(BadRequestException): + def __init__(self): + super(NoMethodDefined, self).__init__( + "BadRequestException", "The REST API doesn't contain any methods" + ) + + class StageNotFoundException(RESTError): code = 404 def __init__(self): super(StageNotFoundException, self).__init__( - "NotFoundException", "Invalid stage identifier specified") + "NotFoundException", "Invalid stage identifier specified" + ) class ApiKeyNotFoundException(RESTError): @@ -15,4 +99,14 @@ class ApiKeyNotFoundException(RESTError): def __init__(self): super(ApiKeyNotFoundException, self).__init__( - "NotFoundException", "Invalid API Key identifier specified") + "NotFoundException", "Invalid API Key identifier specified" + ) + + +class ApiKeyAlreadyExists(RESTError): + code = 409 + + def __init__(self): + super(ApiKeyAlreadyExists, self).__init__( + "ConflictException", "API Key already exists" + ) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 6be062d7f..8b5fb787f 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -3,53 +3,69 @@ from __future__ import unicode_literals import random import string +import re import requests import time from boto3.session import Session + +try: + from urlparse import urlparse +except ImportError: + from urllib.parse import urlparse import responses from moto.core import BaseBackend, BaseModel from .utils import create_id from moto.core.utils import path_url -from .exceptions import StageNotFoundException, ApiKeyNotFoundException +from moto.sts.models import ACCOUNT_ID +from .exceptions import ( + ApiKeyNotFoundException, + AwsProxyNotAllowed, + CrossAccountNotAllowed, + IntegrationMethodNotDefined, + InvalidArn, + InvalidIntegrationArn, + InvalidHttpEndpoint, + InvalidResourcePathException, + InvalidRequestInput, + StageNotFoundException, + RoleNotSpecified, + NoIntegrationDefined, + NoMethodDefined, + ApiKeyAlreadyExists, +) STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}" class Deployment(BaseModel, dict): - def __init__(self, deployment_id, name, description=""): super(Deployment, self).__init__() - self['id'] = deployment_id - self['stageName'] = name - self['description'] = description - self['createdDate'] = int(time.time()) + self["id"] = deployment_id + self["stageName"] = name + self["description"] = description + self["createdDate"] = int(time.time()) class IntegrationResponse(BaseModel, dict): - def __init__(self, status_code, selection_pattern=None): - self['responseTemplates'] = {"application/json": None} - self['statusCode'] = status_code + self["responseTemplates"] = {"application/json": None} + self["statusCode"] = status_code if selection_pattern: - self['selectionPattern'] = selection_pattern + self["selectionPattern"] = selection_pattern class Integration(BaseModel, dict): - def __init__(self, integration_type, uri, http_method, request_templates=None): super(Integration, self).__init__() - self['type'] = integration_type - self['uri'] = uri - self['httpMethod'] = http_method - self['requestTemplates'] = request_templates - self["integrationResponses"] = { - "200": IntegrationResponse(200) - } + self["type"] = integration_type + self["uri"] = uri + self["httpMethod"] = http_method + self["requestTemplates"] = request_templates + self["integrationResponses"] = {"200": IntegrationResponse(200)} def create_integration_response(self, status_code, selection_pattern): - integration_response = IntegrationResponse( - status_code, selection_pattern) + integration_response = IntegrationResponse(status_code, selection_pattern) self["integrationResponses"][status_code] = integration_response return integration_response @@ -61,25 +77,25 @@ class Integration(BaseModel, dict): class MethodResponse(BaseModel, dict): - def __init__(self, status_code): super(MethodResponse, self).__init__() - self['statusCode'] = status_code + self["statusCode"] = status_code class Method(BaseModel, dict): - def __init__(self, method_type, authorization_type): super(Method, self).__init__() - self.update(dict( - httpMethod=method_type, - authorizationType=authorization_type, - authorizerId=None, - apiKeyRequired=None, - requestParameters=None, - requestModels=None, - methodIntegration=None, - )) + self.update( + dict( + httpMethod=method_type, + authorizationType=authorization_type, + authorizerId=None, + apiKeyRequired=None, + requestParameters=None, + requestModels=None, + methodIntegration=None, + ) + ) self.method_responses = {} def create_response(self, response_code): @@ -95,16 +111,13 @@ class Method(BaseModel, dict): class Resource(BaseModel): - def __init__(self, id, region_name, api_id, path_part, parent_id): self.id = id self.region_name = region_name self.api_id = api_id self.path_part = path_part self.parent_id = parent_id - self.resource_methods = { - 'GET': {} - } + self.resource_methods = {"GET": {}} def to_dict(self): response = { @@ -113,8 +126,8 @@ class Resource(BaseModel): "resourceMethods": self.resource_methods, } if self.parent_id: - response['parentId'] = self.parent_id - response['pathPart'] = self.path_part + response["parentId"] = self.parent_id + response["pathPart"] = self.path_part return response def get_path(self): @@ -125,102 +138,112 @@ class Resource(BaseModel): backend = apigateway_backends[self.region_name] parent = backend.get_resource(self.api_id, self.parent_id) parent_path = parent.get_path() - if parent_path != '/': # Root parent - parent_path += '/' + if parent_path != "/": # Root parent + parent_path += "/" return parent_path else: - return '' + return "" def get_response(self, request): integration = self.get_integration(request.method) - integration_type = integration['type'] + integration_type = integration["type"] - if integration_type == 'HTTP': - uri = integration['uri'] - requests_func = getattr(requests, integration[ - 'httpMethod'].lower()) + if integration_type == "HTTP": + uri = integration["uri"] + requests_func = getattr(requests, integration["httpMethod"].lower()) response = requests_func(uri) else: raise NotImplementedError( - "The {0} type has not been implemented".format(integration_type)) + "The {0} type has not been implemented".format(integration_type) + ) return response.status_code, response.text def add_method(self, method_type, authorization_type): - method = Method(method_type=method_type, - authorization_type=authorization_type) + method = Method(method_type=method_type, authorization_type=authorization_type) self.resource_methods[method_type] = method return method def get_method(self, method_type): return self.resource_methods[method_type] - def add_integration(self, method_type, integration_type, uri, request_templates=None): + def add_integration( + self, method_type, integration_type, uri, request_templates=None + ): integration = Integration( - integration_type, uri, method_type, request_templates=request_templates) - self.resource_methods[method_type]['methodIntegration'] = integration + integration_type, uri, method_type, request_templates=request_templates + ) + self.resource_methods[method_type]["methodIntegration"] = integration return integration def get_integration(self, method_type): - return self.resource_methods[method_type]['methodIntegration'] + return self.resource_methods[method_type]["methodIntegration"] def delete_integration(self, method_type): - return self.resource_methods[method_type].pop('methodIntegration') + return self.resource_methods[method_type].pop("methodIntegration") class Stage(BaseModel, dict): - - def __init__(self, name=None, deployment_id=None, variables=None, - description='', cacheClusterEnabled=False, cacheClusterSize=None): + def __init__( + self, + name=None, + deployment_id=None, + variables=None, + description="", + cacheClusterEnabled=False, + cacheClusterSize=None, + ): super(Stage, self).__init__() if variables is None: variables = {} - self['stageName'] = name - self['deploymentId'] = deployment_id - self['methodSettings'] = {} - self['variables'] = variables - self['description'] = description - self['cacheClusterEnabled'] = cacheClusterEnabled - if self['cacheClusterEnabled']: - self['cacheClusterSize'] = str(0.5) + self["stageName"] = name + self["deploymentId"] = deployment_id + self["methodSettings"] = {} + self["variables"] = variables + self["description"] = description + self["cacheClusterEnabled"] = cacheClusterEnabled + if self["cacheClusterEnabled"]: + self["cacheClusterSize"] = str(0.5) if cacheClusterSize is not None: - self['cacheClusterSize'] = str(cacheClusterSize) + self["cacheClusterSize"] = str(cacheClusterSize) def apply_operations(self, patch_operations): for op in patch_operations: - if 'variables/' in op['path']: + if "variables/" in op["path"]: self._apply_operation_to_variables(op) - elif '/cacheClusterEnabled' in op['path']: - self['cacheClusterEnabled'] = self._str2bool(op['value']) - if 'cacheClusterSize' not in self and self['cacheClusterEnabled']: - self['cacheClusterSize'] = str(0.5) - elif '/cacheClusterSize' in op['path']: - self['cacheClusterSize'] = str(float(op['value'])) - elif '/description' in op['path']: - self['description'] = op['value'] - elif '/deploymentId' in op['path']: - self['deploymentId'] = op['value'] - elif op['op'] == 'replace': + elif "/cacheClusterEnabled" in op["path"]: + self["cacheClusterEnabled"] = self._str2bool(op["value"]) + if "cacheClusterSize" not in self and self["cacheClusterEnabled"]: + self["cacheClusterSize"] = str(0.5) + elif "/cacheClusterSize" in op["path"]: + self["cacheClusterSize"] = str(float(op["value"])) + elif "/description" in op["path"]: + self["description"] = op["value"] + elif "/deploymentId" in op["path"]: + self["deploymentId"] = op["value"] + elif op["op"] == "replace": # Method Settings drop into here # (e.g., path could be '/*/*/logging/loglevel') - split_path = op['path'].split('/', 3) + split_path = op["path"].split("/", 3) if len(split_path) != 4: continue self._patch_method_setting( - '/'.join(split_path[1:3]), split_path[3], op['value']) + "/".join(split_path[1:3]), split_path[3], op["value"] + ) else: - raise Exception( - 'Patch operation "%s" not implemented' % op['op']) + raise Exception('Patch operation "%s" not implemented' % op["op"]) return self def _patch_method_setting(self, resource_path_and_method, key, value): updated_key = self._method_settings_translations(key) if updated_key is not None: - if resource_path_and_method not in self['methodSettings']: - self['methodSettings'][ - resource_path_and_method] = self._get_default_method_settings() - self['methodSettings'][resource_path_and_method][ - updated_key] = self._convert_to_type(updated_key, value) + if resource_path_and_method not in self["methodSettings"]: + self["methodSettings"][ + resource_path_and_method + ] = self._get_default_method_settings() + self["methodSettings"][resource_path_and_method][ + updated_key + ] = self._convert_to_type(updated_key, value) def _get_default_method_settings(self): return { @@ -232,21 +255,21 @@ class Stage(BaseModel, dict): "cacheDataEncrypted": True, "cachingEnabled": False, "throttlingBurstLimit": 2000, - "requireAuthorizationForCacheControl": True + "requireAuthorizationForCacheControl": True, } def _method_settings_translations(self, key): mappings = { - 'metrics/enabled': 'metricsEnabled', - 'logging/loglevel': 'loggingLevel', - 'logging/dataTrace': 'dataTraceEnabled', - 'throttling/burstLimit': 'throttlingBurstLimit', - 'throttling/rateLimit': 'throttlingRateLimit', - 'caching/enabled': 'cachingEnabled', - 'caching/ttlInSeconds': 'cacheTtlInSeconds', - 'caching/dataEncrypted': 'cacheDataEncrypted', - 'caching/requireAuthorizationForCacheControl': 'requireAuthorizationForCacheControl', - 'caching/unauthorizedCacheControlHeaderStrategy': 'unauthorizedCacheControlHeaderStrategy' + "metrics/enabled": "metricsEnabled", + "logging/loglevel": "loggingLevel", + "logging/dataTrace": "dataTraceEnabled", + "throttling/burstLimit": "throttlingBurstLimit", + "throttling/rateLimit": "throttlingRateLimit", + "caching/enabled": "cachingEnabled", + "caching/ttlInSeconds": "cacheTtlInSeconds", + "caching/dataEncrypted": "cacheDataEncrypted", + "caching/requireAuthorizationForCacheControl": "requireAuthorizationForCacheControl", + "caching/unauthorizedCacheControlHeaderStrategy": "unauthorizedCacheControlHeaderStrategy", } if key in mappings: @@ -259,26 +282,26 @@ class Stage(BaseModel, dict): def _convert_to_type(self, key, val): type_mappings = { - 'metricsEnabled': 'bool', - 'loggingLevel': 'str', - 'dataTraceEnabled': 'bool', - 'throttlingBurstLimit': 'int', - 'throttlingRateLimit': 'float', - 'cachingEnabled': 'bool', - 'cacheTtlInSeconds': 'int', - 'cacheDataEncrypted': 'bool', - 'requireAuthorizationForCacheControl': 'bool', - 'unauthorizedCacheControlHeaderStrategy': 'str' + "metricsEnabled": "bool", + "loggingLevel": "str", + "dataTraceEnabled": "bool", + "throttlingBurstLimit": "int", + "throttlingRateLimit": "float", + "cachingEnabled": "bool", + "cacheTtlInSeconds": "int", + "cacheDataEncrypted": "bool", + "requireAuthorizationForCacheControl": "bool", + "unauthorizedCacheControlHeaderStrategy": "str", } if key in type_mappings: type_value = type_mappings[key] - if type_value == 'bool': + if type_value == "bool": return self._str2bool(val) - elif type_value == 'int': + elif type_value == "int": return int(val) - elif type_value == 'float': + elif type_value == "float": return float(val) else: return str(val) @@ -286,43 +309,55 @@ class Stage(BaseModel, dict): return str(val) def _apply_operation_to_variables(self, op): - key = op['path'][op['path'].rindex("variables/") + 10:] - if op['op'] == 'remove': - self['variables'].pop(key, None) - elif op['op'] == 'replace': - self['variables'][key] = op['value'] + key = op["path"][op["path"].rindex("variables/") + 10 :] + if op["op"] == "remove": + self["variables"].pop(key, None) + elif op["op"] == "replace": + self["variables"][key] = op["value"] else: - raise Exception('Patch operation "%s" not implemented' % op['op']) + raise Exception('Patch operation "%s" not implemented' % op["op"]) class ApiKey(BaseModel, dict): - - def __init__(self, name=None, description=None, enabled=True, - generateDistinctId=False, value=None, stageKeys=None, customerId=None): + def __init__( + self, + name=None, + description=None, + enabled=True, + generateDistinctId=False, + value=None, + stageKeys=None, + tags=None, + customerId=None, + ): super(ApiKey, self).__init__() - self['id'] = create_id() - self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40)) - self['name'] = name - self['customerId'] = customerId - self['description'] = description - self['enabled'] = enabled - self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) - self['stageKeys'] = stageKeys + self["id"] = create_id() + self["value"] = ( + value + if value + else "".join(random.sample(string.ascii_letters + string.digits, 40)) + ) + self["name"] = name + self["customerId"] = customerId + self["description"] = description + self["enabled"] = enabled + self["createdDate"] = self["lastUpdatedDate"] = int(time.time()) + self["stageKeys"] = stageKeys + self["tags"] = tags def update_operations(self, patch_operations): for op in patch_operations: - if op['op'] == 'replace': - if '/name' in op['path']: - self['name'] = op['value'] - elif '/customerId' in op['path']: - self['customerId'] = op['value'] - elif '/description' in op['path']: - self['description'] = op['value'] - elif '/enabled' in op['path']: - self['enabled'] = self._str2bool(op['value']) + if op["op"] == "replace": + if "/name" in op["path"]: + self["name"] = op["value"] + elif "/customerId" in op["path"]: + self["customerId"] = op["value"] + elif "/description" in op["path"]: + self["description"] = op["value"] + elif "/enabled" in op["path"]: + self["enabled"] = self._str2bool(op["value"]) else: - raise Exception( - 'Patch operation "%s" not implemented' % op['op']) + raise Exception('Patch operation "%s" not implemented' % op["op"]) return self def _str2bool(self, v): @@ -330,30 +365,35 @@ class ApiKey(BaseModel, dict): class UsagePlan(BaseModel, dict): - - def __init__(self, name=None, description=None, apiStages=[], - throttle=None, quota=None): + def __init__( + self, + name=None, + description=None, + apiStages=None, + throttle=None, + quota=None, + tags=None, + ): super(UsagePlan, self).__init__() - self['id'] = create_id() - self['name'] = name - self['description'] = description - self['apiStages'] = apiStages - self['throttle'] = throttle - self['quota'] = quota + self["id"] = create_id() + self["name"] = name + self["description"] = description + self["apiStages"] = apiStages if apiStages else [] + self["throttle"] = throttle + self["quota"] = quota + self["tags"] = tags class UsagePlanKey(BaseModel, dict): - def __init__(self, id, type, name, value): super(UsagePlanKey, self).__init__() - self['id'] = id - self['name'] = name - self['type'] = type - self['value'] = value + self["id"] = id + self["name"] = name + self["type"] = type + self["value"] = value class RestAPI(BaseModel): - def __init__(self, id, region_name, name, description): self.id = id self.region_name = region_name @@ -365,7 +405,7 @@ class RestAPI(BaseModel): self.stages = {} self.resources = {} - self.add_child('/') # Add default child + self.add_child("/") # Add default child def __repr__(self): return str(self.id) @@ -380,8 +420,13 @@ class RestAPI(BaseModel): def add_child(self, path, parent_id=None): child_id = create_id() - child = Resource(id=child_id, region_name=self.region_name, - api_id=self.id, path_part=path, parent_id=parent_id) + child = Resource( + id=child_id, + region_name=self.region_name, + api_id=self.id, + path_part=path, + parent_id=parent_id, + ) self.resources[child_id] = child return child @@ -393,30 +438,53 @@ class RestAPI(BaseModel): def resource_callback(self, request): path = path_url(request.url) - path_after_stage_name = '/'.join(path.split("/")[2:]) + path_after_stage_name = "/".join(path.split("/")[2:]) if not path_after_stage_name: - path_after_stage_name = '/' + path_after_stage_name = "/" resource = self.get_resource_for_path(path_after_stage_name) status_code, response = resource.get_response(request) return status_code, {}, response def update_integration_mocks(self, stage_name): - stage_url_lower = STAGE_URL.format(api_id=self.id.lower(), - region_name=self.region_name, stage_name=stage_name) - stage_url_upper = STAGE_URL.format(api_id=self.id.upper(), - region_name=self.region_name, stage_name=stage_name) + stage_url_lower = STAGE_URL.format( + api_id=self.id.lower(), region_name=self.region_name, stage_name=stage_name + ) + stage_url_upper = STAGE_URL.format( + api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name + ) - responses.add_callback(responses.GET, stage_url_lower, - callback=self.resource_callback) - responses.add_callback(responses.GET, stage_url_upper, - callback=self.resource_callback) + for url in [stage_url_lower, stage_url_upper]: + responses._default_mock._matches.insert( + 0, + responses.CallbackResponse( + url=url, + method=responses.GET, + callback=self.resource_callback, + content_type="text/plain", + match_querystring=False, + ), + ) - def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): + def create_stage( + self, + name, + deployment_id, + variables=None, + description="", + cacheClusterEnabled=None, + cacheClusterSize=None, + ): if variables is None: variables = {} - stage = Stage(name=name, deployment_id=deployment_id, variables=variables, - description=description, cacheClusterSize=cacheClusterSize, cacheClusterEnabled=cacheClusterEnabled) + stage = Stage( + name=name, + deployment_id=deployment_id, + variables=variables, + description=description, + cacheClusterSize=cacheClusterSize, + cacheClusterEnabled=cacheClusterEnabled, + ) self.stages[name] = stage self.update_integration_mocks(name) return stage @@ -428,7 +496,8 @@ class RestAPI(BaseModel): deployment = Deployment(deployment_id, name, description) self.deployments[deployment_id] = deployment self.stages[name] = Stage( - name=name, deployment_id=deployment_id, variables=stage_variables) + name=name, deployment_id=deployment_id, variables=stage_variables + ) self.update_integration_mocks(name) return deployment @@ -447,7 +516,6 @@ class RestAPI(BaseModel): class APIGatewayBackend(BaseBackend): - def __init__(self, region_name): super(APIGatewayBackend, self).__init__() self.apis = {} @@ -488,11 +556,10 @@ class APIGatewayBackend(BaseBackend): return resource def create_resource(self, function_id, parent_resource_id, path_part): + if not re.match("^\\{?[a-zA-Z0-9._-]+\\}?$", path_part): + raise InvalidResourcePathException() api = self.get_rest_api(function_id) - child = api.add_child( - path=path_part, - parent_id=parent_resource_id, - ) + child = api.add_child(path=path_part, parent_id=parent_resource_id) return child def delete_resource(self, function_id, resource_id): @@ -521,13 +588,27 @@ class APIGatewayBackend(BaseBackend): api = self.get_rest_api(function_id) return api.get_stages() - def create_stage(self, function_id, stage_name, deploymentId, - variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): + def create_stage( + self, + function_id, + stage_name, + deploymentId, + variables=None, + description="", + cacheClusterEnabled=None, + cacheClusterSize=None, + ): if variables is None: variables = {} api = self.get_rest_api(function_id) - api.create_stage(stage_name, deploymentId, variables=variables, - description=description, cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) + api.create_stage( + stage_name, + deploymentId, + variables=variables, + description=description, + cacheClusterEnabled=cacheClusterEnabled, + cacheClusterSize=cacheClusterSize, + ) return api.stages.get(stage_name) def update_stage(self, function_id, stage_name, patch_operations): @@ -537,26 +618,73 @@ class APIGatewayBackend(BaseBackend): stage = api.stages[stage_name] = Stage() return stage.apply_operations(patch_operations) + def delete_stage(self, function_id, stage_name): + api = self.get_rest_api(function_id) + del api.stages[stage_name] + def get_method_response(self, function_id, resource_id, method_type, response_code): method = self.get_method(function_id, resource_id, method_type) method_response = method.get_response(response_code) return method_response - def create_method_response(self, function_id, resource_id, method_type, response_code): + def create_method_response( + self, function_id, resource_id, method_type, response_code + ): method = self.get_method(function_id, resource_id, method_type) method_response = method.create_response(response_code) return method_response - def delete_method_response(self, function_id, resource_id, method_type, response_code): + def delete_method_response( + self, function_id, resource_id, method_type, response_code + ): method = self.get_method(function_id, resource_id, method_type) method_response = method.delete_response(response_code) return method_response - def create_integration(self, function_id, resource_id, method_type, integration_type, uri, - request_templates=None): + def create_integration( + self, + function_id, + resource_id, + method_type, + integration_type, + uri, + integration_method=None, + credentials=None, + request_templates=None, + ): resource = self.get_resource(function_id, resource_id) - integration = resource.add_integration(method_type, integration_type, uri, - request_templates=request_templates) + if credentials and not re.match( + "^arn:aws:iam::" + str(ACCOUNT_ID), credentials + ): + raise CrossAccountNotAllowed() + if not integration_method and integration_type in [ + "HTTP", + "HTTP_PROXY", + "AWS", + "AWS_PROXY", + ]: + raise IntegrationMethodNotDefined() + if integration_type in ["AWS_PROXY"] and re.match( + "^arn:aws:apigateway:[a-zA-Z0-9-]+:s3", uri + ): + raise AwsProxyNotAllowed() + if ( + integration_type in ["AWS"] + and re.match("^arn:aws:apigateway:[a-zA-Z0-9-]+:s3", uri) + and not credentials + ): + raise RoleNotSpecified() + if integration_type in ["HTTP", "HTTP_PROXY"] and not self._uri_validator(uri): + raise InvalidHttpEndpoint() + if integration_type in ["AWS", "AWS_PROXY"] and not re.match("^arn:aws:", uri): + raise InvalidArn() + if integration_type in ["AWS", "AWS_PROXY"] and not re.match( + "^arn:aws:apigateway:[a-zA-Z0-9-]+:[a-zA-Z0-9-]+:(path|action)/", uri + ): + raise InvalidIntegrationArn() + integration = resource.add_integration( + method_type, integration_type, uri, request_templates=request_templates + ) return integration def get_integration(self, function_id, resource_id, method_type): @@ -567,31 +695,55 @@ class APIGatewayBackend(BaseBackend): resource = self.get_resource(function_id, resource_id) return resource.delete_integration(method_type) - def create_integration_response(self, function_id, resource_id, method_type, status_code, selection_pattern): - integration = self.get_integration( - function_id, resource_id, method_type) + def create_integration_response( + self, + function_id, + resource_id, + method_type, + status_code, + selection_pattern, + response_templates, + ): + if response_templates is None: + raise InvalidRequestInput() + integration = self.get_integration(function_id, resource_id, method_type) integration_response = integration.create_integration_response( - status_code, selection_pattern) + status_code, selection_pattern + ) return integration_response - def get_integration_response(self, function_id, resource_id, method_type, status_code): - integration = self.get_integration( - function_id, resource_id, method_type) - integration_response = integration.get_integration_response( - status_code) + def get_integration_response( + self, function_id, resource_id, method_type, status_code + ): + integration = self.get_integration(function_id, resource_id, method_type) + integration_response = integration.get_integration_response(status_code) return integration_response - def delete_integration_response(self, function_id, resource_id, method_type, status_code): - integration = self.get_integration( - function_id, resource_id, method_type) - integration_response = integration.delete_integration_response( - status_code) + def delete_integration_response( + self, function_id, resource_id, method_type, status_code + ): + integration = self.get_integration(function_id, resource_id, method_type) + integration_response = integration.delete_integration_response(status_code) return integration_response - def create_deployment(self, function_id, name, description="", stage_variables=None): + def create_deployment( + self, function_id, name, description="", stage_variables=None + ): if stage_variables is None: stage_variables = {} api = self.get_rest_api(function_id) + methods = [ + list(res.resource_methods.values()) + for res in self.list_resources(function_id) + ][0] + if not any(methods): + raise NoMethodDefined() + method_integrations = [ + method["methodIntegration"] if "methodIntegration" in method else None + for method in methods + ] + if not any(method_integrations): + raise NoIntegrationDefined() deployment = api.create_deployment(name, description, stage_variables) return deployment @@ -608,8 +760,12 @@ class APIGatewayBackend(BaseBackend): return api.delete_deployment(deployment_id) def create_apikey(self, payload): + if payload.get("value") is not None: + for api_key in self.get_apikeys(): + if api_key.get("value") == payload["value"]: + raise ApiKeyAlreadyExists() key = ApiKey(**payload) - self.keys[key['id']] = key + self.keys[key["id"]] = key return key def get_apikeys(self): @@ -628,7 +784,7 @@ class APIGatewayBackend(BaseBackend): def create_usage_plan(self, payload): plan = UsagePlan(**payload) - self.usage_plans[plan['id']] = plan + self.usage_plans[plan["id"]] = plan return plan def get_usage_plans(self, api_key_id=None): @@ -637,7 +793,7 @@ class APIGatewayBackend(BaseBackend): plans = [ plan for plan in plans - if self.usage_plan_keys.get(plan['id'], {}).get(api_key_id, False) + if self.usage_plan_keys.get(plan["id"], {}).get(api_key_id, False) ] return plans @@ -658,8 +814,13 @@ class APIGatewayBackend(BaseBackend): api_key = self.keys[key_id] - usage_plan_key = UsagePlanKey(id=key_id, type=payload["keyType"], name=api_key["name"], value=api_key["value"]) - self.usage_plan_keys[usage_plan_id][usage_plan_key['id']] = usage_plan_key + usage_plan_key = UsagePlanKey( + id=key_id, + type=payload["keyType"], + name=api_key["name"], + value=api_key["value"], + ) + self.usage_plan_keys[usage_plan_id][usage_plan_key["id"]] = usage_plan_key return usage_plan_key def get_usage_plan_keys(self, usage_plan_id): @@ -675,7 +836,14 @@ class APIGatewayBackend(BaseBackend): self.usage_plan_keys[usage_plan_id].pop(key_id) return {} + def _uri_validator(self, uri): + try: + result = urlparse(uri) + return all([result.scheme, result.netloc, result.path]) + except Exception: + return False + apigateway_backends = {} -for region_name in Session().get_available_regions('apigateway'): +for region_name in Session().get_available_regions("apigateway"): apigateway_backends[region_name] = APIGatewayBackend(region_name) diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index fa82705b1..c4c7b403e 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -4,13 +4,25 @@ import json from moto.core.responses import BaseResponse from .models import apigateway_backends -from .exceptions import StageNotFoundException, ApiKeyNotFoundException +from .exceptions import ( + ApiKeyNotFoundException, + BadRequestException, + CrossAccountNotAllowed, + StageNotFoundException, + ApiKeyAlreadyExists, +) class APIGatewayResponse(BaseResponse): + def error(self, type_, message, status=400): + return ( + status, + self.response_headers, + json.dumps({"__type": type_, "message": message}), + ) def _get_param(self, key): - return json.loads(self.body).get(key) + return json.loads(self.body).get(key) if self.body else None def _get_param_with_default_value(self, key, default): jsonbody = json.loads(self.body) @@ -27,14 +39,12 @@ class APIGatewayResponse(BaseResponse): def restapis(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if self.method == 'GET': + if self.method == "GET": apis = self.backend.list_apis() - return 200, {}, json.dumps({"item": [ - api.to_dict() for api in apis - ]}) - elif self.method == 'POST': - name = self._get_param('name') - description = self._get_param('description') + return 200, {}, json.dumps({"item": [api.to_dict() for api in apis]}) + elif self.method == "POST": + name = self._get_param("name") + description = self._get_param("description") rest_api = self.backend.create_rest_api(name, description) return 200, {}, json.dumps(rest_api.to_dict()) @@ -42,10 +52,10 @@ class APIGatewayResponse(BaseResponse): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] - if self.method == 'GET': + if self.method == "GET": rest_api = self.backend.get_rest_api(function_id) return 200, {}, json.dumps(rest_api.to_dict()) - elif self.method == 'DELETE': + elif self.method == "DELETE": rest_api = self.backend.delete_rest_api(function_id) return 200, {}, json.dumps(rest_api.to_dict()) @@ -53,26 +63,34 @@ class APIGatewayResponse(BaseResponse): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] - if self.method == 'GET': + if self.method == "GET": resources = self.backend.list_resources(function_id) - return 200, {}, json.dumps({"item": [ - resource.to_dict() for resource in resources - ]}) + return ( + 200, + {}, + json.dumps({"item": [resource.to_dict() for resource in resources]}), + ) def resource_individual(self, request, full_url, headers): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] resource_id = self.path.split("/")[-1] - if self.method == 'GET': - resource = self.backend.get_resource(function_id, resource_id) - elif self.method == 'POST': - path_part = self._get_param("pathPart") - resource = self.backend.create_resource( - function_id, resource_id, path_part) - elif self.method == 'DELETE': - resource = self.backend.delete_resource(function_id, resource_id) - return 200, {}, json.dumps(resource.to_dict()) + try: + if self.method == "GET": + resource = self.backend.get_resource(function_id, resource_id) + elif self.method == "POST": + path_part = self._get_param("pathPart") + resource = self.backend.create_resource( + function_id, resource_id, path_part + ) + elif self.method == "DELETE": + resource = self.backend.delete_resource(function_id, resource_id) + return 200, {}, json.dumps(resource.to_dict()) + except BadRequestException as e: + return self.error( + "com.amazonaws.dynamodb.v20111205#BadRequestException", e.message + ) def resource_methods(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -81,14 +99,14 @@ class APIGatewayResponse(BaseResponse): resource_id = url_path_parts[4] method_type = url_path_parts[6] - if self.method == 'GET': - method = self.backend.get_method( - function_id, resource_id, method_type) + if self.method == "GET": + method = self.backend.get_method(function_id, resource_id, method_type) return 200, {}, json.dumps(method) - elif self.method == 'PUT': + elif self.method == "PUT": authorization_type = self._get_param("authorizationType") method = self.backend.create_method( - function_id, resource_id, method_type, authorization_type) + function_id, resource_id, method_type, authorization_type + ) return 200, {}, json.dumps(method) def resource_method_responses(self, request, full_url, headers): @@ -99,15 +117,18 @@ class APIGatewayResponse(BaseResponse): method_type = url_path_parts[6] response_code = url_path_parts[8] - if self.method == 'GET': + if self.method == "GET": method_response = self.backend.get_method_response( - function_id, resource_id, method_type, response_code) - elif self.method == 'PUT': + function_id, resource_id, method_type, response_code + ) + elif self.method == "PUT": method_response = self.backend.create_method_response( - function_id, resource_id, method_type, response_code) - elif self.method == 'DELETE': + function_id, resource_id, method_type, response_code + ) + elif self.method == "DELETE": method_response = self.backend.delete_method_response( - function_id, resource_id, method_type, response_code) + function_id, resource_id, method_type, response_code + ) return 200, {}, json.dumps(method_response) def restapis_stages(self, request, full_url, headers): @@ -115,21 +136,28 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") function_id = url_path_parts[2] - if self.method == 'POST': + if self.method == "POST": stage_name = self._get_param("stageName") deployment_id = self._get_param("deploymentId") - stage_variables = self._get_param_with_default_value( - 'variables', {}) - description = self._get_param_with_default_value('description', '') + stage_variables = self._get_param_with_default_value("variables", {}) + description = self._get_param_with_default_value("description", "") cacheClusterEnabled = self._get_param_with_default_value( - 'cacheClusterEnabled', False) + "cacheClusterEnabled", False + ) cacheClusterSize = self._get_param_with_default_value( - 'cacheClusterSize', None) + "cacheClusterSize", None + ) - stage_response = self.backend.create_stage(function_id, stage_name, deployment_id, - variables=stage_variables, description=description, - cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) - elif self.method == 'GET': + stage_response = self.backend.create_stage( + function_id, + stage_name, + deployment_id, + variables=stage_variables, + description=description, + cacheClusterEnabled=cacheClusterEnabled, + cacheClusterSize=cacheClusterSize, + ) + elif self.method == "GET": stages = self.backend.get_stages(function_id) return 200, {}, json.dumps({"item": stages}) @@ -141,16 +169,25 @@ class APIGatewayResponse(BaseResponse): function_id = url_path_parts[2] stage_name = url_path_parts[4] - if self.method == 'GET': + if self.method == "GET": try: - stage_response = self.backend.get_stage( - function_id, stage_name) + stage_response = self.backend.get_stage(function_id, stage_name) except StageNotFoundException as error: - return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) - elif self.method == 'PATCH': - patch_operations = self._get_param('patchOperations') + return ( + error.code, + {}, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) + elif self.method == "PATCH": + patch_operations = self._get_param("patchOperations") stage_response = self.backend.update_stage( - function_id, stage_name, patch_operations) + function_id, stage_name, patch_operations + ) + elif self.method == "DELETE": + self.backend.delete_stage(function_id, stage_name) + return 202, {}, "{}" return 200, {}, json.dumps(stage_response) def integrations(self, request, full_url, headers): @@ -160,19 +197,40 @@ class APIGatewayResponse(BaseResponse): resource_id = url_path_parts[4] method_type = url_path_parts[6] - if self.method == 'GET': - integration_response = self.backend.get_integration( - function_id, resource_id, method_type) - elif self.method == 'PUT': - integration_type = self._get_param('type') - uri = self._get_param('uri') - request_templates = self._get_param('requestTemplates') - integration_response = self.backend.create_integration( - function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates) - elif self.method == 'DELETE': - integration_response = self.backend.delete_integration( - function_id, resource_id, method_type) - return 200, {}, json.dumps(integration_response) + try: + if self.method == "GET": + integration_response = self.backend.get_integration( + function_id, resource_id, method_type + ) + elif self.method == "PUT": + integration_type = self._get_param("type") + uri = self._get_param("uri") + integration_http_method = self._get_param("httpMethod") + creds = self._get_param("credentials") + request_templates = self._get_param("requestTemplates") + integration_response = self.backend.create_integration( + function_id, + resource_id, + method_type, + integration_type, + uri, + credentials=creds, + integration_method=integration_http_method, + request_templates=request_templates, + ) + elif self.method == "DELETE": + integration_response = self.backend.delete_integration( + function_id, resource_id, method_type + ) + return 200, {}, json.dumps(integration_response) + except BadRequestException as e: + return self.error( + "com.amazonaws.dynamodb.v20111205#BadRequestException", e.message + ) + except CrossAccountNotAllowed as e: + return self.error( + "com.amazonaws.dynamodb.v20111205#AccessDeniedException", e.message + ) def integration_responses(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -182,36 +240,52 @@ class APIGatewayResponse(BaseResponse): method_type = url_path_parts[6] status_code = url_path_parts[9] - if self.method == 'GET': - integration_response = self.backend.get_integration_response( - function_id, resource_id, method_type, status_code + try: + if self.method == "GET": + integration_response = self.backend.get_integration_response( + function_id, resource_id, method_type, status_code + ) + elif self.method == "PUT": + selection_pattern = self._get_param("selectionPattern") + response_templates = self._get_param("responseTemplates") + integration_response = self.backend.create_integration_response( + function_id, + resource_id, + method_type, + status_code, + selection_pattern, + response_templates, + ) + elif self.method == "DELETE": + integration_response = self.backend.delete_integration_response( + function_id, resource_id, method_type, status_code + ) + return 200, {}, json.dumps(integration_response) + except BadRequestException as e: + return self.error( + "com.amazonaws.dynamodb.v20111205#BadRequestException", e.message ) - elif self.method == 'PUT': - selection_pattern = self._get_param("selectionPattern") - integration_response = self.backend.create_integration_response( - function_id, resource_id, method_type, status_code, selection_pattern - ) - elif self.method == 'DELETE': - integration_response = self.backend.delete_integration_response( - function_id, resource_id, method_type, status_code - ) - return 200, {}, json.dumps(integration_response) def deployments(self, request, full_url, headers): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] - if self.method == 'GET': - deployments = self.backend.get_deployments(function_id) - return 200, {}, json.dumps({"item": deployments}) - elif self.method == 'POST': - name = self._get_param("stageName") - description = self._get_param_with_default_value("description", "") - stage_variables = self._get_param_with_default_value( - 'variables', {}) - deployment = self.backend.create_deployment( - function_id, name, description, stage_variables) - return 200, {}, json.dumps(deployment) + try: + if self.method == "GET": + deployments = self.backend.get_deployments(function_id) + return 200, {}, json.dumps({"item": deployments}) + elif self.method == "POST": + name = self._get_param("stageName") + description = self._get_param_with_default_value("description", "") + stage_variables = self._get_param_with_default_value("variables", {}) + deployment = self.backend.create_deployment( + function_id, name, description, stage_variables + ) + return 200, {}, json.dumps(deployment) + except BadRequestException as e: + return self.error( + "com.amazonaws.dynamodb.v20111205#BadRequestException", e.message + ) def individual_deployment(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -219,20 +293,28 @@ class APIGatewayResponse(BaseResponse): function_id = url_path_parts[2] deployment_id = url_path_parts[4] - if self.method == 'GET': - deployment = self.backend.get_deployment( - function_id, deployment_id) - elif self.method == 'DELETE': - deployment = self.backend.delete_deployment( - function_id, deployment_id) + if self.method == "GET": + deployment = self.backend.get_deployment(function_id, deployment_id) + elif self.method == "DELETE": + deployment = self.backend.delete_deployment(function_id, deployment_id) return 200, {}, json.dumps(deployment) def apikeys(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if self.method == 'POST': - apikey_response = self.backend.create_apikey(json.loads(self.body)) - elif self.method == 'GET': + if self.method == "POST": + try: + apikey_response = self.backend.create_apikey(json.loads(self.body)) + except ApiKeyAlreadyExists as error: + return ( + error.code, + self.headers, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) + + elif self.method == "GET": apikeys_response = self.backend.get_apikeys() return 200, {}, json.dumps({"item": apikeys_response}) return 200, {}, json.dumps(apikey_response) @@ -243,21 +325,21 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") apikey = url_path_parts[2] - if self.method == 'GET': + if self.method == "GET": apikey_response = self.backend.get_apikey(apikey) - elif self.method == 'PATCH': - patch_operations = self._get_param('patchOperations') + elif self.method == "PATCH": + patch_operations = self._get_param("patchOperations") apikey_response = self.backend.update_apikey(apikey, patch_operations) - elif self.method == 'DELETE': + elif self.method == "DELETE": apikey_response = self.backend.delete_apikey(apikey) return 200, {}, json.dumps(apikey_response) def usage_plans(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if self.method == 'POST': + if self.method == "POST": usage_plan_response = self.backend.create_usage_plan(json.loads(self.body)) - elif self.method == 'GET': + elif self.method == "GET": api_key_id = self.querystring.get("keyId", [None])[0] usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id) return 200, {}, json.dumps({"item": usage_plans_response}) @@ -269,9 +351,9 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") usage_plan = url_path_parts[2] - if self.method == 'GET': + if self.method == "GET": usage_plan_response = self.backend.get_usage_plan(usage_plan) - elif self.method == 'DELETE': + elif self.method == "DELETE": usage_plan_response = self.backend.delete_usage_plan(usage_plan) return 200, {}, json.dumps(usage_plan_response) @@ -281,13 +363,21 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") usage_plan_id = url_path_parts[2] - if self.method == 'POST': + if self.method == "POST": try: - usage_plan_response = self.backend.create_usage_plan_key(usage_plan_id, json.loads(self.body)) + usage_plan_response = self.backend.create_usage_plan_key( + usage_plan_id, json.loads(self.body) + ) except ApiKeyNotFoundException as error: - return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) + return ( + error.code, + {}, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) - elif self.method == 'GET': + elif self.method == "GET": usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) return 200, {}, json.dumps({"item": usage_plans_response}) @@ -300,8 +390,10 @@ class APIGatewayResponse(BaseResponse): usage_plan_id = url_path_parts[2] key_id = url_path_parts[4] - if self.method == 'GET': + if self.method == "GET": usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id) - elif self.method == 'DELETE': - usage_plan_response = self.backend.delete_usage_plan_key(usage_plan_id, key_id) + elif self.method == "DELETE": + usage_plan_response = self.backend.delete_usage_plan_key( + usage_plan_id, key_id + ) return 200, {}, json.dumps(usage_plan_response) diff --git a/moto/apigateway/urls.py b/moto/apigateway/urls.py index 5c6d372fa..bb2b2d216 100644 --- a/moto/apigateway/urls.py +++ b/moto/apigateway/urls.py @@ -1,27 +1,25 @@ from __future__ import unicode_literals from .responses import APIGatewayResponse -url_bases = [ - "https?://apigateway.(.+).amazonaws.com" -] +url_bases = ["https?://apigateway.(.+).amazonaws.com"] url_paths = { - '{0}/restapis$': APIGatewayResponse().restapis, - '{0}/restapis/(?P[^/]+)/?$': APIGatewayResponse().restapis_individual, - '{0}/restapis/(?P[^/]+)/resources$': APIGatewayResponse().resources, - '{0}/restapis/(?P[^/]+)/stages$': APIGatewayResponse().restapis_stages, - '{0}/restapis/(?P[^/]+)/stages/(?P[^/]+)/?$': APIGatewayResponse().stages, - '{0}/restapis/(?P[^/]+)/deployments$': APIGatewayResponse().deployments, - '{0}/restapis/(?P[^/]+)/deployments/(?P[^/]+)/?$': APIGatewayResponse().individual_deployment, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/?$': APIGatewayResponse().resource_individual, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/?$': APIGatewayResponse().resource_methods, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/responses/(?P\d+)$': APIGatewayResponse().resource_method_responses, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/?$': APIGatewayResponse().integrations, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/responses/(?P\d+)/?$': APIGatewayResponse().integration_responses, - '{0}/apikeys$': APIGatewayResponse().apikeys, - '{0}/apikeys/(?P[^/]+)': APIGatewayResponse().apikey_individual, - '{0}/usageplans$': APIGatewayResponse().usage_plans, - '{0}/usageplans/(?P[^/]+)/?$': APIGatewayResponse().usage_plan_individual, - '{0}/usageplans/(?P[^/]+)/keys$': APIGatewayResponse().usage_plan_keys, - '{0}/usageplans/(?P[^/]+)/keys/(?P[^/]+)/?$': APIGatewayResponse().usage_plan_key_individual, + "{0}/restapis$": APIGatewayResponse().restapis, + "{0}/restapis/(?P[^/]+)/?$": APIGatewayResponse().restapis_individual, + "{0}/restapis/(?P[^/]+)/resources$": APIGatewayResponse().resources, + "{0}/restapis/(?P[^/]+)/stages$": APIGatewayResponse().restapis_stages, + "{0}/restapis/(?P[^/]+)/stages/(?P[^/]+)/?$": APIGatewayResponse().stages, + "{0}/restapis/(?P[^/]+)/deployments$": APIGatewayResponse().deployments, + "{0}/restapis/(?P[^/]+)/deployments/(?P[^/]+)/?$": APIGatewayResponse().individual_deployment, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/?$": APIGatewayResponse().resource_individual, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/?$": APIGatewayResponse().resource_methods, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/responses/(?P\d+)$": APIGatewayResponse().resource_method_responses, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/?$": APIGatewayResponse().integrations, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/responses/(?P\d+)/?$": APIGatewayResponse().integration_responses, + "{0}/apikeys$": APIGatewayResponse().apikeys, + "{0}/apikeys/(?P[^/]+)": APIGatewayResponse().apikey_individual, + "{0}/usageplans$": APIGatewayResponse().usage_plans, + "{0}/usageplans/(?P[^/]+)/?$": APIGatewayResponse().usage_plan_individual, + "{0}/usageplans/(?P[^/]+)/keys$": APIGatewayResponse().usage_plan_keys, + "{0}/usageplans/(?P[^/]+)/keys/(?P[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual, } diff --git a/moto/apigateway/utils.py b/moto/apigateway/utils.py index 31f8060b0..807848f66 100644 --- a/moto/apigateway/utils.py +++ b/moto/apigateway/utils.py @@ -7,4 +7,4 @@ import string def create_id(): size = 10 chars = list(range(10)) + list(string.ascii_lowercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) diff --git a/moto/athena/__init__.py b/moto/athena/__init__.py new file mode 100644 index 000000000..3c1dc15c5 --- /dev/null +++ b/moto/athena/__init__.py @@ -0,0 +1,7 @@ +from __future__ import unicode_literals +from .models import athena_backends +from ..core.models import base_decorator, deprecated_base_decorator + +athena_backend = athena_backends["us-east-1"] +mock_athena = base_decorator(athena_backends) +mock_athena_deprecated = deprecated_base_decorator(athena_backends) diff --git a/moto/athena/exceptions.py b/moto/athena/exceptions.py new file mode 100644 index 000000000..96b35556a --- /dev/null +++ b/moto/athena/exceptions.py @@ -0,0 +1,19 @@ +from __future__ import unicode_literals + +import json +from werkzeug.exceptions import BadRequest + + +class AthenaClientError(BadRequest): + def __init__(self, code, message): + super(AthenaClientError, self).__init__() + self.description = json.dumps( + { + "Error": { + "Code": code, + "Message": message, + "Type": "InvalidRequestException", + }, + "RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1", + } + ) diff --git a/moto/athena/models.py b/moto/athena/models.py new file mode 100644 index 000000000..2f41046a9 --- /dev/null +++ b/moto/athena/models.py @@ -0,0 +1,81 @@ +from __future__ import unicode_literals +import time + +import boto3 +from moto.core import BaseBackend, BaseModel + +from moto.core import ACCOUNT_ID + + +class TaggableResourceMixin(object): + # This mixing was copied from Redshift when initially implementing + # Athena. TBD if it's worth the overhead. + + def __init__(self, region_name, resource_name, tags): + self.region = region_name + self.resource_name = resource_name + self.tags = tags or [] + + @property + def arn(self): + return "arn:aws:athena:{region}:{account_id}:{resource_name}".format( + region=self.region, account_id=ACCOUNT_ID, resource_name=self.resource_name + ) + + def create_tags(self, tags): + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] + self.tags.extend(tags) + return self.tags + + def delete_tags(self, tag_keys): + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] + return self.tags + + +class WorkGroup(TaggableResourceMixin, BaseModel): + + resource_type = "workgroup" + state = "ENABLED" + + def __init__(self, athena_backend, name, configuration, description, tags): + self.region_name = athena_backend.region_name + super(WorkGroup, self).__init__( + self.region_name, "workgroup/{}".format(name), tags + ) + self.athena_backend = athena_backend + self.name = name + self.description = description + self.configuration = configuration + + +class AthenaBackend(BaseBackend): + region_name = None + + def __init__(self, region_name=None): + if region_name is not None: + self.region_name = region_name + self.work_groups = {} + + def create_work_group(self, name, configuration, description, tags): + if name in self.work_groups: + return None + work_group = WorkGroup(self, name, configuration, description, tags) + self.work_groups[name] = work_group + return work_group + + def list_work_groups(self): + return [ + { + "Name": wg.name, + "State": wg.state, + "Description": wg.description, + "CreationTime": time.time(), + } + for wg in self.work_groups.values() + ] + + +athena_backends = {} +for region in boto3.Session().get_available_regions("athena"): + athena_backends[region] = AthenaBackend(region) diff --git a/moto/athena/responses.py b/moto/athena/responses.py new file mode 100644 index 000000000..80cac5d62 --- /dev/null +++ b/moto/athena/responses.py @@ -0,0 +1,41 @@ +import json + +from moto.core.responses import BaseResponse +from .models import athena_backends + + +class AthenaResponse(BaseResponse): + @property + def athena_backend(self): + return athena_backends[self.region] + + def create_work_group(self): + name = self._get_param("Name") + description = self._get_param("Description") + configuration = self._get_param("Configuration") + tags = self._get_param("Tags") + work_group = self.athena_backend.create_work_group( + name, configuration, description, tags + ) + if not work_group: + return ( + json.dumps( + { + "__type": "InvalidRequestException", + "Message": "WorkGroup already exists", + } + ), + dict(status=400), + ) + return json.dumps( + { + "CreateWorkGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } + } + } + ) + + def list_work_groups(self): + return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()}) diff --git a/moto/athena/urls.py b/moto/athena/urls.py new file mode 100644 index 000000000..4f8fdf7ee --- /dev/null +++ b/moto/athena/urls.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals +from .responses import AthenaResponse + +url_bases = ["https?://athena.(.+).amazonaws.com"] + +url_paths = {"{0}/$": AthenaResponse.dispatch} diff --git a/moto/athena/utils.py b/moto/athena/utils.py new file mode 100644 index 000000000..baffc4882 --- /dev/null +++ b/moto/athena/utils.py @@ -0,0 +1 @@ +from __future__ import unicode_literals diff --git a/moto/autoscaling/__init__.py b/moto/autoscaling/__init__.py index b2b8b0bae..13c1adb16 100644 --- a/moto/autoscaling/__init__.py +++ b/moto/autoscaling/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import autoscaling_backends from ..core.models import base_decorator, deprecated_base_decorator -autoscaling_backend = autoscaling_backends['us-east-1'] +autoscaling_backend = autoscaling_backends["us-east-1"] mock_autoscaling = base_decorator(autoscaling_backends) mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends) diff --git a/moto/autoscaling/exceptions.py b/moto/autoscaling/exceptions.py index 74f62241d..6f73eff8f 100644 --- a/moto/autoscaling/exceptions.py +++ b/moto/autoscaling/exceptions.py @@ -12,13 +12,12 @@ class ResourceContentionError(RESTError): def __init__(self): super(ResourceContentionError, self).__init__( "ResourceContentionError", - "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).") + "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).", + ) class InvalidInstanceError(AutoscalingClientError): - def __init__(self, instance_id): super(InvalidInstanceError, self).__init__( - "ValidationError", - "Instance [{0}] is invalid." - .format(instance_id)) + "ValidationError", "Instance [{0}] is invalid.".format(instance_id) + ) diff --git a/moto/autoscaling/models.py b/moto/autoscaling/models.py index 422075951..45ee7d192 100644 --- a/moto/autoscaling/models.py +++ b/moto/autoscaling/models.py @@ -12,7 +12,9 @@ from moto.elb import elb_backends from moto.elbv2 import elbv2_backends from moto.elb.exceptions import LoadBalancerNotFoundError from .exceptions import ( - AutoscalingClientError, ResourceContentionError, InvalidInstanceError + AutoscalingClientError, + ResourceContentionError, + InvalidInstanceError, ) # http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown @@ -22,8 +24,13 @@ ASG_NAME_TAG = "aws:autoscaling:groupName" class InstanceState(object): - def __init__(self, instance, lifecycle_state="InService", - health_status="Healthy", protected_from_scale_in=False): + def __init__( + self, + instance, + lifecycle_state="InService", + health_status="Healthy", + protected_from_scale_in=False, + ): self.instance = instance self.lifecycle_state = lifecycle_state self.health_status = health_status @@ -31,8 +38,16 @@ class InstanceState(object): class FakeScalingPolicy(BaseModel): - def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment, - cooldown, autoscaling_backend): + def __init__( + self, + name, + policy_type, + adjustment_type, + as_name, + scaling_adjustment, + cooldown, + autoscaling_backend, + ): self.name = name self.policy_type = policy_type self.adjustment_type = adjustment_type @@ -45,21 +60,38 @@ class FakeScalingPolicy(BaseModel): self.autoscaling_backend = autoscaling_backend def execute(self): - if self.adjustment_type == 'ExactCapacity': + if self.adjustment_type == "ExactCapacity": self.autoscaling_backend.set_desired_capacity( - self.as_name, self.scaling_adjustment) - elif self.adjustment_type == 'ChangeInCapacity': + self.as_name, self.scaling_adjustment + ) + elif self.adjustment_type == "ChangeInCapacity": self.autoscaling_backend.change_capacity( - self.as_name, self.scaling_adjustment) - elif self.adjustment_type == 'PercentChangeInCapacity': + self.as_name, self.scaling_adjustment + ) + elif self.adjustment_type == "PercentChangeInCapacity": self.autoscaling_backend.change_capacity_percent( - self.as_name, self.scaling_adjustment) + self.as_name, self.scaling_adjustment + ) class FakeLaunchConfiguration(BaseModel): - def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data, - instance_type, instance_monitoring, instance_profile_name, - spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict): + def __init__( + self, + name, + image_id, + key_name, + ramdisk_id, + kernel_id, + security_groups, + user_data, + instance_type, + instance_monitoring, + instance_profile_name, + spot_price, + ebs_optimized, + associate_public_ip_address, + block_device_mapping_dict, + ): self.name = name self.image_id = image_id self.key_name = key_name @@ -80,8 +112,8 @@ class FakeLaunchConfiguration(BaseModel): config = backend.create_launch_configuration( name=name, image_id=instance.image_id, - kernel_id='', - ramdisk_id='', + kernel_id="", + ramdisk_id="", key_name=instance.key_name, security_groups=instance.security_groups, user_data=instance.user_data, @@ -91,13 +123,15 @@ class FakeLaunchConfiguration(BaseModel): spot_price=None, ebs_optimized=instance.ebs_optimized, associate_public_ip_address=instance.associate_public_ip, - block_device_mappings=instance.block_device_mapping + block_device_mappings=instance.block_device_mapping, ) return config @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] instance_profile_name = properties.get("IamInstanceProfile") @@ -115,20 +149,26 @@ class FakeLaunchConfiguration(BaseModel): instance_profile_name=instance_profile_name, spot_price=properties.get("SpotPrice"), ebs_optimized=properties.get("EbsOptimized"), - associate_public_ip_address=properties.get( - "AssociatePublicIpAddress"), - block_device_mappings=properties.get("BlockDeviceMapping.member") + associate_public_ip_address=properties.get("AssociatePublicIpAddress"), + block_device_mappings=properties.get("BlockDeviceMapping.member"), ) return config @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = autoscaling_backends[region_name] try: backend.delete_launch_configuration(resource_name) @@ -153,34 +193,49 @@ class FakeLaunchConfiguration(BaseModel): @property def instance_monitoring_enabled(self): if self.instance_monitoring: - return 'true' - return 'false' + return "true" + return "false" def _parse_block_device_mappings(self): block_device_map = BlockDeviceMapping() for mapping in self.block_device_mapping_dict: block_type = BlockDeviceType() - mount_point = mapping.get('device_name') - if 'ephemeral' in mapping.get('virtual_name', ''): - block_type.ephemeral_name = mapping.get('virtual_name') + mount_point = mapping.get("device_name") + if "ephemeral" in mapping.get("virtual_name", ""): + block_type.ephemeral_name = mapping.get("virtual_name") else: - block_type.volume_type = mapping.get('ebs._volume_type') - block_type.snapshot_id = mapping.get('ebs._snapshot_id') + block_type.volume_type = mapping.get("ebs._volume_type") + block_type.snapshot_id = mapping.get("ebs._snapshot_id") block_type.delete_on_termination = mapping.get( - 'ebs._delete_on_termination') - block_type.size = mapping.get('ebs._volume_size') - block_type.iops = mapping.get('ebs._iops') + "ebs._delete_on_termination" + ) + block_type.size = mapping.get("ebs._volume_size") + block_type.iops = mapping.get("ebs._iops") block_device_map[mount_point] = block_type return block_device_map class FakeAutoScalingGroup(BaseModel): - def __init__(self, name, availability_zones, desired_capacity, max_size, - min_size, launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, health_check_type, - load_balancers, target_group_arns, placement_group, termination_policies, - autoscaling_backend, tags, - new_instances_protected_from_scale_in=False): + def __init__( + self, + name, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + load_balancers, + target_group_arns, + placement_group, + termination_policies, + autoscaling_backend, + tags, + new_instances_protected_from_scale_in=False, + ): self.autoscaling_backend = autoscaling_backend self.name = name @@ -190,17 +245,22 @@ class FakeAutoScalingGroup(BaseModel): self.min_size = min_size self.launch_config = self.autoscaling_backend.launch_configurations[ - launch_config_name] + launch_config_name + ] self.launch_config_name = launch_config_name - self.default_cooldown = default_cooldown if default_cooldown else DEFAULT_COOLDOWN + self.default_cooldown = ( + default_cooldown if default_cooldown else DEFAULT_COOLDOWN + ) self.health_check_period = health_check_period self.health_check_type = health_check_type if health_check_type else "EC2" self.load_balancers = load_balancers self.target_group_arns = target_group_arns self.placement_group = placement_group self.termination_policies = termination_policies - self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in + self.new_instances_protected_from_scale_in = ( + new_instances_protected_from_scale_in + ) self.suspended_processes = [] self.instance_states = [] @@ -215,8 +275,10 @@ class FakeAutoScalingGroup(BaseModel): if vpc_zone_identifier: # extract azs for vpcs - subnet_ids = vpc_zone_identifier.split(',') - subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(subnet_ids=subnet_ids) + subnet_ids = vpc_zone_identifier.split(",") + subnets = self.autoscaling_backend.ec2_backend.get_all_subnets( + subnet_ids=subnet_ids + ) vpc_zones = [subnet.availability_zone for subnet in subnets] if availability_zones and set(availability_zones) != set(vpc_zones): @@ -229,7 +291,7 @@ class FakeAutoScalingGroup(BaseModel): if not update: raise AutoscalingClientError( "ValidationError", - "At least one Availability Zone or VPC Subnet is required." + "At least one Availability Zone or VPC Subnet is required.", ) return @@ -237,8 +299,10 @@ class FakeAutoScalingGroup(BaseModel): self.vpc_zone_identifier = vpc_zone_identifier @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] launch_config_name = properties.get("LaunchConfigurationName") load_balancer_names = properties.get("LoadBalancerNames", []) @@ -253,7 +317,8 @@ class FakeAutoScalingGroup(BaseModel): min_size=properties.get("MinSize"), launch_config_name=launch_config_name, vpc_zone_identifier=( - ','.join(properties.get("VPCZoneIdentifier", [])) or None), + ",".join(properties.get("VPCZoneIdentifier", [])) or None + ), default_cooldown=properties.get("Cooldown"), health_check_period=properties.get("HealthCheckGracePeriod"), health_check_type=properties.get("HealthCheckType"), @@ -263,18 +328,26 @@ class FakeAutoScalingGroup(BaseModel): termination_policies=properties.get("TerminationPolicies", []), tags=properties.get("Tags", []), new_instances_protected_from_scale_in=properties.get( - "NewInstancesProtectedFromScaleIn", False) + "NewInstancesProtectedFromScaleIn", False + ), ) return group @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = autoscaling_backends[region_name] try: backend.delete_auto_scaling_group(resource_name) @@ -289,11 +362,21 @@ class FakeAutoScalingGroup(BaseModel): def physical_resource_id(self): return self.name - def update(self, availability_zones, desired_capacity, max_size, min_size, - launch_config_name, vpc_zone_identifier, default_cooldown, - health_check_period, health_check_type, - placement_group, termination_policies, - new_instances_protected_from_scale_in=None): + def update( + self, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + placement_group, + termination_policies, + new_instances_protected_from_scale_in=None, + ): self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier, update=True) if max_size is not None: @@ -309,14 +392,17 @@ class FakeAutoScalingGroup(BaseModel): if launch_config_name: self.launch_config = self.autoscaling_backend.launch_configurations[ - launch_config_name] + launch_config_name + ] self.launch_config_name = launch_config_name if health_check_period is not None: self.health_check_period = health_check_period if health_check_type is not None: self.health_check_type = health_check_type if new_instances_protected_from_scale_in is not None: - self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in + self.new_instances_protected_from_scale_in = ( + new_instances_protected_from_scale_in + ) if desired_capacity is not None: self.set_desired_capacity(desired_capacity) @@ -342,25 +428,30 @@ class FakeAutoScalingGroup(BaseModel): # Need to remove some instances count_to_remove = curr_instance_count - self.desired_capacity instances_to_remove = [ # only remove unprotected - state for state in self.instance_states + state + for state in self.instance_states if not state.protected_from_scale_in ][:count_to_remove] if instances_to_remove: # just in case not instances to remove instance_ids_to_remove = [ - instance.instance.id for instance in instances_to_remove] + instance.instance.id for instance in instances_to_remove + ] self.autoscaling_backend.ec2_backend.terminate_instances( - instance_ids_to_remove) - self.instance_states = list(set(self.instance_states) - set(instances_to_remove)) + instance_ids_to_remove + ) + self.instance_states = list( + set(self.instance_states) - set(instances_to_remove) + ) def get_propagated_tags(self): propagated_tags = {} for tag in self.tags: # boto uses 'propagate_at_launch # boto3 and cloudformation use PropagateAtLaunch - if 'propagate_at_launch' in tag and tag['propagate_at_launch'] == 'true': - propagated_tags[tag['key']] = tag['value'] - if 'PropagateAtLaunch' in tag and tag['PropagateAtLaunch']: - propagated_tags[tag['Key']] = tag['Value'] + if "propagate_at_launch" in tag and tag["propagate_at_launch"] == "true": + propagated_tags[tag["key"]] = tag["value"] + if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"]: + propagated_tags[tag["Key"]] = tag["Value"] return propagated_tags def replace_autoscaling_group_instances(self, count_needed, propagated_tags): @@ -371,15 +462,17 @@ class FakeAutoScalingGroup(BaseModel): self.launch_config.user_data, self.launch_config.security_groups, instance_type=self.launch_config.instance_type, - tags={'instance': propagated_tags}, + tags={"instance": propagated_tags}, placement=random.choice(self.availability_zones), ) for instance in reservation.instances: instance.autoscaling_group = self - self.instance_states.append(InstanceState( - instance, - protected_from_scale_in=self.new_instances_protected_from_scale_in, - )) + self.instance_states.append( + InstanceState( + instance, + protected_from_scale_in=self.new_instances_protected_from_scale_in, + ) + ) def append_target_groups(self, target_group_arns): append = [x for x in target_group_arns if x not in self.target_group_arns] @@ -402,10 +495,23 @@ class AutoScalingBackend(BaseBackend): self.__dict__ = {} self.__init__(ec2_backend, elb_backend, elbv2_backend) - def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id, - security_groups, user_data, instance_type, - instance_monitoring, instance_profile_name, - spot_price, ebs_optimized, associate_public_ip_address, block_device_mappings): + def create_launch_configuration( + self, + name, + image_id, + key_name, + kernel_id, + ramdisk_id, + security_groups, + user_data, + instance_type, + instance_monitoring, + instance_profile_name, + spot_price, + ebs_optimized, + associate_public_ip_address, + block_device_mappings, + ): launch_configuration = FakeLaunchConfiguration( name=name, image_id=image_id, @@ -428,23 +534,37 @@ class AutoScalingBackend(BaseBackend): def describe_launch_configurations(self, names): configurations = self.launch_configurations.values() if names: - return [configuration for configuration in configurations if configuration.name in names] + return [ + configuration + for configuration in configurations + if configuration.name in names + ] else: return list(configurations) def delete_launch_configuration(self, launch_configuration_name): self.launch_configurations.pop(launch_configuration_name, None) - def create_auto_scaling_group(self, name, availability_zones, - desired_capacity, max_size, min_size, - launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, - health_check_type, load_balancers, - target_group_arns, placement_group, - termination_policies, tags, - new_instances_protected_from_scale_in=False, - instance_id=None): - + def create_auto_scaling_group( + self, + name, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + load_balancers, + target_group_arns, + placement_group, + termination_policies, + tags, + new_instances_protected_from_scale_in=False, + instance_id=None, + ): def make_int(value): return int(value) if value is not None else value @@ -460,7 +580,9 @@ class AutoScalingBackend(BaseBackend): try: instance = self.ec2_backend.get_instance(instance_id) launch_config_name = name - FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self) + FakeLaunchConfiguration.create_from_instance( + launch_config_name, instance, self + ) except InvalidInstanceIdError: raise InvalidInstanceError(instance_id) @@ -489,19 +611,37 @@ class AutoScalingBackend(BaseBackend): self.update_attached_target_groups(group.name) return group - def update_auto_scaling_group(self, name, availability_zones, - desired_capacity, max_size, min_size, - launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, - health_check_type, placement_group, - termination_policies, - new_instances_protected_from_scale_in=None): + def update_auto_scaling_group( + self, + name, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + placement_group, + termination_policies, + new_instances_protected_from_scale_in=None, + ): group = self.autoscaling_groups[name] - group.update(availability_zones, desired_capacity, max_size, - min_size, launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, health_check_type, - placement_group, termination_policies, - new_instances_protected_from_scale_in=new_instances_protected_from_scale_in) + group.update( + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + placement_group, + termination_policies, + new_instances_protected_from_scale_in=new_instances_protected_from_scale_in, + ) return group def describe_auto_scaling_groups(self, names): @@ -537,32 +677,48 @@ class AutoScalingBackend(BaseBackend): for x in instance_ids ] for instance in new_instances: - self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) + self.ec2_backend.create_tags( + [instance.instance.id], {ASG_NAME_TAG: group.name} + ) group.instance_states.extend(new_instances) self.update_attached_elbs(group.name) - def set_instance_health(self, instance_id, health_status, should_respect_grace_period): + def set_instance_health( + self, instance_id, health_status, should_respect_grace_period + ): instance = self.ec2_backend.get_instance(instance_id) - instance_state = next(instance_state for group in self.autoscaling_groups.values() - for instance_state in group.instance_states if instance_state.instance.id == instance.id) + instance_state = next( + instance_state + for group in self.autoscaling_groups.values() + for instance_state in group.instance_states + if instance_state.instance.id == instance.id + ) instance_state.health_status = health_status def detach_instances(self, group_name, instance_ids, should_decrement): group = self.autoscaling_groups[group_name] original_size = len(group.instance_states) - detached_instances = [x for x in group.instance_states if x.instance.id in instance_ids] + detached_instances = [ + x for x in group.instance_states if x.instance.id in instance_ids + ] for instance in detached_instances: - self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) + self.ec2_backend.delete_tags( + [instance.instance.id], {ASG_NAME_TAG: group.name} + ) - new_instance_state = [x for x in group.instance_states if x.instance.id not in instance_ids] + new_instance_state = [ + x for x in group.instance_states if x.instance.id not in instance_ids + ] group.instance_states = new_instance_state if should_decrement: group.desired_capacity = original_size - len(instance_ids) else: count_needed = len(instance_ids) - group.replace_autoscaling_group_instances(count_needed, group.get_propagated_tags()) + group.replace_autoscaling_group_instances( + count_needed, group.get_propagated_tags() + ) self.update_attached_elbs(group_name) return detached_instances @@ -593,19 +749,32 @@ class AutoScalingBackend(BaseBackend): desired_capacity = int(desired_capacity) self.set_desired_capacity(group_name, desired_capacity) - def create_autoscaling_policy(self, name, policy_type, adjustment_type, as_name, - scaling_adjustment, cooldown): - policy = FakeScalingPolicy(name, policy_type, adjustment_type, as_name, - scaling_adjustment, cooldown, self) + def create_autoscaling_policy( + self, name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown + ): + policy = FakeScalingPolicy( + name, + policy_type, + adjustment_type, + as_name, + scaling_adjustment, + cooldown, + self, + ) self.policies[name] = policy return policy - def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None): - return [policy for policy in self.policies.values() - if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and - (not policy_names or policy.name in policy_names) and - (not policy_types or policy.policy_type in policy_types)] + def describe_policies( + self, autoscaling_group_name=None, policy_names=None, policy_types=None + ): + return [ + policy + for policy in self.policies.values() + if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) + and (not policy_names or policy.name in policy_names) + and (not policy_types or policy.policy_type in policy_types) + ] def delete_policy(self, group_name): self.policies.pop(group_name, None) @@ -616,16 +785,14 @@ class AutoScalingBackend(BaseBackend): def update_attached_elbs(self, group_name): group = self.autoscaling_groups[group_name] - group_instance_ids = set( - state.instance.id for state in group.instance_states) + group_instance_ids = set(state.instance.id for state in group.instance_states) # skip this if group.load_balancers is empty # otherwise elb_backend.describe_load_balancers returns all available load balancers if not group.load_balancers: return try: - elbs = self.elb_backend.describe_load_balancers( - names=group.load_balancers) + elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers) except LoadBalancerNotFoundError: # ELBs can be deleted before their autoscaling group return @@ -633,14 +800,15 @@ class AutoScalingBackend(BaseBackend): for elb in elbs: elb_instace_ids = set(elb.instance_ids) self.elb_backend.register_instances( - elb.name, group_instance_ids - elb_instace_ids) + elb.name, group_instance_ids - elb_instace_ids + ) self.elb_backend.deregister_instances( - elb.name, elb_instace_ids - group_instance_ids) + elb.name, elb_instace_ids - group_instance_ids + ) def update_attached_target_groups(self, group_name): group = self.autoscaling_groups[group_name] - group_instance_ids = set( - state.instance.id for state in group.instance_states) + group_instance_ids = set(state.instance.id for state in group.instance_states) # no action necessary if target_group_arns is empty if not group.target_group_arns: @@ -649,10 +817,13 @@ class AutoScalingBackend(BaseBackend): target_groups = self.elbv2_backend.describe_target_groups( target_group_arns=group.target_group_arns, load_balancer_arn=None, - names=None) + names=None, + ) for target_group in target_groups: - asg_targets = [{'id': x, 'port': target_group.port} for x in group_instance_ids] + asg_targets = [ + {"id": x, "port": target_group.port} for x in group_instance_ids + ] self.elbv2_backend.register_targets(target_group.arn, (asg_targets)) def create_or_update_tags(self, tags): @@ -670,7 +841,7 @@ class AutoScalingBackend(BaseBackend): new_tags.append(old_tag) # if key was never in old_tag's add it (create tag) - if not any(new_tag['key'] == tag['key'] for new_tag in new_tags): + if not any(new_tag["key"] == tag["key"] for new_tag in new_tags): new_tags.append(tag) group.tags = new_tags @@ -678,7 +849,8 @@ class AutoScalingBackend(BaseBackend): def attach_load_balancers(self, group_name, load_balancer_names): group = self.autoscaling_groups[group_name] group.load_balancers.extend( - [x for x in load_balancer_names if x not in group.load_balancers]) + [x for x in load_balancer_names if x not in group.load_balancers] + ) self.update_attached_elbs(group_name) def describe_load_balancers(self, group_name): @@ -686,13 +858,13 @@ class AutoScalingBackend(BaseBackend): def detach_load_balancers(self, group_name, load_balancer_names): group = self.autoscaling_groups[group_name] - group_instance_ids = set( - state.instance.id for state in group.instance_states) + group_instance_ids = set(state.instance.id for state in group.instance_states) elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers) for elb in elbs: - self.elb_backend.deregister_instances( - elb.name, group_instance_ids) - group.load_balancers = [x for x in group.load_balancers if x not in load_balancer_names] + self.elb_backend.deregister_instances(elb.name, group_instance_ids) + group.load_balancers = [ + x for x in group.load_balancers if x not in load_balancer_names + ] def attach_load_balancer_target_groups(self, group_name, target_group_arns): group = self.autoscaling_groups[group_name] @@ -704,36 +876,51 @@ class AutoScalingBackend(BaseBackend): def detach_load_balancer_target_groups(self, group_name, target_group_arns): group = self.autoscaling_groups[group_name] - group.target_group_arns = [x for x in group.target_group_arns if x not in target_group_arns] + group.target_group_arns = [ + x for x in group.target_group_arns if x not in target_group_arns + ] for target_group in target_group_arns: - asg_targets = [{'id': x.instance.id} for x in group.instance_states] + asg_targets = [{"id": x.instance.id} for x in group.instance_states] self.elbv2_backend.deregister_targets(target_group, (asg_targets)) def suspend_processes(self, group_name, scaling_processes): group = self.autoscaling_groups[group_name] group.suspended_processes = scaling_processes or [] - def set_instance_protection(self, group_name, instance_ids, protected_from_scale_in): + def set_instance_protection( + self, group_name, instance_ids, protected_from_scale_in + ): group = self.autoscaling_groups[group_name] protected_instances = [ - x for x in group.instance_states if x.instance.id in instance_ids] + x for x in group.instance_states if x.instance.id in instance_ids + ] for instance in protected_instances: instance.protected_from_scale_in = protected_from_scale_in def notify_terminate_instances(self, instance_ids): - for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items(): + for ( + autoscaling_group_name, + autoscaling_group, + ) in self.autoscaling_groups.items(): original_instance_count = len(autoscaling_group.instance_states) - autoscaling_group.instance_states = list(filter( - lambda i_state: i_state.instance.id not in instance_ids, + autoscaling_group.instance_states = list( + filter( + lambda i_state: i_state.instance.id not in instance_ids, + autoscaling_group.instance_states, + ) + ) + difference = original_instance_count - len( autoscaling_group.instance_states - )) - difference = original_instance_count - len(autoscaling_group.instance_states) + ) if difference > 0: - autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags()) + autoscaling_group.replace_autoscaling_group_instances( + difference, autoscaling_group.get_propagated_tags() + ) self.update_attached_elbs(autoscaling_group_name) autoscaling_backends = {} for region, ec2_backend in ec2_backends.items(): autoscaling_backends[region] = AutoScalingBackend( - ec2_backend, elb_backends[region], elbv2_backends[region]) + ec2_backend, elb_backends[region], elbv2_backends[region] + ) diff --git a/moto/autoscaling/responses.py b/moto/autoscaling/responses.py index 5e409aafb..83e2f7d5a 100644 --- a/moto/autoscaling/responses.py +++ b/moto/autoscaling/responses.py @@ -6,88 +6,88 @@ from .models import autoscaling_backends class AutoScalingResponse(BaseResponse): - @property def autoscaling_backend(self): return autoscaling_backends[self.region] def create_launch_configuration(self): - instance_monitoring_string = self._get_param( - 'InstanceMonitoring.Enabled') - if instance_monitoring_string == 'true': + instance_monitoring_string = self._get_param("InstanceMonitoring.Enabled") + if instance_monitoring_string == "true": instance_monitoring = True else: instance_monitoring = False self.autoscaling_backend.create_launch_configuration( - name=self._get_param('LaunchConfigurationName'), - image_id=self._get_param('ImageId'), - key_name=self._get_param('KeyName'), - ramdisk_id=self._get_param('RamdiskId'), - kernel_id=self._get_param('KernelId'), - security_groups=self._get_multi_param('SecurityGroups.member'), - user_data=self._get_param('UserData'), - instance_type=self._get_param('InstanceType'), + name=self._get_param("LaunchConfigurationName"), + image_id=self._get_param("ImageId"), + key_name=self._get_param("KeyName"), + ramdisk_id=self._get_param("RamdiskId"), + kernel_id=self._get_param("KernelId"), + security_groups=self._get_multi_param("SecurityGroups.member"), + user_data=self._get_param("UserData"), + instance_type=self._get_param("InstanceType"), instance_monitoring=instance_monitoring, - instance_profile_name=self._get_param('IamInstanceProfile'), - spot_price=self._get_param('SpotPrice'), - ebs_optimized=self._get_param('EbsOptimized'), - associate_public_ip_address=self._get_param( - "AssociatePublicIpAddress"), - block_device_mappings=self._get_list_prefix( - 'BlockDeviceMappings.member') + instance_profile_name=self._get_param("IamInstanceProfile"), + spot_price=self._get_param("SpotPrice"), + ebs_optimized=self._get_param("EbsOptimized"), + associate_public_ip_address=self._get_param("AssociatePublicIpAddress"), + block_device_mappings=self._get_list_prefix("BlockDeviceMappings.member"), ) template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE) return template.render() def describe_launch_configurations(self): - names = self._get_multi_param('LaunchConfigurationNames.member') - all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(names) - marker = self._get_param('NextToken') + names = self._get_multi_param("LaunchConfigurationNames.member") + all_launch_configurations = self.autoscaling_backend.describe_launch_configurations( + names + ) + marker = self._get_param("NextToken") all_names = [lc.name for lc in all_launch_configurations] if marker: start = all_names.index(marker) + 1 else: start = 0 - max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier - launch_configurations_resp = all_launch_configurations[start:start + max_records] + max_records = self._get_int_param( + "MaxRecords", 50 + ) # the default is 100, but using 50 to make testing easier + launch_configurations_resp = all_launch_configurations[ + start : start + max_records + ] next_token = None if len(all_launch_configurations) > start + max_records: next_token = launch_configurations_resp[-1].name - template = self.response_template( - DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE) - return template.render(launch_configurations=launch_configurations_resp, next_token=next_token) + template = self.response_template(DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE) + return template.render( + launch_configurations=launch_configurations_resp, next_token=next_token + ) def delete_launch_configuration(self): - launch_configurations_name = self.querystring.get( - 'LaunchConfigurationName')[0] - self.autoscaling_backend.delete_launch_configuration( - launch_configurations_name) + launch_configurations_name = self.querystring.get("LaunchConfigurationName")[0] + self.autoscaling_backend.delete_launch_configuration(launch_configurations_name) template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE) return template.render() def create_auto_scaling_group(self): self.autoscaling_backend.create_auto_scaling_group( - name=self._get_param('AutoScalingGroupName'), - availability_zones=self._get_multi_param( - 'AvailabilityZones.member'), - desired_capacity=self._get_int_param('DesiredCapacity'), - max_size=self._get_int_param('MaxSize'), - min_size=self._get_int_param('MinSize'), - instance_id=self._get_param('InstanceId'), - launch_config_name=self._get_param('LaunchConfigurationName'), - vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), - default_cooldown=self._get_int_param('DefaultCooldown'), - health_check_period=self._get_int_param('HealthCheckGracePeriod'), - health_check_type=self._get_param('HealthCheckType'), - load_balancers=self._get_multi_param('LoadBalancerNames.member'), - target_group_arns=self._get_multi_param('TargetGroupARNs.member'), - placement_group=self._get_param('PlacementGroup'), - termination_policies=self._get_multi_param( - 'TerminationPolicies.member'), - tags=self._get_list_prefix('Tags.member'), + name=self._get_param("AutoScalingGroupName"), + availability_zones=self._get_multi_param("AvailabilityZones.member"), + desired_capacity=self._get_int_param("DesiredCapacity"), + max_size=self._get_int_param("MaxSize"), + min_size=self._get_int_param("MinSize"), + instance_id=self._get_param("InstanceId"), + launch_config_name=self._get_param("LaunchConfigurationName"), + vpc_zone_identifier=self._get_param("VPCZoneIdentifier"), + default_cooldown=self._get_int_param("DefaultCooldown"), + health_check_period=self._get_int_param("HealthCheckGracePeriod"), + health_check_type=self._get_param("HealthCheckType"), + load_balancers=self._get_multi_param("LoadBalancerNames.member"), + target_group_arns=self._get_multi_param("TargetGroupARNs.member"), + placement_group=self._get_param("PlacementGroup"), + termination_policies=self._get_multi_param("TerminationPolicies.member"), + tags=self._get_list_prefix("Tags.member"), new_instances_protected_from_scale_in=self._get_bool_param( - 'NewInstancesProtectedFromScaleIn', False) + "NewInstancesProtectedFromScaleIn", False + ), ) template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE) return template.render() @@ -95,68 +95,73 @@ class AutoScalingResponse(BaseResponse): @amz_crc32 @amzn_request_id def attach_instances(self): - group_name = self._get_param('AutoScalingGroupName') - instance_ids = self._get_multi_param('InstanceIds.member') - self.autoscaling_backend.attach_instances( - group_name, instance_ids) + group_name = self._get_param("AutoScalingGroupName") + instance_ids = self._get_multi_param("InstanceIds.member") + self.autoscaling_backend.attach_instances(group_name, instance_ids) template = self.response_template(ATTACH_INSTANCES_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def set_instance_health(self): - instance_id = self._get_param('InstanceId') + instance_id = self._get_param("InstanceId") health_status = self._get_param("HealthStatus") - if health_status not in ['Healthy', 'Unhealthy']: - raise ValueError('Valid instance health states are: [Healthy, Unhealthy]') + if health_status not in ["Healthy", "Unhealthy"]: + raise ValueError("Valid instance health states are: [Healthy, Unhealthy]") should_respect_grace_period = self._get_param("ShouldRespectGracePeriod") - self.autoscaling_backend.set_instance_health(instance_id, health_status, should_respect_grace_period) + self.autoscaling_backend.set_instance_health( + instance_id, health_status, should_respect_grace_period + ) template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def detach_instances(self): - group_name = self._get_param('AutoScalingGroupName') - instance_ids = self._get_multi_param('InstanceIds.member') - should_decrement_string = self._get_param('ShouldDecrementDesiredCapacity') - if should_decrement_string == 'true': + group_name = self._get_param("AutoScalingGroupName") + instance_ids = self._get_multi_param("InstanceIds.member") + should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity") + if should_decrement_string == "true": should_decrement = True else: should_decrement = False detached_instances = self.autoscaling_backend.detach_instances( - group_name, instance_ids, should_decrement) + group_name, instance_ids, should_decrement + ) template = self.response_template(DETACH_INSTANCES_TEMPLATE) return template.render(detached_instances=detached_instances) @amz_crc32 @amzn_request_id def attach_load_balancer_target_groups(self): - group_name = self._get_param('AutoScalingGroupName') - target_group_arns = self._get_multi_param('TargetGroupARNs.member') + group_name = self._get_param("AutoScalingGroupName") + target_group_arns = self._get_multi_param("TargetGroupARNs.member") self.autoscaling_backend.attach_load_balancer_target_groups( - group_name, target_group_arns) + group_name, target_group_arns + ) template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def describe_load_balancer_target_groups(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups( - group_name) + group_name + ) template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS) return template.render(target_group_arns=target_group_arns) @amz_crc32 @amzn_request_id def detach_load_balancer_target_groups(self): - group_name = self._get_param('AutoScalingGroupName') - target_group_arns = self._get_multi_param('TargetGroupARNs.member') + group_name = self._get_param("AutoScalingGroupName") + target_group_arns = self._get_multi_param("TargetGroupARNs.member") self.autoscaling_backend.detach_load_balancer_target_groups( - group_name, target_group_arns) + group_name, target_group_arns + ) template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) return template.render() @@ -172,7 +177,7 @@ class AutoScalingResponse(BaseResponse): max_records = self._get_int_param("MaxRecords", 50) if max_records > 100: raise ValueError - groups = all_groups[start:start + max_records] + groups = all_groups[start : start + max_records] next_token = None if max_records and len(all_groups) > start + max_records: next_token = groups[-1].name @@ -181,42 +186,40 @@ class AutoScalingResponse(BaseResponse): def update_auto_scaling_group(self): self.autoscaling_backend.update_auto_scaling_group( - name=self._get_param('AutoScalingGroupName'), - availability_zones=self._get_multi_param( - 'AvailabilityZones.member'), - desired_capacity=self._get_int_param('DesiredCapacity'), - max_size=self._get_int_param('MaxSize'), - min_size=self._get_int_param('MinSize'), - launch_config_name=self._get_param('LaunchConfigurationName'), - vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), - default_cooldown=self._get_int_param('DefaultCooldown'), - health_check_period=self._get_int_param('HealthCheckGracePeriod'), - health_check_type=self._get_param('HealthCheckType'), - placement_group=self._get_param('PlacementGroup'), - termination_policies=self._get_multi_param( - 'TerminationPolicies.member'), + name=self._get_param("AutoScalingGroupName"), + availability_zones=self._get_multi_param("AvailabilityZones.member"), + desired_capacity=self._get_int_param("DesiredCapacity"), + max_size=self._get_int_param("MaxSize"), + min_size=self._get_int_param("MinSize"), + launch_config_name=self._get_param("LaunchConfigurationName"), + vpc_zone_identifier=self._get_param("VPCZoneIdentifier"), + default_cooldown=self._get_int_param("DefaultCooldown"), + health_check_period=self._get_int_param("HealthCheckGracePeriod"), + health_check_type=self._get_param("HealthCheckType"), + placement_group=self._get_param("PlacementGroup"), + termination_policies=self._get_multi_param("TerminationPolicies.member"), new_instances_protected_from_scale_in=self._get_bool_param( - 'NewInstancesProtectedFromScaleIn', None) + "NewInstancesProtectedFromScaleIn", None + ), ) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) return template.render() def delete_auto_scaling_group(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") self.autoscaling_backend.delete_auto_scaling_group(group_name) template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE) return template.render() def set_desired_capacity(self): - group_name = self._get_param('AutoScalingGroupName') - desired_capacity = self._get_int_param('DesiredCapacity') - self.autoscaling_backend.set_desired_capacity( - group_name, desired_capacity) + group_name = self._get_param("AutoScalingGroupName") + desired_capacity = self._get_int_param("DesiredCapacity") + self.autoscaling_backend.set_desired_capacity(group_name, desired_capacity) template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE) return template.render() def create_or_update_tags(self): - tags = self._get_list_prefix('Tags.member') + tags = self._get_list_prefix("Tags.member") self.autoscaling_backend.create_or_update_tags(tags) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) @@ -224,38 +227,38 @@ class AutoScalingResponse(BaseResponse): def describe_auto_scaling_instances(self): instance_states = self.autoscaling_backend.describe_auto_scaling_instances() - template = self.response_template( - DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE) + template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE) return template.render(instance_states=instance_states) def put_scaling_policy(self): policy = self.autoscaling_backend.create_autoscaling_policy( - name=self._get_param('PolicyName'), - policy_type=self._get_param('PolicyType'), - adjustment_type=self._get_param('AdjustmentType'), - as_name=self._get_param('AutoScalingGroupName'), - scaling_adjustment=self._get_int_param('ScalingAdjustment'), - cooldown=self._get_int_param('Cooldown'), + name=self._get_param("PolicyName"), + policy_type=self._get_param("PolicyType"), + adjustment_type=self._get_param("AdjustmentType"), + as_name=self._get_param("AutoScalingGroupName"), + scaling_adjustment=self._get_int_param("ScalingAdjustment"), + cooldown=self._get_int_param("Cooldown"), ) template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE) return template.render(policy=policy) def describe_policies(self): policies = self.autoscaling_backend.describe_policies( - autoscaling_group_name=self._get_param('AutoScalingGroupName'), - policy_names=self._get_multi_param('PolicyNames.member'), - policy_types=self._get_multi_param('PolicyTypes.member')) + autoscaling_group_name=self._get_param("AutoScalingGroupName"), + policy_names=self._get_multi_param("PolicyNames.member"), + policy_types=self._get_multi_param("PolicyTypes.member"), + ) template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE) return template.render(policies=policies) def delete_policy(self): - group_name = self._get_param('PolicyName') + group_name = self._get_param("PolicyName") self.autoscaling_backend.delete_policy(group_name) template = self.response_template(DELETE_POLICY_TEMPLATE) return template.render() def execute_policy(self): - group_name = self._get_param('PolicyName') + group_name = self._get_param("PolicyName") self.autoscaling_backend.execute_policy(group_name) template = self.response_template(EXECUTE_POLICY_TEMPLATE) return template.render() @@ -263,17 +266,16 @@ class AutoScalingResponse(BaseResponse): @amz_crc32 @amzn_request_id def attach_load_balancers(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") load_balancer_names = self._get_multi_param("LoadBalancerNames.member") - self.autoscaling_backend.attach_load_balancers( - group_name, load_balancer_names) + self.autoscaling_backend.attach_load_balancers(group_name, load_balancer_names) template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def describe_load_balancers(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") load_balancers = self.autoscaling_backend.describe_load_balancers(group_name) template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE) return template.render(load_balancers=load_balancers) @@ -281,26 +283,28 @@ class AutoScalingResponse(BaseResponse): @amz_crc32 @amzn_request_id def detach_load_balancers(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") load_balancer_names = self._get_multi_param("LoadBalancerNames.member") - self.autoscaling_backend.detach_load_balancers( - group_name, load_balancer_names) + self.autoscaling_backend.detach_load_balancers(group_name, load_balancer_names) template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE) return template.render() def suspend_processes(self): - autoscaling_group_name = self._get_param('AutoScalingGroupName') - scaling_processes = self._get_multi_param('ScalingProcesses.member') - self.autoscaling_backend.suspend_processes(autoscaling_group_name, scaling_processes) + autoscaling_group_name = self._get_param("AutoScalingGroupName") + scaling_processes = self._get_multi_param("ScalingProcesses.member") + self.autoscaling_backend.suspend_processes( + autoscaling_group_name, scaling_processes + ) template = self.response_template(SUSPEND_PROCESSES_TEMPLATE) return template.render() def set_instance_protection(self): - group_name = self._get_param('AutoScalingGroupName') - instance_ids = self._get_multi_param('InstanceIds.member') - protected_from_scale_in = self._get_bool_param('ProtectedFromScaleIn') + group_name = self._get_param("AutoScalingGroupName") + instance_ids = self._get_multi_param("InstanceIds.member") + protected_from_scale_in = self._get_bool_param("ProtectedFromScaleIn") self.autoscaling_backend.set_instance_protection( - group_name, instance_ids, protected_from_scale_in) + group_name, instance_ids, protected_from_scale_in + ) template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE) return template.render() diff --git a/moto/autoscaling/urls.py b/moto/autoscaling/urls.py index 0743fdcf7..5fb33c25d 100644 --- a/moto/autoscaling/urls.py +++ b/moto/autoscaling/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import AutoScalingResponse -url_bases = [ - "https?://autoscaling.(.+).amazonaws.com", -] +url_bases = ["https?://autoscaling.(.+).amazonaws.com"] -url_paths = { - '{0}/$': AutoScalingResponse.dispatch, -} +url_paths = {"{0}/$": AutoScalingResponse.dispatch} diff --git a/moto/awslambda/__init__.py b/moto/awslambda/__init__.py index f0d694654..d40bf051a 100644 --- a/moto/awslambda/__init__.py +++ b/moto/awslambda/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import lambda_backends from ..core.models import base_decorator, deprecated_base_decorator -lambda_backend = lambda_backends['us-east-1'] +lambda_backend = lambda_backends["us-east-1"] mock_lambda = base_decorator(lambda_backends) mock_lambda_deprecated = deprecated_base_decorator(lambda_backends) diff --git a/moto/awslambda/exceptions.py b/moto/awslambda/exceptions.py new file mode 100644 index 000000000..1a82977c3 --- /dev/null +++ b/moto/awslambda/exceptions.py @@ -0,0 +1,31 @@ +from botocore.client import ClientError + + +class LambdaClientError(ClientError): + def __init__(self, error, message): + error_response = {"Error": {"Code": error, "Message": message}} + super(LambdaClientError, self).__init__(error_response, None) + + +class CrossAccountNotAllowed(LambdaClientError): + def __init__(self): + super(CrossAccountNotAllowed, self).__init__( + "AccessDeniedException", "Cross-account pass role is not allowed." + ) + + +class InvalidParameterValueException(LambdaClientError): + def __init__(self, message): + super(InvalidParameterValueException, self).__init__( + "InvalidParameterValueException", message + ) + + +class InvalidRoleFormat(LambdaClientError): + pattern = r"arn:(aws[a-zA-Z-]*)?:iam::(\d{12}):role/?[a-zA-Z_0-9+=,.@\-_/]+" + + def __init__(self, role): + message = "1 validation error detected: Value '{0}' at 'role' failed to satisfy constraint: Member must satisfy regular expression pattern: {1}".format( + role, InvalidRoleFormat.pattern + ) + super(InvalidRoleFormat, self).__init__("ValidationException", message) diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index acc7a5257..b1b8f57a8 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -26,39 +26,50 @@ import requests.adapters import boto.awslambda from moto.core import BaseBackend, BaseModel from moto.core.exceptions import RESTError +from moto.iam.models import iam_backend +from moto.iam.exceptions import IAMNotFoundException from moto.core.utils import unix_time_millis from moto.s3.models import s3_backend from moto.logs.models import logs_backends from moto.s3.exceptions import MissingBucket, MissingKey from moto import settings +from .exceptions import ( + CrossAccountNotAllowed, + InvalidRoleFormat, + InvalidParameterValueException, +) from .utils import make_function_arn, make_function_ver_arn from moto.sqs import sqs_backends +from moto.dynamodb2 import dynamodb_backends2 +from moto.dynamodbstreams import dynamodbstreams_backends +from moto.core import ACCOUNT_ID logger = logging.getLogger(__name__) -ACCOUNT_ID = '123456789012' - try: from tempfile import TemporaryDirectory except ImportError: from backports.tempfile import TemporaryDirectory - -_stderr_regex = re.compile(r'START|END|REPORT RequestId: .*') +# The lambci container is returning a special escape character for the "RequestID" fields. Unicode 033: +# _stderr_regex = re.compile(r"START|END|REPORT RequestId: .*") +_stderr_regex = re.compile(r"\033\[\d+.*") _orig_adapter_send = requests.adapters.HTTPAdapter.send -docker_3 = docker.__version__[0] >= '3' +docker_3 = docker.__version__[0] >= "3" def zip2tar(zip_bytes): with TemporaryDirectory() as td: - tarname = os.path.join(td, 'data.tar') - timeshift = int((datetime.datetime.now() - - datetime.datetime.utcnow()).total_seconds()) - with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zipf, \ - tarfile.TarFile(tarname, 'w') as tarf: + tarname = os.path.join(td, "data.tar") + timeshift = int( + (datetime.datetime.now() - datetime.datetime.utcnow()).total_seconds() + ) + with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zipf, tarfile.TarFile( + tarname, "w" + ) as tarf: for zipinfo in zipf.infolist(): - if zipinfo.filename[-1] == '/': # is_dir() is py3.6+ + if zipinfo.filename[-1] == "/": # is_dir() is py3.6+ continue tarinfo = tarfile.TarInfo(name=zipinfo.filename) @@ -67,7 +78,7 @@ def zip2tar(zip_bytes): infile = zipf.open(zipinfo.filename) tarf.addfile(tarinfo, infile) - with open(tarname, 'rb') as f: + with open(tarname, "rb") as f: tar_data = f.read() return tar_data @@ -81,7 +92,9 @@ class _VolumeRefCount: class _DockerDataVolumeContext: - _data_vol_map = defaultdict(lambda: _VolumeRefCount(0, None)) # {sha256: _VolumeRefCount} + _data_vol_map = defaultdict( + lambda: _VolumeRefCount(0, None) + ) # {sha256: _VolumeRefCount} _lock = threading.Lock() def __init__(self, lambda_func): @@ -107,15 +120,19 @@ class _DockerDataVolumeContext: return self # It doesn't exist so we need to create it - self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(self._lambda_func.code_sha_256) + self._vol_ref.volume = self._lambda_func.docker_client.volumes.create( + self._lambda_func.code_sha_256 + ) if docker_3: - volumes = {self.name: {'bind': '/tmp/data', 'mode': 'rw'}} + volumes = {self.name: {"bind": "/tmp/data", "mode": "rw"}} else: - volumes = {self.name: '/tmp/data'} - container = self._lambda_func.docker_client.containers.run('alpine', 'sleep 100', volumes=volumes, detach=True) + volumes = {self.name: "/tmp/data"} + container = self._lambda_func.docker_client.containers.run( + "alpine", "sleep 100", volumes=volumes, detach=True + ) try: tar_bytes = zip2tar(self._lambda_func.code_bytes) - container.put_archive('/tmp/data', tar_bytes) + container.put_archive("/tmp/data", tar_bytes) finally: container.remove(force=True) @@ -138,13 +155,13 @@ class LambdaFunction(BaseModel): def __init__(self, spec, region, validate_s3=True, version=1): # required self.region = region - self.code = spec['Code'] - self.function_name = spec['FunctionName'] - self.handler = spec['Handler'] - self.role = spec['Role'] - self.run_time = spec['Runtime'] + self.code = spec["Code"] + self.function_name = spec["FunctionName"] + self.handler = spec["Handler"] + self.role = spec["Role"] + self.run_time = spec["Runtime"] self.logs_backend = logs_backends[self.region] - self.environment_vars = spec.get('Environment', {}).get('Variables', {}) + self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.docker_client = docker.from_env() self.policy = "" @@ -159,77 +176,81 @@ class LambdaFunction(BaseModel): if isinstance(adapter, requests.adapters.HTTPAdapter): adapter.send = functools.partial(_orig_adapter_send, adapter) return adapter + self.docker_client.api.get_adapter = replace_adapter_send # optional - self.description = spec.get('Description', '') - self.memory_size = spec.get('MemorySize', 128) - self.publish = spec.get('Publish', False) # this is ignored currently - self.timeout = spec.get('Timeout', 3) + self.description = spec.get("Description", "") + self.memory_size = spec.get("MemorySize", 128) + self.publish = spec.get("Publish", False) # this is ignored currently + self.timeout = spec.get("Timeout", 3) - self.logs_group_name = '/aws/lambda/{}'.format(self.function_name) + self.logs_group_name = "/aws/lambda/{}".format(self.function_name) self.logs_backend.ensure_log_group(self.logs_group_name, []) # this isn't finished yet. it needs to find out the VpcId value self._vpc_config = spec.get( - 'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []}) + "VpcConfig", {"SubnetIds": [], "SecurityGroupIds": []} + ) # auto-generated self.version = version - self.last_modified = datetime.datetime.utcnow().strftime( - '%Y-%m-%d %H:%M:%S') + self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") - if 'ZipFile' in self.code: + if "ZipFile" in self.code: # more hackery to handle unicode/bytes/str in python3 and python2 - # argh! try: - to_unzip_code = base64.b64decode( - bytes(self.code['ZipFile'], 'utf-8')) + to_unzip_code = base64.b64decode(bytes(self.code["ZipFile"], "utf-8")) except Exception: - to_unzip_code = base64.b64decode(self.code['ZipFile']) + to_unzip_code = base64.b64decode(self.code["ZipFile"]) self.code_bytes = to_unzip_code self.code_size = len(to_unzip_code) self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest() # TODO: we should be putting this in a lambda bucket - self.code['UUID'] = str(uuid.uuid4()) - self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID']) + self.code["UUID"] = str(uuid.uuid4()) + self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"]) else: # validate s3 bucket and key key = None try: # FIXME: does not validate bucket region - key = s3_backend.get_key( - self.code['S3Bucket'], self.code['S3Key']) + key = s3_backend.get_key(self.code["S3Bucket"], self.code["S3Key"]) except MissingBucket: if do_validate_s3(): - raise ValueError( - "InvalidParameterValueException", - "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist") + raise InvalidParameterValueException( + "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist" + ) except MissingKey: if do_validate_s3(): raise ValueError( "InvalidParameterValueException", - "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.") + "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.", + ) if key: self.code_bytes = key.value self.code_size = key.size self.code_sha_256 = hashlib.sha256(key.value).hexdigest() - self.function_arn = make_function_arn(self.region, ACCOUNT_ID, self.function_name) + self.function_arn = make_function_arn( + self.region, ACCOUNT_ID, self.function_name + ) self.tags = dict() def set_version(self, version): - self.function_arn = make_function_ver_arn(self.region, ACCOUNT_ID, self.function_name, version) + self.function_arn = make_function_ver_arn( + self.region, ACCOUNT_ID, self.function_name, version + ) self.version = version - self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') + self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") @property def vpc_config(self): config = self._vpc_config.copy() - if config['SecurityGroupIds']: + if config["SecurityGroupIds"]: config.update({"VpcId": "vpc-123abc"}) return config @@ -258,36 +279,101 @@ class LambdaFunction(BaseModel): } if self.environment_vars: - config['Environment'] = { - 'Variables': self.environment_vars - } + config["Environment"] = {"Variables": self.environment_vars} return config def get_code(self): return { "Code": { - "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(self.region, self.code['S3Key']), - "RepositoryType": "S3" + "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format( + self.region, self.code["S3Key"] + ), + "RepositoryType": "S3", }, "Configuration": self.get_configuration(), } + def update_configuration(self, config_updates): + for key, value in config_updates.items(): + if key == "Description": + self.description = value + elif key == "Handler": + self.handler = value + elif key == "MemorySize": + self.memory_size = value + elif key == "Role": + self.role = value + elif key == "Runtime": + self.run_time = value + elif key == "Timeout": + self.timeout = value + elif key == "VpcConfig": + self.vpc_config = value + elif key == "Environment": + self.environment_vars = value["Variables"] + + return self.get_configuration() + + def update_function_code(self, updated_spec): + if "DryRun" in updated_spec and updated_spec["DryRun"]: + return self.get_configuration() + + if "ZipFile" in updated_spec: + self.code["ZipFile"] = updated_spec["ZipFile"] + + # using the "hackery" from __init__ because it seems to work + # TODOs and FIXMEs included, because they'll need to be fixed + # in both places now + try: + to_unzip_code = base64.b64decode( + bytes(updated_spec["ZipFile"], "utf-8") + ) + except Exception: + to_unzip_code = base64.b64decode(updated_spec["ZipFile"]) + + self.code_bytes = to_unzip_code + self.code_size = len(to_unzip_code) + self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest() + + # TODO: we should be putting this in a lambda bucket + self.code["UUID"] = str(uuid.uuid4()) + self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"]) + elif "S3Bucket" in updated_spec and "S3Key" in updated_spec: + key = None + try: + # FIXME: does not validate bucket region + key = s3_backend.get_key( + updated_spec["S3Bucket"], updated_spec["S3Key"] + ) + except MissingBucket: + if do_validate_s3(): + raise ValueError( + "InvalidParameterValueException", + "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist", + ) + except MissingKey: + if do_validate_s3(): + raise ValueError( + "InvalidParameterValueException", + "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.", + ) + if key: + self.code_bytes = key.value + self.code_size = key.size + self.code_sha_256 = hashlib.sha256(key.value).hexdigest() + self.code["S3Bucket"] = updated_spec["S3Bucket"] + self.code["S3Key"] = updated_spec["S3Key"] + + return self.get_configuration() + @staticmethod def convert(s): try: - return str(s, encoding='utf-8') + return str(s, encoding="utf-8") except Exception: return s - @staticmethod - def is_json(test_str): - try: - response = json.loads(test_str) - except Exception: - response = test_str - return response - def _invoke_lambda(self, code, event=None, context=None): # TODO: context not yet implemented if event is None: @@ -312,12 +398,21 @@ class LambdaFunction(BaseModel): container = output = exit_code = None with _DockerDataVolumeContext(self) as data_vol: try: - run_kwargs = dict(links={'motoserver': 'motoserver'}) if settings.TEST_SERVER_MODE else {} + run_kwargs = ( + dict(links={"motoserver": "motoserver"}) + if settings.TEST_SERVER_MODE + else {} + ) container = self.docker_client.containers.run( "lambci/lambda:{}".format(self.run_time), - [self.handler, json.dumps(event)], remove=False, + [self.handler, json.dumps(event)], + remove=False, mem_limit="{}m".format(self.memory_size), - volumes=["{}:/var/task".format(data_vol.name)], environment=env_vars, detach=True, **run_kwargs) + volumes=["{}:/var/task".format(data_vol.name)], + environment=env_vars, + detach=True, + **run_kwargs + ) finally: if container: try: @@ -328,32 +423,43 @@ class LambdaFunction(BaseModel): container.kill() else: if docker_3: - exit_code = exit_code['StatusCode'] + exit_code = exit_code["StatusCode"] output = container.logs(stdout=False, stderr=True) output += container.logs(stdout=True, stderr=False) container.remove() - output = output.decode('utf-8') + output = output.decode("utf-8") # Send output to "logs" backend invoke_id = uuid.uuid4().hex log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format( - date=datetime.datetime.utcnow(), version=self.version, invoke_id=invoke_id + date=datetime.datetime.utcnow(), + version=self.version, + invoke_id=invoke_id, ) self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name) - log_events = [{'timestamp': unix_time_millis(), "message": line} - for line in output.splitlines()] - self.logs_backend.put_log_events(self.logs_group_name, log_stream_name, log_events, None) + log_events = [ + {"timestamp": unix_time_millis(), "message": line} + for line in output.splitlines() + ] + self.logs_backend.put_log_events( + self.logs_group_name, log_stream_name, log_events, None + ) if exit_code != 0: - raise Exception( - 'lambda invoke failed output: {}'.format(output)) + raise Exception("lambda invoke failed output: {}".format(output)) - # strip out RequestId lines - output = os.linesep.join([line for line in self.convert(output).splitlines() if not _stderr_regex.match(line)]) + # strip out RequestId lines (TODO: This will return an additional '\n' in the response) + output = os.linesep.join( + [ + line + for line in self.convert(output).splitlines() + if not _stderr_regex.match(line) + ] + ) return output, False except BaseException as e: traceback.print_exc() @@ -368,31 +474,34 @@ class LambdaFunction(BaseModel): # Get the invocation type: res, errored = self._invoke_lambda(code=self.code, event=body) if request_headers.get("x-amz-invocation-type") == "RequestResponse": - encoded = base64.b64encode(res.encode('utf-8')) - response_headers["x-amz-log-result"] = encoded.decode('utf-8') - payload['result'] = response_headers["x-amz-log-result"] - result = res.encode('utf-8') + encoded = base64.b64encode(res.encode("utf-8")) + response_headers["x-amz-log-result"] = encoded.decode("utf-8") + payload["result"] = response_headers["x-amz-log-result"] + result = res.encode("utf-8") else: result = json.dumps(payload) if errored: - response_headers['x-amz-function-error'] = "Handled" + response_headers["x-amz-function-error"] = "Handled" return result @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, - region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] # required spec = { - 'Code': properties['Code'], - 'FunctionName': resource_name, - 'Handler': properties['Handler'], - 'Role': properties['Role'], - 'Runtime': properties['Runtime'], + "Code": properties["Code"], + "FunctionName": resource_name, + "Handler": properties["Handler"], + "Role": properties["Role"], + "Runtime": properties["Runtime"], } - optional_properties = 'Description MemorySize Publish Timeout VpcConfig Environment'.split() + optional_properties = ( + "Description MemorySize Publish Timeout VpcConfig Environment".split() + ) # NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the # default logic for prop in optional_properties: @@ -402,90 +511,107 @@ class LambdaFunction(BaseModel): # when ZipFile is present in CloudFormation, per the official docs, # the code it's a plaintext code snippet up to 4096 bytes. # this snippet converts this plaintext code to a proper base64-encoded ZIP file. - if 'ZipFile' in properties['Code']: - spec['Code']['ZipFile'] = base64.b64encode( - cls._create_zipfile_from_plaintext_code( - spec['Code']['ZipFile'])) + if "ZipFile" in properties["Code"]: + spec["Code"]["ZipFile"] = base64.b64encode( + cls._create_zipfile_from_plaintext_code(spec["Code"]["ZipFile"]) + ) backend = lambda_backends[region_name] fn = backend.create_function(spec) return fn def get_cfn_attribute(self, attribute_name): - from moto.cloudformation.exceptions import \ - UnformattedGetAttTemplateException - if attribute_name == 'Arn': + from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + + if attribute_name == "Arn": return make_function_arn(self.region, ACCOUNT_ID, self.function_name) raise UnformattedGetAttTemplateException() + @classmethod + def update_from_cloudformation_json( + cls, new_resource_name, cloudformation_json, original_resource, region_name + ): + updated_props = cloudformation_json["Properties"] + original_resource.update_configuration(updated_props) + original_resource.update_function_code(updated_props["Code"]) + return original_resource + @staticmethod def _create_zipfile_from_plaintext_code(code): zip_output = io.BytesIO() - zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED) - zip_file.writestr('lambda_function.zip', code) + zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) + zip_file.writestr("lambda_function.zip", code) zip_file.close() zip_output.seek(0) return zip_output.read() + def delete(self, region): + lambda_backends[region].delete_function(self.function_name) + class EventSourceMapping(BaseModel): def __init__(self, spec): # required - self.function_arn = spec['FunctionArn'] - self.event_source_arn = spec['EventSourceArn'] + self.function_arn = spec["FunctionArn"] + self.event_source_arn = spec["EventSourceArn"] self.uuid = str(uuid.uuid4()) self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple()) # BatchSize service default/max mapping batch_size_map = { - 'kinesis': (100, 10000), - 'dynamodb': (100, 1000), - 'sqs': (10, 10), + "kinesis": (100, 10000), + "dynamodb": (100, 1000), + "sqs": (10, 10), } source_type = self.event_source_arn.split(":")[2].lower() batch_size_entry = batch_size_map.get(source_type) if batch_size_entry: # Use service default if not provided - batch_size = int(spec.get('BatchSize', batch_size_entry[0])) + batch_size = int(spec.get("BatchSize", batch_size_entry[0])) if batch_size > batch_size_entry[1]: - raise ValueError("InvalidParameterValueException", - "BatchSize {} exceeds the max of {}".format(batch_size, batch_size_entry[1])) + raise ValueError( + "InvalidParameterValueException", + "BatchSize {} exceeds the max of {}".format( + batch_size, batch_size_entry[1] + ), + ) else: self.batch_size = batch_size else: - raise ValueError("InvalidParameterValueException", - "Unsupported event source type") + raise ValueError( + "InvalidParameterValueException", "Unsupported event source type" + ) # optional - self.starting_position = spec.get('StartingPosition', 'TRIM_HORIZON') - self.enabled = spec.get('Enabled', True) - self.starting_position_timestamp = spec.get('StartingPositionTimestamp', - None) + self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON") + self.enabled = spec.get("Enabled", True) + self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None) def get_configuration(self): return { - 'UUID': self.uuid, - 'BatchSize': self.batch_size, - 'EventSourceArn': self.event_source_arn, - 'FunctionArn': self.function_arn, - 'LastModified': self.last_modified, - 'LastProcessingResult': '', - 'State': 'Enabled' if self.enabled else 'Disabled', - 'StateTransitionReason': 'User initiated' + "UUID": self.uuid, + "BatchSize": self.batch_size, + "EventSourceArn": self.event_source_arn, + "FunctionArn": self.function_arn, + "LastModified": self.last_modified, + "LastProcessingResult": "", + "State": "Enabled" if self.enabled else "Disabled", + "StateTransitionReason": "User initiated", } @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, - region_name): - properties = cloudformation_json['Properties'] - func = lambda_backends[region_name].get_function(properties['FunctionName']) + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + func = lambda_backends[region_name].get_function(properties["FunctionName"]) spec = { - 'FunctionArn': func.function_arn, - 'EventSourceArn': properties['EventSourceArn'], - 'StartingPosition': properties['StartingPosition'], - 'BatchSize': properties.get('BatchSize', 100) + "FunctionArn": func.function_arn, + "EventSourceArn": properties["EventSourceArn"], + "StartingPosition": properties["StartingPosition"], + "BatchSize": properties.get("BatchSize", 100), } - optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split() + optional_properties = "BatchSize Enabled StartingPositionTimestamp".split() for prop in optional_properties: if prop in properties: spec[prop] = properties[prop] @@ -494,20 +620,19 @@ class EventSourceMapping(BaseModel): class LambdaVersion(BaseModel): def __init__(self, spec): - self.version = spec['Version'] + self.version = spec["Version"] def __repr__(self): return str(self.logical_resource_id) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, - region_name): - properties = cloudformation_json['Properties'] - function_name = properties['FunctionName'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + function_name = properties["FunctionName"] func = lambda_backends[region_name].publish_function(function_name) - spec = { - 'Version': func.version - } + spec = {"Version": func.version} return LambdaVersion(spec) @@ -518,20 +643,20 @@ class LambdaStorage(object): self._arns = weakref.WeakValueDictionary() def _get_latest(self, name): - return self._functions[name]['latest'] + return self._functions[name]["latest"] def _get_version(self, name, version): index = version - 1 try: - return self._functions[name]['versions'][index] + return self._functions[name]["versions"][index] except IndexError: return None def _get_alias(self, name, alias): - return self._functions[name]['alias'].get(alias, None) + return self._functions[name]["alias"].get(alias, None) - def get_function(self, name, qualifier=None): + def get_function_by_name(self, name, qualifier=None): if name not in self._functions: return None @@ -541,34 +666,47 @@ class LambdaStorage(object): try: return self._get_version(name, int(qualifier)) except ValueError: - return self._functions[name]['latest'] + return self._functions[name]["latest"] def list_versions_by_function(self, name): if name not in self._functions: return None - latest = copy.copy(self._functions[name]['latest']) - latest.function_arn += ':$LATEST' - return [latest] + self._functions[name]['versions'] + latest = copy.copy(self._functions[name]["latest"]) + latest.function_arn += ":$LATEST" + return [latest] + self._functions[name]["versions"] def get_arn(self, arn): return self._arns.get(arn, None) - def get_function_by_name_or_arn(self, input): - return self.get_function(input) or self.get_arn(input) + def get_function_by_name_or_arn(self, input, qualifier=None): + return self.get_function_by_name(input, qualifier) or self.get_arn(input) def put_function(self, fn): """ :param fn: Function :type fn: LambdaFunction """ + valid_role = re.match(InvalidRoleFormat.pattern, fn.role) + if valid_role: + account = valid_role.group(2) + if account != ACCOUNT_ID: + raise CrossAccountNotAllowed() + try: + iam_backend.get_role_by_arn(fn.role) + except IAMNotFoundException: + raise InvalidParameterValueException( + "The role defined for the function cannot be assumed by Lambda." + ) + else: + raise InvalidRoleFormat(fn.role) if fn.function_name in self._functions: - self._functions[fn.function_name]['latest'] = fn + self._functions[fn.function_name]["latest"] = fn else: self._functions[fn.function_name] = { - 'latest': fn, - 'versions': [], - 'alias': weakref.WeakValueDictionary() + "latest": fn, + "versions": [], + "alias": weakref.WeakValueDictionary(), } self._arns[fn.function_arn] = fn @@ -576,47 +714,55 @@ class LambdaStorage(object): def publish_function(self, name): if name not in self._functions: return None - if not self._functions[name]['latest']: + if not self._functions[name]["latest"]: return None - new_version = len(self._functions[name]['versions']) + 1 - fn = copy.copy(self._functions[name]['latest']) + new_version = len(self._functions[name]["versions"]) + 1 + fn = copy.copy(self._functions[name]["latest"]) fn.set_version(new_version) - self._functions[name]['versions'].append(fn) + self._functions[name]["versions"].append(fn) self._arns[fn.function_arn] = fn return fn - def del_function(self, name, qualifier=None): - if name in self._functions: + def del_function(self, name_or_arn, qualifier=None): + function = self.get_function_by_name_or_arn(name_or_arn) + if function: + name = function.function_name if not qualifier: # Something is still reffing this so delete all arns - latest = self._functions[name]['latest'].function_arn + latest = self._functions[name]["latest"].function_arn del self._arns[latest] - for fn in self._functions[name]['versions']: + for fn in self._functions[name]["versions"]: del self._arns[fn.function_arn] del self._functions[name] return True - elif qualifier == '$LATEST': - self._functions[name]['latest'] = None + elif qualifier == "$LATEST": + self._functions[name]["latest"] = None # If theres no functions left - if not self._functions[name]['versions'] and not self._functions[name]['latest']: + if ( + not self._functions[name]["versions"] + and not self._functions[name]["latest"] + ): del self._functions[name] return True else: - fn = self.get_function(name, qualifier) + fn = self.get_function_by_name(name, qualifier) if fn: - self._functions[name]['versions'].remove(fn) + self._functions[name]["versions"].remove(fn) # If theres no functions left - if not self._functions[name]['versions'] and not self._functions[name]['latest']: + if ( + not self._functions[name]["versions"] + and not self._functions[name]["latest"] + ): del self._functions[name] return True @@ -627,10 +773,10 @@ class LambdaStorage(object): result = [] for function_group in self._functions.values(): - if function_group['latest'] is not None: - result.append(function_group['latest']) + if function_group["latest"] is not None: + result.append(function_group["latest"]) - result.extend(function_group['versions']) + result.extend(function_group["versions"]) return result @@ -647,44 +793,47 @@ class LambdaBackend(BaseBackend): self.__init__(region_name) def create_function(self, spec): - function_name = spec.get('FunctionName', None) + function_name = spec.get("FunctionName", None) if function_name is None: - raise RESTError('InvalidParameterValueException', 'Missing FunctionName') + raise RESTError("InvalidParameterValueException", "Missing FunctionName") - fn = LambdaFunction(spec, self.region_name, version='$LATEST') + fn = LambdaFunction(spec, self.region_name, version="$LATEST") self._lambdas.put_function(fn) - if spec.get('Publish'): + if spec.get("Publish"): ver = self.publish_function(function_name) fn.version = ver.version return fn def create_event_source_mapping(self, spec): - required = [ - 'EventSourceArn', - 'FunctionName', - ] + required = ["EventSourceArn", "FunctionName"] for param in required: if not spec.get(param): - raise RESTError('InvalidParameterValueException', 'Missing {}'.format(param)) + raise RESTError( + "InvalidParameterValueException", "Missing {}".format(param) + ) # Validate function name - func = self._lambdas.get_function_by_name_or_arn(spec.pop('FunctionName', '')) + func = self._lambdas.get_function_by_name_or_arn(spec.pop("FunctionName", "")) if not func: - raise RESTError('ResourceNotFoundException', 'Invalid FunctionName') + raise RESTError("ResourceNotFoundException", "Invalid FunctionName") # Validate queue for queue in sqs_backends[self.region_name].queues.values(): - if queue.queue_arn == spec['EventSourceArn']: - if queue.lambda_event_source_mappings.get('func.function_arn'): + if queue.queue_arn == spec["EventSourceArn"]: + if queue.lambda_event_source_mappings.get("func.function_arn"): # TODO: Correct exception? - raise RESTError('ResourceConflictException', 'The resource already exists.') + raise RESTError( + "ResourceConflictException", "The resource already exists." + ) if queue.fifo_queue: - raise RESTError('InvalidParameterValueException', - '{} is FIFO'.format(queue.queue_arn)) + raise RESTError( + "InvalidParameterValueException", + "{} is FIFO".format(queue.queue_arn), + ) else: - spec.update({'FunctionArn': func.function_arn}) + spec.update({"FunctionArn": func.function_arn}) esm = EventSourceMapping(spec) self._event_source_mappings[esm.uuid] = esm @@ -692,13 +841,26 @@ class LambdaBackend(BaseBackend): queue.lambda_event_source_mappings[esm.function_arn] = esm return esm - raise RESTError('ResourceNotFoundException', 'Invalid EventSourceArn') + for stream in json.loads( + dynamodbstreams_backends[self.region_name].list_streams() + )["Streams"]: + if stream["StreamArn"] == spec["EventSourceArn"]: + spec.update({"FunctionArn": func.function_arn}) + esm = EventSourceMapping(spec) + self._event_source_mappings[esm.uuid] = esm + table_name = stream["TableName"] + table = dynamodb_backends2[self.region_name].get_table(table_name) + table.lambda_event_source_mappings[esm.function_arn] = esm + return esm + raise RESTError("ResourceNotFoundException", "Invalid EventSourceArn") def publish_function(self, function_name): return self._lambdas.publish_function(function_name) - def get_function(self, function_name, qualifier=None): - return self._lambdas.get_function(function_name, qualifier) + def get_function(self, function_name_or_arn, qualifier=None): + return self._lambdas.get_function_by_name_or_arn( + function_name_or_arn, qualifier + ) def list_versions_by_function(self, function_name): return self._lambdas.list_versions_by_function(function_name) @@ -712,13 +874,15 @@ class LambdaBackend(BaseBackend): def update_event_source_mapping(self, uuid, spec): esm = self.get_event_source_mapping(uuid) if esm: - if spec.get('FunctionName'): - func = self._lambdas.get_function_by_name_or_arn(spec.get('FunctionName')) + if spec.get("FunctionName"): + func = self._lambdas.get_function_by_name_or_arn( + spec.get("FunctionName") + ) esm.function_arn = func.function_arn - if 'BatchSize' in spec: - esm.batch_size = spec['BatchSize'] - if 'Enabled' in spec: - esm.enabled = spec['Enabled'] + if "BatchSize" in spec: + esm.batch_size = spec["BatchSize"] + if "Enabled" in spec: + esm.enabled = spec["Enabled"] return esm return False @@ -759,13 +923,13 @@ class LambdaBackend(BaseBackend): "ApproximateReceiveCount": "1", "SentTimestamp": "1545082649183", "SenderId": "AIDAIENQZJOLO23YVJ4VO", - "ApproximateFirstReceiveTimestamp": "1545082649185" + "ApproximateFirstReceiveTimestamp": "1545082649185", }, "messageAttributes": {}, "md5OfBody": "098f6bcd4621d373cade4e832627b4f6", "eventSource": "aws:sqs", "eventSourceARN": queue_arn, - "awsRegion": self.region_name + "awsRegion": self.region_name, } ] } @@ -773,7 +937,7 @@ class LambdaBackend(BaseBackend): request_headers = {} response_headers = {} func.invoke(json.dumps(event), request_headers, response_headers) - return 'x-amz-function-error' not in response_headers + return "x-amz-function-error" not in response_headers def send_sns_message(self, function_name, message, subject=None, qualifier=None): event = { @@ -790,25 +954,36 @@ class LambdaBackend(BaseBackend): "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e", "Message": message, "MessageAttributes": { - "Test": { - "Type": "String", - "Value": "TestString" - }, - "TestBinary": { - "Type": "Binary", - "Value": "TestBinary" - } + "Test": {"Type": "String", "Value": "TestString"}, + "TestBinary": {"Type": "Binary", "Value": "TestBinary"}, }, "Type": "Notification", "UnsubscribeUrl": "EXAMPLE", "TopicArn": "arn:aws:sns:EXAMPLE", - "Subject": subject or "TestInvoke" - } + "Subject": subject or "TestInvoke", + }, } ] - } - func = self._lambdas.get_function(function_name, qualifier) + func = self._lambdas.get_function_by_name_or_arn(function_name, qualifier) + func.invoke(json.dumps(event), {}, {}) + + def send_dynamodb_items(self, function_arn, items, source): + event = { + "Records": [ + { + "eventID": item.to_json()["eventID"], + "eventName": "INSERT", + "eventVersion": item.to_json()["eventVersion"], + "eventSource": item.to_json()["eventSource"], + "awsRegion": self.region_name, + "dynamodb": item.to_json()["dynamodb"], + "eventSourceARN": source, + } + for item in items + ] + } + func = self._lambdas.get_arn(function_arn) func.invoke(json.dumps(event), {}, {}) def list_tags(self, resource): @@ -837,14 +1012,41 @@ class LambdaBackend(BaseBackend): def add_policy(self, function_name, policy): self.get_function(function_name).policy = policy + def update_function_code(self, function_name, qualifier, body): + fn = self.get_function(function_name, qualifier) + + if fn: + if body.get("Publish", False): + fn = self.publish_function(function_name) + + config = fn.update_function_code(body) + return config + else: + return None + + def update_function_configuration(self, function_name, qualifier, body): + fn = self.get_function(function_name, qualifier) + + return fn.update_configuration(body) if fn else None + + def invoke(self, function_name, qualifier, body, headers, response_headers): + fn = self.get_function(function_name, qualifier) + if fn: + payload = fn.invoke(body, headers, response_headers) + response_headers["Content-Length"] = str(len(payload)) + return response_headers, payload + else: + return response_headers, None + def do_validate_s3(): - return os.environ.get('VALIDATE_LAMBDA_S3', '') in ['', '1', 'true'] + return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"] # Handle us forgotten regions, unless Lambda truly only runs out of US and -lambda_backends = {_region.name: LambdaBackend(_region.name) - for _region in boto.awslambda.regions()} +lambda_backends = { + _region.name: LambdaBackend(_region.name) for _region in boto.awslambda.regions() +} -lambda_backends['ap-southeast-2'] = LambdaBackend('ap-southeast-2') -lambda_backends['us-gov-west-1'] = LambdaBackend('us-gov-west-1') +lambda_backends["ap-southeast-2"] = LambdaBackend("ap-southeast-2") +lambda_backends["us-gov-west-1"] = LambdaBackend("us-gov-west-1") diff --git a/moto/awslambda/responses.py b/moto/awslambda/responses.py index 1e7feb0d0..46203c10d 100644 --- a/moto/awslambda/responses.py +++ b/moto/awslambda/responses.py @@ -32,57 +32,57 @@ class LambdaResponse(BaseResponse): def root(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": return self._list_functions(request, full_url, headers) - elif request.method == 'POST': + elif request.method == "POST": return self._create_function(request, full_url, headers) else: raise ValueError("Cannot handle request") def event_source_mappings(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": querystring = self.querystring - event_source_arn = querystring.get('EventSourceArn', [None])[0] - function_name = querystring.get('FunctionName', [None])[0] + event_source_arn = querystring.get("EventSourceArn", [None])[0] + function_name = querystring.get("FunctionName", [None])[0] return self._list_event_source_mappings(event_source_arn, function_name) - elif request.method == 'POST': + elif request.method == "POST": return self._create_event_source_mapping(request, full_url, headers) else: raise ValueError("Cannot handle request") def event_source_mapping(self, request, full_url, headers): self.setup_class(request, full_url, headers) - path = request.path if hasattr(request, 'path') else path_url(request.url) - uuid = path.split('/')[-1] - if request.method == 'GET': + path = request.path if hasattr(request, "path") else path_url(request.url) + uuid = path.split("/")[-1] + if request.method == "GET": return self._get_event_source_mapping(uuid) - elif request.method == 'PUT': + elif request.method == "PUT": return self._update_event_source_mapping(uuid) - elif request.method == 'DELETE': + elif request.method == "DELETE": return self._delete_event_source_mapping(uuid) else: raise ValueError("Cannot handle request") def function(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": return self._get_function(request, full_url, headers) - elif request.method == 'DELETE': + elif request.method == "DELETE": return self._delete_function(request, full_url, headers) else: raise ValueError("Cannot handle request") def versions(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": # This is ListVersionByFunction - path = request.path if hasattr(request, 'path') else path_url(request.url) - function_name = path.split('/')[-2] + path = request.path if hasattr(request, "path") else path_url(request.url) + function_name = path.split("/")[-2] return self._list_versions_by_function(function_name) - elif request.method == 'POST': + elif request.method == "POST": return self._publish_function(request, full_url, headers) else: raise ValueError("Cannot handle request") @@ -91,7 +91,7 @@ class LambdaResponse(BaseResponse): @amzn_request_id def invoke(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'POST': + if request.method == "POST": return self._invoke(request, full_url) else: raise ValueError("Cannot handle request") @@ -100,57 +100,78 @@ class LambdaResponse(BaseResponse): @amzn_request_id def invoke_async(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'POST': + if request.method == "POST": return self._invoke_async(request, full_url) else: raise ValueError("Cannot handle request") def tag(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": return self._list_tags(request, full_url) - elif request.method == 'POST': + elif request.method == "POST": return self._tag_resource(request, full_url) - elif request.method == 'DELETE': + elif request.method == "DELETE": return self._untag_resource(request, full_url) else: raise ValueError("Cannot handle {0} request".format(request.method)) def policy(self, request, full_url, headers): - if request.method == 'GET': + self.setup_class(request, full_url, headers) + if request.method == "GET": return self._get_policy(request, full_url, headers) - if request.method == 'POST': + if request.method == "POST": return self._add_policy(request, full_url, headers) + def configuration(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + if request.method == "PUT": + return self._put_configuration(request) + else: + raise ValueError("Cannot handle request") + + def code(self, request, full_url, headers): + self.setup_class(request, full_url, headers) + if request.method == "PUT": + return self._put_code() + else: + raise ValueError("Cannot handle request") + def _add_policy(self, request, full_url, headers): - path = request.path if hasattr(request, 'path') else path_url(request.url) - function_name = path.split('/')[-2] + path = request.path if hasattr(request, "path") else path_url(request.url) + function_name = path.split("/")[-2] if self.lambda_backend.get_function(function_name): - policy = request.body.decode('utf8') + policy = self.body self.lambda_backend.add_policy(function_name, policy) return 200, {}, json.dumps(dict(Statement=policy)) else: return 404, {}, "{}" def _get_policy(self, request, full_url, headers): - path = request.path if hasattr(request, 'path') else path_url(request.url) - function_name = path.split('/')[-2] + path = request.path if hasattr(request, "path") else path_url(request.url) + function_name = path.split("/")[-2] if self.lambda_backend.get_function(function_name): lambda_function = self.lambda_backend.get_function(function_name) - return 200, {}, json.dumps(dict(Policy="{\"Statement\":[" + lambda_function.policy + "]}")) + return ( + 200, + {}, + json.dumps( + dict(Policy='{"Statement":[' + lambda_function.policy + "]}") + ), + ) else: return 404, {}, "{}" def _invoke(self, request, full_url): response_headers = {} - function_name = self.path.rsplit('/', 2)[-2] - qualifier = self._get_param('qualifier') + function_name = self.path.rsplit("/", 2)[-2] + qualifier = self._get_param("qualifier") - fn = self.lambda_backend.get_function(function_name, qualifier) - if fn: - payload = fn.invoke(self.body, self.headers, response_headers) - response_headers['Content-Length'] = str(len(payload)) + response_header, payload = self.lambda_backend.invoke( + function_name, qualifier, self.body, self.headers, response_headers + ) + if payload: return 202, response_headers, payload else: return 404, response_headers, "{}" @@ -158,64 +179,52 @@ class LambdaResponse(BaseResponse): def _invoke_async(self, request, full_url): response_headers = {} - function_name = self.path.rsplit('/', 3)[-3] + function_name = self.path.rsplit("/", 3)[-3] fn = self.lambda_backend.get_function(function_name, None) if fn: payload = fn.invoke(self.body, self.headers, response_headers) - response_headers['Content-Length'] = str(len(payload)) + response_headers["Content-Length"] = str(len(payload)) return 202, response_headers, payload else: return 404, response_headers, "{}" def _list_functions(self, request, full_url, headers): - result = { - 'Functions': [] - } + result = {"Functions": []} for fn in self.lambda_backend.list_functions(): json_data = fn.get_configuration() - json_data['Version'] = '$LATEST' - result['Functions'].append(json_data) + json_data["Version"] = "$LATEST" + result["Functions"].append(json_data) return 200, {}, json.dumps(result) def _list_versions_by_function(self, function_name): - result = { - 'Versions': [] - } + result = {"Versions": []} functions = self.lambda_backend.list_versions_by_function(function_name) if functions: for fn in functions: json_data = fn.get_configuration() - result['Versions'].append(json_data) + result["Versions"].append(json_data) return 200, {}, json.dumps(result) def _create_function(self, request, full_url, headers): - try: - fn = self.lambda_backend.create_function(self.json_body) - except ValueError as e: - return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}) - else: - config = fn.get_configuration() - return 201, {}, json.dumps(config) + fn = self.lambda_backend.create_function(self.json_body) + config = fn.get_configuration() + return 201, {}, json.dumps(config) def _create_event_source_mapping(self, request, full_url, headers): - try: - fn = self.lambda_backend.create_event_source_mapping(self.json_body) - except ValueError as e: - return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}) - else: - config = fn.get_configuration() - return 201, {}, json.dumps(config) + fn = self.lambda_backend.create_event_source_mapping(self.json_body) + config = fn.get_configuration() + return 201, {}, json.dumps(config) def _list_event_source_mappings(self, event_source_arn, function_name): - esms = self.lambda_backend.list_event_source_mappings(event_source_arn, function_name) - result = { - 'EventSourceMappings': [esm.get_configuration() for esm in esms] - } + esms = self.lambda_backend.list_event_source_mappings( + event_source_arn, function_name + ) + result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]} return 200, {}, json.dumps(result) def _get_event_source_mapping(self, uuid): @@ -236,13 +245,13 @@ class LambdaResponse(BaseResponse): esm = self.lambda_backend.delete_event_source_mapping(uuid) if esm: json_result = esm.get_configuration() - json_result.update({'State': 'Deleting'}) + json_result.update({"State": "Deleting"}) return 202, {}, json.dumps(json_result) else: return 404, {}, "{}" def _publish_function(self, request, full_url, headers): - function_name = self.path.rsplit('/', 2)[-2] + function_name = self.path.rsplit("/", 2)[-2] fn = self.lambda_backend.publish_function(function_name) if fn: @@ -252,8 +261,8 @@ class LambdaResponse(BaseResponse): return 404, {}, "{}" def _delete_function(self, request, full_url, headers): - function_name = self.path.rsplit('/', 1)[-1] - qualifier = self._get_param('Qualifier', None) + function_name = unquote(self.path.rsplit("/", 1)[-1]) + qualifier = self._get_param("Qualifier", None) if self.lambda_backend.delete_function(function_name, qualifier): return 204, {}, "" @@ -261,17 +270,17 @@ class LambdaResponse(BaseResponse): return 404, {}, "{}" def _get_function(self, request, full_url, headers): - function_name = self.path.rsplit('/', 1)[-1] - qualifier = self._get_param('Qualifier', None) + function_name = unquote(self.path.rsplit("/", 1)[-1]) + qualifier = self._get_param("Qualifier", None) fn = self.lambda_backend.get_function(function_name, qualifier) if fn: code = fn.get_code() - if qualifier is None or qualifier == '$LATEST': - code['Configuration']['Version'] = '$LATEST' - if qualifier == '$LATEST': - code['Configuration']['FunctionArn'] += ':$LATEST' + if qualifier is None or qualifier == "$LATEST": + code["Configuration"]["Version"] = "$LATEST" + if qualifier == "$LATEST": + code["Configuration"]["FunctionArn"] += ":$LATEST" return 200, {}, json.dumps(code) else: return 404, {}, "{}" @@ -284,27 +293,51 @@ class LambdaResponse(BaseResponse): return self.default_region def _list_tags(self, request, full_url): - function_arn = unquote(self.path.rsplit('/', 1)[-1]) + function_arn = unquote(self.path.rsplit("/", 1)[-1]) fn = self.lambda_backend.get_function_by_arn(function_arn) if fn: - return 200, {}, json.dumps({'Tags': fn.tags}) + return 200, {}, json.dumps({"Tags": fn.tags}) else: return 404, {}, "{}" def _tag_resource(self, request, full_url): - function_arn = unquote(self.path.rsplit('/', 1)[-1]) + function_arn = unquote(self.path.rsplit("/", 1)[-1]) - if self.lambda_backend.tag_resource(function_arn, self.json_body['Tags']): + if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]): return 200, {}, "{}" else: return 404, {}, "{}" def _untag_resource(self, request, full_url): - function_arn = unquote(self.path.rsplit('/', 1)[-1]) - tag_keys = self.querystring['tagKeys'] + function_arn = unquote(self.path.rsplit("/", 1)[-1]) + tag_keys = self.querystring["tagKeys"] if self.lambda_backend.untag_resource(function_arn, tag_keys): return 204, {}, "{}" else: return 404, {}, "{}" + + def _put_configuration(self, request): + function_name = self.path.rsplit("/", 2)[-2] + qualifier = self._get_param("Qualifier", None) + resp = self.lambda_backend.update_function_configuration( + function_name, qualifier, body=self.json_body + ) + + if resp: + return 200, {}, json.dumps(resp) + else: + return 404, {}, "{}" + + def _put_code(self): + function_name = self.path.rsplit("/", 2)[-2] + qualifier = self._get_param("Qualifier", None) + resp = self.lambda_backend.update_function_code( + function_name, qualifier, body=self.json_body + ) + + if resp: + return 200, {}, json.dumps(resp) + else: + return 404, {}, "{}" diff --git a/moto/awslambda/urls.py b/moto/awslambda/urls.py index fb2c6ee7e..da7346817 100644 --- a/moto/awslambda/urls.py +++ b/moto/awslambda/urls.py @@ -1,20 +1,20 @@ from __future__ import unicode_literals from .responses import LambdaResponse -url_bases = [ - "https?://lambda.(.+).amazonaws.com", -] +url_bases = ["https?://lambda.(.+).amazonaws.com"] response = LambdaResponse() url_paths = { - '{0}/(?P[^/]+)/functions/?$': response.root, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/?$': response.function, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/versions/?$': response.versions, - r'{0}/(?P[^/]+)/event-source-mappings/?$': response.event_source_mappings, - r'{0}/(?P[^/]+)/event-source-mappings/(?P[\w_-]+)/?$': response.event_source_mapping, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invocations/?$': response.invoke, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invoke-async/?$': response.invoke_async, - r'{0}/(?P[^/]+)/tags/(?P.+)': response.tag, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/policy/?$': response.policy + "{0}/(?P[^/]+)/functions/?$": response.root, + r"{0}/(?P[^/]+)/functions/(?P[\w_:%-]+)/?$": response.function, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/versions/?$": response.versions, + r"{0}/(?P[^/]+)/event-source-mappings/?$": response.event_source_mappings, + r"{0}/(?P[^/]+)/event-source-mappings/(?P[\w_-]+)/?$": response.event_source_mapping, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invocations/?$": response.invoke, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invoke-async/?$": response.invoke_async, + r"{0}/(?P[^/]+)/tags/(?P.+)": response.tag, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/policy/?$": response.policy, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/configuration/?$": response.configuration, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/code/?$": response.code, } diff --git a/moto/awslambda/utils.py b/moto/awslambda/utils.py index 82027cb2f..e024b7b9b 100644 --- a/moto/awslambda/utils.py +++ b/moto/awslambda/utils.py @@ -1,20 +1,20 @@ from collections import namedtuple -ARN = namedtuple('ARN', ['region', 'account', 'function_name', 'version']) +ARN = namedtuple("ARN", ["region", "account", "function_name", "version"]) def make_function_arn(region, account, name): - return 'arn:aws:lambda:{0}:{1}:function:{2}'.format(region, account, name) + return "arn:aws:lambda:{0}:{1}:function:{2}".format(region, account, name) -def make_function_ver_arn(region, account, name, version='1'): +def make_function_ver_arn(region, account, name, version="1"): arn = make_function_arn(region, account, name) - return '{0}:{1}'.format(arn, version) + return "{0}:{1}".format(arn, version) def split_function_arn(arn): - arn = arn.replace('arn:aws:lambda:') + arn = arn.replace("arn:aws:lambda:") - region, account, _, name, version = arn.split(':') + region, account, _, name, version = arn.split(":") return ARN(region, account, name, version) diff --git a/moto/backends.py b/moto/backends.py index 6ea85093d..9295bc758 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -2,14 +2,19 @@ from __future__ import unicode_literals from moto.acm import acm_backends from moto.apigateway import apigateway_backends +from moto.athena import athena_backends from moto.autoscaling import autoscaling_backends from moto.awslambda import lambda_backends +from moto.batch import batch_backends from moto.cloudformation import cloudformation_backends from moto.cloudwatch import cloudwatch_backends +from moto.codepipeline import codepipeline_backends from moto.cognitoidentity import cognitoidentity_backends from moto.cognitoidp import cognitoidp_backends +from moto.config import config_backends from moto.core import moto_api_backends from moto.datapipeline import datapipeline_backends +from moto.datasync import datasync_backends from moto.dynamodb import dynamodb_backends from moto.dynamodb2 import dynamodb_backends2 from moto.dynamodbstreams import dynamodbstreams_backends @@ -24,6 +29,8 @@ from moto.glacier import glacier_backends from moto.glue import glue_backends from moto.iam import iam_backends from moto.instance_metadata import instance_metadata_backends +from moto.iot import iot_backends +from moto.iotdata import iotdata_backends from moto.kinesis import kinesis_backends from moto.kms import kms_backends from moto.logs import logs_backends @@ -33,72 +40,73 @@ from moto.polly import polly_backends from moto.rds2 import rds2_backends from moto.redshift import redshift_backends from moto.resourcegroups import resourcegroups_backends +from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends from moto.route53 import route53_backends from moto.s3 import s3_backends -from moto.ses import ses_backends from moto.secretsmanager import secretsmanager_backends +from moto.ses import ses_backends from moto.sns import sns_backends from moto.sqs import sqs_backends from moto.ssm import ssm_backends +from moto.stepfunctions import stepfunction_backends from moto.sts import sts_backends from moto.swf import swf_backends from moto.xray import xray_backends -from moto.iot import iot_backends -from moto.iotdata import iotdata_backends -from moto.batch import batch_backends -from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends -from moto.config import config_backends BACKENDS = { - 'acm': acm_backends, - 'apigateway': apigateway_backends, - 'autoscaling': autoscaling_backends, - 'batch': batch_backends, - 'cloudformation': cloudformation_backends, - 'cloudwatch': cloudwatch_backends, - 'cognito-identity': cognitoidentity_backends, - 'cognito-idp': cognitoidp_backends, - 'config': config_backends, - 'datapipeline': datapipeline_backends, - 'dynamodb': dynamodb_backends, - 'dynamodb2': dynamodb_backends2, - 'dynamodbstreams': dynamodbstreams_backends, - 'ec2': ec2_backends, - 'ecr': ecr_backends, - 'ecs': ecs_backends, - 'elb': elb_backends, - 'elbv2': elbv2_backends, - 'events': events_backends, - 'emr': emr_backends, - 'glacier': glacier_backends, - 'glue': glue_backends, - 'iam': iam_backends, - 'moto_api': moto_api_backends, - 'instance_metadata': instance_metadata_backends, - 'logs': logs_backends, - 'kinesis': kinesis_backends, - 'kms': kms_backends, - 'opsworks': opsworks_backends, - 'organizations': organizations_backends, - 'polly': polly_backends, - 'redshift': redshift_backends, - 'resource-groups': resourcegroups_backends, - 'rds': rds2_backends, - 's3': s3_backends, - 's3bucket_path': s3_backends, - 'ses': ses_backends, - 'secretsmanager': secretsmanager_backends, - 'sns': sns_backends, - 'sqs': sqs_backends, - 'ssm': ssm_backends, - 'sts': sts_backends, - 'swf': swf_backends, - 'route53': route53_backends, - 'lambda': lambda_backends, - 'xray': xray_backends, - 'resourcegroupstaggingapi': resourcegroupstaggingapi_backends, - 'iot': iot_backends, - 'iot-data': iotdata_backends, + "acm": acm_backends, + "apigateway": apigateway_backends, + "athena": athena_backends, + "autoscaling": autoscaling_backends, + "batch": batch_backends, + "cloudformation": cloudformation_backends, + "cloudwatch": cloudwatch_backends, + "codepipeline": codepipeline_backends, + "cognito-identity": cognitoidentity_backends, + "cognito-idp": cognitoidp_backends, + "config": config_backends, + "datapipeline": datapipeline_backends, + "datasync": datasync_backends, + "dynamodb": dynamodb_backends, + "dynamodb2": dynamodb_backends2, + "dynamodbstreams": dynamodbstreams_backends, + "ec2": ec2_backends, + "ecr": ecr_backends, + "ecs": ecs_backends, + "elb": elb_backends, + "elbv2": elbv2_backends, + "events": events_backends, + "emr": emr_backends, + "glacier": glacier_backends, + "glue": glue_backends, + "iam": iam_backends, + "moto_api": moto_api_backends, + "instance_metadata": instance_metadata_backends, + "logs": logs_backends, + "kinesis": kinesis_backends, + "kms": kms_backends, + "opsworks": opsworks_backends, + "organizations": organizations_backends, + "polly": polly_backends, + "redshift": redshift_backends, + "resource-groups": resourcegroups_backends, + "rds": rds2_backends, + "s3": s3_backends, + "s3bucket_path": s3_backends, + "ses": ses_backends, + "secretsmanager": secretsmanager_backends, + "sns": sns_backends, + "sqs": sqs_backends, + "ssm": ssm_backends, + "stepfunctions": stepfunction_backends, + "sts": sts_backends, + "swf": swf_backends, + "route53": route53_backends, + "lambda": lambda_backends, + "xray": xray_backends, + "resourcegroupstaggingapi": resourcegroupstaggingapi_backends, + "iot": iot_backends, + "iot-data": iotdata_backends, } @@ -106,6 +114,6 @@ def get_model(name, region_name): for backends in BACKENDS.values(): for region, backend in backends.items(): if region == region_name: - models = getattr(backend.__class__, '__models__', {}) + models = getattr(backend.__class__, "__models__", {}) if name in models: return list(getattr(backend, models[name])()) diff --git a/moto/batch/__init__.py b/moto/batch/__init__.py index 6002b6fc7..40144d35d 100644 --- a/moto/batch/__init__.py +++ b/moto/batch/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import batch_backends from ..core.models import base_decorator -batch_backend = batch_backends['us-east-1'] +batch_backend = batch_backends["us-east-1"] mock_batch = base_decorator(batch_backends) diff --git a/moto/batch/exceptions.py b/moto/batch/exceptions.py index a71e54ce3..c411f3fce 100644 --- a/moto/batch/exceptions.py +++ b/moto/batch/exceptions.py @@ -12,26 +12,29 @@ class AWSError(Exception): self.status = status if status is not None else self.STATUS def response(self): - return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) + return ( + json.dumps({"__type": self.code, "message": self.message}), + dict(status=self.status), + ) class InvalidRequestException(AWSError): - CODE = 'InvalidRequestException' + CODE = "InvalidRequestException" class InvalidParameterValueException(AWSError): - CODE = 'InvalidParameterValue' + CODE = "InvalidParameterValue" class ValidationError(AWSError): - CODE = 'ValidationError' + CODE = "ValidationError" class InternalFailure(AWSError): - CODE = 'InternalFailure' + CODE = "InternalFailure" STATUS = 500 class ClientException(AWSError): - CODE = 'ClientException' + CODE = "ClientException" STATUS = 400 diff --git a/moto/batch/models.py b/moto/batch/models.py index caa442802..e12cc8f84 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -19,16 +19,22 @@ from moto.ecs import ecs_backends from moto.logs import logs_backends from .exceptions import InvalidParameterValueException, InternalFailure, ClientException -from .utils import make_arn_for_compute_env, make_arn_for_job_queue, make_arn_for_task_def, lowercase_first_key +from .utils import ( + make_arn_for_compute_env, + make_arn_for_job_queue, + make_arn_for_task_def, + lowercase_first_key, +) from moto.ec2.exceptions import InvalidSubnetIdError from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES from moto.iam.exceptions import IAMNotFoundException - +from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID _orig_adapter_send = requests.adapters.HTTPAdapter.send logger = logging.getLogger(__name__) -DEFAULT_ACCOUNT_ID = 123456789012 -COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(r'^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$') +COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile( + r"^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$" +) def datetime2int(date): @@ -36,13 +42,23 @@ def datetime2int(date): class ComputeEnvironment(BaseModel): - def __init__(self, compute_environment_name, _type, state, compute_resources, service_role, region_name): + def __init__( + self, + compute_environment_name, + _type, + state, + compute_resources, + service_role, + region_name, + ): self.name = compute_environment_name self.env_type = _type self.state = state self.compute_resources = compute_resources self.service_role = service_role - self.arn = make_arn_for_compute_env(DEFAULT_ACCOUNT_ID, compute_environment_name, region_name) + self.arn = make_arn_for_compute_env( + DEFAULT_ACCOUNT_ID, compute_environment_name, region_name + ) self.instances = [] self.ecs_arn = None @@ -60,16 +76,18 @@ class ComputeEnvironment(BaseModel): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = batch_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] env = backend.create_compute_environment( resource_name, - properties['Type'], - properties.get('State', 'ENABLED'), - lowercase_first_key(properties['ComputeResources']), - properties['ServiceRole'] + properties["Type"], + properties.get("State", "ENABLED"), + lowercase_first_key(properties["ComputeResources"]), + properties["ServiceRole"], ) arn = env[1] @@ -77,7 +95,9 @@ class ComputeEnvironment(BaseModel): class JobQueue(BaseModel): - def __init__(self, name, priority, state, environments, env_order_json, region_name): + def __init__( + self, name, priority, state, environments, env_order_json, region_name + ): """ :param name: Job queue name :type name: str @@ -98,18 +118,18 @@ class JobQueue(BaseModel): self.environments = environments self.env_order_json = env_order_json self.arn = make_arn_for_job_queue(DEFAULT_ACCOUNT_ID, name, region_name) - self.status = 'VALID' + self.status = "VALID" self.jobs = [] def describe(self): result = { - 'computeEnvironmentOrder': self.env_order_json, - 'jobQueueArn': self.arn, - 'jobQueueName': self.name, - 'priority': self.priority, - 'state': self.state, - 'status': self.status + "computeEnvironmentOrder": self.env_order_json, + "jobQueueArn": self.arn, + "jobQueueName": self.name, + "priority": self.priority, + "state": self.state, + "status": self.status, } return result @@ -119,19 +139,24 @@ class JobQueue(BaseModel): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = batch_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] # Need to deal with difference case from cloudformation compute_resources, e.g. instanceRole vs InstanceRole # Hacky fix to normalise keys, is making me think I want to start spamming cAsEiNsEnSiTiVe dictionaries - compute_envs = [lowercase_first_key(dict_item) for dict_item in properties['ComputeEnvironmentOrder']] + compute_envs = [ + lowercase_first_key(dict_item) + for dict_item in properties["ComputeEnvironmentOrder"] + ] queue = backend.create_job_queue( queue_name=resource_name, - priority=properties['Priority'], - state=properties.get('State', 'ENABLED'), - compute_env_order=compute_envs + priority=properties["Priority"], + state=properties.get("State", "ENABLED"), + compute_env_order=compute_envs, ) arn = queue[1] @@ -139,7 +164,16 @@ class JobQueue(BaseModel): class JobDefinition(BaseModel): - def __init__(self, name, parameters, _type, container_properties, region_name, revision=0, retry_strategy=0): + def __init__( + self, + name, + parameters, + _type, + container_properties, + region_name, + revision=0, + retry_strategy=0, + ): self.name = name self.retries = retry_strategy self.type = _type @@ -147,7 +181,7 @@ class JobDefinition(BaseModel): self._region = region_name self.container_properties = container_properties self.arn = None - self.status = 'INACTIVE' + self.status = "ACTIVE" if parameters is None: parameters = {} @@ -158,31 +192,33 @@ class JobDefinition(BaseModel): def _update_arn(self): self.revision += 1 - self.arn = make_arn_for_task_def(DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region) + self.arn = make_arn_for_task_def( + DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region + ) def _validate(self): - if self.type not in ('container',): + if self.type not in ("container",): raise ClientException('type must be one of "container"') # For future use when containers arnt the only thing in batch - if self.type != 'container': + if self.type != "container": raise NotImplementedError() if not isinstance(self.parameters, dict): - raise ClientException('parameters must be a string to string map') + raise ClientException("parameters must be a string to string map") - if 'image' not in self.container_properties: - raise ClientException('containerProperties must contain image') + if "image" not in self.container_properties: + raise ClientException("containerProperties must contain image") - if 'memory' not in self.container_properties: - raise ClientException('containerProperties must contain memory') - if self.container_properties['memory'] < 4: - raise ClientException('container memory limit must be greater than 4') + if "memory" not in self.container_properties: + raise ClientException("containerProperties must contain memory") + if self.container_properties["memory"] < 4: + raise ClientException("container memory limit must be greater than 4") - if 'vcpus' not in self.container_properties: - raise ClientException('containerProperties must contain vcpus') - if self.container_properties['vcpus'] < 1: - raise ClientException('container vcpus limit must be greater than 0') + if "vcpus" not in self.container_properties: + raise ClientException("containerProperties must contain vcpus") + if self.container_properties["vcpus"] < 1: + raise ClientException("container vcpus limit must be greater than 0") def update(self, parameters, _type, container_properties, retry_strategy): if parameters is None: @@ -197,21 +233,29 @@ class JobDefinition(BaseModel): if retry_strategy is None: retry_strategy = self.retries - return JobDefinition(self.name, parameters, _type, container_properties, region_name=self._region, revision=self.revision, retry_strategy=retry_strategy) + return JobDefinition( + self.name, + parameters, + _type, + container_properties, + region_name=self._region, + revision=self.revision, + retry_strategy=retry_strategy, + ) def describe(self): result = { - 'jobDefinitionArn': self.arn, - 'jobDefinitionName': self.name, - 'parameters': self.parameters, - 'revision': self.revision, - 'status': self.status, - 'type': self.type + "jobDefinitionArn": self.arn, + "jobDefinitionName": self.name, + "parameters": self.parameters, + "revision": self.revision, + "status": self.status, + "type": self.type, } if self.container_properties is not None: - result['containerProperties'] = self.container_properties + result["containerProperties"] = self.container_properties if self.retries is not None and self.retries > 0: - result['retryStrategy'] = {'attempts': self.retries} + result["retryStrategy"] = {"attempts": self.retries} return result @@ -220,16 +264,18 @@ class JobDefinition(BaseModel): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = batch_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] res = backend.register_job_definition( def_name=resource_name, - parameters=lowercase_first_key(properties.get('Parameters', {})), - _type='container', - retry_strategy=lowercase_first_key(properties['RetryStrategy']), - container_properties=lowercase_first_key(properties['ContainerProperties']) + parameters=lowercase_first_key(properties.get("Parameters", {})), + _type="container", + retry_strategy=lowercase_first_key(properties["RetryStrategy"]), + container_properties=lowercase_first_key(properties["ContainerProperties"]), ) arn = res[1] @@ -238,7 +284,7 @@ class JobDefinition(BaseModel): class Job(threading.Thread, BaseModel): - def __init__(self, name, job_def, job_queue, log_backend): + def __init__(self, name, job_def, job_queue, log_backend, container_overrides): """ Docker Job @@ -254,8 +300,9 @@ class Job(threading.Thread, BaseModel): self.job_name = name self.job_id = str(uuid.uuid4()) self.job_definition = job_def + self.container_overrides = container_overrides self.job_queue = job_queue - self.job_state = 'SUBMITTED' # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED + self.job_state = "SUBMITTED" # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED self.job_queue.jobs.append(self) self.job_started_at = datetime.datetime(1970, 1, 1) self.job_stopped_at = datetime.datetime(1970, 1, 1) @@ -265,7 +312,7 @@ class Job(threading.Thread, BaseModel): self.stop = False self.daemon = True - self.name = 'MOTO-BATCH-' + self.job_id + self.name = "MOTO-BATCH-" + self.job_id self.docker_client = docker.from_env() self._log_backend = log_backend @@ -281,32 +328,40 @@ class Job(threading.Thread, BaseModel): if isinstance(adapter, requests.adapters.HTTPAdapter): adapter.send = functools.partial(_orig_adapter_send, adapter) return adapter + self.docker_client.api.get_adapter = replace_adapter_send def describe(self): result = { - 'jobDefinition': self.job_definition.arn, - 'jobId': self.job_id, - 'jobName': self.job_name, - 'jobQueue': self.job_queue.arn, - 'startedAt': datetime2int(self.job_started_at), - 'status': self.job_state, - 'dependsOn': [] + "jobDefinition": self.job_definition.arn, + "jobId": self.job_id, + "jobName": self.job_name, + "jobQueue": self.job_queue.arn, + "startedAt": datetime2int(self.job_started_at), + "status": self.job_state, + "dependsOn": [], } if self.job_stopped: - result['stoppedAt'] = datetime2int(self.job_stopped_at) - result['container'] = {} - result['container']['command'] = ['/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"'] - result['container']['privileged'] = False - result['container']['readonlyRootFilesystem'] = False - result['container']['ulimits'] = {} - result['container']['vcpus'] = 1 - result['container']['volumes'] = '' - result['container']['logStreamName'] = self.log_stream_name + result["stoppedAt"] = datetime2int(self.job_stopped_at) + result["container"] = {} + result["container"]["command"] = [ + '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"' + ] + result["container"]["privileged"] = False + result["container"]["readonlyRootFilesystem"] = False + result["container"]["ulimits"] = {} + result["container"]["vcpus"] = 1 + result["container"]["volumes"] = "" + result["container"]["logStreamName"] = self.log_stream_name if self.job_stopped_reason is not None: - result['statusReason'] = self.job_stopped_reason + result["statusReason"] = self.job_stopped_reason return result + def _get_container_property(self, p, default): + return self.container_overrides.get( + p, self.job_definition.container_properties.get(p, default) + ) + def run(self): """ Run the container. @@ -322,24 +377,55 @@ class Job(threading.Thread, BaseModel): :return: """ try: - self.job_state = 'PENDING' + self.job_state = "PENDING" time.sleep(1) - image = 'alpine:latest' - cmd = '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"' - name = '{0}-{1}'.format(self.job_name, self.job_id) + image = self.job_definition.container_properties.get( + "image", "alpine:latest" + ) + privileged = self.job_definition.container_properties.get( + "privileged", False + ) + cmd = self._get_container_property( + "command", + '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"', + ) + environment = { + e["name"]: e["value"] + for e in self._get_container_property("environment", []) + } + volumes = { + v["name"]: v["host"] + for v in self._get_container_property("volumes", []) + } + mounts = [ + docker.types.Mount( + m["containerPath"], + volumes[m["sourceVolume"]]["sourcePath"], + type="bind", + read_only=m["readOnly"], + ) + for m in self._get_container_property("mountPoints", []) + ] + name = "{0}-{1}".format(self.job_name, self.job_id) - self.job_state = 'RUNNABLE' + self.job_state = "RUNNABLE" # TODO setup ecs container instance time.sleep(1) - self.job_state = 'STARTING' + self.job_state = "STARTING" + log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON) container = self.docker_client.containers.run( - image, cmd, + image, + cmd, detach=True, - name=name + name=name, + log_config=log_config, + environment=environment, + mounts=mounts, + privileged=privileged, ) - self.job_state = 'RUNNING' + self.job_state = "RUNNING" self.job_started_at = datetime.datetime.now() try: # Log collection @@ -353,53 +439,99 @@ class Job(threading.Thread, BaseModel): # events seem to be duplicated. now = datetime.datetime.now() i = 1 - while container.status == 'running' and not self.stop: + while container.status == "running" and not self.stop: time.sleep(0.15) if i % 10 == 0: - logs_stderr.extend(container.logs(stdout=False, stderr=True, timestamps=True, since=datetime2int(now)).decode().split('\n')) - logs_stdout.extend(container.logs(stdout=True, stderr=False, timestamps=True, since=datetime2int(now)).decode().split('\n')) + logs_stderr.extend( + container.logs( + stdout=False, + stderr=True, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) + logs_stdout.extend( + container.logs( + stdout=True, + stderr=False, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) now = datetime.datetime.now() container.reload() i += 1 # Container should be stopped by this point... unless asked to stop - if container.status == 'running': + if container.status == "running": container.kill() self.job_stopped_at = datetime.datetime.now() # Get final logs - logs_stderr.extend(container.logs(stdout=False, stderr=True, timestamps=True, since=datetime2int(now)).decode().split('\n')) - logs_stdout.extend(container.logs(stdout=True, stderr=False, timestamps=True, since=datetime2int(now)).decode().split('\n')) + logs_stderr.extend( + container.logs( + stdout=False, + stderr=True, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) + logs_stdout.extend( + container.logs( + stdout=True, + stderr=False, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) - self.job_state = 'SUCCEEDED' if not self.stop else 'FAILED' + self.job_state = "SUCCEEDED" if not self.stop else "FAILED" # Process logs logs_stdout = [x for x in logs_stdout if len(x) > 0] logs_stderr = [x for x in logs_stderr if len(x) > 0] logs = [] for line in logs_stdout + logs_stderr: - date, line = line.split(' ', 1) + date, line = line.split(" ", 1) date = dateutil.parser.parse(date) date = int(date.timestamp()) - logs.append({'timestamp': date, 'message': line.strip()}) + logs.append({"timestamp": date, "message": line.strip()}) # Send to cloudwatch - log_group = '/aws/batch/job' - stream_name = '{0}/default/{1}'.format(self.job_definition.name, self.job_id) + log_group = "/aws/batch/job" + stream_name = "{0}/default/{1}".format( + self.job_definition.name, self.job_id + ) self.log_stream_name = stream_name self._log_backend.ensure_log_group(log_group, None) self._log_backend.create_log_stream(log_group, stream_name) self._log_backend.put_log_events(log_group, stream_name, logs, None) except Exception as err: - logger.error('Failed to run AWS Batch container {0}. Error {1}'.format(self.name, err)) - self.job_state = 'FAILED' + logger.error( + "Failed to run AWS Batch container {0}. Error {1}".format( + self.name, err + ) + ) + self.job_state = "FAILED" container.kill() finally: container.remove() except Exception as err: - logger.error('Failed to run AWS Batch container {0}. Error {1}'.format(self.name, err)) - self.job_state = 'FAILED' + logger.error( + "Failed to run AWS Batch container {0}. Error {1}".format( + self.name, err + ) + ) + self.job_state = "FAILED" self.job_stopped = True self.job_stopped_at = datetime.datetime.now() @@ -426,7 +558,7 @@ class BatchBackend(BaseBackend): :return: IAM Backend :rtype: moto.iam.models.IAMBackend """ - return iam_backends['global'] + return iam_backends["global"] @property def ec2_backend(self): @@ -456,7 +588,7 @@ class BatchBackend(BaseBackend): region_name = self.region_name for job in self._jobs.values(): - if job.job_state not in ('FAILED', 'SUCCEEDED'): + if job.job_state not in ("FAILED", "SUCCEEDED"): job.stop = True # Try to join job.join(0.2) @@ -530,7 +662,7 @@ class BatchBackend(BaseBackend): def get_job_definition(self, identifier): """ - Get job defintiion by name or ARN + Get job definitions by name or ARN :param identifier: Name or ARN :type identifier: str @@ -539,15 +671,17 @@ class BatchBackend(BaseBackend): """ job_def = self.get_job_definition_by_arn(identifier) if job_def is None: - if ':' in identifier: - job_def = self.get_job_definition_by_name_revision(*identifier.split(':', 1)) + if ":" in identifier: + job_def = self.get_job_definition_by_name_revision( + *identifier.split(":", 1) + ) else: job_def = self.get_job_definition_by_name(identifier) return job_def def get_job_definitions(self, identifier): """ - Get job defintiion by name or ARN + Get job definitions by name or ARN :param identifier: Name or ARN :type identifier: str @@ -579,7 +713,9 @@ class BatchBackend(BaseBackend): except KeyError: return None - def describe_compute_environments(self, environments=None, max_results=None, next_token=None): + def describe_compute_environments( + self, environments=None, max_results=None, next_token=None + ): envs = set() if environments is not None: envs = set(environments) @@ -591,82 +727,107 @@ class BatchBackend(BaseBackend): continue json_part = { - 'computeEnvironmentArn': arn, - 'computeEnvironmentName': environment.name, - 'ecsClusterArn': environment.ecs_arn, - 'serviceRole': environment.service_role, - 'state': environment.state, - 'type': environment.env_type, - 'status': 'VALID' + "computeEnvironmentArn": arn, + "computeEnvironmentName": environment.name, + "ecsClusterArn": environment.ecs_arn, + "serviceRole": environment.service_role, + "state": environment.state, + "type": environment.env_type, + "status": "VALID", } - if environment.env_type == 'MANAGED': - json_part['computeResources'] = environment.compute_resources + if environment.env_type == "MANAGED": + json_part["computeResources"] = environment.compute_resources result.append(json_part) return result - def create_compute_environment(self, compute_environment_name, _type, state, compute_resources, service_role): + def create_compute_environment( + self, compute_environment_name, _type, state, compute_resources, service_role + ): # Validate if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None: - raise InvalidParameterValueException('Compute environment name does not match ^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$') + raise InvalidParameterValueException( + "Compute environment name does not match ^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$" + ) if self.get_compute_environment_by_name(compute_environment_name) is not None: - raise InvalidParameterValueException('A compute environment already exists with the name {0}'.format(compute_environment_name)) + raise InvalidParameterValueException( + "A compute environment already exists with the name {0}".format( + compute_environment_name + ) + ) # Look for IAM role try: self.iam_backend.get_role_by_arn(service_role) except IAMNotFoundException: - raise InvalidParameterValueException('Could not find IAM role {0}'.format(service_role)) + raise InvalidParameterValueException( + "Could not find IAM role {0}".format(service_role) + ) - if _type not in ('MANAGED', 'UNMANAGED'): - raise InvalidParameterValueException('type {0} must be one of MANAGED | UNMANAGED'.format(service_role)) + if _type not in ("MANAGED", "UNMANAGED"): + raise InvalidParameterValueException( + "type {0} must be one of MANAGED | UNMANAGED".format(service_role) + ) - if state is not None and state not in ('ENABLED', 'DISABLED'): - raise InvalidParameterValueException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state is not None and state not in ("ENABLED", "DISABLED"): + raise InvalidParameterValueException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) - if compute_resources is None and _type == 'MANAGED': - raise InvalidParameterValueException('computeResources must be specified when creating a MANAGED environment'.format(state)) + if compute_resources is None and _type == "MANAGED": + raise InvalidParameterValueException( + "computeResources must be specified when creating a MANAGED environment".format( + state + ) + ) elif compute_resources is not None: self._validate_compute_resources(compute_resources) # By here, all values except SPOT ones have been validated new_comp_env = ComputeEnvironment( - compute_environment_name, _type, state, - compute_resources, service_role, - region_name=self.region_name + compute_environment_name, + _type, + state, + compute_resources, + service_role, + region_name=self.region_name, ) self._compute_environments[new_comp_env.arn] = new_comp_env # Ok by this point, everything is legit, so if its Managed then start some instances - if _type == 'MANAGED': - cpus = int(compute_resources.get('desiredvCpus', compute_resources['minvCpus'])) - instance_types = compute_resources['instanceTypes'] - needed_instance_types = self.find_min_instances_to_meet_vcpus(instance_types, cpus) + if _type == "MANAGED": + cpus = int( + compute_resources.get("desiredvCpus", compute_resources["minvCpus"]) + ) + instance_types = compute_resources["instanceTypes"] + needed_instance_types = self.find_min_instances_to_meet_vcpus( + instance_types, cpus + ) # Create instances # Will loop over and over so we get decent subnet coverage - subnet_cycle = cycle(compute_resources['subnets']) + subnet_cycle = cycle(compute_resources["subnets"]) for instance_type in needed_instance_types: reservation = self.ec2_backend.add_instances( - image_id='ami-ecs-optimised', # Todo import AMIs + image_id="ami-ecs-optimised", # Todo import AMIs count=1, user_data=None, security_group_names=[], instance_type=instance_type, region_name=self.region_name, subnet_id=six.next(subnet_cycle), - key_name=compute_resources.get('ec2KeyPair', 'AWS_OWNED'), - security_group_ids=compute_resources['securityGroupIds'] + key_name=compute_resources.get("ec2KeyPair", "AWS_OWNED"), + security_group_ids=compute_resources["securityGroupIds"], ) new_comp_env.add_instance(reservation.instances[0]) # Create ECS cluster # Should be of format P2OnDemand_Batch_UUID - cluster_name = 'OnDemand_Batch_' + str(uuid.uuid4()) + cluster_name = "OnDemand_Batch_" + str(uuid.uuid4()) ecs_cluster = self.ecs_backend.create_cluster(cluster_name) new_comp_env.set_ecs(ecs_cluster.arn, cluster_name) @@ -679,47 +840,75 @@ class BatchBackend(BaseBackend): :param cr: computeResources :type cr: dict """ - for param in ('instanceRole', 'maxvCpus', 'minvCpus', 'instanceTypes', 'securityGroupIds', 'subnets', 'type'): + for param in ( + "instanceRole", + "maxvCpus", + "minvCpus", + "instanceTypes", + "securityGroupIds", + "subnets", + "type", + ): if param not in cr: - raise InvalidParameterValueException('computeResources must contain {0}'.format(param)) + raise InvalidParameterValueException( + "computeResources must contain {0}".format(param) + ) + for profile in self.iam_backend.get_instance_profiles(): + if profile.arn == cr["instanceRole"]: + break + else: + raise InvalidParameterValueException( + "could not find instanceRole {0}".format(cr["instanceRole"]) + ) - if self.iam_backend.get_role_by_arn(cr['instanceRole']) is None: - raise InvalidParameterValueException('could not find instanceRole {0}'.format(cr['instanceRole'])) + if cr["maxvCpus"] < 0: + raise InvalidParameterValueException("maxVCpus must be positive") + if cr["minvCpus"] < 0: + raise InvalidParameterValueException("minVCpus must be positive") + if cr["maxvCpus"] < cr["minvCpus"]: + raise InvalidParameterValueException( + "maxVCpus must be greater than minvCpus" + ) - if cr['maxvCpus'] < 0: - raise InvalidParameterValueException('maxVCpus must be positive') - if cr['minvCpus'] < 0: - raise InvalidParameterValueException('minVCpus must be positive') - if cr['maxvCpus'] < cr['minvCpus']: - raise InvalidParameterValueException('maxVCpus must be greater than minvCpus') - - if len(cr['instanceTypes']) == 0: - raise InvalidParameterValueException('At least 1 instance type must be provided') - for instance_type in cr['instanceTypes']: - if instance_type == 'optimal': + if len(cr["instanceTypes"]) == 0: + raise InvalidParameterValueException( + "At least 1 instance type must be provided" + ) + for instance_type in cr["instanceTypes"]: + if instance_type == "optimal": pass # Optimal should pick from latest of current gen elif instance_type not in EC2_INSTANCE_TYPES: - raise InvalidParameterValueException('Instance type {0} does not exist'.format(instance_type)) + raise InvalidParameterValueException( + "Instance type {0} does not exist".format(instance_type) + ) - for sec_id in cr['securityGroupIds']: + for sec_id in cr["securityGroupIds"]: if self.ec2_backend.get_security_group_from_id(sec_id) is None: - raise InvalidParameterValueException('security group {0} does not exist'.format(sec_id)) - if len(cr['securityGroupIds']) == 0: - raise InvalidParameterValueException('At least 1 security group must be provided') + raise InvalidParameterValueException( + "security group {0} does not exist".format(sec_id) + ) + if len(cr["securityGroupIds"]) == 0: + raise InvalidParameterValueException( + "At least 1 security group must be provided" + ) - for subnet_id in cr['subnets']: + for subnet_id in cr["subnets"]: try: self.ec2_backend.get_subnet(subnet_id) except InvalidSubnetIdError: - raise InvalidParameterValueException('subnet {0} does not exist'.format(subnet_id)) - if len(cr['subnets']) == 0: - raise InvalidParameterValueException('At least 1 subnet must be provided') + raise InvalidParameterValueException( + "subnet {0} does not exist".format(subnet_id) + ) + if len(cr["subnets"]) == 0: + raise InvalidParameterValueException("At least 1 subnet must be provided") - if cr['type'] not in ('EC2', 'SPOT'): - raise InvalidParameterValueException('computeResources.type must be either EC2 | SPOT') + if cr["type"] not in ("EC2", "SPOT"): + raise InvalidParameterValueException( + "computeResources.type must be either EC2 | SPOT" + ) - if cr['type'] == 'SPOT': - raise InternalFailure('SPOT NOT SUPPORTED YET') + if cr["type"] == "SPOT": + raise InternalFailure("SPOT NOT SUPPORTED YET") @staticmethod def find_min_instances_to_meet_vcpus(instance_types, target): @@ -738,11 +927,11 @@ class BatchBackend(BaseBackend): instances = [] for instance_type in instance_types: - if instance_type == 'optimal': - instance_type = 'm4.4xlarge' + if instance_type == "optimal": + instance_type = "m4.4xlarge" instance_vcpus.append( - (EC2_INSTANCE_TYPES[instance_type]['vcpus'], instance_type) + (EC2_INSTANCE_TYPES[instance_type]["vcpus"], instance_type) ) instance_vcpus = sorted(instance_vcpus, key=lambda item: item[0], reverse=True) @@ -773,7 +962,7 @@ class BatchBackend(BaseBackend): def delete_compute_environment(self, compute_environment_name): if compute_environment_name is None: - raise InvalidParameterValueException('Missing computeEnvironment parameter') + raise InvalidParameterValueException("Missing computeEnvironment parameter") compute_env = self.get_compute_environment(compute_environment_name) @@ -784,29 +973,35 @@ class BatchBackend(BaseBackend): # Delete ECS cluster self.ecs_backend.delete_cluster(compute_env.ecs_name) - if compute_env.env_type == 'MANAGED': - # Delete compute envrionment + if compute_env.env_type == "MANAGED": + # Delete compute environment instance_ids = [instance.id for instance in compute_env.instances] self.ec2_backend.terminate_instances(instance_ids) - def update_compute_environment(self, compute_environment_name, state, compute_resources, service_role): + def update_compute_environment( + self, compute_environment_name, state, compute_resources, service_role + ): # Validate compute_env = self.get_compute_environment(compute_environment_name) if compute_env is None: - raise ClientException('Compute environment {0} does not exist') + raise ClientException("Compute environment {0} does not exist") # Look for IAM role if service_role is not None: try: role = self.iam_backend.get_role_by_arn(service_role) except IAMNotFoundException: - raise InvalidParameterValueException('Could not find IAM role {0}'.format(service_role)) + raise InvalidParameterValueException( + "Could not find IAM role {0}".format(service_role) + ) compute_env.service_role = role if state is not None: - if state not in ('ENABLED', 'DISABLED'): - raise InvalidParameterValueException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state not in ("ENABLED", "DISABLED"): + raise InvalidParameterValueException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) compute_env.state = state @@ -832,32 +1027,51 @@ class BatchBackend(BaseBackend): :return: Tuple of Name, ARN :rtype: tuple of str """ - for variable, var_name in ((queue_name, 'jobQueueName'), (priority, 'priority'), (state, 'state'), (compute_env_order, 'computeEnvironmentOrder')): + for variable, var_name in ( + (queue_name, "jobQueueName"), + (priority, "priority"), + (state, "state"), + (compute_env_order, "computeEnvironmentOrder"), + ): if variable is None: - raise ClientException('{0} must be provided'.format(var_name)) + raise ClientException("{0} must be provided".format(var_name)) - if state not in ('ENABLED', 'DISABLED'): - raise ClientException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state not in ("ENABLED", "DISABLED"): + raise ClientException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) if self.get_job_queue_by_name(queue_name) is not None: - raise ClientException('Job queue {0} already exists'.format(queue_name)) + raise ClientException("Job queue {0} already exists".format(queue_name)) if len(compute_env_order) == 0: - raise ClientException('At least 1 compute environment must be provided') + raise ClientException("At least 1 compute environment must be provided") try: # orders and extracts computeEnvironment names - ordered_compute_environments = [item['computeEnvironment'] for item in sorted(compute_env_order, key=lambda x: x['order'])] + ordered_compute_environments = [ + item["computeEnvironment"] + for item in sorted(compute_env_order, key=lambda x: x["order"]) + ] env_objects = [] # Check each ARN exists, then make a list of compute env's for arn in ordered_compute_environments: env = self.get_compute_environment_by_arn(arn) if env is None: - raise ClientException('Compute environment {0} does not exist'.format(arn)) + raise ClientException( + "Compute environment {0} does not exist".format(arn) + ) env_objects.append(env) except Exception: - raise ClientException('computeEnvironmentOrder is malformed') + raise ClientException("computeEnvironmentOrder is malformed") # Create new Job Queue - queue = JobQueue(queue_name, priority, state, env_objects, compute_env_order, self.region_name) + queue = JobQueue( + queue_name, + priority, + state, + env_objects, + compute_env_order, + self.region_name, + ) self._job_queues[queue.arn] = queue return queue_name, queue.arn @@ -893,33 +1107,40 @@ class BatchBackend(BaseBackend): :rtype: tuple of str """ if queue_name is None: - raise ClientException('jobQueueName must be provided') + raise ClientException("jobQueueName must be provided") job_queue = self.get_job_queue(queue_name) if job_queue is None: - raise ClientException('Job queue {0} does not exist'.format(queue_name)) + raise ClientException("Job queue {0} does not exist".format(queue_name)) if state is not None: - if state not in ('ENABLED', 'DISABLED'): - raise ClientException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state not in ("ENABLED", "DISABLED"): + raise ClientException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) job_queue.state = state if compute_env_order is not None: if len(compute_env_order) == 0: - raise ClientException('At least 1 compute environment must be provided') + raise ClientException("At least 1 compute environment must be provided") try: # orders and extracts computeEnvironment names - ordered_compute_environments = [item['computeEnvironment'] for item in sorted(compute_env_order, key=lambda x: x['order'])] + ordered_compute_environments = [ + item["computeEnvironment"] + for item in sorted(compute_env_order, key=lambda x: x["order"]) + ] env_objects = [] # Check each ARN exists, then make a list of compute env's for arn in ordered_compute_environments: env = self.get_compute_environment_by_arn(arn) if env is None: - raise ClientException('Compute environment {0} does not exist'.format(arn)) + raise ClientException( + "Compute environment {0} does not exist".format(arn) + ) env_objects.append(env) except Exception: - raise ClientException('computeEnvironmentOrder is malformed') + raise ClientException("computeEnvironmentOrder is malformed") job_queue.env_order_json = compute_env_order job_queue.environments = env_objects @@ -935,22 +1156,33 @@ class BatchBackend(BaseBackend): if job_queue is not None: del self._job_queues[job_queue.arn] - def register_job_definition(self, def_name, parameters, _type, retry_strategy, container_properties): + def register_job_definition( + self, def_name, parameters, _type, retry_strategy, container_properties + ): if def_name is None: - raise ClientException('jobDefinitionName must be provided') + raise ClientException("jobDefinitionName must be provided") job_def = self.get_job_definition_by_name(def_name) if retry_strategy is not None: try: - retry_strategy = retry_strategy['attempts'] + retry_strategy = retry_strategy["attempts"] except Exception: - raise ClientException('retryStrategy is malformed') + raise ClientException("retryStrategy is malformed") if job_def is None: - job_def = JobDefinition(def_name, parameters, _type, container_properties, region_name=self.region_name, retry_strategy=retry_strategy) + job_def = JobDefinition( + def_name, + parameters, + _type, + container_properties, + region_name=self.region_name, + retry_strategy=retry_strategy, + ) else: # Make new jobdef - job_def = job_def.update(parameters, _type, container_properties, retry_strategy) + job_def = job_def.update( + parameters, _type, container_properties, retry_strategy + ) self._job_definitions[job_def.arn] = job_def @@ -958,14 +1190,21 @@ class BatchBackend(BaseBackend): def deregister_job_definition(self, def_name): job_def = self.get_job_definition_by_arn(def_name) - if job_def is None and ':' in def_name: - name, revision = def_name.split(':', 1) + if job_def is None and ":" in def_name: + name, revision = def_name.split(":", 1) job_def = self.get_job_definition_by_name_revision(name, revision) if job_def is not None: del self._job_definitions[job_def.arn] - def describe_job_definitions(self, job_def_name=None, job_def_list=None, status=None, max_results=None, next_token=None): + def describe_job_definitions( + self, + job_def_name=None, + job_def_list=None, + status=None, + max_results=None, + next_token=None, + ): jobs = [] # As a job name can reference multiple revisions, we get a list of them @@ -986,19 +1225,36 @@ class BatchBackend(BaseBackend): return [job for job in jobs if job.status == status] return jobs - def submit_job(self, job_name, job_def_id, job_queue, parameters=None, retries=None, depends_on=None, container_overrides=None): - # TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now + def submit_job( + self, + job_name, + job_def_id, + job_queue, + parameters=None, + retries=None, + depends_on=None, + container_overrides=None, + ): + # TODO parameters, retries (which is a dict raw from request), job dependencies and container overrides are ignored for now # Look for job definition job_def = self.get_job_definition(job_def_id) if job_def is None: - raise ClientException('Job definition {0} does not exist'.format(job_def_id)) + raise ClientException( + "Job definition {0} does not exist".format(job_def_id) + ) queue = self.get_job_queue(job_queue) if queue is None: - raise ClientException('Job queue {0} does not exist'.format(job_queue)) + raise ClientException("Job queue {0} does not exist".format(job_queue)) - job = Job(job_name, job_def, queue, log_backend=self.logs_backend) + job = Job( + job_name, + job_def, + queue, + log_backend=self.logs_backend, + container_overrides=container_overrides, + ) self._jobs[job.job_id] = job # Here comes the fun @@ -1025,10 +1281,20 @@ class BatchBackend(BaseBackend): job_queue = self.get_job_queue(job_queue) if job_queue is None: - raise ClientException('Job queue {0} does not exist'.format(job_queue)) + raise ClientException("Job queue {0} does not exist".format(job_queue)) - if job_status is not None and job_status not in ('SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING', 'SUCCEEDED', 'FAILED'): - raise ClientException('Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED') + if job_status is not None and job_status not in ( + "SUBMITTED", + "PENDING", + "RUNNABLE", + "STARTING", + "RUNNING", + "SUCCEEDED", + "FAILED", + ): + raise ClientException( + "Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED" + ) for job in job_queue.jobs: if job_status is not None and job.job_state != job_status: @@ -1040,16 +1306,18 @@ class BatchBackend(BaseBackend): def terminate_job(self, job_id, reason): if job_id is None: - raise ClientException('Job ID does not exist') + raise ClientException("Job ID does not exist") if reason is None: - raise ClientException('Reason does not exist') + raise ClientException("Reason does not exist") job = self.get_job_by_id(job_id) if job is None: - raise ClientException('Job not found') + raise ClientException("Job not found") job.terminate(reason) available_regions = boto3.session.Session().get_available_regions("batch") -batch_backends = {region: BatchBackend(region_name=region) for region in available_regions} +batch_backends = { + region: BatchBackend(region_name=region) for region in available_regions +} diff --git a/moto/batch/responses.py b/moto/batch/responses.py index 7fb606184..61b00e9c9 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -10,7 +10,7 @@ import json class BatchResponse(BaseResponse): def _error(self, code, message): - return json.dumps({'__type': code, 'message': message}), dict(status=400) + return json.dumps({"__type": code, "message": message}), dict(status=400) @property def batch_backend(self): @@ -22,9 +22,9 @@ class BatchResponse(BaseResponse): @property def json(self): - if self.body is None or self.body == '': + if self.body is None or self.body == "": self._json = {} - elif not hasattr(self, '_json'): + elif not hasattr(self, "_json"): try: self._json = json.loads(self.body) except ValueError: @@ -39,153 +39,146 @@ class BatchResponse(BaseResponse): def _get_action(self): # Return element after the /v1/* - return urlsplit(self.uri).path.lstrip('/').split('/')[1] + return urlsplit(self.uri).path.lstrip("/").split("/")[1] # CreateComputeEnvironment def createcomputeenvironment(self): - compute_env_name = self._get_param('computeEnvironmentName') - compute_resource = self._get_param('computeResources') - service_role = self._get_param('serviceRole') - state = self._get_param('state') - _type = self._get_param('type') + compute_env_name = self._get_param("computeEnvironmentName") + compute_resource = self._get_param("computeResources") + service_role = self._get_param("serviceRole") + state = self._get_param("state") + _type = self._get_param("type") try: name, arn = self.batch_backend.create_compute_environment( compute_environment_name=compute_env_name, - _type=_type, state=state, + _type=_type, + state=state, compute_resources=compute_resource, - service_role=service_role + service_role=service_role, ) except AWSError as err: return err.response() - result = { - 'computeEnvironmentArn': arn, - 'computeEnvironmentName': name - } + result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name} return json.dumps(result) # DescribeComputeEnvironments def describecomputeenvironments(self): - compute_environments = self._get_param('computeEnvironments') - max_results = self._get_param('maxResults') # Ignored, should be int - next_token = self._get_param('nextToken') # Ignored + compute_environments = self._get_param("computeEnvironments") + max_results = self._get_param("maxResults") # Ignored, should be int + next_token = self._get_param("nextToken") # Ignored - envs = self.batch_backend.describe_compute_environments(compute_environments, max_results=max_results, next_token=next_token) + envs = self.batch_backend.describe_compute_environments( + compute_environments, max_results=max_results, next_token=next_token + ) - result = {'computeEnvironments': envs} + result = {"computeEnvironments": envs} return json.dumps(result) # DeleteComputeEnvironment def deletecomputeenvironment(self): - compute_environment = self._get_param('computeEnvironment') + compute_environment = self._get_param("computeEnvironment") try: self.batch_backend.delete_compute_environment(compute_environment) except AWSError as err: return err.response() - return '' + return "" # UpdateComputeEnvironment def updatecomputeenvironment(self): - compute_env_name = self._get_param('computeEnvironment') - compute_resource = self._get_param('computeResources') - service_role = self._get_param('serviceRole') - state = self._get_param('state') + compute_env_name = self._get_param("computeEnvironment") + compute_resource = self._get_param("computeResources") + service_role = self._get_param("serviceRole") + state = self._get_param("state") try: name, arn = self.batch_backend.update_compute_environment( compute_environment_name=compute_env_name, compute_resources=compute_resource, service_role=service_role, - state=state + state=state, ) except AWSError as err: return err.response() - result = { - 'computeEnvironmentArn': arn, - 'computeEnvironmentName': name - } + result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name} return json.dumps(result) # CreateJobQueue def createjobqueue(self): - compute_env_order = self._get_param('computeEnvironmentOrder') - queue_name = self._get_param('jobQueueName') - priority = self._get_param('priority') - state = self._get_param('state') + compute_env_order = self._get_param("computeEnvironmentOrder") + queue_name = self._get_param("jobQueueName") + priority = self._get_param("priority") + state = self._get_param("state") try: name, arn = self.batch_backend.create_job_queue( queue_name=queue_name, priority=priority, state=state, - compute_env_order=compute_env_order + compute_env_order=compute_env_order, ) except AWSError as err: return err.response() - result = { - 'jobQueueArn': arn, - 'jobQueueName': name - } + result = {"jobQueueArn": arn, "jobQueueName": name} return json.dumps(result) # DescribeJobQueues def describejobqueues(self): - job_queues = self._get_param('jobQueues') - max_results = self._get_param('maxResults') # Ignored, should be int - next_token = self._get_param('nextToken') # Ignored + job_queues = self._get_param("jobQueues") + max_results = self._get_param("maxResults") # Ignored, should be int + next_token = self._get_param("nextToken") # Ignored - queues = self.batch_backend.describe_job_queues(job_queues, max_results=max_results, next_token=next_token) + queues = self.batch_backend.describe_job_queues( + job_queues, max_results=max_results, next_token=next_token + ) - result = {'jobQueues': queues} + result = {"jobQueues": queues} return json.dumps(result) # UpdateJobQueue def updatejobqueue(self): - compute_env_order = self._get_param('computeEnvironmentOrder') - queue_name = self._get_param('jobQueue') - priority = self._get_param('priority') - state = self._get_param('state') + compute_env_order = self._get_param("computeEnvironmentOrder") + queue_name = self._get_param("jobQueue") + priority = self._get_param("priority") + state = self._get_param("state") try: name, arn = self.batch_backend.update_job_queue( queue_name=queue_name, priority=priority, state=state, - compute_env_order=compute_env_order + compute_env_order=compute_env_order, ) except AWSError as err: return err.response() - result = { - 'jobQueueArn': arn, - 'jobQueueName': name - } + result = {"jobQueueArn": arn, "jobQueueName": name} return json.dumps(result) # DeleteJobQueue def deletejobqueue(self): - queue_name = self._get_param('jobQueue') + queue_name = self._get_param("jobQueue") self.batch_backend.delete_job_queue(queue_name) - return '' + return "" # RegisterJobDefinition def registerjobdefinition(self): - container_properties = self._get_param('containerProperties') - def_name = self._get_param('jobDefinitionName') - parameters = self._get_param('parameters') - retry_strategy = self._get_param('retryStrategy') - _type = self._get_param('type') + container_properties = self._get_param("containerProperties") + def_name = self._get_param("jobDefinitionName") + parameters = self._get_param("parameters") + retry_strategy = self._get_param("retryStrategy") + _type = self._get_param("type") try: name, arn, revision = self.batch_backend.register_job_definition( @@ -193,104 +186,113 @@ class BatchResponse(BaseResponse): parameters=parameters, _type=_type, retry_strategy=retry_strategy, - container_properties=container_properties + container_properties=container_properties, ) except AWSError as err: return err.response() result = { - 'jobDefinitionArn': arn, - 'jobDefinitionName': name, - 'revision': revision + "jobDefinitionArn": arn, + "jobDefinitionName": name, + "revision": revision, } return json.dumps(result) # DeregisterJobDefinition def deregisterjobdefinition(self): - queue_name = self._get_param('jobDefinition') + queue_name = self._get_param("jobDefinition") self.batch_backend.deregister_job_definition(queue_name) - return '' + return "" # DescribeJobDefinitions def describejobdefinitions(self): - job_def_name = self._get_param('jobDefinitionName') - job_def_list = self._get_param('jobDefinitions') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') - status = self._get_param('status') + job_def_name = self._get_param("jobDefinitionName") + job_def_list = self._get_param("jobDefinitions") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") + status = self._get_param("status") - job_defs = self.batch_backend.describe_job_definitions(job_def_name, job_def_list, status, max_results, next_token) + job_defs = self.batch_backend.describe_job_definitions( + job_def_name, job_def_list, status, max_results, next_token + ) - result = {'jobDefinitions': [job.describe() for job in job_defs]} + result = {"jobDefinitions": [job.describe() for job in job_defs]} return json.dumps(result) # SubmitJob def submitjob(self): - container_overrides = self._get_param('containerOverrides') - depends_on = self._get_param('dependsOn') - job_def = self._get_param('jobDefinition') - job_name = self._get_param('jobName') - job_queue = self._get_param('jobQueue') - parameters = self._get_param('parameters') - retries = self._get_param('retryStrategy') + container_overrides = self._get_param("containerOverrides") + depends_on = self._get_param("dependsOn") + job_def = self._get_param("jobDefinition") + job_name = self._get_param("jobName") + job_queue = self._get_param("jobQueue") + parameters = self._get_param("parameters") + retries = self._get_param("retryStrategy") try: name, job_id = self.batch_backend.submit_job( - job_name, job_def, job_queue, + job_name, + job_def, + job_queue, parameters=parameters, retries=retries, depends_on=depends_on, - container_overrides=container_overrides + container_overrides=container_overrides, ) except AWSError as err: return err.response() - result = { - 'jobId': job_id, - 'jobName': name, - } + result = {"jobId": job_id, "jobName": name} return json.dumps(result) # DescribeJobs def describejobs(self): - jobs = self._get_param('jobs') + jobs = self._get_param("jobs") try: - return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)}) + return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)}) except AWSError as err: return err.response() # ListJobs def listjobs(self): - job_queue = self._get_param('jobQueue') - job_status = self._get_param('jobStatus') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') + job_queue = self._get_param("jobQueue") + job_status = self._get_param("jobStatus") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") try: - jobs = self.batch_backend.list_jobs(job_queue, job_status, max_results, next_token) + jobs = self.batch_backend.list_jobs( + job_queue, job_status, max_results, next_token + ) except AWSError as err: return err.response() - result = {'jobSummaryList': [{'jobId': job.job_id, 'jobName': job.job_name} for job in jobs]} + result = { + "jobSummaryList": [ + {"jobId": job.job_id, "jobName": job.job_name} for job in jobs + ] + } return json.dumps(result) # TerminateJob def terminatejob(self): - job_id = self._get_param('jobId') - reason = self._get_param('reason') + job_id = self._get_param("jobId") + reason = self._get_param("reason") try: self.batch_backend.terminate_job(job_id, reason) except AWSError as err: return err.response() - return '' + return "" # CancelJob - def canceljob(self): # Theres some AWS semantics on the differences but for us they're identical ;-) + def canceljob( + self, + ): # Theres some AWS semantics on the differences but for us they're identical ;-) return self.terminatejob() diff --git a/moto/batch/urls.py b/moto/batch/urls.py index c64086ef2..9dc507416 100644 --- a/moto/batch/urls.py +++ b/moto/batch/urls.py @@ -1,25 +1,23 @@ from __future__ import unicode_literals from .responses import BatchResponse -url_bases = [ - "https?://batch.(.+).amazonaws.com", -] +url_bases = ["https?://batch.(.+).amazonaws.com"] url_paths = { - '{0}/v1/createcomputeenvironment$': BatchResponse.dispatch, - '{0}/v1/describecomputeenvironments$': BatchResponse.dispatch, - '{0}/v1/deletecomputeenvironment': BatchResponse.dispatch, - '{0}/v1/updatecomputeenvironment': BatchResponse.dispatch, - '{0}/v1/createjobqueue': BatchResponse.dispatch, - '{0}/v1/describejobqueues': BatchResponse.dispatch, - '{0}/v1/updatejobqueue': BatchResponse.dispatch, - '{0}/v1/deletejobqueue': BatchResponse.dispatch, - '{0}/v1/registerjobdefinition': BatchResponse.dispatch, - '{0}/v1/deregisterjobdefinition': BatchResponse.dispatch, - '{0}/v1/describejobdefinitions': BatchResponse.dispatch, - '{0}/v1/submitjob': BatchResponse.dispatch, - '{0}/v1/describejobs': BatchResponse.dispatch, - '{0}/v1/listjobs': BatchResponse.dispatch, - '{0}/v1/terminatejob': BatchResponse.dispatch, - '{0}/v1/canceljob': BatchResponse.dispatch, + "{0}/v1/createcomputeenvironment$": BatchResponse.dispatch, + "{0}/v1/describecomputeenvironments$": BatchResponse.dispatch, + "{0}/v1/deletecomputeenvironment": BatchResponse.dispatch, + "{0}/v1/updatecomputeenvironment": BatchResponse.dispatch, + "{0}/v1/createjobqueue": BatchResponse.dispatch, + "{0}/v1/describejobqueues": BatchResponse.dispatch, + "{0}/v1/updatejobqueue": BatchResponse.dispatch, + "{0}/v1/deletejobqueue": BatchResponse.dispatch, + "{0}/v1/registerjobdefinition": BatchResponse.dispatch, + "{0}/v1/deregisterjobdefinition": BatchResponse.dispatch, + "{0}/v1/describejobdefinitions": BatchResponse.dispatch, + "{0}/v1/submitjob": BatchResponse.dispatch, + "{0}/v1/describejobs": BatchResponse.dispatch, + "{0}/v1/listjobs": BatchResponse.dispatch, + "{0}/v1/terminatejob": BatchResponse.dispatch, + "{0}/v1/canceljob": BatchResponse.dispatch, } diff --git a/moto/batch/utils.py b/moto/batch/utils.py index 829a55f12..ce9b2ffe8 100644 --- a/moto/batch/utils.py +++ b/moto/batch/utils.py @@ -2,7 +2,9 @@ from __future__ import unicode_literals def make_arn_for_compute_env(account_id, name, region_name): - return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name) + return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format( + region_name, account_id, name + ) def make_arn_for_job_queue(account_id, name, region_name): @@ -10,7 +12,9 @@ def make_arn_for_job_queue(account_id, name, region_name): def make_arn_for_task_def(account_id, name, revision, region_name): - return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision) + return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format( + region_name, account_id, name, revision + ) def lowercase_first_key(some_dict): diff --git a/moto/cloudformation/__init__.py b/moto/cloudformation/__init__.py index b73e3ab6c..351af146c 100644 --- a/moto/cloudformation/__init__.py +++ b/moto/cloudformation/__init__.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals from .models import cloudformation_backends from ..core.models import base_decorator, deprecated_base_decorator -cloudformation_backend = cloudformation_backends['us-east-1'] +cloudformation_backend = cloudformation_backends["us-east-1"] mock_cloudformation = base_decorator(cloudformation_backends) -mock_cloudformation_deprecated = deprecated_base_decorator( - cloudformation_backends) +mock_cloudformation_deprecated = deprecated_base_decorator(cloudformation_backends) diff --git a/moto/cloudformation/exceptions.py b/moto/cloudformation/exceptions.py index 6ea15c5ca..10669ca56 100644 --- a/moto/cloudformation/exceptions.py +++ b/moto/cloudformation/exceptions.py @@ -4,26 +4,23 @@ from jinja2 import Template class UnformattedGetAttTemplateException(Exception): - description = 'Template error: resource {0} does not support attribute type {1} in Fn::GetAtt' + description = ( + "Template error: resource {0} does not support attribute type {1} in Fn::GetAtt" + ) status_code = 400 class ValidationError(BadRequest): - def __init__(self, name_or_id, message=None): if message is None: message = "Stack with id {0} does not exist".format(name_or_id) template = Template(ERROR_RESPONSE) super(ValidationError, self).__init__() - self.description = template.render( - code="ValidationError", - message=message, - ) + self.description = template.render(code="ValidationError", message=message) class MissingParameterError(BadRequest): - def __init__(self, parameter_name): template = Template(ERROR_RESPONSE) super(MissingParameterError, self).__init__() @@ -40,8 +37,8 @@ class ExportNotFound(BadRequest): template = Template(ERROR_RESPONSE) super(ExportNotFound, self).__init__() self.description = template.render( - code='ExportNotFound', - message="No export named {0} found.".format(export_name) + code="ExportNotFound", + message="No export named {0} found.".format(export_name), ) diff --git a/moto/cloudformation/models.py b/moto/cloudformation/models.py index 01e3113dd..71ceaf168 100644 --- a/moto/cloudformation/models.py +++ b/moto/cloudformation/models.py @@ -21,11 +21,19 @@ from .exceptions import ValidationError class FakeStackSet(BaseModel): - - def __init__(self, stackset_id, name, template, region='us-east-1', - status='ACTIVE', description=None, parameters=None, tags=None, - admin_role='AWSCloudFormationStackSetAdministrationRole', - execution_role='AWSCloudFormationStackSetExecutionRole'): + def __init__( + self, + stackset_id, + name, + template, + region="us-east-1", + status="ACTIVE", + description=None, + parameters=None, + tags=None, + admin_role="AWSCloudFormationStackSetAdministrationRole", + execution_role="AWSCloudFormationStackSetExecutionRole", + ): self.id = stackset_id self.arn = generate_stackset_arn(stackset_id, region) self.name = name @@ -42,12 +50,14 @@ class FakeStackSet(BaseModel): def _create_operation(self, operation_id, action, status, accounts=[], regions=[]): operation = { - 'OperationId': str(operation_id), - 'Action': action, - 'Status': status, - 'CreationTimestamp': datetime.now(), - 'EndTimestamp': datetime.now() + timedelta(minutes=2), - 'Instances': [{account: region} for account in accounts for region in regions], + "OperationId": str(operation_id), + "Action": action, + "Status": status, + "CreationTimestamp": datetime.now(), + "EndTimestamp": datetime.now() + timedelta(minutes=2), + "Instances": [ + {account: region} for account in accounts for region in regions + ], } self.operations += [operation] @@ -55,20 +65,30 @@ class FakeStackSet(BaseModel): def get_operation(self, operation_id): for operation in self.operations: - if operation_id == operation['OperationId']: + if operation_id == operation["OperationId"]: return operation raise ValidationError(operation_id) def update_operation(self, operation_id, status): operation = self.get_operation(operation_id) - operation['Status'] = status + operation["Status"] = status return operation_id def delete(self): - self.status = 'DELETED' + self.status = "DELETED" - def update(self, template, description, parameters, tags, admin_role, - execution_role, accounts, regions, operation_id=None): + def update( + self, + template, + description, + parameters, + tags, + admin_role, + execution_role, + accounts, + regions, + operation_id=None, + ): if not operation_id: operation_id = uuid.uuid4() @@ -82,9 +102,13 @@ class FakeStackSet(BaseModel): if accounts and regions: self.update_instances(accounts, regions, self.parameters) - operation = self._create_operation(operation_id=operation_id, - action='UPDATE', status='SUCCEEDED', accounts=accounts, - regions=regions) + operation = self._create_operation( + operation_id=operation_id, + action="UPDATE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) return operation def create_stack_instances(self, accounts, regions, parameters, operation_id=None): @@ -94,8 +118,13 @@ class FakeStackSet(BaseModel): parameters = self.parameters self.instances.create_instances(accounts, regions, parameters, operation_id) - self._create_operation(operation_id=operation_id, action='CREATE', - status='SUCCEEDED', accounts=accounts, regions=regions) + self._create_operation( + operation_id=operation_id, + action="CREATE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) def delete_stack_instances(self, accounts, regions, operation_id=None): if not operation_id: @@ -103,8 +132,13 @@ class FakeStackSet(BaseModel): self.instances.delete(accounts, regions) - operation = self._create_operation(operation_id=operation_id, action='DELETE', - status='SUCCEEDED', accounts=accounts, regions=regions) + operation = self._create_operation( + operation_id=operation_id, + action="DELETE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) return operation def update_instances(self, accounts, regions, parameters, operation_id=None): @@ -112,9 +146,13 @@ class FakeStackSet(BaseModel): operation_id = uuid.uuid4() self.instances.update(accounts, regions, parameters) - operation = self._create_operation(operation_id=operation_id, - action='UPDATE', status='SUCCEEDED', accounts=accounts, - regions=regions) + operation = self._create_operation( + operation_id=operation_id, + action="UPDATE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) return operation @@ -131,12 +169,12 @@ class FakeStackInstances(BaseModel): for region in regions: for account in accounts: instance = { - 'StackId': generate_stack_id(self.stack_name, region, account), - 'StackSetId': self.stackset_id, - 'Region': region, - 'Account': account, - 'Status': "CURRENT", - 'ParameterOverrides': parameters if parameters else [], + "StackId": generate_stack_id(self.stack_name, region, account), + "StackSetId": self.stackset_id, + "Region": region, + "Account": account, + "Status": "CURRENT", + "ParameterOverrides": parameters if parameters else [], } new_instances.append(instance) self.stack_instances += new_instances @@ -147,24 +185,35 @@ class FakeStackInstances(BaseModel): for region in regions: instance = self.get_instance(account, region) if parameters: - instance['ParameterOverrides'] = parameters + instance["ParameterOverrides"] = parameters else: - instance['ParameterOverrides'] = [] + instance["ParameterOverrides"] = [] def delete(self, accounts, regions): for i, instance in enumerate(self.stack_instances): - if instance['Region'] in regions and instance['Account'] in accounts: + if instance["Region"] in regions and instance["Account"] in accounts: self.stack_instances.pop(i) def get_instance(self, account, region): for i, instance in enumerate(self.stack_instances): - if instance['Region'] == region and instance['Account'] == account: + if instance["Region"] == region and instance["Account"] == account: return self.stack_instances[i] class FakeStack(BaseModel): - - def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None, create_change_set=False): + def __init__( + self, + stack_id, + name, + template, + parameters, + region_name, + notification_arns=None, + tags=None, + role_arn=None, + cross_stack_resources=None, + create_change_set=False, + ): self.stack_id = stack_id self.name = name self.template = template @@ -176,22 +225,31 @@ class FakeStack(BaseModel): self.tags = tags if tags else {} self.events = [] if create_change_set: - self._add_stack_event("REVIEW_IN_PROGRESS", - resource_status_reason="User Initiated") + self._add_stack_event( + "REVIEW_IN_PROGRESS", resource_status_reason="User Initiated" + ) else: - self._add_stack_event("CREATE_IN_PROGRESS", - resource_status_reason="User Initiated") + self._add_stack_event( + "CREATE_IN_PROGRESS", resource_status_reason="User Initiated" + ) - self.description = self.template_dict.get('Description') + self.description = self.template_dict.get("Description") self.cross_stack_resources = cross_stack_resources or {} self.resource_map = self._create_resource_map() self.output_map = self._create_output_map() self._add_stack_event("CREATE_COMPLETE") - self.status = 'CREATE_COMPLETE' + self.status = "CREATE_COMPLETE" def _create_resource_map(self): resource_map = ResourceMap( - self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict, self.cross_stack_resources) + self.stack_id, + self.name, + self.parameters, + self.tags, + self.region_name, + self.template_dict, + self.cross_stack_resources, + ) resource_map.create() return resource_map @@ -200,34 +258,46 @@ class FakeStack(BaseModel): output_map.create() return output_map - def _add_stack_event(self, resource_status, resource_status_reason=None, resource_properties=None): - self.events.append(FakeEvent( - stack_id=self.stack_id, - stack_name=self.name, - logical_resource_id=self.name, - physical_resource_id=self.stack_id, - resource_type="AWS::CloudFormation::Stack", - resource_status=resource_status, - resource_status_reason=resource_status_reason, - resource_properties=resource_properties, - )) + def _add_stack_event( + self, resource_status, resource_status_reason=None, resource_properties=None + ): + self.events.append( + FakeEvent( + stack_id=self.stack_id, + stack_name=self.name, + logical_resource_id=self.name, + physical_resource_id=self.stack_id, + resource_type="AWS::CloudFormation::Stack", + resource_status=resource_status, + resource_status_reason=resource_status_reason, + resource_properties=resource_properties, + ) + ) - def _add_resource_event(self, logical_resource_id, resource_status, resource_status_reason=None, resource_properties=None): + def _add_resource_event( + self, + logical_resource_id, + resource_status, + resource_status_reason=None, + resource_properties=None, + ): # not used yet... feel free to help yourself resource = self.resource_map[logical_resource_id] - self.events.append(FakeEvent( - stack_id=self.stack_id, - stack_name=self.name, - logical_resource_id=logical_resource_id, - physical_resource_id=resource.physical_resource_id, - resource_type=resource.type, - resource_status=resource_status, - resource_status_reason=resource_status_reason, - resource_properties=resource_properties, - )) + self.events.append( + FakeEvent( + stack_id=self.stack_id, + stack_name=self.name, + logical_resource_id=logical_resource_id, + physical_resource_id=resource.physical_resource_id, + resource_type=resource.type, + resource_status=resource_status, + resource_status_reason=resource_status_reason, + resource_properties=resource_properties, + ) + ) def _parse_template(self): - yaml.add_multi_constructor('', yaml_tag_constructor) + yaml.add_multi_constructor("", yaml_tag_constructor) try: self.template_dict = yaml.load(self.template, Loader=yaml.Loader) except yaml.parser.ParserError: @@ -250,7 +320,9 @@ class FakeStack(BaseModel): return self.output_map.exports def update(self, template, role_arn=None, parameters=None, tags=None): - self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated") + self._add_stack_event( + "UPDATE_IN_PROGRESS", resource_status_reason="User Initiated" + ) self.template = template self._parse_template() self.resource_map.update(self.template_dict, parameters) @@ -264,15 +336,15 @@ class FakeStack(BaseModel): # TODO: update tags in the resource map def delete(self): - self._add_stack_event("DELETE_IN_PROGRESS", - resource_status_reason="User Initiated") + self._add_stack_event( + "DELETE_IN_PROGRESS", resource_status_reason="User Initiated" + ) self.resource_map.delete() self._add_stack_event("DELETE_COMPLETE") self.status = "DELETE_COMPLETE" class FakeChange(BaseModel): - def __init__(self, action, logical_resource_id, resource_type): self.action = action self.logical_resource_id = logical_resource_id @@ -280,8 +352,21 @@ class FakeChange(BaseModel): class FakeChangeSet(FakeStack): - - def __init__(self, stack_id, stack_name, stack_template, change_set_id, change_set_name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None): + def __init__( + self, + stack_id, + stack_name, + stack_template, + change_set_id, + change_set_name, + template, + parameters, + region_name, + notification_arns=None, + tags=None, + role_arn=None, + cross_stack_resources=None, + ): super(FakeChangeSet, self).__init__( stack_id, stack_name, @@ -306,17 +391,28 @@ class FakeChangeSet(FakeStack): resources_by_action = self.resource_map.diff(self.template_dict, parameters) for action, resources in resources_by_action.items(): for resource_name, resource in resources.items(): - changes.append(FakeChange( - action=action, - logical_resource_id=resource_name, - resource_type=resource['ResourceType'], - )) + changes.append( + FakeChange( + action=action, + logical_resource_id=resource_name, + resource_type=resource["ResourceType"], + ) + ) return changes class FakeEvent(BaseModel): - - def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None): + def __init__( + self, + stack_id, + stack_name, + logical_resource_id, + physical_resource_id, + resource_type, + resource_status, + resource_status_reason=None, + resource_properties=None, + ): self.stack_id = stack_id self.stack_name = stack_name self.logical_resource_id = logical_resource_id @@ -330,7 +426,6 @@ class FakeEvent(BaseModel): class CloudFormationBackend(BaseBackend): - def __init__(self): self.stacks = OrderedDict() self.stacksets = OrderedDict() @@ -338,7 +433,17 @@ class CloudFormationBackend(BaseBackend): self.exports = OrderedDict() self.change_sets = OrderedDict() - def create_stack_set(self, name, template, parameters, tags=None, description=None, region='us-east-1', admin_role=None, execution_role=None): + def create_stack_set( + self, + name, + template, + parameters, + tags=None, + description=None, + region="us-east-1", + admin_role=None, + execution_role=None, + ): stackset_id = generate_stackset_id(name) new_stackset = FakeStackSet( stackset_id=stackset_id, @@ -366,7 +471,9 @@ class CloudFormationBackend(BaseBackend): if self.stacksets[stackset].name == name: self.stacksets[stackset].delete() - def create_stack_instances(self, stackset_name, accounts, regions, parameters, operation_id=None): + def create_stack_instances( + self, stackset_name, accounts, regions, parameters, operation_id=None + ): stackset = self.get_stack_set(stackset_name) stackset.create_stack_instances( @@ -377,9 +484,19 @@ class CloudFormationBackend(BaseBackend): ) return stackset - def update_stack_set(self, stackset_name, template=None, description=None, - parameters=None, tags=None, admin_role=None, execution_role=None, - accounts=None, regions=None, operation_id=None): + def update_stack_set( + self, + stackset_name, + template=None, + description=None, + parameters=None, + tags=None, + admin_role=None, + execution_role=None, + accounts=None, + regions=None, + operation_id=None, + ): stackset = self.get_stack_set(stackset_name) update = stackset.update( template=template, @@ -390,16 +507,28 @@ class CloudFormationBackend(BaseBackend): execution_role=execution_role, accounts=accounts, regions=regions, - operation_id=operation_id + operation_id=operation_id, ) return update - def delete_stack_instances(self, stackset_name, accounts, regions, operation_id=None): + def delete_stack_instances( + self, stackset_name, accounts, regions, operation_id=None + ): stackset = self.get_stack_set(stackset_name) stackset.delete_stack_instances(accounts, regions, operation_id) return stackset - def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False): + def create_stack( + self, + name, + template, + parameters, + region_name, + notification_arns=None, + tags=None, + role_arn=None, + create_change_set=False, + ): stack_id = generate_stack_id(name) new_stack = FakeStack( stack_id=stack_id, @@ -419,10 +548,21 @@ class CloudFormationBackend(BaseBackend): self.exports[export.name] = export return new_stack - def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None): + def create_change_set( + self, + stack_name, + change_set_name, + template, + parameters, + region_name, + change_set_type, + notification_arns=None, + tags=None, + role_arn=None, + ): stack_id = None stack_template = None - if change_set_type == 'UPDATE': + if change_set_type == "UPDATE": stacks = self.stacks.values() stack = None for s in stacks: @@ -449,7 +589,7 @@ class CloudFormationBackend(BaseBackend): notification_arns=notification_arns, tags=tags, role_arn=role_arn, - cross_stack_resources=self.exports + cross_stack_resources=self.exports, ) self.change_sets[change_set_id] = new_change_set self.stacks[stack_id] = new_change_set @@ -488,11 +628,11 @@ class CloudFormationBackend(BaseBackend): stack = self.change_sets[cs] if stack is None: raise ValidationError(stack_name) - if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS': - stack._add_stack_event('CREATE_COMPLETE') + if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS": + stack._add_stack_event("CREATE_COMPLETE") else: - stack._add_stack_event('UPDATE_IN_PROGRESS') - stack._add_stack_event('UPDATE_COMPLETE') + stack._add_stack_event("UPDATE_IN_PROGRESS") + stack._add_stack_event("UPDATE_COMPLETE") return True def describe_stacks(self, name_or_stack_id): @@ -514,9 +654,7 @@ class CloudFormationBackend(BaseBackend): return self.change_sets.values() def list_stacks(self): - return [ - v for v in self.stacks.values() - ] + [ + return [v for v in self.stacks.values()] + [ v for v in self.deleted_stacks.values() ] @@ -558,10 +696,10 @@ class CloudFormationBackend(BaseBackend): all_exports = list(self.exports.values()) if token is None: exports = all_exports[0:100] - next_token = '100' if len(all_exports) > 100 else None + next_token = "100" if len(all_exports) > 100 else None else: token = int(token) - exports = all_exports[token:token + 100] + exports = all_exports[token : token + 100] next_token = str(token + 100) if len(all_exports) > token + 100 else None return exports, next_token @@ -572,7 +710,10 @@ class CloudFormationBackend(BaseBackend): new_stack_export_names = [x.name for x in stack.exports] export_names = self.exports.keys() if not set(export_names).isdisjoint(new_stack_export_names): - raise ValidationError(stack.stack_id, message='Export names must be unique across a given region') + raise ValidationError( + stack.stack_id, + message="Export names must be unique across a given region", + ) cloudformation_backends = {} diff --git a/moto/cloudformation/parsing.py b/moto/cloudformation/parsing.py index f2e03bd81..34d96acc6 100644 --- a/moto/cloudformation/parsing.py +++ b/moto/cloudformation/parsing.py @@ -1,5 +1,4 @@ from __future__ import unicode_literals -import collections import functools import logging import copy @@ -11,6 +10,7 @@ from moto.awslambda import models as lambda_models from moto.batch import models as batch_models from moto.cloudwatch import models as cloudwatch_models from moto.cognitoidentity import models as cognitoidentity_models +from moto.compat import collections_abc from moto.datapipeline import models as datapipeline_models from moto.dynamodb2 import models as dynamodb2_models from moto.ec2 import models as ec2_models @@ -27,8 +27,14 @@ from moto.route53 import models as route53_models from moto.s3 import models as s3_models from moto.sns import models as sns_models from moto.sqs import models as sqs_models +from moto.core import ACCOUNT_ID from .utils import random_suffix -from .exceptions import ExportNotFound, MissingParameterError, UnformattedGetAttTemplateException, ValidationError +from .exceptions import ( + ExportNotFound, + MissingParameterError, + UnformattedGetAttTemplateException, + ValidationError, +) from boto.cloudformation.stack import Output MODEL_MAP = { @@ -100,7 +106,7 @@ NAME_TYPE_MAP = { "AWS::RDS::DBInstance": "DBInstanceIdentifier", "AWS::S3::Bucket": "BucketName", "AWS::SNS::Topic": "TopicName", - "AWS::SQS::Queue": "QueueName" + "AWS::SQS::Queue": "QueueName", } # Just ignore these models types for now @@ -109,13 +115,12 @@ NULL_MODELS = [ "AWS::CloudFormation::WaitConditionHandle", ] -DEFAULT_REGION = 'us-east-1' +DEFAULT_REGION = "us-east-1" logger = logging.getLogger("moto") class LazyDict(dict): - def __getitem__(self, key): val = dict.__getitem__(self, key) if callable(val): @@ -132,10 +137,10 @@ def clean_json(resource_json, resources_map): Eventually, this is where we would add things like function parsing (fn::) """ if isinstance(resource_json, dict): - if 'Ref' in resource_json: + if "Ref" in resource_json: # Parse resource reference - resource = resources_map[resource_json['Ref']] - if hasattr(resource, 'physical_resource_id'): + resource = resources_map[resource_json["Ref"]] + if hasattr(resource, "physical_resource_id"): return resource.physical_resource_id else: return resource @@ -148,74 +153,92 @@ def clean_json(resource_json, resources_map): result = result[clean_json(path, resources_map)] return result - if 'Fn::GetAtt' in resource_json: - resource = resources_map.get(resource_json['Fn::GetAtt'][0]) + if "Fn::GetAtt" in resource_json: + resource = resources_map.get(resource_json["Fn::GetAtt"][0]) if resource is None: return resource_json try: - return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1]) + return resource.get_cfn_attribute(resource_json["Fn::GetAtt"][1]) except NotImplementedError as n: - logger.warning(str(n).format( - resource_json['Fn::GetAtt'][0])) + logger.warning(str(n).format(resource_json["Fn::GetAtt"][0])) except UnformattedGetAttTemplateException: raise ValidationError( - 'Bad Request', + "Bad Request", UnformattedGetAttTemplateException.description.format( - resource_json['Fn::GetAtt'][0], resource_json['Fn::GetAtt'][1])) + resource_json["Fn::GetAtt"][0], resource_json["Fn::GetAtt"][1] + ), + ) - if 'Fn::If' in resource_json: - condition_name, true_value, false_value = resource_json['Fn::If'] + if "Fn::If" in resource_json: + condition_name, true_value, false_value = resource_json["Fn::If"] if resources_map.lazy_condition_map[condition_name]: return clean_json(true_value, resources_map) else: return clean_json(false_value, resources_map) - if 'Fn::Join' in resource_json: - join_list = clean_json(resource_json['Fn::Join'][1], resources_map) - return resource_json['Fn::Join'][0].join([str(x) for x in join_list]) + if "Fn::Join" in resource_json: + join_list = clean_json(resource_json["Fn::Join"][1], resources_map) + return resource_json["Fn::Join"][0].join([str(x) for x in join_list]) - if 'Fn::Split' in resource_json: - to_split = clean_json(resource_json['Fn::Split'][1], resources_map) - return to_split.split(resource_json['Fn::Split'][0]) + if "Fn::Split" in resource_json: + to_split = clean_json(resource_json["Fn::Split"][1], resources_map) + return to_split.split(resource_json["Fn::Split"][0]) - if 'Fn::Select' in resource_json: - select_index = int(resource_json['Fn::Select'][0]) - select_list = clean_json(resource_json['Fn::Select'][1], resources_map) + if "Fn::Select" in resource_json: + select_index = int(resource_json["Fn::Select"][0]) + select_list = clean_json(resource_json["Fn::Select"][1], resources_map) return select_list[select_index] - if 'Fn::Sub' in resource_json: - if isinstance(resource_json['Fn::Sub'], list): + if "Fn::Sub" in resource_json: + if isinstance(resource_json["Fn::Sub"], list): warnings.warn( - "Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation") + "Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation" + ) else: - fn_sub_value = clean_json(resource_json['Fn::Sub'], resources_map) + fn_sub_value = clean_json(resource_json["Fn::Sub"], resources_map) to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value) literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value) for sub in to_sub: - if '.' in sub: - cleaned_ref = clean_json({'Fn::GetAtt': re.findall('(?<=\${)[^"]*?(?=})', sub)[0].split('.')}, resources_map) + if "." in sub: + cleaned_ref = clean_json( + { + "Fn::GetAtt": re.findall('(?<=\${)[^"]*?(?=})', sub)[ + 0 + ].split(".") + }, + resources_map, + ) else: - cleaned_ref = clean_json({'Ref': re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, resources_map) + cleaned_ref = clean_json( + {"Ref": re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, + resources_map, + ) fn_sub_value = fn_sub_value.replace(sub, cleaned_ref) for literal in literals: - fn_sub_value = fn_sub_value.replace(literal, literal.replace('!', '')) + fn_sub_value = fn_sub_value.replace( + literal, literal.replace("!", "") + ) return fn_sub_value pass - if 'Fn::ImportValue' in resource_json: - cleaned_val = clean_json(resource_json['Fn::ImportValue'], resources_map) - values = [x.value for x in resources_map.cross_stack_resources.values() if x.name == cleaned_val] + if "Fn::ImportValue" in resource_json: + cleaned_val = clean_json(resource_json["Fn::ImportValue"], resources_map) + values = [ + x.value + for x in resources_map.cross_stack_resources.values() + if x.name == cleaned_val + ] if any(values): return values[0] else: raise ExportNotFound(cleaned_val) - if 'Fn::GetAZs' in resource_json: - region = resource_json.get('Fn::GetAZs') or DEFAULT_REGION + if "Fn::GetAZs" in resource_json: + region = resource_json.get("Fn::GetAZs") or DEFAULT_REGION result = [] # TODO: make this configurable, to reflect the real AWS AZs - for az in ('a', 'b', 'c', 'd'): - result.append('%s%s' % (region, az)) + for az in ("a", "b", "c", "d"): + result.append("%s%s" % (region, az)) return result cleaned_json = {} @@ -246,58 +269,69 @@ def resource_name_property_from_type(resource_type): def generate_resource_name(resource_type, stack_name, logical_id): - if resource_type in ["AWS::ElasticLoadBalancingV2::TargetGroup", - "AWS::ElasticLoadBalancingV2::LoadBalancer"]: + if resource_type in [ + "AWS::ElasticLoadBalancingV2::TargetGroup", + "AWS::ElasticLoadBalancingV2::LoadBalancer", + ]: # Target group names need to be less than 32 characters, so when cloudformation creates a name for you # it makes sure to stay under that limit - name_prefix = '{0}-{1}'.format(stack_name, logical_id) + name_prefix = "{0}-{1}".format(stack_name, logical_id) my_random_suffix = random_suffix() - truncated_name_prefix = name_prefix[0:32 - (len(my_random_suffix) + 1)] + truncated_name_prefix = name_prefix[0 : 32 - (len(my_random_suffix) + 1)] # if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is # not allowed - if truncated_name_prefix.endswith('-'): + if truncated_name_prefix.endswith("-"): truncated_name_prefix = truncated_name_prefix[:-1] - return '{0}-{1}'.format(truncated_name_prefix, my_random_suffix) + return "{0}-{1}".format(truncated_name_prefix, my_random_suffix) else: - return '{0}-{1}-{2}'.format(stack_name, logical_id, random_suffix()) + return "{0}-{1}-{2}".format(stack_name, logical_id, random_suffix()) def parse_resource(logical_id, resource_json, resources_map): - resource_type = resource_json['Type'] + resource_type = resource_json["Type"] resource_class = resource_class_from_type(resource_type) if not resource_class: warnings.warn( - "Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(resource_type)) + "Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format( + resource_type + ) + ) return None resource_json = clean_json(resource_json, resources_map) resource_name_property = resource_name_property_from_type(resource_type) if resource_name_property: - if 'Properties' not in resource_json: - resource_json['Properties'] = dict() - if resource_name_property not in resource_json['Properties']: - resource_json['Properties'][resource_name_property] = generate_resource_name( - resource_type, resources_map.get('AWS::StackName'), logical_id) - resource_name = resource_json['Properties'][resource_name_property] + if "Properties" not in resource_json: + resource_json["Properties"] = dict() + if resource_name_property not in resource_json["Properties"]: + resource_json["Properties"][ + resource_name_property + ] = generate_resource_name( + resource_type, resources_map.get("AWS::StackName"), logical_id + ) + resource_name = resource_json["Properties"][resource_name_property] else: - resource_name = generate_resource_name(resource_type, resources_map.get('AWS::StackName'), logical_id) + resource_name = generate_resource_name( + resource_type, resources_map.get("AWS::StackName"), logical_id + ) return resource_class, resource_json, resource_name def parse_and_create_resource(logical_id, resource_json, resources_map, region_name): - condition = resource_json.get('Condition') + condition = resource_json.get("Condition") if condition and not resources_map.lazy_condition_map[condition]: # If this has a False condition, don't create the resource return None - resource_type = resource_json['Type'] + resource_type = resource_json["Type"] resource_tuple = parse_resource(logical_id, resource_json, resources_map) if not resource_tuple: return None resource_class, resource_json, resource_name = resource_tuple resource = resource_class.create_from_cloudformation_json( - resource_name, resource_json, region_name) + resource_name, resource_json, region_name + ) resource.type = resource_type resource.logical_resource_id = logical_id return resource @@ -305,24 +339,27 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n def parse_and_update_resource(logical_id, resource_json, resources_map, region_name): resource_class, new_resource_json, new_resource_name = parse_resource( - logical_id, resource_json, resources_map) + logical_id, resource_json, resources_map + ) original_resource = resources_map[logical_id] new_resource = resource_class.update_from_cloudformation_json( original_resource=original_resource, new_resource_name=new_resource_name, cloudformation_json=new_resource_json, - region_name=region_name + region_name=region_name, ) - new_resource.type = resource_json['Type'] + new_resource.type = resource_json["Type"] new_resource.logical_resource_id = logical_id return new_resource def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name): resource_class, resource_json, resource_name = parse_resource( - logical_id, resource_json, resources_map) + logical_id, resource_json, resources_map + ) resource_class.delete_from_cloudformation_json( - resource_name, resource_json, region_name) + resource_name, resource_json, region_name + ) def parse_condition(condition, resources_map, condition_map): @@ -334,8 +371,8 @@ def parse_condition(condition, resources_map, condition_map): condition_values = [] for value in list(condition.values())[0]: # Check if we are referencing another Condition - if 'Condition' in value: - condition_values.append(condition_map[value['Condition']]) + if "Condition" in value: + condition_values.append(condition_map[value["Condition"]]) else: condition_values.append(clean_json(value, resources_map)) @@ -344,36 +381,49 @@ def parse_condition(condition, resources_map, condition_map): elif condition_operator == "Fn::Not": return not parse_condition(condition_values[0], resources_map, condition_map) elif condition_operator == "Fn::And": - return all([ - parse_condition(condition_value, resources_map, condition_map) - for condition_value - in condition_values]) + return all( + [ + parse_condition(condition_value, resources_map, condition_map) + for condition_value in condition_values + ] + ) elif condition_operator == "Fn::Or": - return any([ - parse_condition(condition_value, resources_map, condition_map) - for condition_value - in condition_values]) + return any( + [ + parse_condition(condition_value, resources_map, condition_map) + for condition_value in condition_values + ] + ) def parse_output(output_logical_id, output_json, resources_map): output_json = clean_json(output_json, resources_map) output = Output() output.key = output_logical_id - output.value = clean_json(output_json['Value'], resources_map) - output.description = output_json.get('Description') + output.value = clean_json(output_json["Value"], resources_map) + output.description = output_json.get("Description") return output -class ResourceMap(collections.Mapping): +class ResourceMap(collections_abc.Mapping): """ This is a lazy loading map for resources. This allows us to create resources without needing to create a full dependency tree. Upon creation, each each resources is passed this lazy map that it can grab dependencies from. """ - def __init__(self, stack_id, stack_name, parameters, tags, region_name, template, cross_stack_resources): + def __init__( + self, + stack_id, + stack_name, + parameters, + tags, + region_name, + template, + cross_stack_resources, + ): self._template = template - self._resource_json_map = template['Resources'] + self._resource_json_map = template["Resources"] self._region_name = region_name self.input_parameters = parameters self.tags = copy.deepcopy(tags) @@ -382,7 +432,7 @@ class ResourceMap(collections.Mapping): # Create the default resources self._parsed_resources = { - "AWS::AccountId": "123456789012", + "AWS::AccountId": ACCOUNT_ID, "AWS::Region": self._region_name, "AWS::StackId": stack_id, "AWS::StackName": stack_name, @@ -401,7 +451,8 @@ class ResourceMap(collections.Mapping): if not resource_json: raise KeyError(resource_logical_id) new_resource = parse_and_create_resource( - resource_logical_id, resource_json, self, self._region_name) + resource_logical_id, resource_json, self, self._region_name + ) if new_resource is not None: self._parsed_resources[resource_logical_id] = new_resource return new_resource @@ -417,13 +468,13 @@ class ResourceMap(collections.Mapping): return self._resource_json_map.keys() def load_mapping(self): - self._parsed_resources.update(self._template.get('Mappings', {})) + self._parsed_resources.update(self._template.get("Mappings", {})) def load_parameters(self): - parameter_slots = self._template.get('Parameters', {}) + parameter_slots = self._template.get("Parameters", {}) for parameter_name, parameter in parameter_slots.items(): # Set the default values. - self.resolved_parameters[parameter_name] = parameter.get('Default') + self.resolved_parameters[parameter_name] = parameter.get("Default") # Set any input parameters that were passed self.no_echo_parameter_keys = [] @@ -431,11 +482,11 @@ class ResourceMap(collections.Mapping): if key in self.resolved_parameters: parameter_slot = parameter_slots[key] - value_type = parameter_slot.get('Type', 'String') - if value_type == 'CommaDelimitedList' or value_type.startswith("List"): - value = value.split(',') + value_type = parameter_slot.get("Type", "String") + if value_type == "CommaDelimitedList" or value_type.startswith("List"): + value = value.split(",") - if parameter_slot.get('NoEcho'): + if parameter_slot.get("NoEcho"): self.no_echo_parameter_keys.append(key) self.resolved_parameters[key] = value @@ -449,11 +500,15 @@ class ResourceMap(collections.Mapping): self._parsed_resources.update(self.resolved_parameters) def load_conditions(self): - conditions = self._template.get('Conditions', {}) + conditions = self._template.get("Conditions", {}) self.lazy_condition_map = LazyDict() for condition_name, condition in conditions.items(): - self.lazy_condition_map[condition_name] = functools.partial(parse_condition, - condition, self._parsed_resources, self.lazy_condition_map) + self.lazy_condition_map[condition_name] = functools.partial( + parse_condition, + condition, + self._parsed_resources, + self.lazy_condition_map, + ) for condition_name in self.lazy_condition_map: self.lazy_condition_map[condition_name] @@ -465,13 +520,18 @@ class ResourceMap(collections.Mapping): # Since this is a lazy map, to create every object we just need to # iterate through self. - self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'), - 'aws:cloudformation:stack-id': self.get('AWS::StackId')}) + self.tags.update( + { + "aws:cloudformation:stack-name": self.get("AWS::StackName"), + "aws:cloudformation:stack-id": self.get("AWS::StackId"), + } + ) for resource in self.resources: if isinstance(self[resource], ec2_models.TaggedEC2Resource): - self.tags['aws:cloudformation:logical-id'] = resource + self.tags["aws:cloudformation:logical-id"] = resource ec2_models.ec2_backends[self._region_name].create_tags( - [self[resource].physical_resource_id], self.tags) + [self[resource].physical_resource_id], self.tags + ) def diff(self, template, parameters=None): if parameters: @@ -481,36 +541,35 @@ class ResourceMap(collections.Mapping): self.load_conditions() old_template = self._resource_json_map - new_template = template['Resources'] + new_template = template["Resources"] resource_names_by_action = { - 'Add': set(new_template) - set(old_template), - 'Modify': set(name for name in new_template if name in old_template and new_template[ - name] != old_template[name]), - 'Remove': set(old_template) - set(new_template) - } - resources_by_action = { - 'Add': {}, - 'Modify': {}, - 'Remove': {}, + "Add": set(new_template) - set(old_template), + "Modify": set( + name + for name in new_template + if name in old_template and new_template[name] != old_template[name] + ), + "Remove": set(old_template) - set(new_template), } + resources_by_action = {"Add": {}, "Modify": {}, "Remove": {}} - for resource_name in resource_names_by_action['Add']: - resources_by_action['Add'][resource_name] = { - 'LogicalResourceId': resource_name, - 'ResourceType': new_template[resource_name]['Type'] + for resource_name in resource_names_by_action["Add"]: + resources_by_action["Add"][resource_name] = { + "LogicalResourceId": resource_name, + "ResourceType": new_template[resource_name]["Type"], } - for resource_name in resource_names_by_action['Modify']: - resources_by_action['Modify'][resource_name] = { - 'LogicalResourceId': resource_name, - 'ResourceType': new_template[resource_name]['Type'] + for resource_name in resource_names_by_action["Modify"]: + resources_by_action["Modify"][resource_name] = { + "LogicalResourceId": resource_name, + "ResourceType": new_template[resource_name]["Type"], } - for resource_name in resource_names_by_action['Remove']: - resources_by_action['Remove'][resource_name] = { - 'LogicalResourceId': resource_name, - 'ResourceType': old_template[resource_name]['Type'] + for resource_name in resource_names_by_action["Remove"]: + resources_by_action["Remove"][resource_name] = { + "LogicalResourceId": resource_name, + "ResourceType": old_template[resource_name]["Type"], } return resources_by_action @@ -519,35 +578,38 @@ class ResourceMap(collections.Mapping): resources_by_action = self.diff(template, parameters) old_template = self._resource_json_map - new_template = template['Resources'] + new_template = template["Resources"] self._resource_json_map = new_template - for resource_name, resource in resources_by_action['Add'].items(): + for resource_name, resource in resources_by_action["Add"].items(): resource_json = new_template[resource_name] new_resource = parse_and_create_resource( - resource_name, resource_json, self, self._region_name) + resource_name, resource_json, self, self._region_name + ) self._parsed_resources[resource_name] = new_resource - for resource_name, resource in resources_by_action['Remove'].items(): + for resource_name, resource in resources_by_action["Remove"].items(): resource_json = old_template[resource_name] parse_and_delete_resource( - resource_name, resource_json, self, self._region_name) + resource_name, resource_json, self, self._region_name + ) self._parsed_resources.pop(resource_name) tries = 1 - while resources_by_action['Modify'] and tries < 5: - for resource_name, resource in resources_by_action['Modify'].copy().items(): + while resources_by_action["Modify"] and tries < 5: + for resource_name, resource in resources_by_action["Modify"].copy().items(): resource_json = new_template[resource_name] try: changed_resource = parse_and_update_resource( - resource_name, resource_json, self, self._region_name) + resource_name, resource_json, self, self._region_name + ) except Exception as e: # skip over dependency violations, and try again in a # second pass last_exception = e else: self._parsed_resources[resource_name] = changed_resource - del resources_by_action['Modify'][resource_name] + del resources_by_action["Modify"][resource_name] tries += 1 if tries == 5: raise last_exception @@ -559,7 +621,7 @@ class ResourceMap(collections.Mapping): for resource in remaining_resources.copy(): parsed_resource = self._parsed_resources.get(resource) try: - if parsed_resource and hasattr(parsed_resource, 'delete'): + if parsed_resource and hasattr(parsed_resource, "delete"): parsed_resource.delete(self._region_name) except Exception as e: # skip over dependency violations, and try again in a @@ -572,12 +634,11 @@ class ResourceMap(collections.Mapping): raise last_exception -class OutputMap(collections.Mapping): - +class OutputMap(collections_abc.Mapping): def __init__(self, resources, template, stack_id): self._template = template self._stack_id = stack_id - self._output_json_map = template.get('Outputs') + self._output_json_map = template.get("Outputs") # Create the default resources self._resource_map = resources @@ -591,7 +652,8 @@ class OutputMap(collections.Mapping): else: output_json = self._output_json_map.get(output_logical_id) new_output = parse_output( - output_logical_id, output_json, self._resource_map) + output_logical_id, output_json, self._resource_map + ) self._parsed_outputs[output_logical_id] = new_output return new_output @@ -610,9 +672,11 @@ class OutputMap(collections.Mapping): exports = [] if self.outputs: for key, value in self._output_json_map.items(): - if value.get('Export'): - cleaned_name = clean_json(value['Export'].get('Name'), self._resource_map) - cleaned_value = clean_json(value.get('Value'), self._resource_map) + if value.get("Export"): + cleaned_name = clean_json( + value["Export"].get("Name"), self._resource_map + ) + cleaned_value = clean_json(value.get("Value"), self._resource_map) exports.append(Export(self._stack_id, cleaned_name, cleaned_value)) return exports @@ -622,7 +686,6 @@ class OutputMap(collections.Mapping): class Export(object): - def __init__(self, exporting_stack_id, name, value): self._exporting_stack_id = exporting_stack_id self._name = name diff --git a/moto/cloudformation/responses.py b/moto/cloudformation/responses.py index 80970262f..bf68a6325 100644 --- a/moto/cloudformation/responses.py +++ b/moto/cloudformation/responses.py @@ -7,12 +7,12 @@ from six.moves.urllib.parse import urlparse from moto.core.responses import BaseResponse from moto.core.utils import amzn_request_id from moto.s3 import s3_backend +from moto.core import ACCOUNT_ID from .models import cloudformation_backends from .exceptions import ValidationError class CloudFormationResponse(BaseResponse): - @property def cloudformation_backend(self): return cloudformation_backends[self.region] @@ -20,17 +20,18 @@ class CloudFormationResponse(BaseResponse): def _get_stack_from_s3_url(self, template_url): template_url_parts = urlparse(template_url) if "localhost" in template_url: - bucket_name, key_name = template_url_parts.path.lstrip( - "/").split("/", 1) + bucket_name, key_name = template_url_parts.path.lstrip("/").split("/", 1) else: - if template_url_parts.netloc.endswith('amazonaws.com') \ - and template_url_parts.netloc.startswith('s3'): + if template_url_parts.netloc.endswith( + "amazonaws.com" + ) and template_url_parts.netloc.startswith("s3"): # Handle when S3 url uses amazon url with bucket in path # Also handles getting region as technically s3 is region'd # region = template_url.netloc.split('.')[1] - bucket_name, key_name = template_url_parts.path.lstrip( - "/").split("/", 1) + bucket_name, key_name = template_url_parts.path.lstrip("/").split( + "/", 1 + ) else: bucket_name = template_url_parts.netloc.split(".")[0] key_name = template_url_parts.path.lstrip("/") @@ -39,24 +40,26 @@ class CloudFormationResponse(BaseResponse): return key.value.decode("utf-8") def create_stack(self): - stack_name = self._get_param('StackName') - stack_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') - role_arn = self._get_param('RoleARN') + stack_name = self._get_param("StackName") + stack_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") + role_arn = self._get_param("RoleARN") parameters_list = self._get_list_prefix("Parameters.member") - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) # Hack dict-comprehension - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in parameters_list - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in parameters_list + ] + ) if template_url: stack_body = self._get_stack_from_s3_url(template_url) - stack_notification_arns = self._get_multi_param( - 'NotificationARNs.member') + stack_notification_arns = self._get_multi_param("NotificationARNs.member") stack = self.cloudformation_backend.create_stack( name=stack_name, @@ -68,34 +71,37 @@ class CloudFormationResponse(BaseResponse): role_arn=role_arn, ) if self.request_json: - return json.dumps({ - 'CreateStackResponse': { - 'CreateStackResult': { - 'StackId': stack.stack_id, + return json.dumps( + { + "CreateStackResponse": { + "CreateStackResult": {"StackId": stack.stack_id} } } - }) + ) else: template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE) return template.render(stack=stack) @amzn_request_id def create_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') - stack_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') - role_arn = self._get_param('RoleARN') - update_or_create = self._get_param('ChangeSetType', 'CREATE') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") + stack_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") + role_arn = self._get_param("RoleARN") + update_or_create = self._get_param("ChangeSetType", "CREATE") parameters_list = self._get_list_prefix("Parameters.member") - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) - parameters = {param['parameter_key']: param['parameter_value'] - for param in parameters_list} + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) + parameters = { + param["parameter_key"]: param["parameter_value"] + for param in parameters_list + } if template_url: stack_body = self._get_stack_from_s3_url(template_url) - stack_notification_arns = self._get_multi_param( - 'NotificationARNs.member') + stack_notification_arns = self._get_multi_param("NotificationARNs.member") change_set_id, stack_id = self.cloudformation_backend.create_change_set( stack_name=stack_name, change_set_name=change_set_name, @@ -108,66 +114,64 @@ class CloudFormationResponse(BaseResponse): change_set_type=update_or_create, ) if self.request_json: - return json.dumps({ - 'CreateChangeSetResponse': { - 'CreateChangeSetResult': { - 'Id': change_set_id, - 'StackId': stack_id, + return json.dumps( + { + "CreateChangeSetResponse": { + "CreateChangeSetResult": { + "Id": change_set_id, + "StackId": stack_id, + } } } - }) + ) else: template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render(stack_id=stack_id, change_set_id=change_set_id) def delete_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") - self.cloudformation_backend.delete_change_set(change_set_name=change_set_name, stack_name=stack_name) + self.cloudformation_backend.delete_change_set( + change_set_name=change_set_name, stack_name=stack_name + ) if self.request_json: - return json.dumps({ - 'DeleteChangeSetResponse': { - 'DeleteChangeSetResult': {}, - } - }) + return json.dumps( + {"DeleteChangeSetResponse": {"DeleteChangeSetResult": {}}} + ) else: template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render() def describe_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") change_set = self.cloudformation_backend.describe_change_set( - change_set_name=change_set_name, - stack_name=stack_name, + change_set_name=change_set_name, stack_name=stack_name ) template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render(change_set=change_set) @amzn_request_id def execute_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") self.cloudformation_backend.execute_change_set( - stack_name=stack_name, - change_set_name=change_set_name, + stack_name=stack_name, change_set_name=change_set_name ) if self.request_json: - return json.dumps({ - 'ExecuteChangeSetResponse': { - 'ExecuteChangeSetResult': {}, - } - }) + return json.dumps( + {"ExecuteChangeSetResponse": {"ExecuteChangeSetResult": {}}} + ) else: template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render() def describe_stacks(self): stack_name_or_id = None - if self._get_param('StackName'): - stack_name_or_id = self.querystring.get('StackName')[0] - token = self._get_param('NextToken') + if self._get_param("StackName"): + stack_name_or_id = self.querystring.get("StackName")[0] + token = self._get_param("NextToken") stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id) stack_ids = [stack.stack_id for stack in stacks] if token: @@ -175,7 +179,7 @@ class CloudFormationResponse(BaseResponse): else: start = 0 max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB - stacks_resp = stacks[start:start + max_results] + stacks_resp = stacks[start : start + max_results] next_token = None if len(stacks) > (start + max_results): next_token = stacks_resp[-1].stack_id @@ -183,9 +187,9 @@ class CloudFormationResponse(BaseResponse): return template.render(stacks=stacks_resp, next_token=next_token) def describe_stack_resource(self): - stack_name = self._get_param('StackName') + stack_name = self._get_param("StackName") stack = self.cloudformation_backend.get_stack(stack_name) - logical_resource_id = self._get_param('LogicalResourceId') + logical_resource_id = self._get_param("LogicalResourceId") for stack_resource in stack.stack_resources: if stack_resource.logical_resource_id == logical_resource_id: @@ -194,19 +198,18 @@ class CloudFormationResponse(BaseResponse): else: raise ValidationError(logical_resource_id) - template = self.response_template( - DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE) + template = self.response_template(DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE) return template.render(stack=stack, resource=resource) def describe_stack_resources(self): - stack_name = self._get_param('StackName') + stack_name = self._get_param("StackName") stack = self.cloudformation_backend.get_stack(stack_name) template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE) return template.render(stack=stack) def describe_stack_events(self): - stack_name = self._get_param('StackName') + stack_name = self._get_param("StackName") stack = self.cloudformation_backend.get_stack(stack_name) template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE) @@ -223,68 +226,82 @@ class CloudFormationResponse(BaseResponse): return template.render(stacks=stacks) def list_stack_resources(self): - stack_name_or_id = self._get_param('StackName') - resources = self.cloudformation_backend.list_stack_resources( - stack_name_or_id) + stack_name_or_id = self._get_param("StackName") + resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id) template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE) return template.render(resources=resources) def get_template(self): - name_or_stack_id = self.querystring.get('StackName')[0] + name_or_stack_id = self.querystring.get("StackName")[0] stack = self.cloudformation_backend.get_stack(name_or_stack_id) if self.request_json: - return json.dumps({ - "GetTemplateResponse": { - "GetTemplateResult": { - "TemplateBody": stack.template, - "ResponseMetadata": { - "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + return json.dumps( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": stack.template, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - }) + ) else: template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE) return template.render(stack=stack) def update_stack(self): - stack_name = self._get_param('StackName') - role_arn = self._get_param('RoleARN') - template_url = self._get_param('TemplateURL') - stack_body = self._get_param('TemplateBody') + stack_name = self._get_param("StackName") + role_arn = self._get_param("RoleARN") + template_url = self._get_param("TemplateURL") + stack_body = self._get_param("TemplateBody") stack = self.cloudformation_backend.get_stack(stack_name) - if self._get_param('UsePreviousTemplate') == "true": + if self._get_param("UsePreviousTemplate") == "true": stack_body = stack.template elif not stack_body and template_url: stack_body = self._get_stack_from_s3_url(template_url) incoming_params = self._get_list_prefix("Parameters.member") - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in incoming_params if 'parameter_value' in parameter - ]) - previous = dict([ - (parameter['parameter_key'], stack.parameters[parameter['parameter_key']]) - for parameter - in incoming_params if 'use_previous_value' in parameter - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in incoming_params + if "parameter_value" in parameter + ] + ) + previous = dict( + [ + ( + parameter["parameter_key"], + stack.parameters[parameter["parameter_key"]], + ) + for parameter in incoming_params + if "use_previous_value" in parameter + ] + ) parameters.update(previous) # boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't # end up containing anything we can use to differentiate between passing an empty value versus not # passing anything. so until that changes, moto won't be able to clear tags, only update them. - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) # so that if we don't pass the parameter, we don't clear all the tags accidentally if not tags: tags = None stack = self.cloudformation_backend.get_stack(stack_name) - if stack.status == 'ROLLBACK_COMPLETE': + if stack.status == "ROLLBACK_COMPLETE": raise ValidationError( - stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id)) + stack.stack_id, + message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format( + stack.stack_id + ), + ) stack = self.cloudformation_backend.update_stack( name=stack_name, @@ -295,11 +312,7 @@ class CloudFormationResponse(BaseResponse): ) if self.request_json: stack_body = { - 'UpdateStackResponse': { - 'UpdateStackResult': { - 'StackId': stack.name, - } - } + "UpdateStackResponse": {"UpdateStackResult": {"StackId": stack.name}} } return json.dumps(stack_body) else: @@ -307,56 +320,57 @@ class CloudFormationResponse(BaseResponse): return template.render(stack=stack) def delete_stack(self): - name_or_stack_id = self.querystring.get('StackName')[0] + name_or_stack_id = self.querystring.get("StackName")[0] self.cloudformation_backend.delete_stack(name_or_stack_id) if self.request_json: - return json.dumps({ - 'DeleteStackResponse': { - 'DeleteStackResult': {}, - } - }) + return json.dumps({"DeleteStackResponse": {"DeleteStackResult": {}}}) else: template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE) return template.render() def list_exports(self): - token = self._get_param('NextToken') + token = self._get_param("NextToken") exports, next_token = self.cloudformation_backend.list_exports(token=token) template = self.response_template(LIST_EXPORTS_RESPONSE) return template.render(exports=exports, next_token=next_token) def validate_template(self): - cfn_lint = self.cloudformation_backend.validate_template(self._get_param('TemplateBody')) + cfn_lint = self.cloudformation_backend.validate_template( + self._get_param("TemplateBody") + ) if cfn_lint: raise ValidationError(cfn_lint[0].message) description = "" try: - description = json.loads(self._get_param('TemplateBody'))['Description'] + description = json.loads(self._get_param("TemplateBody"))["Description"] except (ValueError, KeyError): pass try: - description = yaml.load(self._get_param('TemplateBody'))['Description'] + description = yaml.load(self._get_param("TemplateBody"))["Description"] except (yaml.ParserError, KeyError): pass template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE) return template.render(description=description) def create_stack_set(self): - stackset_name = self._get_param('StackSetName') - stack_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') + stackset_name = self._get_param("StackSetName") + stack_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") # role_arn = self._get_param('RoleARN') parameters_list = self._get_list_prefix("Parameters.member") - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) # Copy-Pasta - Hack dict-comprehension - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in parameters_list - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in parameters_list + ] + ) if template_url: stack_body = self._get_stack_from_s3_url(template_url) @@ -368,59 +382,67 @@ class CloudFormationResponse(BaseResponse): # role_arn=role_arn, ) if self.request_json: - return json.dumps({ - 'CreateStackSetResponse': { - 'CreateStackSetResult': { - 'StackSetId': stackset.stackset_id, + return json.dumps( + { + "CreateStackSetResponse": { + "CreateStackSetResult": {"StackSetId": stackset.stackset_id} } } - }) + ) else: template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE) return template.render(stackset=stackset) def create_stack_instances(self): - stackset_name = self._get_param('StackSetName') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - parameters = self._get_multi_param('ParameterOverrides.member') - self.cloudformation_backend.create_stack_instances(stackset_name, accounts, regions, parameters) + stackset_name = self._get_param("StackSetName") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + parameters = self._get_multi_param("ParameterOverrides.member") + self.cloudformation_backend.create_stack_instances( + stackset_name, accounts, regions, parameters + ) template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE) return template.render() def delete_stack_set(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") self.cloudformation_backend.delete_stack_set(stackset_name) template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE) return template.render() def delete_stack_instances(self): - stackset_name = self._get_param('StackSetName') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - operation = self.cloudformation_backend.delete_stack_instances(stackset_name, accounts, regions) + stackset_name = self._get_param("StackSetName") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + operation = self.cloudformation_backend.delete_stack_instances( + stackset_name, accounts, regions + ) template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE) return template.render(operation=operation) def describe_stack_set(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") stackset = self.cloudformation_backend.get_stack_set(stackset_name) if not stackset.admin_role: - stackset.admin_role = 'arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole' + stackset.admin_role = "arn:aws:iam::{AccountId}:role/AWSCloudFormationStackSetAdministrationRole".format( + AccountId=ACCOUNT_ID + ) if not stackset.execution_role: - stackset.execution_role = 'AWSCloudFormationStackSetExecutionRole' + stackset.execution_role = "AWSCloudFormationStackSetExecutionRole" template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE) return template.render(stackset=stackset) def describe_stack_instance(self): - stackset_name = self._get_param('StackSetName') - account = self._get_param('StackInstanceAccount') - region = self._get_param('StackInstanceRegion') + stackset_name = self._get_param("StackSetName") + account = self._get_param("StackInstanceAccount") + region = self._get_param("StackInstanceRegion") - instance = self.cloudformation_backend.get_stack_set(stackset_name).instances.get_instance(account, region) + instance = self.cloudformation_backend.get_stack_set( + stackset_name + ).instances.get_instance(account, region) template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE) rendered = template.render(instance=instance) return rendered @@ -431,61 +453,66 @@ class CloudFormationResponse(BaseResponse): return template.render(stacksets=stacksets) def list_stack_instances(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") stackset = self.cloudformation_backend.get_stack_set(stackset_name) template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE) return template.render(stackset=stackset) def list_stack_set_operations(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") stackset = self.cloudformation_backend.get_stack_set(stackset_name) template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE) return template.render(stackset=stackset) def stop_stack_set_operation(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") stackset = self.cloudformation_backend.get_stack_set(stackset_name) - stackset.update_operation(operation_id, 'STOPPED') + stackset.update_operation(operation_id, "STOPPED") template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE) return template.render() def describe_stack_set_operation(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") stackset = self.cloudformation_backend.get_stack_set(stackset_name) operation = stackset.get_operation(operation_id) template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE) return template.render(stackset=stackset, operation=operation) def list_stack_set_operation_results(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") stackset = self.cloudformation_backend.get_stack_set(stackset_name) operation = stackset.get_operation(operation_id) - template = self.response_template(LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE) + template = self.response_template( + LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE + ) return template.render(operation=operation) def update_stack_set(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') - description = self._get_param('Description') - execution_role = self._get_param('ExecutionRoleName') - admin_role = self._get_param('AdministrationRoleARN') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - template_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") + description = self._get_param("Description") + execution_role = self._get_param("ExecutionRoleName") + admin_role = self._get_param("AdministrationRoleARN") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + template_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") if template_url: template_body = self._get_stack_from_s3_url(template_url) - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) parameters_list = self._get_list_prefix("Parameters.member") - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in parameters_list - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in parameters_list + ] + ) operation = self.cloudformation_backend.update_stack_set( stackset_name=stackset_name, template=template_body, @@ -496,18 +523,20 @@ class CloudFormationResponse(BaseResponse): execution_role=execution_role, accounts=accounts, regions=regions, - operation_id=operation_id + operation_id=operation_id, ) template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE) return template.render(operation=operation) def update_stack_instances(self): - stackset_name = self._get_param('StackSetName') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - parameters = self._get_multi_param('ParameterOverrides.member') - operation = self.cloudformation_backend.get_stack_set(stackset_name).update_instances(accounts, regions, parameters) + stackset_name = self._get_param("StackSetName") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + parameters = self._get_multi_param("ParameterOverrides.member") + operation = self.cloudformation_backend.get_stack_set( + stackset_name + ).update_instances(accounts, regions, parameters) template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE) return template.render(operation=operation) @@ -1025,11 +1054,14 @@ STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE = """ """ -DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = """ +DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = ( + """ {{ stackset.execution_role }} - arn:aws:iam::123456789012:role/{{ stackset.admin_role }} + arn:aws:iam::""" + + ACCOUNT_ID + + """:role/{{ stackset.admin_role }} {{ stackset.id }} {{ operation.CreationTimestamp }} {{ operation.OperationId }} @@ -1046,15 +1078,19 @@ DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE = """ """ +) -LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = """ +LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = ( + """ {% for instance in operation.Instances %} {% for account, region in instance.items() %} - Function not found: arn:aws:lambda:us-west-2:123456789012:function:AWSCloudFormationStackSetAccountGate + Function not found: arn:aws:lambda:us-west-2:""" + + ACCOUNT_ID + + """:function:AWSCloudFormationStackSetAccountGate SKIPPED {{ region }} @@ -1070,3 +1106,4 @@ LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE = """ """ +) diff --git a/moto/cloudformation/urls.py b/moto/cloudformation/urls.py index 468c68d98..84251e82b 100644 --- a/moto/cloudformation/urls.py +++ b/moto/cloudformation/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import CloudFormationResponse -url_bases = [ - "https?://cloudformation.(.+).amazonaws.com", -] +url_bases = ["https?://cloudformation.(.+).amazonaws.com"] -url_paths = { - '{0}/$': CloudFormationResponse.dispatch, -} +url_paths = {"{0}/$": CloudFormationResponse.dispatch} diff --git a/moto/cloudformation/utils.py b/moto/cloudformation/utils.py index e4290ce1a..cd8481002 100644 --- a/moto/cloudformation/utils.py +++ b/moto/cloudformation/utils.py @@ -7,48 +7,56 @@ import os import string from cfnlint import decode, core +from moto.core import ACCOUNT_ID def generate_stack_id(stack_name, region="us-east-1", account="123456789"): random_id = uuid.uuid4() - return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(region, account, stack_name, random_id) + return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format( + region, account, stack_name, random_id + ) def generate_changeset_id(changeset_name, region_name): random_id = uuid.uuid4() - return 'arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}'.format(region_name, changeset_name, random_id) + return "arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}".format( + region_name, changeset_name, random_id + ) def generate_stackset_id(stackset_name): random_id = uuid.uuid4() - return '{}:{}'.format(stackset_name, random_id) + return "{}:{}".format(stackset_name, random_id) def generate_stackset_arn(stackset_id, region_name): - return 'arn:aws:cloudformation:{}:123456789012:stackset/{}'.format(region_name, stackset_id) + return "arn:aws:cloudformation:{}:{}:stackset/{}".format( + region_name, ACCOUNT_ID, stackset_id + ) def random_suffix(): size = 12 chars = list(range(10)) + list(string.ascii_uppercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) def yaml_tag_constructor(loader, tag, node): """convert shorthand intrinsic function to full name """ + def _f(loader, tag, node): - if tag == '!GetAtt': - return node.value.split('.') + if tag == "!GetAtt": + return node.value.split(".") elif type(node) == yaml.SequenceNode: return loader.construct_sequence(node) else: return node.value - if tag == '!Ref': - key = 'Ref' + if tag == "!Ref": + key = "Ref" else: - key = 'Fn::{}'.format(tag[1:]) + key = "Fn::{}".format(tag[1:]) return {key: _f(loader, tag, node)} @@ -71,13 +79,9 @@ def validate_template_cfn_lint(template): rules = core.get_rules([], [], []) # Use us-east-1 region (spec file) for validation - regions = ['us-east-1'] + regions = ["us-east-1"] # Process all the rules and gather the errors - matches = core.run_checks( - abs_filename, - template, - rules, - regions) + matches = core.run_checks(abs_filename, template, rules, regions) return matches diff --git a/moto/cloudwatch/__init__.py b/moto/cloudwatch/__init__.py index 861fb703a..86a774933 100644 --- a/moto/cloudwatch/__init__.py +++ b/moto/cloudwatch/__init__.py @@ -1,6 +1,6 @@ from .models import cloudwatch_backends from ..core.models import base_decorator, deprecated_base_decorator -cloudwatch_backend = cloudwatch_backends['us-east-1'] +cloudwatch_backend = cloudwatch_backends["us-east-1"] mock_cloudwatch = base_decorator(cloudwatch_backends) mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends) diff --git a/moto/cloudwatch/models.py b/moto/cloudwatch/models.py index ed644f874..662005237 100644 --- a/moto/cloudwatch/models.py +++ b/moto/cloudwatch/models.py @@ -1,4 +1,3 @@ - import json from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core import BaseBackend, BaseModel @@ -6,15 +5,15 @@ from moto.core.exceptions import RESTError import boto.ec2.cloudwatch from datetime import datetime, timedelta from dateutil.tz import tzutc +from uuid import uuid4 from .utils import make_arn_for_dashboard -DEFAULT_ACCOUNT_ID = 123456789012 +from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID _EMPTY_LIST = tuple() class Dimension(object): - def __init__(self, name, value): self.name = name self.value = value @@ -49,10 +48,23 @@ def daterange(start, stop, step=timedelta(days=1), inclusive=False): class FakeAlarm(BaseModel): - - def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods, - period, threshold, statistic, description, dimensions, alarm_actions, - ok_actions, insufficient_data_actions, unit): + def __init__( + self, + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ): self.name = name self.namespace = namespace self.metric_name = metric_name @@ -62,8 +74,9 @@ class FakeAlarm(BaseModel): self.threshold = threshold self.statistic = statistic self.description = description - self.dimensions = [Dimension(dimension['name'], dimension[ - 'value']) for dimension in dimensions] + self.dimensions = [ + Dimension(dimension["name"], dimension["value"]) for dimension in dimensions + ] self.alarm_actions = alarm_actions self.ok_actions = ok_actions self.insufficient_data_actions = insufficient_data_actions @@ -72,15 +85,21 @@ class FakeAlarm(BaseModel): self.history = [] - self.state_reason = '' - self.state_reason_data = '{}' - self.state_value = 'OK' + self.state_reason = "" + self.state_reason_data = "{}" + self.state_value = "OK" self.state_updated_timestamp = datetime.utcnow() def update_state(self, reason, reason_data, state_value): # History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action self.history.append( - ('StateUpdate', self.state_reason, self.state_reason_data, self.state_value, self.state_updated_timestamp) + ( + "StateUpdate", + self.state_reason, + self.state_reason_data, + self.state_value, + self.state_updated_timestamp, + ) ) self.state_reason = reason @@ -90,14 +109,14 @@ class FakeAlarm(BaseModel): class MetricDatum(BaseModel): - def __init__(self, namespace, name, value, dimensions, timestamp): self.namespace = namespace self.name = name self.value = value self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc()) - self.dimensions = [Dimension(dimension['Name'], dimension[ - 'Value']) for dimension in dimensions] + self.dimensions = [ + Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions + ] class Dashboard(BaseModel): @@ -120,7 +139,7 @@ class Dashboard(BaseModel): return len(self.body) def __repr__(self): - return ''.format(self.name) + return "".format(self.name) class Statistics: @@ -131,7 +150,7 @@ class Statistics: @property def sample_count(self): - if 'SampleCount' not in self.stats: + if "SampleCount" not in self.stats: return None return len(self.values) @@ -142,28 +161,28 @@ class Statistics: @property def sum(self): - if 'Sum' not in self.stats: + if "Sum" not in self.stats: return None return sum(self.values) @property def minimum(self): - if 'Minimum' not in self.stats: + if "Minimum" not in self.stats: return None return min(self.values) @property def maximum(self): - if 'Maximum' not in self.stats: + if "Maximum" not in self.stats: return None return max(self.values) @property def average(self): - if 'Average' not in self.stats: + if "Average" not in self.stats: return None # when moto is 3.4+ we can switch to the statistics module @@ -171,18 +190,45 @@ class Statistics: class CloudWatchBackend(BaseBackend): - def __init__(self): self.alarms = {} self.dashboards = {} self.metric_data = [] + self.paged_metric_data = {} - def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods, - period, threshold, statistic, description, dimensions, - alarm_actions, ok_actions, insufficient_data_actions, unit): - alarm = FakeAlarm(name, namespace, metric_name, comparison_operator, evaluation_periods, period, - threshold, statistic, description, dimensions, alarm_actions, - ok_actions, insufficient_data_actions, unit) + def put_metric_alarm( + self, + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ): + alarm = FakeAlarm( + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ) self.alarms[name] = alarm return alarm @@ -214,14 +260,12 @@ class CloudWatchBackend(BaseBackend): ] def get_alarms_by_alarm_names(self, alarm_names): - return [ - alarm - for alarm in self.alarms.values() - if alarm.name in alarm_names - ] + return [alarm for alarm in self.alarms.values() if alarm.name in alarm_names] def get_alarms_by_state_value(self, target_state): - return filter(lambda alarm: alarm.state_value == target_state, self.alarms.values()) + return filter( + lambda alarm: alarm.state_value == target_state, self.alarms.values() + ) def delete_alarms(self, alarm_names): for alarm_name in alarm_names: @@ -230,17 +274,31 @@ class CloudWatchBackend(BaseBackend): def put_metric_data(self, namespace, metric_data): for metric_member in metric_data: # Preserve "datetime" for get_metric_statistics comparisons - timestamp = metric_member.get('Timestamp') + timestamp = metric_member.get("Timestamp") if timestamp is not None and type(timestamp) != datetime: - timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') + timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") timestamp = timestamp.replace(tzinfo=tzutc()) - self.metric_data.append(MetricDatum( - namespace, metric_member['MetricName'], float(metric_member.get('Value', 0)), metric_member.get('Dimensions.member', _EMPTY_LIST), timestamp)) + self.metric_data.append( + MetricDatum( + namespace, + metric_member["MetricName"], + float(metric_member.get("Value", 0)), + metric_member.get("Dimensions.member", _EMPTY_LIST), + timestamp, + ) + ) - def get_metric_statistics(self, namespace, metric_name, start_time, end_time, period, stats): + def get_metric_statistics( + self, namespace, metric_name, start_time, end_time, period, stats + ): period_delta = timedelta(seconds=period) - filtered_data = [md for md in self.metric_data if - md.namespace == namespace and md.name == metric_name and start_time <= md.timestamp <= end_time] + filtered_data = [ + md + for md in self.metric_data + if md.namespace == namespace + and md.name == metric_name + and start_time <= md.timestamp <= end_time + ] # earliest to oldest filtered_data = sorted(filtered_data, key=lambda x: x.timestamp) @@ -249,9 +307,15 @@ class CloudWatchBackend(BaseBackend): idx = 0 data = list() - for dt in daterange(filtered_data[0].timestamp, filtered_data[-1].timestamp + period_delta, period_delta): + for dt in daterange( + filtered_data[0].timestamp, + filtered_data[-1].timestamp + period_delta, + period_delta, + ): s = Statistics(stats, dt) - while idx < len(filtered_data) and filtered_data[idx].timestamp < (dt + period_delta): + while idx < len(filtered_data) and filtered_data[idx].timestamp < ( + dt + period_delta + ): s.values.append(filtered_data[idx].value) idx += 1 @@ -268,7 +332,7 @@ class CloudWatchBackend(BaseBackend): def put_dashboard(self, name, body): self.dashboards[name] = Dashboard(name, body) - def list_dashboards(self, prefix=''): + def list_dashboards(self, prefix=""): for key, value in self.dashboards.items(): if key.startswith(prefix): yield value @@ -280,7 +344,12 @@ class CloudWatchBackend(BaseBackend): left_over = to_delete - all_dashboards if len(left_over) > 0: # Some dashboards are not found - return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over)) + return ( + False, + "The specified dashboard does not exist. [{0}]".format( + ", ".join(left_over) + ), + ) for dashboard in to_delete: del self.dashboards[dashboard] @@ -295,32 +364,66 @@ class CloudWatchBackend(BaseBackend): if reason_data is not None: json.loads(reason_data) except ValueError: - raise RESTError('InvalidFormat', 'StateReasonData is invalid JSON') + raise RESTError("InvalidFormat", "StateReasonData is invalid JSON") if alarm_name not in self.alarms: - raise RESTError('ResourceNotFound', 'Alarm {0} not found'.format(alarm_name), status=404) + raise RESTError( + "ResourceNotFound", "Alarm {0} not found".format(alarm_name), status=404 + ) - if state_value not in ('OK', 'ALARM', 'INSUFFICIENT_DATA'): - raise RESTError('InvalidParameterValue', 'StateValue is not one of OK | ALARM | INSUFFICIENT_DATA') + if state_value not in ("OK", "ALARM", "INSUFFICIENT_DATA"): + raise RESTError( + "InvalidParameterValue", + "StateValue is not one of OK | ALARM | INSUFFICIENT_DATA", + ) self.alarms[alarm_name].update_state(reason, reason_data, state_value) + def list_metrics(self, next_token, namespace, metric_name): + if next_token: + if next_token not in self.paged_metric_data: + raise RESTError( + "PaginationException", "Request parameter NextToken is invalid" + ) + else: + metrics = self.paged_metric_data[next_token] + del self.paged_metric_data[next_token] # Cant reuse same token twice + return self._get_paginated(metrics) + else: + metrics = self.get_filtered_metrics(metric_name, namespace) + return self._get_paginated(metrics) + + def get_filtered_metrics(self, metric_name, namespace): + metrics = self.get_all_metrics() + if namespace: + metrics = [md for md in metrics if md.namespace == namespace] + if metric_name: + metrics = [md for md in metrics if md.name == metric_name] + return metrics + + def _get_paginated(self, metrics): + if len(metrics) > 500: + next_token = str(uuid4()) + self.paged_metric_data[next_token] = metrics[500:] + return next_token, metrics[0:500] + else: + return None, metrics + class LogGroup(BaseModel): - def __init__(self, spec): # required - self.name = spec['LogGroupName'] + self.name = spec["LogGroupName"] # optional - self.tags = spec.get('Tags', []) + self.tags = spec.get("Tags", []) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - spec = { - 'LogGroupName': properties['LogGroupName'] - } - optional_properties = 'Tags'.split() + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + spec = {"LogGroupName": properties["LogGroupName"]} + optional_properties = "Tags".split() for prop in optional_properties: if prop in properties: spec[prop] = properties[prop] diff --git a/moto/cloudwatch/responses.py b/moto/cloudwatch/responses.py index bf176e1be..7872e71fd 100644 --- a/moto/cloudwatch/responses.py +++ b/moto/cloudwatch/responses.py @@ -6,7 +6,6 @@ from dateutil.parser import parse as dtparse class CloudWatchResponse(BaseResponse): - @property def cloudwatch_backend(self): return cloudwatch_backends[self.region] @@ -17,45 +16,54 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def put_metric_alarm(self): - name = self._get_param('AlarmName') - namespace = self._get_param('Namespace') - metric_name = self._get_param('MetricName') - comparison_operator = self._get_param('ComparisonOperator') - evaluation_periods = self._get_param('EvaluationPeriods') - period = self._get_param('Period') - threshold = self._get_param('Threshold') - statistic = self._get_param('Statistic') - description = self._get_param('AlarmDescription') - dimensions = self._get_list_prefix('Dimensions.member') - alarm_actions = self._get_multi_param('AlarmActions.member') - ok_actions = self._get_multi_param('OKActions.member') + name = self._get_param("AlarmName") + namespace = self._get_param("Namespace") + metric_name = self._get_param("MetricName") + comparison_operator = self._get_param("ComparisonOperator") + evaluation_periods = self._get_param("EvaluationPeriods") + period = self._get_param("Period") + threshold = self._get_param("Threshold") + statistic = self._get_param("Statistic") + description = self._get_param("AlarmDescription") + dimensions = self._get_list_prefix("Dimensions.member") + alarm_actions = self._get_multi_param("AlarmActions.member") + ok_actions = self._get_multi_param("OKActions.member") insufficient_data_actions = self._get_multi_param( - "InsufficientDataActions.member") - unit = self._get_param('Unit') - alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name, - comparison_operator, - evaluation_periods, period, - threshold, statistic, - description, dimensions, - alarm_actions, ok_actions, - insufficient_data_actions, - unit) + "InsufficientDataActions.member" + ) + unit = self._get_param("Unit") + alarm = self.cloudwatch_backend.put_metric_alarm( + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ) template = self.response_template(PUT_METRIC_ALARM_TEMPLATE) return template.render(alarm=alarm) @amzn_request_id def describe_alarms(self): - action_prefix = self._get_param('ActionPrefix') - alarm_name_prefix = self._get_param('AlarmNamePrefix') - alarm_names = self._get_multi_param('AlarmNames.member') - state_value = self._get_param('StateValue') + action_prefix = self._get_param("ActionPrefix") + alarm_name_prefix = self._get_param("AlarmNamePrefix") + alarm_names = self._get_multi_param("AlarmNames.member") + state_value = self._get_param("StateValue") if action_prefix: - alarms = self.cloudwatch_backend.get_alarms_by_action_prefix( - action_prefix) + alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(action_prefix) elif alarm_name_prefix: alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix( - alarm_name_prefix) + alarm_name_prefix + ) elif alarm_names: alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names) elif state_value: @@ -68,15 +76,15 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def delete_alarms(self): - alarm_names = self._get_multi_param('AlarmNames.member') + alarm_names = self._get_multi_param("AlarmNames.member") self.cloudwatch_backend.delete_alarms(alarm_names) template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE) return template.render() @amzn_request_id def put_metric_data(self): - namespace = self._get_param('Namespace') - metric_data = self._get_multi_param('MetricData.member') + namespace = self._get_param("Namespace") + metric_data = self._get_multi_param("MetricData.member") self.cloudwatch_backend.put_metric_data(namespace, metric_data) template = self.response_template(PUT_METRIC_DATA_TEMPLATE) @@ -84,43 +92,52 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def get_metric_statistics(self): - namespace = self._get_param('Namespace') - metric_name = self._get_param('MetricName') - start_time = dtparse(self._get_param('StartTime')) - end_time = dtparse(self._get_param('EndTime')) - period = int(self._get_param('Period')) + namespace = self._get_param("Namespace") + metric_name = self._get_param("MetricName") + start_time = dtparse(self._get_param("StartTime")) + end_time = dtparse(self._get_param("EndTime")) + period = int(self._get_param("Period")) statistics = self._get_multi_param("Statistics.member") # Unsupported Parameters (To Be Implemented) - unit = self._get_param('Unit') - extended_statistics = self._get_param('ExtendedStatistics') - dimensions = self._get_param('Dimensions') + unit = self._get_param("Unit") + extended_statistics = self._get_param("ExtendedStatistics") + dimensions = self._get_param("Dimensions") if unit or extended_statistics or dimensions: - raise NotImplemented() + raise NotImplementedError() # TODO: this should instead throw InvalidParameterCombination if not statistics: - raise NotImplemented("Must specify either Statistics or ExtendedStatistics") + raise NotImplementedError( + "Must specify either Statistics or ExtendedStatistics" + ) - datapoints = self.cloudwatch_backend.get_metric_statistics(namespace, metric_name, start_time, end_time, period, statistics) + datapoints = self.cloudwatch_backend.get_metric_statistics( + namespace, metric_name, start_time, end_time, period, statistics + ) template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE) return template.render(label=metric_name, datapoints=datapoints) @amzn_request_id def list_metrics(self): - metrics = self.cloudwatch_backend.get_all_metrics() + namespace = self._get_param("Namespace") + metric_name = self._get_param("MetricName") + next_token = self._get_param("NextToken") + next_token, metrics = self.cloudwatch_backend.list_metrics( + next_token, namespace, metric_name + ) template = self.response_template(LIST_METRICS_TEMPLATE) - return template.render(metrics=metrics) + return template.render(metrics=metrics, next_token=next_token) @amzn_request_id def delete_dashboards(self): - dashboards = self._get_multi_param('DashboardNames.member') + dashboards = self._get_multi_param("DashboardNames.member") if dashboards is None: - return self._error('InvalidParameterValue', 'Need at least 1 dashboard') + return self._error("InvalidParameterValue", "Need at least 1 dashboard") status, error = self.cloudwatch_backend.delete_dashboards(dashboards) if not status: - return self._error('ResourceNotFound', error) + return self._error("ResourceNotFound", error) template = self.response_template(DELETE_DASHBOARD_TEMPLATE) return template.render() @@ -143,18 +160,18 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def get_dashboard(self): - dashboard_name = self._get_param('DashboardName') + dashboard_name = self._get_param("DashboardName") dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name) if dashboard is None: - return self._error('ResourceNotFound', 'Dashboard does not exist') + return self._error("ResourceNotFound", "Dashboard does not exist") template = self.response_template(GET_DASHBOARD_TEMPLATE) return template.render(dashboard=dashboard) @amzn_request_id def list_dashboards(self): - prefix = self._get_param('DashboardNamePrefix', '') + prefix = self._get_param("DashboardNamePrefix", "") dashboards = self.cloudwatch_backend.list_dashboards(prefix) @@ -163,13 +180,13 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def put_dashboard(self): - name = self._get_param('DashboardName') - body = self._get_param('DashboardBody') + name = self._get_param("DashboardName") + body = self._get_param("DashboardBody") try: json.loads(body) except ValueError: - return self._error('InvalidParameterInput', 'Body is invalid JSON') + return self._error("InvalidParameterInput", "Body is invalid JSON") self.cloudwatch_backend.put_dashboard(name, body) @@ -178,12 +195,14 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def set_alarm_state(self): - alarm_name = self._get_param('AlarmName') - reason = self._get_param('StateReason') - reason_data = self._get_param('StateReasonData') - state_value = self._get_param('StateValue') + alarm_name = self._get_param("AlarmName") + reason = self._get_param("StateReason") + reason_data = self._get_param("StateReasonData") + state_value = self._get_param("StateValue") - self.cloudwatch_backend.set_alarm_state(alarm_name, reason, reason_data, state_value) + self.cloudwatch_backend.set_alarm_state( + alarm_name, reason, reason_data, state_value + ) template = self.response_template(SET_ALARM_STATE_TEMPLATE) return template.render() @@ -326,9 +345,11 @@ LIST_METRICS_TEMPLATE = """ 0: - extra_data.update({ - attribute[0]['Name']: attribute[0]['Value'] - }) + extra_data.update({attribute[0]["Name"]: attribute[0]["Value"]}) return extra_data class CognitoIdpUserPoolDomain(BaseModel): - def __init__(self, user_pool_id, domain, custom_domain_config=None): self.user_pool_id = user_pool_id self.domain = domain self.custom_domain_config = custom_domain_config or {} def _distribution_name(self): - if self.custom_domain_config and \ - 'CertificateArn' in self.custom_domain_config: + if self.custom_domain_config and "CertificateArn" in self.custom_domain_config: hash = hashlib.md5( - self.custom_domain_config['CertificateArn'].encode('utf-8') + self.custom_domain_config["CertificateArn"].encode("utf-8") ).hexdigest() return "{hash}.cloudfront.net".format(hash=hash[:16]) return None @@ -182,14 +194,11 @@ class CognitoIdpUserPoolDomain(BaseModel): "Version": None, } elif distribution: - return { - "CloudFrontDomain": distribution, - } + return {"CloudFrontDomain": distribution} return None class CognitoIdpUserPoolClient(BaseModel): - def __init__(self, user_pool_id, extended_config): self.user_pool_id = user_pool_id self.id = str(uuid.uuid4()) @@ -211,11 +220,10 @@ class CognitoIdpUserPoolClient(BaseModel): return user_pool_client_json def get_readable_fields(self): - return self.extended_config.get('ReadAttributes', []) + return self.extended_config.get("ReadAttributes", []) class CognitoIdpIdentityProvider(BaseModel): - def __init__(self, name, extended_config): self.name = name self.extended_config = extended_config or {} @@ -239,7 +247,6 @@ class CognitoIdpIdentityProvider(BaseModel): class CognitoIdpGroup(BaseModel): - def __init__(self, user_pool_id, group_name, description, role_arn, precedence): self.user_pool_id = user_pool_id self.group_name = group_name @@ -266,7 +273,6 @@ class CognitoIdpGroup(BaseModel): class CognitoIdpUser(BaseModel): - def __init__(self, user_pool_id, username, password, status, attributes): self.id = str(uuid.uuid4()) self.user_pool_id = user_pool_id @@ -299,19 +305,18 @@ class CognitoIdpUser(BaseModel): { "Enabled": self.enabled, attributes_key: self.attributes, - "MFAOptions": [] + "MFAOptions": [], } ) return user_json def update_attributes(self, new_attributes): - def flatten_attrs(attrs): - return {attr['Name']: attr['Value'] for attr in attrs} + return {attr["Name"]: attr["Value"] for attr in attrs} def expand_attrs(attrs): - return [{'Name': k, 'Value': v} for k, v in attrs.items()] + return [{"Name": k, "Value": v} for k, v in attrs.items()] flat_attributes = flatten_attrs(self.attributes) flat_attributes.update(flatten_attrs(new_attributes)) @@ -319,7 +324,6 @@ class CognitoIdpUser(BaseModel): class CognitoIdpBackend(BaseBackend): - def __init__(self, region): super(CognitoIdpBackend, self).__init__() self.region = region @@ -495,7 +499,9 @@ class CognitoIdpBackend(BaseBackend): if not user_pool: raise ResourceNotFoundError(user_pool_id) - group = CognitoIdpGroup(user_pool_id, group_name, description, role_arn, precedence) + group = CognitoIdpGroup( + user_pool_id, group_name, description, role_arn, precedence + ) if group.group_name in user_pool.groups: raise GroupExistsException("A group with the name already exists") user_pool.groups[group.group_name] = group @@ -561,7 +567,16 @@ class CognitoIdpBackend(BaseBackend): if not user_pool: raise ResourceNotFoundError(user_pool_id) - user = CognitoIdpUser(user_pool_id, username, temporary_password, UserStatus["FORCE_CHANGE_PASSWORD"], attributes) + if username in user_pool.users: + raise UsernameExistsException(username) + + user = CognitoIdpUser( + user_pool_id, + username, + temporary_password, + UserStatus["FORCE_CHANGE_PASSWORD"], + attributes, + ) user_pool.users[user.username] = user return user @@ -607,7 +622,9 @@ class CognitoIdpBackend(BaseBackend): def _log_user_in(self, user_pool, client, username): refresh_token = user_pool.create_refresh_token(client.id, username) - access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) + access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token( + refresh_token + ) return { "AuthenticationResult": { @@ -650,7 +667,11 @@ class CognitoIdpBackend(BaseBackend): return self._log_user_in(user_pool, client, username) elif auth_flow == "REFRESH_TOKEN": refresh_token = auth_parameters.get("REFRESH_TOKEN") - id_token, access_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) + ( + id_token, + access_token, + expires_in, + ) = user_pool.create_tokens_from_refresh_token(refresh_token) return { "AuthenticationResult": { @@ -662,7 +683,9 @@ class CognitoIdpBackend(BaseBackend): else: return {} - def respond_to_auth_challenge(self, session, client_id, challenge_name, challenge_responses): + def respond_to_auth_challenge( + self, session, client_id, challenge_name, challenge_responses + ): user_pool = self.sessions.get(session) if not user_pool: raise ResourceNotFoundError(session) diff --git a/moto/cognitoidp/responses.py b/moto/cognitoidp/responses.py index 75dd8c181..80247b076 100644 --- a/moto/cognitoidp/responses.py +++ b/moto/cognitoidp/responses.py @@ -8,7 +8,6 @@ from .models import cognitoidp_backends, find_region_by_value class CognitoIdpResponse(BaseResponse): - @property def parameters(self): return json.loads(self.body) @@ -16,10 +15,10 @@ class CognitoIdpResponse(BaseResponse): # User pool def create_user_pool(self): name = self.parameters.pop("PoolName") - user_pool = cognitoidp_backends[self.region].create_user_pool(name, self.parameters) - return json.dumps({ - "UserPool": user_pool.to_json(extended=True) - }) + user_pool = cognitoidp_backends[self.region].create_user_pool( + name, self.parameters + ) + return json.dumps({"UserPool": user_pool.to_json(extended=True)}) def list_user_pools(self): max_results = self._get_param("MaxResults") @@ -27,9 +26,7 @@ class CognitoIdpResponse(BaseResponse): user_pools, next_token = cognitoidp_backends[self.region].list_user_pools( max_results=max_results, next_token=next_token ) - response = { - "UserPools": [user_pool.to_json() for user_pool in user_pools], - } + response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]} if next_token: response["NextToken"] = str(next_token) return json.dumps(response) @@ -37,9 +34,7 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool(self): user_pool_id = self._get_param("UserPoolId") user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id) - return json.dumps({ - "UserPool": user_pool.to_json(extended=True) - }) + return json.dumps({"UserPool": user_pool.to_json(extended=True)}) def delete_user_pool(self): user_pool_id = self._get_param("UserPoolId") @@ -61,14 +56,14 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool_domain(self): domain = self._get_param("Domain") - user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(domain) + user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain( + domain + ) domain_description = {} if user_pool_domain: domain_description = user_pool_domain.to_json() - return json.dumps({ - "DomainDescription": domain_description - }) + return json.dumps({"DomainDescription": domain_description}) def delete_user_pool_domain(self): domain = self._get_param("Domain") @@ -89,19 +84,24 @@ class CognitoIdpResponse(BaseResponse): # User pool client def create_user_pool_client(self): user_pool_id = self.parameters.pop("UserPoolId") - user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(user_pool_id, self.parameters) - return json.dumps({ - "UserPoolClient": user_pool_client.to_json(extended=True) - }) + user_pool_client = cognitoidp_backends[self.region].create_user_pool_client( + user_pool_id, self.parameters + ) + return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) def list_user_pool_clients(self): user_pool_id = self._get_param("UserPoolId") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken", "0") - user_pool_clients, next_token = cognitoidp_backends[self.region].list_user_pool_clients(user_pool_id, - max_results=max_results, next_token=next_token) + user_pool_clients, next_token = cognitoidp_backends[ + self.region + ].list_user_pool_clients( + user_pool_id, max_results=max_results, next_token=next_token + ) response = { - "UserPoolClients": [user_pool_client.to_json() for user_pool_client in user_pool_clients] + "UserPoolClients": [ + user_pool_client.to_json() for user_pool_client in user_pool_clients + ] } if next_token: response["NextToken"] = str(next_token) @@ -110,43 +110,51 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool_client(self): user_pool_id = self._get_param("UserPoolId") client_id = self._get_param("ClientId") - user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(user_pool_id, client_id) - return json.dumps({ - "UserPoolClient": user_pool_client.to_json(extended=True) - }) + user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client( + user_pool_id, client_id + ) + return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) def update_user_pool_client(self): user_pool_id = self.parameters.pop("UserPoolId") client_id = self.parameters.pop("ClientId") - user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(user_pool_id, client_id, self.parameters) - return json.dumps({ - "UserPoolClient": user_pool_client.to_json(extended=True) - }) + user_pool_client = cognitoidp_backends[self.region].update_user_pool_client( + user_pool_id, client_id, self.parameters + ) + return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) def delete_user_pool_client(self): user_pool_id = self._get_param("UserPoolId") client_id = self._get_param("ClientId") - cognitoidp_backends[self.region].delete_user_pool_client(user_pool_id, client_id) + cognitoidp_backends[self.region].delete_user_pool_client( + user_pool_id, client_id + ) return "" # Identity provider def create_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self.parameters.pop("ProviderName") - identity_provider = cognitoidp_backends[self.region].create_identity_provider(user_pool_id, name, self.parameters) - return json.dumps({ - "IdentityProvider": identity_provider.to_json(extended=True) - }) + identity_provider = cognitoidp_backends[self.region].create_identity_provider( + user_pool_id, name, self.parameters + ) + return json.dumps( + {"IdentityProvider": identity_provider.to_json(extended=True)} + ) def list_identity_providers(self): user_pool_id = self._get_param("UserPoolId") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken", "0") - identity_providers, next_token = cognitoidp_backends[self.region].list_identity_providers( + identity_providers, next_token = cognitoidp_backends[ + self.region + ].list_identity_providers( user_pool_id, max_results=max_results, next_token=next_token ) response = { - "Providers": [identity_provider.to_json() for identity_provider in identity_providers] + "Providers": [ + identity_provider.to_json() for identity_provider in identity_providers + ] } if next_token: response["NextToken"] = str(next_token) @@ -155,18 +163,22 @@ class CognitoIdpResponse(BaseResponse): def describe_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self._get_param("ProviderName") - identity_provider = cognitoidp_backends[self.region].describe_identity_provider(user_pool_id, name) - return json.dumps({ - "IdentityProvider": identity_provider.to_json(extended=True) - }) + identity_provider = cognitoidp_backends[self.region].describe_identity_provider( + user_pool_id, name + ) + return json.dumps( + {"IdentityProvider": identity_provider.to_json(extended=True)} + ) def update_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self._get_param("ProviderName") - identity_provider = cognitoidp_backends[self.region].update_identity_provider(user_pool_id, name, self.parameters) - return json.dumps({ - "IdentityProvider": identity_provider.to_json(extended=True) - }) + identity_provider = cognitoidp_backends[self.region].update_identity_provider( + user_pool_id, name, self.parameters + ) + return json.dumps( + {"IdentityProvider": identity_provider.to_json(extended=True)} + ) def delete_identity_provider(self): user_pool_id = self._get_param("UserPoolId") @@ -183,31 +195,21 @@ class CognitoIdpResponse(BaseResponse): precedence = self._get_param("Precedence") group = cognitoidp_backends[self.region].create_group( - user_pool_id, - group_name, - description, - role_arn, - precedence, + user_pool_id, group_name, description, role_arn, precedence ) - return json.dumps({ - "Group": group.to_json(), - }) + return json.dumps({"Group": group.to_json()}) def get_group(self): group_name = self._get_param("GroupName") user_pool_id = self._get_param("UserPoolId") group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name) - return json.dumps({ - "Group": group.to_json(), - }) + return json.dumps({"Group": group.to_json()}) def list_groups(self): user_pool_id = self._get_param("UserPoolId") groups = cognitoidp_backends[self.region].list_groups(user_pool_id) - return json.dumps({ - "Groups": [group.to_json() for group in groups], - }) + return json.dumps({"Groups": [group.to_json() for group in groups]}) def delete_group(self): group_name = self._get_param("GroupName") @@ -221,9 +223,7 @@ class CognitoIdpResponse(BaseResponse): group_name = self._get_param("GroupName") cognitoidp_backends[self.region].admin_add_user_to_group( - user_pool_id, - group_name, - username, + user_pool_id, group_name, username ) return "" @@ -231,18 +231,18 @@ class CognitoIdpResponse(BaseResponse): def list_users_in_group(self): user_pool_id = self._get_param("UserPoolId") group_name = self._get_param("GroupName") - users = cognitoidp_backends[self.region].list_users_in_group(user_pool_id, group_name) - return json.dumps({ - "Users": [user.to_json(extended=True) for user in users], - }) + users = cognitoidp_backends[self.region].list_users_in_group( + user_pool_id, group_name + ) + return json.dumps({"Users": [user.to_json(extended=True) for user in users]}) def admin_list_groups_for_user(self): username = self._get_param("Username") user_pool_id = self._get_param("UserPoolId") - groups = cognitoidp_backends[self.region].admin_list_groups_for_user(user_pool_id, username) - return json.dumps({ - "Groups": [group.to_json() for group in groups], - }) + groups = cognitoidp_backends[self.region].admin_list_groups_for_user( + user_pool_id, username + ) + return json.dumps({"Groups": [group.to_json() for group in groups]}) def admin_remove_user_from_group(self): user_pool_id = self._get_param("UserPoolId") @@ -250,9 +250,7 @@ class CognitoIdpResponse(BaseResponse): group_name = self._get_param("GroupName") cognitoidp_backends[self.region].admin_remove_user_from_group( - user_pool_id, - group_name, - username, + user_pool_id, group_name, username ) return "" @@ -266,28 +264,24 @@ class CognitoIdpResponse(BaseResponse): user_pool_id, username, temporary_password, - self._get_param("UserAttributes", []) + self._get_param("UserAttributes", []), ) - return json.dumps({ - "User": user.to_json(extended=True) - }) + return json.dumps({"User": user.to_json(extended=True)}) def admin_get_user(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username) - return json.dumps( - user.to_json(extended=True, attributes_key="UserAttributes") - ) + return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes")) def list_users(self): user_pool_id = self._get_param("UserPoolId") limit = self._get_param("Limit") token = self._get_param("PaginationToken") - users, token = cognitoidp_backends[self.region].list_users(user_pool_id, - limit=limit, - pagination_token=token) + users, token = cognitoidp_backends[self.region].list_users( + user_pool_id, limit=limit, pagination_token=token + ) response = {"Users": [user.to_json(extended=True) for user in users]} if token: response["PaginationToken"] = str(token) @@ -318,10 +312,7 @@ class CognitoIdpResponse(BaseResponse): auth_parameters = self._get_param("AuthParameters") auth_result = cognitoidp_backends[self.region].admin_initiate_auth( - user_pool_id, - client_id, - auth_flow, - auth_parameters, + user_pool_id, client_id, auth_flow, auth_parameters ) return json.dumps(auth_result) @@ -332,21 +323,15 @@ class CognitoIdpResponse(BaseResponse): challenge_name = self._get_param("ChallengeName") challenge_responses = self._get_param("ChallengeResponses") auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge( - session, - client_id, - challenge_name, - challenge_responses, + session, client_id, challenge_name, challenge_responses ) return json.dumps(auth_result) def forgot_password(self): - return json.dumps({ - "CodeDeliveryDetails": { - "DeliveryMedium": "EMAIL", - "Destination": "...", - } - }) + return json.dumps( + {"CodeDeliveryDetails": {"DeliveryMedium": "EMAIL", "Destination": "..."}} + ) # This endpoint receives no authorization header, so if moto-server is listening # on localhost (doesn't get a region in the host header), it doesn't know what @@ -357,7 +342,9 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") password = self._get_param("Password") region = find_region_by_value("client_id", client_id) - cognitoidp_backends[region].confirm_forgot_password(client_id, username, password) + cognitoidp_backends[region].confirm_forgot_password( + client_id, username, password + ) return "" # Ditto the comment on confirm_forgot_password. @@ -366,21 +353,26 @@ class CognitoIdpResponse(BaseResponse): previous_password = self._get_param("PreviousPassword") proposed_password = self._get_param("ProposedPassword") region = find_region_by_value("access_token", access_token) - cognitoidp_backends[region].change_password(access_token, previous_password, proposed_password) + cognitoidp_backends[region].change_password( + access_token, previous_password, proposed_password + ) return "" def admin_update_user_attributes(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") attributes = self._get_param("UserAttributes") - cognitoidp_backends[self.region].admin_update_user_attributes(user_pool_id, username, attributes) + cognitoidp_backends[self.region].admin_update_user_attributes( + user_pool_id, username, attributes + ) return "" class CognitoIdpJsonWebKeyResponse(BaseResponse): - def __init__(self): - with open(os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")) as f: + with open( + os.path.join(os.path.dirname(__file__), "resources/jwks-public.json") + ) as f: self.json_web_key = f.read() def serve_json_web_key(self, request, full_url, headers): diff --git a/moto/cognitoidp/urls.py b/moto/cognitoidp/urls.py index 77441ed5e..5d1dff1d0 100644 --- a/moto/cognitoidp/urls.py +++ b/moto/cognitoidp/urls.py @@ -1,11 +1,9 @@ from __future__ import unicode_literals from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse -url_bases = [ - "https?://cognito-idp.(.+).amazonaws.com", -] +url_bases = ["https?://cognito-idp.(.+).amazonaws.com"] url_paths = { - '{0}/$': CognitoIdpResponse.dispatch, - '{0}//.well-known/jwks.json$': CognitoIdpJsonWebKeyResponse().serve_json_web_key, + "{0}/$": CognitoIdpResponse.dispatch, + "{0}//.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key, } diff --git a/moto/compat.py b/moto/compat.py index a92a5f67b..c0acd28a6 100644 --- a/moto/compat.py +++ b/moto/compat.py @@ -1,5 +1,10 @@ try: - from collections import OrderedDict # flake8: noqa + from collections import OrderedDict # noqa except ImportError: # python 2.6 or earlier, use backport - from ordereddict import OrderedDict # flake8: noqa + from ordereddict import OrderedDict # noqa + +try: + import collections.abc as collections_abc # noqa +except ImportError: + import collections as collections_abc # noqa diff --git a/moto/config/exceptions.py b/moto/config/exceptions.py index 25749200f..4a0dc0d73 100644 --- a/moto/config/exceptions.py +++ b/moto/config/exceptions.py @@ -6,8 +6,12 @@ class NameTooLongException(JsonRESTError): code = 400 def __init__(self, name, location): - message = '1 validation error detected: Value \'{name}\' at \'{location}\' failed to satisfy' \ - ' constraint: Member must have length less than or equal to 256'.format(name=name, location=location) + message = ( + "1 validation error detected: Value '{name}' at '{location}' failed to satisfy" + " constraint: Member must have length less than or equal to 256".format( + name=name, location=location + ) + ) super(NameTooLongException, self).__init__("ValidationException", message) @@ -15,41 +19,54 @@ class InvalidConfigurationRecorderNameException(JsonRESTError): code = 400 def __init__(self, name): - message = 'The configuration recorder name \'{name}\' is not valid, blank string.'.format(name=name) - super(InvalidConfigurationRecorderNameException, self).__init__("InvalidConfigurationRecorderNameException", - message) + message = "The configuration recorder name '{name}' is not valid, blank string.".format( + name=name + ) + super(InvalidConfigurationRecorderNameException, self).__init__( + "InvalidConfigurationRecorderNameException", message + ) class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Failed to put configuration recorder \'{name}\' because the maximum number of ' \ - 'configuration recorders: 1 is reached.'.format(name=name) + message = ( + "Failed to put configuration recorder '{name}' because the maximum number of " + "configuration recorders: 1 is reached.".format(name=name) + ) super(MaxNumberOfConfigurationRecordersExceededException, self).__init__( - "MaxNumberOfConfigurationRecordersExceededException", message) + "MaxNumberOfConfigurationRecordersExceededException", message + ) class InvalidRecordingGroupException(JsonRESTError): code = 400 def __init__(self): - message = 'The recording group provided is not valid' - super(InvalidRecordingGroupException, self).__init__("InvalidRecordingGroupException", message) + message = "The recording group provided is not valid" + super(InvalidRecordingGroupException, self).__init__( + "InvalidRecordingGroupException", message + ) class InvalidResourceTypeException(JsonRESTError): code = 400 def __init__(self, bad_list, good_list): - message = '{num} validation error detected: Value \'{bad_list}\' at ' \ - '\'configurationRecorder.recordingGroup.resourceTypes\' failed to satisfy constraint: ' \ - 'Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]'.format( - num=len(bad_list), bad_list=bad_list, good_list=good_list) + message = ( + "{num} validation error detected: Value '{bad_list}' at " + "'configurationRecorder.recordingGroup.resourceTypes' failed to satisfy constraint: " + "Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]".format( + num=len(bad_list), bad_list=bad_list, good_list=good_list + ) + ) # For PY2: message = str(message) - super(InvalidResourceTypeException, self).__init__("ValidationException", message) + super(InvalidResourceTypeException, self).__init__( + "ValidationException", message + ) class NoSuchConfigurationAggregatorException(JsonRESTError): @@ -57,36 +74,48 @@ class NoSuchConfigurationAggregatorException(JsonRESTError): def __init__(self, number=1): if number == 1: - message = 'The configuration aggregator does not exist. Check the configuration aggregator name and try again.' + message = "The configuration aggregator does not exist. Check the configuration aggregator name and try again." else: - message = 'At least one of the configuration aggregators does not exist. Check the configuration aggregator' \ - ' names and try again.' - super(NoSuchConfigurationAggregatorException, self).__init__("NoSuchConfigurationAggregatorException", message) + message = ( + "At least one of the configuration aggregators does not exist. Check the configuration aggregator" + " names and try again." + ) + super(NoSuchConfigurationAggregatorException, self).__init__( + "NoSuchConfigurationAggregatorException", message + ) class NoSuchConfigurationRecorderException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Cannot find configuration recorder with the specified name \'{name}\'.'.format(name=name) - super(NoSuchConfigurationRecorderException, self).__init__("NoSuchConfigurationRecorderException", message) + message = "Cannot find configuration recorder with the specified name '{name}'.".format( + name=name + ) + super(NoSuchConfigurationRecorderException, self).__init__( + "NoSuchConfigurationRecorderException", message + ) class InvalidDeliveryChannelNameException(JsonRESTError): code = 400 def __init__(self, name): - message = 'The delivery channel name \'{name}\' is not valid, blank string.'.format(name=name) - super(InvalidDeliveryChannelNameException, self).__init__("InvalidDeliveryChannelNameException", - message) + message = "The delivery channel name '{name}' is not valid, blank string.".format( + name=name + ) + super(InvalidDeliveryChannelNameException, self).__init__( + "InvalidDeliveryChannelNameException", message + ) class NoSuchBucketException(JsonRESTError): """We are *only* validating that there is value that is not '' here.""" + code = 400 def __init__(self): - message = 'Cannot find a S3 bucket with an empty bucket name.' + message = "Cannot find a S3 bucket with an empty bucket name." super(NoSuchBucketException, self).__init__("NoSuchBucketException", message) @@ -94,89 +123,120 @@ class InvalidNextTokenException(JsonRESTError): code = 400 def __init__(self): - message = 'The nextToken provided is invalid' - super(InvalidNextTokenException, self).__init__("InvalidNextTokenException", message) + message = "The nextToken provided is invalid" + super(InvalidNextTokenException, self).__init__( + "InvalidNextTokenException", message + ) class InvalidS3KeyPrefixException(JsonRESTError): code = 400 def __init__(self): - message = 'The s3 key prefix \'\' is not valid, empty s3 key prefix.' - super(InvalidS3KeyPrefixException, self).__init__("InvalidS3KeyPrefixException", message) + message = "The s3 key prefix '' is not valid, empty s3 key prefix." + super(InvalidS3KeyPrefixException, self).__init__( + "InvalidS3KeyPrefixException", message + ) class InvalidSNSTopicARNException(JsonRESTError): """We are *only* validating that there is value that is not '' here.""" + code = 400 def __init__(self): - message = 'The sns topic arn \'\' is not valid.' - super(InvalidSNSTopicARNException, self).__init__("InvalidSNSTopicARNException", message) + message = "The sns topic arn '' is not valid." + super(InvalidSNSTopicARNException, self).__init__( + "InvalidSNSTopicARNException", message + ) class InvalidDeliveryFrequency(JsonRESTError): code = 400 def __init__(self, value, good_list): - message = '1 validation error detected: Value \'{value}\' at ' \ - '\'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency\' failed to satisfy ' \ - 'constraint: Member must satisfy enum value set: {good_list}'.format(value=value, good_list=good_list) - super(InvalidDeliveryFrequency, self).__init__("InvalidDeliveryFrequency", message) + message = ( + "1 validation error detected: Value '{value}' at " + "'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency' failed to satisfy " + "constraint: Member must satisfy enum value set: {good_list}".format( + value=value, good_list=good_list + ) + ) + super(InvalidDeliveryFrequency, self).__init__( + "InvalidDeliveryFrequency", message + ) class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Failed to put delivery channel \'{name}\' because the maximum number of ' \ - 'delivery channels: 1 is reached.'.format(name=name) + message = ( + "Failed to put delivery channel '{name}' because the maximum number of " + "delivery channels: 1 is reached.".format(name=name) + ) super(MaxNumberOfDeliveryChannelsExceededException, self).__init__( - "MaxNumberOfDeliveryChannelsExceededException", message) + "MaxNumberOfDeliveryChannelsExceededException", message + ) class NoSuchDeliveryChannelException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Cannot find delivery channel with specified name \'{name}\'.'.format(name=name) - super(NoSuchDeliveryChannelException, self).__init__("NoSuchDeliveryChannelException", message) + message = "Cannot find delivery channel with specified name '{name}'.".format( + name=name + ) + super(NoSuchDeliveryChannelException, self).__init__( + "NoSuchDeliveryChannelException", message + ) class NoAvailableConfigurationRecorderException(JsonRESTError): code = 400 def __init__(self): - message = 'Configuration recorder is not available to put delivery channel.' - super(NoAvailableConfigurationRecorderException, self).__init__("NoAvailableConfigurationRecorderException", - message) + message = "Configuration recorder is not available to put delivery channel." + super(NoAvailableConfigurationRecorderException, self).__init__( + "NoAvailableConfigurationRecorderException", message + ) class NoAvailableDeliveryChannelException(JsonRESTError): code = 400 def __init__(self): - message = 'Delivery channel is not available to start configuration recorder.' - super(NoAvailableDeliveryChannelException, self).__init__("NoAvailableDeliveryChannelException", message) + message = "Delivery channel is not available to start configuration recorder." + super(NoAvailableDeliveryChannelException, self).__init__( + "NoAvailableDeliveryChannelException", message + ) class LastDeliveryChannelDeleteFailedException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \ - 'because there is a running configuration recorder.'.format(name=name) - super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message) + message = ( + "Failed to delete last specified delivery channel with name '{name}', because there, " + "because there is a running configuration recorder.".format(name=name) + ) + super(LastDeliveryChannelDeleteFailedException, self).__init__( + "LastDeliveryChannelDeleteFailedException", message + ) class TooManyAccountSources(JsonRESTError): code = 400 def __init__(self, length): - locations = ['com.amazonaws.xyz'] * length + locations = ["com.amazonaws.xyz"] * length - message = 'Value \'[{locations}]\' at \'accountAggregationSources\' failed to satisfy constraint: ' \ - 'Member must have length less than or equal to 1'.format(locations=', '.join(locations)) + message = ( + "Value '[{locations}]' at 'accountAggregationSources' failed to satisfy constraint: " + "Member must have length less than or equal to 1".format( + locations=", ".join(locations) + ) + ) super(TooManyAccountSources, self).__init__("ValidationException", message) @@ -185,16 +245,22 @@ class DuplicateTags(JsonRESTError): def __init__(self): super(DuplicateTags, self).__init__( - 'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.') + "InvalidInput", + "Duplicate tag keys found. Please note that Tag keys are case insensitive.", + ) class TagKeyTooBig(JsonRESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): + def __init__(self, tag, param="tags.X.member.key"): super(TagKeyTooBig, self).__init__( - 'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 128".format(tag, param)) + "ValidationException", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 128".format( + tag, param + ), + ) class TagValueTooBig(JsonRESTError): @@ -202,31 +268,101 @@ class TagValueTooBig(JsonRESTError): def __init__(self, tag): super(TagValueTooBig, self).__init__( - 'ValidationException', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " - "constraint: Member must have length less than or equal to 256".format(tag)) + "ValidationException", + "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " + "constraint: Member must have length less than or equal to 256".format(tag), + ) class InvalidParameterValueException(JsonRESTError): code = 400 def __init__(self, message): - super(InvalidParameterValueException, self).__init__('InvalidParameterValueException', message) + super(InvalidParameterValueException, self).__init__( + "InvalidParameterValueException", message + ) class InvalidTagCharacters(JsonRESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): - message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param) - message += 'constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+' + def __init__(self, tag, param="tags.X.member.key"): + message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format( + tag, param + ) + message += "constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+" - super(InvalidTagCharacters, self).__init__('ValidationException', message) + super(InvalidTagCharacters, self).__init__("ValidationException", message) class TooManyTags(JsonRESTError): code = 400 - def __init__(self, tags, param='tags'): + def __init__(self, tags, param="tags"): super(TooManyTags, self).__init__( - 'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 50.".format(tags, param)) + "ValidationException", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 50.".format( + tags, param + ), + ) + + +class InvalidResourceParameters(JsonRESTError): + code = 400 + + def __init__(self): + super(InvalidResourceParameters, self).__init__( + "ValidationException", + "Both Resource ID and Resource Name " "cannot be specified in the request", + ) + + +class InvalidLimit(JsonRESTError): + code = 400 + + def __init__(self, value): + super(InvalidLimit, self).__init__( + "ValidationException", + "Value '{value}' at 'limit' failed to satisify constraint: Member" + " must have value less than or equal to 100".format(value=value), + ) + + +class TooManyResourceIds(JsonRESTError): + code = 400 + + def __init__(self): + super(TooManyResourceIds, self).__init__( + "ValidationException", + "The specified list had more than 20 resource ID's. " + "It must have '20' or less items", + ) + + +class ResourceNotDiscoveredException(JsonRESTError): + code = 400 + + def __init__(self, type, resource): + super(ResourceNotDiscoveredException, self).__init__( + "ResourceNotDiscoveredException", + "Resource {resource} of resourceType:{type} is unknown or has not been " + "discovered".format(resource=resource, type=type), + ) + + +class TooManyResourceKeys(JsonRESTError): + code = 400 + + def __init__(self, bad_list): + message = ( + "1 validation error detected: Value '{bad_list}' at " + "'resourceKeys' failed to satisfy constraint: " + "Member must have length less than or equal to 100".format( + bad_list=bad_list + ) + ) + # For PY2: + message = str(message) + + super(TooManyResourceKeys, self).__init__("ValidationException", message) diff --git a/moto/config/models.py b/moto/config/models.py index 6541fc981..9015762fe 100644 --- a/moto/config/models.py +++ b/moto/config/models.py @@ -9,41 +9,71 @@ from datetime import datetime from boto3 import Session -from moto.config.exceptions import InvalidResourceTypeException, InvalidDeliveryFrequency, \ - InvalidConfigurationRecorderNameException, NameTooLongException, \ - MaxNumberOfConfigurationRecordersExceededException, InvalidRecordingGroupException, \ - NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \ - InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \ - InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \ - NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException, TagKeyTooBig, \ - TooManyTags, TagValueTooBig, TooManyAccountSources, InvalidParameterValueException, InvalidNextTokenException, \ - NoSuchConfigurationAggregatorException, InvalidTagCharacters, DuplicateTags +from moto.config.exceptions import ( + InvalidResourceTypeException, + InvalidDeliveryFrequency, + InvalidConfigurationRecorderNameException, + NameTooLongException, + MaxNumberOfConfigurationRecordersExceededException, + InvalidRecordingGroupException, + NoSuchConfigurationRecorderException, + NoAvailableConfigurationRecorderException, + InvalidDeliveryChannelNameException, + NoSuchBucketException, + InvalidS3KeyPrefixException, + InvalidSNSTopicARNException, + MaxNumberOfDeliveryChannelsExceededException, + NoAvailableDeliveryChannelException, + NoSuchDeliveryChannelException, + LastDeliveryChannelDeleteFailedException, + TagKeyTooBig, + TooManyTags, + TagValueTooBig, + TooManyAccountSources, + InvalidParameterValueException, + InvalidNextTokenException, + NoSuchConfigurationAggregatorException, + InvalidTagCharacters, + DuplicateTags, + InvalidLimit, + InvalidResourceParameters, + TooManyResourceIds, + ResourceNotDiscoveredException, + TooManyResourceKeys, +) from moto.core import BaseBackend, BaseModel +from moto.s3.config import s3_config_query + +from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID -DEFAULT_ACCOUNT_ID = 123456789012 POP_STRINGS = [ - 'capitalizeStart', - 'CapitalizeStart', - 'capitalizeArn', - 'CapitalizeArn', - 'capitalizeARN', - 'CapitalizeARN' + "capitalizeStart", + "CapitalizeStart", + "capitalizeArn", + "CapitalizeArn", + "capitalizeARN", + "CapitalizeARN", ] DEFAULT_PAGE_SIZE = 100 +# Map the Config resource type to a backend: +RESOURCE_MAP = {"AWS::S3::Bucket": s3_config_query} + def datetime2int(date): return int(time.mktime(date.timetuple())) def snake_to_camels(original, cap_start, cap_arn): - parts = original.split('_') + parts = original.split("_") - camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:]) + camel_cased = parts[0].lower() + "".join(p.title() for p in parts[1:]) if cap_arn: - camel_cased = camel_cased.replace('Arn', 'ARN') # Some config services use 'ARN' instead of 'Arn' + camel_cased = camel_cased.replace( + "Arn", "ARN" + ) # Some config services use 'ARN' instead of 'Arn' if cap_start: camel_cased = camel_cased[0].upper() + camel_cased[1::] @@ -60,7 +90,7 @@ def random_string(): return "".join(chars) -def validate_tag_key(tag_key, exception_param='tags.X.member.key'): +def validate_tag_key(tag_key, exception_param="tags.X.member.key"): """Validates the tag key. :param tag_key: The tag key to check against. @@ -74,7 +104,7 @@ def validate_tag_key(tag_key, exception_param='tags.X.member.key'): # Validate that the tag key fits the proper Regex: # [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+ - match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key) + match = re.findall(r"[\w\s_.:/=+\-@]+", tag_key) # Kudos if you can come up with a better way of doing a global search :) if not len(match) or len(match[0]) < len(tag_key): raise InvalidTagCharacters(tag_key, param=exception_param) @@ -99,14 +129,14 @@ def validate_tags(tags): for tag in tags: # Validate the Key: - validate_tag_key(tag['Key']) - check_tag_duplicate(proper_tags, tag['Key']) + validate_tag_key(tag["Key"]) + check_tag_duplicate(proper_tags, tag["Key"]) # Validate the Value: - if len(tag['Value']) > 256: - raise TagValueTooBig(tag['Value']) + if len(tag["Value"]) > 256: + raise TagValueTooBig(tag["Value"]) - proper_tags[tag['Key']] = tag['Value'] + proper_tags[tag["Key"]] = tag["Value"] return proper_tags @@ -127,9 +157,17 @@ class ConfigEmptyDictable(BaseModel): for item, value in self.__dict__.items(): if value is not None: if isinstance(value, ConfigEmptyDictable): - data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value.to_dict() + data[ + snake_to_camels( + item, self.capitalize_start, self.capitalize_arn + ) + ] = value.to_dict() else: - data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value + data[ + snake_to_camels( + item, self.capitalize_start, self.capitalize_arn + ) + ] = value # Cleanse the extra properties: for prop in POP_STRINGS: @@ -139,7 +177,6 @@ class ConfigEmptyDictable(BaseModel): class ConfigRecorderStatus(ConfigEmptyDictable): - def __init__(self, name): super(ConfigRecorderStatus, self).__init__() @@ -154,7 +191,7 @@ class ConfigRecorderStatus(ConfigEmptyDictable): def start(self): self.recording = True - self.last_status = 'PENDING' + self.last_status = "PENDING" self.last_start_time = datetime2int(datetime.utcnow()) self.last_status_change_time = datetime2int(datetime.utcnow()) @@ -165,7 +202,6 @@ class ConfigRecorderStatus(ConfigEmptyDictable): class ConfigDeliverySnapshotProperties(ConfigEmptyDictable): - def __init__(self, delivery_frequency): super(ConfigDeliverySnapshotProperties, self).__init__() @@ -173,8 +209,9 @@ class ConfigDeliverySnapshotProperties(ConfigEmptyDictable): class ConfigDeliveryChannel(ConfigEmptyDictable): - - def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None): + def __init__( + self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None + ): super(ConfigDeliveryChannel, self).__init__() self.name = name @@ -185,8 +222,12 @@ class ConfigDeliveryChannel(ConfigEmptyDictable): class RecordingGroup(ConfigEmptyDictable): - - def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None): + def __init__( + self, + all_supported=True, + include_global_resource_types=False, + resource_types=None, + ): super(RecordingGroup, self).__init__() self.all_supported = all_supported @@ -195,8 +236,7 @@ class RecordingGroup(ConfigEmptyDictable): class ConfigRecorder(ConfigEmptyDictable): - - def __init__(self, role_arn, recording_group, name='default', status=None): + def __init__(self, role_arn, recording_group, name="default", status=None): super(ConfigRecorder, self).__init__() self.name = name @@ -210,18 +250,21 @@ class ConfigRecorder(ConfigEmptyDictable): class AccountAggregatorSource(ConfigEmptyDictable): - def __init__(self, account_ids, aws_regions=None, all_aws_regions=None): super(AccountAggregatorSource, self).__init__(capitalize_start=True) # Can't have both the regions and all_regions flag present -- also can't have them both missing: if aws_regions and all_aws_regions: - raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies ' - 'the use of all regions. You must choose one of these options.') + raise InvalidParameterValueException( + "Your configuration aggregator contains a list of regions and also specifies " + "the use of all regions. You must choose one of these options." + ) if not (aws_regions or all_aws_regions): - raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported ' - 'regions and try again.') + raise InvalidParameterValueException( + "Your request does not specify any regions. Select AWS Config-supported " + "regions and try again." + ) self.account_ids = account_ids self.aws_regions = aws_regions @@ -233,18 +276,23 @@ class AccountAggregatorSource(ConfigEmptyDictable): class OrganizationAggregationSource(ConfigEmptyDictable): - def __init__(self, role_arn, aws_regions=None, all_aws_regions=None): - super(OrganizationAggregationSource, self).__init__(capitalize_start=True, capitalize_arn=False) + super(OrganizationAggregationSource, self).__init__( + capitalize_start=True, capitalize_arn=False + ) # Can't have both the regions and all_regions flag present -- also can't have them both missing: if aws_regions and all_aws_regions: - raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies ' - 'the use of all regions. You must choose one of these options.') + raise InvalidParameterValueException( + "Your configuration aggregator contains a list of regions and also specifies " + "the use of all regions. You must choose one of these options." + ) if not (aws_regions or all_aws_regions): - raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported ' - 'regions and try again.') + raise InvalidParameterValueException( + "Your request does not specify any regions. Select AWS Config-supported " + "regions and try again." + ) self.role_arn = role_arn self.aws_regions = aws_regions @@ -256,15 +304,14 @@ class OrganizationAggregationSource(ConfigEmptyDictable): class ConfigAggregator(ConfigEmptyDictable): - def __init__(self, name, region, account_sources=None, org_source=None, tags=None): - super(ConfigAggregator, self).__init__(capitalize_start=True, capitalize_arn=False) + super(ConfigAggregator, self).__init__( + capitalize_start=True, capitalize_arn=False + ) self.configuration_aggregator_name = name - self.configuration_aggregator_arn = 'arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}'.format( - region=region, - id=DEFAULT_ACCOUNT_ID, - random=random_string() + self.configuration_aggregator_arn = "arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}".format( + region=region, id=DEFAULT_ACCOUNT_ID, random=random_string() ) self.account_aggregation_sources = account_sources self.organization_aggregation_source = org_source @@ -280,7 +327,9 @@ class ConfigAggregator(ConfigEmptyDictable): # Override the account aggregation sources if present: if self.account_aggregation_sources: - result['AccountAggregationSources'] = [a.to_dict() for a in self.account_aggregation_sources] + result["AccountAggregationSources"] = [ + a.to_dict() for a in self.account_aggregation_sources + ] # Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to! # if self.tags: @@ -290,15 +339,22 @@ class ConfigAggregator(ConfigEmptyDictable): class ConfigAggregationAuthorization(ConfigEmptyDictable): + def __init__( + self, current_region, authorized_account_id, authorized_aws_region, tags=None + ): + super(ConfigAggregationAuthorization, self).__init__( + capitalize_start=True, capitalize_arn=False + ) - def __init__(self, current_region, authorized_account_id, authorized_aws_region, tags=None): - super(ConfigAggregationAuthorization, self).__init__(capitalize_start=True, capitalize_arn=False) - - self.aggregation_authorization_arn = 'arn:aws:config:{region}:{id}:aggregation-authorization/' \ - '{auth_account}/{auth_region}'.format(region=current_region, - id=DEFAULT_ACCOUNT_ID, - auth_account=authorized_account_id, - auth_region=authorized_aws_region) + self.aggregation_authorization_arn = ( + "arn:aws:config:{region}:{id}:aggregation-authorization/" + "{auth_account}/{auth_region}".format( + region=current_region, + id=DEFAULT_ACCOUNT_ID, + auth_account=authorized_account_id, + auth_region=authorized_aws_region, + ) + ) self.authorized_account_id = authorized_account_id self.authorized_aws_region = authorized_aws_region self.creation_time = datetime2int(datetime.utcnow()) @@ -308,7 +364,6 @@ class ConfigAggregationAuthorization(ConfigEmptyDictable): class ConfigBackend(BaseBackend): - def __init__(self): self.recorders = {} self.delivery_channels = {} @@ -318,9 +373,11 @@ class ConfigBackend(BaseBackend): @staticmethod def _validate_resource_types(resource_list): # Load the service file: - resource_package = 'botocore' - resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) - config_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) + resource_package = "botocore" + resource_path = "/".join(("data", "config", "2014-11-12", "service-2.json")) + config_schema = json.loads( + pkg_resources.resource_string(resource_package, resource_path) + ) # Verify that each entry exists in the supported list: bad_list = [] @@ -328,72 +385,114 @@ class ConfigBackend(BaseBackend): # For PY2: r_str = str(resource) - if r_str not in config_schema['shapes']['ResourceType']['enum']: + if r_str not in config_schema["shapes"]["ResourceType"]["enum"]: bad_list.append(r_str) if bad_list: - raise InvalidResourceTypeException(bad_list, config_schema['shapes']['ResourceType']['enum']) + raise InvalidResourceTypeException( + bad_list, config_schema["shapes"]["ResourceType"]["enum"] + ) @staticmethod def _validate_delivery_snapshot_properties(properties): # Load the service file: - resource_package = 'botocore' - resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) - conifg_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) + resource_package = "botocore" + resource_path = "/".join(("data", "config", "2014-11-12", "service-2.json")) + conifg_schema = json.loads( + pkg_resources.resource_string(resource_package, resource_path) + ) # Verify that the deliveryFrequency is set to an acceptable value: - if properties.get('deliveryFrequency', None) not in \ - conifg_schema['shapes']['MaximumExecutionFrequency']['enum']: - raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None), - conifg_schema['shapes']['MaximumExecutionFrequency']['enum']) + if ( + properties.get("deliveryFrequency", None) + not in conifg_schema["shapes"]["MaximumExecutionFrequency"]["enum"] + ): + raise InvalidDeliveryFrequency( + properties.get("deliveryFrequency", None), + conifg_schema["shapes"]["MaximumExecutionFrequency"]["enum"], + ) def put_configuration_aggregator(self, config_aggregator, region): # Validate the name: - if len(config_aggregator['ConfigurationAggregatorName']) > 256: - raise NameTooLongException(config_aggregator['ConfigurationAggregatorName'], 'configurationAggregatorName') + if len(config_aggregator["ConfigurationAggregatorName"]) > 256: + raise NameTooLongException( + config_aggregator["ConfigurationAggregatorName"], + "configurationAggregatorName", + ) account_sources = None org_source = None # Tag validation: - tags = validate_tags(config_aggregator.get('Tags', [])) + tags = validate_tags(config_aggregator.get("Tags", [])) # Exception if both AccountAggregationSources and OrganizationAggregationSource are supplied: - if config_aggregator.get('AccountAggregationSources') and config_aggregator.get('OrganizationAggregationSource'): - raise InvalidParameterValueException('The configuration aggregator cannot be created because your request contains both the' - ' AccountAggregationSource and the OrganizationAggregationSource. Include only ' - 'one aggregation source and try again.') + if config_aggregator.get("AccountAggregationSources") and config_aggregator.get( + "OrganizationAggregationSource" + ): + raise InvalidParameterValueException( + "The configuration aggregator cannot be created because your request contains both the" + " AccountAggregationSource and the OrganizationAggregationSource. Include only " + "one aggregation source and try again." + ) # If neither are supplied: - if not config_aggregator.get('AccountAggregationSources') and not config_aggregator.get('OrganizationAggregationSource'): - raise InvalidParameterValueException('The configuration aggregator cannot be created because your request is missing either ' - 'the AccountAggregationSource or the OrganizationAggregationSource. Include the ' - 'appropriate aggregation source and try again.') + if not config_aggregator.get( + "AccountAggregationSources" + ) and not config_aggregator.get("OrganizationAggregationSource"): + raise InvalidParameterValueException( + "The configuration aggregator cannot be created because your request is missing either " + "the AccountAggregationSource or the OrganizationAggregationSource. Include the " + "appropriate aggregation source and try again." + ) - if config_aggregator.get('AccountAggregationSources'): + if config_aggregator.get("AccountAggregationSources"): # Currently, only 1 account aggregation source can be set: - if len(config_aggregator['AccountAggregationSources']) > 1: - raise TooManyAccountSources(len(config_aggregator['AccountAggregationSources'])) + if len(config_aggregator["AccountAggregationSources"]) > 1: + raise TooManyAccountSources( + len(config_aggregator["AccountAggregationSources"]) + ) account_sources = [] - for a in config_aggregator['AccountAggregationSources']: - account_sources.append(AccountAggregatorSource(a['AccountIds'], aws_regions=a.get('AwsRegions'), - all_aws_regions=a.get('AllAwsRegions'))) + for a in config_aggregator["AccountAggregationSources"]: + account_sources.append( + AccountAggregatorSource( + a["AccountIds"], + aws_regions=a.get("AwsRegions"), + all_aws_regions=a.get("AllAwsRegions"), + ) + ) else: - org_source = OrganizationAggregationSource(config_aggregator['OrganizationAggregationSource']['RoleArn'], - aws_regions=config_aggregator['OrganizationAggregationSource'].get('AwsRegions'), - all_aws_regions=config_aggregator['OrganizationAggregationSource'].get( - 'AllAwsRegions')) + org_source = OrganizationAggregationSource( + config_aggregator["OrganizationAggregationSource"]["RoleArn"], + aws_regions=config_aggregator["OrganizationAggregationSource"].get( + "AwsRegions" + ), + all_aws_regions=config_aggregator["OrganizationAggregationSource"].get( + "AllAwsRegions" + ), + ) # Grab the existing one if it exists and update it: - if not self.config_aggregators.get(config_aggregator['ConfigurationAggregatorName']): - aggregator = ConfigAggregator(config_aggregator['ConfigurationAggregatorName'], region, account_sources=account_sources, - org_source=org_source, tags=tags) - self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] = aggregator + if not self.config_aggregators.get( + config_aggregator["ConfigurationAggregatorName"] + ): + aggregator = ConfigAggregator( + config_aggregator["ConfigurationAggregatorName"], + region, + account_sources=account_sources, + org_source=org_source, + tags=tags, + ) + self.config_aggregators[ + config_aggregator["ConfigurationAggregatorName"] + ] = aggregator else: - aggregator = self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] + aggregator = self.config_aggregators[ + config_aggregator["ConfigurationAggregatorName"] + ] aggregator.tags = tags aggregator.account_aggregation_sources = account_sources aggregator.organization_aggregation_source = org_source @@ -404,7 +503,7 @@ class ConfigBackend(BaseBackend): def describe_configuration_aggregators(self, names, token, limit): limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit agg_list = [] - result = {'ConfigurationAggregators': []} + result = {"ConfigurationAggregators": []} if names: for name in names: @@ -434,11 +533,13 @@ class ConfigBackend(BaseBackend): start = sorted_aggregators.index(token) # Get the list of items to collect: - agg_list = sorted_aggregators[start:(start + limit)] - result['ConfigurationAggregators'] = [self.config_aggregators[agg].to_dict() for agg in agg_list] + agg_list = sorted_aggregators[start : (start + limit)] + result["ConfigurationAggregators"] = [ + self.config_aggregators[agg].to_dict() for agg in agg_list + ] if len(sorted_aggregators) > (start + limit): - result['NextToken'] = sorted_aggregators[start + limit] + result["NextToken"] = sorted_aggregators[start + limit] return result @@ -448,16 +549,22 @@ class ConfigBackend(BaseBackend): del self.config_aggregators[config_aggregator] - def put_aggregation_authorization(self, current_region, authorized_account, authorized_region, tags): + def put_aggregation_authorization( + self, current_region, authorized_account, authorized_region, tags + ): # Tag validation: tags = validate_tags(tags or []) # Does this already exist? - key = '{}/{}'.format(authorized_account, authorized_region) + key = "{}/{}".format(authorized_account, authorized_region) agg_auth = self.aggregation_authorizations.get(key) if not agg_auth: - agg_auth = ConfigAggregationAuthorization(current_region, authorized_account, authorized_region, tags=tags) - self.aggregation_authorizations['{}/{}'.format(authorized_account, authorized_region)] = agg_auth + agg_auth = ConfigAggregationAuthorization( + current_region, authorized_account, authorized_region, tags=tags + ) + self.aggregation_authorizations[ + "{}/{}".format(authorized_account, authorized_region) + ] = agg_auth else: # Only update the tags: agg_auth.tags = tags @@ -466,7 +573,7 @@ class ConfigBackend(BaseBackend): def describe_aggregation_authorizations(self, token, limit): limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit - result = {'AggregationAuthorizations': []} + result = {"AggregationAuthorizations": []} if not self.aggregation_authorizations: return result @@ -485,70 +592,82 @@ class ConfigBackend(BaseBackend): start = sorted_authorizations.index(token) # Get the list of items to collect: - auth_list = sorted_authorizations[start:(start + limit)] - result['AggregationAuthorizations'] = [self.aggregation_authorizations[auth].to_dict() for auth in auth_list] + auth_list = sorted_authorizations[start : (start + limit)] + result["AggregationAuthorizations"] = [ + self.aggregation_authorizations[auth].to_dict() for auth in auth_list + ] if len(sorted_authorizations) > (start + limit): - result['NextToken'] = sorted_authorizations[start + limit] + result["NextToken"] = sorted_authorizations[start + limit] return result def delete_aggregation_authorization(self, authorized_account, authorized_region): # This will always return a 200 -- regardless if there is or isn't an existing # aggregation authorization. - key = '{}/{}'.format(authorized_account, authorized_region) + key = "{}/{}".format(authorized_account, authorized_region) self.aggregation_authorizations.pop(key, None) def put_configuration_recorder(self, config_recorder): # Validate the name: - if not config_recorder.get('name'): - raise InvalidConfigurationRecorderNameException(config_recorder.get('name')) - if len(config_recorder.get('name')) > 256: - raise NameTooLongException(config_recorder.get('name'), 'configurationRecorder.name') + if not config_recorder.get("name"): + raise InvalidConfigurationRecorderNameException(config_recorder.get("name")) + if len(config_recorder.get("name")) > 256: + raise NameTooLongException( + config_recorder.get("name"), "configurationRecorder.name" + ) # We're going to assume that the passed in Role ARN is correct. # Config currently only allows 1 configuration recorder for an account: - if len(self.recorders) == 1 and not self.recorders.get(config_recorder['name']): - raise MaxNumberOfConfigurationRecordersExceededException(config_recorder['name']) + if len(self.recorders) == 1 and not self.recorders.get(config_recorder["name"]): + raise MaxNumberOfConfigurationRecordersExceededException( + config_recorder["name"] + ) # Is this updating an existing one? recorder_status = None - if self.recorders.get(config_recorder['name']): - recorder_status = self.recorders[config_recorder['name']].status + if self.recorders.get(config_recorder["name"]): + recorder_status = self.recorders[config_recorder["name"]].status # Validate the Recording Group: - if config_recorder.get('recordingGroup') is None: + if config_recorder.get("recordingGroup") is None: recording_group = RecordingGroup() else: - rg = config_recorder['recordingGroup'] + rg = config_recorder["recordingGroup"] # If an empty dict is passed in, then bad: if not rg: raise InvalidRecordingGroupException() # Can't have both the resource types specified and the other flags as True. - if rg.get('resourceTypes') and ( - rg.get('allSupported', False) or - rg.get('includeGlobalResourceTypes', False)): + if rg.get("resourceTypes") and ( + rg.get("allSupported", False) + or rg.get("includeGlobalResourceTypes", False) + ): raise InvalidRecordingGroupException() # Must supply resourceTypes if 'allSupported' is not supplied: - if not rg.get('allSupported') and not rg.get('resourceTypes'): + if not rg.get("allSupported") and not rg.get("resourceTypes"): raise InvalidRecordingGroupException() # Validate that the list provided is correct: - self._validate_resource_types(rg.get('resourceTypes', [])) + self._validate_resource_types(rg.get("resourceTypes", [])) recording_group = RecordingGroup( - all_supported=rg.get('allSupported', True), - include_global_resource_types=rg.get('includeGlobalResourceTypes', False), - resource_types=rg.get('resourceTypes', []) + all_supported=rg.get("allSupported", True), + include_global_resource_types=rg.get( + "includeGlobalResourceTypes", False + ), + resource_types=rg.get("resourceTypes", []), ) - self.recorders[config_recorder['name']] = \ - ConfigRecorder(config_recorder['roleARN'], recording_group, name=config_recorder['name'], - status=recorder_status) + self.recorders[config_recorder["name"]] = ConfigRecorder( + config_recorder["roleARN"], + recording_group, + name=config_recorder["name"], + status=recorder_status, + ) def describe_configuration_recorders(self, recorder_names): recorders = [] @@ -590,43 +709,54 @@ class ConfigBackend(BaseBackend): raise NoAvailableConfigurationRecorderException() # Validate the name: - if not delivery_channel.get('name'): - raise InvalidDeliveryChannelNameException(delivery_channel.get('name')) - if len(delivery_channel.get('name')) > 256: - raise NameTooLongException(delivery_channel.get('name'), 'deliveryChannel.name') + if not delivery_channel.get("name"): + raise InvalidDeliveryChannelNameException(delivery_channel.get("name")) + if len(delivery_channel.get("name")) > 256: + raise NameTooLongException( + delivery_channel.get("name"), "deliveryChannel.name" + ) # We are going to assume that the bucket exists -- but will verify if the bucket provided is blank: - if not delivery_channel.get('s3BucketName'): + if not delivery_channel.get("s3BucketName"): raise NoSuchBucketException() # We are going to assume that the bucket has the correct policy attached to it. We are only going to verify # if the prefix provided is not an empty string: - if delivery_channel.get('s3KeyPrefix', None) == '': + if delivery_channel.get("s3KeyPrefix", None) == "": raise InvalidS3KeyPrefixException() # Ditto for SNS -- Only going to assume that the ARN provided is not an empty string: - if delivery_channel.get('snsTopicARN', None) == '': + if delivery_channel.get("snsTopicARN", None) == "": raise InvalidSNSTopicARNException() # Config currently only allows 1 delivery channel for an account: - if len(self.delivery_channels) == 1 and not self.delivery_channels.get(delivery_channel['name']): - raise MaxNumberOfDeliveryChannelsExceededException(delivery_channel['name']) + if len(self.delivery_channels) == 1 and not self.delivery_channels.get( + delivery_channel["name"] + ): + raise MaxNumberOfDeliveryChannelsExceededException(delivery_channel["name"]) - if not delivery_channel.get('configSnapshotDeliveryProperties'): + if not delivery_channel.get("configSnapshotDeliveryProperties"): dp = None else: # Validate the config snapshot delivery properties: - self._validate_delivery_snapshot_properties(delivery_channel['configSnapshotDeliveryProperties']) + self._validate_delivery_snapshot_properties( + delivery_channel["configSnapshotDeliveryProperties"] + ) dp = ConfigDeliverySnapshotProperties( - delivery_channel['configSnapshotDeliveryProperties']['deliveryFrequency']) + delivery_channel["configSnapshotDeliveryProperties"][ + "deliveryFrequency" + ] + ) - self.delivery_channels[delivery_channel['name']] = \ - ConfigDeliveryChannel(delivery_channel['name'], delivery_channel['s3BucketName'], - prefix=delivery_channel.get('s3KeyPrefix', None), - sns_arn=delivery_channel.get('snsTopicARN', None), - snapshot_properties=dp) + self.delivery_channels[delivery_channel["name"]] = ConfigDeliveryChannel( + delivery_channel["name"], + delivery_channel["s3BucketName"], + prefix=delivery_channel.get("s3KeyPrefix", None), + sns_arn=delivery_channel.get("snsTopicARN", None), + snapshot_properties=dp, + ) def describe_delivery_channels(self, channel_names): channels = [] @@ -680,8 +810,280 @@ class ConfigBackend(BaseBackend): del self.delivery_channels[channel_name] + def list_discovered_resources( + self, + resource_type, + backend_region, + resource_ids, + resource_name, + limit, + next_token, + ): + """This will query against the mocked AWS Config (non-aggregated) listing function that must exist for the resource backend. + + :param resource_type: + :param backend_region: + :param ids: + :param name: + :param limit: + :param next_token: + :return: + """ + identifiers = [] + new_token = None + + limit = limit or DEFAULT_PAGE_SIZE + if limit > DEFAULT_PAGE_SIZE: + raise InvalidLimit(limit) + + if resource_ids and resource_name: + raise InvalidResourceParameters() + + # Only 20 maximum Resource IDs: + if resource_ids and len(resource_ids) > 20: + raise TooManyResourceIds() + + # If the resource type exists and the backend region is implemented in moto, then + # call upon the resource type's Config Query class to retrieve the list of resources that match the criteria: + if RESOURCE_MAP.get(resource_type, {}): + # Is this a global resource type? -- if so, re-write the region to 'global': + backend_query_region = ( + backend_region # Always provide the backend this request arrived from. + ) + if RESOURCE_MAP[resource_type].backends.get("global"): + backend_region = "global" + + # For non-aggregated queries, the we only care about the backend_region. Need to verify that moto has implemented + # the region for the given backend: + if RESOURCE_MAP[resource_type].backends.get(backend_region): + # Fetch the resources for the backend's region: + identifiers, new_token = RESOURCE_MAP[ + resource_type + ].list_config_service_resources( + resource_ids, + resource_name, + limit, + next_token, + backend_region=backend_query_region, + ) + + result = { + "resourceIdentifiers": [ + { + "resourceType": identifier["type"], + "resourceId": identifier["id"], + "resourceName": identifier["name"], + } + for identifier in identifiers + ] + } + + if new_token: + result["nextToken"] = new_token + + return result + + def list_aggregate_discovered_resources( + self, aggregator_name, resource_type, filters, limit, next_token + ): + """This will query against the mocked AWS Config listing function that must exist for the resource backend. + + As far a moto goes -- the only real difference between this function and the `list_discovered_resources` function is that + this will require a Config Aggregator be set up a priori and can search based on resource regions. + + :param aggregator_name: + :param resource_type: + :param filters: + :param limit: + :param next_token: + :return: + """ + if not self.config_aggregators.get(aggregator_name): + raise NoSuchConfigurationAggregatorException() + + identifiers = [] + new_token = None + filters = filters or {} + + limit = limit or DEFAULT_PAGE_SIZE + if limit > DEFAULT_PAGE_SIZE: + raise InvalidLimit(limit) + + # If the resource type exists and the backend region is implemented in moto, then + # call upon the resource type's Config Query class to retrieve the list of resources that match the criteria: + if RESOURCE_MAP.get(resource_type, {}): + # We only care about a filter's Region, Resource Name, and Resource ID: + resource_region = filters.get("Region") + resource_id = [filters["ResourceId"]] if filters.get("ResourceId") else None + resource_name = filters.get("ResourceName") + + identifiers, new_token = RESOURCE_MAP[ + resource_type + ].list_config_service_resources( + resource_id, + resource_name, + limit, + next_token, + resource_region=resource_region, + ) + + result = { + "ResourceIdentifiers": [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": identifier["region"], + "ResourceType": identifier["type"], + "ResourceId": identifier["id"], + "ResourceName": identifier["name"], + } + for identifier in identifiers + ] + } + + if new_token: + result["NextToken"] = new_token + + return result + + def get_resource_config_history(self, resource_type, id, backend_region): + """Returns the configuration of an item in the AWS Config format of the resource for the current regional backend. + + NOTE: This is --NOT-- returning history as it is not supported in moto at this time. (PR's welcome!) + As such, the later_time, earlier_time, limit, and next_token are ignored as this will only + return 1 item. (If no items, it raises an exception) + """ + # If the type isn't implemented then we won't find the item: + if resource_type not in RESOURCE_MAP: + raise ResourceNotDiscoveredException(resource_type, id) + + # Is the resource type global? + backend_query_region = ( + backend_region # Always provide the backend this request arrived from. + ) + if RESOURCE_MAP[resource_type].backends.get("global"): + backend_region = "global" + + # If the backend region isn't implemented then we won't find the item: + if not RESOURCE_MAP[resource_type].backends.get(backend_region): + raise ResourceNotDiscoveredException(resource_type, id) + + # Get the item: + item = RESOURCE_MAP[resource_type].get_config_resource( + id, backend_region=backend_query_region + ) + if not item: + raise ResourceNotDiscoveredException(resource_type, id) + + item["accountId"] = DEFAULT_ACCOUNT_ID + + return {"configurationItems": [item]} + + def batch_get_resource_config(self, resource_keys, backend_region): + """Returns the configuration of an item in the AWS Config format of the resource for the current regional backend. + + :param resource_keys: + :param backend_region: + """ + # Can't have more than 100 items + if len(resource_keys) > 100: + raise TooManyResourceKeys( + ["com.amazonaws.starling.dove.ResourceKey@12345"] * len(resource_keys) + ) + + results = [] + for resource in resource_keys: + # Does the resource type exist? + if not RESOURCE_MAP.get(resource["resourceType"]): + # Not found so skip. + continue + + # Is the resource type global? + config_backend_region = backend_region + backend_query_region = ( + backend_region # Always provide the backend this request arrived from. + ) + if RESOURCE_MAP[resource["resourceType"]].backends.get("global"): + config_backend_region = "global" + + # If the backend region isn't implemented then we won't find the item: + if not RESOURCE_MAP[resource["resourceType"]].backends.get( + config_backend_region + ): + continue + + # Get the item: + item = RESOURCE_MAP[resource["resourceType"]].get_config_resource( + resource["resourceId"], backend_region=backend_query_region + ) + if not item: + continue + + item["accountId"] = DEFAULT_ACCOUNT_ID + + results.append(item) + + return { + "baseConfigurationItems": results, + "unprocessedResourceKeys": [], + } # At this time, moto is not adding unprocessed items. + + def batch_get_aggregate_resource_config( + self, aggregator_name, resource_identifiers + ): + """Returns the configuration of an item in the AWS Config format of the resource for the current regional backend. + + As far a moto goes -- the only real difference between this function and the `batch_get_resource_config` function is that + this will require a Config Aggregator be set up a priori and can search based on resource regions. + + Note: moto will IGNORE the resource account ID in the search query. + """ + if not self.config_aggregators.get(aggregator_name): + raise NoSuchConfigurationAggregatorException() + + # Can't have more than 100 items + if len(resource_identifiers) > 100: + raise TooManyResourceKeys( + ["com.amazonaws.starling.dove.AggregateResourceIdentifier@12345"] + * len(resource_identifiers) + ) + + found = [] + not_found = [] + for identifier in resource_identifiers: + resource_type = identifier["ResourceType"] + resource_region = identifier["SourceRegion"] + resource_id = identifier["ResourceId"] + resource_name = identifier.get("ResourceName", None) + + # Does the resource type exist? + if not RESOURCE_MAP.get(resource_type): + not_found.append(identifier) + continue + + # Get the item: + item = RESOURCE_MAP[resource_type].get_config_resource( + resource_id, + resource_name=resource_name, + resource_region=resource_region, + ) + if not item: + not_found.append(identifier) + continue + + item["accountId"] = DEFAULT_ACCOUNT_ID + + # The 'tags' field is not included in aggregate results for some reason... + item.pop("tags", None) + + found.append(item) + + return { + "BaseConfigurationItems": found, + "UnprocessedResourceIdentifiers": not_found, + } + config_backends = {} boto3_session = Session() -for region in boto3_session.get_available_regions('config'): +for region in boto3_session.get_available_regions("config"): config_backends[region] = ConfigBackend() diff --git a/moto/config/responses.py b/moto/config/responses.py index 03612d403..e977945c9 100644 --- a/moto/config/responses.py +++ b/moto/config/responses.py @@ -4,83 +4,150 @@ from .models import config_backends class ConfigResponse(BaseResponse): - @property def config_backend(self): return config_backends[self.region] def put_configuration_recorder(self): - self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder')) + self.config_backend.put_configuration_recorder( + self._get_param("ConfigurationRecorder") + ) return "" def put_configuration_aggregator(self): - aggregator = self.config_backend.put_configuration_aggregator(json.loads(self.body), self.region) - schema = {'ConfigurationAggregator': aggregator} + aggregator = self.config_backend.put_configuration_aggregator( + json.loads(self.body), self.region + ) + schema = {"ConfigurationAggregator": aggregator} return json.dumps(schema) def describe_configuration_aggregators(self): - aggregators = self.config_backend.describe_configuration_aggregators(self._get_param('ConfigurationAggregatorNames'), - self._get_param('NextToken'), - self._get_param('Limit')) + aggregators = self.config_backend.describe_configuration_aggregators( + self._get_param("ConfigurationAggregatorNames"), + self._get_param("NextToken"), + self._get_param("Limit"), + ) return json.dumps(aggregators) def delete_configuration_aggregator(self): - self.config_backend.delete_configuration_aggregator(self._get_param('ConfigurationAggregatorName')) + self.config_backend.delete_configuration_aggregator( + self._get_param("ConfigurationAggregatorName") + ) return "" def put_aggregation_authorization(self): - agg_auth = self.config_backend.put_aggregation_authorization(self.region, - self._get_param('AuthorizedAccountId'), - self._get_param('AuthorizedAwsRegion'), - self._get_param('Tags')) - schema = {'AggregationAuthorization': agg_auth} + agg_auth = self.config_backend.put_aggregation_authorization( + self.region, + self._get_param("AuthorizedAccountId"), + self._get_param("AuthorizedAwsRegion"), + self._get_param("Tags"), + ) + schema = {"AggregationAuthorization": agg_auth} return json.dumps(schema) def describe_aggregation_authorizations(self): - authorizations = self.config_backend.describe_aggregation_authorizations(self._get_param('NextToken'), self._get_param('Limit')) + authorizations = self.config_backend.describe_aggregation_authorizations( + self._get_param("NextToken"), self._get_param("Limit") + ) return json.dumps(authorizations) def delete_aggregation_authorization(self): - self.config_backend.delete_aggregation_authorization(self._get_param('AuthorizedAccountId'), self._get_param('AuthorizedAwsRegion')) + self.config_backend.delete_aggregation_authorization( + self._get_param("AuthorizedAccountId"), + self._get_param("AuthorizedAwsRegion"), + ) return "" def describe_configuration_recorders(self): - recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames')) - schema = {'ConfigurationRecorders': recorders} + recorders = self.config_backend.describe_configuration_recorders( + self._get_param("ConfigurationRecorderNames") + ) + schema = {"ConfigurationRecorders": recorders} return json.dumps(schema) def describe_configuration_recorder_status(self): recorder_statuses = self.config_backend.describe_configuration_recorder_status( - self._get_param('ConfigurationRecorderNames')) - schema = {'ConfigurationRecordersStatus': recorder_statuses} + self._get_param("ConfigurationRecorderNames") + ) + schema = {"ConfigurationRecordersStatus": recorder_statuses} return json.dumps(schema) def put_delivery_channel(self): - self.config_backend.put_delivery_channel(self._get_param('DeliveryChannel')) + self.config_backend.put_delivery_channel(self._get_param("DeliveryChannel")) return "" def describe_delivery_channels(self): - delivery_channels = self.config_backend.describe_delivery_channels(self._get_param('DeliveryChannelNames')) - schema = {'DeliveryChannels': delivery_channels} + delivery_channels = self.config_backend.describe_delivery_channels( + self._get_param("DeliveryChannelNames") + ) + schema = {"DeliveryChannels": delivery_channels} return json.dumps(schema) def describe_delivery_channel_status(self): raise NotImplementedError() def delete_delivery_channel(self): - self.config_backend.delete_delivery_channel(self._get_param('DeliveryChannelName')) + self.config_backend.delete_delivery_channel( + self._get_param("DeliveryChannelName") + ) return "" def delete_configuration_recorder(self): - self.config_backend.delete_configuration_recorder(self._get_param('ConfigurationRecorderName')) + self.config_backend.delete_configuration_recorder( + self._get_param("ConfigurationRecorderName") + ) return "" def start_configuration_recorder(self): - self.config_backend.start_configuration_recorder(self._get_param('ConfigurationRecorderName')) + self.config_backend.start_configuration_recorder( + self._get_param("ConfigurationRecorderName") + ) return "" def stop_configuration_recorder(self): - self.config_backend.stop_configuration_recorder(self._get_param('ConfigurationRecorderName')) + self.config_backend.stop_configuration_recorder( + self._get_param("ConfigurationRecorderName") + ) return "" + + def list_discovered_resources(self): + schema = self.config_backend.list_discovered_resources( + self._get_param("resourceType"), + self.region, + self._get_param("resourceIds"), + self._get_param("resourceName"), + self._get_param("limit"), + self._get_param("nextToken"), + ) + return json.dumps(schema) + + def list_aggregate_discovered_resources(self): + schema = self.config_backend.list_aggregate_discovered_resources( + self._get_param("ConfigurationAggregatorName"), + self._get_param("ResourceType"), + self._get_param("Filters"), + self._get_param("Limit"), + self._get_param("NextToken"), + ) + return json.dumps(schema) + + def get_resource_config_history(self): + schema = self.config_backend.get_resource_config_history( + self._get_param("resourceType"), self._get_param("resourceId"), self.region + ) + return json.dumps(schema) + + def batch_get_resource_config(self): + schema = self.config_backend.batch_get_resource_config( + self._get_param("resourceKeys"), self.region + ) + return json.dumps(schema) + + def batch_get_aggregate_resource_config(self): + schema = self.config_backend.batch_get_aggregate_resource_config( + self._get_param("ConfigurationAggregatorName"), + self._get_param("ResourceIdentifiers"), + ) + return json.dumps(schema) diff --git a/moto/config/urls.py b/moto/config/urls.py index fd7b6969f..62cf34a52 100644 --- a/moto/config/urls.py +++ b/moto/config/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import ConfigResponse -url_bases = [ - "https?://config.(.+).amazonaws.com", -] +url_bases = ["https?://config.(.+).amazonaws.com"] -url_paths = { - '{0}/$': ConfigResponse.dispatch, -} +url_paths = {"{0}/$": ConfigResponse.dispatch} diff --git a/moto/core/__init__.py b/moto/core/__init__.py index 801e675df..045124fab 100644 --- a/moto/core/__init__.py +++ b/moto/core/__init__.py @@ -1,7 +1,9 @@ from __future__ import unicode_literals -from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa +from .models import BaseModel, BaseBackend, moto_api_backend, ACCOUNT_ID # noqa from .responses import ActionAuthenticatorMixin moto_api_backends = {"global": moto_api_backend} -set_initial_no_auth_action_count = ActionAuthenticatorMixin.set_initial_no_auth_action_count +set_initial_no_auth_action_count = ( + ActionAuthenticatorMixin.set_initial_no_auth_action_count +) diff --git a/moto/core/access_control.py b/moto/core/access_control.py index 3fb11eebd..8ba0c3ba1 100644 --- a/moto/core/access_control.py +++ b/moto/core/access_control.py @@ -24,9 +24,15 @@ from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials from six import string_types -from moto.iam.models import ACCOUNT_ID, Policy +from moto.core import ACCOUNT_ID +from moto.iam.models import Policy from moto.iam import iam_backend -from moto.core.exceptions import SignatureDoesNotMatchError, AccessDeniedError, InvalidClientTokenIdError, AuthFailureError +from moto.core.exceptions import ( + SignatureDoesNotMatchError, + AccessDeniedError, + InvalidClientTokenIdError, + AuthFailureError, +) from moto.s3.exceptions import ( BucketAccessDeniedError, S3AccessDeniedError, @@ -35,7 +41,7 @@ from moto.s3.exceptions import ( S3InvalidAccessKeyIdError, BucketInvalidAccessKeyIdError, BucketSignatureDoesNotMatchError, - S3SignatureDoesNotMatchError + S3SignatureDoesNotMatchError, ) from moto.sts import sts_backend @@ -50,9 +56,8 @@ def create_access_key(access_key_id, headers): class IAMUserAccessKey(object): - def __init__(self, access_key_id, headers): - iam_users = iam_backend.list_users('/', None, None) + iam_users = iam_backend.list_users("/", None, None) for iam_user in iam_users: for access_key in iam_user.access_keys: if access_key.access_key_id == access_key_id: @@ -67,8 +72,7 @@ class IAMUserAccessKey(object): @property def arn(self): return "arn:aws:iam::{account_id}:user/{iam_user_name}".format( - account_id=ACCOUNT_ID, - iam_user_name=self._owner_user_name + account_id=ACCOUNT_ID, iam_user_name=self._owner_user_name ) def create_credentials(self): @@ -79,27 +83,34 @@ class IAMUserAccessKey(object): inline_policy_names = iam_backend.list_user_policies(self._owner_user_name) for inline_policy_name in inline_policy_names: - inline_policy = iam_backend.get_user_policy(self._owner_user_name, inline_policy_name) + inline_policy = iam_backend.get_user_policy( + self._owner_user_name, inline_policy_name + ) user_policies.append(inline_policy) - attached_policies, _ = iam_backend.list_attached_user_policies(self._owner_user_name) + attached_policies, _ = iam_backend.list_attached_user_policies( + self._owner_user_name + ) user_policies += attached_policies user_groups = iam_backend.get_groups_for_user(self._owner_user_name) for user_group in user_groups: inline_group_policy_names = iam_backend.list_group_policies(user_group.name) for inline_group_policy_name in inline_group_policy_names: - inline_user_group_policy = iam_backend.get_group_policy(user_group.name, inline_group_policy_name) + inline_user_group_policy = iam_backend.get_group_policy( + user_group.name, inline_group_policy_name + ) user_policies.append(inline_user_group_policy) - attached_group_policies, _ = iam_backend.list_attached_group_policies(user_group.name) + attached_group_policies, _ = iam_backend.list_attached_group_policies( + user_group.name + ) user_policies += attached_group_policies return user_policies class AssumedRoleAccessKey(object): - def __init__(self, access_key_id, headers): for assumed_role in sts_backend.assumed_roles: if assumed_role.access_key_id == access_key_id: @@ -118,28 +129,33 @@ class AssumedRoleAccessKey(object): return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( account_id=ACCOUNT_ID, role_name=self._owner_role_name, - session_name=self._session_name + session_name=self._session_name, ) def create_credentials(self): - return Credentials(self._access_key_id, self._secret_access_key, self._session_token) + return Credentials( + self._access_key_id, self._secret_access_key, self._session_token + ) def collect_policies(self): role_policies = [] inline_policy_names = iam_backend.list_role_policies(self._owner_role_name) for inline_policy_name in inline_policy_names: - _, inline_policy = iam_backend.get_role_policy(self._owner_role_name, inline_policy_name) + _, inline_policy = iam_backend.get_role_policy( + self._owner_role_name, inline_policy_name + ) role_policies.append(inline_policy) - attached_policies, _ = iam_backend.list_attached_role_policies(self._owner_role_name) + attached_policies, _ = iam_backend.list_attached_role_policies( + self._owner_role_name + ) role_policies += attached_policies return role_policies class CreateAccessKeyFailure(Exception): - def __init__(self, reason, *args): super(CreateAccessKeyFailure, self).__init__(*args) self.reason = reason @@ -147,32 +163,54 @@ class CreateAccessKeyFailure(Exception): @six.add_metaclass(ABCMeta) class IAMRequestBase(object): - def __init__(self, method, path, data, headers): - log.debug("Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format( - class_name=self.__class__.__name__, method=method, path=path, data=data, headers=headers)) + log.debug( + "Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format( + class_name=self.__class__.__name__, + method=method, + path=path, + data=data, + headers=headers, + ) + ) self._method = method self._path = path self._data = data self._headers = headers - credential_scope = self._get_string_between('Credential=', ',', self._headers['Authorization']) - credential_data = credential_scope.split('/') + credential_scope = self._get_string_between( + "Credential=", ",", self._headers["Authorization"] + ) + credential_data = credential_scope.split("/") self._region = credential_data[2] self._service = credential_data[3] - self._action = self._service + ":" + (self._data["Action"][0] if isinstance(self._data["Action"], list) else self._data["Action"]) + self._action = ( + self._service + + ":" + + ( + self._data["Action"][0] + if isinstance(self._data["Action"], list) + else self._data["Action"] + ) + ) try: - self._access_key = create_access_key(access_key_id=credential_data[0], headers=headers) + self._access_key = create_access_key( + access_key_id=credential_data[0], headers=headers + ) except CreateAccessKeyFailure as e: self._raise_invalid_access_key(e.reason) def check_signature(self): - original_signature = self._get_string_between('Signature=', ',', self._headers['Authorization']) + original_signature = self._get_string_between( + "Signature=", ",", self._headers["Authorization"] + ) calculated_signature = self._calculate_signature() if original_signature != calculated_signature: self._raise_signature_does_not_match() def check_action_permitted(self): - if self._action == 'sts:GetCallerIdentity': # always allowed, even if there's an explicit Deny for it + if ( + self._action == "sts:GetCallerIdentity" + ): # always allowed, even if there's an explicit Deny for it return True policies = self._access_key.collect_policies() @@ -213,10 +251,14 @@ class IAMRequestBase(object): return headers def _create_aws_request(self): - signed_headers = self._get_string_between('SignedHeaders=', ',', self._headers['Authorization']).split(';') + signed_headers = self._get_string_between( + "SignedHeaders=", ",", self._headers["Authorization"] + ).split(";") headers = self._create_headers_for_aws_request(signed_headers, self._headers) - request = AWSRequest(method=self._method, url=self._path, data=self._data, headers=headers) - request.context['timestamp'] = headers['X-Amz-Date'] + request = AWSRequest( + method=self._method, url=self._path, data=self._data, headers=headers + ) + request.context["timestamp"] = headers["X-Amz-Date"] return request @@ -234,7 +276,6 @@ class IAMRequestBase(object): class IAMRequest(IAMRequestBase): - def _raise_signature_does_not_match(self): if self._service == "ec2": raise AuthFailureError() @@ -251,14 +292,10 @@ class IAMRequest(IAMRequestBase): return SigV4Auth(credentials, self._service, self._region) def _raise_access_denied(self): - raise AccessDeniedError( - user_arn=self._access_key.arn, - action=self._action - ) + raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action) class S3IAMRequest(IAMRequestBase): - def _raise_signature_does_not_match(self): if "BucketName" in self._data: raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"]) @@ -288,10 +325,13 @@ class S3IAMRequest(IAMRequestBase): class IAMPolicy(object): - def __init__(self, policy): if isinstance(policy, Policy): - default_version = next(policy_version for policy_version in policy.versions if policy_version.is_default) + default_version = next( + policy_version + for policy_version in policy.versions + if policy_version.is_default + ) policy_document = default_version.document elif isinstance(policy, string_types): policy_document = policy @@ -321,7 +361,6 @@ class IAMPolicy(object): class IAMPolicyStatement(object): - def __init__(self, statement): self._statement = statement diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index 06cfd8895..ea91eda63 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException from jinja2 import DictLoader, Environment -SINGLE_ERROR_RESPONSE = u""" +SINGLE_ERROR_RESPONSE = """ {{error_type}} {{message}} @@ -13,8 +13,8 @@ SINGLE_ERROR_RESPONSE = u""" """ -ERROR_RESPONSE = u""" - +ERROR_RESPONSE = """ + {{error_type}} @@ -23,10 +23,10 @@ ERROR_RESPONSE = u""" 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE - + """ -ERROR_JSON_RESPONSE = u"""{ +ERROR_JSON_RESPONSE = """{ "message": "{{message}}", "__type": "{{error_type}}" } @@ -37,18 +37,19 @@ class RESTError(HTTPException): code = 400 templates = { - 'single_error': SINGLE_ERROR_RESPONSE, - 'error': ERROR_RESPONSE, - 'error_json': ERROR_JSON_RESPONSE, + "single_error": SINGLE_ERROR_RESPONSE, + "error": ERROR_RESPONSE, + "error_json": ERROR_JSON_RESPONSE, } - def __init__(self, error_type, message, template='error', **kwargs): + def __init__(self, error_type, message, template="error", **kwargs): super(RESTError, self).__init__() env = Environment(loader=DictLoader(self.templates)) self.error_type = error_type self.message = message self.description = env.get_template(template).render( - error_type=error_type, message=message, **kwargs) + error_type=error_type, message=message, **kwargs + ) class DryRunClientError(RESTError): @@ -56,12 +57,11 @@ class DryRunClientError(RESTError): class JsonRESTError(RESTError): - def __init__(self, error_type, message, template='error_json', **kwargs): - super(JsonRESTError, self).__init__( - error_type, message, template, **kwargs) + def __init__(self, error_type, message, template="error_json", **kwargs): + super(JsonRESTError, self).__init__(error_type, message, template, **kwargs) def get_headers(self, *args, **kwargs): - return [('Content-Type', 'application/json')] + return [("Content-Type", "application/json")] def get_body(self, *args, **kwargs): return self.description @@ -72,8 +72,9 @@ class SignatureDoesNotMatchError(RESTError): def __init__(self): super(SignatureDoesNotMatchError, self).__init__( - 'SignatureDoesNotMatch', - "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.") + "SignatureDoesNotMatch", + "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.", + ) class InvalidClientTokenIdError(RESTError): @@ -81,8 +82,9 @@ class InvalidClientTokenIdError(RESTError): def __init__(self): super(InvalidClientTokenIdError, self).__init__( - 'InvalidClientTokenId', - "The security token included in the request is invalid.") + "InvalidClientTokenId", + "The security token included in the request is invalid.", + ) class AccessDeniedError(RESTError): @@ -90,11 +92,11 @@ class AccessDeniedError(RESTError): def __init__(self, user_arn, action): super(AccessDeniedError, self).__init__( - 'AccessDenied', + "AccessDenied", "User: {user_arn} is not authorized to perform: {operation}".format( - user_arn=user_arn, - operation=action - )) + user_arn=user_arn, operation=action + ), + ) class AuthFailureError(RESTError): @@ -102,5 +104,17 @@ class AuthFailureError(RESTError): def __init__(self): super(AuthFailureError, self).__init__( - 'AuthFailure', - "AWS was not able to validate the provided access credentials") + "AuthFailure", + "AWS was not able to validate the provided access credentials", + ) + + +class InvalidNextTokenException(JsonRESTError): + """For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core.""" + + code = 400 + + def __init__(self): + super(InvalidNextTokenException, self).__init__( + "InvalidNextTokenException", "The nextToken provided is invalid" + ) diff --git a/moto/core/models.py b/moto/core/models.py index 896f9ac4a..3be3bbd8e 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -23,6 +23,9 @@ from .utils import ( ) +ACCOUNT_ID = os.environ.get("MOTO_ACCOUNT_ID", "123456789012") + + class BaseMockAWS(object): nested_count = 0 @@ -31,15 +34,20 @@ class BaseMockAWS(object): self.backends_for_urls = {} from moto.backends import BACKENDS + default_backends = { - "instance_metadata": BACKENDS['instance_metadata']['global'], - "moto_api": BACKENDS['moto_api']['global'], + "instance_metadata": BACKENDS["instance_metadata"]["global"], + "moto_api": BACKENDS["moto_api"]["global"], } self.backends_for_urls.update(self.backends) self.backends_for_urls.update(default_backends) # "Mock" the AWS credentials as they can't be mocked in Botocore currently - FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"} + FAKE_KEYS = { + "AWS_ACCESS_KEY_ID": "foobar_key", + "AWS_SECRET_ACCESS_KEY": "foobar_secret", + } + self.default_session_mock = mock.patch("boto3.DEFAULT_SESSION", None) self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS) if self.__class__.nested_count == 0: @@ -58,6 +66,7 @@ class BaseMockAWS(object): self.stop() def start(self, reset=True): + self.default_session_mock.start() self.env_variables_mocks.start() self.__class__.nested_count += 1 @@ -68,11 +77,12 @@ class BaseMockAWS(object): self.enable_patching() def stop(self): + self.default_session_mock.stop() self.env_variables_mocks.stop() self.__class__.nested_count -= 1 if self.__class__.nested_count < 0: - raise RuntimeError('Called stop() before start().') + raise RuntimeError("Called stop() before start().") if self.__class__.nested_count == 0: self.disable_patching() @@ -85,6 +95,7 @@ class BaseMockAWS(object): finally: self.stop() return result + functools.update_wrapper(wrapper, func) wrapper.__wrapped__ = func return wrapper @@ -122,7 +133,6 @@ class BaseMockAWS(object): class HttprettyMockAWS(BaseMockAWS): - def reset(self): HTTPretty.reset() @@ -144,18 +154,26 @@ class HttprettyMockAWS(BaseMockAWS): HTTPretty.reset() -RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD, - responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT] +RESPONSES_METHODS = [ + responses.GET, + responses.DELETE, + responses.HEAD, + responses.OPTIONS, + responses.PATCH, + responses.POST, + responses.PUT, +] class CallbackResponse(responses.CallbackResponse): - ''' + """ Need to subclass so we can change a couple things - ''' + """ + def get_response(self, request): - ''' + """ Need to override this so we can pass decode_content=False - ''' + """ headers = self.get_headers() result = self.callback(request) @@ -177,17 +195,17 @@ class CallbackResponse(responses.CallbackResponse): ) def _url_matches(self, url, other, match_querystring=False): - ''' + """ Need to override this so we can fix querystrings breaking regex matching - ''' + """ if not match_querystring: - other = other.split('?', 1)[0] + other = other.split("?", 1)[0] if responses._is_string(url): if responses._has_unicode(url): url = responses._clean_unicode(url) if not isinstance(other, six.text_type): - other = other.encode('ascii').decode('utf8') + other = other.encode("ascii").decode("utf8") return self._url_matches_strict(url, other) elif isinstance(url, responses.Pattern) and url.match(other): return True @@ -195,66 +213,23 @@ class CallbackResponse(responses.CallbackResponse): return False -botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') +botocore_mock = responses.RequestsMock( + assert_all_requests_are_fired=False, + target="botocore.vendored.requests.adapters.HTTPAdapter.send", +) responses_mock = responses._default_mock +# Add passthrough to allow any other requests to work +# Since this uses .startswith, it applies to http and https requests. +responses_mock.add_passthru("http") -class ResponsesMockAWS(BaseMockAWS): - def reset(self): - botocore_mock.reset() - responses_mock.reset() - - def enable_patching(self): - if not hasattr(botocore_mock, '_patcher') or not hasattr(botocore_mock._patcher, 'target'): - # Check for unactivated patcher - botocore_mock.start() - - if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'): - responses_mock.start() - - for method in RESPONSES_METHODS: - for backend in self.backends_for_urls.values(): - for key, value in backend.urls.items(): - responses_mock.add( - CallbackResponse( - method=method, - url=re.compile(key), - callback=convert_flask_to_responses_response(value), - stream=True, - match_querystring=False, - ) - ) - botocore_mock.add( - CallbackResponse( - method=method, - url=re.compile(key), - callback=convert_flask_to_responses_response(value), - stream=True, - match_querystring=False, - ) - ) - - def disable_patching(self): - try: - botocore_mock.stop() - except RuntimeError: - pass - - try: - responses_mock.stop() - except RuntimeError: - pass - - -BOTOCORE_HTTP_METHODS = [ - 'GET', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT' -] +BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"] class MockRawResponse(BytesIO): def __init__(self, input): if isinstance(input, six.text_type): - input = input.encode('utf-8') + input = input.encode("utf-8") super(MockRawResponse, self).__init__(input) def stream(self, **kwargs): @@ -285,7 +260,7 @@ class BotocoreStubber(object): found_index = None matchers = self.methods.get(request.method) - base_url = request.url.split('?', 1)[0] + base_url = request.url.split("?", 1)[0] for i, (pattern, callback) in enumerate(matchers): if pattern.match(base_url): if found_index is None: @@ -298,8 +273,10 @@ class BotocoreStubber(object): if response_callback is not None: for header, value in request.headers.items(): if isinstance(value, six.binary_type): - request.headers[header] = value.decode('utf-8') - status, headers, body = response_callback(request, request.url, request.headers) + request.headers[header] = value.decode("utf-8") + status, headers, body = response_callback( + request, request.url, request.headers + ) body = MockRawResponse(body) response = AWSResponse(request.url, status, headers, body) @@ -307,7 +284,15 @@ class BotocoreStubber(object): botocore_stubber = BotocoreStubber() -BUILTIN_HANDLERS.append(('before-send', botocore_stubber)) +BUILTIN_HANDLERS.append(("before-send", botocore_stubber)) + + +def not_implemented_callback(request): + status = 400 + headers = {} + response = "The method is not implemented" + + return status, headers, response class BotocoreEventMockAWS(BaseMockAWS): @@ -323,7 +308,9 @@ class BotocoreEventMockAWS(BaseMockAWS): pattern = re.compile(key) botocore_stubber.register_response(method, pattern, value) - if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'): + if not hasattr(responses_mock, "_patcher") or not hasattr( + responses_mock._patcher, "target" + ): responses_mock.start() for method in RESPONSES_METHODS: @@ -339,6 +326,24 @@ class BotocoreEventMockAWS(BaseMockAWS): match_querystring=False, ) ) + responses_mock.add( + CallbackResponse( + method=method, + url=re.compile("https?://.+.amazonaws.com/.*"), + callback=not_implemented_callback, + stream=True, + match_querystring=False, + ) + ) + botocore_mock.add( + CallbackResponse( + method=method, + url=re.compile("https?://.+.amazonaws.com/.*"), + callback=not_implemented_callback, + stream=True, + match_querystring=False, + ) + ) def disable_patching(self): botocore_stubber.enabled = False @@ -354,9 +359,9 @@ MockAWS = BotocoreEventMockAWS class ServerModeMockAWS(BaseMockAWS): - def reset(self): import requests + requests.post("http://localhost:5000/moto-api/reset") def enable_patching(self): @@ -368,13 +373,13 @@ class ServerModeMockAWS(BaseMockAWS): import mock def fake_boto3_client(*args, **kwargs): - if 'endpoint_url' not in kwargs: - kwargs['endpoint_url'] = "http://localhost:5000" + if "endpoint_url" not in kwargs: + kwargs["endpoint_url"] = "http://localhost:5000" return real_boto3_client(*args, **kwargs) def fake_boto3_resource(*args, **kwargs): - if 'endpoint_url' not in kwargs: - kwargs['endpoint_url'] = "http://localhost:5000" + if "endpoint_url" not in kwargs: + kwargs["endpoint_url"] = "http://localhost:5000" return real_boto3_resource(*args, **kwargs) def fake_httplib_send_output(self, message_body=None, *args, **kwargs): @@ -382,7 +387,7 @@ class ServerModeMockAWS(BaseMockAWS): bytes_buffer = [] for chunk in mixed_buffer: if isinstance(chunk, six.text_type): - bytes_buffer.append(chunk.encode('utf-8')) + bytes_buffer.append(chunk.encode("utf-8")) else: bytes_buffer.append(chunk) msg = b"\r\n".join(bytes_buffer) @@ -403,10 +408,12 @@ class ServerModeMockAWS(BaseMockAWS): if message_body is not None: self.send(message_body) - self._client_patcher = mock.patch('boto3.client', fake_boto3_client) - self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource) + self._client_patcher = mock.patch("boto3.client", fake_boto3_client) + self._resource_patcher = mock.patch("boto3.resource", fake_boto3_resource) if six.PY2: - self._httplib_patcher = mock.patch('httplib.HTTPConnection._send_output', fake_httplib_send_output) + self._httplib_patcher = mock.patch( + "httplib.HTTPConnection._send_output", fake_httplib_send_output + ) self._client_patcher.start() self._resource_patcher.start() @@ -422,7 +429,6 @@ class ServerModeMockAWS(BaseMockAWS): class Model(type): - def __new__(self, clsname, bases, namespace): cls = super(Model, self).__new__(self, clsname, bases, namespace) cls.__models__ = {} @@ -437,9 +443,11 @@ class Model(type): @staticmethod def prop(model_name): """ decorator to mark a class method as returning model values """ + def dec(f): f.__returns_model__ = model_name return f + return dec @@ -449,7 +457,7 @@ model_data = defaultdict(dict) class InstanceTrackerMeta(type): def __new__(meta, name, bases, dct): cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct) - if name == 'BaseModel': + if name == "BaseModel": return cls service = cls.__module__.split(".")[1] @@ -468,7 +476,6 @@ class BaseModel(object): class BaseBackend(object): - def _reset_model_refs(self): # Remove all references to the models stored for service, models in model_data.items(): @@ -484,8 +491,9 @@ class BaseBackend(object): def _url_module(self): backend_module = self.__class__.__module__ backend_urls_module_name = backend_module.replace("models", "urls") - backend_urls_module = __import__(backend_urls_module_name, fromlist=[ - 'url_bases', 'url_paths']) + backend_urls_module = __import__( + backend_urls_module_name, fromlist=["url_bases", "url_paths"] + ) return backend_urls_module @property @@ -541,9 +549,9 @@ class BaseBackend(object): def decorator(self, func=None): if settings.TEST_SERVER_MODE: - mocked_backend = ServerModeMockAWS({'global': self}) + mocked_backend = ServerModeMockAWS({"global": self}) else: - mocked_backend = MockAWS({'global': self}) + mocked_backend = MockAWS({"global": self}) if func: return mocked_backend(func) @@ -552,9 +560,100 @@ class BaseBackend(object): def deprecated_decorator(self, func=None): if func: - return HttprettyMockAWS({'global': self})(func) + return HttprettyMockAWS({"global": self})(func) else: - return HttprettyMockAWS({'global': self}) + return HttprettyMockAWS({"global": self}) + + # def list_config_service_resources(self, resource_ids, resource_name, limit, next_token): + # """For AWS Config. This will list all of the resources of the given type and optional resource name and region""" + # raise NotImplementedError() + + +class ConfigQueryModel(object): + def __init__(self, backends): + """Inits based on the resource type's backends (1 for each region if applicable)""" + self.backends = backends + + def list_config_service_resources( + self, + resource_ids, + resource_name, + limit, + next_token, + backend_region=None, + resource_region=None, + ): + """For AWS Config. This will list all of the resources of the given type and optional resource name and region. + + This supports both aggregated and non-aggregated listing. The following notes the difference: + + - Non-Aggregated Listing - + This only lists resources within a region. The way that this is implemented in moto is based on the region + for the resource backend. + + You must set the `backend_region` to the region that the API request arrived from. resource_region can be set to `None`. + + - Aggregated Listing - + This lists resources from all potential regional backends. For non-global resource types, this should collect a full + list of resources from all the backends, and then be able to filter from the resource region. This is because an + aggregator can aggregate resources from multiple regions. In moto, aggregated regions will *assume full aggregation + from all resources in all regions for a given resource type*. + + The `backend_region` should be set to `None` for these queries, and the `resource_region` should optionally be set to + the `Filters` region parameter to filter out resources that reside in a specific region. + + For aggregated listings, pagination logic should be set such that the next page can properly span all the region backends. + As such, the proper way to implement is to first obtain a full list of results from all the region backends, and then filter + from there. It may be valuable to make this a concatenation of the region and resource name. + + :param resource_region: + :param resource_ids: + :param resource_name: + :param limit: + :param next_token: + :param backend_region: The region for the backend to pull results from. Set to `None` if this is an aggregated query. + :return: This should return a list of Dicts that have the following fields: + [ + { + 'type': 'AWS::The AWS Config data type', + 'name': 'The name of the resource', + 'id': 'The ID of the resource', + 'region': 'The region of the resource -- if global, then you may want to have the calling logic pass in the + aggregator region in for the resource region -- or just us-east-1 :P' + } + , ... + ] + """ + raise NotImplementedError() + + def get_config_resource( + self, resource_id, resource_name=None, backend_region=None, resource_region=None + ): + """For AWS Config. This will query the backend for the specific resource type configuration. + + This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests + will call this function N times to fetch the N objects needing to be fetched. + + - Non-Aggregated Fetching - + This only fetches a resource config within a region. The way that this is implemented in moto is based on the region + for the resource backend. + + You must set the `backend_region` to the region that the API request arrived from. `resource_region` should be set to `None`. + + - Aggregated Fetching - + This fetches resources from all potential regional backends. For non-global resource types, this should collect a full + list of resources from all the backends, and then be able to filter from the resource region. This is because an + aggregator can aggregate resources from multiple regions. In moto, aggregated regions will *assume full aggregation + from all resources in all regions for a given resource type*. + + ... + :param resource_id: + :param resource_name: + :param backend_region: + :param resource_region: + :return: + """ + raise NotImplementedError() class base_decorator(object): @@ -580,9 +679,9 @@ class deprecated_base_decorator(base_decorator): class MotoAPIBackend(BaseBackend): - def reset(self): from moto.backends import BACKENDS + for name, backends in BACKENDS.items(): if name == "moto_api": continue diff --git a/moto/core/responses.py b/moto/core/responses.py index b60f10a20..c708edb8b 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -40,7 +40,7 @@ def _decode_dict(d): newkey = [] for k in key: if isinstance(k, six.binary_type): - newkey.append(k.decode('utf-8')) + newkey.append(k.decode("utf-8")) else: newkey.append(k) else: @@ -52,7 +52,7 @@ def _decode_dict(d): newvalue = [] for v in value: if isinstance(v, six.binary_type): - newvalue.append(v.decode('utf-8')) + newvalue.append(v.decode("utf-8")) else: newvalue.append(v) else: @@ -83,12 +83,15 @@ class DynamicDictLoader(DictLoader): class _TemplateEnvironmentMixin(object): + LEFT_PATTERN = re.compile(r"[\s\n]+<") + RIGHT_PATTERN = re.compile(r">[\s\n]+") def __init__(self): super(_TemplateEnvironmentMixin, self).__init__() self.loader = DynamicDictLoader({}) self.environment = Environment( - loader=self.loader, autoescape=self.should_autoescape) + loader=self.loader, autoescape=self.should_autoescape + ) @property def should_autoescape(self): @@ -101,9 +104,16 @@ class _TemplateEnvironmentMixin(object): def response_template(self, source): template_id = id(source) if not self.contains_template(template_id): - self.loader.update({template_id: source}) - self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape, trim_blocks=True, - lstrip_blocks=True) + collapsed = re.sub( + self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source) + ) + self.loader.update({template_id: collapsed}) + self.environment = Environment( + loader=self.loader, + autoescape=self.should_autoescape, + trim_blocks=True, + lstrip_blocks=True, + ) return self.environment.get_template(template_id) @@ -112,8 +122,13 @@ class ActionAuthenticatorMixin(object): request_count = 0 def _authenticate_and_authorize_action(self, iam_request_cls): - if ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT: - iam_request = iam_request_cls(method=self.method, path=self.path, data=self.data, headers=self.headers) + if ( + ActionAuthenticatorMixin.request_count + >= settings.INITIAL_NO_AUTH_ACTION_COUNT + ): + iam_request = iam_request_cls( + method=self.method, path=self.path, data=self.data, headers=self.headers + ) iam_request.check_signature() iam_request.check_action_permitted() else: @@ -130,10 +145,17 @@ class ActionAuthenticatorMixin(object): def decorator(function): def wrapper(*args, **kwargs): if settings.TEST_SERVER_MODE: - response = requests.post("http://localhost:5000/moto-api/reset-auth", data=str(initial_no_auth_action_count).encode()) - original_initial_no_auth_action_count = response.json()['PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT'] + response = requests.post( + "http://localhost:5000/moto-api/reset-auth", + data=str(initial_no_auth_action_count).encode(), + ) + original_initial_no_auth_action_count = response.json()[ + "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT" + ] else: - original_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT + original_initial_no_auth_action_count = ( + settings.INITIAL_NO_AUTH_ACTION_COUNT + ) original_request_count = ActionAuthenticatorMixin.request_count settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count ActionAuthenticatorMixin.request_count = 0 @@ -141,10 +163,15 @@ class ActionAuthenticatorMixin(object): result = function(*args, **kwargs) finally: if settings.TEST_SERVER_MODE: - requests.post("http://localhost:5000/moto-api/reset-auth", data=str(original_initial_no_auth_action_count).encode()) + requests.post( + "http://localhost:5000/moto-api/reset-auth", + data=str(original_initial_no_auth_action_count).encode(), + ) else: ActionAuthenticatorMixin.request_count = original_request_count - settings.INITIAL_NO_AUTH_ACTION_COUNT = original_initial_no_auth_action_count + settings.INITIAL_NO_AUTH_ACTION_COUNT = ( + original_initial_no_auth_action_count + ) return result functools.update_wrapper(wrapper, function) @@ -156,11 +183,13 @@ class ActionAuthenticatorMixin(object): class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): - default_region = 'us-east-1' + default_region = "us-east-1" # to extract region, use [^.] - region_regex = re.compile(r'\.(?P[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com') - param_list_regex = re.compile(r'(.*)\.(\d+)\.') - access_key_regex = re.compile(r'AWS.*(?P(?[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com") + param_list_regex = re.compile(r"(.*)\.(\d+)\.") + access_key_regex = re.compile( + r"AWS.*(?P(? '^/cars/.*/drivers/.*/drive$' """ - def _convert(elem, is_last): - if not re.match('^{.*}$', elem): - return elem - name = elem.replace('{', '').replace('}', '') - if is_last: - return '(?P<%s>[^/]*)' % name - return '(?P<%s>.*)' % name - elems = uri.split('/') + def _convert(elem, is_last): + if not re.match("^{.*}$", elem): + return elem + name = elem.replace("{", "").replace("}", "").replace("+", "") + if is_last: + return "(?P<%s>[^/]*)" % name + return "(?P<%s>.*)" % name + + elems = uri.split("/") num_elems = len(elems) - regexp = '^{}$'.format('/'.join([_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)])) + regexp = "^{}$".format( + "/".join( + [_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)] + ) + ) return regexp def _get_action_from_method_and_request_uri(self, method, request_uri): @@ -288,19 +329,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # service response class should have 'SERVICE_NAME' class member, # if you want to get action from method and url - if not hasattr(self, 'SERVICE_NAME'): + if not hasattr(self, "SERVICE_NAME"): return None service = self.SERVICE_NAME conn = boto3.client(service, region_name=self.region) # make cache if it does not exist yet - if not hasattr(self, 'method_urls'): + if not hasattr(self, "method_urls"): self.method_urls = defaultdict(lambda: defaultdict(str)) op_names = conn._service_model.operation_names for op_name in op_names: op_model = conn._service_model.operation_model(op_name) - _method = op_model.http['method'] - uri_regexp = self.uri_to_regexp(op_model.http['requestUri']) + _method = op_model.http["method"] + uri_regexp = self.uri_to_regexp(op_model.http["requestUri"]) self.method_urls[_method][uri_regexp] = op_model.name regexp_and_names = self.method_urls[method] for regexp, name in regexp_and_names.items(): @@ -311,11 +352,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return None def _get_action(self): - action = self.querystring.get('Action', [""])[0] + action = self.querystring.get("Action", [""])[0] if not action: # Some services use a header for the action # Headers are case-insensitive. Probably a better way to do this. - match = self.headers.get( - 'x-amz-target') or self.headers.get('X-Amz-Target') + match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target") if match: action = match.split(".")[-1] # get action from method and uri @@ -347,10 +387,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return self._send_response(headers, response) if not action: - return 404, headers, '' + return 404, headers, "" raise NotImplementedError( - "The {0} action has not been implemented".format(action)) + "The {0} action has not been implemented".format(action) + ) @staticmethod def _send_response(headers, response): @@ -358,11 +399,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): body, new_headers = response else: status, new_headers, body = response - status = new_headers.get('status', 200) + status = new_headers.get("status", 200) headers.update(new_headers) # Cast status to string if "status" in headers: - headers['status'] = str(headers['status']) + headers["status"] = str(headers["status"]) return status, headers, body def _get_param(self, param_name, if_none=None): @@ -396,9 +437,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _get_bool_param(self, param_name, if_none=None): val = self._get_param(param_name) if val is not None: - if val.lower() == 'true': + if val.lower() == "true": return True - elif val.lower() == 'false': + elif val.lower() == "false": return False return if_none @@ -416,11 +457,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if is_tracked(name) or not name.startswith(param_prefix): continue - if len(name) > len(param_prefix) and \ - not name[len(param_prefix):].startswith('.'): + if len(name) > len(param_prefix) and not name[ + len(param_prefix) : + ].startswith("."): continue - match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None + match = ( + self.param_list_regex.search(name[len(param_prefix) :]) + if len(name) > len(param_prefix) + else None + ) if match: prefix = param_prefix + match.group(1) value = self._get_multi_param(prefix) @@ -435,7 +481,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if len(value_dict) > 1: # strip off period prefix - value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()} + value_dict = { + name[len(param_prefix) + 1 :]: value + for name, value in value_dict.items() + } else: value_dict = list(value_dict.values())[0] @@ -454,7 +503,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): index = 1 while True: value_dict = self._get_multi_param_helper(prefix + str(index)) - if not value_dict: + if not value_dict and value_dict != "": break values.append(value_dict) @@ -479,8 +528,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): params = {} for key, value in self.querystring.items(): if key.startswith(param_prefix): - params[camelcase_to_underscores( - key.replace(param_prefix, ""))] = value[0] + params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[ + 0 + ] return params def _get_list_prefix(self, param_prefix): @@ -513,19 +563,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): new_items = {} for key, value in self.querystring.items(): if key.startswith(index_prefix): - new_items[camelcase_to_underscores( - key.replace(index_prefix, ""))] = value[0] + new_items[ + camelcase_to_underscores(key.replace(index_prefix, "")) + ] = value[0] if not new_items: break results.append(new_items) param_index += 1 return results - def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'): + def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"): results = {} param_index = 1 while 1: - index_prefix = '{0}.{1}.'.format(param_prefix, param_index) + index_prefix = "{0}.{1}.".format(param_prefix, param_index) k, v = None, None for key, value in self.querystring.items(): @@ -552,8 +603,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): param_index = 1 while True: - key_name = 'tag.{0}._key'.format(param_index) - value_name = 'tag.{0}._value'.format(param_index) + key_name = "tag.{0}._key".format(param_index) + value_name = "tag.{0}._value".format(param_index) try: results[resource_type][tag[key_name]] = tag[value_name] @@ -563,7 +614,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return results - def _get_object_map(self, prefix, name='Name', value='Value'): + def _get_object_map(self, prefix, name="Name", value="Value"): """ Given a query dict like { @@ -591,15 +642,14 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): index = 1 while True: # Loop through looking for keys representing object name - name_key = '{0}.{1}.{2}'.format(prefix, index, name) + name_key = "{0}.{1}.{2}".format(prefix, index, name) obj_name = self.querystring.get(name_key) if not obj_name: # Found all keys break obj = {} - value_key_prefix = '{0}.{1}.{2}.'.format( - prefix, index, value) + value_key_prefix = "{0}.{1}.{2}.".format(prefix, index, value) for k, v in self.querystring.items(): if k.startswith(value_key_prefix): _, value_key = k.split(value_key_prefix, 1) @@ -613,31 +663,46 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): @property def request_json(self): - return 'JSON' in self.querystring.get('ContentType', []) + return "JSON" in self.querystring.get("ContentType", []) def is_not_dryrun(self, action): - if 'true' in self.querystring.get('DryRun', ['false']): - message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action - raise DryRunClientError( - error_type="DryRunOperation", message=message) + if "true" in self.querystring.get("DryRun", ["false"]): + message = ( + "An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set" + % action + ) + raise DryRunClientError(error_type="DryRunOperation", message=message) return True class MotoAPIResponse(BaseResponse): - def reset_response(self, request, full_url, headers): if request.method == "POST": from .models import moto_api_backend + moto_api_backend.reset() return 200, {}, json.dumps({"status": "ok"}) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"}) def reset_auth_response(self, request, full_url, headers): if request.method == "POST": - previous_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT + previous_initial_no_auth_action_count = ( + settings.INITIAL_NO_AUTH_ACTION_COUNT + ) settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode()) ActionAuthenticatorMixin.request_count = 0 - return 200, {}, json.dumps({"status": "ok", "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(previous_initial_no_auth_action_count)}) + return ( + 200, + {}, + json.dumps( + { + "status": "ok", + "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str( + previous_initial_no_auth_action_count + ), + } + ), + ) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"}) def model_data(self, request, full_url, headers): @@ -665,7 +730,8 @@ class MotoAPIResponse(BaseResponse): def dashboard(self, request, full_url, headers): from flask import render_template - return render_template('dashboard.html') + + return render_template("dashboard.html") class _RecursiveDictRef(object): @@ -676,7 +742,7 @@ class _RecursiveDictRef(object): self.dic = {} def __repr__(self): - return '{!r}'.format(self.dic) + return "{!r}".format(self.dic) def __getattr__(self, key): return self.dic.__getattr__(key) @@ -700,21 +766,21 @@ class AWSServiceSpec(object): """ def __init__(self, path): - self.path = resource_filename('botocore', path) - with io.open(self.path, 'r', encoding='utf-8') as f: + self.path = resource_filename("botocore", path) + with io.open(self.path, "r", encoding="utf-8") as f: spec = json.load(f) - self.metadata = spec['metadata'] - self.operations = spec['operations'] - self.shapes = spec['shapes'] + self.metadata = spec["metadata"] + self.operations = spec["operations"] + self.shapes = spec["shapes"] def input_spec(self, operation): try: op = self.operations[operation] except KeyError: - raise ValueError('Invalid operation: {}'.format(operation)) - if 'input' not in op: + raise ValueError("Invalid operation: {}".format(operation)) + if "input" not in op: return {} - shape = self.shapes[op['input']['shape']] + shape = self.shapes[op["input"]["shape"]] return self._expand(shape) def output_spec(self, operation): @@ -728,129 +794,133 @@ class AWSServiceSpec(object): try: op = self.operations[operation] except KeyError: - raise ValueError('Invalid operation: {}'.format(operation)) - if 'output' not in op: + raise ValueError("Invalid operation: {}".format(operation)) + if "output" not in op: return {} - shape = self.shapes[op['output']['shape']] + shape = self.shapes[op["output"]["shape"]] return self._expand(shape) def _expand(self, shape): def expand(dic, seen=None): seen = seen or {} - if dic['type'] == 'structure': + if dic["type"] == "structure": nodes = {} - for k, v in dic['members'].items(): + for k, v in dic["members"].items(): seen_till_here = dict(seen) if k in seen_till_here: nodes[k] = seen_till_here[k] continue seen_till_here[k] = _RecursiveDictRef() - nodes[k] = expand(self.shapes[v['shape']], seen_till_here) + nodes[k] = expand(self.shapes[v["shape"]], seen_till_here) seen_till_here[k].set_reference(k, nodes[k]) - nodes['type'] = 'structure' + nodes["type"] = "structure" return nodes - elif dic['type'] == 'list': + elif dic["type"] == "list": seen_till_here = dict(seen) - shape = dic['member']['shape'] + shape = dic["member"]["shape"] if shape in seen_till_here: return seen_till_here[shape] seen_till_here[shape] = _RecursiveDictRef() expanded = expand(self.shapes[shape], seen_till_here) seen_till_here[shape].set_reference(shape, expanded) - return {'type': 'list', 'member': expanded} + return {"type": "list", "member": expanded} - elif dic['type'] == 'map': + elif dic["type"] == "map": seen_till_here = dict(seen) - node = {'type': 'map'} + node = {"type": "map"} - if 'shape' in dic['key']: - shape = dic['key']['shape'] + if "shape" in dic["key"]: + shape = dic["key"]["shape"] seen_till_here[shape] = _RecursiveDictRef() - node['key'] = expand(self.shapes[shape], seen_till_here) - seen_till_here[shape].set_reference(shape, node['key']) + node["key"] = expand(self.shapes[shape], seen_till_here) + seen_till_here[shape].set_reference(shape, node["key"]) else: - node['key'] = dic['key']['type'] + node["key"] = dic["key"]["type"] - if 'shape' in dic['value']: - shape = dic['value']['shape'] + if "shape" in dic["value"]: + shape = dic["value"]["shape"] seen_till_here[shape] = _RecursiveDictRef() - node['value'] = expand(self.shapes[shape], seen_till_here) - seen_till_here[shape].set_reference(shape, node['value']) + node["value"] = expand(self.shapes[shape], seen_till_here) + seen_till_here[shape].set_reference(shape, node["value"]) else: - node['value'] = dic['value']['type'] + node["value"] = dic["value"]["type"] return node else: - return {'type': dic['type']} + return {"type": dic["type"]} return expand(shape) def to_str(value, spec): - vtype = spec['type'] - if vtype == 'boolean': - return 'true' if value else 'false' - elif vtype == 'integer': + vtype = spec["type"] + if vtype == "boolean": + return "true" if value else "false" + elif vtype == "integer": return str(value) - elif vtype == 'float': + elif vtype == "float": return str(value) - elif vtype == 'double': + elif vtype == "double": return str(value) - elif vtype == 'timestamp': - return datetime.datetime.utcfromtimestamp( - value).replace(tzinfo=pytz.utc).isoformat() - elif vtype == 'string': + elif vtype == "timestamp": + return ( + datetime.datetime.utcfromtimestamp(value) + .replace(tzinfo=pytz.utc) + .isoformat() + ) + elif vtype == "string": return str(value) elif value is None: - return 'null' + return "null" else: - raise TypeError('Unknown type {}'.format(vtype)) + raise TypeError("Unknown type {}".format(vtype)) def from_str(value, spec): - vtype = spec['type'] - if vtype == 'boolean': - return True if value == 'true' else False - elif vtype == 'integer': + vtype = spec["type"] + if vtype == "boolean": + return True if value == "true" else False + elif vtype == "integer": return int(value) - elif vtype == 'float': + elif vtype == "float": return float(value) - elif vtype == 'double': + elif vtype == "double": return float(value) - elif vtype == 'timestamp': + elif vtype == "timestamp": return value - elif vtype == 'string': + elif vtype == "string": return value - raise TypeError('Unknown type {}'.format(vtype)) + raise TypeError("Unknown type {}".format(vtype)) def flatten_json_request_body(prefix, dict_body, spec): """Convert a JSON request body into query params.""" - if len(spec) == 1 and 'type' in spec: + if len(spec) == 1 and "type" in spec: return {prefix: to_str(dict_body, spec)} flat = {} for key, value in dict_body.items(): - node_type = spec[key]['type'] - if node_type == 'list': + node_type = spec[key]["type"] + if node_type == "list": for idx, v in enumerate(value, 1): - pref = key + '.member.' + str(idx) - flat.update(flatten_json_request_body( - pref, v, spec[key]['member'])) - elif node_type == 'map': + pref = key + ".member." + str(idx) + flat.update(flatten_json_request_body(pref, v, spec[key]["member"])) + elif node_type == "map": for idx, (k, v) in enumerate(value.items(), 1): - pref = key + '.entry.' + str(idx) - flat.update(flatten_json_request_body( - pref + '.key', k, spec[key]['key'])) - flat.update(flatten_json_request_body( - pref + '.value', v, spec[key]['value'])) + pref = key + ".entry." + str(idx) + flat.update( + flatten_json_request_body(pref + ".key", k, spec[key]["key"]) + ) + flat.update( + flatten_json_request_body(pref + ".value", v, spec[key]["value"]) + ) else: flat.update(flatten_json_request_body(key, value, spec[key])) if prefix: - prefix = prefix + '.' + prefix = prefix + "." return dict((prefix + k, v) for k, v in flat.items()) @@ -873,41 +943,40 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None): od = OrderedDict() for k, v in value.items(): - if k.startswith('@'): + if k.startswith("@"): continue if k not in spec: # this can happen when with an older version of # botocore for which the node in XML template is not # defined in service spec. - log.warning( - 'Field %s is not defined by the botocore version in use', k) + log.warning("Field %s is not defined by the botocore version in use", k) continue - if spec[k]['type'] == 'list': + if spec[k]["type"] == "list": if v is None: od[k] = [] - elif len(spec[k]['member']) == 1: - if isinstance(v['member'], list): - od[k] = transform(v['member'], spec[k]['member']) + elif len(spec[k]["member"]) == 1: + if isinstance(v["member"], list): + od[k] = transform(v["member"], spec[k]["member"]) else: - od[k] = [transform(v['member'], spec[k]['member'])] - elif isinstance(v['member'], list): - od[k] = [transform(o, spec[k]['member']) - for o in v['member']] - elif isinstance(v['member'], OrderedDict): - od[k] = [transform(v['member'], spec[k]['member'])] + od[k] = [transform(v["member"], spec[k]["member"])] + elif isinstance(v["member"], list): + od[k] = [transform(o, spec[k]["member"]) for o in v["member"]] + elif isinstance(v["member"], OrderedDict): + od[k] = [transform(v["member"], spec[k]["member"])] else: - raise ValueError('Malformatted input') - elif spec[k]['type'] == 'map': + raise ValueError("Malformatted input") + elif spec[k]["type"] == "map": if v is None: od[k] = {} else: - items = ([v['entry']] if not isinstance(v['entry'], list) else - v['entry']) + items = ( + [v["entry"]] if not isinstance(v["entry"], list) else v["entry"] + ) for item in items: - key = from_str(item['key'], spec[k]['key']) - val = from_str(item['value'], spec[k]['value']) + key = from_str(item["key"], spec[k]["key"]) + val = from_str(item["value"], spec[k]["value"]) if k not in od: od[k] = {} od[k][key] = val @@ -921,7 +990,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None): dic = xmltodict.parse(xml) output_spec = service_spec.output_spec(operation) try: - for k in (result_node or (operation + 'Response', operation + 'Result')): + for k in result_node or (operation + "Response", operation + "Result"): dic = dic[k] except KeyError: return None diff --git a/moto/core/urls.py b/moto/core/urls.py index 46025221e..12036b5c3 100644 --- a/moto/core/urls.py +++ b/moto/core/urls.py @@ -1,15 +1,13 @@ from __future__ import unicode_literals from .responses import MotoAPIResponse -url_bases = [ - "https?://motoapi.amazonaws.com" -] +url_bases = ["https?://motoapi.amazonaws.com"] response_instance = MotoAPIResponse() url_paths = { - '{0}/moto-api/$': response_instance.dashboard, - '{0}/moto-api/data.json': response_instance.model_data, - '{0}/moto-api/reset': response_instance.reset_response, - '{0}/moto-api/reset-auth': response_instance.reset_auth_response, + "{0}/moto-api/$": response_instance.dashboard, + "{0}/moto-api/data.json": response_instance.model_data, + "{0}/moto-api/reset": response_instance.reset_response, + "{0}/moto-api/reset-auth": response_instance.reset_auth_response, } diff --git a/moto/core/utils.py b/moto/core/utils.py index ca670e871..efad5679c 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -8,6 +8,7 @@ import random import re import six import string +from botocore.exceptions import ClientError from six.moves.urllib.parse import urlparse @@ -15,9 +16,9 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase def camelcase_to_underscores(argument): - ''' Converts a camelcase param like theNewAttribute to the equivalent - python underscore variable like the_new_attribute''' - result = '' + """ Converts a camelcase param like theNewAttribute to the equivalent + python underscore variable like the_new_attribute""" + result = "" prev_char_title = True if not argument: return argument @@ -41,18 +42,18 @@ def camelcase_to_underscores(argument): def underscores_to_camelcase(argument): - ''' Converts a camelcase param like the_new_attribute to the equivalent + """ Converts a camelcase param like the_new_attribute to the equivalent camelcase version like theNewAttribute. Note that the first letter is - NOT capitalized by this function ''' - result = '' + NOT capitalized by this function """ + result = "" previous_was_underscore = False for char in argument: - if char != '_': + if char != "_": if previous_was_underscore: result += char.upper() else: result += char - previous_was_underscore = char == '_' + previous_was_underscore = char == "_" return result @@ -69,12 +70,18 @@ def method_names_from_class(clazz): def get_random_hex(length=8): - chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f'] - return ''.join(six.text_type(random.choice(chars)) for x in range(length)) + chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"] + return "".join(six.text_type(random.choice(chars)) for x in range(length)) def get_random_message_id(): - return '{0}-{1}-{2}-{3}-{4}'.format(get_random_hex(8), get_random_hex(4), get_random_hex(4), get_random_hex(4), get_random_hex(12)) + return "{0}-{1}-{2}-{3}-{4}".format( + get_random_hex(8), + get_random_hex(4), + get_random_hex(4), + get_random_hex(4), + get_random_hex(12), + ) def convert_regex_to_flask_path(url_path): @@ -97,7 +104,6 @@ def convert_regex_to_flask_path(url_path): class convert_httpretty_response(object): - def __init__(self, callback): self.callback = callback @@ -114,13 +120,12 @@ class convert_httpretty_response(object): def __call__(self, request, url, headers, **kwargs): result = self.callback(request, url, headers) status, headers, response = result - if 'server' not in headers: + if "server" not in headers: headers["server"] = "amazon.com" return status, headers, response class convert_flask_to_httpretty_response(object): - def __init__(self, callback): self.callback = callback @@ -137,7 +142,10 @@ class convert_flask_to_httpretty_response(object): def __call__(self, args=None, **kwargs): from flask import request, Response - result = self.callback(request, request.url, {}) + try: + result = self.callback(request, request.url, {}) + except ClientError as exc: + result = 400, {}, exc.response["Error"]["Message"] # result is a status, headers, response tuple if len(result) == 3: status, headers, content = result @@ -145,13 +153,12 @@ class convert_flask_to_httpretty_response(object): status, headers, content = 200, {}, result response = Response(response=content, status=status, headers=headers) - if request.method == "HEAD" and 'content-length' in headers: - response.headers['Content-Length'] = headers['content-length'] + if request.method == "HEAD" and "content-length" in headers: + response.headers["Content-Length"] = headers["content-length"] return response class convert_flask_to_responses_response(object): - def __init__(self, callback): self.callback = callback @@ -176,14 +183,14 @@ class convert_flask_to_responses_response(object): def iso_8601_datetime_with_milliseconds(datetime): - return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + 'Z' + return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" def iso_8601_datetime_without_milliseconds(datetime): - return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z' + return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z" -RFC1123 = '%a, %d %b %Y %H:%M:%S GMT' +RFC1123 = "%a, %d %b %Y %H:%M:%S GMT" def rfc_1123_datetime(datetime): @@ -212,16 +219,16 @@ def gen_amz_crc32(response, headerdict=None): crc = str(binascii.crc32(response)) if headerdict is not None and isinstance(headerdict, dict): - headerdict.update({'x-amz-crc32': crc}) + headerdict.update({"x-amz-crc32": crc}) return crc def gen_amzn_requestid_long(headerdict=None): - req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)]) + req_id = "".join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)]) if headerdict is not None and isinstance(headerdict, dict): - headerdict.update({'x-amzn-requestid': req_id}) + headerdict.update({"x-amzn-requestid": req_id}) return req_id @@ -239,13 +246,13 @@ def amz_crc32(f): else: if len(response) == 2: body, new_headers = response - status = new_headers.get('status', 200) + status = new_headers.get("status", 200) else: status, new_headers, body = response headers.update(new_headers) # Cast status to string if "status" in headers: - headers['status'] = str(headers['status']) + headers["status"] = str(headers["status"]) try: # Doesnt work on python2 for some odd unicode strings @@ -271,7 +278,7 @@ def amzn_request_id(f): else: if len(response) == 2: body, new_headers = response - status = new_headers.get('status', 200) + status = new_headers.get("status", 200) else: status, new_headers, body = response headers.update(new_headers) @@ -280,7 +287,7 @@ def amzn_request_id(f): # Update request ID in XML try: - body = re.sub(r'(?<=).*(?=<\/RequestId>)', request_id, body) + body = re.sub(r"(?<=).*(?=<\/RequestId>)", request_id, body) except Exception: # Will just ignore if it cant work on bytes (which are str's on python2) pass @@ -293,7 +300,31 @@ def path_url(url): parsed_url = urlparse(url) path = parsed_url.path if not path: - path = '/' + path = "/" if parsed_url.query: - path = path + '?' + parsed_url.query + path = path + "?" + parsed_url.query return path + + +def py2_strip_unicode_keys(blob): + """For Python 2 Only -- this will convert unicode keys in nested Dicts, Lists, and Sets to standard strings.""" + if type(blob) == unicode: # noqa + return str(blob) + + elif type(blob) == dict: + for key in list(blob.keys()): + value = blob.pop(key) + blob[str(key)] = py2_strip_unicode_keys(value) + + elif type(blob) == list: + for i in range(0, len(blob)): + blob[i] = py2_strip_unicode_keys(blob[i]) + + elif type(blob) == set: + new_set = set() + for value in blob: + new_set.add(py2_strip_unicode_keys(value)) + + blob = new_set + + return blob diff --git a/moto/datapipeline/__init__.py b/moto/datapipeline/__init__.py index 2565ddd5a..42ee5d6ff 100644 --- a/moto/datapipeline/__init__.py +++ b/moto/datapipeline/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import datapipeline_backends from ..core.models import base_decorator, deprecated_base_decorator -datapipeline_backend = datapipeline_backends['us-east-1'] +datapipeline_backend = datapipeline_backends["us-east-1"] mock_datapipeline = base_decorator(datapipeline_backends) mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends) diff --git a/moto/datapipeline/models.py b/moto/datapipeline/models.py index bb8417a20..cc1fe777e 100644 --- a/moto/datapipeline/models.py +++ b/moto/datapipeline/models.py @@ -8,85 +8,65 @@ from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys class PipelineObject(BaseModel): - def __init__(self, object_id, name, fields): self.object_id = object_id self.name = name self.fields = fields def to_json(self): - return { - "fields": self.fields, - "id": self.object_id, - "name": self.name, - } + return {"fields": self.fields, "id": self.object_id, "name": self.name} class Pipeline(BaseModel): - def __init__(self, name, unique_id, **kwargs): self.name = name self.unique_id = unique_id - self.description = kwargs.get('description', '') + self.description = kwargs.get("description", "") self.pipeline_id = get_random_pipeline_id() self.creation_time = datetime.datetime.utcnow() self.objects = [] self.status = "PENDING" - self.tags = kwargs.get('tags', []) + self.tags = kwargs.get("tags", []) @property def physical_resource_id(self): return self.pipeline_id def to_meta_json(self): - return { - "id": self.pipeline_id, - "name": self.name, - } + return {"id": self.pipeline_id, "name": self.name} def to_json(self): return { "description": self.description, - "fields": [{ - "key": "@pipelineState", - "stringValue": self.status, - }, { - "key": "description", - "stringValue": self.description - }, { - "key": "name", - "stringValue": self.name - }, { - "key": "@creationTime", - "stringValue": datetime.datetime.strftime(self.creation_time, '%Y-%m-%dT%H-%M-%S'), - }, { - "key": "@id", - "stringValue": self.pipeline_id, - }, { - "key": "@sphere", - "stringValue": "PIPELINE" - }, { - "key": "@version", - "stringValue": "1" - }, { - "key": "@userId", - "stringValue": "924374875933" - }, { - "key": "@accountId", - "stringValue": "924374875933" - }, { - "key": "uniqueId", - "stringValue": self.unique_id - }], + "fields": [ + {"key": "@pipelineState", "stringValue": self.status}, + {"key": "description", "stringValue": self.description}, + {"key": "name", "stringValue": self.name}, + { + "key": "@creationTime", + "stringValue": datetime.datetime.strftime( + self.creation_time, "%Y-%m-%dT%H-%M-%S" + ), + }, + {"key": "@id", "stringValue": self.pipeline_id}, + {"key": "@sphere", "stringValue": "PIPELINE"}, + {"key": "@version", "stringValue": "1"}, + {"key": "@userId", "stringValue": "924374875933"}, + {"key": "@accountId", "stringValue": "924374875933"}, + {"key": "uniqueId", "stringValue": self.unique_id}, + ], "name": self.name, "pipelineId": self.pipeline_id, - "tags": self.tags + "tags": self.tags, } def set_pipeline_objects(self, pipeline_objects): self.objects = [ - PipelineObject(pipeline_object['id'], pipeline_object[ - 'name'], pipeline_object['fields']) + PipelineObject( + pipeline_object["id"], + pipeline_object["name"], + pipeline_object["fields"], + ) for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects) ] @@ -94,15 +74,19 @@ class Pipeline(BaseModel): self.status = "SCHEDULED" @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): datapipeline_backend = datapipeline_backends[region_name] properties = cloudformation_json["Properties"] cloudformation_unique_id = "cf-" + properties["Name"] pipeline = datapipeline_backend.create_pipeline( - properties["Name"], cloudformation_unique_id) + properties["Name"], cloudformation_unique_id + ) datapipeline_backend.put_pipeline_definition( - pipeline.pipeline_id, properties["PipelineObjects"]) + pipeline.pipeline_id, properties["PipelineObjects"] + ) if properties["Activate"]: pipeline.activate() @@ -110,7 +94,6 @@ class Pipeline(BaseModel): class DataPipelineBackend(BaseBackend): - def __init__(self): self.pipelines = OrderedDict() @@ -123,8 +106,11 @@ class DataPipelineBackend(BaseBackend): return self.pipelines.values() def describe_pipelines(self, pipeline_ids): - pipelines = [pipeline for pipeline in self.pipelines.values( - ) if pipeline.pipeline_id in pipeline_ids] + pipelines = [ + pipeline + for pipeline in self.pipelines.values() + if pipeline.pipeline_id in pipeline_ids + ] return pipelines def get_pipeline(self, pipeline_id): @@ -144,7 +130,8 @@ class DataPipelineBackend(BaseBackend): def describe_objects(self, object_ids, pipeline_id): pipeline = self.get_pipeline(pipeline_id) pipeline_objects = [ - pipeline_object for pipeline_object in pipeline.objects + pipeline_object + for pipeline_object in pipeline.objects if pipeline_object.object_id in object_ids ] return pipeline_objects diff --git a/moto/datapipeline/responses.py b/moto/datapipeline/responses.py index e462e3981..42e1ff2c3 100644 --- a/moto/datapipeline/responses.py +++ b/moto/datapipeline/responses.py @@ -7,7 +7,6 @@ from .models import datapipeline_backends class DataPipelineResponse(BaseResponse): - @property def parameters(self): # TODO this should really be moved to core/responses.py @@ -21,47 +20,47 @@ class DataPipelineResponse(BaseResponse): return datapipeline_backends[self.region] def create_pipeline(self): - name = self.parameters.get('name') - unique_id = self.parameters.get('uniqueId') - description = self.parameters.get('description', '') - tags = self.parameters.get('tags', []) - pipeline = self.datapipeline_backend.create_pipeline(name, unique_id, description=description, tags=tags) - return json.dumps({ - "pipelineId": pipeline.pipeline_id, - }) + name = self.parameters.get("name") + unique_id = self.parameters.get("uniqueId") + description = self.parameters.get("description", "") + tags = self.parameters.get("tags", []) + pipeline = self.datapipeline_backend.create_pipeline( + name, unique_id, description=description, tags=tags + ) + return json.dumps({"pipelineId": pipeline.pipeline_id}) def list_pipelines(self): pipelines = list(self.datapipeline_backend.list_pipelines()) pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines] max_pipelines = 50 - marker = self.parameters.get('marker') + marker = self.parameters.get("marker") if marker: start = pipeline_ids.index(marker) + 1 else: start = 0 - pipelines_resp = pipelines[start:start + max_pipelines] + pipelines_resp = pipelines[start : start + max_pipelines] has_more_results = False marker = None if start + max_pipelines < len(pipeline_ids) - 1: has_more_results = True marker = pipelines_resp[-1].pipeline_id - return json.dumps({ - "hasMoreResults": has_more_results, - "marker": marker, - "pipelineIdList": [ - pipeline.to_meta_json() for pipeline in pipelines_resp - ] - }) + return json.dumps( + { + "hasMoreResults": has_more_results, + "marker": marker, + "pipelineIdList": [ + pipeline.to_meta_json() for pipeline in pipelines_resp + ], + } + ) def describe_pipelines(self): pipeline_ids = self.parameters["pipelineIds"] pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids) - return json.dumps({ - "pipelineDescriptionList": [ - pipeline.to_json() for pipeline in pipelines - ] - }) + return json.dumps( + {"pipelineDescriptionList": [pipeline.to_json() for pipeline in pipelines]} + ) def delete_pipeline(self): pipeline_id = self.parameters["pipelineId"] @@ -72,31 +71,38 @@ class DataPipelineResponse(BaseResponse): pipeline_id = self.parameters["pipelineId"] pipeline_objects = self.parameters["pipelineObjects"] - self.datapipeline_backend.put_pipeline_definition( - pipeline_id, pipeline_objects) + self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects) return json.dumps({"errored": False}) def get_pipeline_definition(self): pipeline_id = self.parameters["pipelineId"] pipeline_definition = self.datapipeline_backend.get_pipeline_definition( - pipeline_id) - return json.dumps({ - "pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition] - }) + pipeline_id + ) + return json.dumps( + { + "pipelineObjects": [ + pipeline_object.to_json() for pipeline_object in pipeline_definition + ] + } + ) def describe_objects(self): pipeline_id = self.parameters["pipelineId"] object_ids = self.parameters["objectIds"] pipeline_objects = self.datapipeline_backend.describe_objects( - object_ids, pipeline_id) - return json.dumps({ - "hasMoreResults": False, - "marker": None, - "pipelineObjects": [ - pipeline_object.to_json() for pipeline_object in pipeline_objects - ] - }) + object_ids, pipeline_id + ) + return json.dumps( + { + "hasMoreResults": False, + "marker": None, + "pipelineObjects": [ + pipeline_object.to_json() for pipeline_object in pipeline_objects + ], + } + ) def activate_pipeline(self): pipeline_id = self.parameters["pipelineId"] diff --git a/moto/datapipeline/urls.py b/moto/datapipeline/urls.py index 40805874b..078b44b19 100644 --- a/moto/datapipeline/urls.py +++ b/moto/datapipeline/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DataPipelineResponse -url_bases = [ - "https?://datapipeline.(.+).amazonaws.com", -] +url_bases = ["https?://datapipeline.(.+).amazonaws.com"] -url_paths = { - '{0}/$': DataPipelineResponse.dispatch, -} +url_paths = {"{0}/$": DataPipelineResponse.dispatch} diff --git a/moto/datapipeline/utils.py b/moto/datapipeline/utils.py index 75df4a9a5..b14fe6f1a 100644 --- a/moto/datapipeline/utils.py +++ b/moto/datapipeline/utils.py @@ -1,5 +1,5 @@ -import collections import six +from moto.compat import collections_abc from moto.core.utils import get_random_hex @@ -8,13 +8,15 @@ def get_random_pipeline_id(): def remove_capitalization_of_dict_keys(obj): - if isinstance(obj, collections.Mapping): + if isinstance(obj, collections_abc.Mapping): result = obj.__class__() for key, value in obj.items(): normalized_key = key[:1].lower() + key[1:] result[normalized_key] = remove_capitalization_of_dict_keys(value) return result - elif isinstance(obj, collections.Iterable) and not isinstance(obj, six.string_types): + elif isinstance(obj, collections_abc.Iterable) and not isinstance( + obj, six.string_types + ): result = obj.__class__() for item in obj: result += (remove_capitalization_of_dict_keys(item),) diff --git a/moto/datasync/__init__.py b/moto/datasync/__init__.py new file mode 100644 index 000000000..85134e4f1 --- /dev/null +++ b/moto/datasync/__init__.py @@ -0,0 +1,8 @@ +from __future__ import unicode_literals + +from ..core.models import base_decorator, deprecated_base_decorator +from .models import datasync_backends + +datasync_backend = datasync_backends["us-east-1"] +mock_datasync = base_decorator(datasync_backends) +mock_datasync_deprecated = deprecated_base_decorator(datasync_backends) diff --git a/moto/datasync/exceptions.py b/moto/datasync/exceptions.py new file mode 100644 index 000000000..b0f2d8f0f --- /dev/null +++ b/moto/datasync/exceptions.py @@ -0,0 +1,15 @@ +from __future__ import unicode_literals + +from moto.core.exceptions import JsonRESTError + + +class DataSyncClientError(JsonRESTError): + code = 400 + + +class InvalidRequestException(DataSyncClientError): + def __init__(self, msg=None): + self.code = 400 + super(InvalidRequestException, self).__init__( + "InvalidRequestException", msg or "The request is not valid." + ) diff --git a/moto/datasync/models.py b/moto/datasync/models.py new file mode 100644 index 000000000..17a2659fb --- /dev/null +++ b/moto/datasync/models.py @@ -0,0 +1,230 @@ +import boto3 +from moto.compat import OrderedDict +from moto.core import BaseBackend, BaseModel + +from .exceptions import InvalidRequestException + + +class Location(BaseModel): + def __init__( + self, location_uri, region_name=None, typ=None, metadata=None, arn_counter=0 + ): + self.uri = location_uri + self.region_name = region_name + self.metadata = metadata + self.typ = typ + # Generate ARN + self.arn = "arn:aws:datasync:{0}:111222333444:location/loc-{1}".format( + region_name, str(arn_counter).zfill(17) + ) + + +class Task(BaseModel): + def __init__( + self, + source_location_arn, + destination_location_arn, + name, + region_name, + arn_counter=0, + metadata=None, + ): + self.source_location_arn = source_location_arn + self.destination_location_arn = destination_location_arn + self.name = name + self.metadata = metadata + # For simplicity Tasks are either available or running + self.status = "AVAILABLE" + self.current_task_execution_arn = None + # Generate ARN + self.arn = "arn:aws:datasync:{0}:111222333444:task/task-{1}".format( + region_name, str(arn_counter).zfill(17) + ) + + +class TaskExecution(BaseModel): + + # For simplicity, task_execution can never fail + # Some documentation refers to this list: + # 'Status': 'QUEUED'|'LAUNCHING'|'PREPARING'|'TRANSFERRING'|'VERIFYING'|'SUCCESS'|'ERROR' + # Others refers to this list: + # INITIALIZING | PREPARING | TRANSFERRING | VERIFYING | SUCCESS/FAILURE + # Checking with AWS Support... + TASK_EXECUTION_INTERMEDIATE_STATES = ( + "INITIALIZING", + # 'QUEUED', 'LAUNCHING', + "PREPARING", + "TRANSFERRING", + "VERIFYING", + ) + + TASK_EXECUTION_FAILURE_STATES = ("ERROR",) + TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",) + # Also COMPLETED state? + + def __init__(self, task_arn, arn_counter=0): + self.task_arn = task_arn + self.arn = "{0}/execution/exec-{1}".format(task_arn, str(arn_counter).zfill(17)) + self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[0] + + # Simulate a task execution + def iterate_status(self): + if self.status in self.TASK_EXECUTION_FAILURE_STATES: + return + if self.status in self.TASK_EXECUTION_SUCCESS_STATES: + return + if self.status in self.TASK_EXECUTION_INTERMEDIATE_STATES: + for i, status in enumerate(self.TASK_EXECUTION_INTERMEDIATE_STATES): + if status == self.status: + if i < len(self.TASK_EXECUTION_INTERMEDIATE_STATES) - 1: + self.status = self.TASK_EXECUTION_INTERMEDIATE_STATES[i + 1] + else: + self.status = self.TASK_EXECUTION_SUCCESS_STATES[0] + return + raise Exception( + "TaskExecution.iterate_status: Unknown status={0}".format(self.status) + ) + + def cancel(self): + if self.status not in self.TASK_EXECUTION_INTERMEDIATE_STATES: + raise InvalidRequestException( + "Sync task cannot be cancelled in its current status: {0}".format( + self.status + ) + ) + self.status = "ERROR" + + +class DataSyncBackend(BaseBackend): + def __init__(self, region_name): + self.region_name = region_name + # Always increase when new things are created + # This ensures uniqueness + self.arn_counter = 0 + self.locations = OrderedDict() + self.tasks = OrderedDict() + self.task_executions = OrderedDict() + + def reset(self): + region_name = self.region_name + self._reset_model_refs() + self.__dict__ = {} + self.__init__(region_name) + + def create_location(self, location_uri, typ=None, metadata=None): + """ + # AWS DataSync allows for duplicate LocationUris + for arn, location in self.locations.items(): + if location.uri == location_uri: + raise Exception('Location already exists') + """ + if not typ: + raise Exception("Location type must be specified") + self.arn_counter = self.arn_counter + 1 + location = Location( + location_uri, + region_name=self.region_name, + arn_counter=self.arn_counter, + metadata=metadata, + typ=typ, + ) + self.locations[location.arn] = location + return location.arn + + def _get_location(self, location_arn, typ): + if location_arn not in self.locations: + raise InvalidRequestException( + "Location {0} is not found.".format(location_arn) + ) + location = self.locations[location_arn] + if location.typ != typ: + raise InvalidRequestException( + "Invalid Location type: {0}".format(location.typ) + ) + return location + + def delete_location(self, location_arn): + if location_arn in self.locations: + del self.locations[location_arn] + else: + raise InvalidRequestException + + def create_task( + self, source_location_arn, destination_location_arn, name, metadata=None + ): + if source_location_arn not in self.locations: + raise InvalidRequestException( + "Location {0} not found.".format(source_location_arn) + ) + if destination_location_arn not in self.locations: + raise InvalidRequestException( + "Location {0} not found.".format(destination_location_arn) + ) + self.arn_counter = self.arn_counter + 1 + task = Task( + source_location_arn, + destination_location_arn, + name, + region_name=self.region_name, + arn_counter=self.arn_counter, + metadata=metadata, + ) + self.tasks[task.arn] = task + return task.arn + + def _get_task(self, task_arn): + if task_arn in self.tasks: + return self.tasks[task_arn] + else: + raise InvalidRequestException + + def update_task(self, task_arn, name, metadata): + if task_arn in self.tasks: + task = self.tasks[task_arn] + task.name = name + task.metadata = metadata + else: + raise InvalidRequestException( + "Sync task {0} is not found.".format(task_arn) + ) + + def delete_task(self, task_arn): + if task_arn in self.tasks: + del self.tasks[task_arn] + else: + raise InvalidRequestException + + def start_task_execution(self, task_arn): + self.arn_counter = self.arn_counter + 1 + if task_arn in self.tasks: + task = self.tasks[task_arn] + if task.status == "AVAILABLE": + task_execution = TaskExecution(task_arn, arn_counter=self.arn_counter) + self.task_executions[task_execution.arn] = task_execution + self.tasks[task_arn].current_task_execution_arn = task_execution.arn + self.tasks[task_arn].status = "RUNNING" + return task_execution.arn + raise InvalidRequestException("Invalid request.") + + def _get_task_execution(self, task_execution_arn): + if task_execution_arn in self.task_executions: + return self.task_executions[task_execution_arn] + else: + raise InvalidRequestException + + def cancel_task_execution(self, task_execution_arn): + if task_execution_arn in self.task_executions: + task_execution = self.task_executions[task_execution_arn] + task_execution.cancel() + task_arn = task_execution.task_arn + self.tasks[task_arn].current_task_execution_arn = None + self.tasks[task_arn].status = "AVAILABLE" + return + raise InvalidRequestException( + "Sync task {0} is not found.".format(task_execution_arn) + ) + + +datasync_backends = {} +for region in boto3.Session().get_available_regions("datasync"): + datasync_backends[region] = DataSyncBackend(region_name=region) diff --git a/moto/datasync/responses.py b/moto/datasync/responses.py new file mode 100644 index 000000000..03811fb6e --- /dev/null +++ b/moto/datasync/responses.py @@ -0,0 +1,162 @@ +import json + +from moto.core.responses import BaseResponse + +from .models import datasync_backends + + +class DataSyncResponse(BaseResponse): + @property + def datasync_backend(self): + return datasync_backends[self.region] + + def list_locations(self): + locations = list() + for arn, location in self.datasync_backend.locations.items(): + locations.append({"LocationArn": location.arn, "LocationUri": location.uri}) + return json.dumps({"Locations": locations}) + + def _get_location(self, location_arn, typ): + return self.datasync_backend._get_location(location_arn, typ) + + def create_location_s3(self): + # s3://bucket_name/folder/ + s3_bucket_arn = self._get_param("S3BucketArn") + subdirectory = self._get_param("Subdirectory") + metadata = {"S3Config": self._get_param("S3Config")} + location_uri_elts = ["s3:/", s3_bucket_arn.split(":")[-1]] + if subdirectory: + location_uri_elts.append(subdirectory) + location_uri = "/".join(location_uri_elts) + arn = self.datasync_backend.create_location( + location_uri, metadata=metadata, typ="S3" + ) + return json.dumps({"LocationArn": arn}) + + def describe_location_s3(self): + location_arn = self._get_param("LocationArn") + location = self._get_location(location_arn, typ="S3") + return json.dumps( + { + "LocationArn": location.arn, + "LocationUri": location.uri, + "S3Config": location.metadata["S3Config"], + } + ) + + def create_location_smb(self): + # smb://smb.share.fqdn/AWS_Test/ + subdirectory = self._get_param("Subdirectory") + server_hostname = self._get_param("ServerHostname") + metadata = { + "AgentArns": self._get_param("AgentArns"), + "User": self._get_param("User"), + "Domain": self._get_param("Domain"), + "MountOptions": self._get_param("MountOptions"), + } + + location_uri = "/".join(["smb:/", server_hostname, subdirectory]) + arn = self.datasync_backend.create_location( + location_uri, metadata=metadata, typ="SMB" + ) + return json.dumps({"LocationArn": arn}) + + def describe_location_smb(self): + location_arn = self._get_param("LocationArn") + location = self._get_location(location_arn, typ="SMB") + return json.dumps( + { + "LocationArn": location.arn, + "LocationUri": location.uri, + "AgentArns": location.metadata["AgentArns"], + "User": location.metadata["User"], + "Domain": location.metadata["Domain"], + "MountOptions": location.metadata["MountOptions"], + } + ) + + def delete_location(self): + location_arn = self._get_param("LocationArn") + self.datasync_backend.delete_location(location_arn) + return json.dumps({}) + + def create_task(self): + destination_location_arn = self._get_param("DestinationLocationArn") + source_location_arn = self._get_param("SourceLocationArn") + name = self._get_param("Name") + metadata = { + "CloudWatchLogGroupArn": self._get_param("CloudWatchLogGroupArn"), + "Options": self._get_param("Options"), + "Excludes": self._get_param("Excludes"), + "Tags": self._get_param("Tags"), + } + arn = self.datasync_backend.create_task( + source_location_arn, destination_location_arn, name, metadata=metadata + ) + return json.dumps({"TaskArn": arn}) + + def update_task(self): + task_arn = self._get_param("TaskArn") + self.datasync_backend.update_task( + task_arn, + name=self._get_param("Name"), + metadata={ + "CloudWatchLogGroupArn": self._get_param("CloudWatchLogGroupArn"), + "Options": self._get_param("Options"), + "Excludes": self._get_param("Excludes"), + "Tags": self._get_param("Tags"), + }, + ) + return json.dumps({}) + + def list_tasks(self): + tasks = list() + for arn, task in self.datasync_backend.tasks.items(): + tasks.append( + {"Name": task.name, "Status": task.status, "TaskArn": task.arn} + ) + return json.dumps({"Tasks": tasks}) + + def delete_task(self): + task_arn = self._get_param("TaskArn") + self.datasync_backend.delete_task(task_arn) + return json.dumps({}) + + def describe_task(self): + task_arn = self._get_param("TaskArn") + task = self.datasync_backend._get_task(task_arn) + return json.dumps( + { + "TaskArn": task.arn, + "Status": task.status, + "Name": task.name, + "CurrentTaskExecutionArn": task.current_task_execution_arn, + "SourceLocationArn": task.source_location_arn, + "DestinationLocationArn": task.destination_location_arn, + "CloudWatchLogGroupArn": task.metadata["CloudWatchLogGroupArn"], + "Options": task.metadata["Options"], + "Excludes": task.metadata["Excludes"], + } + ) + + def start_task_execution(self): + task_arn = self._get_param("TaskArn") + arn = self.datasync_backend.start_task_execution(task_arn) + return json.dumps({"TaskExecutionArn": arn}) + + def cancel_task_execution(self): + task_execution_arn = self._get_param("TaskExecutionArn") + self.datasync_backend.cancel_task_execution(task_execution_arn) + return json.dumps({}) + + def describe_task_execution(self): + task_execution_arn = self._get_param("TaskExecutionArn") + task_execution = self.datasync_backend._get_task_execution(task_execution_arn) + result = json.dumps( + {"TaskExecutionArn": task_execution.arn, "Status": task_execution.status} + ) + if task_execution.status == "SUCCESS": + self.datasync_backend.tasks[task_execution.task_arn].status = "AVAILABLE" + # Simulate task being executed + task_execution.iterate_status() + return result diff --git a/moto/datasync/urls.py b/moto/datasync/urls.py new file mode 100644 index 000000000..69ba3cccb --- /dev/null +++ b/moto/datasync/urls.py @@ -0,0 +1,7 @@ +from __future__ import unicode_literals + +from .responses import DataSyncResponse + +url_bases = ["https?://(.*?)(datasync)(.*?).amazonaws.com"] + +url_paths = {"{0}/$": DataSyncResponse.dispatch} diff --git a/moto/dynamodb/comparisons.py b/moto/dynamodb/comparisons.py index d9b391557..5418f906f 100644 --- a/moto/dynamodb/comparisons.py +++ b/moto/dynamodb/comparisons.py @@ -1,19 +1,22 @@ from __future__ import unicode_literals + # TODO add tests for all of these COMPARISON_FUNCS = { - 'EQ': lambda item_value, test_value: item_value == test_value, - 'NE': lambda item_value, test_value: item_value != test_value, - 'LE': lambda item_value, test_value: item_value <= test_value, - 'LT': lambda item_value, test_value: item_value < test_value, - 'GE': lambda item_value, test_value: item_value >= test_value, - 'GT': lambda item_value, test_value: item_value > test_value, - 'NULL': lambda item_value: item_value is None, - 'NOT_NULL': lambda item_value: item_value is not None, - 'CONTAINS': lambda item_value, test_value: test_value in item_value, - 'NOT_CONTAINS': lambda item_value, test_value: test_value not in item_value, - 'BEGINS_WITH': lambda item_value, test_value: item_value.startswith(test_value), - 'IN': lambda item_value, *test_values: item_value in test_values, - 'BETWEEN': lambda item_value, lower_test_value, upper_test_value: lower_test_value <= item_value <= upper_test_value, + "EQ": lambda item_value, test_value: item_value == test_value, + "NE": lambda item_value, test_value: item_value != test_value, + "LE": lambda item_value, test_value: item_value <= test_value, + "LT": lambda item_value, test_value: item_value < test_value, + "GE": lambda item_value, test_value: item_value >= test_value, + "GT": lambda item_value, test_value: item_value > test_value, + "NULL": lambda item_value: item_value is None, + "NOT_NULL": lambda item_value: item_value is not None, + "CONTAINS": lambda item_value, test_value: test_value in item_value, + "NOT_CONTAINS": lambda item_value, test_value: test_value not in item_value, + "BEGINS_WITH": lambda item_value, test_value: item_value.startswith(test_value), + "IN": lambda item_value, *test_values: item_value in test_values, + "BETWEEN": lambda item_value, lower_test_value, upper_test_value: lower_test_value + <= item_value + <= upper_test_value, } diff --git a/moto/dynamodb/models.py b/moto/dynamodb/models.py index 300189a0e..f5771ec6e 100644 --- a/moto/dynamodb/models.py +++ b/moto/dynamodb/models.py @@ -6,13 +6,13 @@ import json from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time +from moto.core import ACCOUNT_ID from .comparisons import get_comparison_func class DynamoJsonEncoder(json.JSONEncoder): - def default(self, obj): - if hasattr(obj, 'to_json'): + if hasattr(obj, "to_json"): return obj.to_json() @@ -33,10 +33,7 @@ class DynamoType(object): return hash((self.type, self.value)) def __eq__(self, other): - return ( - self.type == other.type and - self.value == other.value - ) + return self.type == other.type and self.value == other.value def __repr__(self): return "DynamoType: {0}".format(self.to_json()) @@ -54,7 +51,6 @@ class DynamoType(object): class Item(BaseModel): - def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): self.hash_key = hash_key self.hash_key_type = hash_key_type @@ -73,9 +69,7 @@ class Item(BaseModel): for attribute_key, attribute in self.attrs.items(): attributes[attribute_key] = attribute.value - return { - "Attributes": attributes - } + return {"Attributes": attributes} def describe_attrs(self, attributes): if attributes: @@ -85,16 +79,20 @@ class Item(BaseModel): included[key] = value else: included = self.attrs - return { - "Item": included - } + return {"Item": included} class Table(BaseModel): - - def __init__(self, name, hash_key_attr, hash_key_type, - range_key_attr=None, range_key_type=None, read_capacity=None, - write_capacity=None): + def __init__( + self, + name, + hash_key_attr, + hash_key_type, + range_key_attr=None, + range_key_type=None, + read_capacity=None, + write_capacity=None, + ): self.name = name self.hash_key_attr = hash_key_attr self.hash_key_type = hash_key_type @@ -117,12 +115,12 @@ class Table(BaseModel): "KeySchema": { "HashKeyElement": { "AttributeName": self.hash_key_attr, - "AttributeType": self.hash_key_type - }, + "AttributeType": self.hash_key_type, + } }, "ProvisionedThroughput": { "ReadCapacityUnits": self.read_capacity, - "WriteCapacityUnits": self.write_capacity + "WriteCapacityUnits": self.write_capacity, }, "TableName": self.name, "TableStatus": "ACTIVE", @@ -133,19 +131,29 @@ class Table(BaseModel): if self.has_range_key: results["Table"]["KeySchema"]["RangeKeyElement"] = { "AttributeName": self.range_key_attr, - "AttributeType": self.range_key_type + "AttributeType": self.range_key_type, } return results @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - key_attr = [i['AttributeName'] for i in properties['KeySchema'] if i['KeyType'] == 'HASH'][0] - key_type = [i['AttributeType'] for i in properties['AttributeDefinitions'] if i['AttributeName'] == key_attr][0] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + key_attr = [ + i["AttributeName"] + for i in properties["KeySchema"] + if i["KeyType"] == "HASH" + ][0] + key_type = [ + i["AttributeType"] + for i in properties["AttributeDefinitions"] + if i["AttributeName"] == key_attr + ][0] spec = { - 'name': properties['TableName'], - 'hash_key_attr': key_attr, - 'hash_key_type': key_type + "name": properties["TableName"], + "hash_key_attr": key_attr, + "hash_key_type": key_type, } # TODO: optional properties still missing: # range_key_attr, range_key_type, read_capacity, write_capacity @@ -173,8 +181,9 @@ class Table(BaseModel): else: range_value = None - item = Item(hash_value, self.hash_key_type, range_value, - self.range_key_type, item_attrs) + item = Item( + hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs + ) if range_value: self.items[hash_value][range_value] = item @@ -185,7 +194,8 @@ class Table(BaseModel): def get_item(self, hash_key, range_key): if self.has_range_key and not range_key: raise ValueError( - "Table has a range key, but no range key was passed into get_item") + "Table has a range key, but no range key was passed into get_item" + ) try: if range_key: return self.items[hash_key][range_key] @@ -228,7 +238,10 @@ class Table(BaseModel): for result in self.all_items(): scanned_count += 1 passes_all_conditions = True - for attribute_name, (comparison_operator, comparison_objs) in filters.items(): + for ( + attribute_name, + (comparison_operator, comparison_objs), + ) in filters.items(): attribute = result.attrs.get(attribute_name) if attribute: @@ -236,7 +249,7 @@ class Table(BaseModel): if not attribute.compare(comparison_operator, comparison_objs): passes_all_conditions = False break - elif comparison_operator == 'NULL': + elif comparison_operator == "NULL": # Comparison is NULL and we don't have the attribute continue else: @@ -261,15 +274,17 @@ class Table(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'StreamArn': - region = 'us-east-1' - time = '2000-01-01T00:00:00.000' - return 'arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}'.format(region, self.name, time) + + if attribute_name == "StreamArn": + region = "us-east-1" + time = "2000-01-01T00:00:00.000" + return "arn:aws:dynamodb:{0}:{1}:table/{2}/stream/{3}".format( + region, ACCOUNT_ID, self.name, time + ) raise UnformattedGetAttTemplateException() class DynamoDBBackend(BaseBackend): - def __init__(self): self.tables = OrderedDict() @@ -310,8 +325,7 @@ class DynamoDBBackend(BaseBackend): return None, None hash_key = DynamoType(hash_key_dict) - range_values = [DynamoType(range_value) - for range_value in range_value_dicts] + range_values = [DynamoType(range_value) for range_value in range_value_dicts] return table.query(hash_key, range_comparison, range_values) diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 990069a46..85ae58fc5 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -8,7 +8,6 @@ from .models import dynamodb_backend, dynamo_json_dump class DynamoHandler(BaseResponse): - def get_endpoint_name(self, headers): """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -16,15 +15,15 @@ class DynamoHandler(BaseResponse): ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables """ # Headers are case-insensitive. Probably a better way to do this. - match = headers.get('x-amz-target') or headers.get('X-Amz-Target') + match = headers.get("x-amz-target") or headers.get("X-Amz-Target") if match: return match.split(".")[1] def error(self, type_, status=400): - return status, self.response_headers, dynamo_json_dump({'__type': type_}) + return status, self.response_headers, dynamo_json_dump({"__type": type_}) def call_action(self): - self.body = json.loads(self.body or '{}') + self.body = json.loads(self.body or "{}") endpoint = self.get_endpoint_name(self.headers) if endpoint: endpoint = camelcase_to_underscores(endpoint) @@ -41,7 +40,7 @@ class DynamoHandler(BaseResponse): def list_tables(self): body = self.body - limit = body.get('Limit') + limit = body.get("Limit") if body.get("ExclusiveStartTableName"): last = body.get("ExclusiveStartTableName") start = list(dynamodb_backend.tables.keys()).index(last) + 1 @@ -49,7 +48,7 @@ class DynamoHandler(BaseResponse): start = 0 all_tables = list(dynamodb_backend.tables.keys()) if limit: - tables = all_tables[start:start + limit] + tables = all_tables[start : start + limit] else: tables = all_tables[start:] response = {"TableNames": tables} @@ -59,16 +58,16 @@ class DynamoHandler(BaseResponse): def create_table(self): body = self.body - name = body['TableName'] + name = body["TableName"] - key_schema = body['KeySchema'] - hash_key = key_schema['HashKeyElement'] - hash_key_attr = hash_key['AttributeName'] - hash_key_type = hash_key['AttributeType'] + key_schema = body["KeySchema"] + hash_key = key_schema["HashKeyElement"] + hash_key_attr = hash_key["AttributeName"] + hash_key_type = hash_key["AttributeType"] - range_key = key_schema.get('RangeKeyElement', {}) - range_key_attr = range_key.get('AttributeName') - range_key_type = range_key.get('AttributeType') + range_key = key_schema.get("RangeKeyElement", {}) + range_key_attr = range_key.get("AttributeName") + range_key_type = range_key.get("AttributeType") throughput = body["ProvisionedThroughput"] read_units = throughput["ReadCapacityUnits"] @@ -86,137 +85,131 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(table.describe) def delete_table(self): - name = self.body['TableName'] + name = self.body["TableName"] table = dynamodb_backend.delete_table(name) if table: return dynamo_json_dump(table.describe) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) def update_table(self): - name = self.body['TableName'] + name = self.body["TableName"] throughput = self.body["ProvisionedThroughput"] new_read_units = throughput["ReadCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"] table = dynamodb_backend.update_table_throughput( - name, new_read_units, new_write_units) + name, new_read_units, new_write_units + ) return dynamo_json_dump(table.describe) def describe_table(self): - name = self.body['TableName'] + name = self.body["TableName"] try: table = dynamodb_backend.tables[name] except KeyError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) return dynamo_json_dump(table.describe) def put_item(self): - name = self.body['TableName'] - item = self.body['Item'] + name = self.body["TableName"] + item = self.body["Item"] result = dynamodb_backend.put_item(name, item) if result: item_dict = result.to_json() - item_dict['ConsumedCapacityUnits'] = 1 + item_dict["ConsumedCapacityUnits"] = 1 return dynamo_json_dump(item_dict) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) def batch_write_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] for table_name, table_requests in table_batches.items(): for table_request in table_requests: request_type = list(table_request)[0] request = list(table_request.values())[0] - if request_type == 'PutRequest': - item = request['Item'] + if request_type == "PutRequest": + item = request["Item"] dynamodb_backend.put_item(table_name, item) - elif request_type == 'DeleteRequest': - key = request['Key'] - hash_key = key['HashKeyElement'] - range_key = key.get('RangeKeyElement') - item = dynamodb_backend.delete_item( - table_name, hash_key, range_key) + elif request_type == "DeleteRequest": + key = request["Key"] + hash_key = key["HashKeyElement"] + range_key = key.get("RangeKeyElement") + item = dynamodb_backend.delete_item(table_name, hash_key, range_key) response = { "Responses": { - "Thread": { - "ConsumedCapacityUnits": 1.0 - }, - "Reply": { - "ConsumedCapacityUnits": 1.0 - } + "Thread": {"ConsumedCapacityUnits": 1.0}, + "Reply": {"ConsumedCapacityUnits": 1.0}, }, - "UnprocessedItems": {} + "UnprocessedItems": {}, } return dynamo_json_dump(response) def get_item(self): - name = self.body['TableName'] - key = self.body['Key'] - hash_key = key['HashKeyElement'] - range_key = key.get('RangeKeyElement') - attrs_to_get = self.body.get('AttributesToGet') + name = self.body["TableName"] + key = self.body["Key"] + hash_key = key["HashKeyElement"] + range_key = key.get("RangeKeyElement") + attrs_to_get = self.body.get("AttributesToGet") try: item = dynamodb_backend.get_item(name, hash_key, range_key) except ValueError: - er = 'com.amazon.coral.validate#ValidationException' + er = "com.amazon.coral.validate#ValidationException" return self.error(er, status=400) if item: item_dict = item.describe_attrs(attrs_to_get) - item_dict['ConsumedCapacityUnits'] = 0.5 + item_dict["ConsumedCapacityUnits"] = 0.5 return dynamo_json_dump(item_dict) else: # Item not found - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er, status=404) def batch_get_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] - results = { - "Responses": { - "UnprocessedKeys": {} - } - } + results = {"Responses": {"UnprocessedKeys": {}}} for table_name, table_request in table_batches.items(): items = [] - keys = table_request['Keys'] - attributes_to_get = table_request.get('AttributesToGet') + keys = table_request["Keys"] + attributes_to_get = table_request.get("AttributesToGet") for key in keys: hash_key = key["HashKeyElement"] range_key = key.get("RangeKeyElement") - item = dynamodb_backend.get_item( - table_name, hash_key, range_key) + item = dynamodb_backend.get_item(table_name, hash_key, range_key) if item: item_describe = item.describe_attrs(attributes_to_get) items.append(item_describe) results["Responses"][table_name] = { - "Items": items, "ConsumedCapacityUnits": 1} + "Items": items, + "ConsumedCapacityUnits": 1, + } return dynamo_json_dump(results) def query(self): - name = self.body['TableName'] - hash_key = self.body['HashKeyValue'] - range_condition = self.body.get('RangeKeyCondition') + name = self.body["TableName"] + hash_key = self.body["HashKeyValue"] + range_condition = self.body.get("RangeKeyCondition") if range_condition: - range_comparison = range_condition['ComparisonOperator'] - range_values = range_condition['AttributeValueList'] + range_comparison = range_condition["ComparisonOperator"] + range_values = range_condition["AttributeValueList"] else: range_comparison = None range_values = [] items, last_page = dynamodb_backend.query( - name, hash_key, range_comparison, range_values) + name, hash_key, range_comparison, range_values + ) if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) result = { @@ -234,10 +227,10 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) def scan(self): - name = self.body['TableName'] + name = self.body["TableName"] filters = {} - scan_filters = self.body.get('ScanFilter', {}) + scan_filters = self.body.get("ScanFilter", {}) for attribute_name, scan_filter in scan_filters.items(): # Keys are attribute names. Values are tuples of (comparison, # comparison_value) @@ -248,14 +241,14 @@ class DynamoHandler(BaseResponse): items, scanned_count, last_page = dynamodb_backend.scan(name, filters) if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) result = { "Count": len(items), "Items": [item.attrs for item in items if item], "ConsumedCapacityUnits": 1, - "ScannedCount": scanned_count + "ScannedCount": scanned_count, } # Implement this when we do pagination @@ -267,19 +260,19 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) def delete_item(self): - name = self.body['TableName'] - key = self.body['Key'] - hash_key = key['HashKeyElement'] - range_key = key.get('RangeKeyElement') - return_values = self.body.get('ReturnValues', '') + name = self.body["TableName"] + key = self.body["Key"] + hash_key = key["HashKeyElement"] + range_key = key.get("RangeKeyElement") + return_values = self.body.get("ReturnValues", "") item = dynamodb_backend.delete_item(name, hash_key, range_key) if item: - if return_values == 'ALL_OLD': + if return_values == "ALL_OLD": item_dict = item.to_json() else: - item_dict = {'Attributes': []} - item_dict['ConsumedCapacityUnits'] = 0.5 + item_dict = {"Attributes": []} + item_dict["ConsumedCapacityUnits"] = 0.5 return dynamo_json_dump(item_dict) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) diff --git a/moto/dynamodb/urls.py b/moto/dynamodb/urls.py index 6988f6e15..26f0701a2 100644 --- a/moto/dynamodb/urls.py +++ b/moto/dynamodb/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DynamoHandler -url_bases = [ - "https?://dynamodb.(.+).amazonaws.com" -] +url_bases = ["https?://dynamodb.(.+).amazonaws.com"] -url_paths = { - "{0}/": DynamoHandler.dispatch, -} +url_paths = {"{0}/": DynamoHandler.dispatch} diff --git a/moto/dynamodb2/__init__.py b/moto/dynamodb2/__init__.py index a56a83b35..3d6e8ec1f 100644 --- a/moto/dynamodb2/__init__.py +++ b/moto/dynamodb2/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import dynamodb_backends as dynamodb_backends2 from ..core.models import base_decorator, deprecated_base_decorator -dynamodb_backend2 = dynamodb_backends2['us-east-1'] +dynamodb_backend2 = dynamodb_backends2["us-east-1"] mock_dynamodb2 = base_decorator(dynamodb_backends2) mock_dynamodb2_deprecated = deprecated_base_decorator(dynamodb_backends2) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 151a314f1..69d7f74e0 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -1,7 +1,5 @@ from __future__ import unicode_literals import re -import six -import re from collections import deque from collections import namedtuple @@ -27,37 +25,35 @@ def get_expected(expected): expr = 'Id > 5 AND Subs < 7' """ ops = { - 'EQ': OpEqual, - 'NE': OpNotEqual, - 'LE': OpLessThanOrEqual, - 'LT': OpLessThan, - 'GE': OpGreaterThanOrEqual, - 'GT': OpGreaterThan, - 'NOT_NULL': FuncAttrExists, - 'NULL': FuncAttrNotExists, - 'CONTAINS': FuncContains, - 'NOT_CONTAINS': FuncNotContains, - 'BEGINS_WITH': FuncBeginsWith, - 'IN': FuncIn, - 'BETWEEN': FuncBetween, + "EQ": OpEqual, + "NE": OpNotEqual, + "LE": OpLessThanOrEqual, + "LT": OpLessThan, + "GE": OpGreaterThanOrEqual, + "GT": OpGreaterThan, + "NOT_NULL": FuncAttrExists, + "NULL": FuncAttrNotExists, + "CONTAINS": FuncContains, + "NOT_CONTAINS": FuncNotContains, + "BEGINS_WITH": FuncBeginsWith, + "IN": FuncIn, + "BETWEEN": FuncBetween, } # NOTE: Always uses ConditionalOperator=AND conditions = [] for key, cond in expected.items(): path = AttributePath([key]) - if 'Exists' in cond: - if cond['Exists']: - conditions.append(FuncAttrExists(path)) + if "Exists" in cond: + if cond["Exists"]: + conditions.append(FuncAttrExists(path)) else: - conditions.append(FuncAttrNotExists(path)) - elif 'Value' in cond: - conditions.append(OpEqual(path, AttributeValue(cond['Value']))) - elif 'ComparisonOperator' in cond: - operator_name = cond['ComparisonOperator'] - values = [ - AttributeValue(v) - for v in cond.get("AttributeValueList", [])] + conditions.append(FuncAttrNotExists(path)) + elif "Value" in cond: + conditions.append(OpEqual(path, AttributeValue(cond["Value"]))) + elif "ComparisonOperator" in cond: + operator_name = cond["ComparisonOperator"] + values = [AttributeValue(v) for v in cond.get("AttributeValueList", [])] OpClass = ops[operator_name] conditions.append(OpClass(path, *values)) @@ -77,7 +73,8 @@ class Op(object): """ Base class for a FilterExpression operator """ - OP = '' + + OP = "" def __init__(self, lhs, rhs): self.lhs = lhs @@ -87,45 +84,42 @@ class Op(object): raise NotImplementedError("Expr not defined for {0}".format(type(self))) def __repr__(self): - return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + return "({0} {1} {2})".format(self.lhs, self.OP, self.rhs) + # TODO add tests for all of these -EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa -NE_FUNCTION = lambda item_value, test_value: item_value != test_value # flake8: noqa -LE_FUNCTION = lambda item_value, test_value: item_value <= test_value # flake8: noqa -LT_FUNCTION = lambda item_value, test_value: item_value < test_value # flake8: noqa -GE_FUNCTION = lambda item_value, test_value: item_value >= test_value # flake8: noqa -GT_FUNCTION = lambda item_value, test_value: item_value > test_value # flake8: noqa +EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # noqa +NE_FUNCTION = lambda item_value, test_value: item_value != test_value # noqa +LE_FUNCTION = lambda item_value, test_value: item_value <= test_value # noqa +LT_FUNCTION = lambda item_value, test_value: item_value < test_value # noqa +GE_FUNCTION = lambda item_value, test_value: item_value >= test_value # noqa +GT_FUNCTION = lambda item_value, test_value: item_value > test_value # noqa COMPARISON_FUNCS = { - 'EQ': EQ_FUNCTION, - '=': EQ_FUNCTION, - - 'NE': NE_FUNCTION, - '!=': NE_FUNCTION, - - 'LE': LE_FUNCTION, - '<=': LE_FUNCTION, - - 'LT': LT_FUNCTION, - '<': LT_FUNCTION, - - 'GE': GE_FUNCTION, - '>=': GE_FUNCTION, - - 'GT': GT_FUNCTION, - '>': GT_FUNCTION, - + "EQ": EQ_FUNCTION, + "=": EQ_FUNCTION, + "NE": NE_FUNCTION, + "!=": NE_FUNCTION, + "LE": LE_FUNCTION, + "<=": LE_FUNCTION, + "LT": LT_FUNCTION, + "<": LT_FUNCTION, + "GE": GE_FUNCTION, + ">=": GE_FUNCTION, + "GT": GT_FUNCTION, + ">": GT_FUNCTION, # NULL means the value should not exist at all - 'NULL': lambda item_value: False, + "NULL": lambda item_value: False, # NOT_NULL means the value merely has to exist, and values of None are valid - 'NOT_NULL': lambda item_value: True, - 'CONTAINS': lambda item_value, test_value: test_value in item_value, - 'NOT_CONTAINS': lambda item_value, test_value: test_value not in item_value, - 'BEGINS_WITH': lambda item_value, test_value: item_value.startswith(test_value), - 'IN': lambda item_value, *test_values: item_value in test_values, - 'BETWEEN': lambda item_value, lower_test_value, upper_test_value: lower_test_value <= item_value <= upper_test_value, + "NOT_NULL": lambda item_value: True, + "CONTAINS": lambda item_value, test_value: test_value in item_value, + "NOT_CONTAINS": lambda item_value, test_value: test_value not in item_value, + "BEGINS_WITH": lambda item_value, test_value: item_value.startswith(test_value), + "IN": lambda item_value, *test_values: item_value in test_values, + "BETWEEN": lambda item_value, lower_test_value, upper_test_value: lower_test_value + <= item_value + <= upper_test_value, } @@ -138,8 +132,12 @@ class RecursionStopIteration(StopIteration): class ConditionExpressionParser: - def __init__(self, condition_expression, expression_attribute_names, - expression_attribute_values): + def __init__( + self, + condition_expression, + expression_attribute_names, + expression_attribute_values, + ): self.condition_expression = condition_expression self.expression_attribute_names = expression_attribute_names self.expression_attribute_values = expression_attribute_values @@ -203,52 +201,49 @@ class ConditionExpressionParser: # Condition nodes # --------------- - OR = 'OR' - AND = 'AND' - NOT = 'NOT' - PARENTHESES = 'PARENTHESES' - FUNCTION = 'FUNCTION' - BETWEEN = 'BETWEEN' - IN = 'IN' - COMPARISON = 'COMPARISON' + OR = "OR" + AND = "AND" + NOT = "NOT" + PARENTHESES = "PARENTHESES" + FUNCTION = "FUNCTION" + BETWEEN = "BETWEEN" + IN = "IN" + COMPARISON = "COMPARISON" # Operand nodes # ------------- - EXPRESSION_ATTRIBUTE_VALUE = 'EXPRESSION_ATTRIBUTE_VALUE' - PATH = 'PATH' + EXPRESSION_ATTRIBUTE_VALUE = "EXPRESSION_ATTRIBUTE_VALUE" + PATH = "PATH" # Literal nodes # -------------- - LITERAL = 'LITERAL' - + LITERAL = "LITERAL" class Nonterminal: """Enum defining nonterminals for productions.""" - CONDITION = 'CONDITION' - OPERAND = 'OPERAND' - COMPARATOR = 'COMPARATOR' - FUNCTION_NAME = 'FUNCTION_NAME' - IDENTIFIER = 'IDENTIFIER' - AND = 'AND' - OR = 'OR' - NOT = 'NOT' - BETWEEN = 'BETWEEN' - IN = 'IN' - COMMA = 'COMMA' - LEFT_PAREN = 'LEFT_PAREN' - RIGHT_PAREN = 'RIGHT_PAREN' - WHITESPACE = 'WHITESPACE' + CONDITION = "CONDITION" + OPERAND = "OPERAND" + COMPARATOR = "COMPARATOR" + FUNCTION_NAME = "FUNCTION_NAME" + IDENTIFIER = "IDENTIFIER" + AND = "AND" + OR = "OR" + NOT = "NOT" + BETWEEN = "BETWEEN" + IN = "IN" + COMMA = "COMMA" + LEFT_PAREN = "LEFT_PAREN" + RIGHT_PAREN = "RIGHT_PAREN" + WHITESPACE = "WHITESPACE" - - Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + Node = namedtuple("Node", ["nonterminal", "kind", "text", "value", "children"]) def _lex_condition_expression(self): nodes = deque() remaining_expression = self.condition_expression while remaining_expression: - node, remaining_expression = \ - self._lex_one_node(remaining_expression) + node, remaining_expression = self._lex_one_node(remaining_expression) if node.nonterminal == self.Nonterminal.WHITESPACE: continue nodes.append(node) @@ -256,49 +251,52 @@ class ConditionExpressionParser: def _lex_one_node(self, remaining_expression): # TODO: Handle indexing like [1] - attribute_regex = '(:|#)?[A-z0-9\-_]+' - patterns = [( - self.Nonterminal.WHITESPACE, re.compile('^ +') - ), ( - self.Nonterminal.COMPARATOR, re.compile( - '^(' - # Put long expressions first for greedy matching - '<>|' - '<=|' - '>=|' - '=|' - '<|' - '>)'), - ), ( - self.Nonterminal.OPERAND, re.compile( - '^' + - attribute_regex + '(\.' + attribute_regex + '|\[[0-9]\])*') - ), ( - self.Nonterminal.COMMA, re.compile('^,') - ), ( - self.Nonterminal.LEFT_PAREN, re.compile('^\(') - ), ( - self.Nonterminal.RIGHT_PAREN, re.compile('^\)') - )] + attribute_regex = "(:|#)?[A-z0-9\-_]+" + patterns = [ + (self.Nonterminal.WHITESPACE, re.compile("^ +")), + ( + self.Nonterminal.COMPARATOR, + re.compile( + "^(" + # Put long expressions first for greedy matching + "<>|" + "<=|" + ">=|" + "=|" + "<|" + ">)" + ), + ), + ( + self.Nonterminal.OPERAND, + re.compile( + "^" + attribute_regex + "(\." + attribute_regex + "|\[[0-9]\])*" + ), + ), + (self.Nonterminal.COMMA, re.compile("^,")), + (self.Nonterminal.LEFT_PAREN, re.compile("^\(")), + (self.Nonterminal.RIGHT_PAREN, re.compile("^\)")), + ] for nonterminal, pattern in patterns: match = pattern.match(remaining_expression) if match: match_text = match.group() break - else: # pragma: no cover - raise ValueError("Cannot parse condition starting at: " + - remaining_expression) + else: # pragma: no cover + raise ValueError( + "Cannot parse condition starting at: " + remaining_expression + ) - value = match_text node = self.Node( nonterminal=nonterminal, kind=self.Kind.LITERAL, text=match_text, value=match_text, - children=[]) + children=[], + ) - remaining_expression = remaining_expression[len(match_text):] + remaining_expression = remaining_expression[len(match_text) :] return node, remaining_expression @@ -309,10 +307,8 @@ class ConditionExpressionParser: node = nodes.popleft() if node.nonterminal == self.Nonterminal.OPERAND: - path = node.value.replace('[', '.[').split('.') - children = [ - self._parse_path_element(name) - for name in path] + path = node.value.replace("[", ".[").split(".") + children = [self._parse_path_element(name) for name in path] if len(children) == 1: child = children[0] if child.nonterminal != self.Nonterminal.IDENTIFIER: @@ -322,36 +318,40 @@ class ConditionExpressionParser: for child in children: self._assert( child.nonterminal == self.Nonterminal.IDENTIFIER, - "Cannot use %s in path" % child.text, [node]) - output.append(self.Node( - nonterminal=self.Nonterminal.OPERAND, - kind=self.Kind.PATH, - text=node.text, - value=None, - children=children)) + "Cannot use %s in path" % child.text, + [node], + ) + output.append( + self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.PATH, + text=node.text, + value=None, + children=children, + ) + ) else: output.append(node) return output def _parse_path_element(self, name): reserved = { - 'and': self.Nonterminal.AND, - 'or': self.Nonterminal.OR, - 'in': self.Nonterminal.IN, - 'between': self.Nonterminal.BETWEEN, - 'not': self.Nonterminal.NOT, + "and": self.Nonterminal.AND, + "or": self.Nonterminal.OR, + "in": self.Nonterminal.IN, + "between": self.Nonterminal.BETWEEN, + "not": self.Nonterminal.NOT, } functions = { - 'attribute_exists', - 'attribute_not_exists', - 'attribute_type', - 'begins_with', - 'contains', - 'size', + "attribute_exists", + "attribute_not_exists", + "attribute_type", + "begins_with", + "contains", + "size", } - if name.lower() in reserved: # e.g. AND nonterminal = reserved[name.lower()] @@ -360,7 +360,8 @@ class ConditionExpressionParser: kind=self.Kind.LITERAL, text=name, value=name, - children=[]) + children=[], + ) elif name in functions: # e.g. attribute_exists return self.Node( @@ -368,33 +369,37 @@ class ConditionExpressionParser: kind=self.Kind.LITERAL, text=name, value=name, - children=[]) - elif name.startswith(':'): + children=[], + ) + elif name.startswith(":"): # e.g. :value0 return self.Node( nonterminal=self.Nonterminal.OPERAND, kind=self.Kind.EXPRESSION_ATTRIBUTE_VALUE, text=name, value=self._lookup_expression_attribute_value(name), - children=[]) - elif name.startswith('#'): + children=[], + ) + elif name.startswith("#"): # e.g. #name0 return self.Node( nonterminal=self.Nonterminal.IDENTIFIER, kind=self.Kind.LITERAL, text=name, value=self._lookup_expression_attribute_name(name), - children=[]) - elif name.startswith('['): + children=[], + ) + elif name.startswith("["): # e.g. [123] - if not name.endswith(']'): # pragma: no cover + if not name.endswith("]"): # pragma: no cover raise ValueError("Bad path element %s" % name) return self.Node( nonterminal=self.Nonterminal.IDENTIFIER, kind=self.Kind.LITERAL, text=name, value=int(name[1:-1]), - children=[]) + children=[], + ) else: # e.g. ItemId return self.Node( @@ -402,7 +407,8 @@ class ConditionExpressionParser: kind=self.Kind.LITERAL, text=name, value=name, - children=[]) + children=[], + ) def _lookup_expression_attribute_value(self, name): return self.expression_attribute_values[name] @@ -465,7 +471,7 @@ class ConditionExpressionParser: if len(nodes) < len(production): return False for i in range(len(production)): - if production[i] == '*': + if production[i] == "*": continue expected = getattr(self.Nonterminal, production[i]) if nodes[i].nonterminal != expected: @@ -477,22 +483,24 @@ class ConditionExpressionParser: output = deque() while nodes: - if self._matches(nodes, ['*', 'COMPARATOR']): + if self._matches(nodes, ["*", "COMPARATOR"]): self._assert( - self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), - "Bad comparison", list(nodes)[:3]) + self._matches(nodes, ["OPERAND", "COMPARATOR", "OPERAND"]), + "Bad comparison", + list(nodes)[:3], + ) lhs = nodes.popleft() comparator = nodes.popleft() rhs = nodes.popleft() - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.COMPARISON, - text=" ".join([ - lhs.text, - comparator.text, - rhs.text]), - value=None, - children=[lhs, comparator, rhs])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.COMPARISON, + text=" ".join([lhs.text, comparator.text, rhs.text]), + value=None, + children=[lhs, comparator, rhs], + ) + ) else: output.append(nodes.popleft()) return output @@ -501,37 +509,40 @@ class ConditionExpressionParser: """Apply condition := operand IN ( operand , ... ).""" output = deque() while nodes: - if self._matches(nodes, ['*', 'IN']): + if self._matches(nodes, ["*", "IN"]): self._assert( - self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), - "Bad IN expression", list(nodes)[:3]) + self._matches(nodes, ["OPERAND", "IN", "LEFT_PAREN"]), + "Bad IN expression", + list(nodes)[:3], + ) lhs = nodes.popleft() in_node = nodes.popleft() left_paren = nodes.popleft() all_children = [lhs, in_node, left_paren] rhs = [] while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): + if self._matches(nodes, ["OPERAND", "COMMA"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] rhs.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + elif self._matches(nodes, ["OPERAND", "RIGHT_PAREN"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] rhs.append(operand) break # Close else: - self._assert( - False, - "Bad IN expression starting at", nodes) - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.IN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs] + rhs)) + self._assert(False, "Bad IN expression starting at", nodes) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs, + ) + ) else: output.append(nodes.popleft()) return output @@ -540,23 +551,29 @@ class ConditionExpressionParser: """Apply condition := operand BETWEEN operand AND operand.""" output = deque() while nodes: - if self._matches(nodes, ['*', 'BETWEEN']): + if self._matches(nodes, ["*", "BETWEEN"]): self._assert( - self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', - 'AND', 'OPERAND']), - "Bad BETWEEN expression", list(nodes)[:5]) + self._matches( + nodes, ["OPERAND", "BETWEEN", "OPERAND", "AND", "OPERAND"] + ), + "Bad BETWEEN expression", + list(nodes)[:5], + ) lhs = nodes.popleft() between_node = nodes.popleft() low = nodes.popleft() and_node = nodes.popleft() high = nodes.popleft() all_children = [lhs, between_node, low, and_node, high] - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.BETWEEN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, low, high])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high], + ) + ) else: output.append(nodes.popleft()) return output @@ -566,30 +583,33 @@ class ConditionExpressionParser: output = deque() either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} expected_argument_kind_map = { - 'attribute_exists': [{self.Kind.PATH}], - 'attribute_not_exists': [{self.Kind.PATH}], - 'attribute_type': [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'begins_with': [either_kind, either_kind], - 'contains': [either_kind, either_kind], - 'size': [{self.Kind.PATH}], + "attribute_exists": [{self.Kind.PATH}], + "attribute_not_exists": [{self.Kind.PATH}], + "attribute_type": [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], + "begins_with": [either_kind, either_kind], + "contains": [either_kind, either_kind], + "size": [{self.Kind.PATH}], } while nodes: - if self._matches(nodes, ['FUNCTION_NAME']): + if self._matches(nodes, ["FUNCTION_NAME"]): self._assert( - self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', - 'OPERAND', '*']), - "Bad function expression at", list(nodes)[:4]) + self._matches( + nodes, ["FUNCTION_NAME", "LEFT_PAREN", "OPERAND", "*"] + ), + "Bad function expression at", + list(nodes)[:4], + ) function_name = nodes.popleft() left_paren = nodes.popleft() all_children = [function_name, left_paren] arguments = [] while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): + if self._matches(nodes, ["OPERAND", "COMMA"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] arguments.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + elif self._matches(nodes, ["OPERAND", "RIGHT_PAREN"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] @@ -598,25 +618,34 @@ class ConditionExpressionParser: else: self._assert( False, - "Bad function expression", all_children + list(nodes)[:2]) + "Bad function expression", + all_children + list(nodes)[:2], + ) expected_kinds = expected_argument_kind_map[function_name.value] self._assert( len(arguments) == len(expected_kinds), - "Wrong number of arguments in", all_children) + "Wrong number of arguments in", + all_children, + ) for i in range(len(expected_kinds)): self._assert( arguments[i].kind in expected_kinds[i], - "Wrong type for argument %d in" % i, all_children) - if function_name.value == 'size': + "Wrong type for argument %d in" % i, + all_children, + ) + if function_name.value == "size": nonterminal = self.Nonterminal.OPERAND else: nonterminal = self.Nonterminal.CONDITION - nodes.appendleft(self.Node( - nonterminal=nonterminal, - kind=self.Kind.FUNCTION, - text=" ".join([t.text for t in all_children]), - value=None, - children=[function_name] + arguments)) + nodes.appendleft( + self.Node( + nonterminal=nonterminal, + kind=self.Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments, + ) + ) else: output.append(nodes.popleft()) return output @@ -625,38 +654,40 @@ class ConditionExpressionParser: """Apply condition := ( condition ) and booleans.""" output = deque() while nodes: - if self._matches(nodes, ['LEFT_PAREN']): - parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) - self._assert( - len(parsed) >= 1, - "Failed to close parentheses at", nodes) + if self._matches(nodes, ["LEFT_PAREN"]): + parsed = self._apply_parens_and_booleans( + nodes, left_paren=nodes.popleft() + ) + self._assert(len(parsed) >= 1, "Failed to close parentheses at", nodes) parens = parsed.popleft() self._assert( parens.kind == self.Kind.PARENTHESES, - "Failed to close parentheses at", nodes) + "Failed to close parentheses at", + nodes, + ) output.append(parens) nodes = parsed - elif self._matches(nodes, ['RIGHT_PAREN']): - self._assert( - left_paren is not None, - "Unmatched ) at", nodes) + elif self._matches(nodes, ["RIGHT_PAREN"]): + self._assert(left_paren is not None, "Unmatched ) at", nodes) close_paren = nodes.popleft() children = self._apply_booleans(output) all_children = [left_paren] + list(children) + [close_paren] - return deque([ - self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.PARENTHESES, - text=" ".join([t.text for t in all_children]), - value=None, - children=list(children), - )] + list(nodes)) + return deque( + [ + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ) + ] + + list(nodes) + ) else: output.append(nodes.popleft()) - self._assert( - left_paren is None, - "Unmatched ( at", list(output)) + self._assert(left_paren is None, "Unmatched ( at", list(output)) return self._apply_booleans(output) def _apply_booleans(self, nodes): @@ -665,30 +696,35 @@ class ConditionExpressionParser: nodes = self._apply_and(nodes) nodes = self._apply_or(nodes) # The expression should reduce to a single condition - self._assert( - len(nodes) == 1, - "Unexpected expression at", list(nodes)[1:]) + self._assert(len(nodes) == 1, "Unexpected expression at", list(nodes)[1:]) self._assert( nodes[0].nonterminal == self.Nonterminal.CONDITION, - "Incomplete condition", nodes) + "Incomplete condition", + nodes, + ) return nodes def _apply_not(self, nodes): """Apply condition := NOT condition.""" output = deque() while nodes: - if self._matches(nodes, ['NOT']): + if self._matches(nodes, ["NOT"]): self._assert( - self._matches(nodes, ['NOT', 'CONDITION']), - "Bad NOT expression", list(nodes)[:2]) + self._matches(nodes, ["NOT", "CONDITION"]), + "Bad NOT expression", + list(nodes)[:2], + ) not_node = nodes.popleft() child = nodes.popleft() - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.NOT, - text=" ".join([not_node.text, child.text]), - value=None, - children=[child])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.NOT, + text=" ".join([not_node.text, child.text]), + value=None, + children=[child], + ) + ) else: output.append(nodes.popleft()) @@ -698,20 +734,25 @@ class ConditionExpressionParser: """Apply condition := condition AND condition.""" output = deque() while nodes: - if self._matches(nodes, ['*', 'AND']): + if self._matches(nodes, ["*", "AND"]): self._assert( - self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), - "Bad AND expression", list(nodes)[:3]) + self._matches(nodes, ["CONDITION", "AND", "CONDITION"]), + "Bad AND expression", + list(nodes)[:3], + ) lhs = nodes.popleft() and_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, and_node, rhs] - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.AND, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs], + ) + ) else: output.append(nodes.popleft()) @@ -721,20 +762,25 @@ class ConditionExpressionParser: """Apply condition := condition OR condition.""" output = deque() while nodes: - if self._matches(nodes, ['*', 'OR']): + if self._matches(nodes, ["*", "OR"]): self._assert( - self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), - "Bad OR expression", list(nodes)[:3]) + self._matches(nodes, ["CONDITION", "OR", "CONDITION"]), + "Bad OR expression", + list(nodes)[:3], + ) lhs = nodes.popleft() or_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, or_node, rhs] - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.OR, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs], + ) + ) else: output.append(nodes.popleft()) @@ -748,30 +794,25 @@ class ConditionExpressionParser: elif node.kind == self.Kind.FUNCTION: # size() function_node = node.children[0] - arguments = node.children[1:] + arguments = node.children[1:] function_name = function_node.value arguments = [self._make_operand(arg) for arg in arguments] return FUNC_CLASS[function_name](*arguments) - else: # pragma: no cover + else: # pragma: no cover raise ValueError("Unknown operand: %r" % node) - def _make_op_condition(self, node): if node.kind == self.Kind.OR: lhs, rhs = node.children - return OpOr( - self._make_op_condition(lhs), - self._make_op_condition(rhs)) + return OpOr(self._make_op_condition(lhs), self._make_op_condition(rhs)) elif node.kind == self.Kind.AND: lhs, rhs = node.children - return OpAnd( - self._make_op_condition(lhs), - self._make_op_condition(rhs)) + return OpAnd(self._make_op_condition(lhs), self._make_op_condition(rhs)) elif node.kind == self.Kind.NOT: - child, = node.children + (child,) = node.children return OpNot(self._make_op_condition(child)) elif node.kind == self.Kind.PARENTHESES: - child, = node.children + (child,) = node.children return self._make_op_condition(child) elif node.kind == self.Kind.FUNCTION: function_node = node.children[0] @@ -784,7 +825,8 @@ class ConditionExpressionParser: return FuncBetween( self._make_operand(query), self._make_operand(low), - self._make_operand(high)) + self._make_operand(high), + ) elif node.kind == self.Kind.IN: query = node.children[0] possible_values = node.children[1:] @@ -794,26 +836,11 @@ class ConditionExpressionParser: elif node.kind == self.Kind.COMPARISON: lhs, comparator, rhs = node.children return COMPARATOR_CLASS[comparator.value]( - self._make_operand(lhs), - self._make_operand(rhs)) - else: # pragma: no cover + self._make_operand(lhs), self._make_operand(rhs) + ) + else: # pragma: no cover raise ValueError("Unknown expression node kind %r" % node.kind) - def _print_debug(self, nodes): # pragma: no cover - print('ROOT') - for node in nodes: - self._print_node_recursive(node, depth=1) - - def _print_node_recursive(self, node, depth=0): # pragma: no cover - if len(node.children) > 0: - print(' ' * depth, node.nonterminal, node.kind) - for child in node.children: - self._print_node_recursive(child, depth=depth + 1) - else: - print(' ' * depth, node.nonterminal, node.kind, node.value) - - - def _assert(self, condition, message, nodes): if not condition: raise ValueError(message + " " + " ".join([t.text for t in nodes])) @@ -888,21 +915,20 @@ class AttributeValue(Operand): def expr(self, item): # TODO: Reuse DynamoType code - if self.type == 'N': + if self.type == "N": try: return int(self.value) except ValueError: return float(self.value) - elif self.type in ['SS', 'NS', 'BS']: + elif self.type in ["SS", "NS", "BS"]: sub_type = self.type[0] - return set([AttributeValue({sub_type: v}).expr(item) - for v in self.value]) - elif self.type == 'L': + return set([AttributeValue({sub_type: v}).expr(item) for v in self.value]) + elif self.type == "L": return [AttributeValue(v).expr(item) for v in self.value] - elif self.type == 'M': - return dict([ - (k, AttributeValue(v).expr(item)) - for k, v in self.value.items()]) + elif self.type == "M": + return dict( + [(k, AttributeValue(v).expr(item)) for k, v in self.value.items()] + ) else: return self.value return self.value @@ -915,7 +941,7 @@ class AttributeValue(Operand): class OpDefault(Op): - OP = 'NONE' + OP = "NONE" def expr(self, item): """If no condition is specified, always True.""" @@ -923,7 +949,7 @@ class OpDefault(Op): class OpNot(Op): - OP = 'NOT' + OP = "NOT" def __init__(self, lhs): super(OpNot, self).__init__(lhs, None) @@ -933,38 +959,49 @@ class OpNot(Op): return not lhs def __str__(self): - return '({0} {1})'.format(self.OP, self.lhs) + return "({0} {1})".format(self.OP, self.lhs) class OpAnd(Op): - OP = 'AND' + OP = "AND" def expr(self, item): lhs = self.lhs.expr(item) - rhs = self.rhs.expr(item) - return lhs and rhs + return lhs and self.rhs.expr(item) class OpLessThan(Op): - OP = '<' + OP = "<" def expr(self, item): lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) - return lhs < rhs + # In python3 None is not a valid comparator when using < or > so must be handled specially + if lhs and rhs: + return lhs < rhs + elif lhs is None and rhs: + return True + else: + return False class OpGreaterThan(Op): - OP = '>' + OP = ">" def expr(self, item): lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) - return lhs > rhs + # In python3 None is not a valid comparator when using < or > so must be handled specially + if lhs and rhs: + return lhs > rhs + elif lhs and rhs is None: + return True + else: + return False class OpEqual(Op): - OP = '=' + OP = "=" def expr(self, item): lhs = self.lhs.expr(item) @@ -973,7 +1010,7 @@ class OpEqual(Op): class OpNotEqual(Op): - OP = '<>' + OP = "<>" def expr(self, item): lhs = self.lhs.expr(item) @@ -982,25 +1019,37 @@ class OpNotEqual(Op): class OpLessThanOrEqual(Op): - OP = '<=' + OP = "<=" def expr(self, item): lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) - return lhs <= rhs + # In python3 None is not a valid comparator when using < or > so must be handled specially + if lhs and rhs: + return lhs <= rhs + elif lhs is None and rhs or lhs is None and rhs is None: + return True + else: + return False class OpGreaterThanOrEqual(Op): - OP = '>=' + OP = ">=" def expr(self, item): lhs = self.lhs.expr(item) rhs = self.rhs.expr(item) - return lhs >= rhs + # In python3 None is not a valid comparator when using < or > so must be handled specially + if lhs and rhs: + return lhs >= rhs + elif lhs and rhs is None or lhs is None and rhs is None: + return True + else: + return False class OpOr(Op): - OP = 'OR' + OP = "OR" def expr(self, item): lhs = self.lhs.expr(item) @@ -1011,7 +1060,8 @@ class Func(object): """ Base class for a FilterExpression function """ - FUNC = 'Unknown' + + FUNC = "Unknown" def __init__(self, *arguments): self.arguments = arguments @@ -1020,13 +1070,13 @@ class Func(object): raise NotImplementedError def __repr__(self): - return '{0}({1})'.format( - self.FUNC, - " ".join([repr(arg) for arg in self.arguments])) + return "{0}({1})".format( + self.FUNC, " ".join([repr(arg) for arg in self.arguments]) + ) class FuncAttrExists(Func): - FUNC = 'attribute_exists' + FUNC = "attribute_exists" def __init__(self, attribute): self.attr = attribute @@ -1041,7 +1091,7 @@ def FuncAttrNotExists(attribute): class FuncAttrType(Func): - FUNC = 'attribute_type' + FUNC = "attribute_type" def __init__(self, attribute, _type): self.attr = attribute @@ -1053,7 +1103,7 @@ class FuncAttrType(Func): class FuncBeginsWith(Func): - FUNC = 'begins_with' + FUNC = "begins_with" def __init__(self, attribute, substr): self.attr = attribute @@ -1061,15 +1111,15 @@ class FuncBeginsWith(Func): super(FuncBeginsWith, self).__init__(attribute, substr) def expr(self, item): - if self.attr.get_type(item) != 'S': + if self.attr.get_type(item) != "S": return False - if self.substr.get_type(item) != 'S': + if self.substr.get_type(item) != "S": return False return self.attr.expr(item).startswith(self.substr.expr(item)) class FuncContains(Func): - FUNC = 'contains' + FUNC = "contains" def __init__(self, attribute, operand): self.attr = attribute @@ -1077,7 +1127,7 @@ class FuncContains(Func): super(FuncContains, self).__init__(attribute, operand) def expr(self, item): - if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): + if self.attr.get_type(item) in ("S", "SS", "NS", "BS", "L"): try: return self.operand.expr(item) in self.attr.expr(item) except TypeError: @@ -1090,7 +1140,7 @@ def FuncNotContains(attribute, operand): class FuncSize(Func): - FUNC = 'size' + FUNC = "size" def __init__(self, attribute): self.attr = attribute @@ -1098,15 +1148,15 @@ class FuncSize(Func): def expr(self, item): if self.attr.get_type(item) is None: - raise ValueError('Invalid attribute name {0}'.format(self.attr)) + raise ValueError("Invalid attribute name {0}".format(self.attr)) - if self.attr.get_type(item) in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): + if self.attr.get_type(item) in ("S", "SS", "NS", "B", "BS", "L", "M"): return len(self.attr.expr(item)) - raise ValueError('Invalid filter expression') + raise ValueError("Invalid filter expression") class FuncBetween(Func): - FUNC = 'BETWEEN' + FUNC = "BETWEEN" def __init__(self, attribute, start, end): self.attr = attribute @@ -1115,11 +1165,23 @@ class FuncBetween(Func): super(FuncBetween, self).__init__(attribute, start, end) def expr(self, item): - return self.start.expr(item) <= self.attr.expr(item) <= self.end.expr(item) + # In python3 None is not a valid comparator when using < or > so must be handled specially + start = self.start.expr(item) + attr = self.attr.expr(item) + end = self.end.expr(item) + if start and attr and end: + return start <= attr <= end + elif start is None and attr is None: + # None is between None and None as well as None is between None and any number + return True + elif start is None and attr and end: + return attr <= end + else: + return False class FuncIn(Func): - FUNC = 'IN' + FUNC = "IN" def __init__(self, attribute, *possible_values): self.attr = attribute @@ -1135,20 +1197,20 @@ class FuncIn(Func): COMPARATOR_CLASS = { - '<': OpLessThan, - '>': OpGreaterThan, - '<=': OpLessThanOrEqual, - '>=': OpGreaterThanOrEqual, - '=': OpEqual, - '<>': OpNotEqual + "<": OpLessThan, + ">": OpGreaterThan, + "<=": OpLessThanOrEqual, + ">=": OpGreaterThanOrEqual, + "=": OpEqual, + "<>": OpNotEqual, } FUNC_CLASS = { - 'attribute_exists': FuncAttrExists, - 'attribute_not_exists': FuncAttrNotExists, - 'attribute_type': FuncAttrType, - 'begins_with': FuncBeginsWith, - 'contains': FuncContains, - 'size': FuncSize, - 'between': FuncBetween + "attribute_exists": FuncAttrExists, + "attribute_not_exists": FuncAttrNotExists, + "attribute_type": FuncAttrType, + "begins_with": FuncBeginsWith, + "contains": FuncContains, + "size": FuncSize, + "between": FuncBetween, } diff --git a/moto/dynamodb2/exceptions.py b/moto/dynamodb2/exceptions.py index 9df973292..1f3b5f974 100644 --- a/moto/dynamodb2/exceptions.py +++ b/moto/dynamodb2/exceptions.py @@ -1,2 +1,10 @@ class InvalidIndexNameError(ValueError): pass + + +class InvalidUpdateExpression(ValueError): + pass + + +class ItemSizeTooLarge(Exception): + message = "Item size has exceeded the maximum allowed size" diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index e868caaa8..121f564a4 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -9,6 +9,7 @@ import uuid import six import boto3 +from botocore.exceptions import ParamValidationError from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time @@ -16,13 +17,12 @@ from moto.core.exceptions import JsonRESTError from .comparisons import get_comparison_func from .comparisons import get_filter_expression from .comparisons import get_expected -from .exceptions import InvalidIndexNameError +from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSizeTooLarge class DynamoJsonEncoder(json.JSONEncoder): - def default(self, obj): - if hasattr(obj, 'to_json'): + if hasattr(obj, "to_json"): return obj.to_json() @@ -30,35 +30,133 @@ def dynamo_json_dump(dynamo_object): return json.dumps(dynamo_object, cls=DynamoJsonEncoder) +def bytesize(val): + return len(str(val).encode("utf-8")) + + +def attribute_is_list(attr): + """ + Checks if attribute denotes a list, and returns the name of the list and the given list index if so + :param attr: attr or attr[index] + :return: attr, index or None + """ + list_index_update = re.match("(.+)\\[([0-9]+)\\]", attr) + if list_index_update: + attr = list_index_update.group(1) + return attr, list_index_update.group(2) if list_index_update else None + + class DynamoType(object): """ http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes """ def __init__(self, type_as_dict): - self.type = list(type_as_dict)[0] - self.value = list(type_as_dict.values())[0] + if type(type_as_dict) == DynamoType: + self.type = type_as_dict.type + self.value = type_as_dict.value + else: + self.type = list(type_as_dict)[0] + self.value = list(type_as_dict.values())[0] + if self.is_list(): + self.value = [DynamoType(val) for val in self.value] + elif self.is_map(): + self.value = dict((k, DynamoType(v)) for k, v in self.value.items()) + + def get(self, key): + if not key: + return self + else: + key_head = key.split(".")[0] + key_tail = ".".join(key.split(".")[1:]) + if key_head not in self.value: + self.value[key_head] = DynamoType({"NONE": None}) + return self.value[key_head].get(key_tail) + + def set(self, key, new_value, index=None): + if index: + index = int(index) + if type(self.value) is not list: + raise InvalidUpdateExpression + if index >= len(self.value): + self.value.append(new_value) + # {'L': [DynamoType, ..]} ==> DynamoType.set() + self.value[min(index, len(self.value) - 1)].set(key, new_value) + else: + attr = (key or "").split(".").pop(0) + attr, list_index = attribute_is_list(attr) + if not key: + # {'S': value} ==> {'S': new_value} + self.type = new_value.type + self.value = new_value.value + else: + if attr not in self.value: # nonexistingattribute + type_of_new_attr = "M" if "." in key else new_value.type + self.value[attr] = DynamoType({type_of_new_attr: {}}) + # {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value) + self.value[attr].set( + ".".join(key.split(".")[1:]), new_value, list_index + ) + + def delete(self, key, index=None): + if index: + if not key: + if int(index) < len(self.value): + del self.value[int(index)] + elif "." in key: + self.value[int(index)].delete(".".join(key.split(".")[1:])) + else: + self.value[int(index)].delete(key) + else: + attr = key.split(".")[0] + attr, list_index = attribute_is_list(attr) + + if list_index: + self.value[attr].delete(".".join(key.split(".")[1:]), list_index) + elif "." in key: + self.value[attr].delete(".".join(key.split(".")[1:])) + else: + self.value.pop(key) + + def filter(self, projection_expressions): + nested_projections = [ + expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr + ] + if self.is_map(): + expressions_to_delete = [] + for attr in self.value: + if ( + attr not in projection_expressions + and attr not in nested_projections + ): + expressions_to_delete.append(attr) + elif attr in nested_projections: + relevant_expressions = [ + expr[len(attr + ".") :] + for expr in projection_expressions + if expr.startswith(attr + ".") + ] + self.value[attr].filter(relevant_expressions) + for expr in expressions_to_delete: + self.value.pop(expr) def __hash__(self): return hash((self.type, self.value)) def __eq__(self, other): - return ( - self.type == other.type and - self.value == other.value - ) + return self.type == other.type and self.value == other.value def __lt__(self, other): - return self.value < other.value + return self.cast_value < other.cast_value def __le__(self, other): - return self.value <= other.value + return self.cast_value <= other.cast_value def __gt__(self, other): - return self.value > other.value + return self.cast_value > other.cast_value def __ge__(self, other): - return self.value >= other.value + return self.cast_value >= other.cast_value def __repr__(self): return "DynamoType: {0}".format(self.to_json()) @@ -72,14 +170,11 @@ class DynamoType(object): return float(self.value) elif self.is_set(): sub_type = self.type[0] - return set([DynamoType({sub_type: v}).cast_value - for v in self.value]) + return set([DynamoType({sub_type: v}).cast_value for v in self.value]) elif self.is_list(): return [DynamoType(v).cast_value for v in self.value] elif self.is_map(): - return dict([ - (k, DynamoType(v).cast_value) - for k, v in self.value.items()]) + return dict([(k, DynamoType(v).cast_value) for k, v in self.value.items()]) else: return self.value @@ -89,16 +184,39 @@ class DynamoType(object): Returns DynamoType or None. """ - if isinstance(key, six.string_types) and self.is_map() and key in self.value: - return DynamoType(self.value[key]) + if isinstance(key, six.string_types) and self.is_map(): + if "." in key and key.split(".")[0] in self.value: + return self.value[key.split(".")[0]].child_attr( + ".".join(key.split(".")[1:]) + ) + elif "." not in key and key in self.value: + return DynamoType(self.value[key]) if isinstance(key, int) and self.is_list(): idx = key - if idx >= 0 and idx < len(self.value): + if 0 <= idx < len(self.value): return DynamoType(self.value[idx]) return None + def size(self): + if self.is_number(): + value_size = len(str(self.value)) + elif self.is_set(): + sub_type = self.type[0] + value_size = sum([DynamoType({sub_type: v}).size() for v in self.value]) + elif self.is_list(): + value_size = sum([v.size() for v in self.value]) + elif self.is_map(): + value_size = sum( + [bytesize(k) + DynamoType(v).size() for k, v in self.value.items()] + ) + elif type(self.value) == bool: + value_size = 1 + else: + value_size = bytesize(self.value) + return value_size + def to_json(self): return {self.type: self.value} @@ -111,30 +229,53 @@ class DynamoType(object): return comparison_func(self.cast_value, *range_values) def is_number(self): - return self.type == 'N' + return self.type == "N" def is_set(self): - return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' + return self.type == "SS" or self.type == "NS" or self.type == "BS" def is_list(self): - return self.type == 'L' + return self.type == "L" def is_map(self): - return self.type == 'M' + return self.type == "M" def same_type(self, other): return self.type == other.type -class Item(BaseModel): +# https://github.com/spulec/moto/issues/1874 +# Ensure that the total size of an item does not exceed 400kb +class LimitedSizeDict(dict): + def __init__(self, *args, **kwargs): + self.update(*args, **kwargs) + def __setitem__(self, key, value): + current_item_size = sum( + [ + item.size() if type(item) == DynamoType else bytesize(str(item)) + for item in (list(self.keys()) + list(self.values())) + ] + ) + new_item_size = bytesize(key) + ( + value.size() if type(value) == DynamoType else bytesize(str(value)) + ) + # Official limit is set to 400000 (400KB) + # Manual testing confirms that the actual limit is between 409 and 410KB + # We'll set the limit to something in between to be safe + if (current_item_size + new_item_size) > 405000: + raise ItemSizeTooLarge + super(LimitedSizeDict, self).__setitem__(key, value) + + +class Item(BaseModel): def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): self.hash_key = hash_key self.hash_key_type = hash_key_type self.range_key = range_key self.range_key_type = range_key_type - self.attrs = {} + self.attrs = LimitedSizeDict() for key, value in attrs.items(): self.attrs[key] = DynamoType(value) @@ -144,13 +285,9 @@ class Item(BaseModel): def to_json(self): attributes = {} for attribute_key, attribute in self.attrs.items(): - attributes[attribute_key] = { - attribute.type: attribute.value - } + attributes[attribute_key] = {attribute.type: attribute.value} - return { - "Attributes": attributes - } + return {"Attributes": attributes} def describe_attrs(self, attributes): if attributes: @@ -160,83 +297,80 @@ class Item(BaseModel): included[key] = value else: included = self.attrs - return { - "Item": included - } + return {"Item": included} - def update(self, update_expression, expression_attribute_names, expression_attribute_values): + def update( + self, update_expression, expression_attribute_names, expression_attribute_values + ): # Update subexpressions are identifiable by the operator keyword, so split on that and # get rid of the empty leading string. - parts = [p for p in re.split(r'\b(SET|REMOVE|ADD|DELETE)\b', update_expression, flags=re.I) if p] + parts = [ + p + for p in re.split( + r"\b(SET|REMOVE|ADD|DELETE)\b", update_expression, flags=re.I + ) + if p + ] # make sure that we correctly found only operator/value pairs - assert len(parts) % 2 == 0, "Mismatched operators and values in update expression: '{}'".format(update_expression) + assert ( + len(parts) % 2 == 0 + ), "Mismatched operators and values in update expression: '{}'".format( + update_expression + ) for action, valstr in zip(parts[:-1:2], parts[1::2]): action = action.upper() # "Should" retain arguments in side (...) - values = re.split(r',(?![^(]*\))', valstr) + values = re.split(r",(?![^(]*\))", valstr) for value in values: # A Real value value = value.lstrip(":").rstrip(",").strip() for k, v in expression_attribute_names.items(): - value = re.sub(r'{0}\b'.format(k), v, value) + value = re.sub(r"{0}\b".format(k), v, value) if action == "REMOVE": - self.attrs.pop(value, None) - elif action == 'SET': + key = value + attr, list_index = attribute_is_list(key.split(".")[0]) + if "." not in key: + if list_index: + new_list = DynamoType(self.attrs[attr]) + new_list.delete(None, list_index) + self.attrs[attr] = new_list + else: + self.attrs.pop(value, None) + else: + # Handle nested dict updates + self.attrs[attr].delete(".".join(key.split(".")[1:])) + elif action == "SET": key, value = value.split("=", 1) key = key.strip() value = value.strip() - # If not exists, changes value to a default if needed, else its the same as it was - if value.startswith('if_not_exists'): - # Function signature - match = re.match(r'.*if_not_exists\s*\((?P.+),\s*(?P.+)\).*', value) - if not match: - raise TypeError - - path, value = match.groups() - - # If it already exists, get its value so we dont overwrite it - if path in self.attrs: - value = self.attrs[path] + # check whether key is a list + attr, list_index = attribute_is_list(key.split(".")[0]) + # If value not exists, changes value to a default if needed, else its the same as it was + value = self._get_default(value) + # If operation == list_append, get the original value and append it + value = self._get_appended_list(value, expression_attribute_values) if type(value) != DynamoType: if value in expression_attribute_values: - value = DynamoType(expression_attribute_values[value]) + dyn_value = DynamoType(expression_attribute_values[value]) else: - value = DynamoType({"S": value}) - - if '.' not in key: - self.attrs[key] = value + dyn_value = DynamoType({"S": value}) else: - # Handle nested dict updates - key_parts = key.split('.') - attr = key_parts.pop(0) - if attr not in self.attrs: - raise ValueError + dyn_value = value - last_val = self.attrs[attr].value - for key_part in key_parts: - # Hack but it'll do, traverses into a dict - last_val_type = list(last_val.keys()) - if last_val_type and last_val_type[0] == 'M': - last_val = last_val['M'] + if "." in key and attr not in self.attrs: + raise ValueError # Setting nested attr not allowed if first attr does not exist yet + elif attr not in self.attrs: + self.attrs[attr] = dyn_value # set new top-level attribute + else: + self.attrs[attr].set( + ".".join(key.split(".")[1:]), dyn_value, list_index + ) # set value recursively - if key_part not in last_val: - last_val[key_part] = {'M': {}} - - last_val = last_val[key_part] - - # We have reference to a nested object but we cant just assign to it - current_type = list(last_val.keys())[0] - if current_type == value.type: - last_val[current_type] = value.value - else: - last_val[value.type] = value.value - del last_val[current_type] - - elif action == 'ADD': + elif action == "ADD": key, value = value.split(" ", 1) key = key.strip() value_str = value.strip() @@ -248,27 +382,39 @@ class Item(BaseModel): # Handle adding numbers - value gets added to existing value, # or added to 0 if it doesn't exist yet if dyn_value.is_number(): - existing = self.attrs.get(key, DynamoType({"N": '0'})) + existing = self.attrs.get(key, DynamoType({"N": "0"})) if not existing.same_type(dyn_value): raise TypeError() - self.attrs[key] = DynamoType({"N": str( - decimal.Decimal(existing.value) + - decimal.Decimal(dyn_value.value) - )}) + self.attrs[key] = DynamoType( + { + "N": str( + decimal.Decimal(existing.value) + + decimal.Decimal(dyn_value.value) + ) + } + ) # Handle adding sets - value is added to the set, or set is # created with only this value if it doesn't exist yet # New value must be of same set type as previous value elif dyn_value.is_set(): - existing = self.attrs.get(key, DynamoType({dyn_value.type: {}})) - if not existing.same_type(dyn_value): + key_head = key.split(".")[0] + key_tail = ".".join(key.split(".")[1:]) + if key_head not in self.attrs: + self.attrs[key_head] = DynamoType({dyn_value.type: {}}) + existing = self.attrs.get(key_head) + existing = existing.get(key_tail) + if existing.value and not existing.same_type(dyn_value): raise TypeError() - new_set = set(existing.value).union(dyn_value.value) - self.attrs[key] = DynamoType({existing.type: list(new_set)}) + new_set = set(existing.value or []).union(dyn_value.value) + existing.set( + key=None, + new_value=DynamoType({dyn_value.type: list(new_set)}), + ) else: # Number and Sets are the only supported types for ADD raise TypeError - elif action == 'DELETE': + elif action == "DELETE": key, value = value.split(" ", 1) key = key.strip() value_str = value.strip() @@ -279,24 +425,67 @@ class Item(BaseModel): if not dyn_value.is_set(): raise TypeError - existing = self.attrs.get(key, None) + key_head = key.split(".")[0] + key_tail = ".".join(key.split(".")[1:]) + existing = self.attrs.get(key_head) + existing = existing.get(key_tail) if existing: if not existing.same_type(dyn_value): raise TypeError new_set = set(existing.value).difference(dyn_value.value) - self.attrs[key] = DynamoType({existing.type: list(new_set)}) + existing.set( + key=None, + new_value=DynamoType({existing.type: list(new_set)}), + ) else: - raise NotImplementedError('{} update action not yet supported'.format(action)) + raise NotImplementedError( + "{} update action not yet supported".format(action) + ) + + def _get_appended_list(self, value, expression_attribute_values): + if type(value) != DynamoType: + list_append_re = re.match("list_append\\((.+),(.+)\\)", value) + if list_append_re: + new_value = expression_attribute_values[list_append_re.group(2).strip()] + old_list_key = list_append_re.group(1) + # Get the existing value + old_list = self.attrs[old_list_key.split(".")[0]] + if "." in old_list_key: + # Value is nested inside a map - find the appropriate child attr + old_list = old_list.child_attr( + ".".join(old_list_key.split(".")[1:]) + ) + if not old_list.is_list(): + raise ParamValidationError + old_list.value.extend(new_value["L"]) + value = old_list + return value + + def _get_default(self, value): + if value.startswith("if_not_exists"): + # Function signature + match = re.match( + r".*if_not_exists\s*\((?P.+),\s*(?P.+)\).*", value + ) + if not match: + raise TypeError + + path, value = match.groups() + + # If it already exists, get its value so we dont overwrite it + if path in self.attrs: + value = self.attrs[path] + return value def update_with_attribute_updates(self, attribute_updates): for attribute_name, update_action in attribute_updates.items(): - action = update_action['Action'] - if action == 'DELETE' and 'Value' not in update_action: + action = update_action["Action"] + if action == "DELETE" and "Value" not in update_action: if attribute_name in self.attrs: del self.attrs[attribute_name] continue - new_value = list(update_action['Value'].values())[0] - if action == 'PUT': + new_value = list(update_action["Value"].values())[0] + if action == "PUT": # TODO deal with other types if isinstance(new_value, list): self.attrs[attribute_name] = DynamoType({"L": new_value}) @@ -304,50 +493,72 @@ class Item(BaseModel): self.attrs[attribute_name] = DynamoType({"SS": new_value}) elif isinstance(new_value, dict): self.attrs[attribute_name] = DynamoType({"M": new_value}) - elif set(update_action['Value'].keys()) == set(['N']): + elif set(update_action["Value"].keys()) == set(["N"]): self.attrs[attribute_name] = DynamoType({"N": new_value}) - elif set(update_action['Value'].keys()) == set(['NULL']): + elif set(update_action["Value"].keys()) == set(["NULL"]): if attribute_name in self.attrs: del self.attrs[attribute_name] else: self.attrs[attribute_name] = DynamoType({"S": new_value}) - elif action == 'ADD': - if set(update_action['Value'].keys()) == set(['N']): - existing = self.attrs.get( - attribute_name, DynamoType({"N": '0'})) - self.attrs[attribute_name] = DynamoType({"N": str( - decimal.Decimal(existing.value) + - decimal.Decimal(new_value) - )}) - elif set(update_action['Value'].keys()) == set(['SS']): + elif action == "ADD": + if set(update_action["Value"].keys()) == set(["N"]): + existing = self.attrs.get(attribute_name, DynamoType({"N": "0"})) + self.attrs[attribute_name] = DynamoType( + { + "N": str( + decimal.Decimal(existing.value) + + decimal.Decimal(new_value) + ) + } + ) + elif set(update_action["Value"].keys()) == set(["SS"]): existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) new_set = set(existing.value).union(set(new_value)) - self.attrs[attribute_name] = DynamoType({ - "SS": list(new_set) - }) + self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) else: # TODO: implement other data types raise NotImplementedError( - 'ADD not supported for %s' % ', '.join(update_action['Value'].keys())) - elif action == 'DELETE': - if set(update_action['Value'].keys()) == set(['SS']): + "ADD not supported for %s" + % ", ".join(update_action["Value"].keys()) + ) + elif action == "DELETE": + if set(update_action["Value"].keys()) == set(["SS"]): existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) new_set = set(existing.value).difference(set(new_value)) - self.attrs[attribute_name] = DynamoType({ - "SS": list(new_set) - }) + self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) else: raise NotImplementedError( - 'ADD not supported for %s' % ', '.join(update_action['Value'].keys())) + "ADD not supported for %s" + % ", ".join(update_action["Value"].keys()) + ) else: raise NotImplementedError( - '%s action not support for update_with_attribute_updates' % action) + "%s action not support for update_with_attribute_updates" % action + ) + + # Filter using projection_expression + # Ensure a deep copy is used to filter, otherwise actual data will be removed + def filter(self, projection_expression): + expressions = [x.strip() for x in projection_expression.split(",")] + top_level_expressions = [ + expr[0 : expr.index(".")] for expr in expressions if "." in expr + ] + for attr in list(self.attrs): + if attr not in expressions and attr not in top_level_expressions: + self.attrs.pop(attr) + if attr in top_level_expressions: + relevant_expressions = [ + expr[len(attr + ".") :] + for expr in expressions + if expr.startswith(attr + ".") + ] + self.attrs[attr].filter(relevant_expressions) class StreamRecord(BaseModel): def __init__(self, table, stream_type, event_name, old, new, seq): - old_a = old.to_json()['Attributes'] if old is not None else {} - new_a = new.to_json()['Attributes'] if new is not None else {} + old_a = old.to_json()["Attributes"] if old is not None else {} + new_a = new.to_json()["Attributes"] if new is not None else {} rec = old if old is not None else new keys = {table.hash_key_attr: rec.hash_key.to_json()} @@ -355,28 +566,27 @@ class StreamRecord(BaseModel): keys[table.range_key_attr] = rec.range_key.to_json() self.record = { - 'eventID': uuid.uuid4().hex, - 'eventName': event_name, - 'eventSource': 'aws:dynamodb', - 'eventVersion': '1.0', - 'awsRegion': 'us-east-1', - 'dynamodb': { - 'StreamViewType': stream_type, - 'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(), - 'SequenceNumber': seq, - 'SizeBytes': 1, - 'Keys': keys - } + "eventID": uuid.uuid4().hex, + "eventName": event_name, + "eventSource": "aws:dynamodb", + "eventVersion": "1.0", + "awsRegion": "us-east-1", + "dynamodb": { + "StreamViewType": stream_type, + "ApproximateCreationDateTime": datetime.datetime.utcnow().isoformat(), + "SequenceNumber": str(seq), + "SizeBytes": 1, + "Keys": keys, + }, } - if stream_type in ('NEW_IMAGE', 'NEW_AND_OLD_IMAGES'): - self.record['dynamodb']['NewImage'] = new_a - if stream_type in ('OLD_IMAGE', 'NEW_AND_OLD_IMAGES'): - self.record['dynamodb']['OldImage'] = old_a + if stream_type in ("NEW_IMAGE", "NEW_AND_OLD_IMAGES"): + self.record["dynamodb"]["NewImage"] = new_a + if stream_type in ("OLD_IMAGE", "NEW_AND_OLD_IMAGES"): + self.record["dynamodb"]["OldImage"] = old_a # This is a substantial overestimate but it's the easiest to do now - self.record['dynamodb']['SizeBytes'] = len( - json.dumps(self.record['dynamodb'])) + self.record["dynamodb"]["SizeBytes"] = len(json.dumps(self.record["dynamodb"])) def to_json(self): return self.record @@ -385,30 +595,43 @@ class StreamRecord(BaseModel): class StreamShard(BaseModel): def __init__(self, table): self.table = table - self.id = 'shardId-00000001541626099285-f35f62ef' + self.id = "shardId-00000001541626099285-f35f62ef" self.starting_sequence_number = 1100000000017454423009 self.items = [] self.created_on = datetime.datetime.utcnow() def to_json(self): return { - 'ShardId': self.id, - 'SequenceNumberRange': { - 'StartingSequenceNumber': str(self.starting_sequence_number) - } + "ShardId": self.id, + "SequenceNumberRange": { + "StartingSequenceNumber": str(self.starting_sequence_number) + }, } def add(self, old, new): - t = self.table.stream_specification['StreamViewType'] + t = self.table.stream_specification["StreamViewType"] if old is None: - event_name = 'INSERT' + event_name = "INSERT" elif new is None: - event_name = 'DELETE' + event_name = "DELETE" else: - event_name = 'MODIFY' + event_name = "MODIFY" seq = len(self.items) + self.starting_sequence_number - self.items.append( - StreamRecord(self.table, t, event_name, old, new, seq)) + self.items.append(StreamRecord(self.table, t, event_name, old, new, seq)) + result = None + from moto.awslambda import lambda_backends + + for arn, esm in self.table.lambda_event_source_mappings.items(): + region = arn[ + len("arn:aws:lambda:") : arn.index(":", len("arn:aws:lambda:")) + ] + + result = lambda_backends[region].send_dynamodb_items( + arn, self.items, esm.event_source_arn + ) + + if result: + self.items = [] def get(self, start, quantity): start -= self.starting_sequence_number @@ -418,8 +641,16 @@ class StreamShard(BaseModel): class Table(BaseModel): - - def __init__(self, table_name, schema=None, attr=None, throughput=None, indexes=None, global_indexes=None, streams=None): + def __init__( + self, + table_name, + schema=None, + attr=None, + throughput=None, + indexes=None, + global_indexes=None, + streams=None, + ): self.name = table_name self.attr = attr self.schema = schema @@ -435,8 +666,7 @@ class Table(BaseModel): self.range_key_attr = elem["AttributeName"] self.range_key_type = elem["KeyType"] if throughput is None: - self.throughput = { - 'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10} + self.throughput = {"WriteCapacityUnits": 10, "ReadCapacityUnits": 10} else: self.throughput = throughput self.throughput["NumberOfDecreasesToday"] = 0 @@ -447,65 +677,72 @@ class Table(BaseModel): self.table_arn = self._generate_arn(table_name) self.tags = [] self.ttl = { - 'TimeToLiveStatus': 'DISABLED' # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED', + "TimeToLiveStatus": "DISABLED" # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED', # 'AttributeName': 'string' # Can contain this } self.set_stream_specification(streams) + self.lambda_event_source_mappings = {} @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] params = {} - if 'KeySchema' in properties: - params['schema'] = properties['KeySchema'] - if 'AttributeDefinitions' in properties: - params['attr'] = properties['AttributeDefinitions'] - if 'GlobalSecondaryIndexes' in properties: - params['global_indexes'] = properties['GlobalSecondaryIndexes'] - if 'ProvisionedThroughput' in properties: - params['throughput'] = properties['ProvisionedThroughput'] - if 'LocalSecondaryIndexes' in properties: - params['indexes'] = properties['LocalSecondaryIndexes'] + if "KeySchema" in properties: + params["schema"] = properties["KeySchema"] + if "AttributeDefinitions" in properties: + params["attr"] = properties["AttributeDefinitions"] + if "GlobalSecondaryIndexes" in properties: + params["global_indexes"] = properties["GlobalSecondaryIndexes"] + if "ProvisionedThroughput" in properties: + params["throughput"] = properties["ProvisionedThroughput"] + if "LocalSecondaryIndexes" in properties: + params["indexes"] = properties["LocalSecondaryIndexes"] - table = dynamodb_backends[region_name].create_table(name=properties['TableName'], **params) + table = dynamodb_backends[region_name].create_table( + name=properties["TableName"], **params + ) return table def _generate_arn(self, name): - return 'arn:aws:dynamodb:us-east-1:123456789011:table/' + name + return "arn:aws:dynamodb:us-east-1:123456789011:table/" + name def set_stream_specification(self, streams): self.stream_specification = streams - if streams and (streams.get('StreamEnabled') or streams.get('StreamViewType')): - self.stream_specification['StreamEnabled'] = True + if streams and (streams.get("StreamEnabled") or streams.get("StreamViewType")): + self.stream_specification["StreamEnabled"] = True self.latest_stream_label = datetime.datetime.utcnow().isoformat() self.stream_shard = StreamShard(self) else: - self.stream_specification = {'StreamEnabled': False} + self.stream_specification = {"StreamEnabled": False} self.latest_stream_label = None self.stream_shard = None - def describe(self, base_key='TableDescription'): + def describe(self, base_key="TableDescription"): results = { base_key: { - 'AttributeDefinitions': self.attr, - 'ProvisionedThroughput': self.throughput, - 'TableSizeBytes': 0, - 'TableName': self.name, - 'TableStatus': 'ACTIVE', - 'TableArn': self.table_arn, - 'KeySchema': self.schema, - 'ItemCount': len(self), - 'CreationDateTime': unix_time(self.created_at), - 'GlobalSecondaryIndexes': [index for index in self.global_indexes], - 'LocalSecondaryIndexes': [index for index in self.indexes], + "AttributeDefinitions": self.attr, + "ProvisionedThroughput": self.throughput, + "TableSizeBytes": 0, + "TableName": self.name, + "TableStatus": "ACTIVE", + "TableArn": self.table_arn, + "KeySchema": self.schema, + "ItemCount": len(self), + "CreationDateTime": unix_time(self.created_at), + "GlobalSecondaryIndexes": [index for index in self.global_indexes], + "LocalSecondaryIndexes": [index for index in self.indexes], } } - if self.stream_specification and self.stream_specification['StreamEnabled']: - results[base_key]['StreamSpecification'] = self.stream_specification + if self.stream_specification and self.stream_specification["StreamEnabled"]: + results[base_key]["StreamSpecification"] = self.stream_specification if self.latest_stream_label: - results[base_key]['LatestStreamLabel'] = self.latest_stream_label - results[base_key]['LatestStreamArn'] = self.table_arn + '/stream/' + self.latest_stream_label + results[base_key]["LatestStreamLabel"] = self.latest_stream_label + results[base_key]["LatestStreamArn"] = ( + self.table_arn + "/stream/" + self.latest_stream_label + ) return results def __len__(self): @@ -522,9 +759,9 @@ class Table(BaseModel): keys = [self.hash_key_attr] for index in self.global_indexes: hash_key = None - for key in index['KeySchema']: - if key['KeyType'] == 'HASH': - hash_key = key['AttributeName'] + for key in index["KeySchema"]: + if key["KeyType"] == "HASH": + hash_key = key["AttributeName"] keys.append(hash_key) return keys @@ -533,15 +770,21 @@ class Table(BaseModel): keys = [self.range_key_attr] for index in self.global_indexes: range_key = None - for key in index['KeySchema']: - if key['KeyType'] == 'RANGE': - range_key = keys.append(key['AttributeName']) + for key in index["KeySchema"]: + if key["KeyType"] == "RANGE": + range_key = keys.append(key["AttributeName"]) keys.append(range_key) return keys - def put_item(self, item_attrs, expected=None, condition_expression=None, - expression_attribute_names=None, - expression_attribute_values=None, overwrite=False): + def put_item( + self, + item_attrs, + expected=None, + condition_expression=None, + expression_attribute_names=None, + expression_attribute_values=None, + overwrite=False, + ): hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.has_range_key: range_value = DynamoType(item_attrs.get(self.range_key_attr)) @@ -552,26 +795,27 @@ class Table(BaseModel): expected = {} lookup_range_value = range_value else: - expected_range_value = expected.get( - self.range_key_attr, {}).get("Value") - if(expected_range_value is None): + expected_range_value = expected.get(self.range_key_attr, {}).get("Value") + if expected_range_value is None: lookup_range_value = range_value else: lookup_range_value = DynamoType(expected_range_value) current = self.get_item(hash_value, lookup_range_value) - item = Item(hash_value, self.hash_key_type, range_value, - self.range_key_type, item_attrs) + item = Item( + hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs + ) if not overwrite: if not get_expected(expected).expr(current): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") condition_op = get_filter_expression( condition_expression, expression_attribute_names, - expression_attribute_values) + expression_attribute_values, + ) if not condition_op.expr(current): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") if range_value: self.items[hash_value][range_value] = item @@ -593,18 +837,27 @@ class Table(BaseModel): def has_range_key(self): return self.range_key_attr is not None - def get_item(self, hash_key, range_key=None): + def get_item(self, hash_key, range_key=None, projection_expression=None): if self.has_range_key and not range_key: raise ValueError( - "Table has a range key, but no range key was passed into get_item") + "Table has a range key, but no range key was passed into get_item" + ) try: + result = None + if range_key: - return self.items[hash_key][range_key] + result = self.items[hash_key][range_key] + elif hash_key in self.items: + result = self.items[hash_key] - if hash_key in self.items: - return self.items[hash_key] + if projection_expression and result: + result = copy.deepcopy(result) + result.filter(projection_expression) - raise KeyError + if not result: + raise KeyError + + return result except KeyError: return None @@ -622,30 +875,42 @@ class Table(BaseModel): except KeyError: return None - def query(self, hash_key, range_comparison, range_objs, limit, - exclusive_start_key, scan_index_forward, projection_expression, - index_name=None, filter_expression=None, **filter_kwargs): + def query( + self, + hash_key, + range_comparison, + range_objs, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name=None, + filter_expression=None, + **filter_kwargs + ): results = [] if index_name: all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) if index_name not in indexes_by_name: - raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % ( - index_name, self.name, ', '.join(indexes_by_name.keys()) - )) + raise ValueError( + "Invalid index: %s for table: %s. Available indexes are: %s" + % (index_name, self.name, ", ".join(indexes_by_name.keys())) + ) index = indexes_by_name[index_name] try: - index_hash_key = [key for key in index[ - 'KeySchema'] if key['KeyType'] == 'HASH'][0] + index_hash_key = [ + key for key in index["KeySchema"] if key["KeyType"] == "HASH" + ][0] except IndexError: - raise ValueError('Missing Hash Key. KeySchema: %s' % - index['KeySchema']) + raise ValueError("Missing Hash Key. KeySchema: %s" % index["KeySchema"]) try: - index_range_key = [key for key in index[ - 'KeySchema'] if key['KeyType'] == 'RANGE'][0] + index_range_key = [ + key for key in index["KeySchema"] if key["KeyType"] == "RANGE" + ][0] except IndexError: index_range_key = None @@ -653,26 +918,32 @@ class Table(BaseModel): for item in self.all_items(): if not isinstance(item, Item): continue - item_hash_key = item.attrs.get(index_hash_key['AttributeName']) + item_hash_key = item.attrs.get(index_hash_key["AttributeName"]) if index_range_key is None: if item_hash_key and item_hash_key == hash_key: possible_results.append(item) else: - item_range_key = item.attrs.get(index_range_key['AttributeName']) + item_range_key = item.attrs.get(index_range_key["AttributeName"]) if item_hash_key and item_hash_key == hash_key and item_range_key: possible_results.append(item) else: - possible_results = [item for item in list(self.all_items()) if isinstance( - item, Item) and item.hash_key == hash_key] - + possible_results = [ + item + for item in list(self.all_items()) + if isinstance(item, Item) and item.hash_key == hash_key + ] if range_comparison: if index_name and not index_range_key: raise ValueError( - 'Range Key comparison but no range key found for index: %s' % index_name) + "Range Key comparison but no range key found for index: %s" + % index_name + ) elif index_name: for result in possible_results: - if result.attrs.get(index_range_key['AttributeName']).compare(range_comparison, range_objs): + if result.attrs.get(index_range_key["AttributeName"]).compare( + range_comparison, range_objs + ): results.append(result) else: for result in possible_results: @@ -682,9 +953,12 @@ class Table(BaseModel): if filter_kwargs: for result in possible_results: for field, value in filter_kwargs.items(): - dynamo_types = [DynamoType(ele) for ele in value[ - "AttributeValueList"]] - if result.attrs.get(field).compare(value['ComparisonOperator'], dynamo_types): + dynamo_types = [ + DynamoType(ele) for ele in value["AttributeValueList"] + ] + if result.attrs.get(field).compare( + value["ComparisonOperator"], dynamo_types + ): results.append(result) if not range_comparison and not filter_kwargs: @@ -695,8 +969,11 @@ class Table(BaseModel): if index_name: if index_range_key: - results.sort(key=lambda item: item.attrs[index_range_key['AttributeName']].value - if item.attrs.get(index_range_key['AttributeName']) else None) + results.sort( + key=lambda item: item.attrs[index_range_key["AttributeName"]].value + if item.attrs.get(index_range_key["AttributeName"]) + else None + ) else: results.sort(key=lambda item: item.range_key) @@ -708,16 +985,14 @@ class Table(BaseModel): if filter_expression is not None: results = [item for item in results if filter_expression.expr(item)] + results = copy.deepcopy(results) if projection_expression: - expressions = [x.strip() for x in projection_expression.split(',')] - results = copy.deepcopy(results) for result in results: - for attr in list(result.attrs): - if attr not in expressions: - result.attrs.pop(attr) + result.filter(projection_expression) - results, last_evaluated_key = self._trim_results(results, limit, - exclusive_start_key) + results, last_evaluated_key = self._trim_results( + results, limit, exclusive_start_key + ) return results, scanned_count, last_evaluated_key def all_items(self): @@ -734,9 +1009,9 @@ class Table(BaseModel): def has_idx_items(self, index_name): all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) idx = indexes_by_name[index_name] - idx_col_set = set([i['AttributeName'] for i in idx['KeySchema']]) + idx_col_set = set([i["AttributeName"] for i in idx["KeySchema"]]) for hash_set in self.items.values(): if self.range_key_attr: @@ -747,15 +1022,25 @@ class Table(BaseModel): if idx_col_set.issubset(set(hash_set.attrs)): yield hash_set - def scan(self, filters, limit, exclusive_start_key, filter_expression=None, index_name=None, projection_expression=None): + def scan( + self, + filters, + limit, + exclusive_start_key, + filter_expression=None, + index_name=None, + projection_expression=None, + ): results = [] scanned_count = 0 all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) if index_name: if index_name not in indexes_by_name: - raise InvalidIndexNameError('The table does not have the specified index: %s' % index_name) + raise InvalidIndexNameError( + "The table does not have the specified index: %s" % index_name + ) items = self.has_idx_items(index_name) else: items = self.all_items() @@ -763,7 +1048,10 @@ class Table(BaseModel): for item in items: scanned_count += 1 passes_all_conditions = True - for attribute_name, (comparison_operator, comparison_objs) in filters.items(): + for ( + attribute_name, + (comparison_operator, comparison_objs), + ) in filters.items(): attribute = item.attrs.get(attribute_name) if attribute: @@ -771,7 +1059,7 @@ class Table(BaseModel): if not attribute.compare(comparison_operator, comparison_objs): passes_all_conditions = False break - elif comparison_operator == 'NULL': + elif comparison_operator == "NULL": # Comparison is NULL and we don't have the attribute continue else: @@ -787,42 +1075,41 @@ class Table(BaseModel): results.append(item) if projection_expression: - expressions = [x.strip() for x in projection_expression.split(',')] results = copy.deepcopy(results) for result in results: - for attr in list(result.attrs): - if attr not in expressions: - result.attrs.pop(attr) + result.filter(projection_expression) - results, last_evaluated_key = self._trim_results(results, limit, - exclusive_start_key, index_name) + results, last_evaluated_key = self._trim_results( + results, limit, exclusive_start_key, index_name + ) return results, scanned_count, last_evaluated_key - def _trim_results(self, results, limit, exclusive_start_key, scaned_index=None): + def _trim_results(self, results, limit, exclusive_start_key, scanned_index=None): if exclusive_start_key is not None: hash_key = DynamoType(exclusive_start_key.get(self.hash_key_attr)) range_key = exclusive_start_key.get(self.range_key_attr) if range_key is not None: range_key = DynamoType(range_key) for i in range(len(results)): - if results[i].hash_key == hash_key and results[i].range_key == range_key: - results = results[i + 1:] + if ( + results[i].hash_key == hash_key + and results[i].range_key == range_key + ): + results = results[i + 1 :] break last_evaluated_key = None if limit and len(results) > limit: results = results[:limit] - last_evaluated_key = { - self.hash_key_attr: results[-1].hash_key - } + last_evaluated_key = {self.hash_key_attr: results[-1].hash_key} if results[-1].range_key is not None: last_evaluated_key[self.range_key_attr] = results[-1].range_key - if scaned_index: + if scanned_index: all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) - idx = indexes_by_name[scaned_index] - idx_col_list = [i['AttributeName'] for i in idx['KeySchema']] + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) + idx = indexes_by_name[scanned_index] + idx_col_list = [i["AttributeName"] for i in idx["KeySchema"]] for col in idx_col_list: last_evaluated_key[col] = results[-1].attrs[col] @@ -840,7 +1127,6 @@ class Table(BaseModel): class DynamoDBBackend(BaseBackend): - def __init__(self, region_name=None): self.region_name = region_name self.tables = OrderedDict() @@ -869,7 +1155,9 @@ class DynamoDBBackend(BaseBackend): def untag_resource(self, table_arn, tag_keys): for table in self.tables: if self.tables[table].table_arn == table_arn: - self.tables[table].tags = [tag for tag in self.tables[table].tags if tag['Key'] not in tag_keys] + self.tables[table].tags = [ + tag for tag in self.tables[table].tags if tag["Key"] not in tag_keys + ] def list_tags_of_resource(self, table_arn): required_table = None @@ -885,55 +1173,76 @@ class DynamoDBBackend(BaseBackend): def update_table_streams(self, name, stream_specification): table = self.tables[name] - if (stream_specification.get('StreamEnabled') or stream_specification.get('StreamViewType')) and table.latest_stream_label: - raise ValueError('Table already has stream enabled') + if ( + stream_specification.get("StreamEnabled") + or stream_specification.get("StreamViewType") + ) and table.latest_stream_label: + raise ValueError("Table already has stream enabled") table.set_stream_specification(stream_specification) return table def update_table_global_indexes(self, name, global_index_updates): table = self.tables[name] - gsis_by_name = dict((i['IndexName'], i) for i in table.global_indexes) + gsis_by_name = dict((i["IndexName"], i) for i in table.global_indexes) for gsi_update in global_index_updates: - gsi_to_create = gsi_update.get('Create') - gsi_to_update = gsi_update.get('Update') - gsi_to_delete = gsi_update.get('Delete') + gsi_to_create = gsi_update.get("Create") + gsi_to_update = gsi_update.get("Update") + gsi_to_delete = gsi_update.get("Delete") if gsi_to_delete: - index_name = gsi_to_delete['IndexName'] + index_name = gsi_to_delete["IndexName"] if index_name not in gsis_by_name: - raise ValueError('Global Secondary Index does not exist, but tried to delete: %s' % - gsi_to_delete['IndexName']) + raise ValueError( + "Global Secondary Index does not exist, but tried to delete: %s" + % gsi_to_delete["IndexName"] + ) del gsis_by_name[index_name] if gsi_to_update: - index_name = gsi_to_update['IndexName'] + index_name = gsi_to_update["IndexName"] if index_name not in gsis_by_name: - raise ValueError('Global Secondary Index does not exist, but tried to update: %s' % - gsi_to_update['IndexName']) + raise ValueError( + "Global Secondary Index does not exist, but tried to update: %s" + % gsi_to_update["IndexName"] + ) gsis_by_name[index_name].update(gsi_to_update) if gsi_to_create: - if gsi_to_create['IndexName'] in gsis_by_name: + if gsi_to_create["IndexName"] in gsis_by_name: raise ValueError( - 'Global Secondary Index already exists: %s' % gsi_to_create['IndexName']) + "Global Secondary Index already exists: %s" + % gsi_to_create["IndexName"] + ) - gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create + gsis_by_name[gsi_to_create["IndexName"]] = gsi_to_create # in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other # parts of the codebase table.global_indexes = list(gsis_by_name.values()) return table - def put_item(self, table_name, item_attrs, expected=None, - condition_expression=None, expression_attribute_names=None, - expression_attribute_values=None, overwrite=False): + def put_item( + self, + table_name, + item_attrs, + expected=None, + condition_expression=None, + expression_attribute_names=None, + expression_attribute_values=None, + overwrite=False, + ): table = self.tables.get(table_name) if not table: return None - return table.put_item(item_attrs, expected, condition_expression, - expression_attribute_names, - expression_attribute_values, overwrite) + return table.put_item( + item_attrs, + expected, + condition_expression, + expression_attribute_names, + expression_attribute_values, + overwrite, + ) def get_table_keys_name(self, table_name, keys): """ @@ -959,42 +1268,80 @@ class DynamoDBBackend(BaseBackend): return potential_hash, potential_range def get_keys_value(self, table, keys): - if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): + if table.hash_key_attr not in keys or ( + table.has_range_key and table.range_key_attr not in keys + ): raise ValueError( - "Table has a range key, but no range key was passed into get_item") + "Table has a range key, but no range key was passed into get_item" + ) hash_key = DynamoType(keys[table.hash_key_attr]) - range_key = DynamoType( - keys[table.range_key_attr]) if table.has_range_key else None + range_key = ( + DynamoType(keys[table.range_key_attr]) if table.has_range_key else None + ) return hash_key, range_key def get_table(self, table_name): return self.tables.get(table_name) - def get_item(self, table_name, keys): + def get_item(self, table_name, keys, projection_expression=None): table = self.get_table(table_name) if not table: raise ValueError("No table found") hash_key, range_key = self.get_keys_value(table, keys) - return table.get_item(hash_key, range_key) + return table.get_item(hash_key, range_key, projection_expression) - def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, - limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, - expr_names=None, expr_values=None, filter_expression=None, - **filter_kwargs): + def query( + self, + table_name, + hash_key_dict, + range_comparison, + range_value_dicts, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name=None, + expr_names=None, + expr_values=None, + filter_expression=None, + **filter_kwargs + ): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) - range_values = [DynamoType(range_value) - for range_value in range_value_dicts] + range_values = [DynamoType(range_value) for range_value in range_value_dicts] - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) + filter_expression = get_filter_expression( + filter_expression, expr_names, expr_values + ) - return table.query(hash_key, range_comparison, range_values, limit, - exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) + return table.query( + hash_key, + range_comparison, + range_values, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name, + filter_expression, + **filter_kwargs + ) - def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values, index_name, projection_expression): + def scan( + self, + table_name, + filters, + limit, + exclusive_start_key, + filter_expression, + expr_names, + expr_values, + index_name, + projection_expression, + ): table = self.tables.get(table_name) if not table: return None, None, None @@ -1004,14 +1351,37 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) + filter_expression = get_filter_expression( + filter_expression, expr_names, expr_values + ) - projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')]) + projection_expression = ",".join( + [ + expr_names.get(attr, attr) + for attr in projection_expression.replace(" ", "").split(",") + ] + ) - return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression) + return table.scan( + scan_filters, + limit, + exclusive_start_key, + filter_expression, + index_name, + projection_expression, + ) - def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected=None, condition_expression=None): + def update_item( + self, + table_name, + key, + update_expression, + attribute_updates, + expression_attribute_names, + expression_attribute_values, + expected=None, + condition_expression=None, + ): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): @@ -1034,67 +1404,87 @@ class DynamoDBBackend(BaseBackend): expected = {} if not get_expected(expected).expr(item): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") condition_op = get_filter_expression( condition_expression, expression_attribute_names, - expression_attribute_values) + expression_attribute_values, + ) if not condition_op.expr(item): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") # Update does not fail on new items, so create one if item is None: - data = { - table.hash_key_attr: { - hash_value.type: hash_value.value, - }, - } + data = {table.hash_key_attr: {hash_value.type: hash_value.value}} if range_value: - data.update({ - table.range_key_attr: { - range_value.type: range_value.value, - } - }) + data.update( + {table.range_key_attr: {range_value.type: range_value.value}} + ) table.put_item(data) item = table.get_item(hash_value, range_value) if update_expression: - item.update(update_expression, expression_attribute_names, - expression_attribute_values) + item.update( + update_expression, + expression_attribute_names, + expression_attribute_values, + ) else: item.update_with_attribute_updates(attribute_updates) return item - def delete_item(self, table_name, keys): + def delete_item( + self, + table_name, + key, + expression_attribute_names=None, + expression_attribute_values=None, + condition_expression=None, + ): table = self.get_table(table_name) if not table: return None - hash_key, range_key = self.get_keys_value(table, keys) - return table.delete_item(hash_key, range_key) + + hash_value, range_value = self.get_keys_value(table, key) + item = table.get_item(hash_value, range_value) + + condition_op = get_filter_expression( + condition_expression, + expression_attribute_names, + expression_attribute_values, + ) + if not condition_op.expr(item): + raise ValueError("The conditional request failed") + + return table.delete_item(hash_value, range_value) def update_ttl(self, table_name, ttl_spec): table = self.tables.get(table_name) if table is None: - raise JsonRESTError('ResourceNotFound', 'Table not found') + raise JsonRESTError("ResourceNotFound", "Table not found") - if 'Enabled' not in ttl_spec or 'AttributeName' not in ttl_spec: - raise JsonRESTError('InvalidParameterValue', - 'TimeToLiveSpecification does not contain Enabled and AttributeName') + if "Enabled" not in ttl_spec or "AttributeName" not in ttl_spec: + raise JsonRESTError( + "InvalidParameterValue", + "TimeToLiveSpecification does not contain Enabled and AttributeName", + ) - if ttl_spec['Enabled']: - table.ttl['TimeToLiveStatus'] = 'ENABLED' + if ttl_spec["Enabled"]: + table.ttl["TimeToLiveStatus"] = "ENABLED" else: - table.ttl['TimeToLiveStatus'] = 'DISABLED' - table.ttl['AttributeName'] = ttl_spec['AttributeName'] + table.ttl["TimeToLiveStatus"] = "DISABLED" + table.ttl["AttributeName"] = ttl_spec["AttributeName"] def describe_ttl(self, table_name): table = self.tables.get(table_name) if table is None: - raise JsonRESTError('ResourceNotFound', 'Table not found') + raise JsonRESTError("ResourceNotFound", "Table not found") return table.ttl available_regions = boto3.session.Session().get_available_regions("dynamodb") -dynamodb_backends = {region: DynamoDBBackend(region_name=region) for region in available_regions} +dynamodb_backends = { + region: DynamoDBBackend(region_name=region) for region in available_regions +} diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 3e9fbb553..c9f3529a9 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -1,11 +1,12 @@ from __future__ import unicode_literals +import itertools import json import six import re from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores, amzn_request_id -from .exceptions import InvalidIndexNameError +from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSizeTooLarge from .models import dynamodb_backends, dynamo_json_dump @@ -15,25 +16,30 @@ def has_empty_keys_or_values(_dict): if not isinstance(_dict, dict): return False return any( - key == '' or value == '' or - has_empty_keys_or_values(value) + key == "" or value == "" or has_empty_keys_or_values(value) for key, value in _dict.items() ) def get_empty_str_error(): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return (400, - {'server': 'amazon.com'}, - dynamo_json_dump({'__type': er, - 'message': ('One or more parameter values were ' - 'invalid: An AttributeValue may not ' - 'contain an empty string')} - )) + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return ( + 400, + {"server": "amazon.com"}, + dynamo_json_dump( + { + "__type": er, + "message": ( + "One or more parameter values were " + "invalid: An AttributeValue may not " + "contain an empty string" + ), + } + ), + ) class DynamoHandler(BaseResponse): - def get_endpoint_name(self, headers): """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -41,12 +47,16 @@ class DynamoHandler(BaseResponse): ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables """ # Headers are case-insensitive. Probably a better way to do this. - match = headers.get('x-amz-target') or headers.get('X-Amz-Target') + match = headers.get("x-amz-target") or headers.get("X-Amz-Target") if match: return match.split(".")[1] def error(self, type_, message, status=400): - return status, self.response_headers, dynamo_json_dump({'__type': type_, 'message': message}) + return ( + status, + self.response_headers, + dynamo_json_dump({"__type": type_, "message": message}), + ) @property def dynamodb_backend(self): @@ -58,7 +68,7 @@ class DynamoHandler(BaseResponse): @amzn_request_id def call_action(self): - self.body = json.loads(self.body or '{}') + self.body = json.loads(self.body or "{}") endpoint = self.get_endpoint_name(self.headers) if endpoint: endpoint = camelcase_to_underscores(endpoint) @@ -75,7 +85,7 @@ class DynamoHandler(BaseResponse): def list_tables(self): body = self.body - limit = body.get('Limit', 100) + limit = body.get("Limit", 100) if body.get("ExclusiveStartTableName"): last = body.get("ExclusiveStartTableName") start = list(self.dynamodb_backend.tables.keys()).index(last) + 1 @@ -83,7 +93,7 @@ class DynamoHandler(BaseResponse): start = 0 all_tables = list(self.dynamodb_backend.tables.keys()) if limit: - tables = all_tables[start:start + limit] + tables = all_tables[start : start + limit] else: tables = all_tables[start:] response = {"TableNames": tables} @@ -95,245 +105,296 @@ class DynamoHandler(BaseResponse): def create_table(self): body = self.body # get the table name - table_name = body['TableName'] + table_name = body["TableName"] # check billing mode and get the throughput if "BillingMode" in body.keys() and body["BillingMode"] == "PAY_PER_REQUEST": if "ProvisionedThroughput" in body.keys(): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, - 'ProvisionedThroughput cannot be specified \ - when BillingMode is PAY_PER_REQUEST') + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error( + er, + "ProvisionedThroughput cannot be specified \ + when BillingMode is PAY_PER_REQUEST", + ) throughput = None - else: # Provisioned (default billing mode) + else: # Provisioned (default billing mode) throughput = body.get("ProvisionedThroughput") # getting the schema - key_schema = body['KeySchema'] + key_schema = body["KeySchema"] # getting attribute definition attr = body["AttributeDefinitions"] # getting the indexes global_indexes = body.get("GlobalSecondaryIndexes", []) local_secondary_indexes = body.get("LocalSecondaryIndexes", []) + # Verify AttributeDefinitions list all + expected_attrs = [] + expected_attrs.extend([key["AttributeName"] for key in key_schema]) + expected_attrs.extend( + schema["AttributeName"] + for schema in itertools.chain( + *list(idx["KeySchema"] for idx in local_secondary_indexes) + ) + ) + expected_attrs.extend( + schema["AttributeName"] + for schema in itertools.chain( + *list(idx["KeySchema"] for idx in global_indexes) + ) + ) + expected_attrs = list(set(expected_attrs)) + expected_attrs.sort() + actual_attrs = [item["AttributeName"] for item in attr] + actual_attrs.sort() + if actual_attrs != expected_attrs: + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error( + er, + "One or more parameter values were invalid: " + "Some index key attributes are not defined in AttributeDefinitions. " + "Keys: " + + str(expected_attrs) + + ", AttributeDefinitions: " + + str(actual_attrs), + ) # get the stream specification streams = body.get("StreamSpecification") - table = self.dynamodb_backend.create_table(table_name, - schema=key_schema, - throughput=throughput, - attr=attr, - global_indexes=global_indexes, - indexes=local_secondary_indexes, - streams=streams) + table = self.dynamodb_backend.create_table( + table_name, + schema=key_schema, + throughput=throughput, + attr=attr, + global_indexes=global_indexes, + indexes=local_secondary_indexes, + streams=streams, + ) if table is not None: return dynamo_json_dump(table.describe()) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceInUseException' - return self.error(er, 'Resource in use') + er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" + return self.error(er, "Resource in use") def delete_table(self): - name = self.body['TableName'] + name = self.body["TableName"] table = self.dynamodb_backend.delete_table(name) if table is not None: return dynamo_json_dump(table.describe()) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") def tag_resource(self): - table_arn = self.body['ResourceArn'] - tags = self.body['Tags'] + table_arn = self.body["ResourceArn"] + tags = self.body["Tags"] self.dynamodb_backend.tag_resource(table_arn, tags) - return '' + return "" def untag_resource(self): - table_arn = self.body['ResourceArn'] - tags = self.body['TagKeys'] + table_arn = self.body["ResourceArn"] + tags = self.body["TagKeys"] self.dynamodb_backend.untag_resource(table_arn, tags) - return '' + return "" def list_tags_of_resource(self): try: - table_arn = self.body['ResourceArn'] + table_arn = self.body["ResourceArn"] all_tags = self.dynamodb_backend.list_tags_of_resource(table_arn) - all_tag_keys = [tag['Key'] for tag in all_tags] - marker = self.body.get('NextToken') + all_tag_keys = [tag["Key"] for tag in all_tags] + marker = self.body.get("NextToken") if marker: start = all_tag_keys.index(marker) + 1 else: start = 0 max_items = 10 # there is no default, but using 10 to make testing easier - tags_resp = all_tags[start:start + max_items] + tags_resp = all_tags[start : start + max_items] next_marker = None if len(all_tags) > start + max_items: - next_marker = tags_resp[-1]['Key'] + next_marker = tags_resp[-1]["Key"] if next_marker: - return json.dumps({'Tags': tags_resp, - 'NextToken': next_marker}) - return json.dumps({'Tags': tags_resp}) + return json.dumps({"Tags": tags_resp, "NextToken": next_marker}) + return json.dumps({"Tags": tags_resp}) except AttributeError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") def update_table(self): - name = self.body['TableName'] + name = self.body["TableName"] table = self.dynamodb_backend.get_table(name) - if 'GlobalSecondaryIndexUpdates' in self.body: + if "GlobalSecondaryIndexUpdates" in self.body: table = self.dynamodb_backend.update_table_global_indexes( - name, self.body['GlobalSecondaryIndexUpdates']) - if 'ProvisionedThroughput' in self.body: + name, self.body["GlobalSecondaryIndexUpdates"] + ) + if "ProvisionedThroughput" in self.body: throughput = self.body["ProvisionedThroughput"] table = self.dynamodb_backend.update_table_throughput(name, throughput) - if 'StreamSpecification' in self.body: + if "StreamSpecification" in self.body: try: - table = self.dynamodb_backend.update_table_streams(name, self.body['StreamSpecification']) + table = self.dynamodb_backend.update_table_streams( + name, self.body["StreamSpecification"] + ) except ValueError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceInUseException' - return self.error(er, 'Cannot enable stream') + er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" + return self.error(er, "Cannot enable stream") return dynamo_json_dump(table.describe()) def describe_table(self): - name = self.body['TableName'] + name = self.body["TableName"] try: table = self.dynamodb_backend.tables[name] except KeyError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') - return dynamo_json_dump(table.describe(base_key='Table')) + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") + return dynamo_json_dump(table.describe(base_key="Table")) def put_item(self): - name = self.body['TableName'] - item = self.body['Item'] - return_values = self.body.get('ReturnValues', 'NONE') + name = self.body["TableName"] + item = self.body["Item"] + return_values = self.body.get("ReturnValues", "NONE") - if return_values not in ('ALL_OLD', 'NONE'): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Return values set to invalid value') + if return_values not in ("ALL_OLD", "NONE"): + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Return values set to invalid value") if has_empty_keys_or_values(item): return get_empty_str_error() - overwrite = 'Expected' not in self.body + overwrite = "Expected" not in self.body if not overwrite: - expected = self.body['Expected'] + expected = self.body["Expected"] else: expected = None - if return_values == 'ALL_OLD': + if return_values == "ALL_OLD": existing_item = self.dynamodb_backend.get_item(name, item) if existing_item: - existing_attributes = existing_item.to_json()['Attributes'] + existing_attributes = existing_item.to_json()["Attributes"] else: existing_attributes = {} # Attempt to parse simple ConditionExpressions into an Expected # expression - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + condition_expression = self.body.get("ConditionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) if condition_expression: overwrite = False try: result = self.dynamodb_backend.put_item( - name, item, expected, condition_expression, - expression_attribute_names, expression_attribute_values, - overwrite) + name, + item, + expected, + condition_expression, + expression_attribute_names, + expression_attribute_values, + overwrite, + ) + except ItemSizeTooLarge: + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, ItemSizeTooLarge.message) except ValueError: - er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' - return self.error(er, 'A condition specified in the operation could not be evaluated.') + er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) if result: item_dict = result.to_json() - item_dict['ConsumedCapacity'] = { - 'TableName': name, - 'CapacityUnits': 1 - } - if return_values == 'ALL_OLD': - item_dict['Attributes'] = existing_attributes + item_dict["ConsumedCapacity"] = {"TableName": name, "CapacityUnits": 1} + if return_values == "ALL_OLD": + item_dict["Attributes"] = existing_attributes else: - item_dict.pop('Attributes', None) + item_dict.pop("Attributes", None) return dynamo_json_dump(item_dict) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") def batch_write_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] for table_name, table_requests in table_batches.items(): for table_request in table_requests: request_type = list(table_request.keys())[0] request = list(table_request.values())[0] - if request_type == 'PutRequest': - item = request['Item'] + if request_type == "PutRequest": + item = request["Item"] self.dynamodb_backend.put_item(table_name, item) - elif request_type == 'DeleteRequest': - keys = request['Key'] + elif request_type == "DeleteRequest": + keys = request["Key"] item = self.dynamodb_backend.delete_item(table_name, keys) response = { "ConsumedCapacity": [ { - 'TableName': table_name, - 'CapacityUnits': 1.0, - 'Table': {'CapacityUnits': 1.0} - } for table_name, table_requests in table_batches.items() + "TableName": table_name, + "CapacityUnits": 1.0, + "Table": {"CapacityUnits": 1.0}, + } + for table_name, table_requests in table_batches.items() ], "ItemCollectionMetrics": {}, - "UnprocessedItems": {} + "UnprocessedItems": {}, } return dynamo_json_dump(response) def get_item(self): - name = self.body['TableName'] - key = self.body['Key'] + name = self.body["TableName"] + key = self.body["Key"] + projection_expression = self.body.get("ProjectionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + + projection_expression = self._adjust_projection_expression( + projection_expression, expression_attribute_names + ) + try: - item = self.dynamodb_backend.get_item(name, key) + item = self.dynamodb_backend.get_item(name, key, projection_expression) except ValueError: - er = 'com.amazon.coral.validate#ValidationException' - return self.error(er, 'Validation Exception') + er = "com.amazon.coral.validate#ValidationException" + return self.error(er, "Validation Exception") if item: item_dict = item.describe_attrs(attributes=None) - item_dict['ConsumedCapacity'] = { - 'TableName': name, - 'CapacityUnits': 0.5 - } + item_dict["ConsumedCapacity"] = {"TableName": name, "CapacityUnits": 0.5} return dynamo_json_dump(item_dict) else: # Item not found - return 200, self.response_headers, '{}' + return 200, self.response_headers, "{}" def batch_get_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] - results = { - "ConsumedCapacity": [], - "Responses": { - }, - "UnprocessedKeys": { - } - } + results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}} for table_name, table_request in table_batches.items(): - keys = table_request['Keys'] + keys = table_request["Keys"] if self._contains_duplicates(keys): - er = 'com.amazon.coral.validate#ValidationException' - return self.error(er, 'Provided list of item keys contains duplicates') - attributes_to_get = table_request.get('AttributesToGet') + er = "com.amazon.coral.validate#ValidationException" + return self.error(er, "Provided list of item keys contains duplicates") + attributes_to_get = table_request.get("AttributesToGet") + projection_expression = table_request.get("ProjectionExpression") + expression_attribute_names = table_request.get( + "ExpressionAttributeNames", {} + ) + + projection_expression = self._adjust_projection_expression( + projection_expression, expression_attribute_names + ) + results["Responses"][table_name] = [] for key in keys: - item = self.dynamodb_backend.get_item(table_name, key) + item = self.dynamodb_backend.get_item( + table_name, key, projection_expression + ) if item: item_describe = item.describe_attrs(attributes_to_get) - results["Responses"][table_name].append( - item_describe["Item"]) + results["Responses"][table_name].append(item_describe["Item"]) - results["ConsumedCapacity"].append({ - "CapacityUnits": len(keys), - "TableName": table_name - }) + results["ConsumedCapacity"].append( + {"CapacityUnits": len(keys), "TableName": table_name} + ) return dynamo_json_dump(results) def _contains_duplicates(self, keys): @@ -346,147 +407,168 @@ class DynamoHandler(BaseResponse): return False def query(self): - name = self.body['TableName'] + name = self.body["TableName"] # {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}} - key_condition_expression = self.body.get('KeyConditionExpression') - projection_expression = self.body.get('ProjectionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - filter_expression = self.body.get('FilterExpression') - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + key_condition_expression = self.body.get("KeyConditionExpression") + projection_expression = self.body.get("ProjectionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + filter_expression = self.body.get("FilterExpression") + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) - if projection_expression and expression_attribute_names: - expressions = [x.strip() for x in projection_expression.split(',')] - for expression in expressions: - if expression in expression_attribute_names: - projection_expression = projection_expression.replace(expression, expression_attribute_names[expression]) + projection_expression = self._adjust_projection_expression( + projection_expression, expression_attribute_names + ) filter_kwargs = {} if key_condition_expression: - value_alias_map = self.body.get('ExpressionAttributeValues', {}) + value_alias_map = self.body.get("ExpressionAttributeValues", {}) table = self.dynamodb_backend.get_table(name) # If table does not exist if table is None: - return self.error('com.amazonaws.dynamodb.v20120810#ResourceNotFoundException', - 'Requested resource not found') + return self.error( + "com.amazonaws.dynamodb.v20120810#ResourceNotFoundException", + "Requested resource not found", + ) - index_name = self.body.get('IndexName') + index_name = self.body.get("IndexName") if index_name: - all_indexes = (table.global_indexes or []) + \ - (table.indexes or []) - indexes_by_name = dict((i['IndexName'], i) - for i in all_indexes) + all_indexes = (table.global_indexes or []) + (table.indexes or []) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) if index_name not in indexes_by_name: - raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % ( - index_name, name, ', '.join(indexes_by_name.keys()) - )) + er = "com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" + return self.error( + er, + "Invalid index: {} for table: {}. Available indexes are: {}".format( + index_name, name, ", ".join(indexes_by_name.keys()) + ), + ) - index = indexes_by_name[index_name]['KeySchema'] + index = indexes_by_name[index_name]["KeySchema"] else: index = table.schema - reverse_attribute_lookup = dict((v, k) for k, v in - six.iteritems(self.body.get('ExpressionAttributeNames', {}))) + reverse_attribute_lookup = dict( + (v, k) + for k, v in six.iteritems(self.body.get("ExpressionAttributeNames", {})) + ) if " AND " in key_condition_expression: expressions = key_condition_expression.split(" AND ", 1) - index_hash_key = [key for key in index if key['KeyType'] == 'HASH'][0] - hash_key_var = reverse_attribute_lookup.get(index_hash_key['AttributeName'], - index_hash_key['AttributeName']) - hash_key_regex = r'(^|[\s(]){0}\b'.format(hash_key_var) - i, hash_key_expression = next((i, e) for i, e in enumerate(expressions) - if re.search(hash_key_regex, e)) - hash_key_expression = hash_key_expression.strip('()') + index_hash_key = [key for key in index if key["KeyType"] == "HASH"][0] + hash_key_var = reverse_attribute_lookup.get( + index_hash_key["AttributeName"], index_hash_key["AttributeName"] + ) + hash_key_regex = r"(^|[\s(]){0}\b".format(hash_key_var) + i, hash_key_expression = next( + (i, e) + for i, e in enumerate(expressions) + if re.search(hash_key_regex, e) + ) + hash_key_expression = hash_key_expression.strip("()") expressions.pop(i) # TODO implement more than one range expression and OR operators - range_key_expression = expressions[0].strip('()') + range_key_expression = expressions[0].strip("()") range_key_expression_components = range_key_expression.split() range_comparison = range_key_expression_components[1] - if 'AND' in range_key_expression: - range_comparison = 'BETWEEN' + if "AND" in range_key_expression: + range_comparison = "BETWEEN" range_values = [ value_alias_map[range_key_expression_components[2]], value_alias_map[range_key_expression_components[4]], ] - elif 'begins_with' in range_key_expression: - range_comparison = 'BEGINS_WITH' + elif "begins_with" in range_key_expression: + range_comparison = "BEGINS_WITH" range_values = [ - value_alias_map[range_key_expression_components[1]], + value_alias_map[range_key_expression_components[-1]] ] else: - range_values = [value_alias_map[ - range_key_expression_components[2]]] + range_values = [value_alias_map[range_key_expression_components[2]]] else: - hash_key_expression = key_condition_expression.strip('()') + hash_key_expression = key_condition_expression.strip("()") range_comparison = None range_values = [] + if "=" not in hash_key_expression: + return self.error( + "com.amazonaws.dynamodb.v20111205#ValidationException", + "Query key condition not supported", + ) hash_key_value_alias = hash_key_expression.split("=")[1].strip() # Temporary fix until we get proper KeyConditionExpression function - hash_key = value_alias_map.get(hash_key_value_alias, {'S': hash_key_value_alias}) + hash_key = value_alias_map.get( + hash_key_value_alias, {"S": hash_key_value_alias} + ) else: # 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}} - key_conditions = self.body.get('KeyConditions') + key_conditions = self.body.get("KeyConditions") query_filters = self.body.get("QueryFilter") if key_conditions: - hash_key_name, range_key_name = self.dynamodb_backend.get_table_keys_name( - name, key_conditions.keys()) + ( + hash_key_name, + range_key_name, + ) = self.dynamodb_backend.get_table_keys_name( + name, key_conditions.keys() + ) for key, value in key_conditions.items(): if key not in (hash_key_name, range_key_name): filter_kwargs[key] = value if hash_key_name is None: er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" - return self.error(er, 'Requested resource not found') - hash_key = key_conditions[hash_key_name][ - 'AttributeValueList'][0] + return self.error(er, "Requested resource not found") + hash_key = key_conditions[hash_key_name]["AttributeValueList"][0] if len(key_conditions) == 1: range_comparison = None range_values = [] else: if range_key_name is None and not filter_kwargs: er = "com.amazon.coral.validate#ValidationException" - return self.error(er, 'Validation Exception') + return self.error(er, "Validation Exception") else: range_condition = key_conditions.get(range_key_name) if range_condition: - range_comparison = range_condition[ - 'ComparisonOperator'] - range_values = range_condition[ - 'AttributeValueList'] + range_comparison = range_condition["ComparisonOperator"] + range_values = range_condition["AttributeValueList"] else: range_comparison = None range_values = [] if query_filters: filter_kwargs.update(query_filters) - index_name = self.body.get('IndexName') - exclusive_start_key = self.body.get('ExclusiveStartKey') + index_name = self.body.get("IndexName") + exclusive_start_key = self.body.get("ExclusiveStartKey") limit = self.body.get("Limit") scan_index_forward = self.body.get("ScanIndexForward") items, scanned_count, last_evaluated_key = self.dynamodb_backend.query( - name, hash_key, range_comparison, range_values, limit, - exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name, - expr_names=expression_attribute_names, expr_values=expression_attribute_values, - filter_expression=filter_expression, **filter_kwargs + name, + hash_key, + range_comparison, + range_values, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name=index_name, + expr_names=expression_attribute_names, + expr_values=expression_attribute_values, + filter_expression=filter_expression, + **filter_kwargs ) if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") result = { "Count": len(items), - 'ConsumedCapacity': { - 'TableName': name, - 'CapacityUnits': 1, - }, - "ScannedCount": scanned_count + "ConsumedCapacity": {"TableName": name, "CapacityUnits": 1}, + "ScannedCount": scanned_count, } - if self.body.get('Select', '').upper() != 'COUNT': + if self.body.get("Select", "").upper() != "COUNT": result["Items"] = [item.attrs for item in items] if last_evaluated_key is not None: @@ -494,11 +576,30 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) + def _adjust_projection_expression(self, projection_expression, expr_attr_names): + def _adjust(expression): + return ( + expr_attr_names[expression] + if expression in expr_attr_names + else expression + ) + + if projection_expression and expr_attr_names: + expressions = [x.strip() for x in projection_expression.split(",")] + return ",".join( + [ + ".".join([_adjust(expr) for expr in nested_expr.split(".")]) + for nested_expr in expressions + ] + ) + + return projection_expression + def scan(self): - name = self.body['TableName'] + name = self.body["TableName"] filters = {} - scan_filters = self.body.get('ScanFilter', {}) + scan_filters = self.body.get("ScanFilter", {}) for attribute_name, scan_filter in scan_filters.items(): # Keys are attribute names. Values are tuples of (comparison, # comparison_value) @@ -506,173 +607,217 @@ class DynamoHandler(BaseResponse): comparison_values = scan_filter.get("AttributeValueList", []) filters[attribute_name] = (comparison_operator, comparison_values) - filter_expression = self.body.get('FilterExpression') - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - projection_expression = self.body.get('ProjectionExpression', '') - exclusive_start_key = self.body.get('ExclusiveStartKey') + filter_expression = self.body.get("FilterExpression") + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + projection_expression = self.body.get("ProjectionExpression", "") + exclusive_start_key = self.body.get("ExclusiveStartKey") limit = self.body.get("Limit") - index_name = self.body.get('IndexName') + index_name = self.body.get("IndexName") try: - items, scanned_count, last_evaluated_key = self.dynamodb_backend.scan(name, filters, - limit, - exclusive_start_key, - filter_expression, - expression_attribute_names, - expression_attribute_values, - index_name, - projection_expression) + items, scanned_count, last_evaluated_key = self.dynamodb_backend.scan( + name, + filters, + limit, + exclusive_start_key, + filter_expression, + expression_attribute_names, + expression_attribute_values, + index_name, + projection_expression, + ) except InvalidIndexNameError as err: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' + er = "com.amazonaws.dynamodb.v20111205#ValidationException" return self.error(er, str(err)) except ValueError as err: - er = 'com.amazonaws.dynamodb.v20111205#ValidationError' - return self.error(er, 'Bad Filter Expression: {0}'.format(err)) + er = "com.amazonaws.dynamodb.v20111205#ValidationError" + return self.error(er, "Bad Filter Expression: {0}".format(err)) except Exception as err: - er = 'com.amazonaws.dynamodb.v20111205#InternalFailure' - return self.error(er, 'Internal error. {0}'.format(err)) + er = "com.amazonaws.dynamodb.v20111205#InternalFailure" + return self.error(er, "Internal error. {0}".format(err)) # Items should be a list, at least an empty one. Is None if table does not exist. # Should really check this at the beginning if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") result = { "Count": len(items), "Items": [item.attrs for item in items], - 'ConsumedCapacity': { - 'TableName': name, - 'CapacityUnits': 1, - }, - "ScannedCount": scanned_count + "ConsumedCapacity": {"TableName": name, "CapacityUnits": 1}, + "ScannedCount": scanned_count, } if last_evaluated_key is not None: result["LastEvaluatedKey"] = last_evaluated_key return dynamo_json_dump(result) def delete_item(self): - name = self.body['TableName'] - keys = self.body['Key'] - return_values = self.body.get('ReturnValues', 'NONE') - if return_values not in ('ALL_OLD', 'NONE'): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Return values set to invalid value') + name = self.body["TableName"] + key = self.body["Key"] + return_values = self.body.get("ReturnValues", "NONE") + if return_values not in ("ALL_OLD", "NONE"): + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Return values set to invalid value") table = self.dynamodb_backend.get_table(name) if not table: - er = 'com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException' - return self.error(er, 'A condition specified in the operation could not be evaluated.') + er = "com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) - item = self.dynamodb_backend.delete_item(name, keys) - if item and return_values == 'ALL_OLD': + # Attempt to parse simple ConditionExpressions into an Expected + # expression + condition_expression = self.body.get("ConditionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) + + try: + item = self.dynamodb_backend.delete_item( + name, + key, + expression_attribute_names, + expression_attribute_values, + condition_expression, + ) + except ValueError: + er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) + + if item and return_values == "ALL_OLD": item_dict = item.to_json() else: - item_dict = {'Attributes': {}} - item_dict['ConsumedCapacityUnits'] = 0.5 + item_dict = {"Attributes": {}} + item_dict["ConsumedCapacityUnits"] = 0.5 return dynamo_json_dump(item_dict) def update_item(self): - name = self.body['TableName'] - key = self.body['Key'] - return_values = self.body.get('ReturnValues', 'NONE') - update_expression = self.body.get('UpdateExpression', '').strip() - attribute_updates = self.body.get('AttributeUpdates') - expression_attribute_names = self.body.get( - 'ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get( - 'ExpressionAttributeValues', {}) + name = self.body["TableName"] + key = self.body["Key"] + return_values = self.body.get("ReturnValues", "NONE") + update_expression = self.body.get("UpdateExpression", "").strip() + attribute_updates = self.body.get("AttributeUpdates") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) existing_item = self.dynamodb_backend.get_item(name, key) if existing_item: - existing_attributes = existing_item.to_json()['Attributes'] + existing_attributes = existing_item.to_json()["Attributes"] else: existing_attributes = {} - if return_values not in ('NONE', 'ALL_OLD', 'ALL_NEW', 'UPDATED_OLD', - 'UPDATED_NEW'): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Return values set to invalid value') + if return_values not in ( + "NONE", + "ALL_OLD", + "ALL_NEW", + "UPDATED_OLD", + "UPDATED_NEW", + ): + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Return values set to invalid value") if has_empty_keys_or_values(expression_attribute_values): return get_empty_str_error() - if 'Expected' in self.body: - expected = self.body['Expected'] + if "Expected" in self.body: + expected = self.body["Expected"] else: expected = None # Attempt to parse simple ConditionExpressions into an Expected # expression - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + condition_expression = self.body.get("ConditionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) # Support spaces between operators in an update expression # E.g. `a = b + c` -> `a=b+c` if update_expression: - update_expression = re.sub( - r'\s*([=\+-])\s*', '\\1', update_expression) + update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression) try: item = self.dynamodb_backend.update_item( - name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected, condition_expression + name, + key, + update_expression, + attribute_updates, + expression_attribute_names, + expression_attribute_values, + expected, + condition_expression, ) + except InvalidUpdateExpression: + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error( + er, + "The document path provided in the update expression is invalid for update", + ) + except ItemSizeTooLarge: + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, ItemSizeTooLarge.message) except ValueError: - er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' - return self.error(er, 'A condition specified in the operation could not be evaluated.') + er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) except TypeError: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Validation Exception') + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Validation Exception") item_dict = item.to_json() - item_dict['ConsumedCapacity'] = { - 'TableName': name, - 'CapacityUnits': 0.5 - } + item_dict["ConsumedCapacity"] = {"TableName": name, "CapacityUnits": 0.5} unchanged_attributes = { - k for k in existing_attributes.keys() - if existing_attributes[k] == item_dict['Attributes'].get(k) + k + for k in existing_attributes.keys() + if existing_attributes[k] == item_dict["Attributes"].get(k) } - changed_attributes = set(existing_attributes.keys()).union(item_dict['Attributes'].keys()).difference(unchanged_attributes) + changed_attributes = ( + set(existing_attributes.keys()) + .union(item_dict["Attributes"].keys()) + .difference(unchanged_attributes) + ) - if return_values == 'NONE': - item_dict['Attributes'] = {} - elif return_values == 'ALL_OLD': - item_dict['Attributes'] = existing_attributes - elif return_values == 'UPDATED_OLD': - item_dict['Attributes'] = { - k: v for k, v in existing_attributes.items() - if k in changed_attributes + if return_values == "NONE": + item_dict["Attributes"] = {} + elif return_values == "ALL_OLD": + item_dict["Attributes"] = existing_attributes + elif return_values == "UPDATED_OLD": + item_dict["Attributes"] = { + k: v for k, v in existing_attributes.items() if k in changed_attributes } - elif return_values == 'UPDATED_NEW': - item_dict['Attributes'] = { - k: v for k, v in item_dict['Attributes'].items() + elif return_values == "UPDATED_NEW": + item_dict["Attributes"] = { + k: v + for k, v in item_dict["Attributes"].items() if k in changed_attributes } return dynamo_json_dump(item_dict) def describe_limits(self): - return json.dumps({ - 'AccountMaxReadCapacityUnits': 20000, - 'TableMaxWriteCapacityUnits': 10000, - 'AccountMaxWriteCapacityUnits': 20000, - 'TableMaxReadCapacityUnits': 10000 - }) + return json.dumps( + { + "AccountMaxReadCapacityUnits": 20000, + "TableMaxWriteCapacityUnits": 10000, + "AccountMaxWriteCapacityUnits": 20000, + "TableMaxReadCapacityUnits": 10000, + } + ) def update_time_to_live(self): - name = self.body['TableName'] - ttl_spec = self.body['TimeToLiveSpecification'] + name = self.body["TableName"] + ttl_spec = self.body["TimeToLiveSpecification"] self.dynamodb_backend.update_ttl(name, ttl_spec) - return json.dumps({'TimeToLiveSpecification': ttl_spec}) + return json.dumps({"TimeToLiveSpecification": ttl_spec}) def describe_time_to_live(self): - name = self.body['TableName'] + name = self.body["TableName"] ttl_spec = self.dynamodb_backend.describe_ttl(name) - return json.dumps({'TimeToLiveDescription': ttl_spec}) + return json.dumps({"TimeToLiveDescription": ttl_spec}) diff --git a/moto/dynamodb2/urls.py b/moto/dynamodb2/urls.py index 6988f6e15..26f0701a2 100644 --- a/moto/dynamodb2/urls.py +++ b/moto/dynamodb2/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DynamoHandler -url_bases = [ - "https?://dynamodb.(.+).amazonaws.com" -] +url_bases = ["https?://dynamodb.(.+).amazonaws.com"] -url_paths = { - "{0}/": DynamoHandler.dispatch, -} +url_paths = {"{0}/": DynamoHandler.dispatch} diff --git a/moto/dynamodbstreams/__init__.py b/moto/dynamodbstreams/__init__.py index b35879eba..85dd5404c 100644 --- a/moto/dynamodbstreams/__init__.py +++ b/moto/dynamodbstreams/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import dynamodbstreams_backends from ..core.models import base_decorator -dynamodbstreams_backend = dynamodbstreams_backends['us-east-1'] +dynamodbstreams_backend = dynamodbstreams_backends["us-east-1"] mock_dynamodbstreams = base_decorator(dynamodbstreams_backends) diff --git a/moto/dynamodbstreams/models.py b/moto/dynamodbstreams/models.py index 41cc6e280..6e99d8ef6 100644 --- a/moto/dynamodbstreams/models.py +++ b/moto/dynamodbstreams/models.py @@ -10,51 +10,59 @@ from moto.dynamodb2.models import dynamodb_backends class ShardIterator(BaseModel): - def __init__(self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None): - self.id = base64.b64encode(os.urandom(472)).decode('utf-8') + def __init__( + self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None + ): + self.id = base64.b64encode(os.urandom(472)).decode("utf-8") self.streams_backend = streams_backend self.stream_shard = stream_shard self.shard_iterator_type = shard_iterator_type - if shard_iterator_type == 'TRIM_HORIZON': + if shard_iterator_type == "TRIM_HORIZON": self.sequence_number = stream_shard.starting_sequence_number - elif shard_iterator_type == 'LATEST': - self.sequence_number = stream_shard.starting_sequence_number + len(stream_shard.items) - elif shard_iterator_type == 'AT_SEQUENCE_NUMBER': + elif shard_iterator_type == "LATEST": + self.sequence_number = stream_shard.starting_sequence_number + len( + stream_shard.items + ) + elif shard_iterator_type == "AT_SEQUENCE_NUMBER": self.sequence_number = sequence_number - elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER': + elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER": self.sequence_number = sequence_number + 1 @property def arn(self): - return '{}/stream/{}|1|{}'.format( + return "{}/stream/{}|1|{}".format( self.stream_shard.table.table_arn, self.stream_shard.table.latest_stream_label, - self.id) + self.id, + ) def to_json(self): - return { - 'ShardIterator': self.arn - } + return {"ShardIterator": self.arn} def get(self, limit=1000): items = self.stream_shard.get(self.sequence_number, limit) try: - last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items) - new_shard_iterator = ShardIterator(self.streams_backend, - self.stream_shard, - 'AFTER_SEQUENCE_NUMBER', - last_sequence_number) + last_sequence_number = max( + int(i["dynamodb"]["SequenceNumber"]) for i in items + ) + new_shard_iterator = ShardIterator( + self.streams_backend, + self.stream_shard, + "AFTER_SEQUENCE_NUMBER", + last_sequence_number, + ) except ValueError: - new_shard_iterator = ShardIterator(self.streams_backend, - self.stream_shard, - 'AT_SEQUENCE_NUMBER', - self.sequence_number) + new_shard_iterator = ShardIterator( + self.streams_backend, + self.stream_shard, + "AT_SEQUENCE_NUMBER", + self.sequence_number, + ) - self.streams_backend.shard_iterators[new_shard_iterator.arn] = new_shard_iterator - return { - 'NextShardIterator': new_shard_iterator.arn, - 'Records': items - } + self.streams_backend.shard_iterators[ + new_shard_iterator.arn + ] = new_shard_iterator + return {"NextShardIterator": new_shard_iterator.arn, "Records": items} class DynamoDBStreamsBackend(BaseBackend): @@ -72,23 +80,27 @@ class DynamoDBStreamsBackend(BaseBackend): return dynamodb_backends[self.region] def _get_table_from_arn(self, arn): - table_name = arn.split(':', 6)[5].split('/')[1] + table_name = arn.split(":", 6)[5].split("/")[1] return self.dynamodb.get_table(table_name) def describe_stream(self, arn): table = self._get_table_from_arn(arn) - resp = {'StreamDescription': { - 'StreamArn': arn, - 'StreamLabel': table.latest_stream_label, - 'StreamStatus': ('ENABLED' if table.latest_stream_label - else 'DISABLED'), - 'StreamViewType': table.stream_specification['StreamViewType'], - 'CreationRequestDateTime': table.stream_shard.created_on.isoformat(), - 'TableName': table.name, - 'KeySchema': table.schema, - 'Shards': ([table.stream_shard.to_json()] if table.stream_shard - else []) - }} + resp = { + "StreamDescription": { + "StreamArn": arn, + "StreamLabel": table.latest_stream_label, + "StreamStatus": ( + "ENABLED" if table.latest_stream_label else "DISABLED" + ), + "StreamViewType": table.stream_specification["StreamViewType"], + "CreationRequestDateTime": table.stream_shard.created_on.isoformat(), + "TableName": table.name, + "KeySchema": table.schema, + "Shards": ( + [table.stream_shard.to_json()] if table.stream_shard else [] + ), + } + } return json.dumps(resp) @@ -98,22 +110,26 @@ class DynamoDBStreamsBackend(BaseBackend): if table_name is not None and table.name != table_name: continue if table.latest_stream_label: - d = table.describe(base_key='Table') - streams.append({ - 'StreamArn': d['Table']['LatestStreamArn'], - 'TableName': d['Table']['TableName'], - 'StreamLabel': d['Table']['LatestStreamLabel'] - }) + d = table.describe(base_key="Table") + streams.append( + { + "StreamArn": d["Table"]["LatestStreamArn"], + "TableName": d["Table"]["TableName"], + "StreamLabel": d["Table"]["LatestStreamLabel"], + } + ) - return json.dumps({'Streams': streams}) + return json.dumps({"Streams": streams}) - def get_shard_iterator(self, arn, shard_id, shard_iterator_type, sequence_number=None): + def get_shard_iterator( + self, arn, shard_id, shard_iterator_type, sequence_number=None + ): table = self._get_table_from_arn(arn) assert table.stream_shard.id == shard_id - shard_iterator = ShardIterator(self, table.stream_shard, - shard_iterator_type, - sequence_number) + shard_iterator = ShardIterator( + self, table.stream_shard, shard_iterator_type, sequence_number + ) self.shard_iterators[shard_iterator.arn] = shard_iterator return json.dumps(shard_iterator.to_json()) @@ -123,7 +139,7 @@ class DynamoDBStreamsBackend(BaseBackend): return json.dumps(shard_iterator.get(limit)) -available_regions = boto3.session.Session().get_available_regions( - 'dynamodbstreams') -dynamodbstreams_backends = {region: DynamoDBStreamsBackend(region=region) - for region in available_regions} +available_regions = boto3.session.Session().get_available_regions("dynamodbstreams") +dynamodbstreams_backends = { + region: DynamoDBStreamsBackend(region=region) for region in available_regions +} diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index c9c113615..d4f5c78a6 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -3,32 +3,38 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse from .models import dynamodbstreams_backends +from six import string_types class DynamoDBStreamsHandler(BaseResponse): - @property def backend(self): return dynamodbstreams_backends[self.region] def describe_stream(self): - arn = self._get_param('StreamArn') + arn = self._get_param("StreamArn") return self.backend.describe_stream(arn) def list_streams(self): - table_name = self._get_param('TableName') + table_name = self._get_param("TableName") return self.backend.list_streams(table_name) def get_shard_iterator(self): - arn = self._get_param('StreamArn') - shard_id = self._get_param('ShardId') - shard_iterator_type = self._get_param('ShardIteratorType') - return self.backend.get_shard_iterator(arn, shard_id, - shard_iterator_type) + arn = self._get_param("StreamArn") + shard_id = self._get_param("ShardId") + shard_iterator_type = self._get_param("ShardIteratorType") + sequence_number = self._get_param("SequenceNumber") + # according to documentation sequence_number param should be string + if isinstance(sequence_number, string_types): + sequence_number = int(sequence_number) + + return self.backend.get_shard_iterator( + arn, shard_id, shard_iterator_type, sequence_number + ) def get_records(self): - arn = self._get_param('ShardIterator') - limit = self._get_param('Limit') + arn = self._get_param("ShardIterator") + limit = self._get_param("Limit") if limit is None: limit = 1000 return self.backend.get_records(arn, limit) diff --git a/moto/dynamodbstreams/urls.py b/moto/dynamodbstreams/urls.py index 1d0f94c35..a7589ae13 100644 --- a/moto/dynamodbstreams/urls.py +++ b/moto/dynamodbstreams/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DynamoDBStreamsHandler -url_bases = [ - "https?://streams.dynamodb.(.+).amazonaws.com" -] +url_bases = ["https?://streams.dynamodb.(.+).amazonaws.com"] -url_paths = { - "{0}/$": DynamoDBStreamsHandler.dispatch, -} +url_paths = {"{0}/$": DynamoDBStreamsHandler.dispatch} diff --git a/moto/ec2/__init__.py b/moto/ec2/__init__.py index ba8cbe0a0..c16912f57 100644 --- a/moto/ec2/__init__.py +++ b/moto/ec2/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import ec2_backends from ..core.models import base_decorator, deprecated_base_decorator -ec2_backend = ec2_backends['us-east-1'] +ec2_backend = ec2_backends["us-east-1"] mock_ec2 = base_decorator(ec2_backends) mock_ec2_deprecated = deprecated_base_decorator(ec2_backends) diff --git a/moto/ec2/exceptions.py b/moto/ec2/exceptions.py index b7a49cc57..b2c1792f2 100644 --- a/moto/ec2/exceptions.py +++ b/moto/ec2/exceptions.py @@ -7,499 +7,468 @@ class EC2ClientError(RESTError): class DependencyViolationError(EC2ClientError): - def __init__(self, message): - super(DependencyViolationError, self).__init__( - "DependencyViolation", message) + super(DependencyViolationError, self).__init__("DependencyViolation", message) class MissingParameterError(EC2ClientError): - def __init__(self, parameter): super(MissingParameterError, self).__init__( "MissingParameter", - "The request must contain the parameter {0}" - .format(parameter)) + "The request must contain the parameter {0}".format(parameter), + ) class InvalidDHCPOptionsIdError(EC2ClientError): - def __init__(self, dhcp_options_id): super(InvalidDHCPOptionsIdError, self).__init__( "InvalidDhcpOptionID.NotFound", - "DhcpOptionID {0} does not exist." - .format(dhcp_options_id)) + "DhcpOptionID {0} does not exist.".format(dhcp_options_id), + ) class MalformedDHCPOptionsIdError(EC2ClientError): - def __init__(self, dhcp_options_id): super(MalformedDHCPOptionsIdError, self).__init__( "InvalidDhcpOptionsId.Malformed", - "Invalid id: \"{0}\" (expecting \"dopt-...\")" - .format(dhcp_options_id)) + 'Invalid id: "{0}" (expecting "dopt-...")'.format(dhcp_options_id), + ) class InvalidKeyPairNameError(EC2ClientError): - def __init__(self, key): super(InvalidKeyPairNameError, self).__init__( - "InvalidKeyPair.NotFound", - "The keypair '{0}' does not exist." - .format(key)) + "InvalidKeyPair.NotFound", "The keypair '{0}' does not exist.".format(key) + ) class InvalidKeyPairDuplicateError(EC2ClientError): - def __init__(self, key): super(InvalidKeyPairDuplicateError, self).__init__( - "InvalidKeyPair.Duplicate", - "The keypair '{0}' already exists." - .format(key)) + "InvalidKeyPair.Duplicate", "The keypair '{0}' already exists.".format(key) + ) class InvalidKeyPairFormatError(EC2ClientError): - def __init__(self): super(InvalidKeyPairFormatError, self).__init__( - "InvalidKeyPair.Format", - "Key is not in valid OpenSSH public key format") + "InvalidKeyPair.Format", "Key is not in valid OpenSSH public key format" + ) class InvalidVPCIdError(EC2ClientError): - def __init__(self, vpc_id): super(InvalidVPCIdError, self).__init__( - "InvalidVpcID.NotFound", - "VpcID {0} does not exist." - .format(vpc_id)) + "InvalidVpcID.NotFound", "VpcID {0} does not exist.".format(vpc_id) + ) class InvalidSubnetIdError(EC2ClientError): - def __init__(self, subnet_id): super(InvalidSubnetIdError, self).__init__( "InvalidSubnetID.NotFound", - "The subnet ID '{0}' does not exist" - .format(subnet_id)) + "The subnet ID '{0}' does not exist".format(subnet_id), + ) class InvalidNetworkAclIdError(EC2ClientError): - def __init__(self, network_acl_id): super(InvalidNetworkAclIdError, self).__init__( "InvalidNetworkAclID.NotFound", - "The network acl ID '{0}' does not exist" - .format(network_acl_id)) + "The network acl ID '{0}' does not exist".format(network_acl_id), + ) class InvalidVpnGatewayIdError(EC2ClientError): - def __init__(self, network_acl_id): super(InvalidVpnGatewayIdError, self).__init__( "InvalidVpnGatewayID.NotFound", - "The virtual private gateway ID '{0}' does not exist" - .format(network_acl_id)) + "The virtual private gateway ID '{0}' does not exist".format( + network_acl_id + ), + ) class InvalidVpnConnectionIdError(EC2ClientError): - def __init__(self, network_acl_id): super(InvalidVpnConnectionIdError, self).__init__( "InvalidVpnConnectionID.NotFound", - "The vpnConnection ID '{0}' does not exist" - .format(network_acl_id)) + "The vpnConnection ID '{0}' does not exist".format(network_acl_id), + ) class InvalidCustomerGatewayIdError(EC2ClientError): - def __init__(self, customer_gateway_id): super(InvalidCustomerGatewayIdError, self).__init__( "InvalidCustomerGatewayID.NotFound", - "The customer gateway ID '{0}' does not exist" - .format(customer_gateway_id)) + "The customer gateway ID '{0}' does not exist".format(customer_gateway_id), + ) class InvalidNetworkInterfaceIdError(EC2ClientError): - def __init__(self, eni_id): super(InvalidNetworkInterfaceIdError, self).__init__( "InvalidNetworkInterfaceID.NotFound", - "The network interface ID '{0}' does not exist" - .format(eni_id)) + "The network interface ID '{0}' does not exist".format(eni_id), + ) class InvalidNetworkAttachmentIdError(EC2ClientError): - def __init__(self, attachment_id): super(InvalidNetworkAttachmentIdError, self).__init__( "InvalidAttachmentID.NotFound", - "The network interface attachment ID '{0}' does not exist" - .format(attachment_id)) + "The network interface attachment ID '{0}' does not exist".format( + attachment_id + ), + ) class InvalidSecurityGroupDuplicateError(EC2ClientError): - def __init__(self, name): super(InvalidSecurityGroupDuplicateError, self).__init__( "InvalidGroup.Duplicate", - "The security group '{0}' already exists" - .format(name)) + "The security group '{0}' already exists".format(name), + ) class InvalidSecurityGroupNotFoundError(EC2ClientError): - def __init__(self, name): super(InvalidSecurityGroupNotFoundError, self).__init__( "InvalidGroup.NotFound", - "The security group '{0}' does not exist" - .format(name)) + "The security group '{0}' does not exist".format(name), + ) class InvalidPermissionNotFoundError(EC2ClientError): - def __init__(self): super(InvalidPermissionNotFoundError, self).__init__( "InvalidPermission.NotFound", - "The specified rule does not exist in this security group") + "The specified rule does not exist in this security group", + ) class InvalidPermissionDuplicateError(EC2ClientError): - def __init__(self): super(InvalidPermissionDuplicateError, self).__init__( - "InvalidPermission.Duplicate", - "The specified rule already exists") + "InvalidPermission.Duplicate", "The specified rule already exists" + ) class InvalidRouteTableIdError(EC2ClientError): - def __init__(self, route_table_id): super(InvalidRouteTableIdError, self).__init__( "InvalidRouteTableID.NotFound", - "The routeTable ID '{0}' does not exist" - .format(route_table_id)) + "The routeTable ID '{0}' does not exist".format(route_table_id), + ) class InvalidRouteError(EC2ClientError): - def __init__(self, route_table_id, cidr): super(InvalidRouteError, self).__init__( "InvalidRoute.NotFound", - "no route with destination-cidr-block {0} in route table {1}" - .format(cidr, route_table_id)) + "no route with destination-cidr-block {0} in route table {1}".format( + cidr, route_table_id + ), + ) class InvalidInstanceIdError(EC2ClientError): - def __init__(self, instance_id): super(InvalidInstanceIdError, self).__init__( "InvalidInstanceID.NotFound", - "The instance ID '{0}' does not exist" - .format(instance_id)) + "The instance ID '{0}' does not exist".format(instance_id), + ) class InvalidAMIIdError(EC2ClientError): - def __init__(self, ami_id): super(InvalidAMIIdError, self).__init__( "InvalidAMIID.NotFound", - "The image id '[{0}]' does not exist" - .format(ami_id)) + "The image id '[{0}]' does not exist".format(ami_id), + ) class InvalidAMIAttributeItemValueError(EC2ClientError): - def __init__(self, attribute, value): super(InvalidAMIAttributeItemValueError, self).__init__( "InvalidAMIAttributeItemValue", - "Invalid attribute item value \"{0}\" for {1} item type." - .format(value, attribute)) + 'Invalid attribute item value "{0}" for {1} item type.'.format( + value, attribute + ), + ) class MalformedAMIIdError(EC2ClientError): - def __init__(self, ami_id): super(MalformedAMIIdError, self).__init__( "InvalidAMIID.Malformed", - "Invalid id: \"{0}\" (expecting \"ami-...\")" - .format(ami_id)) + 'Invalid id: "{0}" (expecting "ami-...")'.format(ami_id), + ) class InvalidSnapshotIdError(EC2ClientError): - def __init__(self, snapshot_id): super(InvalidSnapshotIdError, self).__init__( - "InvalidSnapshot.NotFound", - "") # Note: AWS returns empty message for this, as of 2014.08.22. + "InvalidSnapshot.NotFound", "" + ) # Note: AWS returns empty message for this, as of 2014.08.22. class InvalidVolumeIdError(EC2ClientError): - def __init__(self, volume_id): super(InvalidVolumeIdError, self).__init__( "InvalidVolume.NotFound", - "The volume '{0}' does not exist." - .format(volume_id)) + "The volume '{0}' does not exist.".format(volume_id), + ) class InvalidVolumeAttachmentError(EC2ClientError): - def __init__(self, volume_id, instance_id): super(InvalidVolumeAttachmentError, self).__init__( "InvalidAttachment.NotFound", - "Volume {0} can not be detached from {1} because it is not attached" - .format(volume_id, instance_id)) + "Volume {0} can not be detached from {1} because it is not attached".format( + volume_id, instance_id + ), + ) class InvalidDomainError(EC2ClientError): - def __init__(self, domain): super(InvalidDomainError, self).__init__( - "InvalidParameterValue", - "Invalid value '{0}' for domain." - .format(domain)) + "InvalidParameterValue", "Invalid value '{0}' for domain.".format(domain) + ) class InvalidAddressError(EC2ClientError): - def __init__(self, ip): super(InvalidAddressError, self).__init__( - "InvalidAddress.NotFound", - "Address '{0}' not found." - .format(ip)) + "InvalidAddress.NotFound", "Address '{0}' not found.".format(ip) + ) class InvalidAllocationIdError(EC2ClientError): - def __init__(self, allocation_id): super(InvalidAllocationIdError, self).__init__( "InvalidAllocationID.NotFound", - "Allocation ID '{0}' not found." - .format(allocation_id)) + "Allocation ID '{0}' not found.".format(allocation_id), + ) class InvalidAssociationIdError(EC2ClientError): - def __init__(self, association_id): super(InvalidAssociationIdError, self).__init__( "InvalidAssociationID.NotFound", - "Association ID '{0}' not found." - .format(association_id)) + "Association ID '{0}' not found.".format(association_id), + ) class InvalidVpcCidrBlockAssociationIdError(EC2ClientError): - def __init__(self, association_id): super(InvalidVpcCidrBlockAssociationIdError, self).__init__( "InvalidVpcCidrBlockAssociationIdError.NotFound", - "The vpc CIDR block association ID '{0}' does not exist" - .format(association_id)) + "The vpc CIDR block association ID '{0}' does not exist".format( + association_id + ), + ) class InvalidVPCPeeringConnectionIdError(EC2ClientError): - def __init__(self, vpc_peering_connection_id): super(InvalidVPCPeeringConnectionIdError, self).__init__( "InvalidVpcPeeringConnectionId.NotFound", - "VpcPeeringConnectionID {0} does not exist." - .format(vpc_peering_connection_id)) + "VpcPeeringConnectionID {0} does not exist.".format( + vpc_peering_connection_id + ), + ) class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError): - def __init__(self, vpc_peering_connection_id): super(InvalidVPCPeeringConnectionStateTransitionError, self).__init__( "InvalidStateTransition", - "VpcPeeringConnectionID {0} is not in the correct state for the request." - .format(vpc_peering_connection_id)) + "VpcPeeringConnectionID {0} is not in the correct state for the request.".format( + vpc_peering_connection_id + ), + ) class InvalidParameterValueError(EC2ClientError): - def __init__(self, parameter_value): super(InvalidParameterValueError, self).__init__( "InvalidParameterValue", - "Value {0} is invalid for parameter." - .format(parameter_value)) + "Value {0} is invalid for parameter.".format(parameter_value), + ) class InvalidParameterValueErrorTagNull(EC2ClientError): - def __init__(self): super(InvalidParameterValueErrorTagNull, self).__init__( "InvalidParameterValue", - "Tag value cannot be null. Use empty string instead.") + "Tag value cannot be null. Use empty string instead.", + ) class InvalidParameterValueErrorUnknownAttribute(EC2ClientError): - def __init__(self, parameter_value): super(InvalidParameterValueErrorUnknownAttribute, self).__init__( "InvalidParameterValue", - "Value ({0}) for parameter attribute is invalid. Unknown attribute." - .format(parameter_value)) + "Value ({0}) for parameter attribute is invalid. Unknown attribute.".format( + parameter_value + ), + ) class InvalidInternetGatewayIdError(EC2ClientError): - def __init__(self, internet_gateway_id): super(InvalidInternetGatewayIdError, self).__init__( "InvalidInternetGatewayID.NotFound", - "InternetGatewayID {0} does not exist." - .format(internet_gateway_id)) + "InternetGatewayID {0} does not exist.".format(internet_gateway_id), + ) class GatewayNotAttachedError(EC2ClientError): - def __init__(self, internet_gateway_id, vpc_id): super(GatewayNotAttachedError, self).__init__( "Gateway.NotAttached", - "InternetGatewayID {0} is not attached to a VPC {1}." - .format(internet_gateway_id, vpc_id)) + "InternetGatewayID {0} is not attached to a VPC {1}.".format( + internet_gateway_id, vpc_id + ), + ) class ResourceAlreadyAssociatedError(EC2ClientError): - def __init__(self, resource_id): super(ResourceAlreadyAssociatedError, self).__init__( "Resource.AlreadyAssociated", - "Resource {0} is already associated." - .format(resource_id)) + "Resource {0} is already associated.".format(resource_id), + ) class TagLimitExceeded(EC2ClientError): - def __init__(self): super(TagLimitExceeded, self).__init__( "TagLimitExceeded", - "The maximum number of Tags for a resource has been reached.") + "The maximum number of Tags for a resource has been reached.", + ) class InvalidID(EC2ClientError): - def __init__(self, resource_id): super(InvalidID, self).__init__( - "InvalidID", - "The ID '{0}' is not valid" - .format(resource_id)) + "InvalidID", "The ID '{0}' is not valid".format(resource_id) + ) class InvalidCIDRSubnetError(EC2ClientError): - def __init__(self, cidr): super(InvalidCIDRSubnetError, self).__init__( "InvalidParameterValue", - "invalid CIDR subnet specification: {0}" - .format(cidr)) + "invalid CIDR subnet specification: {0}".format(cidr), + ) class RulesPerSecurityGroupLimitExceededError(EC2ClientError): - def __init__(self): super(RulesPerSecurityGroupLimitExceededError, self).__init__( "RulesPerSecurityGroupLimitExceeded", - 'The maximum number of rules per security group ' - 'has been reached.') + "The maximum number of rules per security group " "has been reached.", + ) class MotoNotImplementedError(NotImplementedError): - def __init__(self, blurb): super(MotoNotImplementedError, self).__init__( "{0} has not been implemented in Moto yet." " Feel free to open an issue at" - " https://github.com/spulec/moto/issues".format(blurb)) + " https://github.com/spulec/moto/issues".format(blurb) + ) class FilterNotImplementedError(MotoNotImplementedError): - def __init__(self, filter_name, method_name): super(FilterNotImplementedError, self).__init__( - "The filter '{0}' for {1}".format( - filter_name, method_name)) + "The filter '{0}' for {1}".format(filter_name, method_name) + ) class CidrLimitExceeded(EC2ClientError): - def __init__(self, vpc_id, max_cidr_limit): super(CidrLimitExceeded, self).__init__( "CidrLimitExceeded", - "This network '{0}' has met its maximum number of allowed CIDRs: {1}".format(vpc_id, max_cidr_limit) + "This network '{0}' has met its maximum number of allowed CIDRs: {1}".format( + vpc_id, max_cidr_limit + ), ) class OperationNotPermitted(EC2ClientError): - def __init__(self, association_id): super(OperationNotPermitted, self).__init__( "OperationNotPermitted", "The vpc CIDR block with association ID {} may not be disassociated. " - "It is the primary IPv4 CIDR block of the VPC".format(association_id) + "It is the primary IPv4 CIDR block of the VPC".format(association_id), ) class InvalidAvailabilityZoneError(EC2ClientError): - def __init__(self, availability_zone_value, valid_availability_zones): super(InvalidAvailabilityZoneError, self).__init__( "InvalidParameterValue", "Value ({0}) for parameter availabilityZone is invalid. " - "Subnets can currently only be created in the following availability zones: {1}.".format(availability_zone_value, valid_availability_zones) + "Subnets can currently only be created in the following availability zones: {1}.".format( + availability_zone_value, valid_availability_zones + ), ) class NetworkAclEntryAlreadyExistsError(EC2ClientError): - def __init__(self, rule_number): super(NetworkAclEntryAlreadyExistsError, self).__init__( "NetworkAclEntryAlreadyExists", - "The network acl entry identified by {} already exists.".format(rule_number) + "The network acl entry identified by {} already exists.".format( + rule_number + ), ) class InvalidSubnetRangeError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidSubnetRangeError, self).__init__( - "InvalidSubnet.Range", - "The CIDR '{}' is invalid.".format(cidr_block) + "InvalidSubnet.Range", "The CIDR '{}' is invalid.".format(cidr_block) ) class InvalidCIDRBlockParameterError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidCIDRBlockParameterError, self).__init__( "InvalidParameterValue", - "Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block) + "Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format( + cidr_block + ), ) class InvalidDestinationCIDRBlockParameterError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidDestinationCIDRBlockParameterError, self).__init__( "InvalidParameterValue", - "Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block) + "Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format( + cidr_block + ), ) class InvalidSubnetConflictError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidSubnetConflictError, self).__init__( "InvalidSubnet.Conflict", - "The CIDR '{}' conflicts with another subnet".format(cidr_block) + "The CIDR '{}' conflicts with another subnet".format(cidr_block), ) class InvalidVPCRangeError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidVPCRangeError, self).__init__( - "InvalidVpc.Range", - "The CIDR '{}' is invalid.".format(cidr_block) + "InvalidVpc.Range", "The CIDR '{}' is invalid.".format(cidr_block) ) @@ -509,7 +478,9 @@ class OperationNotPermitted2(EC2ClientError): super(OperationNotPermitted2, self).__init__( "OperationNotPermitted", "Incorrect region ({0}) specified for this request." - "VPC peering connection {1} must be accepted in region {2}".format(client_region, pcx_id, acceptor_region) + "VPC peering connection {1} must be accepted in region {2}".format( + client_region, pcx_id, acceptor_region + ), ) @@ -519,9 +490,9 @@ class OperationNotPermitted3(EC2ClientError): super(OperationNotPermitted3, self).__init__( "OperationNotPermitted", "Incorrect region ({0}) specified for this request." - "VPC peering connection {1} must be accepted or rejected in region {2}".format(client_region, - pcx_id, - acceptor_region) + "VPC peering connection {1} must be accepted or rejected in region {2}".format( + client_region, pcx_id, acceptor_region + ), ) @@ -529,5 +500,5 @@ class InvalidLaunchTemplateNameError(EC2ClientError): def __init__(self): super(InvalidLaunchTemplateNameError, self).__init__( "InvalidLaunchTemplateName.AlreadyExistsException", - "Launch template name already in use." + "Launch template name already in use.", ) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index 10d6f2b28..374494faa 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -23,7 +23,10 @@ from boto.ec2.launchspecification import LaunchSpecification from moto.compat import OrderedDict from moto.core import BaseBackend from moto.core.models import Model, BaseModel -from moto.core.utils import iso_8601_datetime_with_milliseconds, camelcase_to_underscores +from moto.core.utils import ( + iso_8601_datetime_with_milliseconds, + camelcase_to_underscores, +) from .exceptions import ( CidrLimitExceeded, DependencyViolationError, @@ -84,7 +87,8 @@ from .exceptions import ( OperationNotPermitted3, ResourceAlreadyAssociatedError, RulesPerSecurityGroupLimitExceededError, - TagLimitExceeded) + TagLimitExceeded, +) from .utils import ( EC2_RESOURCE_TO_PREFIX, EC2_PREFIX_TO_RESOURCE, @@ -132,27 +136,30 @@ from .utils import ( is_tag_filter, tag_filter_matches, rsa_public_key_parse, - rsa_public_key_fingerprint + rsa_public_key_fingerprint, ) INSTANCE_TYPES = json.load( - open(resource_filename(__name__, 'resources/instance_types.json'), 'r') + open(resource_filename(__name__, "resources/instance_types.json"), "r") ) AMIS = json.load( - open(os.environ.get('MOTO_AMIS_PATH') or resource_filename( - __name__, 'resources/amis.json'), 'r') + open( + os.environ.get("MOTO_AMIS_PATH") + or resource_filename(__name__, "resources/amis.json"), + "r", + ) ) OWNER_ID = "111122223333" def utc_date_and_time(): - return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z') + return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.000Z") def validate_resource_ids(resource_ids): if not resource_ids: - raise MissingParameterError(parameter='resourceIdSet') + raise MissingParameterError(parameter="resourceIdSet") for resource_id in resource_ids: if not is_valid_resource_id(resource_id): raise InvalidID(resource_id=resource_id) @@ -160,7 +167,7 @@ def validate_resource_ids(resource_ids): class InstanceState(object): - def __init__(self, name='pending', code=0): + def __init__(self, name="pending", code=0): self.name = name self.code = code @@ -173,8 +180,7 @@ class StateReason(object): class TaggedEC2Resource(BaseModel): def get_tags(self, *args, **kwargs): - tags = self.ec2_backend.describe_tags( - filters={'resource-id': [self.id]}) + tags = self.ec2_backend.describe_tags(filters={"resource-id": [self.id]}) return tags def add_tag(self, key, value): @@ -187,28 +193,38 @@ class TaggedEC2Resource(BaseModel): def get_filter_value(self, filter_name, method_name=None): tags = self.get_tags() - if filter_name.startswith('tag:'): - tagname = filter_name.replace('tag:', '', 1) + if filter_name.startswith("tag:"): + tagname = filter_name.replace("tag:", "", 1) for tag in tags: - if tag['key'] == tagname: - return tag['value'] + if tag["key"] == tagname: + return tag["value"] - return '' - elif filter_name == 'tag-key': - return [tag['key'] for tag in tags] - elif filter_name == 'tag-value': - return [tag['value'] for tag in tags] + return "" + elif filter_name == "tag-key": + return [tag["key"] for tag in tags] + elif filter_name == "tag-value": + return [tag["value"] for tag in tags] else: raise FilterNotImplementedError(filter_name, method_name) class NetworkInterface(TaggedEC2Resource): - def __init__(self, ec2_backend, subnet, private_ip_address, device_index=0, - public_ip_auto_assign=True, group_ids=None, description=None): + def __init__( + self, + ec2_backend, + subnet, + private_ip_address, + private_ip_addresses=None, + device_index=0, + public_ip_auto_assign=True, + group_ids=None, + description=None, + ): self.ec2_backend = ec2_backend self.id = random_eni_id() self.device_index = device_index self.private_ip_address = private_ip_address or random_private_ip() + self.private_ip_addresses = private_ip_addresses self.subnet = subnet self.instance = None self.attachment_id = None @@ -231,32 +247,39 @@ class NetworkInterface(TaggedEC2Resource): if not group: # Create with specific group ID. group = SecurityGroup( - self.ec2_backend, group_id, group_id, group_id, vpc_id=subnet.vpc_id) + self.ec2_backend, + group_id, + group_id, + group_id, + vpc_id=subnet.vpc_id, + ) self.ec2_backend.groups[subnet.vpc_id][group_id] = group if group: self._group_set.append(group) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - security_group_ids = properties.get('SecurityGroups', []) + security_group_ids = properties.get("SecurityGroups", []) ec2_backend = ec2_backends[region_name] - subnet_id = properties.get('SubnetId') + subnet_id = properties.get("SubnetId") if subnet_id: subnet = ec2_backend.get_subnet(subnet_id) else: subnet = None - private_ip_address = properties.get('PrivateIpAddress', None) - description = properties.get('Description', None) + private_ip_address = properties.get("PrivateIpAddress", None) + description = properties.get("Description", None) network_interface = ec2_backend.create_network_interface( subnet, private_ip_address, group_ids=security_group_ids, - description=description + description=description, ) return network_interface @@ -280,11 +303,13 @@ class NetworkInterface(TaggedEC2Resource): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'PrimaryPrivateIpAddress': + + if attribute_name == "PrimaryPrivateIpAddress": return self.private_ip_address - elif attribute_name == 'SecondaryPrivateIpAddresses': + elif attribute_name == "SecondaryPrivateIpAddresses": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "SecondaryPrivateIpAddresses" ]"') + '"Fn::GetAtt" : [ "{0}" , "SecondaryPrivateIpAddresses" ]"' + ) raise UnformattedGetAttTemplateException() @property @@ -292,23 +317,24 @@ class NetworkInterface(TaggedEC2Resource): return self.id def get_filter_value(self, filter_name): - if filter_name == 'network-interface-id': + if filter_name == "network-interface-id": return self.id - elif filter_name in ('addresses.private-ip-address', 'private-ip-address'): + elif filter_name in ("addresses.private-ip-address", "private-ip-address"): return self.private_ip_address - elif filter_name == 'subnet-id': + elif filter_name == "subnet-id": return self.subnet.id - elif filter_name == 'vpc-id': + elif filter_name == "vpc-id": return self.subnet.vpc_id - elif filter_name == 'group-id': + elif filter_name == "group-id": return [group.id for group in self._group_set] - elif filter_name == 'availability-zone': + elif filter_name == "availability-zone": return self.subnet.availability_zone - elif filter_name == 'description': + elif filter_name == "description": return self.description else: return super(NetworkInterface, self).get_filter_value( - filter_name, 'DescribeNetworkInterfaces') + filter_name, "DescribeNetworkInterfaces" + ) class NetworkInterfaceBackend(object): @@ -316,9 +342,24 @@ class NetworkInterfaceBackend(object): self.enis = {} super(NetworkInterfaceBackend, self).__init__() - def create_network_interface(self, subnet, private_ip_address, group_ids=None, description=None, **kwargs): + def create_network_interface( + self, + subnet, + private_ip_address, + private_ip_addresses=None, + group_ids=None, + description=None, + **kwargs + ): eni = NetworkInterface( - self, subnet, private_ip_address, group_ids=group_ids, description=description, **kwargs) + self, + subnet, + private_ip_address, + private_ip_addresses, + group_ids=group_ids, + description=description, + **kwargs + ) self.enis[eni.id] = eni return eni @@ -339,11 +380,12 @@ class NetworkInterfaceBackend(object): if filters: for (_filter, _filter_value) in filters.items(): - if _filter == 'network-interface-id': - _filter = 'id' - enis = [eni for eni in enis if getattr( - eni, _filter) in _filter_value] - elif _filter == 'group-id': + if _filter == "network-interface-id": + _filter = "id" + enis = [ + eni for eni in enis if getattr(eni, _filter) in _filter_value + ] + elif _filter == "group-id": original_enis = enis enis = [] for eni in original_enis: @@ -351,15 +393,18 @@ class NetworkInterfaceBackend(object): if group.id in _filter_value: enis.append(eni) break - elif _filter == 'private-ip-address:': - enis = [eni for eni in enis if eni.private_ip_address in _filter_value] - elif _filter == 'subnet-id': + elif _filter == "private-ip-address:": + enis = [ + eni for eni in enis if eni.private_ip_address in _filter_value + ] + elif _filter == "subnet-id": enis = [eni for eni in enis if eni.subnet.id in _filter_value] - elif _filter == 'description': + elif _filter == "description": enis = [eni for eni in enis if eni.description in _filter_value] else: self.raise_not_implemented_error( - "The filter '{0}' for DescribeNetworkInterfaces".format(_filter)) + "The filter '{0}' for DescribeNetworkInterfaces".format(_filter) + ) return enis def attach_network_interface(self, eni_id, instance_id, device_index): @@ -390,17 +435,30 @@ class NetworkInterfaceBackend(object): if eni_ids: enis = [eni for eni in enis if eni.id in eni_ids] if len(enis) != len(eni_ids): - invalid_id = list(set(eni_ids).difference( - set([eni.id for eni in enis])))[0] + invalid_id = list( + set(eni_ids).difference(set([eni.id for eni in enis])) + )[0] raise InvalidNetworkInterfaceIdError(invalid_id) return generic_filter(filters, enis) class Instance(TaggedEC2Resource, BotoInstance): - VALID_ATTRIBUTES = {'instanceType', 'kernel', 'ramdisk', 'userData', 'disableApiTermination', - 'instanceInitiatedShutdownBehavior', 'rootDeviceName', 'blockDeviceMapping', - 'productCodes', 'sourceDestCheck', 'groupSet', 'ebsOptimized', 'sriovNetSupport'} + VALID_ATTRIBUTES = { + "instanceType", + "kernel", + "ramdisk", + "userData", + "disableApiTermination", + "instanceInitiatedShutdownBehavior", + "rootDeviceName", + "blockDeviceMapping", + "productCodes", + "sourceDestCheck", + "groupSet", + "ebsOptimized", + "sriovNetSupport", + } def __init__(self, ec2_backend, image_id, user_data, security_groups, **kwargs): super(Instance, self).__init__() @@ -424,7 +482,9 @@ class Instance(TaggedEC2Resource, BotoInstance): self.launch_time = utc_date_and_time() self.ami_launch_index = kwargs.get("ami_launch_index", 0) self.disable_api_termination = kwargs.get("disable_api_termination", False) - self.instance_initiated_shutdown_behavior = kwargs.get("instance_initiated_shutdown_behavior", "stop") + self.instance_initiated_shutdown_behavior = kwargs.get( + "instance_initiated_shutdown_behavior", "stop" + ) self.sriov_net_support = "simple" self._spot_fleet_id = kwargs.get("spot_fleet_id", None) self.associate_public_ip = kwargs.get("associate_public_ip", False) @@ -432,29 +492,31 @@ class Instance(TaggedEC2Resource, BotoInstance): # If we are in EC2-Classic, autoassign a public IP self.associate_public_ip = True - amis = self.ec2_backend.describe_images(filters={'image-id': image_id}) + amis = self.ec2_backend.describe_images(filters={"image-id": image_id}) ami = amis[0] if amis else None if ami is None: - warnings.warn('Could not find AMI with image-id:{0}, ' - 'in the near future this will ' - 'cause an error.\n' - 'Use ec2_backend.describe_images() to ' - 'find suitable image for your test'.format(image_id), - PendingDeprecationWarning) + warnings.warn( + "Could not find AMI with image-id:{0}, " + "in the near future this will " + "cause an error.\n" + "Use ec2_backend.describe_images() to " + "find suitable image for your test".format(image_id), + PendingDeprecationWarning, + ) self.platform = ami.platform if ami else None - self.virtualization_type = ami.virtualization_type if ami else 'paravirtual' - self.architecture = ami.architecture if ami else 'x86_64' + self.virtualization_type = ami.virtualization_type if ami else "paravirtual" + self.architecture = ami.architecture if ami else "x86_64" # handle weird bug around user_data -- something grabs the repr(), so # it must be clean if isinstance(self.user_data, list) and len(self.user_data) > 0: if six.PY3 and isinstance(self.user_data[0], six.binary_type): # string will have a "b" prefix -- need to get rid of it - self.user_data[0] = self.user_data[0].decode('utf-8') + self.user_data[0] = self.user_data[0].decode("utf-8") elif six.PY2 and isinstance(self.user_data[0], six.text_type): # string will have a "u" prefix -- need to get rid of it - self.user_data[0] = self.user_data[0].encode('utf-8') + self.user_data[0] = self.user_data[0].encode("utf-8") if self.subnet_id: subnet = ec2_backend.get_subnet(self.subnet_id) @@ -463,11 +525,11 @@ class Instance(TaggedEC2Resource, BotoInstance): if self.associate_public_ip is None: # Mapping public ip hasnt been explicitly enabled or disabled - self.associate_public_ip = subnet.map_public_ip_on_launch == 'true' + self.associate_public_ip = subnet.map_public_ip_on_launch == "true" elif placement: self._placement.zone = placement else: - self._placement.zone = ec2_backend.region_name + 'a' + self._placement.zone = ec2_backend.region_name + "a" self.block_device_mapping = BlockDeviceMapping() @@ -475,7 +537,7 @@ class Instance(TaggedEC2Resource, BotoInstance): self.prep_nics( kwargs.get("nics", {}), private_ip=kwargs.get("private_ip"), - associate_public_ip=self.associate_public_ip + associate_public_ip=self.associate_public_ip, ) def __del__(self): @@ -491,12 +553,12 @@ class Instance(TaggedEC2Resource, BotoInstance): def setup_defaults(self): # Default have an instance with root volume should you not wish to # override with attach volume cmd. - volume = self.ec2_backend.create_volume(8, 'us-east-1a') - self.ec2_backend.attach_volume(volume.id, self.id, '/dev/sda1') + volume = self.ec2_backend.create_volume(8, "us-east-1a") + self.ec2_backend.attach_volume(volume.id, self.id, "/dev/sda1") def teardown_defaults(self): - volume_id = self.block_device_mapping['/dev/sda1'].volume_id - self.ec2_backend.detach_volume(volume_id, self.id, '/dev/sda1') + volume_id = self.block_device_mapping["/dev/sda1"].volume_id + self.ec2_backend.detach_volume(volume_id, self.id, "/dev/sda1") self.ec2_backend.delete_volume(volume_id) @property @@ -509,7 +571,7 @@ class Instance(TaggedEC2Resource, BotoInstance): @property def private_dns(self): - formatted_ip = self.private_ip.replace('.', '-') + formatted_ip = self.private_ip.replace(".", "-") if self.region_name == "us-east-1": return "ip-{0}.ec2.internal".format(formatted_ip) else: @@ -522,30 +584,36 @@ class Instance(TaggedEC2Resource, BotoInstance): @property def public_dns(self): if self.public_ip: - formatted_ip = self.public_ip.replace('.', '-') + formatted_ip = self.public_ip.replace(".", "-") if self.region_name == "us-east-1": return "ec2-{0}.compute-1.amazonaws.com".format(formatted_ip) else: - return "ec2-{0}.{1}.compute.amazonaws.com".format(formatted_ip, self.region_name) + return "ec2-{0}.{1}.compute.amazonaws.com".format( + formatted_ip, self.region_name + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - security_group_ids = properties.get('SecurityGroups', []) - group_names = [ec2_backend.get_security_group_from_id( - group_id).name for group_id in security_group_ids] + security_group_ids = properties.get("SecurityGroups", []) + group_names = [ + ec2_backend.get_security_group_from_id(group_id).name + for group_id in security_group_ids + ] reservation = ec2_backend.add_instances( - image_id=properties['ImageId'], - user_data=properties.get('UserData'), + image_id=properties["ImageId"], + user_data=properties.get("UserData"), count=1, security_group_names=group_names, instance_type=properties.get("InstanceType", "m1.small"), subnet_id=properties.get("SubnetId"), key_name=properties.get("KeyName"), - private_ip=properties.get('PrivateIpAddress'), + private_ip=properties.get("PrivateIpAddress"), ) instance = reservation.instances[0] for tag in properties.get("Tags", []): @@ -553,19 +621,24 @@ class Instance(TaggedEC2Resource, BotoInstance): return instance @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] all_instances = ec2_backend.all_instances() # the resource_name for instances is the stack name, logical id, and random suffix separated # by hyphens. So to lookup the instances using the 'aws:cloudformation:logical-id' tag, we need to # extract the logical-id from the resource_name - logical_id = resource_name.split('-')[1] + logical_id = resource_name.split("-")[1] for instance in all_instances: instance_tags = instance.get_tags() for tag in instance_tags: - if tag['key'] == 'aws:cloudformation:logical-id' and tag['value'] == logical_id: + if ( + tag["key"] == "aws:cloudformation:logical-id" + and tag["value"] == logical_id + ): instance.delete(region_name) @property @@ -590,9 +663,12 @@ class Instance(TaggedEC2Resource, BotoInstance): self._state.code = 80 self._reason = "User initiated ({0})".format( - datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')) - self._state_reason = StateReason("Client.UserInitiatedShutdown: User initiated shutdown", - "Client.UserInitiatedShutdown") + datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC") + ) + self._state_reason = StateReason( + "Client.UserInitiatedShutdown: User initiated shutdown", + "Client.UserInitiatedShutdown", + ) def delete(self, region): self.terminate() @@ -606,18 +682,26 @@ class Instance(TaggedEC2Resource, BotoInstance): if self._spot_fleet_id: spot_fleet = self.ec2_backend.get_spot_fleet_request(self._spot_fleet_id) for spec in spot_fleet.launch_specs: - if spec.instance_type == self.instance_type and spec.subnet_id == self.subnet_id: + if ( + spec.instance_type == self.instance_type + and spec.subnet_id == self.subnet_id + ): break spot_fleet.fulfilled_capacity -= spec.weighted_capacity - spot_fleet.spot_requests = [req for req in spot_fleet.spot_requests if req.instance != self] + spot_fleet.spot_requests = [ + req for req in spot_fleet.spot_requests if req.instance != self + ] self._state.name = "terminated" self._state.code = 48 self._reason = "User initiated ({0})".format( - datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')) - self._state_reason = StateReason("Client.UserInitiatedShutdown: User initiated shutdown", - "Client.UserInitiatedShutdown") + datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC") + ) + self._state_reason = StateReason( + "Client.UserInitiatedShutdown: User initiated shutdown", + "Client.UserInitiatedShutdown", + ) def reboot(self, *args, **kwargs): self._state.name = "running" @@ -653,22 +737,24 @@ class Instance(TaggedEC2Resource, BotoInstance): private_ip = random_private_ip() # Primary NIC defaults - primary_nic = {'SubnetId': self.subnet_id, - 'PrivateIpAddress': private_ip, - 'AssociatePublicIpAddress': associate_public_ip} + primary_nic = { + "SubnetId": self.subnet_id, + "PrivateIpAddress": private_ip, + "AssociatePublicIpAddress": associate_public_ip, + } primary_nic = dict((k, v) for k, v in primary_nic.items() if v) # If empty NIC spec but primary NIC values provided, create NIC from # them. if primary_nic and not nic_spec: nic_spec[0] = primary_nic - nic_spec[0]['DeviceIndex'] = 0 + nic_spec[0]["DeviceIndex"] = 0 # Flesh out data structures and associations for nic in nic_spec.values(): - device_index = int(nic.get('DeviceIndex')) + device_index = int(nic.get("DeviceIndex")) - nic_id = nic.get('NetworkInterfaceId') + nic_id = nic.get("NetworkInterfaceId") if nic_id: # If existing NIC found, use it. use_nic = self.ec2_backend.get_network_interface(nic_id) @@ -680,21 +766,21 @@ class Instance(TaggedEC2Resource, BotoInstance): if device_index == 0 and primary_nic: nic.update(primary_nic) - if 'SubnetId' in nic: - subnet = self.ec2_backend.get_subnet(nic['SubnetId']) + if "SubnetId" in nic: + subnet = self.ec2_backend.get_subnet(nic["SubnetId"]) else: subnet = None - group_id = nic.get('SecurityGroupId') + group_id = nic.get("SecurityGroupId") group_ids = [group_id] if group_id else [] - use_nic = self.ec2_backend.create_network_interface(subnet, - nic.get( - 'PrivateIpAddress'), - device_index=device_index, - public_ip_auto_assign=nic.get( - 'AssociatePublicIpAddress', False), - group_ids=group_ids) + use_nic = self.ec2_backend.create_network_interface( + subnet, + nic.get("PrivateIpAddress"), + device_index=device_index, + public_ip_auto_assign=nic.get("AssociatePublicIpAddress", False), + group_ids=group_ids, + ) self.attach_eni(use_nic, device_index) @@ -717,15 +803,16 @@ class Instance(TaggedEC2Resource, BotoInstance): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'AvailabilityZone': + + if attribute_name == "AvailabilityZone": return self.placement - elif attribute_name == 'PrivateDnsName': + elif attribute_name == "PrivateDnsName": return self.private_dns - elif attribute_name == 'PublicDnsName': + elif attribute_name == "PublicDnsName": return self.public_dns - elif attribute_name == 'PrivateIp': + elif attribute_name == "PrivateIp": return self.private_ip - elif attribute_name == 'PublicIp': + elif attribute_name == "PublicIp": return self.public_ip raise UnformattedGetAttTemplateException() @@ -741,28 +828,26 @@ class InstanceBackend(object): return instance raise InvalidInstanceIdError(instance_id) - def add_instances(self, image_id, count, user_data, security_group_names, - **kwargs): + def add_instances(self, image_id, count, user_data, security_group_names, **kwargs): new_reservation = Reservation() new_reservation.id = random_reservation_id() - security_groups = [self.get_security_group_from_name(name) - for name in security_group_names] - security_groups.extend(self.get_security_group_from_id(sg_id) - for sg_id in kwargs.pop("security_group_ids", [])) + security_groups = [ + self.get_security_group_from_name(name) for name in security_group_names + ] + security_groups.extend( + self.get_security_group_from_id(sg_id) + for sg_id in kwargs.pop("security_group_ids", []) + ) self.reservations[new_reservation.id] = new_reservation tags = kwargs.pop("tags", {}) - instance_tags = tags.get('instance', {}) + instance_tags = tags.get("instance", {}) for index in range(count): kwargs["ami_launch_index"] = index new_instance = Instance( - self, - image_id, - user_data, - security_groups, - **kwargs + self, image_id, user_data, security_groups, **kwargs ) new_reservation.instances.append(new_instance) new_instance.add_tags(instance_tags) @@ -789,7 +874,8 @@ class InstanceBackend(object): terminated_instances = [] if not instance_ids: raise EC2ClientError( - "InvalidParameterCombination", "No instances specified") + "InvalidParameterCombination", "No instances specified" + ) for instance in self.get_multi_instances_by_id(instance_ids): instance.terminate() terminated_instances.append(instance) @@ -814,15 +900,15 @@ class InstanceBackend(object): new_group_list = [] for new_group_id in new_group_id_list: new_group_list.append(self.get_security_group_from_id(new_group_id)) - setattr(instance, 'security_groups', new_group_list) + setattr(instance, "security_groups", new_group_list) return instance def describe_instance_attribute(self, instance_id, attribute): if attribute not in Instance.VALID_ATTRIBUTES: raise InvalidParameterValueErrorUnknownAttribute(attribute) - if attribute == 'groupSet': - key = 'security_groups' + if attribute == "groupSet": + key = "security_groups" else: key = camelcase_to_underscores(attribute) instance = self.get_instance(instance_id) @@ -875,25 +961,34 @@ class InstanceBackend(object): reservations = [] for reservation in self.all_reservations(): reservation_instance_ids = [ - instance.id for instance in reservation.instances] + instance.id for instance in reservation.instances + ] matching_reservation = any( - instance_id in reservation_instance_ids for instance_id in instance_ids) + instance_id in reservation_instance_ids for instance_id in instance_ids + ) if matching_reservation: reservation.instances = [ - instance for instance in reservation.instances if instance.id in instance_ids] + instance + for instance in reservation.instances + if instance.id in instance_ids + ] reservations.append(reservation) found_instance_ids = [ - instance.id for reservation in reservations for instance in reservation.instances] + instance.id + for reservation in reservations + for instance in reservation.instances + ] if len(found_instance_ids) != len(instance_ids): - invalid_id = list(set(instance_ids).difference( - set(found_instance_ids)))[0] + invalid_id = list(set(instance_ids).difference(set(found_instance_ids)))[0] raise InvalidInstanceIdError(invalid_id) if filters is not None: reservations = filter_reservations(reservations, filters) return reservations def all_reservations(self, filters=None): - reservations = [copy.copy(reservation) for reservation in self.reservations.values()] + reservations = [ + copy.copy(reservation) for reservation in self.reservations.values() + ] if filters is not None: reservations = filter_reservations(reservations, filters) return reservations @@ -906,12 +1001,12 @@ class KeyPair(object): self.material = material def get_filter_value(self, filter_name): - if filter_name == 'key-name': + if filter_name == "key-name": return self.name - elif filter_name == 'fingerprint': + elif filter_name == "fingerprint": return self.fingerprint else: - raise FilterNotImplementedError(filter_name, 'DescribeKeyPairs') + raise FilterNotImplementedError(filter_name, "DescribeKeyPairs") class KeyPairBackend(object): @@ -934,8 +1029,11 @@ class KeyPairBackend(object): def describe_key_pairs(self, key_names=None, filters=None): results = [] if key_names: - results = [keypair for keypair in self.keypairs.values() - if keypair.name in key_names] + results = [ + keypair + for keypair in self.keypairs.values() + if keypair.name in key_names + ] if len(key_names) > len(results): unknown_keys = set(key_names) - set(results) raise InvalidKeyPairNameError(unknown_keys) @@ -957,35 +1055,35 @@ class KeyPairBackend(object): raise InvalidKeyPairFormatError() fingerprint = rsa_public_key_fingerprint(rsa_public_key) - keypair = KeyPair(key_name, material=public_key_material, fingerprint=fingerprint) + keypair = KeyPair( + key_name, material=public_key_material, fingerprint=fingerprint + ) self.keypairs[key_name] = keypair return keypair class TagBackend(object): - VALID_TAG_FILTERS = ['key', - 'resource-id', - 'resource-type', - 'value'] + VALID_TAG_FILTERS = ["key", "resource-id", "resource-type", "value"] - VALID_TAG_RESOURCE_FILTER_TYPES = ['customer-gateway', - 'dhcp-options', - 'image', - 'instance', - 'internet-gateway', - 'network-acl', - 'network-interface', - 'reserved-instances', - 'route-table', - 'security-group', - 'snapshot', - 'spot-instances-request', - 'subnet', - 'volume', - 'vpc', - 'vpc-peering-connection' - 'vpn-connection', - 'vpn-gateway'] + VALID_TAG_RESOURCE_FILTER_TYPES = [ + "customer-gateway", + "dhcp-options", + "image", + "instance", + "internet-gateway", + "network-acl", + "network-interface", + "reserved-instances", + "route-table", + "security-group", + "snapshot", + "spot-instances-request", + "subnet", + "volume", + "vpc", + "vpc-peering-connection" "vpn-connection", + "vpn-gateway", + ] def __init__(self): self.tags = defaultdict(dict) @@ -996,7 +1094,11 @@ class TagBackend(object): raise InvalidParameterValueErrorTagNull() for resource_id in resource_ids: if resource_id in self.tags: - if len(self.tags[resource_id]) + len([tag for tag in tags if not tag.startswith("aws:")]) > 50: + if ( + len(self.tags[resource_id]) + + len([tag for tag in tags if not tag.startswith("aws:")]) + > 50 + ): raise TagLimitExceeded() elif len([tag for tag in tags if not tag.startswith("aws:")]) > 50: raise TagLimitExceeded() @@ -1017,6 +1119,7 @@ class TagBackend(object): def describe_tags(self, filters=None): import re + results = [] key_filters = [] resource_id_filters = [] @@ -1025,21 +1128,24 @@ class TagBackend(object): if filters is not None: for tag_filter in filters: if tag_filter in self.VALID_TAG_FILTERS: - if tag_filter == 'key': + if tag_filter == "key": for value in filters[tag_filter]: - key_filters.append(re.compile( - simple_aws_filter_to_re(value))) - if tag_filter == 'resource-id': + key_filters.append( + re.compile(simple_aws_filter_to_re(value)) + ) + if tag_filter == "resource-id": for value in filters[tag_filter]: resource_id_filters.append( - re.compile(simple_aws_filter_to_re(value))) - if tag_filter == 'resource-type': + re.compile(simple_aws_filter_to_re(value)) + ) + if tag_filter == "resource-type": for value in filters[tag_filter]: resource_type_filters.append(value) - if tag_filter == 'value': + if tag_filter == "value": for value in filters[tag_filter]: - value_filters.append(re.compile( - simple_aws_filter_to_re(value))) + value_filters.append( + re.compile(simple_aws_filter_to_re(value)) + ) for resource_id, tags in self.tags.items(): for key, value in tags.items(): add_result = False @@ -1064,7 +1170,10 @@ class TagBackend(object): id_pass = True if resource_type_filters: for resource_type in resource_type_filters: - if EC2_PREFIX_TO_RESOURCE[get_prefix(resource_id)] == resource_type: + if ( + EC2_PREFIX_TO_RESOURCE[get_prefix(resource_id)] + == resource_type + ): type_pass = True else: type_pass = True @@ -1079,24 +1188,41 @@ class TagBackend(object): # If we're not filtering, or we are filtering and this if add_result: result = { - 'resource_id': resource_id, - 'key': key, - 'value': value, - 'resource_type': EC2_PREFIX_TO_RESOURCE[get_prefix(resource_id)], + "resource_id": resource_id, + "key": key, + "value": value, + "resource_type": EC2_PREFIX_TO_RESOURCE[ + get_prefix(resource_id) + ], } results.append(result) return results class Ami(TaggedEC2Resource): - def __init__(self, ec2_backend, ami_id, instance=None, source_ami=None, - name=None, description=None, owner_id=OWNER_ID, - public=False, virtualization_type=None, architecture=None, - state='available', creation_date=None, platform=None, - image_type='machine', image_location=None, hypervisor=None, - root_device_type='standard', root_device_name='/dev/sda1', sriov='simple', - region_name='us-east-1a' - ): + def __init__( + self, + ec2_backend, + ami_id, + instance=None, + source_ami=None, + name=None, + description=None, + owner_id=OWNER_ID, + public=False, + virtualization_type=None, + architecture=None, + state="available", + creation_date=None, + platform=None, + image_type="machine", + image_location=None, + hypervisor=None, + root_device_type="standard", + root_device_name="/dev/sda1", + sriov="simple", + region_name="us-east-1a", + ): self.ec2_backend = ec2_backend self.id = ami_id self.state = state @@ -1113,7 +1239,9 @@ class Ami(TaggedEC2Resource): self.root_device_name = root_device_name self.root_device_type = root_device_type self.sriov = sriov - self.creation_date = utc_date_and_time() if creation_date is None else creation_date + self.creation_date = ( + utc_date_and_time() if creation_date is None else creation_date + ) if instance: self.instance = instance @@ -1142,42 +1270,42 @@ class Ami(TaggedEC2Resource): self.launch_permission_users = set() if public: - self.launch_permission_groups.add('all') + self.launch_permission_groups.add("all") # AWS auto-creates these, we should reflect the same. volume = self.ec2_backend.create_volume(15, region_name) self.ebs_snapshot = self.ec2_backend.create_snapshot( - volume.id, "Auto-created snapshot for AMI %s" % self.id, owner_id) + volume.id, "Auto-created snapshot for AMI %s" % self.id, owner_id + ) self.ec2_backend.delete_volume(volume.id) @property def is_public(self): - return 'all' in self.launch_permission_groups + return "all" in self.launch_permission_groups @property def is_public_string(self): return str(self.is_public).lower() def get_filter_value(self, filter_name): - if filter_name == 'virtualization-type': + if filter_name == "virtualization-type": return self.virtualization_type - elif filter_name == 'kernel-id': + elif filter_name == "kernel-id": return self.kernel_id - elif filter_name in ['architecture', 'platform']: + elif filter_name in ["architecture", "platform"]: return getattr(self, filter_name) - elif filter_name == 'image-id': + elif filter_name == "image-id": return self.id - elif filter_name == 'is-public': + elif filter_name == "is-public": return self.is_public_string - elif filter_name == 'state': + elif filter_name == "state": return self.state - elif filter_name == 'name': + elif filter_name == "name": return self.name - elif filter_name == 'owner-id': + elif filter_name == "owner-id": return self.owner_id else: - return super(Ami, self).get_filter_value( - filter_name, 'DescribeImages') + return super(Ami, self).get_filter_value(filter_name, "DescribeImages") class AmiBackend(object): @@ -1193,7 +1321,7 @@ class AmiBackend(object): def _load_amis(self): for ami in AMIS: - ami_id = ami['ami_id'] + ami_id = ami["ami_id"] self.amis[ami_id] = Ami(self, **ami) def create_image(self, instance_id, name=None, description=None, context=None): @@ -1201,35 +1329,51 @@ class AmiBackend(object): ami_id = random_ami_id() instance = self.get_instance(instance_id) - ami = Ami(self, ami_id, instance=instance, source_ami=None, - name=name, description=description, - owner_id=context.get_current_user() if context else OWNER_ID) + ami = Ami( + self, + ami_id, + instance=instance, + source_ami=None, + name=name, + description=description, + owner_id=context.get_current_user() if context else OWNER_ID, + ) self.amis[ami_id] = ami return ami def copy_image(self, source_image_id, source_region, name=None, description=None): source_ami = ec2_backends[source_region].describe_images( - ami_ids=[source_image_id])[0] + ami_ids=[source_image_id] + )[0] ami_id = random_ami_id() - ami = Ami(self, ami_id, instance=None, source_ami=source_ami, - name=name, description=description) + ami = Ami( + self, + ami_id, + instance=None, + source_ami=source_ami, + name=name, + description=description, + ) self.amis[ami_id] = ami return ami - def describe_images(self, ami_ids=(), filters=None, exec_users=None, owners=None, - context=None): + def describe_images( + self, ami_ids=(), filters=None, exec_users=None, owners=None, context=None + ): images = self.amis.values() if len(ami_ids): # boto3 seems to default to just searching based on ami ids if that parameter is passed # and if no images are found, it raises an errors - malformed_ami_ids = [ami_id for ami_id in ami_ids if not ami_id.startswith('ami-')] + malformed_ami_ids = [ + ami_id for ami_id in ami_ids if not ami_id.startswith("ami-") + ] if malformed_ami_ids: raise MalformedAMIIdError(malformed_ami_ids) images = [ami for ami in images if ami.id in ami_ids] if len(images) == 0: - raise InvalidAMIIdError(ami_ids) + raise InvalidAMIIdError(ami_ids) else: # Limit images by launch permissions if exec_users: @@ -1243,10 +1387,14 @@ class AmiBackend(object): # Limit by owner ids if owners: # support filtering by Owners=['self'] - owners = list(map( - lambda o: context.get_current_user() - if context and o == 'self' else o, - owners)) + owners = list( + map( + lambda o: context.get_current_user() + if context and o == "self" + else o, + owners, + ) + ) images = [ami for ami in images if ami.owner_id in owners] # Generic filters @@ -1281,7 +1429,7 @@ class AmiBackend(object): if len(user_id) != 12 or not user_id.isdigit(): raise InvalidAMIAttributeItemValueError("userId", user_id) - if group and group != 'all': + if group and group != "all": raise InvalidAMIAttributeItemValueError("UserGroup", group) def add_launch_permission(self, ami_id, user_ids=None, group=None): @@ -1328,96 +1476,150 @@ class RegionsAndZonesBackend(object): regions = [Region(ri.name, ri.endpoint) for ri in boto.ec2.regions()] zones = { - 'ap-south-1': [ + "ap-south-1": [ Zone(region_name="ap-south-1", name="ap-south-1a", zone_id="aps1-az1"), - Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3") + Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3"), ], - 'eu-west-3': [ + "eu-west-3": [ Zone(region_name="eu-west-3", name="eu-west-3a", zone_id="euw3-az1"), Zone(region_name="eu-west-3", name="eu-west-3b", zone_id="euw3-az2"), - Zone(region_name="eu-west-3", name="eu-west-3c", zone_id="euw3-az3") + Zone(region_name="eu-west-3", name="eu-west-3c", zone_id="euw3-az3"), ], - 'eu-north-1': [ + "eu-north-1": [ Zone(region_name="eu-north-1", name="eu-north-1a", zone_id="eun1-az1"), Zone(region_name="eu-north-1", name="eu-north-1b", zone_id="eun1-az2"), - Zone(region_name="eu-north-1", name="eu-north-1c", zone_id="eun1-az3") + Zone(region_name="eu-north-1", name="eu-north-1c", zone_id="eun1-az3"), ], - 'eu-west-2': [ + "eu-west-2": [ Zone(region_name="eu-west-2", name="eu-west-2a", zone_id="euw2-az2"), Zone(region_name="eu-west-2", name="eu-west-2b", zone_id="euw2-az3"), - Zone(region_name="eu-west-2", name="eu-west-2c", zone_id="euw2-az1") + Zone(region_name="eu-west-2", name="eu-west-2c", zone_id="euw2-az1"), ], - 'eu-west-1': [ + "eu-west-1": [ Zone(region_name="eu-west-1", name="eu-west-1a", zone_id="euw1-az3"), Zone(region_name="eu-west-1", name="eu-west-1b", zone_id="euw1-az1"), - Zone(region_name="eu-west-1", name="eu-west-1c", zone_id="euw1-az2") + Zone(region_name="eu-west-1", name="eu-west-1c", zone_id="euw1-az2"), ], - 'ap-northeast-3': [ - Zone(region_name="ap-northeast-3", name="ap-northeast-2a", zone_id="apne3-az1") + "ap-northeast-3": [ + Zone( + region_name="ap-northeast-3", + name="ap-northeast-2a", + zone_id="apne3-az1", + ) ], - 'ap-northeast-2': [ - Zone(region_name="ap-northeast-2", name="ap-northeast-2a", zone_id="apne2-az1"), - Zone(region_name="ap-northeast-2", name="ap-northeast-2c", zone_id="apne2-az3") + "ap-northeast-2": [ + Zone( + region_name="ap-northeast-2", + name="ap-northeast-2a", + zone_id="apne2-az1", + ), + Zone( + region_name="ap-northeast-2", + name="ap-northeast-2c", + zone_id="apne2-az3", + ), ], - 'ap-northeast-1': [ - Zone(region_name="ap-northeast-1", name="ap-northeast-1a", zone_id="apne1-az4"), - Zone(region_name="ap-northeast-1", name="ap-northeast-1c", zone_id="apne1-az1"), - Zone(region_name="ap-northeast-1", name="ap-northeast-1d", zone_id="apne1-az2") + "ap-northeast-1": [ + Zone( + region_name="ap-northeast-1", + name="ap-northeast-1a", + zone_id="apne1-az4", + ), + Zone( + region_name="ap-northeast-1", + name="ap-northeast-1c", + zone_id="apne1-az1", + ), + Zone( + region_name="ap-northeast-1", + name="ap-northeast-1d", + zone_id="apne1-az2", + ), ], - 'sa-east-1': [ + "sa-east-1": [ Zone(region_name="sa-east-1", name="sa-east-1a", zone_id="sae1-az1"), - Zone(region_name="sa-east-1", name="sa-east-1c", zone_id="sae1-az3") + Zone(region_name="sa-east-1", name="sa-east-1c", zone_id="sae1-az3"), ], - 'ca-central-1': [ + "ca-central-1": [ Zone(region_name="ca-central-1", name="ca-central-1a", zone_id="cac1-az1"), - Zone(region_name="ca-central-1", name="ca-central-1b", zone_id="cac1-az2") + Zone(region_name="ca-central-1", name="ca-central-1b", zone_id="cac1-az2"), ], - 'ap-southeast-1': [ - Zone(region_name="ap-southeast-1", name="ap-southeast-1a", zone_id="apse1-az1"), - Zone(region_name="ap-southeast-1", name="ap-southeast-1b", zone_id="apse1-az2"), - Zone(region_name="ap-southeast-1", name="ap-southeast-1c", zone_id="apse1-az3") + "ap-southeast-1": [ + Zone( + region_name="ap-southeast-1", + name="ap-southeast-1a", + zone_id="apse1-az1", + ), + Zone( + region_name="ap-southeast-1", + name="ap-southeast-1b", + zone_id="apse1-az2", + ), + Zone( + region_name="ap-southeast-1", + name="ap-southeast-1c", + zone_id="apse1-az3", + ), ], - 'ap-southeast-2': [ - Zone(region_name="ap-southeast-2", name="ap-southeast-2a", zone_id="apse2-az1"), - Zone(region_name="ap-southeast-2", name="ap-southeast-2b", zone_id="apse2-az3"), - Zone(region_name="ap-southeast-2", name="ap-southeast-2c", zone_id="apse2-az2") + "ap-southeast-2": [ + Zone( + region_name="ap-southeast-2", + name="ap-southeast-2a", + zone_id="apse2-az1", + ), + Zone( + region_name="ap-southeast-2", + name="ap-southeast-2b", + zone_id="apse2-az3", + ), + Zone( + region_name="ap-southeast-2", + name="ap-southeast-2c", + zone_id="apse2-az2", + ), ], - 'eu-central-1': [ + "eu-central-1": [ Zone(region_name="eu-central-1", name="eu-central-1a", zone_id="euc1-az2"), Zone(region_name="eu-central-1", name="eu-central-1b", zone_id="euc1-az3"), - Zone(region_name="eu-central-1", name="eu-central-1c", zone_id="euc1-az1") + Zone(region_name="eu-central-1", name="eu-central-1c", zone_id="euc1-az1"), ], - 'us-east-1': [ + "us-east-1": [ Zone(region_name="us-east-1", name="us-east-1a", zone_id="use1-az6"), Zone(region_name="us-east-1", name="us-east-1b", zone_id="use1-az1"), Zone(region_name="us-east-1", name="us-east-1c", zone_id="use1-az2"), Zone(region_name="us-east-1", name="us-east-1d", zone_id="use1-az4"), Zone(region_name="us-east-1", name="us-east-1e", zone_id="use1-az3"), - Zone(region_name="us-east-1", name="us-east-1f", zone_id="use1-az5") + Zone(region_name="us-east-1", name="us-east-1f", zone_id="use1-az5"), ], - 'us-east-2': [ + "us-east-2": [ Zone(region_name="us-east-2", name="us-east-2a", zone_id="use2-az1"), Zone(region_name="us-east-2", name="us-east-2b", zone_id="use2-az2"), - Zone(region_name="us-east-2", name="us-east-2c", zone_id="use2-az3") + Zone(region_name="us-east-2", name="us-east-2c", zone_id="use2-az3"), ], - 'us-west-1': [ + "us-west-1": [ Zone(region_name="us-west-1", name="us-west-1a", zone_id="usw1-az3"), - Zone(region_name="us-west-1", name="us-west-1b", zone_id="usw1-az1") + Zone(region_name="us-west-1", name="us-west-1b", zone_id="usw1-az1"), ], - 'us-west-2': [ + "us-west-2": [ Zone(region_name="us-west-2", name="us-west-2a", zone_id="usw2-az2"), Zone(region_name="us-west-2", name="us-west-2b", zone_id="usw2-az1"), - Zone(region_name="us-west-2", name="us-west-2c", zone_id="usw2-az3") + Zone(region_name="us-west-2", name="us-west-2c", zone_id="usw2-az3"), ], - 'cn-north-1': [ + "cn-north-1": [ Zone(region_name="cn-north-1", name="cn-north-1a", zone_id="cnn1-az1"), - Zone(region_name="cn-north-1", name="cn-north-1b", zone_id="cnn1-az2") + Zone(region_name="cn-north-1", name="cn-north-1b", zone_id="cnn1-az2"), + ], + "us-gov-west-1": [ + Zone( + region_name="us-gov-west-1", name="us-gov-west-1a", zone_id="usgw1-az1" + ), + Zone( + region_name="us-gov-west-1", name="us-gov-west-1b", zone_id="usgw1-az2" + ), + Zone( + region_name="us-gov-west-1", name="us-gov-west-1c", zone_id="usgw1-az3" + ), ], - 'us-gov-west-1': [ - Zone(region_name="us-gov-west-1", name="us-gov-west-1a", zone_id="usgw1-az1"), - Zone(region_name="us-gov-west-1", name="us-gov-west-1b", zone_id="usgw1-az2"), - Zone(region_name="us-gov-west-1", name="us-gov-west-1c", zone_id="usgw1-az3") - ] } def describe_regions(self, region_names=[]): @@ -1442,23 +1644,27 @@ class RegionsAndZonesBackend(object): class SecurityRule(object): def __init__(self, ip_protocol, from_port, to_port, ip_ranges, source_groups): self.ip_protocol = ip_protocol - self.from_port = from_port - self.to_port = to_port self.ip_ranges = ip_ranges or [] self.source_groups = source_groups - @property - def unique_representation(self): - return "{0}-{1}-{2}-{3}-{4}".format( - self.ip_protocol, - self.from_port, - self.to_port, - self.ip_ranges, - self.source_groups - ) + if ip_protocol != "-1": + self.from_port = from_port + self.to_port = to_port def __eq__(self, other): - return self.unique_representation == other.unique_representation + if self.ip_protocol != other.ip_protocol: + return False + if self.ip_ranges != other.ip_ranges: + return False + if self.source_groups != other.source_groups: + return False + if self.ip_protocol != "-1": + if self.from_port != other.from_port: + return False + if self.to_port != other.to_port: + return False + + return True class SecurityGroup(TaggedEC2Resource): @@ -1468,20 +1674,22 @@ class SecurityGroup(TaggedEC2Resource): self.name = name self.description = description self.ingress_rules = [] - self.egress_rules = [SecurityRule(-1, None, None, ['0.0.0.0/0'], [])] + self.egress_rules = [SecurityRule("-1", None, None, ["0.0.0.0/0"], [])] self.enis = {} self.vpc_id = vpc_id self.owner_id = OWNER_ID @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - vpc_id = properties.get('VpcId') + vpc_id = properties.get("VpcId") security_group = ec2_backend.create_security_group( name=resource_name, - description=properties.get('GroupDescription'), + description=properties.get("GroupDescription"), vpc_id=vpc_id, ) @@ -1490,15 +1698,15 @@ class SecurityGroup(TaggedEC2Resource): tag_value = tag["Value"] security_group.add_tag(tag_key, tag_value) - for ingress_rule in properties.get('SecurityGroupIngress', []): - source_group_id = ingress_rule.get('SourceSecurityGroupId') + for ingress_rule in properties.get("SecurityGroupIngress", []): + source_group_id = ingress_rule.get("SourceSecurityGroupId") ec2_backend.authorize_security_group_ingress( group_name_or_id=security_group.id, - ip_protocol=ingress_rule['IpProtocol'], - from_port=ingress_rule['FromPort'], - to_port=ingress_rule['ToPort'], - ip_ranges=ingress_rule.get('CidrIp'), + ip_protocol=ingress_rule["IpProtocol"], + from_port=ingress_rule["FromPort"], + to_port=ingress_rule["ToPort"], + ip_ranges=ingress_rule.get("CidrIp"), source_group_ids=[source_group_id], vpc_id=vpc_id, ) @@ -1506,28 +1714,33 @@ class SecurityGroup(TaggedEC2Resource): return security_group @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls._delete_security_group_given_vpc_id( - original_resource.name, original_resource.vpc_id, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, original_resource.vpc_id, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - vpc_id = properties.get('VpcId') - cls._delete_security_group_given_vpc_id( - resource_name, vpc_id, region_name) + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + vpc_id = properties.get("VpcId") + cls._delete_security_group_given_vpc_id(resource_name, vpc_id, region_name) @classmethod def _delete_security_group_given_vpc_id(cls, resource_name, vpc_id, region_name): ec2_backend = ec2_backends[region_name] - security_group = ec2_backend.get_security_group_from_name( - resource_name, vpc_id) + security_group = ec2_backend.get_security_group_from_name(resource_name, vpc_id) if security_group: security_group.delete(region_name) def delete(self, region_name): - ''' Not exposed as part of the ELB API - used for CloudFormation. ''' + """ Not exposed as part of the ELB API - used for CloudFormation. """ self.ec2_backend.delete_security_group(group_id=self.id) @property @@ -1538,18 +1751,18 @@ class SecurityGroup(TaggedEC2Resource): def to_attr(filter_name): attr = None - if filter_name == 'group-name': - attr = 'name' - elif filter_name == 'group-id': - attr = 'id' - elif filter_name == 'vpc-id': - attr = 'vpc_id' + if filter_name == "group-name": + attr = "name" + elif filter_name == "group-id": + attr = "id" + elif filter_name == "vpc-id": + attr = "vpc_id" else: - attr = filter_name.replace('-', '_') + attr = filter_name.replace("-", "_") return attr - if key.startswith('ip-permission'): + if key.startswith("ip-permission"): match = re.search(r"ip-permission.(*)", key) ingress_attr = to_attr(match.groups()[0]) @@ -1575,7 +1788,8 @@ class SecurityGroup(TaggedEC2Resource): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'GroupId': + + if attribute_name == "GroupId": return self.id raise UnformattedGetAttTemplateException() @@ -1590,13 +1804,13 @@ class SecurityGroup(TaggedEC2Resource): def get_number_of_ingress_rules(self): return sum( - len(rule.ip_ranges) + len(rule.source_groups) - for rule in self.ingress_rules) + len(rule.ip_ranges) + len(rule.source_groups) for rule in self.ingress_rules + ) def get_number_of_egress_rules(self): return sum( - len(rule.ip_ranges) + len(rule.source_groups) - for rule in self.egress_rules) + len(rule.ip_ranges) + len(rule.source_groups) for rule in self.egress_rules + ) class SecurityGroupBackend(object): @@ -1611,7 +1825,7 @@ class SecurityGroupBackend(object): def create_security_group(self, name, description, vpc_id=None, force=False): if not description: - raise MissingParameterError('GroupDescription') + raise MissingParameterError("GroupDescription") group_id = random_security_group_id() if not force: @@ -1624,30 +1838,27 @@ class SecurityGroupBackend(object): return group def describe_security_groups(self, group_ids=None, groupnames=None, filters=None): - matches = itertools.chain(*[x.values() - for x in self.groups.values()]) + matches = itertools.chain(*[x.values() for x in self.groups.values()]) if group_ids: - matches = [grp for grp in matches - if grp.id in group_ids] + matches = [grp for grp in matches if grp.id in group_ids] if len(group_ids) > len(matches): unknown_ids = set(group_ids) - set(matches) raise InvalidSecurityGroupNotFoundError(unknown_ids) if groupnames: - matches = [grp for grp in matches - if grp.name in groupnames] + matches = [grp for grp in matches if grp.name in groupnames] if len(groupnames) > len(matches): unknown_names = set(groupnames) - set(matches) raise InvalidSecurityGroupNotFoundError(unknown_names) if filters: - matches = [grp for grp in matches - if grp.matches_filters(filters)] + matches = [grp for grp in matches if grp.matches_filters(filters)] return matches def _delete_security_group(self, vpc_id, group_id): if self.groups[vpc_id][group_id].enis: raise DependencyViolationError( - "{0} is being utilized by {1}".format(group_id, 'ENIs')) + "{0} is being utilized by {1}".format(group_id, "ENIs") + ) return self.groups[vpc_id].pop(group_id) def delete_security_group(self, name=None, group_id=None): @@ -1668,7 +1879,8 @@ class SecurityGroupBackend(object): def get_security_group_from_id(self, group_id): # 2 levels of chaining necessary since it's a complex structure all_groups = itertools.chain.from_iterable( - [x.values() for x in self.groups.values()]) + [x.values() for x in self.groups.values()] + ) for group in all_groups: if group.id == group_id: return group @@ -1685,15 +1897,17 @@ class SecurityGroupBackend(object): group = self.get_security_group_from_name(group_name_or_id, vpc_id) return group - def authorize_security_group_ingress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def authorize_security_group_ingress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) if ip_ranges and not isinstance(ip_ranges, list): ip_ranges = [ip_ranges] @@ -1703,16 +1917,19 @@ class SecurityGroupBackend(object): raise InvalidCIDRSubnetError(cidr=cidr) self._verify_group_will_respect_rule_count_limit( - group, group.get_number_of_ingress_rules(), - ip_ranges, source_group_names, source_group_ids) + group, + group.get_number_of_ingress_rules(), + ip_ranges, + source_group_names, + source_group_ids, + ) source_group_names = source_group_names if source_group_names else [] source_group_ids = source_group_ids if source_group_ids else [] source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1723,25 +1940,27 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) group.add_ingress_rule(security_rule) - def revoke_security_group_ingress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def revoke_security_group_ingress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1751,21 +1970,24 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) if security_rule in group.ingress_rules: group.ingress_rules.remove(security_rule) return security_rule raise InvalidPermissionNotFoundError() - def authorize_security_group_egress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def authorize_security_group_egress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) if ip_ranges and not isinstance(ip_ranges, list): @@ -1776,16 +1998,19 @@ class SecurityGroupBackend(object): raise InvalidCIDRSubnetError(cidr=cidr) self._verify_group_will_respect_rule_count_limit( - group, group.get_number_of_egress_rules(), - ip_ranges, source_group_names, source_group_ids) + group, + group.get_number_of_egress_rules(), + ip_ranges, + source_group_names, + source_group_ids, + ) source_group_names = source_group_names if source_group_names else [] source_group_ids = source_group_ids if source_group_ids else [] source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1796,25 +2021,27 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) group.add_egress_rule(security_rule) - def revoke_security_group_egress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def revoke_security_group_egress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1824,15 +2051,21 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) if security_rule in group.egress_rules: group.egress_rules.remove(security_rule) return security_rule raise InvalidPermissionNotFoundError() def _verify_group_will_respect_rule_count_limit( - self, group, current_rule_nb, - ip_ranges, source_group_names=None, source_group_ids=None): + self, + group, + current_rule_nb, + ip_ranges, + source_group_names=None, + source_group_ids=None, + ): max_nb_rules = 50 if group.vpc_id else 100 future_group_nb_rules = current_rule_nb if ip_ranges: @@ -1851,12 +2084,14 @@ class SecurityGroupIngress(object): self.properties = properties @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - group_name = properties.get('GroupName') - group_id = properties.get('GroupId') + group_name = properties.get("GroupName") + group_id = properties.get("GroupId") ip_protocol = properties.get("IpProtocol") cidr_ip = properties.get("CidrIp") cidr_ipv6 = properties.get("CidrIpv6") @@ -1868,7 +2103,12 @@ class SecurityGroupIngress(object): to_port = properties.get("ToPort") assert group_id or group_name - assert source_security_group_name or cidr_ip or cidr_ipv6 or source_security_group_id + assert ( + source_security_group_name + or cidr_ip + or cidr_ipv6 + or source_security_group_id + ) assert ip_protocol if source_security_group_id: @@ -1886,10 +2126,12 @@ class SecurityGroupIngress(object): if group_id: security_group = ec2_backend.describe_security_groups(group_ids=[group_id])[ - 0] + 0 + ] else: security_group = ec2_backend.describe_security_groups( - groupnames=[group_name])[0] + groupnames=[group_name] + )[0] ec2_backend.authorize_security_group_ingress( group_name_or_id=security_group.id, @@ -1913,23 +2155,27 @@ class VolumeAttachment(object): self.status = status @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - instance_id = properties['InstanceId'] - volume_id = properties['VolumeId'] + instance_id = properties["InstanceId"] + volume_id = properties["VolumeId"] ec2_backend = ec2_backends[region_name] attachment = ec2_backend.attach_volume( volume_id=volume_id, instance_id=instance_id, - device_path=properties['Device'], + device_path=properties["Device"], ) return attachment class Volume(TaggedEC2Resource): - def __init__(self, ec2_backend, volume_id, size, zone, snapshot_id=None, encrypted=False): + def __init__( + self, ec2_backend, volume_id, size, zone, snapshot_id=None, encrypted=False + ): self.id = volume_id self.size = size self.zone = zone @@ -1940,13 +2186,14 @@ class Volume(TaggedEC2Resource): self.encrypted = encrypted @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] volume = ec2_backend.create_volume( - size=properties.get('Size'), - zone_name=properties.get('AvailabilityZone'), + size=properties.get("Size"), zone_name=properties.get("AvailabilityZone") ) return volume @@ -1957,72 +2204,80 @@ class Volume(TaggedEC2Resource): @property def status(self): if self.attachment: - return 'in-use' + return "in-use" else: - return 'available' + return "available" def get_filter_value(self, filter_name): - if filter_name.startswith('attachment') and not self.attachment: + if filter_name.startswith("attachment") and not self.attachment: return None - elif filter_name == 'attachment.attach-time': + elif filter_name == "attachment.attach-time": return self.attachment.attach_time - elif filter_name == 'attachment.device': + elif filter_name == "attachment.device": return self.attachment.device - elif filter_name == 'attachment.instance-id': + elif filter_name == "attachment.instance-id": return self.attachment.instance.id - elif filter_name == 'attachment.status': + elif filter_name == "attachment.status": return self.attachment.status - elif filter_name == 'create-time': + elif filter_name == "create-time": return self.create_time - elif filter_name == 'size': + elif filter_name == "size": return self.size - elif filter_name == 'snapshot-id': + elif filter_name == "snapshot-id": return self.snapshot_id - elif filter_name == 'status': + elif filter_name == "status": return self.status - elif filter_name == 'volume-id': + elif filter_name == "volume-id": return self.id - elif filter_name == 'encrypted': + elif filter_name == "encrypted": return str(self.encrypted).lower() - elif filter_name == 'availability-zone': + elif filter_name == "availability-zone": return self.zone.name else: - return super(Volume, self).get_filter_value( - filter_name, 'DescribeVolumes') + return super(Volume, self).get_filter_value(filter_name, "DescribeVolumes") class Snapshot(TaggedEC2Resource): - def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id=OWNER_ID): + def __init__( + self, + ec2_backend, + snapshot_id, + volume, + description, + encrypted=False, + owner_id=OWNER_ID, + ): self.id = snapshot_id self.volume = volume self.description = description self.start_time = utc_date_and_time() self.create_volume_permission_groups = set() self.ec2_backend = ec2_backend - self.status = 'completed' + self.status = "completed" self.encrypted = encrypted self.owner_id = owner_id def get_filter_value(self, filter_name): - if filter_name == 'description': + if filter_name == "description": return self.description - elif filter_name == 'snapshot-id': + elif filter_name == "snapshot-id": return self.id - elif filter_name == 'start-time': + elif filter_name == "start-time": return self.start_time - elif filter_name == 'volume-id': + elif filter_name == "volume-id": return self.volume.id - elif filter_name == 'volume-size': + elif filter_name == "volume-size": return self.volume.size - elif filter_name == 'encrypted': + elif filter_name == "encrypted": return str(self.encrypted).lower() - elif filter_name == 'status': + elif filter_name == "status": return self.status - elif filter_name == 'owner-id': + elif filter_name == "owner-id": return self.owner_id else: return super(Snapshot, self).get_filter_value( - filter_name, 'DescribeSnapshots') + filter_name, "DescribeSnapshots" + ) class EBSBackend(object): @@ -2048,8 +2303,7 @@ class EBSBackend(object): def describe_volumes(self, volume_ids=None, filters=None): matches = self.volumes.values() if volume_ids: - matches = [vol for vol in matches - if vol.id in volume_ids] + matches = [vol for vol in matches if vol.id in volume_ids] if len(volume_ids) > len(matches): unknown_ids = set(volume_ids) - set(matches) raise InvalidVolumeIdError(unknown_ids) @@ -2075,11 +2329,14 @@ class EBSBackend(object): if not volume or not instance: return False - volume.attachment = VolumeAttachment( - volume, instance, device_path, 'attached') + volume.attachment = VolumeAttachment(volume, instance, device_path, "attached") # Modify instance to capture mount of block device. - bdt = BlockDeviceType(volume_id=volume_id, status=volume.status, size=volume.size, - attach_time=utc_date_and_time()) + bdt = BlockDeviceType( + volume_id=volume_id, + status=volume.status, + size=volume.size, + attach_time=utc_date_and_time(), + ) instance.block_device_mapping[device_path] = bdt return volume.attachment @@ -2090,7 +2347,7 @@ class EBSBackend(object): old_attachment = volume.attachment if not old_attachment: raise InvalidVolumeAttachmentError(volume_id, instance_id) - old_attachment.status = 'detached' + old_attachment.status = "detached" volume.attachment = None return old_attachment @@ -2108,8 +2365,7 @@ class EBSBackend(object): def describe_snapshots(self, snapshot_ids=None, filters=None): matches = self.snapshots.values() if snapshot_ids: - matches = [snap for snap in matches - if snap.id in snapshot_ids] + matches = [snap for snap in matches if snap.id in snapshot_ids] if len(snapshot_ids) > len(matches): unknown_ids = set(snapshot_ids) - set(matches) raise InvalidSnapshotIdError(unknown_ids) @@ -2119,10 +2375,16 @@ class EBSBackend(object): def copy_snapshot(self, source_snapshot_id, source_region, description=None): source_snapshot = ec2_backends[source_region].describe_snapshots( - snapshot_ids=[source_snapshot_id])[0] + snapshot_ids=[source_snapshot_id] + )[0] snapshot_id = random_snapshot_id() - snapshot = Snapshot(self, snapshot_id, volume=source_snapshot.volume, - description=description, encrypted=source_snapshot.encrypted) + snapshot = Snapshot( + self, + snapshot_id, + volume=source_snapshot.volume, + description=description, + encrypted=source_snapshot.encrypted, + ) self.snapshots[snapshot_id] = snapshot return snapshot @@ -2144,9 +2406,10 @@ class EBSBackend(object): def add_create_volume_permission(self, snapshot_id, user_id=None, group=None): if user_id: self.raise_not_implemented_error( - "The UserId parameter for ModifySnapshotAttribute") + "The UserId parameter for ModifySnapshotAttribute" + ) - if group != 'all': + if group != "all": raise InvalidAMIAttributeItemValueError("UserGroup", group) snapshot = self.get_snapshot(snapshot_id) snapshot.create_volume_permission_groups.add(group) @@ -2155,9 +2418,10 @@ class EBSBackend(object): def remove_create_volume_permission(self, snapshot_id, user_id=None, group=None): if user_id: self.raise_not_implemented_error( - "The UserId parameter for ModifySnapshotAttribute") + "The UserId parameter for ModifySnapshotAttribute" + ) - if group != 'all': + if group != "all": raise InvalidAMIAttributeItemValueError("UserGroup", group) snapshot = self.get_snapshot(snapshot_id) snapshot.create_volume_permission_groups.discard(group) @@ -2165,34 +2429,48 @@ class EBSBackend(object): class VPC(TaggedEC2Resource): - def __init__(self, ec2_backend, vpc_id, cidr_block, is_default, instance_tenancy='default', - amazon_provided_ipv6_cidr_block=False): + def __init__( + self, + ec2_backend, + vpc_id, + cidr_block, + is_default, + instance_tenancy="default", + amazon_provided_ipv6_cidr_block=False, + ): self.ec2_backend = ec2_backend self.id = vpc_id self.cidr_block = cidr_block self.cidr_block_association_set = {} self.dhcp_options = None - self.state = 'available' + self.state = "available" self.instance_tenancy = instance_tenancy - self.is_default = 'true' if is_default else 'false' - self.enable_dns_support = 'true' + self.is_default = "true" if is_default else "false" + self.enable_dns_support = "true" + self.classic_link_enabled = "false" + self.classic_link_dns_supported = "false" # This attribute is set to 'true' only for default VPCs # or VPCs created using the wizard of the VPC console - self.enable_dns_hostnames = 'true' if is_default else 'false' + self.enable_dns_hostnames = "true" if is_default else "false" self.associate_vpc_cidr_block(cidr_block) if amazon_provided_ipv6_cidr_block: - self.associate_vpc_cidr_block(cidr_block, amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_block) + self.associate_vpc_cidr_block( + cidr_block, + amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_block, + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] vpc = ec2_backend.create_vpc( - cidr_block=properties['CidrBlock'], - instance_tenancy=properties.get('InstanceTenancy', 'default') + cidr_block=properties["CidrBlock"], + instance_tenancy=properties.get("InstanceTenancy", "default"), ) for tag in properties.get("Tags", []): tag_key = tag["Key"] @@ -2206,58 +2484,112 @@ class VPC(TaggedEC2Resource): return self.id def get_filter_value(self, filter_name): - if filter_name in ('vpc-id', 'vpcId'): + if filter_name in ("vpc-id", "vpcId"): return self.id - elif filter_name in ('cidr', 'cidr-block', 'cidrBlock'): + elif filter_name in ("cidr", "cidr-block", "cidrBlock"): return self.cidr_block - elif filter_name in ('cidr-block-association.cidr-block', 'ipv6-cidr-block-association.ipv6-cidr-block'): - return [c['cidr_block'] for c in self.get_cidr_block_association_set(ipv6='ipv6' in filter_name)] - elif filter_name in ('cidr-block-association.association-id', 'ipv6-cidr-block-association.association-id'): + elif filter_name in ( + "cidr-block-association.cidr-block", + "ipv6-cidr-block-association.ipv6-cidr-block", + ): + return [ + c["cidr_block"] + for c in self.get_cidr_block_association_set(ipv6="ipv6" in filter_name) + ] + elif filter_name in ( + "cidr-block-association.association-id", + "ipv6-cidr-block-association.association-id", + ): return self.cidr_block_association_set.keys() - elif filter_name in ('cidr-block-association.state', 'ipv6-cidr-block-association.state'): - return [c['cidr_block_state']['state'] for c in self.get_cidr_block_association_set(ipv6='ipv6' in filter_name)] - elif filter_name in ('instance_tenancy', 'InstanceTenancy'): + elif filter_name in ( + "cidr-block-association.state", + "ipv6-cidr-block-association.state", + ): + return [ + c["cidr_block_state"]["state"] + for c in self.get_cidr_block_association_set(ipv6="ipv6" in filter_name) + ] + elif filter_name in ("instance_tenancy", "InstanceTenancy"): return self.instance_tenancy - elif filter_name in ('is-default', 'isDefault'): + elif filter_name in ("is-default", "isDefault"): return self.is_default - elif filter_name == 'state': + elif filter_name == "state": return self.state - elif filter_name in ('dhcp-options-id', 'dhcpOptionsId'): + elif filter_name in ("dhcp-options-id", "dhcpOptionsId"): if not self.dhcp_options: return None return self.dhcp_options.id else: - return super(VPC, self).get_filter_value(filter_name, 'DescribeVpcs') + return super(VPC, self).get_filter_value(filter_name, "DescribeVpcs") - def associate_vpc_cidr_block(self, cidr_block, amazon_provided_ipv6_cidr_block=False): + def associate_vpc_cidr_block( + self, cidr_block, amazon_provided_ipv6_cidr_block=False + ): max_associations = 5 if not amazon_provided_ipv6_cidr_block else 1 - if len(self.get_cidr_block_association_set(amazon_provided_ipv6_cidr_block)) >= max_associations: + if ( + len(self.get_cidr_block_association_set(amazon_provided_ipv6_cidr_block)) + >= max_associations + ): raise CidrLimitExceeded(self.id, max_associations) association_id = random_vpc_cidr_association_id() association_set = { - 'association_id': association_id, - 'cidr_block_state': {'state': 'associated', 'StatusMessage': ''} + "association_id": association_id, + "cidr_block_state": {"state": "associated", "StatusMessage": ""}, } - association_set['cidr_block'] = random_ipv6_cidr() if amazon_provided_ipv6_cidr_block else cidr_block + association_set["cidr_block"] = ( + random_ipv6_cidr() if amazon_provided_ipv6_cidr_block else cidr_block + ) self.cidr_block_association_set[association_id] = association_set return association_set + def enable_vpc_classic_link(self): + # Check if current cidr block doesn't fall within the 10.0.0.0/8 block, excluding 10.0.0.0/16 and 10.1.0.0/16. + # Doesn't check any route tables, maybe something for in the future? + # See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/vpc-classiclink.html#classiclink-limitations + network_address = ipaddress.ip_network(self.cidr_block).network_address + if ( + network_address not in ipaddress.ip_network("10.0.0.0/8") + or network_address in ipaddress.ip_network("10.0.0.0/16") + or network_address in ipaddress.ip_network("10.1.0.0/16") + ): + self.classic_link_enabled = "true" + + return self.classic_link_enabled + + def disable_vpc_classic_link(self): + self.classic_link_enabled = "false" + return self.classic_link_enabled + + def enable_vpc_classic_link_dns_support(self): + self.classic_link_dns_supported = "true" + return self.classic_link_dns_supported + + def disable_vpc_classic_link_dns_support(self): + self.classic_link_dns_supported = "false" + return self.classic_link_dns_supported + def disassociate_vpc_cidr_block(self, association_id): - if self.cidr_block == self.cidr_block_association_set.get(association_id, {}).get('cidr_block'): + if self.cidr_block == self.cidr_block_association_set.get( + association_id, {} + ).get("cidr_block"): raise OperationNotPermitted(association_id) response = self.cidr_block_association_set.pop(association_id, {}) if response: - response['vpc_id'] = self.id - response['cidr_block_state']['state'] = 'disassociating' + response["vpc_id"] = self.id + response["cidr_block_state"]["state"] = "disassociating" return response def get_cidr_block_association_set(self, ipv6=False): - return [c for c in self.cidr_block_association_set.values() if ('::/' if ipv6 else '.') in c.get('cidr_block')] + return [ + c + for c in self.cidr_block_association_set.values() + if ("::/" if ipv6 else ".") in c.get("cidr_block") + ] class VPCBackend(object): @@ -2275,15 +2607,29 @@ class VPCBackend(object): if inst is not None: yield inst - def create_vpc(self, cidr_block, instance_tenancy='default', amazon_provided_ipv6_cidr_block=False): + def create_vpc( + self, + cidr_block, + instance_tenancy="default", + amazon_provided_ipv6_cidr_block=False, + ): vpc_id = random_vpc_id() try: - vpc_cidr_block = ipaddress.IPv4Network(six.text_type(cidr_block), strict=False) + vpc_cidr_block = ipaddress.IPv4Network( + six.text_type(cidr_block), strict=False + ) except ValueError: raise InvalidCIDRBlockParameterError(cidr_block) if vpc_cidr_block.prefixlen < 16 or vpc_cidr_block.prefixlen > 28: raise InvalidVPCRangeError(cidr_block) - vpc = VPC(self, vpc_id, cidr_block, len(self.vpcs) == 0, instance_tenancy, amazon_provided_ipv6_cidr_block) + vpc = VPC( + self, + vpc_id, + cidr_block, + len(self.vpcs) == 0, + instance_tenancy, + amazon_provided_ipv6_cidr_block, + ) self.vpcs[vpc_id] = vpc # AWS creates a default main route table and security group. @@ -2292,10 +2638,11 @@ class VPCBackend(object): # AWS creates a default Network ACL self.create_network_acl(vpc_id, default=True) - default = self.get_security_group_from_name('default', vpc_id=vpc_id) + default = self.get_security_group_from_name("default", vpc_id=vpc_id) if not default: self.create_security_group( - 'default', 'default VPC security group', vpc_id=vpc_id) + "default", "default VPC security group", vpc_id=vpc_id + ) return vpc @@ -2314,8 +2661,7 @@ class VPCBackend(object): def get_all_vpcs(self, vpc_ids=None, filters=None): matches = self.vpcs.values() if vpc_ids: - matches = [vpc for vpc in matches - if vpc.id in vpc_ids] + matches = [vpc for vpc in matches if vpc.id in vpc_ids] if len(vpc_ids) > len(matches): unknown_ids = set(vpc_ids) - set(matches) raise InvalidVPCIdError(unknown_ids) @@ -2325,7 +2671,7 @@ class VPCBackend(object): def delete_vpc(self, vpc_id): # Delete route table if only main route table remains. - route_tables = self.get_all_route_tables(filters={'vpc-id': vpc_id}) + route_tables = self.get_all_route_tables(filters={"vpc-id": vpc_id}) if len(route_tables) > 1: raise DependencyViolationError( "The vpc {0} has dependencies and cannot be deleted.".format(vpc_id) @@ -2334,7 +2680,7 @@ class VPCBackend(object): self.delete_route_table(route_table.id) # Delete default security group if exists. - default = self.get_security_group_from_name('default', vpc_id=vpc_id) + default = self.get_security_group_from_name("default", vpc_id=vpc_id) if default: self.delete_security_group(group_id=default.id) @@ -2351,14 +2697,30 @@ class VPCBackend(object): def describe_vpc_attribute(self, vpc_id, attr_name): vpc = self.get_vpc(vpc_id) - if attr_name in ('enable_dns_support', 'enable_dns_hostnames'): + if attr_name in ("enable_dns_support", "enable_dns_hostnames"): return getattr(vpc, attr_name) else: raise InvalidParameterValueError(attr_name) + def enable_vpc_classic_link(self, vpc_id): + vpc = self.get_vpc(vpc_id) + return vpc.enable_vpc_classic_link() + + def disable_vpc_classic_link(self, vpc_id): + vpc = self.get_vpc(vpc_id) + return vpc.disable_vpc_classic_link() + + def enable_vpc_classic_link_dns_support(self, vpc_id): + vpc = self.get_vpc(vpc_id) + return vpc.enable_vpc_classic_link_dns_support() + + def disable_vpc_classic_link_dns_support(self, vpc_id): + vpc = self.get_vpc(vpc_id) + return vpc.disable_vpc_classic_link_dns_support() + def modify_vpc_attribute(self, vpc_id, attr_name, attr_value): vpc = self.get_vpc(vpc_id) - if attr_name in ('enable_dns_support', 'enable_dns_hostnames'): + if attr_name in ("enable_dns_support", "enable_dns_hostnames"): setattr(vpc, attr_name, attr_value) else: raise InvalidParameterValueError(attr_name) @@ -2371,35 +2733,37 @@ class VPCBackend(object): else: raise InvalidVpcCidrBlockAssociationIdError(association_id) - def associate_vpc_cidr_block(self, vpc_id, cidr_block, amazon_provided_ipv6_cidr_block): + def associate_vpc_cidr_block( + self, vpc_id, cidr_block, amazon_provided_ipv6_cidr_block + ): vpc = self.get_vpc(vpc_id) return vpc.associate_vpc_cidr_block(cidr_block, amazon_provided_ipv6_cidr_block) class VPCPeeringConnectionStatus(object): - def __init__(self, code='initiating-request', message=''): + def __init__(self, code="initiating-request", message=""): self.code = code self.message = message def deleted(self): - self.code = 'deleted' - self.message = 'Deleted by {deleter ID}' + self.code = "deleted" + self.message = "Deleted by {deleter ID}" def initiating(self): - self.code = 'initiating-request' - self.message = 'Initiating Request to {accepter ID}' + self.code = "initiating-request" + self.message = "Initiating Request to {accepter ID}" def pending(self): - self.code = 'pending-acceptance' - self.message = 'Pending Acceptance by {accepter ID}' + self.code = "pending-acceptance" + self.message = "Pending Acceptance by {accepter ID}" def accept(self): - self.code = 'active' - self.message = 'Active' + self.code = "active" + self.message = "Active" def reject(self): - self.code = 'rejected' - self.message = 'Inactive' + self.code = "rejected" + self.message = "Inactive" class VPCPeeringConnection(TaggedEC2Resource): @@ -2410,12 +2774,14 @@ class VPCPeeringConnection(TaggedEC2Resource): self._status = VPCPeeringConnectionStatus() @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - vpc = ec2_backend.get_vpc(properties['VpcId']) - peer_vpc = ec2_backend.get_vpc(properties['PeerVpcId']) + vpc = ec2_backend.get_vpc(properties["VpcId"]) + peer_vpc = ec2_backend.get_vpc(properties["PeerVpcId"]) vpc_pcx = ec2_backend.create_vpc_peering_connection(vpc, peer_vpc) @@ -2474,7 +2840,7 @@ class VPCPeeringConnectionBackend(object): pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region) - if vpc_pcx._status.code != 'pending-acceptance': + if vpc_pcx._status.code != "pending-acceptance": raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) vpc_pcx._status.accept() return vpc_pcx @@ -2486,20 +2852,33 @@ class VPCPeeringConnectionBackend(object): pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region) - if vpc_pcx._status.code != 'pending-acceptance': + if vpc_pcx._status.code != "pending-acceptance": raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) vpc_pcx._status.reject() return vpc_pcx class Subnet(TaggedEC2Resource): - def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone, default_for_az, - map_public_ip_on_launch, owner_id=OWNER_ID, assign_ipv6_address_on_creation=False): + def __init__( + self, + ec2_backend, + subnet_id, + vpc_id, + cidr_block, + availability_zone, + default_for_az, + map_public_ip_on_launch, + owner_id=OWNER_ID, + assign_ipv6_address_on_creation=False, + ): self.ec2_backend = ec2_backend self.id = subnet_id self.vpc_id = vpc_id self.cidr_block = cidr_block self.cidr = ipaddress.IPv4Network(six.text_type(self.cidr_block), strict=False) + self._available_ip_addresses = ( + ipaddress.IPv4Network(six.text_type(self.cidr_block)).num_addresses - 5 + ) self._availability_zone = availability_zone self.default_for_az = default_for_az self.map_public_ip_on_launch = map_public_ip_on_launch @@ -2509,22 +2888,24 @@ class Subnet(TaggedEC2Resource): # Theory is we assign ip's as we go (as 16,777,214 usable IPs in a /8) self._subnet_ip_generator = self.cidr.hosts() - self.reserved_ips = [six.next(self._subnet_ip_generator) for _ in range(0, 3)] # Reserved by AWS + self.reserved_ips = [ + six.next(self._subnet_ip_generator) for _ in range(0, 3) + ] # Reserved by AWS self._unused_ips = set() # if instance is destroyed hold IP here for reuse self._subnet_ips = {} # has IP: instance @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - vpc_id = properties['VpcId'] - cidr_block = properties['CidrBlock'] - availability_zone = properties.get('AvailabilityZone') + vpc_id = properties["VpcId"] + cidr_block = properties["CidrBlock"] + availability_zone = properties.get("AvailabilityZone") ec2_backend = ec2_backends[region_name] subnet = ec2_backend.create_subnet( - vpc_id=vpc_id, - cidr_block=cidr_block, - availability_zone=availability_zone, + vpc_id=vpc_id, cidr_block=cidr_block, availability_zone=availability_zone ) for tag in properties.get("Tags", []): tag_key = tag["Key"] @@ -2533,6 +2914,21 @@ class Subnet(TaggedEC2Resource): return subnet + @property + def available_ip_addresses(self): + enis = [ + eni + for eni in self.ec2_backend.get_all_network_interfaces() + if eni.subnet.id == self.id + ] + addresses_taken = [ + eni.private_ip_address for eni in enis if eni.private_ip_address + ] + for eni in enis: + if eni.private_ip_addresses: + addresses_taken.extend(eni.private_ip_addresses) + return str(self._available_ip_addresses - len(addresses_taken)) + @property def availability_zone(self): return self._availability_zone.name @@ -2558,25 +2954,24 @@ class Subnet(TaggedEC2Resource): Taken from: http://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeSubnets.html """ - if filter_name in ('cidr', 'cidrBlock', 'cidr-block'): + if filter_name in ("cidr", "cidrBlock", "cidr-block"): return self.cidr_block - elif filter_name in ('vpc-id', 'vpcId'): + elif filter_name in ("vpc-id", "vpcId"): return self.vpc_id - elif filter_name == 'subnet-id': + elif filter_name == "subnet-id": return self.id - elif filter_name in ('availabilityZone', 'availability-zone'): + elif filter_name in ("availabilityZone", "availability-zone"): return self.availability_zone - elif filter_name in ('defaultForAz', 'default-for-az'): + elif filter_name in ("defaultForAz", "default-for-az"): return self.default_for_az else: - return super(Subnet, self).get_filter_value( - filter_name, 'DescribeSubnets') + return super(Subnet, self).get_filter_value(filter_name, "DescribeSubnets") def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'AvailabilityZone': - raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "AvailabilityZone" ]"') + + if attribute_name == "AvailabilityZone": + raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "AvailabilityZone" ]"') raise UnformattedGetAttTemplateException() def get_available_subnet_ip(self, instance): @@ -2600,10 +2995,12 @@ class Subnet(TaggedEC2Resource): def request_ip(self, ip, instance): if ipaddress.ip_address(ip) not in self.cidr: - raise Exception('IP does not fall in the subnet CIDR of {0}'.format(self.cidr)) + raise Exception( + "IP does not fall in the subnet CIDR of {0}".format(self.cidr) + ) if ip in self._subnet_ips: - raise Exception('IP already in use') + raise Exception("IP already in use") try: self._unused_ips.remove(ip) except KeyError: @@ -2634,17 +3031,25 @@ class SubnetBackend(object): def create_subnet(self, vpc_id, cidr_block, availability_zone, context=None): subnet_id = random_subnet_id() - vpc = self.get_vpc(vpc_id) # Validate VPC exists and the supplied CIDR block is a subnet of the VPC's - vpc_cidr_block = ipaddress.IPv4Network(six.text_type(vpc.cidr_block), strict=False) + vpc = self.get_vpc( + vpc_id + ) # Validate VPC exists and the supplied CIDR block is a subnet of the VPC's + vpc_cidr_block = ipaddress.IPv4Network( + six.text_type(vpc.cidr_block), strict=False + ) try: - subnet_cidr_block = ipaddress.IPv4Network(six.text_type(cidr_block), strict=False) + subnet_cidr_block = ipaddress.IPv4Network( + six.text_type(cidr_block), strict=False + ) except ValueError: raise InvalidCIDRBlockParameterError(cidr_block) - if not (vpc_cidr_block.network_address <= subnet_cidr_block.network_address and - vpc_cidr_block.broadcast_address >= subnet_cidr_block.broadcast_address): + if not ( + vpc_cidr_block.network_address <= subnet_cidr_block.network_address + and vpc_cidr_block.broadcast_address >= subnet_cidr_block.broadcast_address + ): raise InvalidSubnetRangeError(cidr_block) - for subnet in self.get_all_subnets(filters={'vpc-id': vpc_id}): + for subnet in self.get_all_subnets(filters={"vpc-id": vpc_id}): if subnet.cidr.overlaps(subnet_cidr_block): raise InvalidSubnetConflictError(cidr_block) @@ -2653,14 +3058,36 @@ class SubnetBackend(object): default_for_az = str(availability_zone not in self.subnets).lower() map_public_ip_on_launch = default_for_az if availability_zone is None: - availability_zone = 'us-east-1a' + availability_zone = "us-east-1a" try: - availability_zone_data = next(zone for zones in RegionsAndZonesBackend.zones.values() for zone in zones if zone.name == availability_zone) + availability_zone_data = next( + zone + for zones in RegionsAndZonesBackend.zones.values() + for zone in zones + if zone.name == availability_zone + ) except StopIteration: - raise InvalidAvailabilityZoneError(availability_zone, ", ".join([zone.name for zones in RegionsAndZonesBackend.zones.values() for zone in zones])) - subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone_data, - default_for_az, map_public_ip_on_launch, - owner_id=context.get_current_user() if context else OWNER_ID, assign_ipv6_address_on_creation=False) + raise InvalidAvailabilityZoneError( + availability_zone, + ", ".join( + [ + zone.name + for zones in RegionsAndZonesBackend.zones.values() + for zone in zones + ] + ), + ) + subnet = Subnet( + self, + subnet_id, + vpc_id, + cidr_block, + availability_zone_data, + default_for_az, + map_public_ip_on_launch, + owner_id=context.get_current_user() if context else OWNER_ID, + assign_ipv6_address_on_creation=False, + ) # AWS associates a new subnet with the default Network ACL self.associate_default_network_acl_with_subnet(subnet_id, vpc_id) @@ -2669,11 +3096,9 @@ class SubnetBackend(object): def get_all_subnets(self, subnet_ids=None, filters=None): # Extract a list of all subnets - matches = itertools.chain(*[x.values() - for x in self.subnets.values()]) + matches = itertools.chain(*[x.values() for x in self.subnets.values()]) if subnet_ids: - matches = [sn for sn in matches - if sn.id in subnet_ids] + matches = [sn for sn in matches if sn.id in subnet_ids] if len(subnet_ids) > len(matches): unknown_ids = set(subnet_ids) - set(matches) raise InvalidSubnetIdError(unknown_ids) @@ -2690,7 +3115,7 @@ class SubnetBackend(object): def modify_subnet_attribute(self, subnet_id, attr_name, attr_value): subnet = self.get_subnet(subnet_id) - if attr_name in ('map_public_ip_on_launch', 'assign_ipv6_address_on_creation'): + if attr_name in ("map_public_ip_on_launch", "assign_ipv6_address_on_creation"): setattr(subnet, attr_name, attr_value) else: raise InvalidParameterValueError(attr_name) @@ -2702,16 +3127,17 @@ class SubnetRouteTableAssociation(object): self.subnet_id = subnet_id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - route_table_id = properties['RouteTableId'] - subnet_id = properties['SubnetId'] + route_table_id = properties["RouteTableId"] + subnet_id = properties["SubnetId"] ec2_backend = ec2_backends[region_name] subnet_association = ec2_backend.create_subnet_association( - route_table_id=route_table_id, - subnet_id=subnet_id, + route_table_id=route_table_id, subnet_id=subnet_id ) return subnet_association @@ -2722,10 +3148,10 @@ class SubnetRouteTableAssociationBackend(object): super(SubnetRouteTableAssociationBackend, self).__init__() def create_subnet_association(self, route_table_id, subnet_id): - subnet_association = SubnetRouteTableAssociation( - route_table_id, subnet_id) - self.subnet_associations["{0}:{1}".format( - route_table_id, subnet_id)] = subnet_association + subnet_association = SubnetRouteTableAssociation(route_table_id, subnet_id) + self.subnet_associations[ + "{0}:{1}".format(route_table_id, subnet_id) + ] = subnet_association return subnet_association @@ -2739,14 +3165,14 @@ class RouteTable(TaggedEC2Resource): self.routes = {} @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - vpc_id = properties['VpcId'] + vpc_id = properties["VpcId"] ec2_backend = ec2_backends[region_name] - route_table = ec2_backend.create_route_table( - vpc_id=vpc_id, - ) + route_table = ec2_backend.create_route_table(vpc_id=vpc_id) return route_table @property @@ -2758,9 +3184,9 @@ class RouteTable(TaggedEC2Resource): # Note: Boto only supports 'true'. # https://github.com/boto/boto/issues/1742 if self.main: - return 'true' + return "true" else: - return 'false' + return "false" elif filter_name == "route-table-id": return self.id elif filter_name == "vpc-id": @@ -2773,7 +3199,8 @@ class RouteTable(TaggedEC2Resource): return self.associations.values() else: return super(RouteTable, self).get_filter_value( - filter_name, 'DescribeRouteTables') + filter_name, "DescribeRouteTables" + ) class RouteTableBackend(object): @@ -2803,10 +3230,16 @@ class RouteTableBackend(object): if route_table_ids: route_tables = [ - route_table for route_table in route_tables if route_table.id in route_table_ids] + route_table + for route_table in route_tables + if route_table.id in route_table_ids + ] if len(route_tables) != len(route_table_ids): - invalid_id = list(set(route_table_ids).difference( - set([route_table.id for route_table in route_tables])))[0] + invalid_id = list( + set(route_table_ids).difference( + set([route_table.id for route_table in route_tables]) + ) + )[0] raise InvalidRouteTableIdError(invalid_id) return generic_filter(filters, route_tables) @@ -2815,7 +3248,9 @@ class RouteTableBackend(object): route_table = self.get_route_table(route_table_id) if route_table.associations: raise DependencyViolationError( - "The routeTable '{0}' has dependencies and cannot be deleted.".format(route_table_id) + "The routeTable '{0}' has dependencies and cannot be deleted.".format( + route_table_id + ) ) self.route_tables.pop(route_table_id) return True @@ -2823,9 +3258,12 @@ class RouteTableBackend(object): def associate_route_table(self, route_table_id, subnet_id): # Idempotent if association already exists. route_tables_by_subnet = self.get_all_route_tables( - filters={'association.subnet-id': [subnet_id]}) + filters={"association.subnet-id": [subnet_id]} + ) if route_tables_by_subnet: - for association_id, check_subnet_id in route_tables_by_subnet[0].associations.items(): + for association_id, check_subnet_id in route_tables_by_subnet[ + 0 + ].associations.items(): if subnet_id == check_subnet_id: return association_id @@ -2850,7 +3288,8 @@ class RouteTableBackend(object): # Find route table which currently has the association, error if none. route_tables_by_association_id = self.get_all_route_tables( - filters={'association.route-table-association-id': [association_id]}) + filters={"association.route-table-association-id": [association_id]} + ) if not route_tables_by_association_id: raise InvalidAssociationIdError(association_id) @@ -2861,33 +3300,47 @@ class RouteTableBackend(object): class Route(object): - def __init__(self, route_table, destination_cidr_block, local=False, - gateway=None, instance=None, interface=None, vpc_pcx=None): + def __init__( + self, + route_table, + destination_cidr_block, + local=False, + gateway=None, + instance=None, + nat_gateway=None, + interface=None, + vpc_pcx=None, + ): self.id = generate_route_id(route_table.id, destination_cidr_block) self.route_table = route_table self.destination_cidr_block = destination_cidr_block self.local = local self.gateway = gateway self.instance = instance + self.nat_gateway = nat_gateway self.interface = interface self.vpc_pcx = vpc_pcx @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - gateway_id = properties.get('GatewayId') - instance_id = properties.get('InstanceId') - interface_id = properties.get('NetworkInterfaceId') - pcx_id = properties.get('VpcPeeringConnectionId') + gateway_id = properties.get("GatewayId") + instance_id = properties.get("InstanceId") + interface_id = properties.get("NetworkInterfaceId") + nat_gateway_id = properties.get("NatGatewayId") + pcx_id = properties.get("VpcPeeringConnectionId") - route_table_id = properties['RouteTableId'] + route_table_id = properties["RouteTableId"] ec2_backend = ec2_backends[region_name] route_table = ec2_backend.create_route( route_table_id=route_table_id, - destination_cidr_block=properties.get('DestinationCidrBlock'), + destination_cidr_block=properties.get("DestinationCidrBlock"), gateway_id=gateway_id, instance_id=instance_id, + nat_gateway_id=nat_gateway_id, interface_id=interface_id, vpc_peering_connection_id=pcx_id, ) @@ -2898,20 +3351,27 @@ class RouteBackend(object): def __init__(self): super(RouteBackend, self).__init__() - def create_route(self, route_table_id, destination_cidr_block, local=False, - gateway_id=None, instance_id=None, interface_id=None, - vpc_peering_connection_id=None): + def create_route( + self, + route_table_id, + destination_cidr_block, + local=False, + gateway_id=None, + instance_id=None, + nat_gateway_id=None, + interface_id=None, + vpc_peering_connection_id=None, + ): route_table = self.get_route_table(route_table_id) if interface_id: - self.raise_not_implemented_error( - "CreateRoute to NetworkInterfaceId") + self.raise_not_implemented_error("CreateRoute to NetworkInterfaceId") gateway = None if gateway_id: - if EC2_RESOURCE_TO_PREFIX['vpn-gateway'] in gateway_id: + if EC2_RESOURCE_TO_PREFIX["vpn-gateway"] in gateway_id: gateway = self.get_vpn_gateway(gateway_id) - elif EC2_RESOURCE_TO_PREFIX['internet-gateway'] in gateway_id: + elif EC2_RESOURCE_TO_PREFIX["internet-gateway"] in gateway_id: gateway = self.get_internet_gateway(gateway_id) try: @@ -2919,39 +3379,55 @@ class RouteBackend(object): except ValueError: raise InvalidDestinationCIDRBlockParameterError(destination_cidr_block) - route = Route(route_table, destination_cidr_block, local=local, - gateway=gateway, - instance=self.get_instance( - instance_id) if instance_id else None, - interface=None, - vpc_pcx=self.get_vpc_peering_connection( - vpc_peering_connection_id) if vpc_peering_connection_id else None) + nat_gateway = None + if nat_gateway_id is not None: + nat_gateway = self.nat_gateways.get(nat_gateway_id) + + route = Route( + route_table, + destination_cidr_block, + local=local, + gateway=gateway, + instance=self.get_instance(instance_id) if instance_id else None, + nat_gateway=nat_gateway, + interface=None, + vpc_pcx=self.get_vpc_peering_connection(vpc_peering_connection_id) + if vpc_peering_connection_id + else None, + ) route_table.routes[route.id] = route return route - def replace_route(self, route_table_id, destination_cidr_block, - gateway_id=None, instance_id=None, interface_id=None, - vpc_peering_connection_id=None): + def replace_route( + self, + route_table_id, + destination_cidr_block, + gateway_id=None, + instance_id=None, + interface_id=None, + vpc_peering_connection_id=None, + ): route_table = self.get_route_table(route_table_id) route_id = generate_route_id(route_table.id, destination_cidr_block) route = route_table.routes[route_id] if interface_id: - self.raise_not_implemented_error( - "ReplaceRoute to NetworkInterfaceId") + self.raise_not_implemented_error("ReplaceRoute to NetworkInterfaceId") route.gateway = None if gateway_id: - if EC2_RESOURCE_TO_PREFIX['vpn-gateway'] in gateway_id: + if EC2_RESOURCE_TO_PREFIX["vpn-gateway"] in gateway_id: route.gateway = self.get_vpn_gateway(gateway_id) - elif EC2_RESOURCE_TO_PREFIX['internet-gateway'] in gateway_id: + elif EC2_RESOURCE_TO_PREFIX["internet-gateway"] in gateway_id: route.gateway = self.get_internet_gateway(gateway_id) - route.instance = self.get_instance( - instance_id) if instance_id else None + route.instance = self.get_instance(instance_id) if instance_id else None route.interface = None - route.vpc_pcx = self.get_vpc_peering_connection( - vpc_peering_connection_id) if vpc_peering_connection_id else None + route.vpc_pcx = ( + self.get_vpc_peering_connection(vpc_peering_connection_id) + if vpc_peering_connection_id + else None + ) route_table.routes[route.id] = route return route @@ -2977,7 +3453,9 @@ class InternetGateway(TaggedEC2Resource): self.vpc = None @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] return ec2_backend.create_internet_gateway() @@ -3052,16 +3530,18 @@ class VPCGatewayAttachment(BaseModel): self.vpc_id = vpc_id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] attachment = ec2_backend.create_vpc_gateway_attachment( - gateway_id=properties['InternetGatewayId'], - vpc_id=properties['VpcId'], + gateway_id=properties["InternetGatewayId"], vpc_id=properties["VpcId"] ) ec2_backend.attach_internet_gateway( - properties['InternetGatewayId'], properties['VpcId']) + properties["InternetGatewayId"], properties["VpcId"] + ) return attachment @property @@ -3081,11 +3561,30 @@ class VPCGatewayAttachmentBackend(object): class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource): - def __init__(self, ec2_backend, spot_request_id, price, image_id, type, - valid_from, valid_until, launch_group, availability_zone_group, - key_name, security_groups, user_data, instance_type, placement, - kernel_id, ramdisk_id, monitoring_enabled, subnet_id, tags, spot_fleet_id, - **kwargs): + def __init__( + self, + ec2_backend, + spot_request_id, + price, + image_id, + type, + valid_from, + valid_until, + launch_group, + availability_zone_group, + key_name, + security_groups, + user_data, + instance_type, + placement, + kernel_id, + ramdisk_id, + monitoring_enabled, + subnet_id, + tags, + spot_fleet_id, + **kwargs + ): super(SpotInstanceRequest, self).__init__(**kwargs) ls = LaunchSpecification() self.ec2_backend = ec2_backend @@ -3112,30 +3611,31 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource): if security_groups: for group_name in security_groups: - group = self.ec2_backend.get_security_group_from_name( - group_name) + group = self.ec2_backend.get_security_group_from_name(group_name) if group: ls.groups.append(group) else: # If not security groups, add the default - default_group = self.ec2_backend.get_security_group_from_name( - "default") + default_group = self.ec2_backend.get_security_group_from_name("default") ls.groups.append(default_group) self.instance = self.launch_instance() def get_filter_value(self, filter_name): - if filter_name == 'state': + if filter_name == "state": return self.state - elif filter_name == 'spot-instance-request-id': + elif filter_name == "spot-instance-request-id": return self.id else: return super(SpotInstanceRequest, self).get_filter_value( - filter_name, 'DescribeSpotInstanceRequests') + filter_name, "DescribeSpotInstanceRequests" + ) def launch_instance(self): reservation = self.ec2_backend.add_instances( - image_id=self.launch_specification.image_id, count=1, user_data=self.user_data, + image_id=self.launch_specification.image_id, + count=1, + user_data=self.user_data, instance_type=self.launch_specification.instance_type, subnet_id=self.launch_specification.subnet_id, key_name=self.launch_specification.key_name, @@ -3154,25 +3654,59 @@ class SpotRequestBackend(object): self.spot_instance_requests = {} super(SpotRequestBackend, self).__init__() - def request_spot_instances(self, price, image_id, count, type, valid_from, - valid_until, launch_group, availability_zone_group, - key_name, security_groups, user_data, - instance_type, placement, kernel_id, ramdisk_id, - monitoring_enabled, subnet_id, tags=None, spot_fleet_id=None): + def request_spot_instances( + self, + price, + image_id, + count, + type, + valid_from, + valid_until, + launch_group, + availability_zone_group, + key_name, + security_groups, + user_data, + instance_type, + placement, + kernel_id, + ramdisk_id, + monitoring_enabled, + subnet_id, + tags=None, + spot_fleet_id=None, + ): requests = [] tags = tags or {} for _ in range(count): spot_request_id = random_spot_request_id() - request = SpotInstanceRequest(self, - spot_request_id, price, image_id, type, valid_from, valid_until, - launch_group, availability_zone_group, key_name, security_groups, - user_data, instance_type, placement, kernel_id, ramdisk_id, - monitoring_enabled, subnet_id, tags, spot_fleet_id) + request = SpotInstanceRequest( + self, + spot_request_id, + price, + image_id, + type, + valid_from, + valid_until, + launch_group, + availability_zone_group, + key_name, + security_groups, + user_data, + instance_type, + placement, + kernel_id, + ramdisk_id, + monitoring_enabled, + subnet_id, + tags, + spot_fleet_id, + ) self.spot_instance_requests[spot_request_id] = request requests.append(request) return requests - @Model.prop('SpotInstanceRequest') + @Model.prop("SpotInstanceRequest") def describe_spot_instance_requests(self, filters=None): requests = self.spot_instance_requests.values() @@ -3186,9 +3720,21 @@ class SpotRequestBackend(object): class SpotFleetLaunchSpec(object): - def __init__(self, ebs_optimized, group_set, iam_instance_profile, image_id, - instance_type, key_name, monitoring, spot_price, subnet_id, tag_specifications, - user_data, weighted_capacity): + def __init__( + self, + ebs_optimized, + group_set, + iam_instance_profile, + image_id, + instance_type, + key_name, + monitoring, + spot_price, + subnet_id, + tag_specifications, + user_data, + weighted_capacity, + ): self.ebs_optimized = ebs_optimized self.group_set = group_set self.iam_instance_profile = iam_instance_profile @@ -3204,8 +3750,16 @@ class SpotFleetLaunchSpec(object): class SpotFleetRequest(TaggedEC2Resource): - def __init__(self, ec2_backend, spot_fleet_request_id, spot_price, - target_capacity, iam_fleet_role, allocation_strategy, launch_specs): + def __init__( + self, + ec2_backend, + spot_fleet_request_id, + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ): self.ec2_backend = ec2_backend self.id = spot_fleet_request_id @@ -3218,21 +3772,23 @@ class SpotFleetRequest(TaggedEC2Resource): self.launch_specs = [] for spec in launch_specs: - self.launch_specs.append(SpotFleetLaunchSpec( - ebs_optimized=spec['ebs_optimized'], - group_set=[val for key, val in spec.items( - ) if key.startswith("group_set")], - iam_instance_profile=spec.get('iam_instance_profile._arn'), - image_id=spec['image_id'], - instance_type=spec['instance_type'], - key_name=spec.get('key_name'), - monitoring=spec.get('monitoring._enabled'), - spot_price=spec.get('spot_price', self.spot_price), - subnet_id=spec['subnet_id'], - tag_specifications=self._parse_tag_specifications(spec), - user_data=spec.get('user_data'), - weighted_capacity=spec['weighted_capacity'], - ) + self.launch_specs.append( + SpotFleetLaunchSpec( + ebs_optimized=spec["ebs_optimized"], + group_set=[ + val for key, val in spec.items() if key.startswith("group_set") + ], + iam_instance_profile=spec.get("iam_instance_profile._arn"), + image_id=spec["image_id"], + instance_type=spec["instance_type"], + key_name=spec.get("key_name"), + monitoring=spec.get("monitoring._enabled"), + spot_price=spec.get("spot_price", self.spot_price), + subnet_id=spec["subnet_id"], + tag_specifications=self._parse_tag_specifications(spec), + user_data=spec.get("user_data"), + weighted_capacity=spec["weighted_capacity"], + ) ) self.spot_requests = [] @@ -3243,26 +3799,34 @@ class SpotFleetRequest(TaggedEC2Resource): return self.id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json[ - 'Properties']['SpotFleetRequestConfigData'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"]["SpotFleetRequestConfigData"] ec2_backend = ec2_backends[region_name] - spot_price = properties.get('SpotPrice') - target_capacity = properties['TargetCapacity'] - iam_fleet_role = properties['IamFleetRole'] - allocation_strategy = properties['AllocationStrategy'] + spot_price = properties.get("SpotPrice") + target_capacity = properties["TargetCapacity"] + iam_fleet_role = properties["IamFleetRole"] + allocation_strategy = properties["AllocationStrategy"] launch_specs = properties["LaunchSpecifications"] launch_specs = [ - dict([(camelcase_to_underscores(key), val) - for key, val in launch_spec.items()]) - for launch_spec - in launch_specs + dict( + [ + (camelcase_to_underscores(key), val) + for key, val in launch_spec.items() + ] + ) + for launch_spec in launch_specs ] - spot_fleet_request = ec2_backend.request_spot_fleet(spot_price, - target_capacity, iam_fleet_role, allocation_strategy, - launch_specs) + spot_fleet_request = ec2_backend.request_spot_fleet( + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ) return spot_fleet_request @@ -3270,11 +3834,12 @@ class SpotFleetRequest(TaggedEC2Resource): weight_map = defaultdict(int) weight_so_far = 0 - if self.allocation_strategy == 'diversified': + if self.allocation_strategy == "diversified": launch_spec_index = 0 while True: launch_spec = self.launch_specs[ - launch_spec_index % len(self.launch_specs)] + launch_spec_index % len(self.launch_specs) + ] weight_map[launch_spec] += 1 weight_so_far += launch_spec.weighted_capacity if weight_so_far >= weight_to_add: @@ -3283,10 +3848,15 @@ class SpotFleetRequest(TaggedEC2Resource): else: # lowestPrice cheapest_spec = sorted( # FIXME: change `+inf` to the on demand price scaled to weighted capacity when it's not present - self.launch_specs, key=lambda spec: float(spec.spot_price or '+inf'))[0] - weight_so_far = weight_to_add + (weight_to_add % cheapest_spec.weighted_capacity) + self.launch_specs, + key=lambda spec: float(spec.spot_price or "+inf"), + )[0] + weight_so_far = weight_to_add + ( + weight_to_add % cheapest_spec.weighted_capacity + ) weight_map[cheapest_spec] = int( - weight_so_far // cheapest_spec.weighted_capacity) + weight_so_far // cheapest_spec.weighted_capacity + ) return weight_map, weight_so_far @@ -3324,7 +3894,10 @@ class SpotFleetRequest(TaggedEC2Resource): for req in self.spot_requests: instance = req.instance for spec in self.launch_specs: - if spec.instance_type == instance.instance_type and spec.subnet_id == instance.subnet_id: + if ( + spec.instance_type == instance.instance_type + and spec.subnet_id == instance.subnet_id + ): break if new_fulfilled_capacity - spec.weighted_capacity < self.target_capacity: @@ -3332,25 +3905,48 @@ class SpotFleetRequest(TaggedEC2Resource): new_fulfilled_capacity -= spec.weighted_capacity instance_ids.append(instance.id) - self.spot_requests = [req for req in self.spot_requests if req.instance.id not in instance_ids] + self.spot_requests = [ + req for req in self.spot_requests if req.instance.id not in instance_ids + ] self.ec2_backend.terminate_instances(instance_ids) def _parse_tag_specifications(self, spec): try: - tag_spec_num = max([int(key.split('.')[1]) for key in spec if key.startswith("tag_specification_set")]) + tag_spec_num = max( + [ + int(key.split(".")[1]) + for key in spec + if key.startswith("tag_specification_set") + ] + ) except ValueError: # no tag specifications return {} tag_specifications = {} for si in range(1, tag_spec_num + 1): - resource_type = spec["tag_specification_set.{si}._resource_type".format(si=si)] + resource_type = spec[ + "tag_specification_set.{si}._resource_type".format(si=si) + ] - tags = [key for key in spec if key.startswith("tag_specification_set.{si}._tag".format(si=si))] - tag_num = max([int(key.split('.')[3]) for key in tags]) - tag_specifications[resource_type] = dict(( - spec["tag_specification_set.{si}._tag.{ti}._key".format(si=si, ti=ti)], - spec["tag_specification_set.{si}._tag.{ti}._value".format(si=si, ti=ti)], - ) for ti in range(1, tag_num + 1)) + tags = [ + key + for key in spec + if key.startswith("tag_specification_set.{si}._tag".format(si=si)) + ] + tag_num = max([int(key.split(".")[3]) for key in tags]) + tag_specifications[resource_type] = dict( + ( + spec[ + "tag_specification_set.{si}._tag.{ti}._key".format(si=si, ti=ti) + ], + spec[ + "tag_specification_set.{si}._tag.{ti}._value".format( + si=si, ti=ti + ) + ], + ) + for ti in range(1, tag_num + 1) + ) return tag_specifications @@ -3360,12 +3956,25 @@ class SpotFleetBackend(object): self.spot_fleet_requests = {} super(SpotFleetBackend, self).__init__() - def request_spot_fleet(self, spot_price, target_capacity, iam_fleet_role, - allocation_strategy, launch_specs): + def request_spot_fleet( + self, + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ): spot_fleet_request_id = random_spot_fleet_request_id() - request = SpotFleetRequest(self, spot_fleet_request_id, spot_price, - target_capacity, iam_fleet_role, allocation_strategy, launch_specs) + request = SpotFleetRequest( + self, + spot_fleet_request_id, + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ) self.spot_fleet_requests[spot_fleet_request_id] = request return request @@ -3381,7 +3990,8 @@ class SpotFleetBackend(object): if spot_fleet_request_ids: requests = [ - request for request in requests if request.id in spot_fleet_request_ids] + request for request in requests if request.id in spot_fleet_request_ids + ] return requests @@ -3396,15 +4006,17 @@ class SpotFleetBackend(object): del self.spot_fleet_requests[spot_fleet_request_id] return spot_requests - def modify_spot_fleet_request(self, spot_fleet_request_id, target_capacity, terminate_instances): + def modify_spot_fleet_request( + self, spot_fleet_request_id, target_capacity, terminate_instances + ): if target_capacity < 0: - raise ValueError('Cannot reduce spot fleet capacity below 0') + raise ValueError("Cannot reduce spot fleet capacity below 0") spot_fleet_request = self.spot_fleet_requests[spot_fleet_request_id] delta = target_capacity - spot_fleet_request.fulfilled_capacity spot_fleet_request.target_capacity = target_capacity if delta > 0: spot_fleet_request.create_spot_requests(delta) - elif delta < 0 and terminate_instances == 'Default': + elif delta < 0 and terminate_instances == "Default": spot_fleet_request.terminate_instances() return True @@ -3422,18 +4034,19 @@ class ElasticAddress(object): self.association_id = None @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] - properties = cloudformation_json.get('Properties') + properties = cloudformation_json.get("Properties") instance_id = None if properties: - domain = properties.get('Domain') - eip = ec2_backend.allocate_address( - domain=domain if domain else 'standard') - instance_id = properties.get('InstanceId') + domain = properties.get("Domain") + eip = ec2_backend.allocate_address(domain=domain if domain else "standard") + instance_id = properties.get("InstanceId") else: - eip = ec2_backend.allocate_address(domain='standard') + eip = ec2_backend.allocate_address(domain="standard") if instance_id: instance = ec2_backend.get_instance_by_id(instance_id) @@ -3447,28 +4060,29 @@ class ElasticAddress(object): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'AllocationId': + + if attribute_name == "AllocationId": return self.allocation_id raise UnformattedGetAttTemplateException() def get_filter_value(self, filter_name): - if filter_name == 'allocation-id': + if filter_name == "allocation-id": return self.allocation_id - elif filter_name == 'association-id': + elif filter_name == "association-id": return self.association_id - elif filter_name == 'domain': + elif filter_name == "domain": return self.domain - elif filter_name == 'instance-id' and self.instance: + elif filter_name == "instance-id" and self.instance: return self.instance.id - elif filter_name == 'network-interface-id' and self.eni: + elif filter_name == "network-interface-id" and self.eni: return self.eni.id - elif filter_name == 'private-ip-address' and self.eni: + elif filter_name == "private-ip-address" and self.eni: return self.eni.private_ip_address - elif filter_name == 'public-ip': + elif filter_name == "public-ip": return self.public_ip else: # TODO: implement network-interface-owner-id - raise FilterNotImplementedError(filter_name, 'DescribeAddresses') + raise FilterNotImplementedError(filter_name, "DescribeAddresses") class ElasticAddressBackend(object): @@ -3477,7 +4091,7 @@ class ElasticAddressBackend(object): super(ElasticAddressBackend, self).__init__() def allocate_address(self, domain, address=None): - if domain not in ['standard', 'vpc']: + if domain not in ["standard", "vpc"]: raise InvalidDomainError(domain) if address: address = ElasticAddress(domain, address) @@ -3487,8 +4101,7 @@ class ElasticAddressBackend(object): return address def address_by_ip(self, ips): - eips = [address for address in self.addresses - if address.public_ip in ips] + eips = [address for address in self.addresses if address.public_ip in ips] # TODO: Trim error message down to specific invalid address. if not eips or len(ips) > len(eips): @@ -3497,8 +4110,11 @@ class ElasticAddressBackend(object): return eips def address_by_allocation(self, allocation_ids): - eips = [address for address in self.addresses - if address.allocation_id in allocation_ids] + eips = [ + address + for address in self.addresses + if address.allocation_id in allocation_ids + ] # TODO: Trim error message down to specific invalid id. if not eips or len(allocation_ids) > len(eips): @@ -3507,8 +4123,11 @@ class ElasticAddressBackend(object): return eips def address_by_association(self, association_ids): - eips = [address for address in self.addresses - if address.association_id in association_ids] + eips = [ + address + for address in self.addresses + if address.association_id in association_ids + ] # TODO: Trim error message down to specific invalid id. if not eips or len(association_ids) > len(eips): @@ -3516,7 +4135,14 @@ class ElasticAddressBackend(object): return eips - def associate_address(self, instance=None, eni=None, address=None, allocation_id=None, reassociate=False): + def associate_address( + self, + instance=None, + eni=None, + address=None, + allocation_id=None, + reassociate=False, + ): eips = [] if address: eips = self.address_by_ip([address]) @@ -3524,10 +4150,10 @@ class ElasticAddressBackend(object): eips = self.address_by_allocation([allocation_id]) eip = eips[0] - new_instance_association = bool(instance and ( - not eip.instance or eip.instance.id == instance.id)) - new_eni_association = bool( - eni and (not eip.eni or eni.id == eip.eni.id)) + new_instance_association = bool( + instance and (not eip.instance or eip.instance.id == instance.id) + ) + new_eni_association = bool(eni and (not eip.eni or eni.id == eip.eni.id)) if new_instance_association or new_eni_association or reassociate: eip.instance = instance @@ -3547,14 +4173,12 @@ class ElasticAddressBackend(object): def describe_addresses(self, allocation_ids=None, public_ips=None, filters=None): matches = self.addresses if allocation_ids: - matches = [addr for addr in matches - if addr.allocation_id in allocation_ids] + matches = [addr for addr in matches if addr.allocation_id in allocation_ids] if len(allocation_ids) > len(matches): unknown_ids = set(allocation_ids) - set(matches) raise InvalidAllocationIdError(unknown_ids) if public_ips: - matches = [addr for addr in matches - if addr.public_ip in public_ips] + matches = [addr for addr in matches if addr.public_ip in public_ips] if len(public_ips) > len(matches): unknown_ips = set(allocation_ids) - set(matches) raise InvalidAddressError(unknown_ips) @@ -3596,9 +4220,15 @@ class ElasticAddressBackend(object): class DHCPOptionsSet(TaggedEC2Resource): - def __init__(self, ec2_backend, domain_name_servers=None, domain_name=None, - ntp_servers=None, netbios_name_servers=None, - netbios_node_type=None): + def __init__( + self, + ec2_backend, + domain_name_servers=None, + domain_name=None, + ntp_servers=None, + netbios_name_servers=None, + netbios_node_type=None, + ): self.ec2_backend = ec2_backend self._options = { "domain-name-servers": domain_name_servers, @@ -3623,16 +4253,17 @@ class DHCPOptionsSet(TaggedEC2Resource): Taken from: http://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeDhcpOptions.html """ - if filter_name == 'dhcp-options-id': + if filter_name == "dhcp-options-id": return self.id - elif filter_name == 'key': + elif filter_name == "key": return list(self._options.keys()) - elif filter_name == 'value': + elif filter_name == "value": values = [item for item in list(self._options.values()) if item] return itertools.chain(*values) else: return super(DHCPOptionsSet, self).get_filter_value( - filter_name, 'DescribeDhcpOptions') + filter_name, "DescribeDhcpOptions" + ) @property def options(self): @@ -3649,9 +4280,13 @@ class DHCPOptionsSetBackend(object): vpc.dhcp_options = dhcp_options def create_dhcp_options( - self, domain_name_servers=None, domain_name=None, - ntp_servers=None, netbios_name_servers=None, - netbios_node_type=None): + self, + domain_name_servers=None, + domain_name=None, + ntp_servers=None, + netbios_name_servers=None, + netbios_node_type=None, + ): NETBIOS_NODE_TYPES = [1, 2, 4, 8] @@ -3663,8 +4298,12 @@ class DHCPOptionsSetBackend(object): raise InvalidParameterValueError(netbios_node_type) options = DHCPOptionsSet( - self, domain_name_servers, domain_name, ntp_servers, - netbios_name_servers, netbios_node_type + self, + domain_name_servers, + domain_name, + ntp_servers, + netbios_name_servers, + netbios_node_type, ) self.dhcp_options_sets[options.id] = options return options @@ -3679,13 +4318,12 @@ class DHCPOptionsSetBackend(object): return options_sets or self.dhcp_options_sets.values() def delete_dhcp_options_set(self, options_id): - if not (options_id and options_id.startswith('dopt-')): + if not (options_id and options_id.startswith("dopt-")): raise MalformedDHCPOptionsIdError(options_id) if options_id in self.dhcp_options_sets: if self.dhcp_options_sets[options_id].vpc: - raise DependencyViolationError( - "Cannot delete assigned DHCP options.") + raise DependencyViolationError("Cannot delete assigned DHCP options.") self.dhcp_options_sets.pop(options_id) else: raise InvalidDHCPOptionsIdError(options_id) @@ -3696,21 +4334,31 @@ class DHCPOptionsSetBackend(object): if dhcp_options_ids: dhcp_options_sets = [ - dhcp_options_set for dhcp_options_set in dhcp_options_sets if dhcp_options_set.id in dhcp_options_ids] + dhcp_options_set + for dhcp_options_set in dhcp_options_sets + if dhcp_options_set.id in dhcp_options_ids + ] if len(dhcp_options_sets) != len(dhcp_options_ids): - invalid_id = list(set(dhcp_options_ids).difference( - set([dhcp_options_set.id for dhcp_options_set in dhcp_options_sets])))[0] + invalid_id = list( + set(dhcp_options_ids).difference( + set( + [ + dhcp_options_set.id + for dhcp_options_set in dhcp_options_sets + ] + ) + ) + )[0] raise InvalidDHCPOptionsIdError(invalid_id) return generic_filter(filters, dhcp_options_sets) class VPNConnection(TaggedEC2Resource): - def __init__(self, ec2_backend, id, type, - customer_gateway_id, vpn_gateway_id): + def __init__(self, ec2_backend, id, type, customer_gateway_id, vpn_gateway_id): self.ec2_backend = ec2_backend self.id = id - self.state = 'available' + self.state = "available" self.customer_gateway_configuration = {} self.type = type self.customer_gateway_id = customer_gateway_id @@ -3720,8 +4368,9 @@ class VPNConnection(TaggedEC2Resource): self.static_routes = None def get_filter_value(self, filter_name): - return super(VPNConnection, self).get_filter_value( - filter_name, 'DescribeVpnConnections') + return super(VPNConnection, self).get_filter_value( + filter_name, "DescribeVpnConnections" + ) class VPNConnectionBackend(object): @@ -3729,16 +4378,18 @@ class VPNConnectionBackend(object): self.vpn_connections = {} super(VPNConnectionBackend, self).__init__() - def create_vpn_connection(self, type, customer_gateway_id, - vpn_gateway_id, - static_routes_only=None): + def create_vpn_connection( + self, type, customer_gateway_id, vpn_gateway_id, static_routes_only=None + ): vpn_connection_id = random_vpn_connection_id() if static_routes_only: pass vpn_connection = VPNConnection( - self, id=vpn_connection_id, type=type, + self, + id=vpn_connection_id, + type=type, customer_gateway_id=customer_gateway_id, - vpn_gateway_id=vpn_gateway_id + vpn_gateway_id=vpn_gateway_id, ) self.vpn_connections[vpn_connection.id] = vpn_connection return vpn_connection @@ -3764,11 +4415,17 @@ class VPNConnectionBackend(object): vpn_connections = self.vpn_connections.values() if vpn_connection_ids: - vpn_connections = [vpn_connection for vpn_connection in vpn_connections - if vpn_connection.id in vpn_connection_ids] + vpn_connections = [ + vpn_connection + for vpn_connection in vpn_connections + if vpn_connection.id in vpn_connection_ids + ] if len(vpn_connections) != len(vpn_connection_ids): - invalid_id = list(set(vpn_connection_ids).difference( - set([vpn_connection.id for vpn_connection in vpn_connections])))[0] + invalid_id = list( + set(vpn_connection_ids).difference( + set([vpn_connection.id for vpn_connection in vpn_connections]) + ) + )[0] raise InvalidVpnConnectionIdError(invalid_id) return generic_filter(filters, vpn_connections) @@ -3796,25 +4453,40 @@ class NetworkAclBackend(object): def add_default_entries(self, network_acl_id): default_acl_entries = [ - {'rule_number': "100", 'rule_action': 'allow', 'egress': 'true'}, - {'rule_number': "32767", 'rule_action': 'deny', 'egress': 'true'}, - {'rule_number': "100", 'rule_action': 'allow', 'egress': 'false'}, - {'rule_number': "32767", 'rule_action': 'deny', 'egress': 'false'} + {"rule_number": "100", "rule_action": "allow", "egress": "true"}, + {"rule_number": "32767", "rule_action": "deny", "egress": "true"}, + {"rule_number": "100", "rule_action": "allow", "egress": "false"}, + {"rule_number": "32767", "rule_action": "deny", "egress": "false"}, ] for entry in default_acl_entries: - self.create_network_acl_entry(network_acl_id=network_acl_id, rule_number=entry['rule_number'], protocol='-1', - rule_action=entry['rule_action'], egress=entry['egress'], cidr_block='0.0.0.0/0', - icmp_code=None, icmp_type=None, port_range_from=None, port_range_to=None) + self.create_network_acl_entry( + network_acl_id=network_acl_id, + rule_number=entry["rule_number"], + protocol="-1", + rule_action=entry["rule_action"], + egress=entry["egress"], + cidr_block="0.0.0.0/0", + icmp_code=None, + icmp_type=None, + port_range_from=None, + port_range_to=None, + ) def get_all_network_acls(self, network_acl_ids=None, filters=None): network_acls = self.network_acls.values() if network_acl_ids: - network_acls = [network_acl for network_acl in network_acls - if network_acl.id in network_acl_ids] + network_acls = [ + network_acl + for network_acl in network_acls + if network_acl.id in network_acl_ids + ] if len(network_acls) != len(network_acl_ids): - invalid_id = list(set(network_acl_ids).difference( - set([network_acl.id for network_acl in network_acls])))[0] + invalid_id = list( + set(network_acl_ids).difference( + set([network_acl.id for network_acl in network_acls]) + ) + )[0] raise InvalidRouteTableIdError(invalid_id) return generic_filter(filters, network_acls) @@ -3825,46 +4497,91 @@ class NetworkAclBackend(object): raise InvalidNetworkAclIdError(network_acl_id) return deleted - def create_network_acl_entry(self, network_acl_id, rule_number, - protocol, rule_action, egress, cidr_block, - icmp_code, icmp_type, port_range_from, - port_range_to): + def create_network_acl_entry( + self, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ): network_acl = self.get_network_acl(network_acl_id) - if any(entry.egress == egress and entry.rule_number == rule_number for entry in network_acl.network_acl_entries): + if any( + entry.egress == egress and entry.rule_number == rule_number + for entry in network_acl.network_acl_entries + ): raise NetworkAclEntryAlreadyExistsError(rule_number) - network_acl_entry = NetworkAclEntry(self, network_acl_id, rule_number, - protocol, rule_action, egress, - cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_entry = NetworkAclEntry( + self, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) network_acl.network_acl_entries.append(network_acl_entry) return network_acl_entry def delete_network_acl_entry(self, network_acl_id, rule_number, egress): network_acl = self.get_network_acl(network_acl_id) - entry = next(entry for entry in network_acl.network_acl_entries - if entry.egress == egress and entry.rule_number == rule_number) + entry = next( + entry + for entry in network_acl.network_acl_entries + if entry.egress == egress and entry.rule_number == rule_number + ) if entry is not None: network_acl.network_acl_entries.remove(entry) return entry - def replace_network_acl_entry(self, network_acl_id, rule_number, protocol, rule_action, egress, - cidr_block, icmp_code, icmp_type, port_range_from, port_range_to): + def replace_network_acl_entry( + self, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ): self.delete_network_acl_entry(network_acl_id, rule_number, egress) - network_acl_entry = self.create_network_acl_entry(network_acl_id, rule_number, - protocol, rule_action, egress, - cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_entry = self.create_network_acl_entry( + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) return network_acl_entry - def replace_network_acl_association(self, association_id, - network_acl_id): + def replace_network_acl_association(self, association_id, network_acl_id): # lookup existing association for subnet and delete it - default_acl = next(value for key, value in self.network_acls.items() - if association_id in value.associations.keys()) + default_acl = next( + value + for key, value in self.network_acls.items() + if association_id in value.associations.keys() + ) subnet_id = None for key, value in default_acl.associations.items(): @@ -3874,24 +4591,27 @@ class NetworkAclBackend(object): break new_assoc_id = random_network_acl_subnet_association_id() - association = NetworkAclAssociation(self, - new_assoc_id, - subnet_id, - network_acl_id) + association = NetworkAclAssociation( + self, new_assoc_id, subnet_id, network_acl_id + ) new_acl = self.get_network_acl(network_acl_id) new_acl.associations[new_assoc_id] = association return association def associate_default_network_acl_with_subnet(self, subnet_id, vpc_id): association_id = random_network_acl_subnet_association_id() - acl = next(acl for acl in self.network_acls.values() if acl.default and acl.vpc_id == vpc_id) - acl.associations[association_id] = NetworkAclAssociation(self, association_id, - subnet_id, acl.id) + acl = next( + acl + for acl in self.network_acls.values() + if acl.default and acl.vpc_id == vpc_id + ) + acl.associations[association_id] = NetworkAclAssociation( + self, association_id, subnet_id, acl.id + ) class NetworkAclAssociation(object): - def __init__(self, ec2_backend, new_association_id, - subnet_id, network_acl_id): + def __init__(self, ec2_backend, new_association_id, subnet_id, network_acl_id): self.ec2_backend = ec2_backend self.id = new_association_id self.new_association_id = new_association_id @@ -3907,7 +4627,7 @@ class NetworkAcl(TaggedEC2Resource): self.vpc_id = vpc_id self.network_acl_entries = [] self.associations = {} - self.default = 'true' if default is True else 'false' + self.default = "true" if default is True else "false" def get_filter_value(self, filter_name): if filter_name == "default": @@ -3920,14 +4640,25 @@ class NetworkAcl(TaggedEC2Resource): return [assoc.subnet_id for assoc in self.associations.values()] else: return super(NetworkAcl, self).get_filter_value( - filter_name, 'DescribeNetworkAcls') + filter_name, "DescribeNetworkAcls" + ) class NetworkAclEntry(TaggedEC2Resource): - def __init__(self, ec2_backend, network_acl_id, rule_number, - protocol, rule_action, egress, cidr_block, - icmp_code, icmp_type, port_range_from, - port_range_to): + def __init__( + self, + ec2_backend, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ): self.ec2_backend = ec2_backend self.network_acl_id = network_acl_id self.rule_number = rule_number @@ -3950,8 +4681,9 @@ class VpnGateway(TaggedEC2Resource): super(VpnGateway, self).__init__() def get_filter_value(self, filter_name): - return super(VpnGateway, self).get_filter_value( - filter_name, 'DescribeVpnGateways') + return super(VpnGateway, self).get_filter_value( + filter_name, "DescribeVpnGateways" + ) class VpnGatewayAttachment(object): @@ -3966,7 +4698,7 @@ class VpnGatewayBackend(object): self.vpn_gateways = {} super(VpnGatewayBackend, self).__init__() - def create_vpn_gateway(self, type='ipsec.1'): + def create_vpn_gateway(self, type="ipsec.1"): vpn_gateway_id = random_vpn_gateway_id() vpn_gateway = VpnGateway(self, vpn_gateway_id, type) self.vpn_gateways[vpn_gateway_id] = vpn_gateway @@ -3985,7 +4717,7 @@ class VpnGatewayBackend(object): def attach_vpn_gateway(self, vpn_gateway_id, vpc_id): vpn_gateway = self.get_vpn_gateway(vpn_gateway_id) self.get_vpc(vpc_id) - attachment = VpnGatewayAttachment(vpc_id, state='attached') + attachment = VpnGatewayAttachment(vpc_id, state="attached") vpn_gateway.attachments[vpc_id] = attachment return attachment @@ -4015,8 +4747,9 @@ class CustomerGateway(TaggedEC2Resource): super(CustomerGateway, self).__init__() def get_filter_value(self, filter_name): - return super(CustomerGateway, self).get_filter_value( - filter_name, 'DescribeCustomerGateways') + return super(CustomerGateway, self).get_filter_value( + filter_name, "DescribeCustomerGateways" + ) class CustomerGatewayBackend(object): @@ -4024,10 +4757,11 @@ class CustomerGatewayBackend(object): self.customer_gateways = {} super(CustomerGatewayBackend, self).__init__() - def create_customer_gateway(self, type='ipsec.1', ip_address=None, bgp_asn=None): + def create_customer_gateway(self, type="ipsec.1", ip_address=None, bgp_asn=None): customer_gateway_id = random_customer_gateway_id() customer_gateway = CustomerGateway( - self, customer_gateway_id, type, ip_address, bgp_asn) + self, customer_gateway_id, type, ip_address, bgp_asn + ) self.customer_gateways[customer_gateway_id] = customer_gateway return customer_gateway @@ -4036,8 +4770,7 @@ class CustomerGatewayBackend(object): return generic_filter(filters, customer_gateways) def get_customer_gateway(self, customer_gateway_id): - customer_gateway = self.customer_gateways.get( - customer_gateway_id, None) + customer_gateway = self.customer_gateways.get(customer_gateway_id, None) if not customer_gateway: raise InvalidCustomerGatewayIdError(customer_gateway_id) return customer_gateway @@ -4055,7 +4788,7 @@ class NatGateway(object): self.id = random_nat_gateway_id() self.subnet_id = subnet_id self.allocation_id = allocation_id - self.state = 'available' + self.state = "available" self.private_ip = random_private_ip() # protected properties @@ -4063,11 +4796,11 @@ class NatGateway(object): self._backend = backend # NOTE: this is the core of NAT Gateways creation self._eni = self._backend.create_network_interface( - backend.get_subnet(self.subnet_id), self.private_ip) + backend.get_subnet(self.subnet_id), self.private_ip + ) # associate allocation with ENI - self._backend.associate_address( - eni=self._eni, allocation_id=self.allocation_id) + self._backend.associate_address(eni=self._eni, allocation_id=self.allocation_id) @property def vpc_id(self): @@ -4088,11 +4821,13 @@ class NatGateway(object): return eips[0].public_ip @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] nat_gateway = ec2_backend.create_nat_gateway( - cloudformation_json['Properties']['SubnetId'], - cloudformation_json['Properties']['AllocationId'], + cloudformation_json["Properties"]["SubnetId"], + cloudformation_json["Properties"]["AllocationId"], ) return nat_gateway @@ -4103,7 +4838,35 @@ class NatGatewayBackend(object): super(NatGatewayBackend, self).__init__() def get_all_nat_gateways(self, filters): - return self.nat_gateways.values() + nat_gateways = self.nat_gateways.values() + + if filters is not None: + if filters.get("nat-gateway-id") is not None: + nat_gateways = [ + nat_gateway + for nat_gateway in nat_gateways + if nat_gateway.id in filters["nat-gateway-id"] + ] + if filters.get("vpc-id") is not None: + nat_gateways = [ + nat_gateway + for nat_gateway in nat_gateways + if nat_gateway.vpc_id in filters["vpc-id"] + ] + if filters.get("subnet-id") is not None: + nat_gateways = [ + nat_gateway + for nat_gateway in nat_gateways + if nat_gateway.subnet_id in filters["subnet-id"] + ] + if filters.get("state") is not None: + nat_gateways = [ + nat_gateway + for nat_gateway in nat_gateways + if nat_gateway.state in filters["state"] + ] + + return nat_gateways def create_nat_gateway(self, subnet_id, allocation_id): nat_gateway = NatGateway(self, subnet_id, allocation_id) @@ -4157,11 +4920,12 @@ class LaunchTemplate(TaggedEC2Resource): return self.latest_version().number def get_filter_value(self, filter_name): - if filter_name == 'launch-template-name': + if filter_name == "launch-template-name": return self.name else: return super(LaunchTemplate, self).get_filter_value( - filter_name, "DescribeLaunchTemplates") + filter_name, "DescribeLaunchTemplates" + ) class LaunchTemplateBackend(object): @@ -4186,7 +4950,9 @@ class LaunchTemplateBackend(object): def get_launch_template_by_name(self, name): return self.get_launch_template(self.launch_template_name_to_ids[name]) - def get_launch_templates(self, template_names=None, template_ids=None, filters=None): + def get_launch_templates( + self, template_names=None, template_ids=None, filters=None + ): if template_names and not template_ids: template_ids = [] for name in template_names: @@ -4200,16 +4966,35 @@ class LaunchTemplateBackend(object): return generic_filter(filters, templates) -class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend, - RegionsAndZonesBackend, SecurityGroupBackend, AmiBackend, - VPCBackend, SubnetBackend, SubnetRouteTableAssociationBackend, - NetworkInterfaceBackend, VPNConnectionBackend, - VPCPeeringConnectionBackend, - RouteTableBackend, RouteBackend, InternetGatewayBackend, - VPCGatewayAttachmentBackend, SpotFleetBackend, - SpotRequestBackend, ElasticAddressBackend, KeyPairBackend, - DHCPOptionsSetBackend, NetworkAclBackend, VpnGatewayBackend, - CustomerGatewayBackend, NatGatewayBackend, LaunchTemplateBackend): +class EC2Backend( + BaseBackend, + InstanceBackend, + TagBackend, + EBSBackend, + RegionsAndZonesBackend, + SecurityGroupBackend, + AmiBackend, + VPCBackend, + SubnetBackend, + SubnetRouteTableAssociationBackend, + NetworkInterfaceBackend, + VPNConnectionBackend, + VPCPeeringConnectionBackend, + RouteTableBackend, + RouteBackend, + InternetGatewayBackend, + VPCGatewayAttachmentBackend, + SpotFleetBackend, + SpotRequestBackend, + ElasticAddressBackend, + KeyPairBackend, + DHCPOptionsSetBackend, + NetworkAclBackend, + VpnGatewayBackend, + CustomerGatewayBackend, + NatGatewayBackend, + LaunchTemplateBackend, +): def __init__(self, region_name): self.region_name = region_name super(EC2Backend, self).__init__() @@ -4220,20 +5005,20 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend, # docs.aws.amazon.com/AmazonVPC/latest/UserGuide/default-vpc.html # if not self.vpcs: - vpc = self.create_vpc('172.31.0.0/16') + vpc = self.create_vpc("172.31.0.0/16") else: # For now this is included for potential # backward-compatibility issues vpc = self.vpcs.values()[0] # Create default subnet for each availability zone - ip, _ = vpc.cidr_block.split('/') - ip = ip.split('.') + ip, _ = vpc.cidr_block.split("/") + ip = ip.split(".") ip[2] = 0 for zone in self.describe_availability_zones(): az_name = zone.name - cidr_block = '.'.join(str(i) for i in ip) + '/20' + cidr_block = ".".join(str(i) for i in ip) + "/20" self.create_subnet(vpc.id, cidr_block, availability_zone=az_name) ip[2] += 16 @@ -4253,49 +5038,51 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend, def do_resources_exist(self, resource_ids): for resource_id in resource_ids: resource_prefix = get_prefix(resource_id) - if resource_prefix == EC2_RESOURCE_TO_PREFIX['customer-gateway']: + if resource_prefix == EC2_RESOURCE_TO_PREFIX["customer-gateway"]: self.get_customer_gateway(customer_gateway_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['dhcp-options']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["dhcp-options"]: self.describe_dhcp_options(options_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['image']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["image"]: self.describe_images(ami_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['instance']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["instance"]: self.get_instance_by_id(instance_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['internet-gateway']: - self.describe_internet_gateways( - internet_gateway_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['launch-template']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["internet-gateway"]: + self.describe_internet_gateways(internet_gateway_ids=[resource_id]) + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["launch-template"]: self.get_launch_template(resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-acl']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["network-acl"]: self.get_all_network_acls() - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["network-interface"]: self.describe_network_interfaces( - filters={'network-interface-id': resource_id}) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['reserved-instance']: - self.raise_not_implemented_error('DescribeReservedInstances') - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['route-table']: + filters={"network-interface-id": resource_id} + ) + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["reserved-instance"]: + self.raise_not_implemented_error("DescribeReservedInstances") + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["route-table"]: self.get_route_table(route_table_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['security-group']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["security-group"]: self.describe_security_groups(group_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['snapshot']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["snapshot"]: self.get_snapshot(snapshot_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['spot-instance-request']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["spot-instance-request"]: self.describe_spot_instance_requests( - filters={'spot-instance-request-id': resource_id}) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['subnet']: + filters={"spot-instance-request-id": resource_id} + ) + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["subnet"]: self.get_subnet(subnet_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['volume']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["volume"]: self.get_volume(volume_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpc']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpc"]: self.get_vpc(vpc_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpc-peering-connection']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"]: self.get_vpc_peering_connection(vpc_pcx_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpn-connection']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpn-connection"]: self.describe_vpn_connections(vpn_connection_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpn-gateway']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpn-gateway"]: self.get_vpn_gateway(vpn_gateway_id=resource_id) return True -ec2_backends = {region.name: EC2Backend(region.name) - for region in RegionsAndZonesBackend.regions} +ec2_backends = { + region.name: EC2Backend(region.name) for region in RegionsAndZonesBackend.regions +} diff --git a/moto/ec2/responses/__init__.py b/moto/ec2/responses/__init__.py index d0648eb50..21cbf8249 100644 --- a/moto/ec2/responses/__init__.py +++ b/moto/ec2/responses/__init__.py @@ -70,10 +70,10 @@ class EC2Response( Windows, NatGateways, ): - @property def ec2_backend(self): from moto.ec2.models import ec2_backends + return ec2_backends[self.region] @property diff --git a/moto/ec2/responses/account_attributes.py b/moto/ec2/responses/account_attributes.py index 8a5b9a4b0..068a7c395 100644 --- a/moto/ec2/responses/account_attributes.py +++ b/moto/ec2/responses/account_attributes.py @@ -3,13 +3,12 @@ from moto.core.responses import BaseResponse class AccountAttributes(BaseResponse): - def describe_account_attributes(self): template = self.response_template(DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT) return template.render() -DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = u""" +DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE diff --git a/moto/ec2/responses/amazon_dev_pay.py b/moto/ec2/responses/amazon_dev_pay.py index 14df3f004..982b7f4a3 100644 --- a/moto/ec2/responses/amazon_dev_pay.py +++ b/moto/ec2/responses/amazon_dev_pay.py @@ -3,7 +3,7 @@ from moto.core.responses import BaseResponse class AmazonDevPay(BaseResponse): - def confirm_product_instance(self): raise NotImplementedError( - 'AmazonDevPay.confirm_product_instance is not yet implemented') + "AmazonDevPay.confirm_product_instance is not yet implemented" + ) diff --git a/moto/ec2/responses/amis.py b/moto/ec2/responses/amis.py index 17e1e228d..6736a7175 100755 --- a/moto/ec2/responses/amis.py +++ b/moto/ec2/responses/amis.py @@ -4,76 +4,83 @@ from moto.ec2.utils import filters_from_querystring class AmisResponse(BaseResponse): - def create_image(self): - name = self.querystring.get('Name')[0] - description = self._get_param('Description', if_none='') - instance_id = self._get_param('InstanceId') - if self.is_not_dryrun('CreateImage'): + name = self.querystring.get("Name")[0] + description = self._get_param("Description", if_none="") + instance_id = self._get_param("InstanceId") + if self.is_not_dryrun("CreateImage"): image = self.ec2_backend.create_image( - instance_id, name, description, context=self) + instance_id, name, description, context=self + ) template = self.response_template(CREATE_IMAGE_RESPONSE) return template.render(image=image) def copy_image(self): - source_image_id = self._get_param('SourceImageId') - source_region = self._get_param('SourceRegion') - name = self._get_param('Name') - description = self._get_param('Description') - if self.is_not_dryrun('CopyImage'): + source_image_id = self._get_param("SourceImageId") + source_region = self._get_param("SourceRegion") + name = self._get_param("Name") + description = self._get_param("Description") + if self.is_not_dryrun("CopyImage"): image = self.ec2_backend.copy_image( - source_image_id, source_region, name, description) + source_image_id, source_region, name, description + ) template = self.response_template(COPY_IMAGE_RESPONSE) return template.render(image=image) def deregister_image(self): - ami_id = self._get_param('ImageId') - if self.is_not_dryrun('DeregisterImage'): + ami_id = self._get_param("ImageId") + if self.is_not_dryrun("DeregisterImage"): success = self.ec2_backend.deregister_image(ami_id) template = self.response_template(DEREGISTER_IMAGE_RESPONSE) return template.render(success=str(success).lower()) def describe_images(self): - ami_ids = self._get_multi_param('ImageId') + ami_ids = self._get_multi_param("ImageId") filters = filters_from_querystring(self.querystring) - owners = self._get_multi_param('Owner') - exec_users = self._get_multi_param('ExecutableBy') + owners = self._get_multi_param("Owner") + exec_users = self._get_multi_param("ExecutableBy") images = self.ec2_backend.describe_images( - ami_ids=ami_ids, filters=filters, exec_users=exec_users, - owners=owners, context=self) + ami_ids=ami_ids, + filters=filters, + exec_users=exec_users, + owners=owners, + context=self, + ) template = self.response_template(DESCRIBE_IMAGES_RESPONSE) return template.render(images=images) def describe_image_attribute(self): - ami_id = self._get_param('ImageId') + ami_id = self._get_param("ImageId") groups = self.ec2_backend.get_launch_permission_groups(ami_id) users = self.ec2_backend.get_launch_permission_users(ami_id) template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE) return template.render(ami_id=ami_id, groups=groups, users=users) def modify_image_attribute(self): - ami_id = self._get_param('ImageId') - operation_type = self._get_param('OperationType') - group = self._get_param('UserGroup.1') - user_ids = self._get_multi_param('UserId') - if self.is_not_dryrun('ModifyImageAttribute'): - if (operation_type == 'add'): + ami_id = self._get_param("ImageId") + operation_type = self._get_param("OperationType") + group = self._get_param("UserGroup.1") + user_ids = self._get_multi_param("UserId") + if self.is_not_dryrun("ModifyImageAttribute"): + if operation_type == "add": self.ec2_backend.add_launch_permission( - ami_id, user_ids=user_ids, group=group) - elif (operation_type == 'remove'): + ami_id, user_ids=user_ids, group=group + ) + elif operation_type == "remove": self.ec2_backend.remove_launch_permission( - ami_id, user_ids=user_ids, group=group) + ami_id, user_ids=user_ids, group=group + ) return MODIFY_IMAGE_ATTRIBUTE_RESPONSE def register_image(self): - if self.is_not_dryrun('RegisterImage'): - raise NotImplementedError( - 'AMIs.register_image is not yet implemented') + if self.is_not_dryrun("RegisterImage"): + raise NotImplementedError("AMIs.register_image is not yet implemented") def reset_image_attribute(self): - if self.is_not_dryrun('ResetImageAttribute'): + if self.is_not_dryrun("ResetImageAttribute"): raise NotImplementedError( - 'AMIs.reset_image_attribute is not yet implemented') + "AMIs.reset_image_attribute is not yet implemented" + ) CREATE_IMAGE_RESPONSE = """ diff --git a/moto/ec2/responses/availability_zones_and_regions.py b/moto/ec2/responses/availability_zones_and_regions.py index a6e35a89c..d63e2f4ad 100644 --- a/moto/ec2/responses/availability_zones_and_regions.py +++ b/moto/ec2/responses/availability_zones_and_regions.py @@ -3,14 +3,13 @@ from moto.core.responses import BaseResponse class AvailabilityZonesAndRegions(BaseResponse): - def describe_availability_zones(self): zones = self.ec2_backend.describe_availability_zones() template = self.response_template(DESCRIBE_ZONES_RESPONSE) return template.render(zones=zones) def describe_regions(self): - region_names = self._get_multi_param('RegionName') + region_names = self._get_multi_param("RegionName") regions = self.ec2_backend.describe_regions(region_names) template = self.response_template(DESCRIBE_REGIONS_RESPONSE) return template.render(regions=regions) diff --git a/moto/ec2/responses/customer_gateways.py b/moto/ec2/responses/customer_gateways.py index 866b93045..65b93cc2e 100644 --- a/moto/ec2/responses/customer_gateways.py +++ b/moto/ec2/responses/customer_gateways.py @@ -4,21 +4,20 @@ from moto.ec2.utils import filters_from_querystring class CustomerGateways(BaseResponse): - def create_customer_gateway(self): # raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented') - type = self._get_param('Type') - ip_address = self._get_param('IpAddress') - bgp_asn = self._get_param('BgpAsn') + type = self._get_param("Type") + ip_address = self._get_param("IpAddress") + bgp_asn = self._get_param("BgpAsn") customer_gateway = self.ec2_backend.create_customer_gateway( - type, ip_address=ip_address, bgp_asn=bgp_asn) + type, ip_address=ip_address, bgp_asn=bgp_asn + ) template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE) return template.render(customer_gateway=customer_gateway) def delete_customer_gateway(self): - customer_gateway_id = self._get_param('CustomerGatewayId') - delete_status = self.ec2_backend.delete_customer_gateway( - customer_gateway_id) + customer_gateway_id = self._get_param("CustomerGatewayId") + delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id) template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE) return template.render(customer_gateway=delete_status) diff --git a/moto/ec2/responses/dhcp_options.py b/moto/ec2/responses/dhcp_options.py index 1f740d14b..868ab85cf 100644 --- a/moto/ec2/responses/dhcp_options.py +++ b/moto/ec2/responses/dhcp_options.py @@ -1,15 +1,12 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse -from moto.ec2.utils import ( - filters_from_querystring, - dhcp_configuration_from_querystring) +from moto.ec2.utils import filters_from_querystring, dhcp_configuration_from_querystring class DHCPOptions(BaseResponse): - def associate_dhcp_options(self): - dhcp_opt_id = self._get_param('DhcpOptionsId') - vpc_id = self._get_param('VpcId') + dhcp_opt_id = self._get_param("DhcpOptionsId") + vpc_id = self._get_param("VpcId") dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0] vpc = self.ec2_backend.get_vpc(vpc_id) @@ -35,14 +32,14 @@ class DHCPOptions(BaseResponse): domain_name=domain_name, ntp_servers=ntp_servers, netbios_name_servers=netbios_name_servers, - netbios_node_type=netbios_node_type + netbios_node_type=netbios_node_type, ) template = self.response_template(CREATE_DHCP_OPTIONS_RESPONSE) return template.render(dhcp_options_set=dhcp_options_set) def delete_dhcp_options(self): - dhcp_opt_id = self._get_param('DhcpOptionsId') + dhcp_opt_id = self._get_param("DhcpOptionsId") delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id) template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE) return template.render(delete_status=delete_status) @@ -50,13 +47,12 @@ class DHCPOptions(BaseResponse): def describe_dhcp_options(self): dhcp_opt_ids = self._get_multi_param("DhcpOptionsId") filters = filters_from_querystring(self.querystring) - dhcp_opts = self.ec2_backend.get_all_dhcp_options( - dhcp_opt_ids, filters) + dhcp_opts = self.ec2_backend.get_all_dhcp_options(dhcp_opt_ids, filters) template = self.response_template(DESCRIBE_DHCP_OPTIONS_RESPONSE) return template.render(dhcp_options=dhcp_opts) -CREATE_DHCP_OPTIONS_RESPONSE = u""" +CREATE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -92,14 +88,14 @@ CREATE_DHCP_OPTIONS_RESPONSE = u""" """ -DELETE_DHCP_OPTIONS_RESPONSE = u""" +DELETE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE {{delete_status}} """ -DESCRIBE_DHCP_OPTIONS_RESPONSE = u""" +DESCRIBE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -139,7 +135,7 @@ DESCRIBE_DHCP_OPTIONS_RESPONSE = u""" """ -ASSOCIATE_DHCP_OPTIONS_RESPONSE = u""" +ASSOCIATE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE true diff --git a/moto/ec2/responses/elastic_block_store.py b/moto/ec2/responses/elastic_block_store.py index acd37b283..d11470242 100644 --- a/moto/ec2/responses/elastic_block_store.py +++ b/moto/ec2/responses/elastic_block_store.py @@ -4,137 +4,148 @@ from moto.ec2.utils import filters_from_querystring class ElasticBlockStore(BaseResponse): - def attach_volume(self): - volume_id = self._get_param('VolumeId') - instance_id = self._get_param('InstanceId') - device_path = self._get_param('Device') - if self.is_not_dryrun('AttachVolume'): + volume_id = self._get_param("VolumeId") + instance_id = self._get_param("InstanceId") + device_path = self._get_param("Device") + if self.is_not_dryrun("AttachVolume"): attachment = self.ec2_backend.attach_volume( - volume_id, instance_id, device_path) + volume_id, instance_id, device_path + ) template = self.response_template(ATTACHED_VOLUME_RESPONSE) return template.render(attachment=attachment) def copy_snapshot(self): - source_snapshot_id = self._get_param('SourceSnapshotId') - source_region = self._get_param('SourceRegion') - description = self._get_param('Description') - if self.is_not_dryrun('CopySnapshot'): + source_snapshot_id = self._get_param("SourceSnapshotId") + source_region = self._get_param("SourceRegion") + description = self._get_param("Description") + if self.is_not_dryrun("CopySnapshot"): snapshot = self.ec2_backend.copy_snapshot( - source_snapshot_id, source_region, description) + source_snapshot_id, source_region, description + ) template = self.response_template(COPY_SNAPSHOT_RESPONSE) return template.render(snapshot=snapshot) def create_snapshot(self): - volume_id = self._get_param('VolumeId') - description = self._get_param('Description') + volume_id = self._get_param("VolumeId") + description = self._get_param("Description") tags = self._parse_tag_specification("TagSpecification") - snapshot_tags = tags.get('snapshot', {}) - if self.is_not_dryrun('CreateSnapshot'): + snapshot_tags = tags.get("snapshot", {}) + if self.is_not_dryrun("CreateSnapshot"): snapshot = self.ec2_backend.create_snapshot(volume_id, description) snapshot.add_tags(snapshot_tags) template = self.response_template(CREATE_SNAPSHOT_RESPONSE) return template.render(snapshot=snapshot) def create_volume(self): - size = self._get_param('Size') - zone = self._get_param('AvailabilityZone') - snapshot_id = self._get_param('SnapshotId') + size = self._get_param("Size") + zone = self._get_param("AvailabilityZone") + snapshot_id = self._get_param("SnapshotId") tags = self._parse_tag_specification("TagSpecification") - volume_tags = tags.get('volume', {}) - encrypted = self._get_param('Encrypted', if_none=False) - if self.is_not_dryrun('CreateVolume'): - volume = self.ec2_backend.create_volume( - size, zone, snapshot_id, encrypted) + volume_tags = tags.get("volume", {}) + encrypted = self._get_param("Encrypted", if_none=False) + if self.is_not_dryrun("CreateVolume"): + volume = self.ec2_backend.create_volume(size, zone, snapshot_id, encrypted) volume.add_tags(volume_tags) template = self.response_template(CREATE_VOLUME_RESPONSE) return template.render(volume=volume) def delete_snapshot(self): - snapshot_id = self._get_param('SnapshotId') - if self.is_not_dryrun('DeleteSnapshot'): + snapshot_id = self._get_param("SnapshotId") + if self.is_not_dryrun("DeleteSnapshot"): self.ec2_backend.delete_snapshot(snapshot_id) return DELETE_SNAPSHOT_RESPONSE def delete_volume(self): - volume_id = self._get_param('VolumeId') - if self.is_not_dryrun('DeleteVolume'): + volume_id = self._get_param("VolumeId") + if self.is_not_dryrun("DeleteVolume"): self.ec2_backend.delete_volume(volume_id) return DELETE_VOLUME_RESPONSE def describe_snapshots(self): filters = filters_from_querystring(self.querystring) - snapshot_ids = self._get_multi_param('SnapshotId') - snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters) + snapshot_ids = self._get_multi_param("SnapshotId") + snapshots = self.ec2_backend.describe_snapshots( + snapshot_ids=snapshot_ids, filters=filters + ) template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE) return template.render(snapshots=snapshots) def describe_volumes(self): filters = filters_from_querystring(self.querystring) - volume_ids = self._get_multi_param('VolumeId') - volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters) + volume_ids = self._get_multi_param("VolumeId") + volumes = self.ec2_backend.describe_volumes( + volume_ids=volume_ids, filters=filters + ) template = self.response_template(DESCRIBE_VOLUMES_RESPONSE) return template.render(volumes=volumes) def describe_volume_attribute(self): raise NotImplementedError( - 'ElasticBlockStore.describe_volume_attribute is not yet implemented') + "ElasticBlockStore.describe_volume_attribute is not yet implemented" + ) def describe_volume_status(self): raise NotImplementedError( - 'ElasticBlockStore.describe_volume_status is not yet implemented') + "ElasticBlockStore.describe_volume_status is not yet implemented" + ) def detach_volume(self): - volume_id = self._get_param('VolumeId') - instance_id = self._get_param('InstanceId') - device_path = self._get_param('Device') - if self.is_not_dryrun('DetachVolume'): + volume_id = self._get_param("VolumeId") + instance_id = self._get_param("InstanceId") + device_path = self._get_param("Device") + if self.is_not_dryrun("DetachVolume"): attachment = self.ec2_backend.detach_volume( - volume_id, instance_id, device_path) + volume_id, instance_id, device_path + ) template = self.response_template(DETATCH_VOLUME_RESPONSE) return template.render(attachment=attachment) def enable_volume_io(self): - if self.is_not_dryrun('EnableVolumeIO'): + if self.is_not_dryrun("EnableVolumeIO"): raise NotImplementedError( - 'ElasticBlockStore.enable_volume_io is not yet implemented') + "ElasticBlockStore.enable_volume_io is not yet implemented" + ) def import_volume(self): - if self.is_not_dryrun('ImportVolume'): + if self.is_not_dryrun("ImportVolume"): raise NotImplementedError( - 'ElasticBlockStore.import_volume is not yet implemented') + "ElasticBlockStore.import_volume is not yet implemented" + ) def describe_snapshot_attribute(self): - snapshot_id = self._get_param('SnapshotId') - groups = self.ec2_backend.get_create_volume_permission_groups( - snapshot_id) - template = self.response_template( - DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE) + snapshot_id = self._get_param("SnapshotId") + groups = self.ec2_backend.get_create_volume_permission_groups(snapshot_id) + template = self.response_template(DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE) return template.render(snapshot_id=snapshot_id, groups=groups) def modify_snapshot_attribute(self): - snapshot_id = self._get_param('SnapshotId') - operation_type = self._get_param('OperationType') - group = self._get_param('UserGroup.1') - user_id = self._get_param('UserId.1') - if self.is_not_dryrun('ModifySnapshotAttribute'): - if (operation_type == 'add'): + snapshot_id = self._get_param("SnapshotId") + operation_type = self._get_param("OperationType") + group = self._get_param("UserGroup.1") + user_id = self._get_param("UserId.1") + if self.is_not_dryrun("ModifySnapshotAttribute"): + if operation_type == "add": self.ec2_backend.add_create_volume_permission( - snapshot_id, user_id=user_id, group=group) - elif (operation_type == 'remove'): + snapshot_id, user_id=user_id, group=group + ) + elif operation_type == "remove": self.ec2_backend.remove_create_volume_permission( - snapshot_id, user_id=user_id, group=group) + snapshot_id, user_id=user_id, group=group + ) return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE def modify_volume_attribute(self): - if self.is_not_dryrun('ModifyVolumeAttribute'): + if self.is_not_dryrun("ModifyVolumeAttribute"): raise NotImplementedError( - 'ElasticBlockStore.modify_volume_attribute is not yet implemented') + "ElasticBlockStore.modify_volume_attribute is not yet implemented" + ) def reset_snapshot_attribute(self): - if self.is_not_dryrun('ResetSnapshotAttribute'): + if self.is_not_dryrun("ResetSnapshotAttribute"): raise NotImplementedError( - 'ElasticBlockStore.reset_snapshot_attribute is not yet implemented') + "ElasticBlockStore.reset_snapshot_attribute is not yet implemented" + ) CREATE_VOLUME_RESPONSE = """ diff --git a/moto/ec2/responses/elastic_ip_addresses.py b/moto/ec2/responses/elastic_ip_addresses.py index 6e1c9fe38..e25922706 100644 --- a/moto/ec2/responses/elastic_ip_addresses.py +++ b/moto/ec2/responses/elastic_ip_addresses.py @@ -4,14 +4,14 @@ from moto.ec2.utils import filters_from_querystring class ElasticIPAddresses(BaseResponse): - def allocate_address(self): - domain = self._get_param('Domain', if_none='standard') - reallocate_address = self._get_param('Address', if_none=None) - if self.is_not_dryrun('AllocateAddress'): + domain = self._get_param("Domain", if_none="standard") + reallocate_address = self._get_param("Address", if_none=None) + if self.is_not_dryrun("AllocateAddress"): if reallocate_address: address = self.ec2_backend.allocate_address( - domain, address=reallocate_address) + domain, address=reallocate_address + ) else: address = self.ec2_backend.allocate_address(domain) template = self.response_template(ALLOCATE_ADDRESS_RESPONSE) @@ -21,73 +21,92 @@ class ElasticIPAddresses(BaseResponse): instance = eni = None if "InstanceId" in self.querystring: - instance = self.ec2_backend.get_instance( - self._get_param('InstanceId')) + instance = self.ec2_backend.get_instance(self._get_param("InstanceId")) elif "NetworkInterfaceId" in self.querystring: eni = self.ec2_backend.get_network_interface( - self._get_param('NetworkInterfaceId')) + self._get_param("NetworkInterfaceId") + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.") + "MissingParameter", + "Invalid request, expect InstanceId/NetworkId parameter.", + ) reassociate = False if "AllowReassociation" in self.querystring: - reassociate = self._get_param('AllowReassociation') == "true" + reassociate = self._get_param("AllowReassociation") == "true" - if self.is_not_dryrun('AssociateAddress'): + if self.is_not_dryrun("AssociateAddress"): if instance or eni: if "PublicIp" in self.querystring: eip = self.ec2_backend.associate_address( - instance=instance, eni=eni, - address=self._get_param('PublicIp'), reassociate=reassociate) + instance=instance, + eni=eni, + address=self._get_param("PublicIp"), + reassociate=reassociate, + ) elif "AllocationId" in self.querystring: eip = self.ec2_backend.associate_address( - instance=instance, eni=eni, - allocation_id=self._get_param('AllocationId'), reassociate=reassociate) + instance=instance, + eni=eni, + allocation_id=self._get_param("AllocationId"), + reassociate=reassociate, + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") + "MissingParameter", + "Invalid request, expect PublicIp/AllocationId parameter.", + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect either instance or ENI.") + "MissingParameter", + "Invalid request, expect either instance or ENI.", + ) template = self.response_template(ASSOCIATE_ADDRESS_RESPONSE) return template.render(address=eip) def describe_addresses(self): - allocation_ids = self._get_multi_param('AllocationId') - public_ips = self._get_multi_param('PublicIp') + allocation_ids = self._get_multi_param("AllocationId") + public_ips = self._get_multi_param("PublicIp") filters = filters_from_querystring(self.querystring) addresses = self.ec2_backend.describe_addresses( - allocation_ids, public_ips, filters) + allocation_ids, public_ips, filters + ) template = self.response_template(DESCRIBE_ADDRESS_RESPONSE) return template.render(addresses=addresses) def disassociate_address(self): - if self.is_not_dryrun('DisAssociateAddress'): + if self.is_not_dryrun("DisAssociateAddress"): if "PublicIp" in self.querystring: self.ec2_backend.disassociate_address( - address=self._get_param('PublicIp')) + address=self._get_param("PublicIp") + ) elif "AssociationId" in self.querystring: self.ec2_backend.disassociate_address( - association_id=self._get_param('AssociationId')) + association_id=self._get_param("AssociationId") + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.") + "MissingParameter", + "Invalid request, expect PublicIp/AssociationId parameter.", + ) return self.response_template(DISASSOCIATE_ADDRESS_RESPONSE).render() def release_address(self): - if self.is_not_dryrun('ReleaseAddress'): + if self.is_not_dryrun("ReleaseAddress"): if "PublicIp" in self.querystring: - self.ec2_backend.release_address( - address=self._get_param('PublicIp')) + self.ec2_backend.release_address(address=self._get_param("PublicIp")) elif "AllocationId" in self.querystring: self.ec2_backend.release_address( - allocation_id=self._get_param('AllocationId')) + allocation_id=self._get_param("AllocationId") + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") + "MissingParameter", + "Invalid request, expect PublicIp/AllocationId parameter.", + ) return self.response_template(RELEASE_ADDRESS_RESPONSE).render() diff --git a/moto/ec2/responses/elastic_network_interfaces.py b/moto/ec2/responses/elastic_network_interfaces.py index 9c37e70da..6761b294e 100644 --- a/moto/ec2/responses/elastic_network_interfaces.py +++ b/moto/ec2/responses/elastic_network_interfaces.py @@ -4,71 +4,70 @@ from moto.ec2.utils import filters_from_querystring class ElasticNetworkInterfaces(BaseResponse): - def create_network_interface(self): - subnet_id = self._get_param('SubnetId') - private_ip_address = self._get_param('PrivateIpAddress') - groups = self._get_multi_param('SecurityGroupId') + subnet_id = self._get_param("SubnetId") + private_ip_address = self._get_param("PrivateIpAddress") + private_ip_addresses = self._get_multi_param("PrivateIpAddresses") + groups = self._get_multi_param("SecurityGroupId") subnet = self.ec2_backend.get_subnet(subnet_id) - description = self._get_param('Description') - if self.is_not_dryrun('CreateNetworkInterface'): + description = self._get_param("Description") + if self.is_not_dryrun("CreateNetworkInterface"): eni = self.ec2_backend.create_network_interface( - subnet, private_ip_address, groups, description) - template = self.response_template( - CREATE_NETWORK_INTERFACE_RESPONSE) + subnet, private_ip_address, private_ip_addresses, groups, description + ) + template = self.response_template(CREATE_NETWORK_INTERFACE_RESPONSE) return template.render(eni=eni) def delete_network_interface(self): - eni_id = self._get_param('NetworkInterfaceId') - if self.is_not_dryrun('DeleteNetworkInterface'): + eni_id = self._get_param("NetworkInterfaceId") + if self.is_not_dryrun("DeleteNetworkInterface"): self.ec2_backend.delete_network_interface(eni_id) - template = self.response_template( - DELETE_NETWORK_INTERFACE_RESPONSE) + template = self.response_template(DELETE_NETWORK_INTERFACE_RESPONSE) return template.render() def describe_network_interface_attribute(self): raise NotImplementedError( - 'ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented') + "ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented" + ) def describe_network_interfaces(self): - eni_ids = self._get_multi_param('NetworkInterfaceId') + eni_ids = self._get_multi_param("NetworkInterfaceId") filters = filters_from_querystring(self.querystring) enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters) template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE) return template.render(enis=enis) def attach_network_interface(self): - eni_id = self._get_param('NetworkInterfaceId') - instance_id = self._get_param('InstanceId') - device_index = self._get_param('DeviceIndex') - if self.is_not_dryrun('AttachNetworkInterface'): + eni_id = self._get_param("NetworkInterfaceId") + instance_id = self._get_param("InstanceId") + device_index = self._get_param("DeviceIndex") + if self.is_not_dryrun("AttachNetworkInterface"): attachment_id = self.ec2_backend.attach_network_interface( - eni_id, instance_id, device_index) - template = self.response_template( - ATTACH_NETWORK_INTERFACE_RESPONSE) + eni_id, instance_id, device_index + ) + template = self.response_template(ATTACH_NETWORK_INTERFACE_RESPONSE) return template.render(attachment_id=attachment_id) def detach_network_interface(self): - attachment_id = self._get_param('AttachmentId') - if self.is_not_dryrun('DetachNetworkInterface'): + attachment_id = self._get_param("AttachmentId") + if self.is_not_dryrun("DetachNetworkInterface"): self.ec2_backend.detach_network_interface(attachment_id) - template = self.response_template( - DETACH_NETWORK_INTERFACE_RESPONSE) + template = self.response_template(DETACH_NETWORK_INTERFACE_RESPONSE) return template.render() def modify_network_interface_attribute(self): # Currently supports modifying one and only one security group - eni_id = self._get_param('NetworkInterfaceId') - group_id = self._get_param('SecurityGroupId.1') - if self.is_not_dryrun('ModifyNetworkInterface'): - self.ec2_backend.modify_network_interface_attribute( - eni_id, group_id) + eni_id = self._get_param("NetworkInterfaceId") + group_id = self._get_param("SecurityGroupId.1") + if self.is_not_dryrun("ModifyNetworkInterface"): + self.ec2_backend.modify_network_interface_attribute(eni_id, group_id) return MODIFY_NETWORK_INTERFACE_ATTRIBUTE_RESPONSE def reset_network_interface_attribute(self): - if self.is_not_dryrun('ResetNetworkInterface'): + if self.is_not_dryrun("ResetNetworkInterface"): raise NotImplementedError( - 'ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented') + "ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented" + ) CREATE_NETWORK_INTERFACE_RESPONSE = """ diff --git a/moto/ec2/responses/general.py b/moto/ec2/responses/general.py index 262d9f8ea..5dcd73358 100644 --- a/moto/ec2/responses/general.py +++ b/moto/ec2/responses/general.py @@ -3,20 +3,19 @@ from moto.core.responses import BaseResponse class General(BaseResponse): - def get_console_output(self): - instance_id = self._get_param('InstanceId') + instance_id = self._get_param("InstanceId") if not instance_id: # For compatibility with boto. # See: https://github.com/spulec/moto/pull/1152#issuecomment-332487599 - instance_id = self._get_multi_param('InstanceId')[0] + instance_id = self._get_multi_param("InstanceId")[0] instance = self.ec2_backend.get_instance(instance_id) template = self.response_template(GET_CONSOLE_OUTPUT_RESULT) return template.render(instance=instance) -GET_CONSOLE_OUTPUT_RESULT = ''' +GET_CONSOLE_OUTPUT_RESULT = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE {{ instance.id }} @@ -29,4 +28,4 @@ R0hNRU0gYXZhaWxhYmxlLgo3MjdNQiBMT1dNRU0gYXZhaWxhYmxlLgpOWCAoRXhlY3V0ZSBEaXNh YmxlKSBwcm90ZWN0aW9uOiBhY3RpdmUKSVJRIGxvY2t1cCBkZXRlY3Rpb24gZGlzYWJsZWQKQnVp bHQgMSB6b25lbGlzdHMKS2VybmVsIGNvbW1hbmQgbGluZTogcm9vdD0vZGV2L3NkYTEgcm8gNApF bmFibGluZyBmYXN0IEZQVSBzYXZlIGFuZCByZXN0b3JlLi4uIGRvbmUuCg== -''' +""" diff --git a/moto/ec2/responses/instances.py b/moto/ec2/responses/instances.py index 82c2b1997..b9e572d29 100644 --- a/moto/ec2/responses/instances.py +++ b/moto/ec2/responses/instances.py @@ -4,19 +4,20 @@ from boto.ec2.instancetype import InstanceType from moto.autoscaling import autoscaling_backends from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores -from moto.ec2.utils import filters_from_querystring, \ - dict_from_querystring +from moto.ec2.utils import filters_from_querystring, dict_from_querystring +from moto.elbv2 import elbv2_backends +from moto.core import ACCOUNT_ID class InstanceResponse(BaseResponse): - def describe_instances(self): filter_dict = filters_from_querystring(self.querystring) - instance_ids = self._get_multi_param('InstanceId') + instance_ids = self._get_multi_param("InstanceId") token = self._get_param("NextToken") if instance_ids: reservations = self.ec2_backend.get_reservations_by_instance_ids( - instance_ids, filters=filter_dict) + instance_ids, filters=filter_dict + ) else: reservations = self.ec2_backend.all_reservations(filters=filter_dict) @@ -25,80 +26,99 @@ class InstanceResponse(BaseResponse): start = reservation_ids.index(token) + 1 else: start = 0 - max_results = int(self._get_param('MaxResults', 100)) - reservations_resp = reservations[start:start + max_results] + max_results = int(self._get_param("MaxResults", 100)) + reservations_resp = reservations[start : start + max_results] next_token = None if max_results and len(reservations) > (start + max_results): next_token = reservations_resp[-1].id template = self.response_template(EC2_DESCRIBE_INSTANCES) - return template.render(reservations=reservations_resp, next_token=next_token).replace('True', 'true').replace('False', 'false') + return ( + template.render(reservations=reservations_resp, next_token=next_token) + .replace("True", "true") + .replace("False", "false") + ) def run_instances(self): - min_count = int(self._get_param('MinCount', if_none='1')) - image_id = self._get_param('ImageId') - owner_id = self._get_param('OwnerId') - user_data = self._get_param('UserData') - security_group_names = self._get_multi_param('SecurityGroup') - security_group_ids = self._get_multi_param('SecurityGroupId') + min_count = int(self._get_param("MinCount", if_none="1")) + image_id = self._get_param("ImageId") + owner_id = self._get_param("OwnerId") + user_data = self._get_param("UserData") + security_group_names = self._get_multi_param("SecurityGroup") + security_group_ids = self._get_multi_param("SecurityGroupId") nics = dict_from_querystring("NetworkInterface", self.querystring) - instance_type = self._get_param('InstanceType', if_none='m1.small') - placement = self._get_param('Placement.AvailabilityZone') - subnet_id = self._get_param('SubnetId') - private_ip = self._get_param('PrivateIpAddress') - associate_public_ip = self._get_param('AssociatePublicIpAddress') - key_name = self._get_param('KeyName') - ebs_optimized = self._get_param('EbsOptimized') - instance_initiated_shutdown_behavior = self._get_param("InstanceInitiatedShutdownBehavior") + instance_type = self._get_param("InstanceType", if_none="m1.small") + placement = self._get_param("Placement.AvailabilityZone") + subnet_id = self._get_param("SubnetId") + private_ip = self._get_param("PrivateIpAddress") + associate_public_ip = self._get_param("AssociatePublicIpAddress") + key_name = self._get_param("KeyName") + ebs_optimized = self._get_param("EbsOptimized") + instance_initiated_shutdown_behavior = self._get_param( + "InstanceInitiatedShutdownBehavior" + ) tags = self._parse_tag_specification("TagSpecification") region_name = self.region - if self.is_not_dryrun('RunInstance'): + if self.is_not_dryrun("RunInstance"): new_reservation = self.ec2_backend.add_instances( - image_id, min_count, user_data, security_group_names, - instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id, - owner_id=owner_id, key_name=key_name, security_group_ids=security_group_ids, - nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip, - tags=tags, ebs_optimized=ebs_optimized, instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior) + image_id, + min_count, + user_data, + security_group_names, + instance_type=instance_type, + placement=placement, + region_name=region_name, + subnet_id=subnet_id, + owner_id=owner_id, + key_name=key_name, + security_group_ids=security_group_ids, + nics=nics, + private_ip=private_ip, + associate_public_ip=associate_public_ip, + tags=tags, + ebs_optimized=ebs_optimized, + instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior, + ) template = self.response_template(EC2_RUN_INSTANCES) return template.render(reservation=new_reservation) def terminate_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('TerminateInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("TerminateInstance"): instances = self.ec2_backend.terminate_instances(instance_ids) autoscaling_backends[self.region].notify_terminate_instances(instance_ids) + elbv2_backends[self.region].notify_terminate_instances(instance_ids) template = self.response_template(EC2_TERMINATE_INSTANCES) return template.render(instances=instances) def reboot_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('RebootInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("RebootInstance"): instances = self.ec2_backend.reboot_instances(instance_ids) template = self.response_template(EC2_REBOOT_INSTANCES) return template.render(instances=instances) def stop_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('StopInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("StopInstance"): instances = self.ec2_backend.stop_instances(instance_ids) template = self.response_template(EC2_STOP_INSTANCES) return template.render(instances=instances) def start_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('StartInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("StartInstance"): instances = self.ec2_backend.start_instances(instance_ids) template = self.response_template(EC2_START_INSTANCES) return template.render(instances=instances) def describe_instance_status(self): - instance_ids = self._get_multi_param('InstanceId') - include_all_instances = self._get_param('IncludeAllInstances') == 'true' + instance_ids = self._get_multi_param("InstanceId") + include_all_instances = self._get_param("IncludeAllInstances") == "true" if instance_ids: - instances = self.ec2_backend.get_multi_instances_by_id( - instance_ids) + instances = self.ec2_backend.get_multi_instances_by_id(instance_ids) elif include_all_instances: instances = self.ec2_backend.all_instances() else: @@ -108,40 +128,45 @@ class InstanceResponse(BaseResponse): return template.render(instances=instances) def describe_instance_types(self): - instance_types = [InstanceType( - name='t1.micro', cores=1, memory=644874240, disk=0)] + instance_types = [ + InstanceType(name="t1.micro", cores=1, memory=644874240, disk=0) + ] template = self.response_template(EC2_DESCRIBE_INSTANCE_TYPES) return template.render(instance_types=instance_types) def describe_instance_attribute(self): # TODO this and modify below should raise IncorrectInstanceState if # instance not in stopped state - attribute = self._get_param('Attribute') - instance_id = self._get_param('InstanceId') + attribute = self._get_param("Attribute") + instance_id = self._get_param("InstanceId") instance, value = self.ec2_backend.describe_instance_attribute( - instance_id, attribute) + instance_id, attribute + ) if attribute == "groupSet": - template = self.response_template( - EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE) + template = self.response_template(EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE) else: template = self.response_template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE) return template.render(instance=instance, attribute=attribute, value=value) def modify_instance_attribute(self): - handlers = [self._dot_value_instance_attribute_handler, - self._block_device_mapping_handler, - self._security_grp_instance_attribute_handler] + handlers = [ + self._dot_value_instance_attribute_handler, + self._block_device_mapping_handler, + self._security_grp_instance_attribute_handler, + ] for handler in handlers: success = handler() if success: return success - msg = "This specific call to ModifyInstanceAttribute has not been" \ - " implemented in Moto yet. Feel free to open an issue at" \ - " https://github.com/spulec/moto/issues" + msg = ( + "This specific call to ModifyInstanceAttribute has not been" + " implemented in Moto yet. Feel free to open an issue at" + " https://github.com/spulec/moto/issues" + ) raise NotImplementedError(msg) def _block_device_mapping_handler(self): @@ -164,8 +189,8 @@ class InstanceResponse(BaseResponse): configuration, but it should be trivial to add anything else. """ mapping_counter = 1 - mapping_device_name_fmt = 'BlockDeviceMapping.%s.DeviceName' - mapping_del_on_term_fmt = 'BlockDeviceMapping.%s.Ebs.DeleteOnTermination' + mapping_device_name_fmt = "BlockDeviceMapping.%s.DeviceName" + mapping_del_on_term_fmt = "BlockDeviceMapping.%s.Ebs.DeleteOnTermination" while True: mapping_device_name = mapping_device_name_fmt % mapping_counter if mapping_device_name not in self.querystring.keys(): @@ -173,15 +198,14 @@ class InstanceResponse(BaseResponse): mapping_del_on_term = mapping_del_on_term_fmt % mapping_counter del_on_term_value_str = self.querystring[mapping_del_on_term][0] - del_on_term_value = True if 'true' == del_on_term_value_str else False + del_on_term_value = True if "true" == del_on_term_value_str else False device_name_value = self.querystring[mapping_device_name][0] - instance_id = self._get_param('InstanceId') + instance_id = self._get_param("InstanceId") instance = self.ec2_backend.get_instance(instance_id) - if self.is_not_dryrun('ModifyInstanceAttribute'): - block_device_type = instance.block_device_mapping[ - device_name_value] + if self.is_not_dryrun("ModifyInstanceAttribute"): + block_device_type = instance.block_device_mapping[device_name_value] block_device_type.delete_on_termination = del_on_term_value # +1 for the next device @@ -193,39 +217,43 @@ class InstanceResponse(BaseResponse): def _dot_value_instance_attribute_handler(self): attribute_key = None for key, value in self.querystring.items(): - if '.Value' in key: + if ".Value" in key: attribute_key = key break if not attribute_key: return - if self.is_not_dryrun('Modify' + attribute_key.split(".")[0]): + if self.is_not_dryrun("Modify" + attribute_key.split(".")[0]): value = self.querystring.get(attribute_key)[0] - normalized_attribute = camelcase_to_underscores( - attribute_key.split(".")[0]) - instance_id = self._get_param('InstanceId') + normalized_attribute = camelcase_to_underscores(attribute_key.split(".")[0]) + instance_id = self._get_param("InstanceId") self.ec2_backend.modify_instance_attribute( - instance_id, normalized_attribute, value) + instance_id, normalized_attribute, value + ) return EC2_MODIFY_INSTANCE_ATTRIBUTE def _security_grp_instance_attribute_handler(self): new_security_grp_list = [] for key, value in self.querystring.items(): - if 'GroupId.' in key: + if "GroupId." in key: new_security_grp_list.append(self.querystring.get(key)[0]) - instance_id = self._get_param('InstanceId') - if self.is_not_dryrun('ModifyInstanceSecurityGroups'): + instance_id = self._get_param("InstanceId") + if self.is_not_dryrun("ModifyInstanceSecurityGroups"): self.ec2_backend.modify_instance_security_groups( - instance_id, new_security_grp_list) + instance_id, new_security_grp_list + ) return EC2_MODIFY_INSTANCE_ATTRIBUTE -EC2_RUN_INSTANCES = """ +EC2_RUN_INSTANCES = ( + """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE {{ reservation.id }} - 123456789012 + """ + + ACCOUNT_ID + + """ sg-245f6a01 @@ -307,7 +335,9 @@ EC2_RUN_INSTANCES = """ in-use 1b:2b:3c:4d:5e:6f {{ nic.private_ip_address }} @@ -330,7 +360,9 @@ EC2_RUN_INSTANCES = """ {% endif %} @@ -340,7 +372,9 @@ EC2_RUN_INSTANCES = """ {% endif %} @@ -352,14 +386,18 @@ EC2_RUN_INSTANCES = """ +EC2_DESCRIBE_INSTANCES = ( + """ fdcdcab1-ae5c-489e-9c33-4637c5dda355 {% for reservation in reservations %} {{ reservation.id }} - 123456789012 + """ + + ACCOUNT_ID + + """ {% for group in reservation.dynamic_group_list %} @@ -452,7 +490,9 @@ EC2_DESCRIBE_INSTANCES = """ {% if instance.get_tags() %} {% for tag in instance.get_tags() %} @@ -475,7 +515,9 @@ EC2_DESCRIBE_INSTANCES = """ in-use 1b:2b:3c:4d:5e:6f {{ nic.private_ip_address }} @@ -502,7 +544,9 @@ EC2_DESCRIBE_INSTANCES = """ {% endif %} @@ -512,7 +556,9 @@ EC2_DESCRIBE_INSTANCES = """ {% endif %} @@ -530,6 +576,7 @@ EC2_DESCRIBE_INSTANCES = """ diff --git a/moto/ec2/responses/internet_gateways.py b/moto/ec2/responses/internet_gateways.py index ebea14adf..d232b3b05 100644 --- a/moto/ec2/responses/internet_gateways.py +++ b/moto/ec2/responses/internet_gateways.py @@ -1,29 +1,26 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse -from moto.ec2.utils import ( - filters_from_querystring, -) +from moto.ec2.utils import filters_from_querystring class InternetGateways(BaseResponse): - def attach_internet_gateway(self): - igw_id = self._get_param('InternetGatewayId') - vpc_id = self._get_param('VpcId') - if self.is_not_dryrun('AttachInternetGateway'): + igw_id = self._get_param("InternetGatewayId") + vpc_id = self._get_param("VpcId") + if self.is_not_dryrun("AttachInternetGateway"): self.ec2_backend.attach_internet_gateway(igw_id, vpc_id) template = self.response_template(ATTACH_INTERNET_GATEWAY_RESPONSE) return template.render() def create_internet_gateway(self): - if self.is_not_dryrun('CreateInternetGateway'): + if self.is_not_dryrun("CreateInternetGateway"): igw = self.ec2_backend.create_internet_gateway() template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE) return template.render(internet_gateway=igw) def delete_internet_gateway(self): - igw_id = self._get_param('InternetGatewayId') - if self.is_not_dryrun('DeleteInternetGateway'): + igw_id = self._get_param("InternetGatewayId") + if self.is_not_dryrun("DeleteInternetGateway"): self.ec2_backend.delete_internet_gateway(igw_id) template = self.response_template(DELETE_INTERNET_GATEWAY_RESPONSE) return template.render() @@ -33,10 +30,10 @@ class InternetGateways(BaseResponse): if "InternetGatewayId.1" in self.querystring: igw_ids = self._get_multi_param("InternetGatewayId") igws = self.ec2_backend.describe_internet_gateways( - igw_ids, filters=filter_dict) + igw_ids, filters=filter_dict + ) else: - igws = self.ec2_backend.describe_internet_gateways( - filters=filter_dict) + igws = self.ec2_backend.describe_internet_gateways(filters=filter_dict) template = self.response_template(DESCRIBE_INTERNET_GATEWAYS_RESPONSE) return template.render(internet_gateways=igws) @@ -44,20 +41,20 @@ class InternetGateways(BaseResponse): def detach_internet_gateway(self): # TODO validate no instances with EIPs in VPC before detaching # raise else DependencyViolationError() - igw_id = self._get_param('InternetGatewayId') - vpc_id = self._get_param('VpcId') - if self.is_not_dryrun('DetachInternetGateway'): + igw_id = self._get_param("InternetGatewayId") + vpc_id = self._get_param("VpcId") + if self.is_not_dryrun("DetachInternetGateway"): self.ec2_backend.detach_internet_gateway(igw_id, vpc_id) template = self.response_template(DETACH_INTERNET_GATEWAY_RESPONSE) return template.render() -ATTACH_INTERNET_GATEWAY_RESPONSE = u""" +ATTACH_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE true """ -CREATE_INTERNET_GATEWAY_RESPONSE = u""" +CREATE_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE {{ internet_gateway.id }} @@ -75,12 +72,12 @@ CREATE_INTERNET_GATEWAY_RESPONSE = u""" +DELETE_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE true """ -DESCRIBE_INTERNET_GATEWAYS_RESPONSE = u""" 59dbff89-35bd-4eac-99ed-be587EXAMPLE @@ -112,7 +109,7 @@ DESCRIBE_INTERNET_GATEWAYS_RESPONSE = u""" """ -DETACH_INTERNET_GATEWAY_RESPONSE = u""" +DETACH_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE true """ diff --git a/moto/ec2/responses/ip_addresses.py b/moto/ec2/responses/ip_addresses.py index fab5cbddc..789abfdec 100644 --- a/moto/ec2/responses/ip_addresses.py +++ b/moto/ec2/responses/ip_addresses.py @@ -4,13 +4,14 @@ from moto.core.responses import BaseResponse class IPAddresses(BaseResponse): - def assign_private_ip_addresses(self): - if self.is_not_dryrun('AssignPrivateIPAddress'): + if self.is_not_dryrun("AssignPrivateIPAddress"): raise NotImplementedError( - 'IPAddresses.assign_private_ip_addresses is not yet implemented') + "IPAddresses.assign_private_ip_addresses is not yet implemented" + ) def unassign_private_ip_addresses(self): - if self.is_not_dryrun('UnAssignPrivateIPAddress'): + if self.is_not_dryrun("UnAssignPrivateIPAddress"): raise NotImplementedError( - 'IPAddresses.unassign_private_ip_addresses is not yet implemented') + "IPAddresses.unassign_private_ip_addresses is not yet implemented" + ) diff --git a/moto/ec2/responses/key_pairs.py b/moto/ec2/responses/key_pairs.py index d927bddda..fa2e60904 100644 --- a/moto/ec2/responses/key_pairs.py +++ b/moto/ec2/responses/key_pairs.py @@ -5,32 +5,32 @@ from moto.ec2.utils import filters_from_querystring class KeyPairs(BaseResponse): - def create_key_pair(self): - name = self._get_param('KeyName') - if self.is_not_dryrun('CreateKeyPair'): + name = self._get_param("KeyName") + if self.is_not_dryrun("CreateKeyPair"): keypair = self.ec2_backend.create_key_pair(name) template = self.response_template(CREATE_KEY_PAIR_RESPONSE) return template.render(keypair=keypair) def delete_key_pair(self): - name = self._get_param('KeyName') - if self.is_not_dryrun('DeleteKeyPair'): - success = six.text_type( - self.ec2_backend.delete_key_pair(name)).lower() - return self.response_template(DELETE_KEY_PAIR_RESPONSE).render(success=success) + name = self._get_param("KeyName") + if self.is_not_dryrun("DeleteKeyPair"): + success = six.text_type(self.ec2_backend.delete_key_pair(name)).lower() + return self.response_template(DELETE_KEY_PAIR_RESPONSE).render( + success=success + ) def describe_key_pairs(self): - names = self._get_multi_param('KeyName') + names = self._get_multi_param("KeyName") filters = filters_from_querystring(self.querystring) keypairs = self.ec2_backend.describe_key_pairs(names, filters) template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE) return template.render(keypairs=keypairs) def import_key_pair(self): - name = self._get_param('KeyName') - material = self._get_param('PublicKeyMaterial') - if self.is_not_dryrun('ImportKeyPair'): + name = self._get_param("KeyName") + material = self._get_param("PublicKeyMaterial") + if self.is_not_dryrun("ImportKeyPair"): keypair = self.ec2_backend.import_key_pair(name, material) template = self.response_template(IMPORT_KEYPAIR_RESPONSE) return template.render(keypair=keypair) diff --git a/moto/ec2/responses/launch_templates.py b/moto/ec2/responses/launch_templates.py index a8d92a928..22faba539 100644 --- a/moto/ec2/responses/launch_templates.py +++ b/moto/ec2/responses/launch_templates.py @@ -10,9 +10,9 @@ from xml.dom import minidom def xml_root(name): - root = ElementTree.Element(name, { - "xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/" - }) + root = ElementTree.Element( + name, {"xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/"} + ) request_id = str(uuid.uuid4()) + "example" ElementTree.SubElement(root, "requestId").text = request_id @@ -22,10 +22,10 @@ def xml_root(name): def xml_serialize(tree, key, value): name = key[0].lower() + key[1:] if isinstance(value, list): - if name[-1] == 's': + if name[-1] == "s": name = name[:-1] - name = name + 'Set' + name = name + "Set" node = ElementTree.SubElement(tree, name) @@ -36,17 +36,19 @@ def xml_serialize(tree, key, value): xml_serialize(node, dictkey, dictvalue) elif isinstance(value, list): for item in value: - xml_serialize(node, 'item', item) + xml_serialize(node, "item", item) elif value is None: pass else: - raise NotImplementedError("Don't know how to serialize \"{}\" to xml".format(value.__class__)) + raise NotImplementedError( + 'Don\'t know how to serialize "{}" to xml'.format(value.__class__) + ) def pretty_xml(tree): - rough = ElementTree.tostring(tree, 'utf-8') + rough = ElementTree.tostring(tree, "utf-8") parsed = minidom.parseString(rough) - return parsed.toprettyxml(indent=' ') + return parsed.toprettyxml(indent=" ") def parse_object(raw_data): @@ -92,68 +94,87 @@ def parse_lists(data): class LaunchTemplates(BaseResponse): def create_launch_template(self): - name = self._get_param('LaunchTemplateName') - version_description = self._get_param('VersionDescription') + name = self._get_param("LaunchTemplateName") + version_description = self._get_param("VersionDescription") tag_spec = self._parse_tag_specification("TagSpecification") - raw_template_data = self._get_dict_param('LaunchTemplateData.') + raw_template_data = self._get_dict_param("LaunchTemplateData.") parsed_template_data = parse_object(raw_template_data) - if self.is_not_dryrun('CreateLaunchTemplate'): + if self.is_not_dryrun("CreateLaunchTemplate"): if tag_spec: - if 'TagSpecifications' not in parsed_template_data: - parsed_template_data['TagSpecifications'] = [] + if "TagSpecifications" not in parsed_template_data: + parsed_template_data["TagSpecifications"] = [] converted_tag_spec = [] for resource_type, tags in six.iteritems(tag_spec): - converted_tag_spec.append({ - "ResourceType": resource_type, - "Tags": [{"Key": key, "Value": value} for key, value in six.iteritems(tags)], - }) + converted_tag_spec.append( + { + "ResourceType": resource_type, + "Tags": [ + {"Key": key, "Value": value} + for key, value in six.iteritems(tags) + ], + } + ) - parsed_template_data['TagSpecifications'].extend(converted_tag_spec) + parsed_template_data["TagSpecifications"].extend(converted_tag_spec) - template = self.ec2_backend.create_launch_template(name, version_description, parsed_template_data) + template = self.ec2_backend.create_launch_template( + name, version_description, parsed_template_data + ) version = template.default_version() tree = xml_root("CreateLaunchTemplateResponse") - xml_serialize(tree, "launchTemplate", { - "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersionNumber": template.default_version_number, - "latestVersionNumber": version.number, - "launchTemplateId": template.id, - "launchTemplateName": template.name - }) + xml_serialize( + tree, + "launchTemplate", + { + "createTime": version.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersionNumber": template.default_version_number, + "latestVersionNumber": version.number, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + }, + ) return pretty_xml(tree) def create_launch_template_version(self): - name = self._get_param('LaunchTemplateName') - tmpl_id = self._get_param('LaunchTemplateId') + name = self._get_param("LaunchTemplateName") + tmpl_id = self._get_param("LaunchTemplateId") if name: template = self.ec2_backend.get_launch_template_by_name(name) if tmpl_id: template = self.ec2_backend.get_launch_template(tmpl_id) - version_description = self._get_param('VersionDescription') + version_description = self._get_param("VersionDescription") - raw_template_data = self._get_dict_param('LaunchTemplateData.') + raw_template_data = self._get_dict_param("LaunchTemplateData.") template_data = parse_object(raw_template_data) - if self.is_not_dryrun('CreateLaunchTemplate'): + if self.is_not_dryrun("CreateLaunchTemplate"): version = template.create_version(template_data, version_description) tree = xml_root("CreateLaunchTemplateVersionResponse") - xml_serialize(tree, "launchTemplateVersion", { - "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersion": template.is_default(version), - "launchTemplateData": version.data, - "launchTemplateId": template.id, - "launchTemplateName": template.name, - "versionDescription": version.description, - "versionNumber": version.number, - }) + xml_serialize( + tree, + "launchTemplateVersion", + { + "createTime": version.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersion": template.is_default(version), + "launchTemplateData": version.data, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + "versionDescription": version.description, + "versionNumber": version.number, + }, + ) return pretty_xml(tree) # def delete_launch_template(self): @@ -163,8 +184,8 @@ class LaunchTemplates(BaseResponse): # pass def describe_launch_template_versions(self): - name = self._get_param('LaunchTemplateName') - template_id = self._get_param('LaunchTemplateId') + name = self._get_param("LaunchTemplateName") + template_id = self._get_param("LaunchTemplateId") if name: template = self.ec2_backend.get_launch_template_by_name(name) if template_id: @@ -177,12 +198,15 @@ class LaunchTemplates(BaseResponse): filters = filters_from_querystring(self.querystring) if filters: - raise FilterNotImplementedError("all filters", "DescribeLaunchTemplateVersions") + raise FilterNotImplementedError( + "all filters", "DescribeLaunchTemplateVersions" + ) - if self.is_not_dryrun('DescribeLaunchTemplateVersions'): - tree = ElementTree.Element("DescribeLaunchTemplateVersionsResponse", { - "xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/", - }) + if self.is_not_dryrun("DescribeLaunchTemplateVersions"): + tree = ElementTree.Element( + "DescribeLaunchTemplateVersionsResponse", + {"xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/"}, + ) request_id = ElementTree.SubElement(tree, "requestId") request_id.text = "65cadec1-b364-4354-8ca8-4176dexample" @@ -209,16 +233,22 @@ class LaunchTemplates(BaseResponse): ret_versions = ret_versions[:max_results] for version in ret_versions: - xml_serialize(versions_node, "item", { - "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersion": True, - "launchTemplateData": version.data, - "launchTemplateId": template.id, - "launchTemplateName": template.name, - "versionDescription": version.description, - "versionNumber": version.number, - }) + xml_serialize( + versions_node, + "item", + { + "createTime": version.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersion": True, + "launchTemplateData": version.data, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + "versionDescription": version.description, + "versionNumber": version.number, + }, + ) return pretty_xml(tree) @@ -232,19 +262,29 @@ class LaunchTemplates(BaseResponse): tree = ElementTree.Element("DescribeLaunchTemplatesResponse") templates_node = ElementTree.SubElement(tree, "launchTemplates") - templates = self.ec2_backend.get_launch_templates(template_names=template_names, template_ids=template_ids, filters=filters) + templates = self.ec2_backend.get_launch_templates( + template_names=template_names, + template_ids=template_ids, + filters=filters, + ) templates = templates[:max_results] for template in templates: - xml_serialize(templates_node, "item", { - "createTime": template.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersionNumber": template.default_version_number, - "latestVersionNumber": template.latest_version_number, - "launchTemplateId": template.id, - "launchTemplateName": template.name, - }) + xml_serialize( + templates_node, + "item", + { + "createTime": template.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersionNumber": template.default_version_number, + "latestVersionNumber": template.latest_version_number, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + }, + ) return pretty_xml(tree) diff --git a/moto/ec2/responses/monitoring.py b/moto/ec2/responses/monitoring.py index 2024abe7e..4ef8db087 100644 --- a/moto/ec2/responses/monitoring.py +++ b/moto/ec2/responses/monitoring.py @@ -3,13 +3,14 @@ from moto.core.responses import BaseResponse class Monitoring(BaseResponse): - def monitor_instances(self): - if self.is_not_dryrun('MonitorInstances'): + if self.is_not_dryrun("MonitorInstances"): raise NotImplementedError( - 'Monitoring.monitor_instances is not yet implemented') + "Monitoring.monitor_instances is not yet implemented" + ) def unmonitor_instances(self): - if self.is_not_dryrun('UnMonitorInstances'): + if self.is_not_dryrun("UnMonitorInstances"): raise NotImplementedError( - 'Monitoring.unmonitor_instances is not yet implemented') + "Monitoring.unmonitor_instances is not yet implemented" + ) diff --git a/moto/ec2/responses/nat_gateways.py b/moto/ec2/responses/nat_gateways.py index ce9479e82..efa5c2656 100644 --- a/moto/ec2/responses/nat_gateways.py +++ b/moto/ec2/responses/nat_gateways.py @@ -4,17 +4,17 @@ from moto.ec2.utils import filters_from_querystring class NatGateways(BaseResponse): - def create_nat_gateway(self): - subnet_id = self._get_param('SubnetId') - allocation_id = self._get_param('AllocationId') + subnet_id = self._get_param("SubnetId") + allocation_id = self._get_param("AllocationId") nat_gateway = self.ec2_backend.create_nat_gateway( - subnet_id=subnet_id, allocation_id=allocation_id) + subnet_id=subnet_id, allocation_id=allocation_id + ) template = self.response_template(CREATE_NAT_GATEWAY) return template.render(nat_gateway=nat_gateway) def delete_nat_gateway(self): - nat_gateway_id = self._get_param('NatGatewayId') + nat_gateway_id = self._get_param("NatGatewayId") nat_gateway = self.ec2_backend.delete_nat_gateway(nat_gateway_id) template = self.response_template(DELETE_NAT_GATEWAY_RESPONSE) return template.render(nat_gateway=nat_gateway) diff --git a/moto/ec2/responses/network_acls.py b/moto/ec2/responses/network_acls.py index 97f370306..8d89e6065 100644 --- a/moto/ec2/responses/network_acls.py +++ b/moto/ec2/responses/network_acls.py @@ -4,82 +4,95 @@ from moto.ec2.utils import filters_from_querystring class NetworkACLs(BaseResponse): - def create_network_acl(self): - vpc_id = self._get_param('VpcId') + vpc_id = self._get_param("VpcId") network_acl = self.ec2_backend.create_network_acl(vpc_id) template = self.response_template(CREATE_NETWORK_ACL_RESPONSE) return template.render(network_acl=network_acl) def create_network_acl_entry(self): - network_acl_id = self._get_param('NetworkAclId') - rule_number = self._get_param('RuleNumber') - protocol = self._get_param('Protocol') - rule_action = self._get_param('RuleAction') - egress = self._get_param('Egress') - cidr_block = self._get_param('CidrBlock') - icmp_code = self._get_param('Icmp.Code') - icmp_type = self._get_param('Icmp.Type') - port_range_from = self._get_param('PortRange.From') - port_range_to = self._get_param('PortRange.To') + network_acl_id = self._get_param("NetworkAclId") + rule_number = self._get_param("RuleNumber") + protocol = self._get_param("Protocol") + rule_action = self._get_param("RuleAction") + egress = self._get_param("Egress") + cidr_block = self._get_param("CidrBlock") + icmp_code = self._get_param("Icmp.Code") + icmp_type = self._get_param("Icmp.Type") + port_range_from = self._get_param("PortRange.From") + port_range_to = self._get_param("PortRange.To") network_acl_entry = self.ec2_backend.create_network_acl_entry( - network_acl_id, rule_number, protocol, rule_action, - egress, cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) template = self.response_template(CREATE_NETWORK_ACL_ENTRY_RESPONSE) return template.render(network_acl_entry=network_acl_entry) def delete_network_acl(self): - network_acl_id = self._get_param('NetworkAclId') + network_acl_id = self._get_param("NetworkAclId") self.ec2_backend.delete_network_acl(network_acl_id) template = self.response_template(DELETE_NETWORK_ACL_ASSOCIATION) return template.render() def delete_network_acl_entry(self): - network_acl_id = self._get_param('NetworkAclId') - rule_number = self._get_param('RuleNumber') - egress = self._get_param('Egress') + network_acl_id = self._get_param("NetworkAclId") + rule_number = self._get_param("RuleNumber") + egress = self._get_param("Egress") self.ec2_backend.delete_network_acl_entry(network_acl_id, rule_number, egress) template = self.response_template(DELETE_NETWORK_ACL_ENTRY_RESPONSE) return template.render() def replace_network_acl_entry(self): - network_acl_id = self._get_param('NetworkAclId') - rule_number = self._get_param('RuleNumber') - protocol = self._get_param('Protocol') - rule_action = self._get_param('RuleAction') - egress = self._get_param('Egress') - cidr_block = self._get_param('CidrBlock') - icmp_code = self._get_param('Icmp.Code') - icmp_type = self._get_param('Icmp.Type') - port_range_from = self._get_param('PortRange.From') - port_range_to = self._get_param('PortRange.To') + network_acl_id = self._get_param("NetworkAclId") + rule_number = self._get_param("RuleNumber") + protocol = self._get_param("Protocol") + rule_action = self._get_param("RuleAction") + egress = self._get_param("Egress") + cidr_block = self._get_param("CidrBlock") + icmp_code = self._get_param("Icmp.Code") + icmp_type = self._get_param("Icmp.Type") + port_range_from = self._get_param("PortRange.From") + port_range_to = self._get_param("PortRange.To") self.ec2_backend.replace_network_acl_entry( - network_acl_id, rule_number, protocol, rule_action, - egress, cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) template = self.response_template(REPLACE_NETWORK_ACL_ENTRY_RESPONSE) return template.render() def describe_network_acls(self): - network_acl_ids = self._get_multi_param('NetworkAclId') + network_acl_ids = self._get_multi_param("NetworkAclId") filters = filters_from_querystring(self.querystring) - network_acls = self.ec2_backend.get_all_network_acls( - network_acl_ids, filters) + network_acls = self.ec2_backend.get_all_network_acls(network_acl_ids, filters) template = self.response_template(DESCRIBE_NETWORK_ACL_RESPONSE) return template.render(network_acls=network_acls) def replace_network_acl_association(self): - association_id = self._get_param('AssociationId') - network_acl_id = self._get_param('NetworkAclId') + association_id = self._get_param("AssociationId") + network_acl_id = self._get_param("NetworkAclId") association = self.ec2_backend.replace_network_acl_association( - association_id, - network_acl_id + association_id, network_acl_id ) template = self.response_template(REPLACE_NETWORK_ACL_ASSOCIATION) return template.render(association=association) diff --git a/moto/ec2/responses/placement_groups.py b/moto/ec2/responses/placement_groups.py index 06930f700..2a7ade653 100644 --- a/moto/ec2/responses/placement_groups.py +++ b/moto/ec2/responses/placement_groups.py @@ -3,17 +3,19 @@ from moto.core.responses import BaseResponse class PlacementGroups(BaseResponse): - def create_placement_group(self): - if self.is_not_dryrun('CreatePlacementGroup'): + if self.is_not_dryrun("CreatePlacementGroup"): raise NotImplementedError( - 'PlacementGroups.create_placement_group is not yet implemented') + "PlacementGroups.create_placement_group is not yet implemented" + ) def delete_placement_group(self): - if self.is_not_dryrun('DeletePlacementGroup'): + if self.is_not_dryrun("DeletePlacementGroup"): raise NotImplementedError( - 'PlacementGroups.delete_placement_group is not yet implemented') + "PlacementGroups.delete_placement_group is not yet implemented" + ) def describe_placement_groups(self): raise NotImplementedError( - 'PlacementGroups.describe_placement_groups is not yet implemented') + "PlacementGroups.describe_placement_groups is not yet implemented" + ) diff --git a/moto/ec2/responses/reserved_instances.py b/moto/ec2/responses/reserved_instances.py index 07bd6661e..23a2b8715 100644 --- a/moto/ec2/responses/reserved_instances.py +++ b/moto/ec2/responses/reserved_instances.py @@ -3,30 +3,35 @@ from moto.core.responses import BaseResponse class ReservedInstances(BaseResponse): - def cancel_reserved_instances_listing(self): - if self.is_not_dryrun('CancelReservedInstances'): + if self.is_not_dryrun("CancelReservedInstances"): raise NotImplementedError( - 'ReservedInstances.cancel_reserved_instances_listing is not yet implemented') + "ReservedInstances.cancel_reserved_instances_listing is not yet implemented" + ) def create_reserved_instances_listing(self): - if self.is_not_dryrun('CreateReservedInstances'): + if self.is_not_dryrun("CreateReservedInstances"): raise NotImplementedError( - 'ReservedInstances.create_reserved_instances_listing is not yet implemented') + "ReservedInstances.create_reserved_instances_listing is not yet implemented" + ) def describe_reserved_instances(self): raise NotImplementedError( - 'ReservedInstances.describe_reserved_instances is not yet implemented') + "ReservedInstances.describe_reserved_instances is not yet implemented" + ) def describe_reserved_instances_listings(self): raise NotImplementedError( - 'ReservedInstances.describe_reserved_instances_listings is not yet implemented') + "ReservedInstances.describe_reserved_instances_listings is not yet implemented" + ) def describe_reserved_instances_offerings(self): raise NotImplementedError( - 'ReservedInstances.describe_reserved_instances_offerings is not yet implemented') + "ReservedInstances.describe_reserved_instances_offerings is not yet implemented" + ) def purchase_reserved_instances_offering(self): - if self.is_not_dryrun('PurchaseReservedInstances'): + if self.is_not_dryrun("PurchaseReservedInstances"): raise NotImplementedError( - 'ReservedInstances.purchase_reserved_instances_offering is not yet implemented') + "ReservedInstances.purchase_reserved_instances_offering is not yet implemented" + ) diff --git a/moto/ec2/responses/route_tables.py b/moto/ec2/responses/route_tables.py index 3878f325d..b5d65f831 100644 --- a/moto/ec2/responses/route_tables.py +++ b/moto/ec2/responses/route_tables.py @@ -4,89 +4,96 @@ from moto.ec2.utils import filters_from_querystring class RouteTables(BaseResponse): - def associate_route_table(self): - route_table_id = self._get_param('RouteTableId') - subnet_id = self._get_param('SubnetId') + route_table_id = self._get_param("RouteTableId") + subnet_id = self._get_param("SubnetId") association_id = self.ec2_backend.associate_route_table( - route_table_id, subnet_id) + route_table_id, subnet_id + ) template = self.response_template(ASSOCIATE_ROUTE_TABLE_RESPONSE) return template.render(association_id=association_id) def create_route(self): - route_table_id = self._get_param('RouteTableId') - destination_cidr_block = self._get_param('DestinationCidrBlock') - gateway_id = self._get_param('GatewayId') - instance_id = self._get_param('InstanceId') - interface_id = self._get_param('NetworkInterfaceId') - pcx_id = self._get_param('VpcPeeringConnectionId') + route_table_id = self._get_param("RouteTableId") + destination_cidr_block = self._get_param("DestinationCidrBlock") + gateway_id = self._get_param("GatewayId") + instance_id = self._get_param("InstanceId") + nat_gateway_id = self._get_param("NatGatewayId") + interface_id = self._get_param("NetworkInterfaceId") + pcx_id = self._get_param("VpcPeeringConnectionId") - self.ec2_backend.create_route(route_table_id, destination_cidr_block, - gateway_id=gateway_id, - instance_id=instance_id, - interface_id=interface_id, - vpc_peering_connection_id=pcx_id) + self.ec2_backend.create_route( + route_table_id, + destination_cidr_block, + gateway_id=gateway_id, + instance_id=instance_id, + nat_gateway_id=nat_gateway_id, + interface_id=interface_id, + vpc_peering_connection_id=pcx_id, + ) template = self.response_template(CREATE_ROUTE_RESPONSE) return template.render() def create_route_table(self): - vpc_id = self._get_param('VpcId') + vpc_id = self._get_param("VpcId") route_table = self.ec2_backend.create_route_table(vpc_id) template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE) return template.render(route_table=route_table) def delete_route(self): - route_table_id = self._get_param('RouteTableId') - destination_cidr_block = self._get_param('DestinationCidrBlock') + route_table_id = self._get_param("RouteTableId") + destination_cidr_block = self._get_param("DestinationCidrBlock") self.ec2_backend.delete_route(route_table_id, destination_cidr_block) template = self.response_template(DELETE_ROUTE_RESPONSE) return template.render() def delete_route_table(self): - route_table_id = self._get_param('RouteTableId') + route_table_id = self._get_param("RouteTableId") self.ec2_backend.delete_route_table(route_table_id) template = self.response_template(DELETE_ROUTE_TABLE_RESPONSE) return template.render() def describe_route_tables(self): - route_table_ids = self._get_multi_param('RouteTableId') + route_table_ids = self._get_multi_param("RouteTableId") filters = filters_from_querystring(self.querystring) - route_tables = self.ec2_backend.get_all_route_tables( - route_table_ids, filters) + route_tables = self.ec2_backend.get_all_route_tables(route_table_ids, filters) template = self.response_template(DESCRIBE_ROUTE_TABLES_RESPONSE) return template.render(route_tables=route_tables) def disassociate_route_table(self): - association_id = self._get_param('AssociationId') + association_id = self._get_param("AssociationId") self.ec2_backend.disassociate_route_table(association_id) template = self.response_template(DISASSOCIATE_ROUTE_TABLE_RESPONSE) return template.render() def replace_route(self): - route_table_id = self._get_param('RouteTableId') - destination_cidr_block = self._get_param('DestinationCidrBlock') - gateway_id = self._get_param('GatewayId') - instance_id = self._get_param('InstanceId') - interface_id = self._get_param('NetworkInterfaceId') - pcx_id = self._get_param('VpcPeeringConnectionId') + route_table_id = self._get_param("RouteTableId") + destination_cidr_block = self._get_param("DestinationCidrBlock") + gateway_id = self._get_param("GatewayId") + instance_id = self._get_param("InstanceId") + interface_id = self._get_param("NetworkInterfaceId") + pcx_id = self._get_param("VpcPeeringConnectionId") - self.ec2_backend.replace_route(route_table_id, destination_cidr_block, - gateway_id=gateway_id, - instance_id=instance_id, - interface_id=interface_id, - vpc_peering_connection_id=pcx_id) + self.ec2_backend.replace_route( + route_table_id, + destination_cidr_block, + gateway_id=gateway_id, + instance_id=instance_id, + interface_id=interface_id, + vpc_peering_connection_id=pcx_id, + ) template = self.response_template(REPLACE_ROUTE_RESPONSE) return template.render() def replace_route_table_association(self): - route_table_id = self._get_param('RouteTableId') - association_id = self._get_param('AssociationId') + route_table_id = self._get_param("RouteTableId") + association_id = self._get_param("AssociationId") new_association_id = self.ec2_backend.replace_route_table_association( - association_id, route_table_id) - template = self.response_template( - REPLACE_ROUTE_TABLE_ASSOCIATION_RESPONSE) + association_id, route_table_id + ) + template = self.response_template(REPLACE_ROUTE_TABLE_ASSOCIATION_RESPONSE) return template.render(association_id=new_association_id) @@ -168,6 +175,10 @@ DESCRIBE_ROUTE_TABLES_RESPONSE = """ CreateRoute blackhole {% endif %} + {% if route.nat_gateway %} + {{ route.nat_gateway.id }} + active + {% endif %} {% endfor %} diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index 4aecfcf78..6f2926f61 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse from moto.ec2.utils import filters_from_querystring +from moto.core import ACCOUNT_ID def try_parse_int(value, default=None): @@ -12,37 +13,35 @@ def try_parse_int(value, default=None): def parse_sg_attributes_from_dict(sg_attributes): - ip_protocol = sg_attributes.get('IpProtocol', [None])[0] - from_port = sg_attributes.get('FromPort', [None])[0] - to_port = sg_attributes.get('ToPort', [None])[0] + ip_protocol = sg_attributes.get("IpProtocol", [None])[0] + from_port = sg_attributes.get("FromPort", [None])[0] + to_port = sg_attributes.get("ToPort", [None])[0] ip_ranges = [] - ip_ranges_tree = sg_attributes.get('IpRanges') or {} + ip_ranges_tree = sg_attributes.get("IpRanges") or {} for ip_range_idx in sorted(ip_ranges_tree.keys()): - ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0]) + ip_ranges.append(ip_ranges_tree[ip_range_idx]["CidrIp"][0]) source_groups = [] source_group_ids = [] - groups_tree = sg_attributes.get('Groups') or {} + groups_tree = sg_attributes.get("Groups") or {} for group_idx in sorted(groups_tree.keys()): group_dict = groups_tree[group_idx] - if 'GroupId' in group_dict: - source_group_ids.append(group_dict['GroupId'][0]) - elif 'GroupName' in group_dict: - source_groups.append(group_dict['GroupName'][0]) + if "GroupId" in group_dict: + source_group_ids.append(group_dict["GroupId"][0]) + elif "GroupName" in group_dict: + source_groups.append(group_dict["GroupName"][0]) return ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids class SecurityGroups(BaseResponse): - def _process_rules_from_querystring(self): - group_name_or_id = (self._get_param('GroupName') or - self._get_param('GroupId')) + group_name_or_id = self._get_param("GroupName") or self._get_param("GroupId") querytree = {} for key, value in self.querystring.items(): - key_splitted = key.split('.') + key_splitted = key.split(".") key_splitted = [try_parse_int(e, e) for e in key_splitted] d = querytree @@ -52,41 +51,70 @@ class SecurityGroups(BaseResponse): d = d[subkey] d[key_splitted[-1]] = value - if 'IpPermissions' not in querytree: + if "IpPermissions" not in querytree: # Handle single rule syntax - ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids = parse_sg_attributes_from_dict(querytree) - yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, - source_groups, source_group_ids) + ( + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) = parse_sg_attributes_from_dict(querytree) + yield ( + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) - ip_permissions = querytree.get('IpPermissions') or {} + ip_permissions = querytree.get("IpPermissions") or {} for ip_permission_idx in sorted(ip_permissions.keys()): ip_permission = ip_permissions[ip_permission_idx] - ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids = parse_sg_attributes_from_dict(ip_permission) + ( + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) = parse_sg_attributes_from_dict(ip_permission) - yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, - source_groups, source_group_ids) + yield ( + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) def authorize_security_group_egress(self): - if self.is_not_dryrun('GrantSecurityGroupEgress'): + if self.is_not_dryrun("GrantSecurityGroupEgress"): for args in self._process_rules_from_querystring(): self.ec2_backend.authorize_security_group_egress(*args) return AUTHORIZE_SECURITY_GROUP_EGRESS_RESPONSE def authorize_security_group_ingress(self): - if self.is_not_dryrun('GrantSecurityGroupIngress'): + if self.is_not_dryrun("GrantSecurityGroupIngress"): for args in self._process_rules_from_querystring(): self.ec2_backend.authorize_security_group_ingress(*args) return AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE def create_security_group(self): - name = self._get_param('GroupName') - description = self._get_param('GroupDescription') - vpc_id = self._get_param('VpcId') + name = self._get_param("GroupName") + description = self._get_param("GroupDescription") + vpc_id = self._get_param("VpcId") - if self.is_not_dryrun('CreateSecurityGroup'): + if self.is_not_dryrun("CreateSecurityGroup"): group = self.ec2_backend.create_security_group( - name, description, vpc_id=vpc_id) + name, description, vpc_id=vpc_id + ) template = self.response_template(CREATE_SECURITY_GROUP_RESPONSE) return template.render(group=group) @@ -95,10 +123,10 @@ class SecurityGroups(BaseResponse): # See # http://docs.aws.amazon.com/AWSEC2/latest/APIReference/ApiReference-query-DeleteSecurityGroup.html - name = self._get_param('GroupName') - sg_id = self._get_param('GroupId') + name = self._get_param("GroupName") + sg_id = self._get_param("GroupId") - if self.is_not_dryrun('DeleteSecurityGroup'): + if self.is_not_dryrun("DeleteSecurityGroup"): if name: self.ec2_backend.delete_security_group(name) elif sg_id: @@ -112,16 +140,14 @@ class SecurityGroups(BaseResponse): filters = filters_from_querystring(self.querystring) groups = self.ec2_backend.describe_security_groups( - group_ids=group_ids, - groupnames=groupnames, - filters=filters + group_ids=group_ids, groupnames=groupnames, filters=filters ) template = self.response_template(DESCRIBE_SECURITY_GROUPS_RESPONSE) return template.render(groups=groups) def revoke_security_group_egress(self): - if self.is_not_dryrun('RevokeSecurityGroupEgress'): + if self.is_not_dryrun("RevokeSecurityGroupEgress"): for args in self._process_rules_from_querystring(): success = self.ec2_backend.revoke_security_group_egress(*args) if not success: @@ -129,7 +155,7 @@ class SecurityGroups(BaseResponse): return REVOKE_SECURITY_GROUP_EGRESS_RESPONSE def revoke_security_group_ingress(self): - if self.is_not_dryrun('RevokeSecurityGroupIngress'): + if self.is_not_dryrun("RevokeSecurityGroupIngress"): for args in self._process_rules_from_querystring(): self.ec2_backend.revoke_security_group_ingress(*args) return REVOKE_SECURITY_GROUP_INGRESS_REPONSE @@ -146,12 +172,15 @@ DELETE_GROUP_RESPONSE = """ +DESCRIBE_SECURITY_GROUPS_RESPONSE = ( + """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE {% for group in groups %} - 123456789012 + """ + + ACCOUNT_ID + + """ {{ group.id }} {{ group.name }} {{ group.description }} @@ -171,7 +200,9 @@ DESCRIBE_SECURITY_GROUPS_RESPONSE = """ {{ source_group.id }} {{ source_group.name }} @@ -200,7 +231,9 @@ DESCRIBE_SECURITY_GROUPS_RESPONSE = """ {{ source_group.id }} {{ source_group.name }} @@ -230,6 +263,7 @@ DESCRIBE_SECURITY_GROUPS_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE diff --git a/moto/ec2/responses/spot_fleets.py b/moto/ec2/responses/spot_fleets.py index bb9aeb4ca..b7de85323 100644 --- a/moto/ec2/responses/spot_fleets.py +++ b/moto/ec2/responses/spot_fleets.py @@ -3,12 +3,12 @@ from moto.core.responses import BaseResponse class SpotFleets(BaseResponse): - def cancel_spot_fleet_requests(self): spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.") terminate_instances = self._get_param("TerminateInstances") spot_fleets = self.ec2_backend.cancel_spot_fleet_requests( - spot_fleet_request_ids, terminate_instances) + spot_fleet_request_ids, terminate_instances + ) template = self.response_template(CANCEL_SPOT_FLEETS_TEMPLATE) return template.render(spot_fleets=spot_fleets) @@ -16,37 +16,42 @@ class SpotFleets(BaseResponse): spot_fleet_request_id = self._get_param("SpotFleetRequestId") spot_requests = self.ec2_backend.describe_spot_fleet_instances( - spot_fleet_request_id) - template = self.response_template( - DESCRIBE_SPOT_FLEET_INSTANCES_TEMPLATE) - return template.render(spot_request_id=spot_fleet_request_id, spot_requests=spot_requests) + spot_fleet_request_id + ) + template = self.response_template(DESCRIBE_SPOT_FLEET_INSTANCES_TEMPLATE) + return template.render( + spot_request_id=spot_fleet_request_id, spot_requests=spot_requests + ) def describe_spot_fleet_requests(self): spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.") - requests = self.ec2_backend.describe_spot_fleet_requests( - spot_fleet_request_ids) + requests = self.ec2_backend.describe_spot_fleet_requests(spot_fleet_request_ids) template = self.response_template(DESCRIBE_SPOT_FLEET_TEMPLATE) return template.render(requests=requests) def modify_spot_fleet_request(self): spot_fleet_request_id = self._get_param("SpotFleetRequestId") target_capacity = self._get_int_param("TargetCapacity") - terminate_instances = self._get_param("ExcessCapacityTerminationPolicy", if_none="Default") + terminate_instances = self._get_param( + "ExcessCapacityTerminationPolicy", if_none="Default" + ) successful = self.ec2_backend.modify_spot_fleet_request( - spot_fleet_request_id, target_capacity, terminate_instances) + spot_fleet_request_id, target_capacity, terminate_instances + ) template = self.response_template(MODIFY_SPOT_FLEET_REQUEST_TEMPLATE) return template.render(successful=successful) def request_spot_fleet(self): spot_config = self._get_dict_param("SpotFleetRequestConfig.") - spot_price = spot_config.get('spot_price') - target_capacity = spot_config['target_capacity'] - iam_fleet_role = spot_config['iam_fleet_role'] - allocation_strategy = spot_config['allocation_strategy'] + spot_price = spot_config.get("spot_price") + target_capacity = spot_config["target_capacity"] + iam_fleet_role = spot_config["iam_fleet_role"] + allocation_strategy = spot_config["allocation_strategy"] launch_specs = self._get_list_prefix( - "SpotFleetRequestConfig.LaunchSpecifications") + "SpotFleetRequestConfig.LaunchSpecifications" + ) request = self.ec2_backend.request_spot_fleet( spot_price=spot_price, diff --git a/moto/ec2/responses/spot_instances.py b/moto/ec2/responses/spot_instances.py index b0e80a320..392ad9524 100644 --- a/moto/ec2/responses/spot_instances.py +++ b/moto/ec2/responses/spot_instances.py @@ -4,64 +4,61 @@ from moto.ec2.utils import filters_from_querystring class SpotInstances(BaseResponse): - def cancel_spot_instance_requests(self): - request_ids = self._get_multi_param('SpotInstanceRequestId') - if self.is_not_dryrun('CancelSpotInstance'): - requests = self.ec2_backend.cancel_spot_instance_requests( - request_ids) + request_ids = self._get_multi_param("SpotInstanceRequestId") + if self.is_not_dryrun("CancelSpotInstance"): + requests = self.ec2_backend.cancel_spot_instance_requests(request_ids) template = self.response_template(CANCEL_SPOT_INSTANCES_TEMPLATE) return template.render(requests=requests) def create_spot_datafeed_subscription(self): - if self.is_not_dryrun('CreateSpotDatafeedSubscription'): + if self.is_not_dryrun("CreateSpotDatafeedSubscription"): raise NotImplementedError( - 'SpotInstances.create_spot_datafeed_subscription is not yet implemented') + "SpotInstances.create_spot_datafeed_subscription is not yet implemented" + ) def delete_spot_datafeed_subscription(self): - if self.is_not_dryrun('DeleteSpotDatafeedSubscription'): + if self.is_not_dryrun("DeleteSpotDatafeedSubscription"): raise NotImplementedError( - 'SpotInstances.delete_spot_datafeed_subscription is not yet implemented') + "SpotInstances.delete_spot_datafeed_subscription is not yet implemented" + ) def describe_spot_datafeed_subscription(self): raise NotImplementedError( - 'SpotInstances.describe_spot_datafeed_subscription is not yet implemented') + "SpotInstances.describe_spot_datafeed_subscription is not yet implemented" + ) def describe_spot_instance_requests(self): filters = filters_from_querystring(self.querystring) - requests = self.ec2_backend.describe_spot_instance_requests( - filters=filters) + requests = self.ec2_backend.describe_spot_instance_requests(filters=filters) template = self.response_template(DESCRIBE_SPOT_INSTANCES_TEMPLATE) return template.render(requests=requests) def describe_spot_price_history(self): raise NotImplementedError( - 'SpotInstances.describe_spot_price_history is not yet implemented') + "SpotInstances.describe_spot_price_history is not yet implemented" + ) def request_spot_instances(self): - price = self._get_param('SpotPrice') - image_id = self._get_param('LaunchSpecification.ImageId') - count = self._get_int_param('InstanceCount', 1) - type = self._get_param('Type', 'one-time') - valid_from = self._get_param('ValidFrom') - valid_until = self._get_param('ValidUntil') - launch_group = self._get_param('LaunchGroup') - availability_zone_group = self._get_param('AvailabilityZoneGroup') - key_name = self._get_param('LaunchSpecification.KeyName') - security_groups = self._get_multi_param( - 'LaunchSpecification.SecurityGroup') - user_data = self._get_param('LaunchSpecification.UserData') - instance_type = self._get_param( - 'LaunchSpecification.InstanceType', 'm1.small') - placement = self._get_param( - 'LaunchSpecification.Placement.AvailabilityZone') - kernel_id = self._get_param('LaunchSpecification.KernelId') - ramdisk_id = self._get_param('LaunchSpecification.RamdiskId') - monitoring_enabled = self._get_param( - 'LaunchSpecification.Monitoring.Enabled') - subnet_id = self._get_param('LaunchSpecification.SubnetId') + price = self._get_param("SpotPrice") + image_id = self._get_param("LaunchSpecification.ImageId") + count = self._get_int_param("InstanceCount", 1) + type = self._get_param("Type", "one-time") + valid_from = self._get_param("ValidFrom") + valid_until = self._get_param("ValidUntil") + launch_group = self._get_param("LaunchGroup") + availability_zone_group = self._get_param("AvailabilityZoneGroup") + key_name = self._get_param("LaunchSpecification.KeyName") + security_groups = self._get_multi_param("LaunchSpecification.SecurityGroup") + user_data = self._get_param("LaunchSpecification.UserData") + instance_type = self._get_param("LaunchSpecification.InstanceType", "m1.small") + placement = self._get_param("LaunchSpecification.Placement.AvailabilityZone") + kernel_id = self._get_param("LaunchSpecification.KernelId") + ramdisk_id = self._get_param("LaunchSpecification.RamdiskId") + monitoring_enabled = self._get_param("LaunchSpecification.Monitoring.Enabled") + subnet_id = self._get_param("LaunchSpecification.SubnetId") - if self.is_not_dryrun('RequestSpotInstance'): + if self.is_not_dryrun("RequestSpotInstance"): requests = self.ec2_backend.request_spot_instances( price=price, image_id=image_id, diff --git a/moto/ec2/responses/subnets.py b/moto/ec2/responses/subnets.py index 0412d9e8b..e11984e52 100644 --- a/moto/ec2/responses/subnets.py +++ b/moto/ec2/responses/subnets.py @@ -6,44 +6,42 @@ from moto.ec2.utils import filters_from_querystring class Subnets(BaseResponse): - def create_subnet(self): - vpc_id = self._get_param('VpcId') - cidr_block = self._get_param('CidrBlock') + vpc_id = self._get_param("VpcId") + cidr_block = self._get_param("CidrBlock") availability_zone = self._get_param( - 'AvailabilityZone', if_none=random.choice( - self.ec2_backend.describe_availability_zones()).name) + "AvailabilityZone", + if_none=random.choice(self.ec2_backend.describe_availability_zones()).name, + ) subnet = self.ec2_backend.create_subnet( - vpc_id, - cidr_block, - availability_zone, - context=self, + vpc_id, cidr_block, availability_zone, context=self ) template = self.response_template(CREATE_SUBNET_RESPONSE) return template.render(subnet=subnet) def delete_subnet(self): - subnet_id = self._get_param('SubnetId') + subnet_id = self._get_param("SubnetId") subnet = self.ec2_backend.delete_subnet(subnet_id) template = self.response_template(DELETE_SUBNET_RESPONSE) return template.render(subnet=subnet) def describe_subnets(self): - subnet_ids = self._get_multi_param('SubnetId') + subnet_ids = self._get_multi_param("SubnetId") filters = filters_from_querystring(self.querystring) subnets = self.ec2_backend.get_all_subnets(subnet_ids, filters) template = self.response_template(DESCRIBE_SUBNETS_RESPONSE) return template.render(subnets=subnets) def modify_subnet_attribute(self): - subnet_id = self._get_param('SubnetId') + subnet_id = self._get_param("SubnetId") - for attribute in ('MapPublicIpOnLaunch', 'AssignIpv6AddressOnCreation'): - if self.querystring.get('%s.Value' % attribute): + for attribute in ("MapPublicIpOnLaunch", "AssignIpv6AddressOnCreation"): + if self.querystring.get("%s.Value" % attribute): attr_name = camelcase_to_underscores(attribute) - attr_value = self.querystring.get('%s.Value' % attribute)[0] + attr_value = self.querystring.get("%s.Value" % attribute)[0] self.ec2_backend.modify_subnet_attribute( - subnet_id, attr_name, attr_value) + subnet_id, attr_name, attr_value + ) return MODIFY_SUBNET_ATTRIBUTE_RESPONSE @@ -55,7 +53,7 @@ CREATE_SUBNET_RESPONSE = """ pending {{ subnet.vpc_id }} {{ subnet.cidr_block }} - 251 + {{ subnet.available_ip_addresses }} {{ subnet._availability_zone.name }} {{ subnet._availability_zone.zone_id }} {{ subnet.default_for_az }} @@ -83,7 +81,7 @@ DESCRIBE_SUBNETS_RESPONSE = """ available {{ subnet.vpc_id }} {{ subnet.cidr_block }} - 251 + {{ subnet.available_ip_addresses }} {{ subnet._availability_zone.name }} {{ subnet._availability_zone.zone_id }} {{ subnet.default_for_az }} diff --git a/moto/ec2/responses/tags.py b/moto/ec2/responses/tags.py index 65d3da255..5290b7409 100644 --- a/moto/ec2/responses/tags.py +++ b/moto/ec2/responses/tags.py @@ -6,21 +6,20 @@ from moto.ec2.utils import tags_from_query_string, filters_from_querystring class TagResponse(BaseResponse): - def create_tags(self): - resource_ids = self._get_multi_param('ResourceId') + resource_ids = self._get_multi_param("ResourceId") validate_resource_ids(resource_ids) self.ec2_backend.do_resources_exist(resource_ids) tags = tags_from_query_string(self.querystring) - if self.is_not_dryrun('CreateTags'): + if self.is_not_dryrun("CreateTags"): self.ec2_backend.create_tags(resource_ids, tags) return CREATE_RESPONSE def delete_tags(self): - resource_ids = self._get_multi_param('ResourceId') + resource_ids = self._get_multi_param("ResourceId") validate_resource_ids(resource_ids) tags = tags_from_query_string(self.querystring) - if self.is_not_dryrun('DeleteTags'): + if self.is_not_dryrun("DeleteTags"): self.ec2_backend.delete_tags(resource_ids, tags) return DELETE_RESPONSE diff --git a/moto/ec2/responses/virtual_private_gateways.py b/moto/ec2/responses/virtual_private_gateways.py index 75de31b93..ce30aa9b2 100644 --- a/moto/ec2/responses/virtual_private_gateways.py +++ b/moto/ec2/responses/virtual_private_gateways.py @@ -4,25 +4,21 @@ from moto.ec2.utils import filters_from_querystring class VirtualPrivateGateways(BaseResponse): - def attach_vpn_gateway(self): - vpn_gateway_id = self._get_param('VpnGatewayId') - vpc_id = self._get_param('VpcId') - attachment = self.ec2_backend.attach_vpn_gateway( - vpn_gateway_id, - vpc_id - ) + vpn_gateway_id = self._get_param("VpnGatewayId") + vpc_id = self._get_param("VpcId") + attachment = self.ec2_backend.attach_vpn_gateway(vpn_gateway_id, vpc_id) template = self.response_template(ATTACH_VPN_GATEWAY_RESPONSE) return template.render(attachment=attachment) def create_vpn_gateway(self): - type = self._get_param('Type') + type = self._get_param("Type") vpn_gateway = self.ec2_backend.create_vpn_gateway(type) template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE) return template.render(vpn_gateway=vpn_gateway) def delete_vpn_gateway(self): - vpn_gateway_id = self._get_param('VpnGatewayId') + vpn_gateway_id = self._get_param("VpnGatewayId") vpn_gateway = self.ec2_backend.delete_vpn_gateway(vpn_gateway_id) template = self.response_template(DELETE_VPN_GATEWAY_RESPONSE) return template.render(vpn_gateway=vpn_gateway) @@ -34,12 +30,9 @@ class VirtualPrivateGateways(BaseResponse): return template.render(vpn_gateways=vpn_gateways) def detach_vpn_gateway(self): - vpn_gateway_id = self._get_param('VpnGatewayId') - vpc_id = self._get_param('VpcId') - attachment = self.ec2_backend.detach_vpn_gateway( - vpn_gateway_id, - vpc_id - ) + vpn_gateway_id = self._get_param("VpnGatewayId") + vpc_id = self._get_param("VpcId") + attachment = self.ec2_backend.detach_vpn_gateway(vpn_gateway_id, vpc_id) template = self.response_template(DETACH_VPN_GATEWAY_RESPONSE) return template.render(attachment=attachment) diff --git a/moto/ec2/responses/vm_export.py b/moto/ec2/responses/vm_export.py index 6fdf59ba3..a4c831fcb 100644 --- a/moto/ec2/responses/vm_export.py +++ b/moto/ec2/responses/vm_export.py @@ -3,15 +3,15 @@ from moto.core.responses import BaseResponse class VMExport(BaseResponse): - def cancel_export_task(self): - raise NotImplementedError( - 'VMExport.cancel_export_task is not yet implemented') + raise NotImplementedError("VMExport.cancel_export_task is not yet implemented") def create_instance_export_task(self): raise NotImplementedError( - 'VMExport.create_instance_export_task is not yet implemented') + "VMExport.create_instance_export_task is not yet implemented" + ) def describe_export_tasks(self): raise NotImplementedError( - 'VMExport.describe_export_tasks is not yet implemented') + "VMExport.describe_export_tasks is not yet implemented" + ) diff --git a/moto/ec2/responses/vm_import.py b/moto/ec2/responses/vm_import.py index 8c2ba138c..50f77c66c 100644 --- a/moto/ec2/responses/vm_import.py +++ b/moto/ec2/responses/vm_import.py @@ -3,19 +3,18 @@ from moto.core.responses import BaseResponse class VMImport(BaseResponse): - def cancel_conversion_task(self): raise NotImplementedError( - 'VMImport.cancel_conversion_task is not yet implemented') + "VMImport.cancel_conversion_task is not yet implemented" + ) def describe_conversion_tasks(self): raise NotImplementedError( - 'VMImport.describe_conversion_tasks is not yet implemented') + "VMImport.describe_conversion_tasks is not yet implemented" + ) def import_instance(self): - raise NotImplementedError( - 'VMImport.import_instance is not yet implemented') + raise NotImplementedError("VMImport.import_instance is not yet implemented") def import_volume(self): - raise NotImplementedError( - 'VMImport.import_volume is not yet implemented') + raise NotImplementedError("VMImport.import_volume is not yet implemented") diff --git a/moto/ec2/responses/vpc_peering_connections.py b/moto/ec2/responses/vpc_peering_connections.py index 68bae72da..3bf86af8a 100644 --- a/moto/ec2/responses/vpc_peering_connections.py +++ b/moto/ec2/responses/vpc_peering_connections.py @@ -1,50 +1,48 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse +from moto.core import ACCOUNT_ID class VPCPeeringConnections(BaseResponse): - def create_vpc_peering_connection(self): - peer_region = self._get_param('PeerRegion') + peer_region = self._get_param("PeerRegion") if peer_region == self.region or peer_region is None: - peer_vpc = self.ec2_backend.get_vpc(self._get_param('PeerVpcId')) + peer_vpc = self.ec2_backend.get_vpc(self._get_param("PeerVpcId")) else: - peer_vpc = self.ec2_backend.get_cross_vpc(self._get_param('PeerVpcId'), peer_region) - vpc = self.ec2_backend.get_vpc(self._get_param('VpcId')) + peer_vpc = self.ec2_backend.get_cross_vpc( + self._get_param("PeerVpcId"), peer_region + ) + vpc = self.ec2_backend.get_vpc(self._get_param("VpcId")) vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc) - template = self.response_template( - CREATE_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(CREATE_VPC_PEERING_CONNECTION_RESPONSE) return template.render(vpc_pcx=vpc_pcx) def delete_vpc_peering_connection(self): - vpc_pcx_id = self._get_param('VpcPeeringConnectionId') + vpc_pcx_id = self._get_param("VpcPeeringConnectionId") vpc_pcx = self.ec2_backend.delete_vpc_peering_connection(vpc_pcx_id) - template = self.response_template( - DELETE_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(DELETE_VPC_PEERING_CONNECTION_RESPONSE) return template.render(vpc_pcx=vpc_pcx) def describe_vpc_peering_connections(self): vpc_pcxs = self.ec2_backend.get_all_vpc_peering_connections() - template = self.response_template( - DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE) + template = self.response_template(DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE) return template.render(vpc_pcxs=vpc_pcxs) def accept_vpc_peering_connection(self): - vpc_pcx_id = self._get_param('VpcPeeringConnectionId') + vpc_pcx_id = self._get_param("VpcPeeringConnectionId") vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id) - template = self.response_template( - ACCEPT_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(ACCEPT_VPC_PEERING_CONNECTION_RESPONSE) return template.render(vpc_pcx=vpc_pcx) def reject_vpc_peering_connection(self): - vpc_pcx_id = self._get_param('VpcPeeringConnectionId') + vpc_pcx_id = self._get_param("VpcPeeringConnectionId") self.ec2_backend.reject_vpc_peering_connection(vpc_pcx_id) - template = self.response_template( - REJECT_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(REJECT_VPC_PEERING_CONNECTION_RESPONSE) return template.render() -CREATE_VPC_PEERING_CONNECTION_RESPONSE = """ +CREATE_VPC_PEERING_CONNECTION_RESPONSE = ( + """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -60,7 +58,9 @@ CREATE_VPC_PEERING_CONNECTION_RESPONSE = """ - 123456789012 + """ + + ACCOUNT_ID + + """ {{ vpc_pcx.peer_vpc.id }} @@ -72,8 +72,10 @@ CREATE_VPC_PEERING_CONNECTION_RESPONSE = """ """ +) -DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = """ +DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = ( + """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -86,7 +88,9 @@ DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = """ {{ vpc_pcx.vpc.cidr_block }} - 123456789012 + """ + + ACCOUNT_ID + + """ {{ vpc_pcx.peer_vpc.id }} {{ vpc_pcx.peer_vpc.cidr_block }} @@ -105,6 +109,7 @@ DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE = """ """ +) DELETE_VPC_PEERING_CONNECTION_RESPONSE = """ @@ -113,7 +118,8 @@ DELETE_VPC_PEERING_CONNECTION_RESPONSE = """ """ -ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = """ +ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = ( + """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -124,7 +130,9 @@ ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = """ {{ vpc_pcx.vpc.cidr_block }} - 123456789012 + """ + + ACCOUNT_ID + + """ {{ vpc_pcx.peer_vpc.id }} {{ vpc_pcx.peer_vpc.cidr_block }} @@ -141,6 +149,7 @@ ACCEPT_VPC_PEERING_CONNECTION_RESPONSE = """ """ +) REJECT_VPC_PEERING_CONNECTION_RESPONSE = """ diff --git a/moto/ec2/responses/vpcs.py b/moto/ec2/responses/vpcs.py index 88673d863..0fd198378 100644 --- a/moto/ec2/responses/vpcs.py +++ b/moto/ec2/responses/vpcs.py @@ -5,74 +5,163 @@ from moto.ec2.utils import filters_from_querystring class VPCs(BaseResponse): + def _get_doc_date(self): + return ( + "2013-10-15" + if "Boto/" in self.headers.get("user-agent", "") + else "2016-11-15" + ) def create_vpc(self): - cidr_block = self._get_param('CidrBlock') - instance_tenancy = self._get_param('InstanceTenancy', if_none='default') - amazon_provided_ipv6_cidr_blocks = self._get_param('AmazonProvidedIpv6CidrBlock') - vpc = self.ec2_backend.create_vpc(cidr_block, instance_tenancy, - amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_blocks) - doc_date = '2013-10-15' if 'Boto/' in self.headers.get('user-agent', '') else '2016-11-15' + cidr_block = self._get_param("CidrBlock") + instance_tenancy = self._get_param("InstanceTenancy", if_none="default") + amazon_provided_ipv6_cidr_blocks = self._get_param( + "AmazonProvidedIpv6CidrBlock" + ) + vpc = self.ec2_backend.create_vpc( + cidr_block, + instance_tenancy, + amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_blocks, + ) + doc_date = self._get_doc_date() template = self.response_template(CREATE_VPC_RESPONSE) return template.render(vpc=vpc, doc_date=doc_date) def delete_vpc(self): - vpc_id = self._get_param('VpcId') + vpc_id = self._get_param("VpcId") vpc = self.ec2_backend.delete_vpc(vpc_id) template = self.response_template(DELETE_VPC_RESPONSE) return template.render(vpc=vpc) def describe_vpcs(self): - vpc_ids = self._get_multi_param('VpcId') + vpc_ids = self._get_multi_param("VpcId") filters = filters_from_querystring(self.querystring) vpcs = self.ec2_backend.get_all_vpcs(vpc_ids=vpc_ids, filters=filters) - doc_date = '2013-10-15' if 'Boto/' in self.headers.get('user-agent', '') else '2016-11-15' + doc_date = ( + "2013-10-15" + if "Boto/" in self.headers.get("user-agent", "") + else "2016-11-15" + ) template = self.response_template(DESCRIBE_VPCS_RESPONSE) return template.render(vpcs=vpcs, doc_date=doc_date) def describe_vpc_attribute(self): - vpc_id = self._get_param('VpcId') - attribute = self._get_param('Attribute') + vpc_id = self._get_param("VpcId") + attribute = self._get_param("Attribute") attr_name = camelcase_to_underscores(attribute) value = self.ec2_backend.describe_vpc_attribute(vpc_id, attr_name) template = self.response_template(DESCRIBE_VPC_ATTRIBUTE_RESPONSE) return template.render(vpc_id=vpc_id, attribute=attribute, value=value) - def modify_vpc_attribute(self): - vpc_id = self._get_param('VpcId') + def describe_vpc_classic_link_dns_support(self): + vpc_ids = self._get_multi_param("VpcIds") + filters = filters_from_querystring(self.querystring) + vpcs = self.ec2_backend.get_all_vpcs(vpc_ids=vpc_ids, filters=filters) + doc_date = self._get_doc_date() + template = self.response_template( + DESCRIBE_VPC_CLASSIC_LINK_DNS_SUPPORT_RESPONSE + ) + return template.render(vpcs=vpcs, doc_date=doc_date) - for attribute in ('EnableDnsSupport', 'EnableDnsHostnames'): - if self.querystring.get('%s.Value' % attribute): + def enable_vpc_classic_link_dns_support(self): + vpc_id = self._get_param("VpcId") + classic_link_dns_supported = self.ec2_backend.enable_vpc_classic_link_dns_support( + vpc_id=vpc_id + ) + doc_date = self._get_doc_date() + template = self.response_template(ENABLE_VPC_CLASSIC_LINK_DNS_SUPPORT_RESPONSE) + return template.render( + classic_link_dns_supported=classic_link_dns_supported, doc_date=doc_date + ) + + def disable_vpc_classic_link_dns_support(self): + vpc_id = self._get_param("VpcId") + classic_link_dns_supported = self.ec2_backend.disable_vpc_classic_link_dns_support( + vpc_id=vpc_id + ) + doc_date = self._get_doc_date() + template = self.response_template(DISABLE_VPC_CLASSIC_LINK_DNS_SUPPORT_RESPONSE) + return template.render( + classic_link_dns_supported=classic_link_dns_supported, doc_date=doc_date + ) + + def describe_vpc_classic_link(self): + vpc_ids = self._get_multi_param("VpcId") + filters = filters_from_querystring(self.querystring) + vpcs = self.ec2_backend.get_all_vpcs(vpc_ids=vpc_ids, filters=filters) + doc_date = self._get_doc_date() + template = self.response_template(DESCRIBE_VPC_CLASSIC_LINK_RESPONSE) + return template.render(vpcs=vpcs, doc_date=doc_date) + + def enable_vpc_classic_link(self): + vpc_id = self._get_param("VpcId") + classic_link_enabled = self.ec2_backend.enable_vpc_classic_link(vpc_id=vpc_id) + doc_date = self._get_doc_date() + template = self.response_template(ENABLE_VPC_CLASSIC_LINK_RESPONSE) + return template.render( + classic_link_enabled=classic_link_enabled, doc_date=doc_date + ) + + def disable_vpc_classic_link(self): + vpc_id = self._get_param("VpcId") + classic_link_enabled = self.ec2_backend.disable_vpc_classic_link(vpc_id=vpc_id) + doc_date = self._get_doc_date() + template = self.response_template(DISABLE_VPC_CLASSIC_LINK_RESPONSE) + return template.render( + classic_link_enabled=classic_link_enabled, doc_date=doc_date + ) + + def modify_vpc_attribute(self): + vpc_id = self._get_param("VpcId") + + for attribute in ("EnableDnsSupport", "EnableDnsHostnames"): + if self.querystring.get("%s.Value" % attribute): attr_name = camelcase_to_underscores(attribute) - attr_value = self.querystring.get('%s.Value' % attribute)[0] - self.ec2_backend.modify_vpc_attribute( - vpc_id, attr_name, attr_value) + attr_value = self.querystring.get("%s.Value" % attribute)[0] + self.ec2_backend.modify_vpc_attribute(vpc_id, attr_name, attr_value) return MODIFY_VPC_ATTRIBUTE_RESPONSE def associate_vpc_cidr_block(self): - vpc_id = self._get_param('VpcId') - amazon_provided_ipv6_cidr_blocks = self._get_param('AmazonProvidedIpv6CidrBlock') + vpc_id = self._get_param("VpcId") + amazon_provided_ipv6_cidr_blocks = self._get_param( + "AmazonProvidedIpv6CidrBlock" + ) # todo test on AWS if can create an association for IPV4 and IPV6 in the same call? - cidr_block = self._get_param('CidrBlock') if not amazon_provided_ipv6_cidr_blocks else None - value = self.ec2_backend.associate_vpc_cidr_block(vpc_id, cidr_block, amazon_provided_ipv6_cidr_blocks) + cidr_block = ( + self._get_param("CidrBlock") + if not amazon_provided_ipv6_cidr_blocks + else None + ) + value = self.ec2_backend.associate_vpc_cidr_block( + vpc_id, cidr_block, amazon_provided_ipv6_cidr_blocks + ) if not amazon_provided_ipv6_cidr_blocks: render_template = ASSOCIATE_VPC_CIDR_BLOCK_RESPONSE else: render_template = IPV6_ASSOCIATE_VPC_CIDR_BLOCK_RESPONSE template = self.response_template(render_template) - return template.render(vpc_id=vpc_id, value=value, cidr_block=value['cidr_block'], - association_id=value['association_id'], cidr_block_state='associating') + return template.render( + vpc_id=vpc_id, + value=value, + cidr_block=value["cidr_block"], + association_id=value["association_id"], + cidr_block_state="associating", + ) def disassociate_vpc_cidr_block(self): - association_id = self._get_param('AssociationId') + association_id = self._get_param("AssociationId") value = self.ec2_backend.disassociate_vpc_cidr_block(association_id) - if "::" in value.get('cidr_block', ''): + if "::" in value.get("cidr_block", ""): render_template = IPV6_DISASSOCIATE_VPC_CIDR_BLOCK_RESPONSE else: render_template = DISASSOCIATE_VPC_CIDR_BLOCK_RESPONSE template = self.response_template(render_template) - return template.render(vpc_id=value['vpc_id'], cidr_block=value['cidr_block'], - association_id=value['association_id'], cidr_block_state='disassociating') + return template.render( + vpc_id=value["vpc_id"], + cidr_block=value["cidr_block"], + association_id=value["association_id"], + cidr_block_state="disassociating", + ) CREATE_VPC_RESPONSE = """ @@ -121,6 +210,56 @@ CREATE_VPC_RESPONSE = """ """ +DESCRIBE_VPC_CLASSIC_LINK_DNS_SUPPORT_RESPONSE = """ + + 7a62c442-3484-4f42-9342-6942EXAMPLE + + {% for vpc in vpcs %} + + {{ vpc.id }} + {{ vpc.classic_link_dns_supported }} + + {% endfor %} + +""" + +ENABLE_VPC_CLASSIC_LINK_DNS_SUPPORT_RESPONSE = """ + + 7a62c442-3484-4f42-9342-6942EXAMPLE + {{ classic_link_dns_supported }} +""" + +DISABLE_VPC_CLASSIC_LINK_DNS_SUPPORT_RESPONSE = """ + + 7a62c442-3484-4f42-9342-6942EXAMPLE + {{ classic_link_dns_supported }} +""" + +DESCRIBE_VPC_CLASSIC_LINK_RESPONSE = """ + + 7a62c442-3484-4f42-9342-6942EXAMPLE + + {% for vpc in vpcs %} + + {{ vpc.id }} + {{ vpc.classic_link_enabled }} + + {% endfor %} + +""" + +ENABLE_VPC_CLASSIC_LINK_RESPONSE = """ + + 7a62c442-3484-4f42-9342-6942EXAMPLE + {{ classic_link_enabled }} +""" + +DISABLE_VPC_CLASSIC_LINK_RESPONSE = """ + + 7a62c442-3484-4f42-9342-6942EXAMPLE + {{ classic_link_enabled }} +""" + DESCRIBE_VPCS_RESPONSE = """ 7a62c442-3484-4f42-9342-6942EXAMPLE diff --git a/moto/ec2/responses/vpn_connections.py b/moto/ec2/responses/vpn_connections.py index 276e3ca99..9ddd4d7d9 100644 --- a/moto/ec2/responses/vpn_connections.py +++ b/moto/ec2/responses/vpn_connections.py @@ -4,29 +4,29 @@ from moto.ec2.utils import filters_from_querystring class VPNConnections(BaseResponse): - def create_vpn_connection(self): - type = self._get_param('Type') - cgw_id = self._get_param('CustomerGatewayId') - vgw_id = self._get_param('VPNGatewayId') - static_routes = self._get_param('StaticRoutesOnly') + type = self._get_param("Type") + cgw_id = self._get_param("CustomerGatewayId") + vgw_id = self._get_param("VPNGatewayId") + static_routes = self._get_param("StaticRoutesOnly") vpn_connection = self.ec2_backend.create_vpn_connection( - type, cgw_id, vgw_id, static_routes_only=static_routes) + type, cgw_id, vgw_id, static_routes_only=static_routes + ) template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE) return template.render(vpn_connection=vpn_connection) def delete_vpn_connection(self): - vpn_connection_id = self._get_param('VpnConnectionId') - vpn_connection = self.ec2_backend.delete_vpn_connection( - vpn_connection_id) + vpn_connection_id = self._get_param("VpnConnectionId") + vpn_connection = self.ec2_backend.delete_vpn_connection(vpn_connection_id) template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE) return template.render(vpn_connection=vpn_connection) def describe_vpn_connections(self): - vpn_connection_ids = self._get_multi_param('VpnConnectionId') + vpn_connection_ids = self._get_multi_param("VpnConnectionId") filters = filters_from_querystring(self.querystring) vpn_connections = self.ec2_backend.get_all_vpn_connections( - vpn_connection_ids=vpn_connection_ids, filters=filters) + vpn_connection_ids=vpn_connection_ids, filters=filters + ) template = self.response_template(DESCRIBE_VPN_CONNECTION_RESPONSE) return template.render(vpn_connections=vpn_connections) diff --git a/moto/ec2/responses/windows.py b/moto/ec2/responses/windows.py index 13dfa9b67..14b2b0666 100644 --- a/moto/ec2/responses/windows.py +++ b/moto/ec2/responses/windows.py @@ -3,19 +3,16 @@ from moto.core.responses import BaseResponse class Windows(BaseResponse): - def bundle_instance(self): - raise NotImplementedError( - 'Windows.bundle_instance is not yet implemented') + raise NotImplementedError("Windows.bundle_instance is not yet implemented") def cancel_bundle_task(self): - raise NotImplementedError( - 'Windows.cancel_bundle_task is not yet implemented') + raise NotImplementedError("Windows.cancel_bundle_task is not yet implemented") def describe_bundle_tasks(self): raise NotImplementedError( - 'Windows.describe_bundle_tasks is not yet implemented') + "Windows.describe_bundle_tasks is not yet implemented" + ) def get_password_data(self): - raise NotImplementedError( - 'Windows.get_password_data is not yet implemented') + raise NotImplementedError("Windows.get_password_data is not yet implemented") diff --git a/moto/ec2/urls.py b/moto/ec2/urls.py index 241ab7133..b83a9e950 100644 --- a/moto/ec2/urls.py +++ b/moto/ec2/urls.py @@ -2,10 +2,6 @@ from __future__ import unicode_literals from .responses import EC2Response -url_bases = [ - "https?://ec2.(.+).amazonaws.com(|.cn)", -] +url_bases = ["https?://ec2.(.+).amazonaws.com(|.cn)"] -url_paths = { - '{0}/': EC2Response.dispatch, -} +url_paths = {"{0}/": EC2Response.dispatch} diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index e67cb39f4..2301248c1 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -15,173 +15,171 @@ from sshpubkeys.keys import SSHKey EC2_RESOURCE_TO_PREFIX = { - 'customer-gateway': 'cgw', - 'dhcp-options': 'dopt', - 'image': 'ami', - 'instance': 'i', - 'internet-gateway': 'igw', - 'launch-template': 'lt', - 'nat-gateway': 'nat', - 'network-acl': 'acl', - 'network-acl-subnet-assoc': 'aclassoc', - 'network-interface': 'eni', - 'network-interface-attachment': 'eni-attach', - 'reserved-instance': 'uuid4', - 'route-table': 'rtb', - 'route-table-association': 'rtbassoc', - 'security-group': 'sg', - 'snapshot': 'snap', - 'spot-instance-request': 'sir', - 'spot-fleet-request': 'sfr', - 'subnet': 'subnet', - 'reservation': 'r', - 'volume': 'vol', - 'vpc': 'vpc', - 'vpc-cidr-association-id': 'vpc-cidr-assoc', - 'vpc-elastic-ip': 'eipalloc', - 'vpc-elastic-ip-association': 'eipassoc', - 'vpc-peering-connection': 'pcx', - 'vpn-connection': 'vpn', - 'vpn-gateway': 'vgw'} + "customer-gateway": "cgw", + "dhcp-options": "dopt", + "image": "ami", + "instance": "i", + "internet-gateway": "igw", + "launch-template": "lt", + "nat-gateway": "nat", + "network-acl": "acl", + "network-acl-subnet-assoc": "aclassoc", + "network-interface": "eni", + "network-interface-attachment": "eni-attach", + "reserved-instance": "uuid4", + "route-table": "rtb", + "route-table-association": "rtbassoc", + "security-group": "sg", + "snapshot": "snap", + "spot-instance-request": "sir", + "spot-fleet-request": "sfr", + "subnet": "subnet", + "reservation": "r", + "volume": "vol", + "vpc": "vpc", + "vpc-cidr-association-id": "vpc-cidr-assoc", + "vpc-elastic-ip": "eipalloc", + "vpc-elastic-ip-association": "eipassoc", + "vpc-peering-connection": "pcx", + "vpn-connection": "vpn", + "vpn-gateway": "vgw", +} EC2_PREFIX_TO_RESOURCE = dict((v, k) for (k, v) in EC2_RESOURCE_TO_PREFIX.items()) def random_resource_id(size=8): - chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f'] - resource_id = ''.join(six.text_type(random.choice(chars)) for x in range(size)) + chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"] + resource_id = "".join(six.text_type(random.choice(chars)) for _ in range(size)) return resource_id -def random_id(prefix='', size=8): - return '{0}-{1}'.format(prefix, random_resource_id(size)) +def random_id(prefix="", size=8): + return "{0}-{1}".format(prefix, random_resource_id(size)) def random_ami_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['image']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["image"]) def random_instance_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['instance'], size=17) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["instance"], size=17) def random_reservation_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['reservation']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["reservation"]) def random_security_group_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['security-group']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["security-group"]) def random_snapshot_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['snapshot']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["snapshot"]) def random_spot_request_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['spot-instance-request']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["spot-instance-request"]) def random_spot_fleet_request_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['spot-fleet-request']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["spot-fleet-request"]) def random_subnet_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['subnet']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["subnet"]) def random_subnet_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['route-table-association']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["route-table-association"]) def random_network_acl_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-acl']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl"]) def random_network_acl_subnet_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-acl-subnet-assoc']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl-subnet-assoc"]) def random_vpn_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpn-gateway']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-gateway"]) def random_vpn_connection_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpn-connection']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-connection"]) def random_customer_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['customer-gateway']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["customer-gateway"]) def random_volume_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['volume']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["volume"]) def random_vpc_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc"]) def random_vpc_cidr_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-cidr-association-id']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-cidr-association-id"]) def random_vpc_peering_connection_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-peering-connection']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"]) def random_eip_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-elastic-ip-association']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-elastic-ip-association"]) def random_internet_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['internet-gateway']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["internet-gateway"]) def random_route_table_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['route-table']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["route-table"]) def random_eip_allocation_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-elastic-ip']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-elastic-ip"]) def random_dhcp_option_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['dhcp-options']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["dhcp-options"]) def random_eni_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-interface']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-interface"]) def random_eni_attach_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-interface-attachment']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-interface-attachment"]) def random_nat_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['nat-gateway'], size=17) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["nat-gateway"], size=17) def random_launch_template_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['launch-template'], size=17) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["launch-template"], size=17) def random_public_ip(): - return '54.214.{0}.{1}'.format(random.choice(range(255)), - random.choice(range(255))) + return "54.214.{0}.{1}".format(random.choice(range(255)), random.choice(range(255))) def random_private_ip(): - return '10.{0}.{1}.{2}'.format(random.choice(range(255)), - random.choice(range(255)), - random.choice(range(255))) + return "10.{0}.{1}.{2}".format( + random.choice(range(255)), random.choice(range(255)), random.choice(range(255)) + ) def random_ip(): return "127.{0}.{1}.{2}".format( - random.randint(0, 255), - random.randint(0, 255), - random.randint(0, 255) + random.randint(0, 255), random.randint(0, 255), random.randint(0, 255) ) @@ -194,13 +192,13 @@ def generate_route_id(route_table_id, cidr_block): def split_route_id(route_id): - values = route_id.split('~') + values = route_id.split("~") return values[0], values[1] def tags_from_query_string(querystring_dict): - prefix = 'Tag' - suffix = 'Key' + prefix = "Tag" + suffix = "Key" response_values = {} for key, value in querystring_dict.items(): if key.startswith(prefix) and key.endswith(suffix): @@ -208,14 +206,13 @@ def tags_from_query_string(querystring_dict): tag_key = querystring_dict.get("Tag.{0}.Key".format(tag_index))[0] tag_value_key = "Tag.{0}.Value".format(tag_index) if tag_value_key in querystring_dict: - response_values[tag_key] = querystring_dict.get(tag_value_key)[ - 0] + response_values[tag_key] = querystring_dict.get(tag_value_key)[0] else: response_values[tag_key] = None return response_values -def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration'): +def dhcp_configuration_from_querystring(querystring, option="DhcpConfiguration"): """ turn: {u'AWSAccessKeyId': [u'the_key'], @@ -234,7 +231,7 @@ def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration' {u'domain-name': [u'example.com'], u'domain-name-servers': [u'10.0.0.6', u'10.0.0.7']} """ - key_needle = re.compile(u'{0}.[0-9]+.Key'.format(option), re.UNICODE) + key_needle = re.compile("{0}.[0-9]+.Key".format(option), re.UNICODE) response_values = {} for key, value in querystring.items(): @@ -243,8 +240,7 @@ def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration' key_index = key.split(".")[1] value_index = 1 while True: - value_key = u'{0}.{1}.Value.{2}'.format( - option, key_index, value_index) + value_key = "{0}.{1}.Value.{2}".format(option, key_index, value_index) if value_key in querystring: values.extend(querystring[value_key]) else: @@ -261,8 +257,11 @@ def filters_from_querystring(querystring_dict): if match: filter_index = match.groups()[0] value_prefix = "Filter.{0}.Value".format(filter_index) - filter_values = [filter_value[0] for filter_key, filter_value in querystring_dict.items() if - filter_key.startswith(value_prefix)] + filter_values = [ + filter_value[0] + for filter_key, filter_value in querystring_dict.items() + if filter_key.startswith(value_prefix) + ] response_values[value[0]] = filter_values return response_values @@ -283,7 +282,7 @@ def dict_from_querystring(parameter, querystring_dict): def get_object_value(obj, attr): - keys = attr.split('.') + keys = attr.split(".") val = obj for key in keys: if hasattr(val, key): @@ -301,36 +300,37 @@ def get_object_value(obj, attr): def is_tag_filter(filter_name): - return (filter_name.startswith('tag:') or - filter_name.startswith('tag-value') or - filter_name.startswith('tag-key')) + return ( + filter_name.startswith("tag:") + or filter_name.startswith("tag-value") + or filter_name.startswith("tag-key") + ) def get_obj_tag(obj, filter_name): - tag_name = filter_name.replace('tag:', '', 1) - tags = dict((tag['key'], tag['value']) for tag in obj.get_tags()) + tag_name = filter_name.replace("tag:", "", 1) + tags = dict((tag["key"], tag["value"]) for tag in obj.get_tags()) return tags.get(tag_name) def get_obj_tag_names(obj): - tags = set((tag['key'] for tag in obj.get_tags())) + tags = set((tag["key"] for tag in obj.get_tags())) return tags def get_obj_tag_values(obj): - tags = set((tag['value'] for tag in obj.get_tags())) + tags = set((tag["value"] for tag in obj.get_tags())) return tags def tag_filter_matches(obj, filter_name, filter_values): - regex_filters = [re.compile(simple_aws_filter_to_re(f)) - for f in filter_values] - if filter_name == 'tag-key': + regex_filters = [re.compile(simple_aws_filter_to_re(f)) for f in filter_values] + if filter_name == "tag-key": tag_values = get_obj_tag_names(obj) - elif filter_name == 'tag-value': + elif filter_name == "tag-value": tag_values = get_obj_tag_values(obj) else: - tag_values = [get_obj_tag(obj, filter_name) or ''] + tag_values = [get_obj_tag(obj, filter_name) or ""] for tag_value in tag_values: if any(regex.match(tag_value) for regex in regex_filters): @@ -340,22 +340,22 @@ def tag_filter_matches(obj, filter_name, filter_values): filter_dict_attribute_mapping = { - 'instance-state-name': 'state', - 'instance-id': 'id', - 'state-reason-code': '_state_reason.code', - 'source-dest-check': 'source_dest_check', - 'vpc-id': 'vpc_id', - 'group-id': 'security_groups.id', - 'instance.group-id': 'security_groups.id', - 'instance.group-name': 'security_groups.name', - 'instance-type': 'instance_type', - 'private-ip-address': 'private_ip', - 'ip-address': 'public_ip', - 'availability-zone': 'placement', - 'architecture': 'architecture', - 'image-id': 'image_id', - 'network-interface.private-dns-name': 'private_dns', - 'private-dns-name': 'private_dns' + "instance-state-name": "state", + "instance-id": "id", + "state-reason-code": "_state_reason.code", + "source-dest-check": "source_dest_check", + "vpc-id": "vpc_id", + "group-id": "security_groups.id", + "instance.group-id": "security_groups.id", + "instance.group-name": "security_groups.name", + "instance-type": "instance_type", + "private-ip-address": "private_ip", + "ip-address": "public_ip", + "availability-zone": "placement", + "architecture": "architecture", + "image-id": "image_id", + "network-interface.private-dns-name": "private_dns", + "private-dns-name": "private_dns", } @@ -372,8 +372,9 @@ def passes_filter_dict(instance, filter_dict): return False else: raise NotImplementedError( - "Filter dicts have not been implemented in Moto for '%s' yet. Feel free to open an issue at https://github.com/spulec/moto/issues" % - filter_name) + "Filter dicts have not been implemented in Moto for '%s' yet. Feel free to open an issue at https://github.com/spulec/moto/issues" + % filter_name + ) return True @@ -418,7 +419,8 @@ def passes_igw_filter_dict(igw, filter_dict): else: raise NotImplementedError( "Internet Gateway filter dicts have not been implemented in Moto for '%s' yet. Feel free to open an issue at https://github.com/spulec/moto/issues", - filter_name) + filter_name, + ) return True @@ -445,7 +447,9 @@ def is_filter_matching(obj, filter, filter_value): try: value = set(value) - return (value and value.issubset(filter_value)) or value.issuperset(filter_value) + return (value and value.issubset(filter_value)) or value.issuperset( + filter_value + ) except TypeError: return value in filter_value @@ -453,47 +457,48 @@ def is_filter_matching(obj, filter, filter_value): def generic_filter(filters, objects): if filters: for (_filter, _filter_value) in filters.items(): - objects = [obj for obj in objects if is_filter_matching( - obj, _filter, _filter_value)] + objects = [ + obj + for obj in objects + if is_filter_matching(obj, _filter, _filter_value) + ] return objects def simple_aws_filter_to_re(filter_string): - tmp_filter = filter_string.replace('\?', '[?]') - tmp_filter = tmp_filter.replace('\*', '[*]') + tmp_filter = filter_string.replace(r"\?", "[?]") + tmp_filter = tmp_filter.replace(r"\*", "[*]") tmp_filter = fnmatch.translate(tmp_filter) return tmp_filter def random_key_pair(): private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend()) + public_exponent=65537, key_size=2048, backend=default_backend() + ) private_key_material = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption()) + encryption_algorithm=serialization.NoEncryption(), + ) public_key_fingerprint = rsa_public_key_fingerprint(private_key.public_key()) return { - 'fingerprint': public_key_fingerprint, - 'material': private_key_material.decode('ascii') + "fingerprint": public_key_fingerprint, + "material": private_key_material.decode("ascii"), } def get_prefix(resource_id): - resource_id_prefix, separator, after = resource_id.partition('-') - if resource_id_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']: - if after.startswith('attach'): - resource_id_prefix = EC2_RESOURCE_TO_PREFIX[ - 'network-interface-attachment'] + resource_id_prefix, separator, after = resource_id.partition("-") + if resource_id_prefix == EC2_RESOURCE_TO_PREFIX["network-interface"]: + if after.startswith("attach"): + resource_id_prefix = EC2_RESOURCE_TO_PREFIX["network-interface-attachment"] if resource_id_prefix not in EC2_RESOURCE_TO_PREFIX.values(): - uuid4hex = re.compile( - '[0-9a-f]{12}4[0-9a-f]{3}[89ab][0-9a-f]{15}\Z', re.I) + uuid4hex = re.compile(r"[0-9a-f]{12}4[0-9a-f]{3}[89ab][0-9a-f]{15}\Z", re.I) if uuid4hex.match(resource_id) is not None: - resource_id_prefix = EC2_RESOURCE_TO_PREFIX['reserved-instance'] + resource_id_prefix = EC2_RESOURCE_TO_PREFIX["reserved-instance"] else: return None return resource_id_prefix @@ -504,13 +509,13 @@ def is_valid_resource_id(resource_id): resource_id_prefix = get_prefix(resource_id) if resource_id_prefix not in valid_prefixes: return False - resource_id_pattern = resource_id_prefix + '-[0-9a-f]{8}' + resource_id_pattern = resource_id_prefix + "-[0-9a-f]{8}" resource_pattern_re = re.compile(resource_id_pattern) return resource_pattern_re.match(resource_id) is not None def is_valid_cidr(cird): - cidr_pattern = '^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])(\/(\d|[1-2]\d|3[0-2]))$' + cidr_pattern = r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])(\/(\d|[1-2]\d|3[0-2]))$" cidr_pattern_re = re.compile(cidr_pattern) return cidr_pattern_re.match(cird) is not None @@ -528,20 +533,20 @@ def generate_instance_identity_document(instance): """ document = { - 'devPayProductCodes': None, - 'availabilityZone': instance.placement['AvailabilityZone'], - 'privateIp': instance.private_ip_address, - 'version': '2010-8-31', - 'region': instance.placement['AvailabilityZone'][:-1], - 'instanceId': instance.id, - 'billingProducts': None, - 'instanceType': instance.instance_type, - 'accountId': '012345678910', - 'pendingTime': '2015-11-19T16:32:11Z', - 'imageId': instance.image_id, - 'kernelId': instance.kernel_id, - 'ramdiskId': instance.ramdisk_id, - 'architecture': instance.architecture, + "devPayProductCodes": None, + "availabilityZone": instance.placement["AvailabilityZone"], + "privateIp": instance.private_ip_address, + "version": "2010-8-31", + "region": instance.placement["AvailabilityZone"][:-1], + "instanceId": instance.id, + "billingProducts": None, + "instanceType": instance.instance_type, + "accountId": "012345678910", + "pendingTime": "2015-11-19T16:32:11Z", + "imageId": instance.image_id, + "kernelId": instance.kernel_id, + "ramdiskId": instance.ramdisk_id, + "architecture": instance.architecture, } return document @@ -555,10 +560,10 @@ def rsa_public_key_parse(key_material): decoded_key = base64.b64decode(key_material).decode("ascii") public_key = SSHKey(decoded_key) except (sshpubkeys.exceptions.InvalidKeyException, UnicodeDecodeError): - raise ValueError('bad key') + raise ValueError("bad key") if not public_key.rsa: - raise ValueError('bad key') + raise ValueError("bad key") return public_key.rsa @@ -566,7 +571,8 @@ def rsa_public_key_parse(key_material): def rsa_public_key_fingerprint(rsa_public_key): key_data = rsa_public_key.public_bytes( encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo) + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) fingerprint_hex = hashlib.md5(key_data).hexdigest() - fingerprint = re.sub(r'([a-f0-9]{2})(?!$)', r'\1:', fingerprint_hex) + fingerprint = re.sub(r"([a-f0-9]{2})(?!$)", r"\1:", fingerprint_hex) return fingerprint diff --git a/moto/ecr/__init__.py b/moto/ecr/__init__.py index 56b2cacbb..e90cd9e4c 100644 --- a/moto/ecr/__init__.py +++ b/moto/ecr/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import ecr_backends from ..core.models import base_decorator, deprecated_base_decorator -ecr_backend = ecr_backends['us-east-1'] +ecr_backend = ecr_backends["us-east-1"] mock_ecr = base_decorator(ecr_backends) mock_ecr_deprecated = deprecated_base_decorator(ecr_backends) diff --git a/moto/ecr/exceptions.py b/moto/ecr/exceptions.py index f7b951b53..9b55f0589 100644 --- a/moto/ecr/exceptions.py +++ b/moto/ecr/exceptions.py @@ -9,7 +9,8 @@ class RepositoryNotFoundException(RESTError): super(RepositoryNotFoundException, self).__init__( error_type="RepositoryNotFoundException", message="The repository with name '{0}' does not exist in the registry " - "with id '{1}'".format(repository_name, registry_id)) + "with id '{1}'".format(repository_name, registry_id), + ) class ImageNotFoundException(RESTError): @@ -19,4 +20,7 @@ class ImageNotFoundException(RESTError): super(ImageNotFoundException, self).__init__( error_type="ImageNotFoundException", message="The image with imageId {0} does not exist within the repository with name '{1}' " - "in the registry with id '{2}'".format(image_id, repository_name, registry_id)) + "in the registry with id '{2}'".format( + image_id, repository_name, registry_id + ), + ) diff --git a/moto/ecr/models.py b/moto/ecr/models.py index b03f25dee..f84df79aa 100644 --- a/moto/ecr/models.py +++ b/moto/ecr/models.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals import hashlib import re -from copy import copy from datetime import datetime from random import random @@ -12,26 +11,26 @@ from moto.core import BaseBackend, BaseModel from moto.ec2 import ec2_backends from moto.ecr.exceptions import ImageNotFoundException, RepositoryNotFoundException -DEFAULT_REGISTRY_ID = '012345678910' +DEFAULT_REGISTRY_ID = "012345678910" class BaseObject(BaseModel): - def camelCase(self, key): words = [] - for i, word in enumerate(key.split('_')): + for i, word in enumerate(key.split("_")): if i > 0: words.append(word.title()) else: words.append(word) - return ''.join(words) + return "".join(words) def gen_response_object(self): - response_object = copy(self.__dict__) - for key, value in response_object.items(): - if '_' in key: + response_object = dict() + for key, value in self.__dict__.items(): + if "_" in key: response_object[self.camelCase(key)] = value - del response_object[key] + else: + response_object[key] = value return response_object @property @@ -40,15 +39,16 @@ class BaseObject(BaseModel): class Repository(BaseObject): - def __init__(self, repository_name): self.registry_id = DEFAULT_REGISTRY_ID - self.arn = 'arn:aws:ecr:us-east-1:{0}:repository/{1}'.format( - self.registry_id, repository_name) + self.arn = "arn:aws:ecr:us-east-1:{0}:repository/{1}".format( + self.registry_id, repository_name + ) self.name = repository_name # self.created = datetime.utcnow() - self.uri = '{0}.dkr.ecr.us-east-1.amazonaws.com/{1}'.format( - self.registry_id, repository_name) + self.uri = "{0}.dkr.ecr.us-east-1.amazonaws.com/{1}".format( + self.registry_id, repository_name + ) self.images = [] @property @@ -59,38 +59,45 @@ class Repository(BaseObject): def response_object(self): response_object = self.gen_response_object() - response_object['registryId'] = self.registry_id - response_object['repositoryArn'] = self.arn - response_object['repositoryName'] = self.name - response_object['repositoryUri'] = self.uri + response_object["registryId"] = self.registry_id + response_object["repositoryArn"] = self.arn + response_object["repositoryName"] = self.name + response_object["repositoryUri"] = self.uri # response_object['createdAt'] = self.created - del response_object['arn'], response_object['name'], response_object['images'] + del response_object["arn"], response_object["name"], response_object["images"] return response_object @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ecr_backend = ecr_backends[region_name] return ecr_backend.create_repository( # RepositoryName is optional in CloudFormation, thus create a random # name if necessary repository_name=properties.get( - 'RepositoryName', 'ecrrepository{0}'.format(int(random() * 10 ** 6))), + "RepositoryName", "ecrrepository{0}".format(int(random() * 10 ** 6)) + ) ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - if original_resource.name != properties['RepositoryName']: + if original_resource.name != properties["RepositoryName"]: ecr_backend = ecr_backends[region_name] ecr_backend.delete_cluster(original_resource.arn) return ecr_backend.create_repository( # RepositoryName is optional in CloudFormation, thus create a # random name if necessary repository_name=properties.get( - 'RepositoryName', 'RepositoryName{0}'.format(int(random() * 10 ** 6))), + "RepositoryName", + "RepositoryName{0}".format(int(random() * 10 ** 6)), + ) ) else: # no-op when nothing changed between old and new resources @@ -98,8 +105,9 @@ class Repository(BaseObject): class Image(BaseObject): - - def __init__(self, tag, manifest, repository, digest=None, registry_id=DEFAULT_REGISTRY_ID): + def __init__( + self, tag, manifest, repository, digest=None, registry_id=DEFAULT_REGISTRY_ID + ): self.image_tag = tag self.image_tags = [tag] if tag is not None else [] self.image_manifest = manifest @@ -110,8 +118,10 @@ class Image(BaseObject): self.image_pushed_at = str(datetime.utcnow().isoformat()) def _create_digest(self): - image_contents = 'docker_image{0}'.format(int(random() * 10 ** 6)) - self.image_digest = "sha256:%s" % hashlib.sha256(image_contents.encode('utf-8')).hexdigest() + image_contents = "docker_image{0}".format(int(random() * 10 ** 6)) + self.image_digest = ( + "sha256:%s" % hashlib.sha256(image_contents.encode("utf-8")).hexdigest() + ) def get_image_digest(self): if not self.image_digest: @@ -135,54 +145,61 @@ class Image(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - response_object['imageId'] = {} - response_object['imageId']['imageTag'] = self.image_tag - response_object['imageId']['imageDigest'] = self.get_image_digest() - response_object['imageManifest'] = self.image_manifest - response_object['repositoryName'] = self.repository - response_object['registryId'] = self.registry_id - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageId"] = {} + response_object["imageId"]["imageTag"] = self.image_tag + response_object["imageId"]["imageDigest"] = self.get_image_digest() + response_object["imageManifest"] = self.image_manifest + response_object["repositoryName"] = self.repository + response_object["registryId"] = self.registry_id + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } @property def response_list_object(self): response_object = self.gen_response_object() - response_object['imageTag'] = self.image_tag - response_object['imageDigest'] = "i don't know" - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageTag"] = self.image_tag + response_object["imageDigest"] = "i don't know" + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } @property def response_describe_object(self): response_object = self.gen_response_object() - response_object['imageTags'] = self.image_tags - response_object['imageDigest'] = self.get_image_digest() - response_object['imageManifest'] = self.image_manifest - response_object['repositoryName'] = self.repository - response_object['registryId'] = self.registry_id - response_object['imageSizeInBytes'] = self.image_size_in_bytes - response_object['imagePushedAt'] = self.image_pushed_at + response_object["imageTags"] = self.image_tags + response_object["imageDigest"] = self.get_image_digest() + response_object["imageManifest"] = self.image_manifest + response_object["repositoryName"] = self.repository + response_object["registryId"] = self.registry_id + response_object["imageSizeInBytes"] = self.image_size_in_bytes + response_object["imagePushedAt"] = self.image_pushed_at return {k: v for k, v in response_object.items() if v is not None and v != []} @property def response_batch_get_image(self): response_object = {} - response_object['imageId'] = {} - response_object['imageId']['imageTag'] = self.image_tag - response_object['imageId']['imageDigest'] = self.get_image_digest() - response_object['imageManifest'] = self.image_manifest - response_object['repositoryName'] = self.repository - response_object['registryId'] = self.registry_id - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageId"] = {} + response_object["imageId"]["imageTag"] = self.image_tag + response_object["imageId"]["imageDigest"] = self.get_image_digest() + response_object["imageManifest"] = self.image_manifest + response_object["repositoryName"] = self.repository + response_object["registryId"] = self.registry_id + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } @property def response_batch_delete_image(self): response_object = {} - response_object['imageDigest'] = self.get_image_digest() - response_object['imageTag'] = self.image_tag - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageDigest"] = self.get_image_digest() + response_object["imageTag"] = self.image_tag + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } class ECRBackend(BaseBackend): - def __init__(self): self.repositories = {} @@ -193,7 +210,9 @@ class ECRBackend(BaseBackend): if repository_names: for repository_name in repository_names: if repository_name not in self.repositories: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) repositories = [] for repository in self.repositories.values(): @@ -218,7 +237,9 @@ class ECRBackend(BaseBackend): if repository_name in self.repositories: return self.repositories.pop(repository_name) else: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) def list_images(self, repository_name, registry_id=None): """ @@ -235,7 +256,9 @@ class ECRBackend(BaseBackend): found = True if not found: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) images = [] for image in repository.images: @@ -247,26 +270,34 @@ class ECRBackend(BaseBackend): if repository_name in self.repositories: repository = self.repositories[repository_name] else: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) if image_ids: response = set() for image_id in image_ids: found = False for image in repository.images: - if (('imageDigest' in image_id and image.get_image_digest() == image_id['imageDigest']) or - ('imageTag' in image_id and image_id['imageTag'] in image.image_tags)): + if ( + "imageDigest" in image_id + and image.get_image_digest() == image_id["imageDigest"] + ) or ( + "imageTag" in image_id + and image_id["imageTag"] in image.image_tags + ): found = True response.add(image) if not found: image_id_representation = "{imageDigest:'%s', imageTag:'%s'}" % ( - image_id.get('imageDigest', 'null'), - image_id.get('imageTag', 'null'), + image_id.get("imageDigest", "null"), + image_id.get("imageTag", "null"), ) raise ImageNotFoundException( image_id=image_id_representation, repository_name=repository_name, - registry_id=registry_id or DEFAULT_REGISTRY_ID) + registry_id=registry_id or DEFAULT_REGISTRY_ID, + ) else: response = [] @@ -281,7 +312,12 @@ class ECRBackend(BaseBackend): else: raise Exception("{0} is not a repository".format(repository_name)) - existing_images = list(filter(lambda x: x.response_object['imageManifest'] == image_manifest, repository.images)) + existing_images = list( + filter( + lambda x: x.response_object["imageManifest"] == image_manifest, + repository.images, + ) + ) if not existing_images: # this image is not in ECR yet image = Image(image_tag, image_manifest, repository_name) @@ -292,36 +328,47 @@ class ECRBackend(BaseBackend): existing_images[0].update_tag(image_tag) return existing_images[0] - def batch_get_image(self, repository_name, registry_id=None, image_ids=None, accepted_media_types=None): + def batch_get_image( + self, + repository_name, + registry_id=None, + image_ids=None, + accepted_media_types=None, + ): if repository_name in self.repositories: repository = self.repositories[repository_name] else: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) if not image_ids: - raise ParamValidationError(msg='Missing required parameter in input: "imageIds"') + raise ParamValidationError( + msg='Missing required parameter in input: "imageIds"' + ) - response = { - 'images': [], - 'failures': [], - } + response = {"images": [], "failures": []} for image_id in image_ids: found = False for image in repository.images: - if (('imageDigest' in image_id and image.get_image_digest() == image_id['imageDigest']) or - ('imageTag' in image_id and image.image_tag == image_id['imageTag'])): + if ( + "imageDigest" in image_id + and image.get_image_digest() == image_id["imageDigest"] + ) or ( + "imageTag" in image_id and image.image_tag == image_id["imageTag"] + ): found = True - response['images'].append(image.response_batch_get_image) + response["images"].append(image.response_batch_get_image) if not found: - response['failures'].append({ - 'imageId': { - 'imageTag': image_id.get('imageTag', 'null') - }, - 'failureCode': 'ImageNotFound', - 'failureReason': 'Requested image not found' - }) + response["failures"].append( + { + "imageId": {"imageTag": image_id.get("imageTag", "null")}, + "failureCode": "ImageNotFound", + "failureReason": "Requested image not found", + } + ) return response @@ -338,10 +385,7 @@ class ECRBackend(BaseBackend): msg='Missing required parameter in input: "imageIds"' ) - response = { - "imageIds": [], - "failures": [] - } + response = {"imageIds": [], "failures": []} for image_id in image_ids: image_found = False @@ -377,8 +421,8 @@ class ECRBackend(BaseBackend): # Search by matching both digest and tag if "imageDigest" in image_id and "imageTag" in image_id: if ( - image_id["imageDigest"] == image.get_image_digest() and - image_id["imageTag"] in image.image_tags + image_id["imageDigest"] == image.get_image_digest() + and image_id["imageTag"] in image.image_tags ): image_found = True for image_tag in reversed(image.image_tags): @@ -390,7 +434,10 @@ class ECRBackend(BaseBackend): del repository.images[num] # Search by matching digest - elif "imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"]: + elif ( + "imageDigest" in image_id + and image.get_image_digest() == image_id["imageDigest"] + ): image_found = True for image_tag in reversed(image.image_tags): repository.images[num].image_tag = image_tag @@ -399,7 +446,9 @@ class ECRBackend(BaseBackend): del repository.images[num] # Search by matching tag - elif "imageTag" in image_id and image_id["imageTag"] in image.image_tags: + elif ( + "imageTag" in image_id and image_id["imageTag"] in image.image_tags + ): image_found = True repository.images[num].image_tag = image_id["imageTag"] response["imageIds"].append(image.response_batch_delete_image) @@ -416,10 +465,14 @@ class ECRBackend(BaseBackend): } if "imageDigest" in image_id: - failure_response["imageId"]["imageDigest"] = image_id.get("imageDigest", "null") + failure_response["imageId"]["imageDigest"] = image_id.get( + "imageDigest", "null" + ) if "imageTag" in image_id: - failure_response["imageId"]["imageTag"] = image_id.get("imageTag", "null") + failure_response["imageId"]["imageTag"] = image_id.get( + "imageTag", "null" + ) response["failures"].append(failure_response) diff --git a/moto/ecr/responses.py b/moto/ecr/responses.py index f758176ad..37078b878 100644 --- a/moto/ecr/responses.py +++ b/moto/ecr/responses.py @@ -24,148 +24,154 @@ class ECRResponse(BaseResponse): return self.request_params.get(param, None) def create_repository(self): - repository_name = self._get_param('repositoryName') + repository_name = self._get_param("repositoryName") if repository_name is None: - repository_name = 'default' + repository_name = "default" repository = self.ecr_backend.create_repository(repository_name) - return json.dumps({ - 'repository': repository.response_object - }) + return json.dumps({"repository": repository.response_object}) def describe_repositories(self): - describe_repositories_name = self._get_param('repositoryNames') - registry_id = self._get_param('registryId') + describe_repositories_name = self._get_param("repositoryNames") + registry_id = self._get_param("registryId") repositories = self.ecr_backend.describe_repositories( - repository_names=describe_repositories_name, registry_id=registry_id) - return json.dumps({ - 'repositories': repositories, - 'failures': [] - }) + repository_names=describe_repositories_name, registry_id=registry_id + ) + return json.dumps({"repositories": repositories, "failures": []}) def delete_repository(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") repository = self.ecr_backend.delete_repository(repository_str, registry_id) - return json.dumps({ - 'repository': repository.response_object - }) + return json.dumps({"repository": repository.response_object}) def put_image(self): - repository_str = self._get_param('repositoryName') - image_manifest = self._get_param('imageManifest') - image_tag = self._get_param('imageTag') + repository_str = self._get_param("repositoryName") + image_manifest = self._get_param("imageManifest") + image_tag = self._get_param("imageTag") image = self.ecr_backend.put_image(repository_str, image_manifest, image_tag) - return json.dumps({ - 'image': image.response_object - }) + return json.dumps({"image": image.response_object}) def list_images(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") images = self.ecr_backend.list_images(repository_str, registry_id) - return json.dumps({ - 'imageIds': [image.response_list_object for image in images], - }) + return json.dumps( + {"imageIds": [image.response_list_object for image in images]} + ) def describe_images(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') - image_ids = self._get_param('imageIds') - images = self.ecr_backend.describe_images(repository_str, registry_id, image_ids) - return json.dumps({ - 'imageDetails': [image.response_describe_object for image in images], - }) + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") + image_ids = self._get_param("imageIds") + images = self.ecr_backend.describe_images( + repository_str, registry_id, image_ids + ) + return json.dumps( + {"imageDetails": [image.response_describe_object for image in images]} + ) def batch_check_layer_availability(self): - if self.is_not_dryrun('BatchCheckLayerAvailability'): + if self.is_not_dryrun("BatchCheckLayerAvailability"): raise NotImplementedError( - 'ECR.batch_check_layer_availability is not yet implemented') + "ECR.batch_check_layer_availability is not yet implemented" + ) def batch_delete_image(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') - image_ids = self._get_param('imageIds') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") + image_ids = self._get_param("imageIds") - response = self.ecr_backend.batch_delete_image(repository_str, registry_id, image_ids) + response = self.ecr_backend.batch_delete_image( + repository_str, registry_id, image_ids + ) return json.dumps(response) def batch_get_image(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') - image_ids = self._get_param('imageIds') - accepted_media_types = self._get_param('acceptedMediaTypes') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") + image_ids = self._get_param("imageIds") + accepted_media_types = self._get_param("acceptedMediaTypes") - response = self.ecr_backend.batch_get_image(repository_str, registry_id, image_ids, accepted_media_types) + response = self.ecr_backend.batch_get_image( + repository_str, registry_id, image_ids, accepted_media_types + ) return json.dumps(response) def can_paginate(self): - if self.is_not_dryrun('CanPaginate'): - raise NotImplementedError( - 'ECR.can_paginate is not yet implemented') + if self.is_not_dryrun("CanPaginate"): + raise NotImplementedError("ECR.can_paginate is not yet implemented") def complete_layer_upload(self): - if self.is_not_dryrun('CompleteLayerUpload'): + if self.is_not_dryrun("CompleteLayerUpload"): raise NotImplementedError( - 'ECR.complete_layer_upload is not yet implemented') + "ECR.complete_layer_upload is not yet implemented" + ) def delete_repository_policy(self): - if self.is_not_dryrun('DeleteRepositoryPolicy'): + if self.is_not_dryrun("DeleteRepositoryPolicy"): raise NotImplementedError( - 'ECR.delete_repository_policy is not yet implemented') + "ECR.delete_repository_policy is not yet implemented" + ) def generate_presigned_url(self): - if self.is_not_dryrun('GeneratePresignedUrl'): + if self.is_not_dryrun("GeneratePresignedUrl"): raise NotImplementedError( - 'ECR.generate_presigned_url is not yet implemented') + "ECR.generate_presigned_url is not yet implemented" + ) def get_authorization_token(self): - registry_ids = self._get_param('registryIds') + registry_ids = self._get_param("registryIds") if not registry_ids: registry_ids = [DEFAULT_REGISTRY_ID] auth_data = [] for registry_id in registry_ids: - password = '{}-auth-token'.format(registry_id) - auth_token = b64encode("AWS:{}".format(password).encode('ascii')).decode() - auth_data.append({ - 'authorizationToken': auth_token, - 'expiresAt': time.mktime(datetime(2015, 1, 1).timetuple()), - 'proxyEndpoint': 'https://{}.dkr.ecr.{}.amazonaws.com'.format(registry_id, self.region) - }) - return json.dumps({'authorizationData': auth_data}) + password = "{}-auth-token".format(registry_id) + auth_token = b64encode("AWS:{}".format(password).encode("ascii")).decode() + auth_data.append( + { + "authorizationToken": auth_token, + "expiresAt": time.mktime(datetime(2015, 1, 1).timetuple()), + "proxyEndpoint": "https://{}.dkr.ecr.{}.amazonaws.com".format( + registry_id, self.region + ), + } + ) + return json.dumps({"authorizationData": auth_data}) def get_download_url_for_layer(self): - if self.is_not_dryrun('GetDownloadUrlForLayer'): + if self.is_not_dryrun("GetDownloadUrlForLayer"): raise NotImplementedError( - 'ECR.get_download_url_for_layer is not yet implemented') + "ECR.get_download_url_for_layer is not yet implemented" + ) def get_paginator(self): - if self.is_not_dryrun('GetPaginator'): - raise NotImplementedError( - 'ECR.get_paginator is not yet implemented') + if self.is_not_dryrun("GetPaginator"): + raise NotImplementedError("ECR.get_paginator is not yet implemented") def get_repository_policy(self): - if self.is_not_dryrun('GetRepositoryPolicy'): + if self.is_not_dryrun("GetRepositoryPolicy"): raise NotImplementedError( - 'ECR.get_repository_policy is not yet implemented') + "ECR.get_repository_policy is not yet implemented" + ) def get_waiter(self): - if self.is_not_dryrun('GetWaiter'): - raise NotImplementedError( - 'ECR.get_waiter is not yet implemented') + if self.is_not_dryrun("GetWaiter"): + raise NotImplementedError("ECR.get_waiter is not yet implemented") def initiate_layer_upload(self): - if self.is_not_dryrun('InitiateLayerUpload'): + if self.is_not_dryrun("InitiateLayerUpload"): raise NotImplementedError( - 'ECR.initiate_layer_upload is not yet implemented') + "ECR.initiate_layer_upload is not yet implemented" + ) def set_repository_policy(self): - if self.is_not_dryrun('SetRepositoryPolicy'): + if self.is_not_dryrun("SetRepositoryPolicy"): raise NotImplementedError( - 'ECR.set_repository_policy is not yet implemented') + "ECR.set_repository_policy is not yet implemented" + ) def upload_layer_part(self): - if self.is_not_dryrun('UploadLayerPart'): - raise NotImplementedError( - 'ECR.upload_layer_part is not yet implemented') + if self.is_not_dryrun("UploadLayerPart"): + raise NotImplementedError("ECR.upload_layer_part is not yet implemented") diff --git a/moto/ecr/urls.py b/moto/ecr/urls.py index 5b12cd843..a25874e43 100644 --- a/moto/ecr/urls.py +++ b/moto/ecr/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import ECRResponse -url_bases = [ - "https?://ecr.(.+).amazonaws.com", - "https?://api.ecr.(.+).amazonaws.com", -] +url_bases = ["https?://ecr.(.+).amazonaws.com", "https?://api.ecr.(.+).amazonaws.com"] -url_paths = { - '{0}/$': ECRResponse.dispatch, -} +url_paths = {"{0}/$": ECRResponse.dispatch} diff --git a/moto/ecs/__init__.py b/moto/ecs/__init__.py index 8fb3dd41e..3048838be 100644 --- a/moto/ecs/__init__.py +++ b/moto/ecs/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import ecs_backends from ..core.models import base_decorator, deprecated_base_decorator -ecs_backend = ecs_backends['us-east-1'] +ecs_backend = ecs_backends["us-east-1"] mock_ecs = base_decorator(ecs_backends) mock_ecs_deprecated = deprecated_base_decorator(ecs_backends) diff --git a/moto/ecs/exceptions.py b/moto/ecs/exceptions.py index 6e329f227..d08066192 100644 --- a/moto/ecs/exceptions.py +++ b/moto/ecs/exceptions.py @@ -9,7 +9,7 @@ class ServiceNotFoundException(RESTError): super(ServiceNotFoundException, self).__init__( error_type="ServiceNotFoundException", message="The service {0} does not exist".format(service_name), - template='error_json', + template="error_json", ) diff --git a/moto/ecs/models.py b/moto/ecs/models.py index 863cfc49e..c9dc998ee 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -12,27 +12,23 @@ from moto.core.utils import unix_time from moto.ec2 import ec2_backends from copy import copy -from .exceptions import ( - ServiceNotFoundException, - TaskDefinitionNotFoundException -) +from .exceptions import ServiceNotFoundException, TaskDefinitionNotFoundException class BaseObject(BaseModel): - def camelCase(self, key): words = [] - for i, word in enumerate(key.split('_')): + for i, word in enumerate(key.split("_")): if i > 0: words.append(word.title()) else: words.append(word) - return ''.join(words) + return "".join(words) def gen_response_object(self): response_object = copy(self.__dict__) for key, value in self.__dict__.items(): - if '_' in key: + if "_" in key: response_object[self.camelCase(key)] = value del response_object[key] return response_object @@ -43,16 +39,17 @@ class BaseObject(BaseModel): class Cluster(BaseObject): - - def __init__(self, cluster_name): + def __init__(self, cluster_name, region_name): self.active_services_count = 0 - self.arn = 'arn:aws:ecs:us-east-1:012345678910:cluster/{0}'.format( - cluster_name) + self.arn = "arn:aws:ecs:{0}:012345678910:cluster/{1}".format( + region_name, cluster_name + ) self.name = cluster_name self.pending_tasks_count = 0 self.registered_container_instances_count = 0 self.running_tasks_count = 0 - self.status = 'ACTIVE' + self.status = "ACTIVE" + self.region_name = region_name @property def physical_resource_id(self): @@ -61,16 +58,18 @@ class Cluster(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - response_object['clusterArn'] = self.arn - response_object['clusterName'] = self.name - del response_object['arn'], response_object['name'] + response_object["clusterArn"] = self.arn + response_object["clusterName"] = self.name + del response_object["arn"], response_object["name"] return response_object @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): # if properties is not provided, cloudformation will use the default values for all properties - if 'Properties' in cloudformation_json: - properties = cloudformation_json['Properties'] + if "Properties" in cloudformation_json: + properties = cloudformation_json["Properties"] else: properties = {} @@ -79,21 +78,25 @@ class Cluster(BaseObject): # ClusterName is optional in CloudFormation, thus create a random # name if necessary cluster_name=properties.get( - 'ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))), + "ClusterName", "ecscluster{0}".format(int(random() * 10 ** 6)) + ) ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - if original_resource.name != properties['ClusterName']: + if original_resource.name != properties["ClusterName"]: ecs_backend = ecs_backends[region_name] ecs_backend.delete_cluster(original_resource.arn) return ecs_backend.create_cluster( # ClusterName is optional in CloudFormation, thus create a # random name if necessary cluster_name=properties.get( - 'ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))), + "ClusterName", "ecscluster{0}".format(int(random() * 10 ** 6)) + ) ) else: # no-op when nothing changed between old and new resources @@ -101,18 +104,27 @@ class Cluster(BaseObject): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() class TaskDefinition(BaseObject): - - def __init__(self, family, revision, container_definitions, volumes=None, tags=None): + def __init__( + self, + family, + revision, + container_definitions, + region_name, + volumes=None, + tags=None, + ): self.family = family self.revision = revision - self.arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/{0}:{1}'.format( - family, revision) + self.arn = "arn:aws:ecs:{0}:012345678910:task-definition/{1}:{2}".format( + region_name, family, revision + ) self.container_definitions = container_definitions self.tags = tags if tags is not None else [] if volumes is None: @@ -123,9 +135,9 @@ class TaskDefinition(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - response_object['taskDefinitionArn'] = response_object['arn'] - del response_object['arn'] - del response_object['tags'] + response_object["taskDefinitionArn"] = response_object["arn"] + del response_object["arn"] + del response_object["tags"] return response_object @property @@ -133,55 +145,74 @@ class TaskDefinition(BaseObject): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] family = properties.get( - 'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) - container_definitions = properties['ContainerDefinitions'] - volumes = properties.get('Volumes') + "Family", "task-definition-{0}".format(int(random() * 10 ** 6)) + ) + container_definitions = properties["ContainerDefinitions"] + volumes = properties.get("Volumes") ecs_backend = ecs_backends[region_name] return ecs_backend.register_task_definition( - family=family, container_definitions=container_definitions, volumes=volumes) + family=family, container_definitions=container_definitions, volumes=volumes + ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] family = properties.get( - 'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) - container_definitions = properties['ContainerDefinitions'] - volumes = properties.get('Volumes') - if (original_resource.family != family or - original_resource.container_definitions != container_definitions or - original_resource.volumes != volumes): - # currently TaskRoleArn isn't stored at TaskDefinition - # instances + "Family", "task-definition-{0}".format(int(random() * 10 ** 6)) + ) + container_definitions = properties["ContainerDefinitions"] + volumes = properties.get("Volumes") + if ( + original_resource.family != family + or original_resource.container_definitions != container_definitions + or original_resource.volumes != volumes + ): + # currently TaskRoleArn isn't stored at TaskDefinition + # instances ecs_backend = ecs_backends[region_name] ecs_backend.deregister_task_definition(original_resource.arn) return ecs_backend.register_task_definition( - family=family, container_definitions=container_definitions, volumes=volumes) + family=family, + container_definitions=container_definitions, + volumes=volumes, + ) else: # no-op when nothing changed between old and new resources return original_resource class Task(BaseObject): - - def __init__(self, cluster, task_definition, container_instance_arn, - resource_requirements, overrides={}, started_by=''): + def __init__( + self, + cluster, + task_definition, + container_instance_arn, + resource_requirements, + overrides={}, + started_by="", + ): self.cluster_arn = cluster.arn - self.task_arn = 'arn:aws:ecs:us-east-1:012345678910:task/{0}'.format( - str(uuid.uuid4())) + self.task_arn = "arn:aws:ecs:{0}:012345678910:task/{1}".format( + cluster.region_name, str(uuid.uuid4()) + ) self.container_instance_arn = container_instance_arn - self.last_status = 'RUNNING' - self.desired_status = 'RUNNING' + self.last_status = "RUNNING" + self.desired_status = "RUNNING" self.task_definition_arn = task_definition.arn self.overrides = overrides self.containers = [] self.started_by = started_by - self.stopped_reason = '' + self.stopped_reason = "" self.resource_requirements = resource_requirements @property @@ -191,31 +222,43 @@ class Task(BaseObject): class Service(BaseObject): - - def __init__(self, cluster, service_name, task_definition, desired_count, load_balancers=None, scheduling_strategy=None): + def __init__( + self, + cluster, + service_name, + task_definition, + desired_count, + load_balancers=None, + scheduling_strategy=None, + tags=None, + ): self.cluster_arn = cluster.arn - self.arn = 'arn:aws:ecs:us-east-1:012345678910:service/{0}'.format( - service_name) + self.arn = "arn:aws:ecs:{0}:012345678910:service/{1}".format( + cluster.region_name, service_name + ) self.name = service_name - self.status = 'ACTIVE' + self.status = "ACTIVE" self.running_count = 0 self.task_definition = task_definition.arn self.desired_count = desired_count self.events = [] self.deployments = [ { - 'createdAt': datetime.now(pytz.utc), - 'desiredCount': self.desired_count, - 'id': 'ecs-svc/{}'.format(randint(0, 32**12)), - 'pendingCount': self.desired_count, - 'runningCount': 0, - 'status': 'PRIMARY', - 'taskDefinition': task_definition.arn, - 'updatedAt': datetime.now(pytz.utc), + "createdAt": datetime.now(pytz.utc), + "desiredCount": self.desired_count, + "id": "ecs-svc/{}".format(randint(0, 32 ** 12)), + "pendingCount": self.desired_count, + "runningCount": 0, + "status": "PRIMARY", + "taskDefinition": task_definition.arn, + "updatedAt": datetime.now(pytz.utc), } ] self.load_balancers = load_balancers if load_balancers is not None else [] - self.scheduling_strategy = scheduling_strategy if scheduling_strategy is not None else 'REPLICA' + self.scheduling_strategy = ( + scheduling_strategy if scheduling_strategy is not None else "REPLICA" + ) + self.tags = tags if tags is not None else [] self.pending_count = 0 @property @@ -225,189 +268,223 @@ class Service(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - del response_object['name'], response_object['arn'] - response_object['serviceName'] = self.name - response_object['serviceArn'] = self.arn - response_object['schedulingStrategy'] = self.scheduling_strategy + del response_object["name"], response_object["arn"], response_object["tags"] + response_object["serviceName"] = self.name + response_object["serviceArn"] = self.arn + response_object["schedulingStrategy"] = self.scheduling_strategy - for deployment in response_object['deployments']: - if isinstance(deployment['createdAt'], datetime): - deployment['createdAt'] = unix_time(deployment['createdAt'].replace(tzinfo=None)) - if isinstance(deployment['updatedAt'], datetime): - deployment['updatedAt'] = unix_time(deployment['updatedAt'].replace(tzinfo=None)) + for deployment in response_object["deployments"]: + if isinstance(deployment["createdAt"], datetime): + deployment["createdAt"] = unix_time( + deployment["createdAt"].replace(tzinfo=None) + ) + if isinstance(deployment["updatedAt"], datetime): + deployment["updatedAt"] = unix_time( + deployment["updatedAt"].replace(tzinfo=None) + ) return response_object @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - if isinstance(properties['Cluster'], Cluster): - cluster = properties['Cluster'].name + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + if isinstance(properties["Cluster"], Cluster): + cluster = properties["Cluster"].name else: - cluster = properties['Cluster'] - if isinstance(properties['TaskDefinition'], TaskDefinition): - task_definition = properties['TaskDefinition'].family + cluster = properties["Cluster"] + if isinstance(properties["TaskDefinition"], TaskDefinition): + task_definition = properties["TaskDefinition"].family else: - task_definition = properties['TaskDefinition'] - service_name = '{0}Service{1}'.format(cluster, int(random() * 10 ** 6)) - desired_count = properties['DesiredCount'] + task_definition = properties["TaskDefinition"] + service_name = "{0}Service{1}".format(cluster, int(random() * 10 ** 6)) + desired_count = properties["DesiredCount"] # TODO: LoadBalancers # TODO: Role ecs_backend = ecs_backends[region_name] return ecs_backend.create_service( - cluster, service_name, task_definition, desired_count) + cluster, service_name, task_definition, desired_count + ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - if isinstance(properties['Cluster'], Cluster): - cluster_name = properties['Cluster'].name + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + if isinstance(properties["Cluster"], Cluster): + cluster_name = properties["Cluster"].name else: - cluster_name = properties['Cluster'] - if isinstance(properties['TaskDefinition'], TaskDefinition): - task_definition = properties['TaskDefinition'].family + cluster_name = properties["Cluster"] + if isinstance(properties["TaskDefinition"], TaskDefinition): + task_definition = properties["TaskDefinition"].family else: - task_definition = properties['TaskDefinition'] - desired_count = properties['DesiredCount'] + task_definition = properties["TaskDefinition"] + desired_count = properties["DesiredCount"] ecs_backend = ecs_backends[region_name] service_name = original_resource.name - if original_resource.cluster_arn != Cluster(cluster_name).arn: + if original_resource.cluster_arn != Cluster(cluster_name, region_name).arn: # TODO: LoadBalancers # TODO: Role ecs_backend.delete_service(cluster_name, service_name) - new_service_name = '{0}Service{1}'.format( - cluster_name, int(random() * 10 ** 6)) + new_service_name = "{0}Service{1}".format( + cluster_name, int(random() * 10 ** 6) + ) return ecs_backend.create_service( - cluster_name, new_service_name, task_definition, desired_count) + cluster_name, new_service_name, task_definition, desired_count + ) else: - return ecs_backend.update_service(cluster_name, service_name, task_definition, desired_count) + return ecs_backend.update_service( + cluster_name, service_name, task_definition, desired_count + ) def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Name': + + if attribute_name == "Name": return self.name raise UnformattedGetAttTemplateException() class ContainerInstance(BaseObject): - def __init__(self, ec2_instance_id, region_name): self.ec2_instance_id = ec2_instance_id self.agent_connected = True - self.status = 'ACTIVE' + self.status = "ACTIVE" self.registered_resources = [ - {'doubleValue': 0.0, - 'integerValue': 4096, - 'longValue': 0, - 'name': 'CPU', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 7482, - 'longValue': 0, - 'name': 'MEMORY', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS', - 'stringSetValue': ['22', '2376', '2375', '51678', '51679'], - 'type': 'STRINGSET'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS_UDP', - 'stringSetValue': [], - 'type': 'STRINGSET'}] - self.container_instance_arn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format( - str(uuid.uuid4())) + { + "doubleValue": 0.0, + "integerValue": 4096, + "longValue": 0, + "name": "CPU", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 7482, + "longValue": 0, + "name": "MEMORY", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS", + "stringSetValue": ["22", "2376", "2375", "51678", "51679"], + "type": "STRINGSET", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS_UDP", + "stringSetValue": [], + "type": "STRINGSET", + }, + ] + self.container_instance_arn = "arn:aws:ecs:{0}:012345678910:container-instance/{1}".format( + region_name, str(uuid.uuid4()) + ) self.pending_tasks_count = 0 self.remaining_resources = [ - {'doubleValue': 0.0, - 'integerValue': 4096, - 'longValue': 0, - 'name': 'CPU', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 7482, - 'longValue': 0, - 'name': 'MEMORY', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS', - 'stringSetValue': ['22', '2376', '2375', '51678', '51679'], - 'type': 'STRINGSET'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS_UDP', - 'stringSetValue': [], - 'type': 'STRINGSET'} + { + "doubleValue": 0.0, + "integerValue": 4096, + "longValue": 0, + "name": "CPU", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 7482, + "longValue": 0, + "name": "MEMORY", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS", + "stringSetValue": ["22", "2376", "2375", "51678", "51679"], + "type": "STRINGSET", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS_UDP", + "stringSetValue": [], + "type": "STRINGSET", + }, ] self.running_tasks_count = 0 self.version_info = { - 'agentVersion': "1.0.0", - 'agentHash': '4023248', - 'dockerVersion': 'DockerVersion: 1.5.0' + "agentVersion": "1.0.0", + "agentHash": "4023248", + "dockerVersion": "DockerVersion: 1.5.0", } ec2_backend = ec2_backends[region_name] ec2_instance = ec2_backend.get_instance(ec2_instance_id) self.attributes = { - 'ecs.ami-id': ec2_instance.image_id, - 'ecs.availability-zone': ec2_instance.placement, - 'ecs.instance-type': ec2_instance.instance_type, - 'ecs.os-type': ec2_instance.platform if ec2_instance.platform == 'windows' else 'linux' # options are windows and linux, linux is default + "ecs.ami-id": ec2_instance.image_id, + "ecs.availability-zone": ec2_instance.placement, + "ecs.instance-type": ec2_instance.instance_type, + "ecs.os-type": ec2_instance.platform + if ec2_instance.platform == "windows" + else "linux", # options are windows and linux, linux is default } @property def response_object(self): response_object = self.gen_response_object() - response_object['attributes'] = [self._format_attribute(name, value) for name, value in response_object['attributes'].items()] + response_object["attributes"] = [ + self._format_attribute(name, value) + for name, value in response_object["attributes"].items() + ] return response_object def _format_attribute(self, name, value): - formatted_attr = { - 'name': name, - } + formatted_attr = {"name": name} if value is not None: - formatted_attr['value'] = value + formatted_attr["value"] = value return formatted_attr class ClusterFailure(BaseObject): - def __init__(self, reason, cluster_name): + def __init__(self, reason, cluster_name, region_name): self.reason = reason - self.arn = "arn:aws:ecs:us-east-1:012345678910:cluster/{0}".format( - cluster_name) + self.arn = "arn:aws:ecs:{0}:012345678910:cluster/{1}".format( + region_name, cluster_name + ) @property def response_object(self): response_object = self.gen_response_object() - response_object['reason'] = self.reason - response_object['arn'] = self.arn + response_object["reason"] = self.reason + response_object["arn"] = self.arn return response_object class ContainerInstanceFailure(BaseObject): - - def __init__(self, reason, container_instance_id): + def __init__(self, reason, container_instance_id, region_name): self.reason = reason - self.arn = "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format( - container_instance_id) + self.arn = "arn:aws:ecs:{0}:012345678910:container-instance/{1}".format( + region_name, container_instance_id + ) @property def response_object(self): response_object = self.gen_response_object() - response_object['reason'] = self.reason - response_object['arn'] = self.arn + response_object["reason"] = self.reason + response_object["arn"] = self.arn return response_object class EC2ContainerServiceBackend(BaseBackend): - def __init__(self, region_name): super(EC2ContainerServiceBackend, self).__init__() self.clusters = {} @@ -423,22 +500,24 @@ class EC2ContainerServiceBackend(BaseBackend): self.__init__(region_name) def describe_task_definition(self, task_definition_str): - task_definition_name = task_definition_str.split('/')[-1] - if ':' in task_definition_name: - family, revision = task_definition_name.split(':') + task_definition_name = task_definition_str.split("/")[-1] + if ":" in task_definition_name: + family, revision = task_definition_name.split(":") revision = int(revision) else: family = task_definition_name revision = self._get_last_task_definition_revision_id(family) - if family in self.task_definitions and revision in self.task_definitions[family]: + if ( + family in self.task_definitions + and revision in self.task_definitions[family] + ): return self.task_definitions[family][revision] else: - raise Exception( - "{0} is not a task_definition".format(task_definition_name)) + raise Exception("{0} is not a task_definition".format(task_definition_name)) def create_cluster(self, cluster_name): - cluster = Cluster(cluster_name) + cluster = Cluster(cluster_name, self.region_name) self.clusters[cluster_name] = cluster return cluster @@ -452,26 +531,29 @@ class EC2ContainerServiceBackend(BaseBackend): list_clusters = [] failures = [] if list_clusters_name is None: - if 'default' in self.clusters: - list_clusters.append(self.clusters['default'].response_object) + if "default" in self.clusters: + list_clusters.append(self.clusters["default"].response_object) else: for cluster in list_clusters_name: - cluster_name = cluster.split('/')[-1] + cluster_name = cluster.split("/")[-1] if cluster_name in self.clusters: - list_clusters.append( - self.clusters[cluster_name].response_object) + list_clusters.append(self.clusters[cluster_name].response_object) else: - failures.append(ClusterFailure('MISSING', cluster_name)) + failures.append( + ClusterFailure("MISSING", cluster_name, self.region_name) + ) return list_clusters, failures def delete_cluster(self, cluster_str): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: return self.clusters.pop(cluster_name) else: raise Exception("{0} is not a cluster".format(cluster_name)) - def register_task_definition(self, family, container_definitions, volumes, tags=None): + def register_task_definition( + self, family, container_definitions, volumes, tags=None + ): if family in self.task_definitions: last_id = self._get_last_task_definition_revision_id(family) revision = (last_id or 0) + 1 @@ -479,35 +561,38 @@ class EC2ContainerServiceBackend(BaseBackend): self.task_definitions[family] = {} revision = 1 task_definition = TaskDefinition( - family, revision, container_definitions, volumes, tags) + family, revision, container_definitions, self.region_name, volumes, tags + ) self.task_definitions[family][revision] = task_definition return task_definition - def list_task_definitions(self): - """ - Filtering not implemented - """ + def list_task_definitions(self, family_prefix): task_arns = [] for task_definition_list in self.task_definitions.values(): - task_arns.extend([ - task_definition.arn - for task_definition in task_definition_list.values() - ]) + task_arns.extend( + [ + task_definition.arn + for task_definition in task_definition_list.values() + if family_prefix is None or task_definition.family == family_prefix + ] + ) return task_arns def deregister_task_definition(self, task_definition_str): - task_definition_name = task_definition_str.split('/')[-1] - family, revision = task_definition_name.split(':') + task_definition_name = task_definition_str.split("/")[-1] + family, revision = task_definition_name.split(":") revision = int(revision) - if family in self.task_definitions and revision in self.task_definitions[family]: + if ( + family in self.task_definitions + and revision in self.task_definitions[family] + ): return self.task_definitions[family].pop(revision) else: - raise Exception( - "{0} is not a task_definition".format(task_definition_name)) + raise Exception("{0} is not a task_definition".format(task_definition_name)) def run_task(self, cluster_str, task_definition_str, count, overrides, started_by): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -517,24 +602,42 @@ class EC2ContainerServiceBackend(BaseBackend): self.tasks[cluster_name] = {} tasks = [] container_instances = list( - self.container_instances.get(cluster_name, {}).keys()) + self.container_instances.get(cluster_name, {}).keys() + ) if not container_instances: raise Exception("No instances found in cluster {}".format(cluster_name)) - active_container_instances = [x for x in container_instances if - self.container_instances[cluster_name][x].status == 'ACTIVE'] - resource_requirements = self._calculate_task_resource_requirements(task_definition) + active_container_instances = [ + x + for x in container_instances + if self.container_instances[cluster_name][x].status == "ACTIVE" + ] + resource_requirements = self._calculate_task_resource_requirements( + task_definition + ) # TODO: return event about unable to place task if not able to place enough tasks to meet count placed_count = 0 for container_instance in active_container_instances: - container_instance = self.container_instances[cluster_name][container_instance] + container_instance = self.container_instances[cluster_name][ + container_instance + ] container_instance_arn = container_instance.container_instance_arn try_to_place = True while try_to_place: - can_be_placed, message = self._can_be_placed(container_instance, resource_requirements) + can_be_placed, message = self._can_be_placed( + container_instance, resource_requirements + ) if can_be_placed: - task = Task(cluster, task_definition, container_instance_arn, - resource_requirements, overrides or {}, started_by or '') - self.update_container_instance_resources(container_instance, resource_requirements) + task = Task( + cluster, + task_definition, + container_instance_arn, + resource_requirements, + overrides or {}, + started_by or "", + ) + self.update_container_instance_resources( + container_instance, resource_requirements + ) tasks.append(task) self.tasks[cluster_name][task.task_arn] = task placed_count += 1 @@ -551,23 +654,33 @@ class EC2ContainerServiceBackend(BaseBackend): # cloudformation uses capitalized properties, while boto uses all lower case # CPU is optional - resource_requirements["CPU"] += container_definition.get('cpu', - container_definition.get('Cpu', 0)) + resource_requirements["CPU"] += container_definition.get( + "cpu", container_definition.get("Cpu", 0) + ) # either memory or memory reservation must be provided - if 'Memory' in container_definition or 'MemoryReservation' in container_definition: + if ( + "Memory" in container_definition + or "MemoryReservation" in container_definition + ): resource_requirements["MEMORY"] += container_definition.get( - "Memory", container_definition.get('MemoryReservation')) + "Memory", container_definition.get("MemoryReservation") + ) else: resource_requirements["MEMORY"] += container_definition.get( - "memory", container_definition.get('memoryReservation')) + "memory", container_definition.get("memoryReservation") + ) - port_mapping_key = 'PortMappings' if 'PortMappings' in container_definition else 'portMappings' + port_mapping_key = ( + "PortMappings" + if "PortMappings" in container_definition + else "portMappings" + ) for port_mapping in container_definition.get(port_mapping_key, []): - if 'hostPort' in port_mapping: - resource_requirements["PORTS"].append(port_mapping.get('hostPort')) - elif 'HostPort' in port_mapping: - resource_requirements["PORTS"].append(port_mapping.get('HostPort')) + if "hostPort" in port_mapping: + resource_requirements["PORTS"].append(port_mapping.get("hostPort")) + elif "HostPort" in port_mapping: + resource_requirements["PORTS"].append(port_mapping.get("HostPort")) return resource_requirements @@ -602,8 +715,15 @@ class EC2ContainerServiceBackend(BaseBackend): return False, "Port clash" return True, "Can be placed" - def start_task(self, cluster_str, task_definition_str, container_instances, overrides, started_by): - cluster_name = cluster_str.split('/')[-1] + def start_task( + self, + cluster_str, + task_definition_str, + container_instances, + overrides, + started_by, + ): + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -615,22 +735,31 @@ class EC2ContainerServiceBackend(BaseBackend): if not container_instances: raise Exception("No container instance list provided") - container_instance_ids = [x.split('/')[-1] - for x in container_instances] - resource_requirements = self._calculate_task_resource_requirements(task_definition) + container_instance_ids = [x.split("/")[-1] for x in container_instances] + resource_requirements = self._calculate_task_resource_requirements( + task_definition + ) for container_instance_id in container_instance_ids: container_instance = self.container_instances[cluster_name][ container_instance_id ] - task = Task(cluster, task_definition, container_instance.container_instance_arn, - resource_requirements, overrides or {}, started_by or '') + task = Task( + cluster, + task_definition, + container_instance.container_instance_arn, + resource_requirements, + overrides or {}, + started_by or "", + ) tasks.append(task) - self.update_container_instance_resources(container_instance, resource_requirements) + self.update_container_instance_resources( + container_instance, resource_requirements + ) self.tasks[cluster_name][task.task_arn] = task return tasks def describe_tasks(self, cluster_str, tasks): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -641,58 +770,88 @@ class EC2ContainerServiceBackend(BaseBackend): for cluster, cluster_tasks in self.tasks.items(): for task_arn, task in cluster_tasks.items(): task_id = task_arn.split("/")[-1] - if task_arn in tasks or task.task_arn in tasks or any(task_id in task for task in tasks): + if ( + task_arn in tasks + or task.task_arn in tasks + or any(task_id in task for task in tasks) + ): response.append(task) return response - def list_tasks(self, cluster_str, container_instance, family, started_by, service_name, desiredStatus): + def list_tasks( + self, + cluster_str, + container_instance, + family, + started_by, + service_name, + desiredStatus, + ): filtered_tasks = [] for cluster, tasks in self.tasks.items(): for arn, task in tasks.items(): filtered_tasks.append(task) if cluster_str: - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) filtered_tasks = list( - filter(lambda t: cluster_name in t.cluster_arn, filtered_tasks)) + filter(lambda t: cluster_name in t.cluster_arn, filtered_tasks) + ) if container_instance: - filtered_tasks = list(filter( - lambda t: container_instance in t.container_instance_arn, filtered_tasks)) + filtered_tasks = list( + filter( + lambda t: container_instance in t.container_instance_arn, + filtered_tasks, + ) + ) if started_by: filtered_tasks = list( - filter(lambda t: started_by == t.started_by, filtered_tasks)) + filter(lambda t: started_by == t.started_by, filtered_tasks) + ) return [t.task_arn for t in filtered_tasks] def stop_task(self, cluster_str, task_str, reason): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) if not task_str: raise Exception("A task ID or ARN is required") - task_id = task_str.split('/')[-1] + task_id = task_str.split("/")[-1] tasks = self.tasks.get(cluster_name, None) if not tasks: - raise Exception( - "Cluster {} has no registered tasks".format(cluster_name)) + raise Exception("Cluster {} has no registered tasks".format(cluster_name)) for task in tasks.keys(): if task.endswith(task_id): container_instance_arn = tasks[task].container_instance_arn - container_instance = self.container_instances[cluster_name][container_instance_arn.split('/')[-1]] - self.update_container_instance_resources(container_instance, tasks[task].resource_requirements, - removing=True) - tasks[task].last_status = 'STOPPED' - tasks[task].desired_status = 'STOPPED' + container_instance = self.container_instances[cluster_name][ + container_instance_arn.split("/")[-1] + ] + self.update_container_instance_resources( + container_instance, tasks[task].resource_requirements, removing=True + ) + tasks[task].last_status = "STOPPED" + tasks[task].desired_status = "STOPPED" tasks[task].stopped_reason = reason return tasks[task] - raise Exception("Could not find task {} on cluster {}".format( - task_str, cluster_name)) + raise Exception( + "Could not find task {} on cluster {}".format(task_str, cluster_name) + ) - def create_service(self, cluster_str, service_name, task_definition_str, desired_count, load_balancers=None, scheduling_strategy=None): - cluster_name = cluster_str.split('/')[-1] + def create_service( + self, + cluster_str, + service_name, + task_definition_str, + desired_count, + load_balancers=None, + scheduling_strategy=None, + tags=None, + ): + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -700,52 +859,70 @@ class EC2ContainerServiceBackend(BaseBackend): task_definition = self.describe_task_definition(task_definition_str) desired_count = desired_count if desired_count is not None else 0 - service = Service(cluster, service_name, - task_definition, desired_count, load_balancers, scheduling_strategy) - cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) + service = Service( + cluster, + service_name, + task_definition, + desired_count, + load_balancers, + scheduling_strategy, + tags, + ) + cluster_service_pair = "{0}:{1}".format(cluster_name, service_name) self.services[cluster_service_pair] = service return service def list_services(self, cluster_str, scheduling_strategy=None): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] service_arns = [] for key, value in self.services.items(): - if cluster_name + ':' in key: + if cluster_name + ":" in key: service = self.services[key] - if scheduling_strategy is None or service.scheduling_strategy == scheduling_strategy: + if ( + scheduling_strategy is None + or service.scheduling_strategy == scheduling_strategy + ): service_arns.append(service.arn) return sorted(service_arns) def describe_services(self, cluster_str, service_names_or_arns): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] result = [] - for existing_service_name, existing_service_obj in sorted(self.services.items()): + for existing_service_name, existing_service_obj in sorted( + self.services.items() + ): for requested_name_or_arn in service_names_or_arns: - cluster_service_pair = '{0}:{1}'.format( - cluster_name, requested_name_or_arn) - if cluster_service_pair == existing_service_name or existing_service_obj.arn == requested_name_or_arn: + cluster_service_pair = "{0}:{1}".format( + cluster_name, requested_name_or_arn + ) + if ( + cluster_service_pair == existing_service_name + or existing_service_obj.arn == requested_name_or_arn + ): result.append(existing_service_obj) return result - def update_service(self, cluster_str, service_name, task_definition_str, desired_count): - cluster_name = cluster_str.split('/')[-1] - cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) + def update_service( + self, cluster_str, service_name, task_definition_str, desired_count + ): + cluster_name = cluster_str.split("/")[-1] + cluster_service_pair = "{0}:{1}".format(cluster_name, service_name) if cluster_service_pair in self.services: if task_definition_str is not None: self.describe_task_definition(task_definition_str) self.services[ - cluster_service_pair].task_definition = task_definition_str + cluster_service_pair + ].task_definition = task_definition_str if desired_count is not None: - self.services[ - cluster_service_pair].desired_count = desired_count + self.services[cluster_service_pair].desired_count = desired_count return self.services[cluster_service_pair] else: raise ServiceNotFoundException(service_name) def delete_service(self, cluster_name, service_name): - cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) + cluster_service_pair = "{0}:{1}".format(cluster_name, service_name) if cluster_service_pair in self.services: service = self.services[cluster_service_pair] if service.desired_count > 0: @@ -753,80 +930,110 @@ class EC2ContainerServiceBackend(BaseBackend): else: return self.services.pop(cluster_service_pair) else: - raise Exception("cluster {0} or service {1} does not exist".format( - cluster_name, service_name)) + raise Exception( + "cluster {0} or service {1} does not exist".format( + cluster_name, service_name + ) + ) def register_container_instance(self, cluster_str, ec2_instance_id): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) container_instance = ContainerInstance(ec2_instance_id, self.region_name) if not self.container_instances.get(cluster_name): self.container_instances[cluster_name] = {} - container_instance_id = container_instance.container_instance_arn.split( - '/')[-1] + container_instance_id = container_instance.container_instance_arn.split("/")[-1] self.container_instances[cluster_name][ - container_instance_id] = container_instance + container_instance_id + ] = container_instance self.clusters[cluster_name].registered_container_instances_count += 1 return container_instance def list_container_instances(self, cluster_str): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] container_instances_values = self.container_instances.get( - cluster_name, {}).values() + cluster_name, {} + ).values() container_instances = [ - ci.container_instance_arn for ci in container_instances_values] + ci.container_instance_arn for ci in container_instances_values + ] return sorted(container_instances) def describe_container_instances(self, cluster_str, list_container_instance_ids): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) + if not list_container_instance_ids: + raise JsonRESTError( + "InvalidParameterException", "Container instance cannot be empty" + ) failures = [] container_instance_objects = [] for container_instance_id in list_container_instance_ids: - container_instance_id = container_instance_id.split('/')[-1] - container_instance = self.container_instances[ - cluster_name].get(container_instance_id, None) + container_instance_id = container_instance_id.split("/")[-1] + container_instance = self.container_instances[cluster_name].get( + container_instance_id, None + ) if container_instance is not None: container_instance_objects.append(container_instance) else: - failures.append(ContainerInstanceFailure( - 'MISSING', container_instance_id)) + failures.append( + ContainerInstanceFailure( + "MISSING", container_instance_id, self.region_name + ) + ) return container_instance_objects, failures - def update_container_instances_state(self, cluster_str, list_container_instance_ids, status): - cluster_name = cluster_str.split('/')[-1] + def update_container_instances_state( + self, cluster_str, list_container_instance_ids, status + ): + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) status = status.upper() - if status not in ['ACTIVE', 'DRAINING']: - raise Exception("An error occurred (InvalidParameterException) when calling the UpdateContainerInstancesState operation: \ - Container instances status should be one of [ACTIVE,DRAINING]") + if status not in ["ACTIVE", "DRAINING"]: + raise Exception( + "An error occurred (InvalidParameterException) when calling the UpdateContainerInstancesState operation: \ + Container instances status should be one of [ACTIVE,DRAINING]" + ) failures = [] container_instance_objects = [] - list_container_instance_ids = [x.split('/')[-1] - for x in list_container_instance_ids] + list_container_instance_ids = [ + x.split("/")[-1] for x in list_container_instance_ids + ] for container_instance_id in list_container_instance_ids: - container_instance = self.container_instances[cluster_name].get(container_instance_id, None) + container_instance = self.container_instances[cluster_name].get( + container_instance_id, None + ) if container_instance is not None: container_instance.status = status container_instance_objects.append(container_instance) else: - failures.append(ContainerInstanceFailure('MISSING', container_instance_id)) + failures.append( + ContainerInstanceFailure( + "MISSING", container_instance_id, self.region_name + ) + ) return container_instance_objects, failures - def update_container_instance_resources(self, container_instance, task_resources, removing=False): + def update_container_instance_resources( + self, container_instance, task_resources, removing=False + ): resource_multiplier = 1 if removing: resource_multiplier = -1 for resource in container_instance.remaining_resources: if resource.get("name") == "CPU": - resource["integerValue"] -= task_resources.get('CPU') * resource_multiplier + resource["integerValue"] -= ( + task_resources.get("CPU") * resource_multiplier + ) elif resource.get("name") == "MEMORY": - resource["integerValue"] -= task_resources.get('MEMORY') * resource_multiplier + resource["integerValue"] -= ( + task_resources.get("MEMORY") * resource_multiplier + ) elif resource.get("name") == "PORTS": for port in task_resources.get("PORTS"): if removing: @@ -837,11 +1044,13 @@ class EC2ContainerServiceBackend(BaseBackend): def deregister_container_instance(self, cluster_str, container_instance_str, force): failures = [] - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) - container_instance_id = container_instance_str.split('/')[-1] - container_instance = self.container_instances[cluster_name].get(container_instance_id) + container_instance_id = container_instance_str.split("/")[-1] + container_instance = self.container_instances[cluster_name].get( + container_instance_id + ) if container_instance is None: raise Exception("{0} is not a container id in the cluster") if not force and container_instance.running_tasks_count > 0: @@ -849,53 +1058,86 @@ class EC2ContainerServiceBackend(BaseBackend): # Currently assume that people might want to do something based around deregistered instances # with tasks left running on them - but nothing if no tasks were running already elif force and container_instance.running_tasks_count > 0: - if not self.container_instances.get('orphaned'): - self.container_instances['orphaned'] = {} - self.container_instances['orphaned'][container_instance_id] = container_instance - del(self.container_instances[cluster_name][container_instance_id]) + if not self.container_instances.get("orphaned"): + self.container_instances["orphaned"] = {} + self.container_instances["orphaned"][ + container_instance_id + ] = container_instance + del self.container_instances[cluster_name][container_instance_id] self._respond_to_cluster_state_update(cluster_str) return container_instance, failures def _respond_to_cluster_state_update(self, cluster_str): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) pass def put_attributes(self, cluster_name, attributes=None): if cluster_name is None or cluster_name not in self.clusters: - raise JsonRESTError('ClusterNotFoundException', 'Cluster not found', status=400) + raise JsonRESTError( + "ClusterNotFoundException", "Cluster not found", status=400 + ) if attributes is None: - raise JsonRESTError('InvalidParameterException', 'attributes value is required') + raise JsonRESTError( + "InvalidParameterException", "attributes value is required" + ) for attr in attributes: - self._put_attribute(cluster_name, attr['name'], attr.get('value'), attr.get('targetId'), attr.get('targetType')) + self._put_attribute( + cluster_name, + attr["name"], + attr.get("value"), + attr.get("targetId"), + attr.get("targetType"), + ) - def _put_attribute(self, cluster_name, name, value=None, target_id=None, target_type=None): + def _put_attribute( + self, cluster_name, name, value=None, target_id=None, target_type=None + ): if target_id is None and target_type is None: for instance in self.container_instances[cluster_name].values(): instance.attributes[name] = value elif target_type is None: # targetId is full container instance arn try: - arn = target_id.rsplit('/', 1)[-1] + arn = target_id.rsplit("/", 1)[-1] self.container_instances[cluster_name][arn].attributes[name] = value except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) else: # targetId is container uuid, targetType must be container-instance try: - if target_type != 'container-instance': - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + if target_type != "container-instance": + raise JsonRESTError( + "TargetNotFoundException", + "Could not find {0}".format(target_id), + ) - self.container_instances[cluster_name][target_id].attributes[name] = value + self.container_instances[cluster_name][target_id].attributes[ + name + ] = value except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) - def list_attributes(self, target_type, cluster_name=None, attr_name=None, attr_value=None, max_results=None, next_token=None): - if target_type != 'container-instance': - raise JsonRESTError('InvalidParameterException', 'targetType must be container-instance') + def list_attributes( + self, + target_type, + cluster_name=None, + attr_name=None, + attr_value=None, + max_results=None, + next_token=None, + ): + if target_type != "container-instance": + raise JsonRESTError( + "InvalidParameterException", "targetType must be container-instance" + ) filters = [lambda x: True] @@ -911,21 +1153,40 @@ class EC2ContainerServiceBackend(BaseBackend): for cluster_name, cobj in self.container_instances.items(): for container_instance in cobj.values(): for key, value in container_instance.attributes.items(): - all_attrs.append((cluster_name, container_instance.container_instance_arn, key, value)) + all_attrs.append( + ( + cluster_name, + container_instance.container_instance_arn, + key, + value, + ) + ) return filter(lambda x: all(f(x) for f in filters), all_attrs) def delete_attributes(self, cluster_name, attributes=None): if cluster_name is None or cluster_name not in self.clusters: - raise JsonRESTError('ClusterNotFoundException', 'Cluster not found', status=400) + raise JsonRESTError( + "ClusterNotFoundException", "Cluster not found", status=400 + ) if attributes is None: - raise JsonRESTError('InvalidParameterException', 'attributes value is required') + raise JsonRESTError( + "InvalidParameterException", "attributes value is required" + ) for attr in attributes: - self._delete_attribute(cluster_name, attr['name'], attr.get('value'), attr.get('targetId'), attr.get('targetType')) + self._delete_attribute( + cluster_name, + attr["name"], + attr.get("value"), + attr.get("targetId"), + attr.get("targetType"), + ) - def _delete_attribute(self, cluster_name, name, value=None, target_id=None, target_type=None): + def _delete_attribute( + self, cluster_name, name, value=None, target_id=None, target_type=None + ): if target_id is None and target_type is None: for instance in self.container_instances[cluster_name].values(): if name in instance.attributes and instance.attributes[name] == value: @@ -933,47 +1194,68 @@ class EC2ContainerServiceBackend(BaseBackend): elif target_type is None: # targetId is full container instance arn try: - arn = target_id.rsplit('/', 1)[-1] + arn = target_id.rsplit("/", 1)[-1] instance = self.container_instances[cluster_name][arn] if name in instance.attributes and instance.attributes[name] == value: del instance.attributes[name] except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) else: # targetId is container uuid, targetType must be container-instance try: - if target_type != 'container-instance': - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + if target_type != "container-instance": + raise JsonRESTError( + "TargetNotFoundException", + "Could not find {0}".format(target_id), + ) instance = self.container_instances[cluster_name][target_id] if name in instance.attributes and instance.attributes[name] == value: del instance.attributes[name] except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) - def list_task_definition_families(self, family_prefix=None, status=None, max_results=None, next_token=None): + def list_task_definition_families( + self, family_prefix=None, status=None, max_results=None, next_token=None + ): for task_fam in self.task_definitions: if family_prefix is not None and not task_fam.startswith(family_prefix): continue yield task_fam - def list_tags_for_resource(self, resource_arn): - """Currently only implemented for task definitions""" + @staticmethod + def _parse_resource_arn(resource_arn): match = re.match( "^arn:aws:ecs:(?P[^:]+):(?P[^:]+):(?P[^:]+)/(?P.*)$", - resource_arn) + resource_arn, + ) if not match: - raise JsonRESTError('InvalidParameterException', 'The ARN provided is invalid.') + raise JsonRESTError( + "InvalidParameterException", "The ARN provided is invalid." + ) + return match.groupdict() - service = match.group("service") - if service == "task-definition": + def list_tags_for_resource(self, resource_arn): + """Currently implemented only for task definitions and services""" + parsed_arn = self._parse_resource_arn(resource_arn) + if parsed_arn["service"] == "task-definition": for task_definition in self.task_definitions.values(): for revision in task_definition.values(): if revision.arn == resource_arn: return revision.tags else: raise TaskDefinitionNotFoundException() + elif parsed_arn["service"] == "service": + for service in self.services.values(): + if service.arn == resource_arn: + return service.tags + else: + raise ServiceNotFoundException(service_name=parsed_arn["id"]) raise NotImplementedError() def _get_last_task_definition_revision_id(self, family): @@ -981,6 +1263,46 @@ class EC2ContainerServiceBackend(BaseBackend): if definitions: return max(definitions.keys()) + def tag_resource(self, resource_arn, tags): + """Currently implemented only for services""" + parsed_arn = self._parse_resource_arn(resource_arn) + if parsed_arn["service"] == "service": + for service in self.services.values(): + if service.arn == resource_arn: + service.tags = self._merge_tags(service.tags, tags) + return {} + else: + raise ServiceNotFoundException(service_name=parsed_arn["id"]) + raise NotImplementedError() + + def _merge_tags(self, existing_tags, new_tags): + merged_tags = new_tags + new_keys = self._get_keys(new_tags) + for existing_tag in existing_tags: + if existing_tag["key"] not in new_keys: + merged_tags.append(existing_tag) + return merged_tags + + @staticmethod + def _get_keys(tags): + return [tag["key"] for tag in tags] + + def untag_resource(self, resource_arn, tag_keys): + """Currently implemented only for services""" + parsed_arn = self._parse_resource_arn(resource_arn) + if parsed_arn["service"] == "service": + for service in self.services.values(): + if service.arn == resource_arn: + service.tags = [ + tag for tag in service.tags if tag["key"] not in tag_keys + ] + return {} + else: + raise ServiceNotFoundException(service_name=parsed_arn["id"]) + raise NotImplementedError() + available_regions = boto3.session.Session().get_available_regions("ecs") -ecs_backends = {region: EC2ContainerServiceBackend(region) for region in available_regions} +ecs_backends = { + region: EC2ContainerServiceBackend(region) for region in available_regions +} diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index abb79ea78..d08bded2c 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -6,7 +6,6 @@ from .models import ecs_backends class EC2ContainerServiceResponse(BaseResponse): - @property def ecs_backend(self): """ @@ -28,294 +27,316 @@ class EC2ContainerServiceResponse(BaseResponse): return self.request_params.get(param, if_none) def create_cluster(self): - cluster_name = self._get_param('clusterName') + cluster_name = self._get_param("clusterName") if cluster_name is None: - cluster_name = 'default' + cluster_name = "default" cluster = self.ecs_backend.create_cluster(cluster_name) - return json.dumps({ - 'cluster': cluster.response_object - }) + return json.dumps({"cluster": cluster.response_object}) def list_clusters(self): cluster_arns = self.ecs_backend.list_clusters() - return json.dumps({ - 'clusterArns': cluster_arns - # 'nextToken': str(uuid.uuid4()) - }) + return json.dumps( + { + "clusterArns": cluster_arns + # 'nextToken': str(uuid.uuid4()) + } + ) def describe_clusters(self): - list_clusters_name = self._get_param('clusters') + list_clusters_name = self._get_param("clusters") clusters, failures = self.ecs_backend.describe_clusters(list_clusters_name) - return json.dumps({ - 'clusters': clusters, - 'failures': [cluster.response_object for cluster in failures] - }) + return json.dumps( + { + "clusters": clusters, + "failures": [cluster.response_object for cluster in failures], + } + ) def delete_cluster(self): - cluster_str = self._get_param('cluster') + cluster_str = self._get_param("cluster") cluster = self.ecs_backend.delete_cluster(cluster_str) - return json.dumps({ - 'cluster': cluster.response_object - }) + return json.dumps({"cluster": cluster.response_object}) def register_task_definition(self): - family = self._get_param('family') - container_definitions = self._get_param('containerDefinitions') - volumes = self._get_param('volumes') - tags = self._get_param('tags') + family = self._get_param("family") + container_definitions = self._get_param("containerDefinitions") + volumes = self._get_param("volumes") + tags = self._get_param("tags") task_definition = self.ecs_backend.register_task_definition( - family, container_definitions, volumes, tags) - return json.dumps({ - 'taskDefinition': task_definition.response_object - }) + family, container_definitions, volumes, tags + ) + return json.dumps({"taskDefinition": task_definition.response_object}) def list_task_definitions(self): - task_definition_arns = self.ecs_backend.list_task_definitions() - return json.dumps({ - 'taskDefinitionArns': task_definition_arns - # 'nextToken': str(uuid.uuid4()) - }) + family_prefix = self._get_param("familyPrefix") + task_definition_arns = self.ecs_backend.list_task_definitions(family_prefix) + return json.dumps( + { + "taskDefinitionArns": task_definition_arns + # 'nextToken': str(uuid.uuid4()) + } + ) def describe_task_definition(self): - task_definition_str = self._get_param('taskDefinition') + task_definition_str = self._get_param("taskDefinition") data = self.ecs_backend.describe_task_definition(task_definition_str) - return json.dumps({ - 'taskDefinition': data.response_object, - 'failures': [] - }) + return json.dumps({"taskDefinition": data.response_object, "failures": []}) def deregister_task_definition(self): - task_definition_str = self._get_param('taskDefinition') + task_definition_str = self._get_param("taskDefinition") task_definition = self.ecs_backend.deregister_task_definition( - task_definition_str) - return json.dumps({ - 'taskDefinition': task_definition.response_object - }) + task_definition_str + ) + return json.dumps({"taskDefinition": task_definition.response_object}) def run_task(self): - cluster_str = self._get_param('cluster') - overrides = self._get_param('overrides') - task_definition_str = self._get_param('taskDefinition') - count = self._get_int_param('count') - started_by = self._get_param('startedBy') + cluster_str = self._get_param("cluster") + overrides = self._get_param("overrides") + task_definition_str = self._get_param("taskDefinition") + count = self._get_int_param("count") + started_by = self._get_param("startedBy") tasks = self.ecs_backend.run_task( - cluster_str, task_definition_str, count, overrides, started_by) - return json.dumps({ - 'tasks': [task.response_object for task in tasks], - 'failures': [] - }) + cluster_str, task_definition_str, count, overrides, started_by + ) + return json.dumps( + {"tasks": [task.response_object for task in tasks], "failures": []} + ) def describe_tasks(self): - cluster = self._get_param('cluster') - tasks = self._get_param('tasks') + cluster = self._get_param("cluster") + tasks = self._get_param("tasks") data = self.ecs_backend.describe_tasks(cluster, tasks) - return json.dumps({ - 'tasks': [task.response_object for task in data], - 'failures': [] - }) + return json.dumps( + {"tasks": [task.response_object for task in data], "failures": []} + ) def start_task(self): - cluster_str = self._get_param('cluster') - overrides = self._get_param('overrides') - task_definition_str = self._get_param('taskDefinition') - container_instances = self._get_param('containerInstances') - started_by = self._get_param('startedBy') + cluster_str = self._get_param("cluster") + overrides = self._get_param("overrides") + task_definition_str = self._get_param("taskDefinition") + container_instances = self._get_param("containerInstances") + started_by = self._get_param("startedBy") tasks = self.ecs_backend.start_task( - cluster_str, task_definition_str, container_instances, overrides, started_by) - return json.dumps({ - 'tasks': [task.response_object for task in tasks], - 'failures': [] - }) + cluster_str, task_definition_str, container_instances, overrides, started_by + ) + return json.dumps( + {"tasks": [task.response_object for task in tasks], "failures": []} + ) def list_tasks(self): - cluster_str = self._get_param('cluster') - container_instance = self._get_param('containerInstance') - family = self._get_param('family') - started_by = self._get_param('startedBy') - service_name = self._get_param('serviceName') - desiredStatus = self._get_param('desiredStatus') + cluster_str = self._get_param("cluster") + container_instance = self._get_param("containerInstance") + family = self._get_param("family") + started_by = self._get_param("startedBy") + service_name = self._get_param("serviceName") + desiredStatus = self._get_param("desiredStatus") task_arns = self.ecs_backend.list_tasks( - cluster_str, container_instance, family, started_by, service_name, desiredStatus) - return json.dumps({ - 'taskArns': task_arns - }) + cluster_str, + container_instance, + family, + started_by, + service_name, + desiredStatus, + ) + return json.dumps({"taskArns": task_arns}) def stop_task(self): - cluster_str = self._get_param('cluster') - task = self._get_param('task') - reason = self._get_param('reason') + cluster_str = self._get_param("cluster") + task = self._get_param("task") + reason = self._get_param("reason") task = self.ecs_backend.stop_task(cluster_str, task, reason) - return json.dumps({ - 'task': task.response_object - }) + return json.dumps({"task": task.response_object}) def create_service(self): - cluster_str = self._get_param('cluster') - service_name = self._get_param('serviceName') - task_definition_str = self._get_param('taskDefinition') - desired_count = self._get_int_param('desiredCount') - load_balancers = self._get_param('loadBalancers') - scheduling_strategy = self._get_param('schedulingStrategy') + cluster_str = self._get_param("cluster") + service_name = self._get_param("serviceName") + task_definition_str = self._get_param("taskDefinition") + desired_count = self._get_int_param("desiredCount") + load_balancers = self._get_param("loadBalancers") + scheduling_strategy = self._get_param("schedulingStrategy") + tags = self._get_param("tags") service = self.ecs_backend.create_service( - cluster_str, service_name, task_definition_str, desired_count, load_balancers, scheduling_strategy) - return json.dumps({ - 'service': service.response_object - }) + cluster_str, + service_name, + task_definition_str, + desired_count, + load_balancers, + scheduling_strategy, + tags, + ) + return json.dumps({"service": service.response_object}) def list_services(self): - cluster_str = self._get_param('cluster') - scheduling_strategy = self._get_param('schedulingStrategy') + cluster_str = self._get_param("cluster") + scheduling_strategy = self._get_param("schedulingStrategy") service_arns = self.ecs_backend.list_services(cluster_str, scheduling_strategy) - return json.dumps({ - 'serviceArns': service_arns - # , - # 'nextToken': str(uuid.uuid4()) - }) + return json.dumps( + { + "serviceArns": service_arns + # , + # 'nextToken': str(uuid.uuid4()) + } + ) def describe_services(self): - cluster_str = self._get_param('cluster') - service_names = self._get_param('services') - services = self.ecs_backend.describe_services( - cluster_str, service_names) - return json.dumps({ - 'services': [service.response_object for service in services], - 'failures': [] - }) + cluster_str = self._get_param("cluster") + service_names = self._get_param("services") + services = self.ecs_backend.describe_services(cluster_str, service_names) + return json.dumps( + { + "services": [service.response_object for service in services], + "failures": [], + } + ) def update_service(self): - cluster_str = self._get_param('cluster') - service_name = self._get_param('service') - task_definition = self._get_param('taskDefinition') - desired_count = self._get_int_param('desiredCount') + cluster_str = self._get_param("cluster") + service_name = self._get_param("service") + task_definition = self._get_param("taskDefinition") + desired_count = self._get_int_param("desiredCount") service = self.ecs_backend.update_service( - cluster_str, service_name, task_definition, desired_count) - return json.dumps({ - 'service': service.response_object - }) + cluster_str, service_name, task_definition, desired_count + ) + return json.dumps({"service": service.response_object}) def delete_service(self): - service_name = self._get_param('service') - cluster_name = self._get_param('cluster') + service_name = self._get_param("service") + cluster_name = self._get_param("cluster") service = self.ecs_backend.delete_service(cluster_name, service_name) - return json.dumps({ - 'service': service.response_object - }) + return json.dumps({"service": service.response_object}) def register_container_instance(self): - cluster_str = self._get_param('cluster') - instance_identity_document_str = self._get_param( - 'instanceIdentityDocument') + cluster_str = self._get_param("cluster") + instance_identity_document_str = self._get_param("instanceIdentityDocument") instance_identity_document = json.loads(instance_identity_document_str) ec2_instance_id = instance_identity_document["instanceId"] container_instance = self.ecs_backend.register_container_instance( - cluster_str, ec2_instance_id) - return json.dumps({ - 'containerInstance': container_instance.response_object - }) + cluster_str, ec2_instance_id + ) + return json.dumps({"containerInstance": container_instance.response_object}) def deregister_container_instance(self): - cluster_str = self._get_param('cluster') + cluster_str = self._get_param("cluster") if not cluster_str: - cluster_str = 'default' - container_instance_str = self._get_param('containerInstance') - force = self._get_param('force') + cluster_str = "default" + container_instance_str = self._get_param("containerInstance") + force = self._get_param("force") container_instance, failures = self.ecs_backend.deregister_container_instance( cluster_str, container_instance_str, force ) - return json.dumps({ - 'containerInstance': container_instance.response_object - }) + return json.dumps({"containerInstance": container_instance.response_object}) def list_container_instances(self): - cluster_str = self._get_param('cluster') - container_instance_arns = self.ecs_backend.list_container_instances( - cluster_str) - return json.dumps({ - 'containerInstanceArns': container_instance_arns - }) + cluster_str = self._get_param("cluster") + container_instance_arns = self.ecs_backend.list_container_instances(cluster_str) + return json.dumps({"containerInstanceArns": container_instance_arns}) def describe_container_instances(self): - cluster_str = self._get_param('cluster') - list_container_instance_arns = self._get_param('containerInstances') + cluster_str = self._get_param("cluster") + list_container_instance_arns = self._get_param("containerInstances") container_instances, failures = self.ecs_backend.describe_container_instances( - cluster_str, list_container_instance_arns) - return json.dumps({ - 'failures': [ci.response_object for ci in failures], - 'containerInstances': [ci.response_object for ci in container_instances] - }) + cluster_str, list_container_instance_arns + ) + return json.dumps( + { + "failures": [ci.response_object for ci in failures], + "containerInstances": [ + ci.response_object for ci in container_instances + ], + } + ) def update_container_instances_state(self): - cluster_str = self._get_param('cluster') - list_container_instance_arns = self._get_param('containerInstances') - status_str = self._get_param('status') - container_instances, failures = self.ecs_backend.update_container_instances_state(cluster_str, - list_container_instance_arns, - status_str) - return json.dumps({ - 'failures': [ci.response_object for ci in failures], - 'containerInstances': [ci.response_object for ci in container_instances] - }) + cluster_str = self._get_param("cluster") + list_container_instance_arns = self._get_param("containerInstances") + status_str = self._get_param("status") + ( + container_instances, + failures, + ) = self.ecs_backend.update_container_instances_state( + cluster_str, list_container_instance_arns, status_str + ) + return json.dumps( + { + "failures": [ci.response_object for ci in failures], + "containerInstances": [ + ci.response_object for ci in container_instances + ], + } + ) def put_attributes(self): - cluster_name = self._get_param('cluster') - attributes = self._get_param('attributes') + cluster_name = self._get_param("cluster") + attributes = self._get_param("attributes") self.ecs_backend.put_attributes(cluster_name, attributes) - return json.dumps({'attributes': attributes}) + return json.dumps({"attributes": attributes}) def list_attributes(self): - cluster_name = self._get_param('cluster') - attr_name = self._get_param('attributeName') - attr_value = self._get_param('attributeValue') - target_type = self._get_param('targetType') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') + cluster_name = self._get_param("cluster") + attr_name = self._get_param("attributeName") + attr_value = self._get_param("attributeValue") + target_type = self._get_param("targetType") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") - results = self.ecs_backend.list_attributes(target_type, cluster_name, attr_name, attr_value, max_results, next_token) + results = self.ecs_backend.list_attributes( + target_type, cluster_name, attr_name, attr_value, max_results, next_token + ) # Result will be [item will be {0 cluster_name, 1 arn, 2 name, 3 value}] formatted_results = [] for _, arn, name, value in results: - tmp_result = { - 'name': name, - 'targetId': arn - } + tmp_result = {"name": name, "targetId": arn} if value is not None: - tmp_result['value'] = value + tmp_result["value"] = value formatted_results.append(tmp_result) - return json.dumps({'attributes': formatted_results}) + return json.dumps({"attributes": formatted_results}) def delete_attributes(self): - cluster_name = self._get_param('cluster') - attributes = self._get_param('attributes') + cluster_name = self._get_param("cluster") + attributes = self._get_param("attributes") self.ecs_backend.delete_attributes(cluster_name, attributes) - return json.dumps({'attributes': attributes}) + return json.dumps({"attributes": attributes}) def discover_poll_endpoint(self): # Here are the arguments, this api is used by the ecs client so obviously no decent # documentation. Hence I've responded with valid but useless data # cluster_name = self._get_param('cluster') # instance = self._get_param('containerInstance') - return json.dumps({ - 'endpoint': 'http://localhost', - 'telemetryEndpoint': 'http://localhost' - }) + return json.dumps( + {"endpoint": "http://localhost", "telemetryEndpoint": "http://localhost"} + ) def list_task_definition_families(self): - family_prefix = self._get_param('familyPrefix') - status = self._get_param('status') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') + family_prefix = self._get_param("familyPrefix") + status = self._get_param("status") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") - results = self.ecs_backend.list_task_definition_families(family_prefix, status, max_results, next_token) + results = self.ecs_backend.list_task_definition_families( + family_prefix, status, max_results, next_token + ) - return json.dumps({'families': list(results)}) + return json.dumps({"families": list(results)}) def list_tags_for_resource(self): - resource_arn = self._get_param('resourceArn') + resource_arn = self._get_param("resourceArn") tags = self.ecs_backend.list_tags_for_resource(resource_arn) - return json.dumps({'tags': tags}) + return json.dumps({"tags": tags}) + + def tag_resource(self): + resource_arn = self._get_param("resourceArn") + tags = self._get_param("tags") + results = self.ecs_backend.tag_resource(resource_arn, tags) + return json.dumps(results) + + def untag_resource(self): + resource_arn = self._get_param("resourceArn") + tag_keys = self._get_param("tagKeys") + results = self.ecs_backend.untag_resource(resource_arn, tag_keys) + return json.dumps(results) diff --git a/moto/ecs/urls.py b/moto/ecs/urls.py index 1e0d5fbf9..a5adc5923 100644 --- a/moto/ecs/urls.py +++ b/moto/ecs/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import EC2ContainerServiceResponse -url_bases = [ - "https?://ecs.(.+).amazonaws.com", -] +url_bases = ["https?://ecs.(.+).amazonaws.com"] -url_paths = { - '{0}/$': EC2ContainerServiceResponse.dispatch, -} +url_paths = {"{0}/$": EC2ContainerServiceResponse.dispatch} diff --git a/moto/elb/__init__.py b/moto/elb/__init__.py index e25f2d486..d3627ed6d 100644 --- a/moto/elb/__init__.py +++ b/moto/elb/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import elb_backends from ..core.models import base_decorator, deprecated_base_decorator -elb_backend = elb_backends['us-east-1'] +elb_backend = elb_backends["us-east-1"] mock_elb = base_decorator(elb_backends) mock_elb_deprecated = deprecated_base_decorator(elb_backends) diff --git a/moto/elb/exceptions.py b/moto/elb/exceptions.py index 3ea6a1642..d41a66e3f 100644 --- a/moto/elb/exceptions.py +++ b/moto/elb/exceptions.py @@ -7,68 +7,66 @@ class ELBClientError(RESTError): class DuplicateTagKeysError(ELBClientError): - def __init__(self, cidr): super(DuplicateTagKeysError, self).__init__( - "DuplicateTagKeys", - "Tag key was specified more than once: {0}" - .format(cidr)) + "DuplicateTagKeys", "Tag key was specified more than once: {0}".format(cidr) + ) class LoadBalancerNotFoundError(ELBClientError): - def __init__(self, cidr): super(LoadBalancerNotFoundError, self).__init__( "LoadBalancerNotFound", - "The specified load balancer does not exist: {0}" - .format(cidr)) + "The specified load balancer does not exist: {0}".format(cidr), + ) class TooManyTagsError(ELBClientError): - def __init__(self): super(TooManyTagsError, self).__init__( "LoadBalancerNotFound", - "The quota for the number of tags that can be assigned to a load balancer has been reached") + "The quota for the number of tags that can be assigned to a load balancer has been reached", + ) class BadHealthCheckDefinition(ELBClientError): - def __init__(self): super(BadHealthCheckDefinition, self).__init__( "ValidationError", - "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL") + "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL", + ) class DuplicateListenerError(ELBClientError): - def __init__(self, name, port): super(DuplicateListenerError, self).__init__( "DuplicateListener", - "A listener already exists for {0} with LoadBalancerPort {1}, but with a different InstancePort, Protocol, or SSLCertificateId" - .format(name, port)) + "A listener already exists for {0} with LoadBalancerPort {1}, but with a different InstancePort, Protocol, or SSLCertificateId".format( + name, port + ), + ) class DuplicateLoadBalancerName(ELBClientError): - def __init__(self, name): super(DuplicateLoadBalancerName, self).__init__( "DuplicateLoadBalancerName", - "The specified load balancer name already exists for this account: {0}" - .format(name)) + "The specified load balancer name already exists for this account: {0}".format( + name + ), + ) class EmptyListenersError(ELBClientError): - def __init__(self): super(EmptyListenersError, self).__init__( - "ValidationError", - "Listeners cannot be empty") + "ValidationError", "Listeners cannot be empty" + ) class InvalidSecurityGroupError(ELBClientError): - def __init__(self): super(InvalidSecurityGroupError, self).__init__( "ValidationError", - "One or more of the specified security groups do not exist.") + "One or more of the specified security groups do not exist.", + ) diff --git a/moto/elb/models.py b/moto/elb/models.py index 8781620f1..f77811623 100644 --- a/moto/elb/models.py +++ b/moto/elb/models.py @@ -8,10 +8,7 @@ from boto.ec2.elb.attributes import ( AccessLogAttribute, CrossZoneLoadBalancingAttribute, ) -from boto.ec2.elb.policies import ( - Policies, - OtherPolicy, -) +from boto.ec2.elb.policies import Policies, OtherPolicy from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.ec2.models import ec2_backends @@ -27,20 +24,19 @@ from .exceptions import ( class FakeHealthCheck(BaseModel): - - def __init__(self, timeout, healthy_threshold, unhealthy_threshold, - interval, target): + def __init__( + self, timeout, healthy_threshold, unhealthy_threshold, interval, target + ): self.timeout = timeout self.healthy_threshold = healthy_threshold self.unhealthy_threshold = unhealthy_threshold self.interval = interval self.target = target - if not target.startswith(('HTTP', 'TCP', 'HTTPS', 'SSL')): + if not target.startswith(("HTTP", "TCP", "HTTPS", "SSL")): raise BadHealthCheckDefinition class FakeListener(BaseModel): - def __init__(self, load_balancer_port, instance_port, protocol, ssl_certificate_id): self.load_balancer_port = load_balancer_port self.instance_port = instance_port @@ -49,22 +45,38 @@ class FakeListener(BaseModel): self.policy_names = [] def __repr__(self): - return "FakeListener(lbp: %s, inp: %s, pro: %s, cid: %s, policies: %s)" % (self.load_balancer_port, self.instance_port, self.protocol, self.ssl_certificate_id, self.policy_names) + return "FakeListener(lbp: %s, inp: %s, pro: %s, cid: %s, policies: %s)" % ( + self.load_balancer_port, + self.instance_port, + self.protocol, + self.ssl_certificate_id, + self.policy_names, + ) class FakeBackend(BaseModel): - def __init__(self, instance_port): self.instance_port = instance_port self.policy_names = [] def __repr__(self): - return "FakeBackend(inp: %s, policies: %s)" % (self.instance_port, self.policy_names) + return "FakeBackend(inp: %s, policies: %s)" % ( + self.instance_port, + self.policy_names, + ) class FakeLoadBalancer(BaseModel): - - def __init__(self, name, zones, ports, scheme='internet-facing', vpc_id=None, subnets=None, security_groups=None): + def __init__( + self, + name, + zones, + ports, + scheme="internet-facing", + vpc_id=None, + subnets=None, + security_groups=None, + ): self.name = name self.health_check = None self.instance_ids = [] @@ -80,47 +92,49 @@ class FakeLoadBalancer(BaseModel): self.policies.lb_cookie_stickiness_policies = [] self.security_groups = security_groups or [] self.subnets = subnets or [] - self.vpc_id = vpc_id or 'vpc-56e10e3d' + self.vpc_id = vpc_id or "vpc-56e10e3d" self.tags = {} self.dns_name = "%s.us-east-1.elb.amazonaws.com" % (name) for port in ports: listener = FakeListener( - protocol=(port.get('protocol') or port['Protocol']), + protocol=(port.get("protocol") or port["Protocol"]), load_balancer_port=( - port.get('load_balancer_port') or port['LoadBalancerPort']), - instance_port=( - port.get('instance_port') or port['InstancePort']), + port.get("load_balancer_port") or port["LoadBalancerPort"] + ), + instance_port=(port.get("instance_port") or port["InstancePort"]), ssl_certificate_id=port.get( - 'ssl_certificate_id', port.get('SSLCertificateId')), + "ssl_certificate_id", port.get("SSLCertificateId") + ), ) self.listeners.append(listener) # it is unclear per the AWS documentation as to when or how backend # information gets set, so let's guess and set it here *shrug* backend = FakeBackend( - instance_port=( - port.get('instance_port') or port['InstancePort']), + instance_port=(port.get("instance_port") or port["InstancePort"]) ) self.backends.append(backend) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elb_backend = elb_backends[region_name] new_elb = elb_backend.create_load_balancer( - name=properties.get('LoadBalancerName', resource_name), - zones=properties.get('AvailabilityZones', []), - ports=properties['Listeners'], - scheme=properties.get('Scheme', 'internet-facing'), + name=properties.get("LoadBalancerName", resource_name), + zones=properties.get("AvailabilityZones", []), + ports=properties["Listeners"], + scheme=properties.get("Scheme", "internet-facing"), ) - instance_ids = properties.get('Instances', []) + instance_ids = properties.get("Instances", []) for instance_id in instance_ids: elb_backend.register_instances(new_elb.name, [instance_id]) - policies = properties.get('Policies', []) + policies = properties.get("Policies", []) port_policies = {} for policy in policies: policy_name = policy["PolicyName"] @@ -134,29 +148,37 @@ class FakeLoadBalancer(BaseModel): for port, policies in port_policies.items(): elb_backend.set_load_balancer_policies_of_backend_server( - new_elb.name, port, list(policies)) + new_elb.name, port, list(policies) + ) - health_check = properties.get('HealthCheck') + health_check = properties.get("HealthCheck") if health_check: elb_backend.configure_health_check( load_balancer_name=new_elb.name, - timeout=health_check['Timeout'], - healthy_threshold=health_check['HealthyThreshold'], - unhealthy_threshold=health_check['UnhealthyThreshold'], - interval=health_check['Interval'], - target=health_check['Target'], + timeout=health_check["Timeout"], + healthy_threshold=health_check["HealthyThreshold"], + unhealthy_threshold=health_check["UnhealthyThreshold"], + interval=health_check["Interval"], + target=health_check["Target"], ) return new_elb @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): elb_backend = elb_backends[region_name] try: elb_backend.delete_load_balancer(resource_name) @@ -169,20 +191,25 @@ class FakeLoadBalancer(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'CanonicalHostedZoneName': + + if attribute_name == "CanonicalHostedZoneName": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneName" ]"') - elif attribute_name == 'CanonicalHostedZoneNameID': + '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneName" ]"' + ) + elif attribute_name == "CanonicalHostedZoneNameID": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneNameID" ]"') - elif attribute_name == 'DNSName': + '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneNameID" ]"' + ) + elif attribute_name == "DNSName": return self.dns_name - elif attribute_name == 'SourceSecurityGroup.GroupName': + elif attribute_name == "SourceSecurityGroup.GroupName": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.GroupName" ]"') - elif attribute_name == 'SourceSecurityGroup.OwnerAlias': + '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.GroupName" ]"' + ) + elif attribute_name == "SourceSecurityGroup.OwnerAlias": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.OwnerAlias" ]"') + '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.OwnerAlias" ]"' + ) raise UnformattedGetAttTemplateException() @classmethod @@ -220,12 +247,11 @@ class FakeLoadBalancer(BaseModel): del self.tags[key] def delete(self, region): - ''' Not exposed as part of the ELB API - used for CloudFormation. ''' + """ Not exposed as part of the ELB API - used for CloudFormation. """ elb_backends[region].delete_load_balancer(self.name) class ELBBackend(BaseBackend): - def __init__(self, region_name=None): self.region_name = region_name self.load_balancers = OrderedDict() @@ -235,7 +261,15 @@ class ELBBackend(BaseBackend): self.__dict__ = {} self.__init__(region_name) - def create_load_balancer(self, name, zones, ports, scheme='internet-facing', subnets=None, security_groups=None): + def create_load_balancer( + self, + name, + zones, + ports, + scheme="internet-facing", + subnets=None, + security_groups=None, + ): vpc_id = None ec2_backend = ec2_backends[self.region_name] if subnets: @@ -257,7 +291,8 @@ class ELBBackend(BaseBackend): scheme=scheme, subnets=subnets, security_groups=security_groups, - vpc_id=vpc_id) + vpc_id=vpc_id, + ) self.load_balancers[name] = new_load_balancer return new_load_balancer @@ -265,10 +300,10 @@ class ELBBackend(BaseBackend): balancer = self.load_balancers.get(name, None) if balancer: for port in ports: - protocol = port['protocol'] - instance_port = port['instance_port'] - lb_port = port['load_balancer_port'] - ssl_certificate_id = port.get('ssl_certificate_id') + protocol = port["protocol"] + instance_port = port["instance_port"] + lb_port = port["load_balancer_port"] + ssl_certificate_id = port.get("ssl_certificate_id") for listener in balancer.listeners: if lb_port == listener.load_balancer_port: if protocol != listener.protocol: @@ -279,8 +314,11 @@ class ELBBackend(BaseBackend): raise DuplicateListenerError(name, lb_port) break else: - balancer.listeners.append(FakeListener( - lb_port, instance_port, protocol, ssl_certificate_id)) + balancer.listeners.append( + FakeListener( + lb_port, instance_port, protocol, ssl_certificate_id + ) + ) return balancer @@ -288,7 +326,8 @@ class ELBBackend(BaseBackend): balancers = self.load_balancers.values() if names: matched_balancers = [ - balancer for balancer in balancers if balancer.name in names] + balancer for balancer in balancers if balancer.name in names + ] if len(names) != len(matched_balancers): missing_elb = list(set(names) - set(matched_balancers))[0] raise LoadBalancerNotFoundError(missing_elb) @@ -315,7 +354,9 @@ class ELBBackend(BaseBackend): def get_load_balancer(self, load_balancer_name): return self.load_balancers.get(load_balancer_name) - def apply_security_groups_to_load_balancer(self, load_balancer_name, security_group_ids): + def apply_security_groups_to_load_balancer( + self, load_balancer_name, security_group_ids + ): load_balancer = self.load_balancers.get(load_balancer_name) ec2_backend = ec2_backends[self.region_name] for security_group_id in security_group_ids: @@ -323,22 +364,30 @@ class ELBBackend(BaseBackend): raise InvalidSecurityGroupError() load_balancer.security_groups = security_group_ids - def configure_health_check(self, load_balancer_name, timeout, - healthy_threshold, unhealthy_threshold, interval, - target): - check = FakeHealthCheck(timeout, healthy_threshold, unhealthy_threshold, - interval, target) + def configure_health_check( + self, + load_balancer_name, + timeout, + healthy_threshold, + unhealthy_threshold, + interval, + target, + ): + check = FakeHealthCheck( + timeout, healthy_threshold, unhealthy_threshold, interval, target + ) load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.health_check = check return check - def set_load_balancer_listener_sslcertificate(self, name, lb_port, ssl_certificate_id): + def set_load_balancer_listener_sslcertificate( + self, name, lb_port, ssl_certificate_id + ): balancer = self.load_balancers.get(name, None) if balancer: for idx, listener in enumerate(balancer.listeners): if lb_port == listener.load_balancer_port: - balancer.listeners[ - idx].ssl_certificate_id = ssl_certificate_id + balancer.listeners[idx].ssl_certificate_id = ssl_certificate_id return balancer @@ -350,7 +399,10 @@ class ELBBackend(BaseBackend): def deregister_instances(self, load_balancer_name, instance_ids): load_balancer = self.get_load_balancer(load_balancer_name) new_instance_ids = [ - instance_id for instance_id in load_balancer.instance_ids if instance_id not in instance_ids] + instance_id + for instance_id in load_balancer.instance_ids + if instance_id not in instance_ids + ] load_balancer.instance_ids = new_instance_ids return load_balancer @@ -376,7 +428,9 @@ class ELBBackend(BaseBackend): def create_lb_other_policy(self, load_balancer_name, other_policy): load_balancer = self.get_load_balancer(load_balancer_name) - if other_policy.policy_name not in [p.policy_name for p in load_balancer.policies.other_policies]: + if other_policy.policy_name not in [ + p.policy_name for p in load_balancer.policies.other_policies + ]: load_balancer.policies.other_policies.append(other_policy) return load_balancer @@ -391,19 +445,27 @@ class ELBBackend(BaseBackend): load_balancer.policies.lb_cookie_stickiness_policies.append(policy) return load_balancer - def set_load_balancer_policies_of_backend_server(self, load_balancer_name, instance_port, policies): + def set_load_balancer_policies_of_backend_server( + self, load_balancer_name, instance_port, policies + ): load_balancer = self.get_load_balancer(load_balancer_name) - backend = [b for b in load_balancer.backends if int( - b.instance_port) == instance_port][0] + backend = [ + b for b in load_balancer.backends if int(b.instance_port) == instance_port + ][0] backend_idx = load_balancer.backends.index(backend) backend.policy_names = policies load_balancer.backends[backend_idx] = backend return load_balancer - def set_load_balancer_policies_of_listener(self, load_balancer_name, load_balancer_port, policies): + def set_load_balancer_policies_of_listener( + self, load_balancer_name, load_balancer_port, policies + ): load_balancer = self.get_load_balancer(load_balancer_name) - listener = [l for l in load_balancer.listeners if int( - l.load_balancer_port) == load_balancer_port][0] + listener = [ + l + for l in load_balancer.listeners + if int(l.load_balancer_port) == load_balancer_port + ][0] listener_idx = load_balancer.listeners.index(listener) listener.policy_names = policies load_balancer.listeners[listener_idx] = listener diff --git a/moto/elb/responses.py b/moto/elb/responses.py index b512f56e9..de21f23e7 100644 --- a/moto/elb/responses.py +++ b/moto/elb/responses.py @@ -5,10 +5,7 @@ from boto.ec2.elb.attributes import ( AccessLogAttribute, CrossZoneLoadBalancingAttribute, ) -from boto.ec2.elb.policies import ( - AppCookieStickinessPolicy, - OtherPolicy, -) +from boto.ec2.elb.policies import AppCookieStickinessPolicy, OtherPolicy from moto.core.responses import BaseResponse from .models import elb_backends @@ -16,16 +13,15 @@ from .exceptions import DuplicateTagKeysError, LoadBalancerNotFoundError class ELBResponse(BaseResponse): - @property def elb_backend(self): return elb_backends[self.region] def create_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") availability_zones = self._get_multi_param("AvailabilityZones.member") ports = self._get_list_prefix("Listeners.member") - scheme = self._get_param('Scheme') + scheme = self._get_param("Scheme") subnets = self._get_multi_param("Subnets.member") security_groups = self._get_multi_param("SecurityGroups.member") @@ -42,27 +38,29 @@ class ELBResponse(BaseResponse): return template.render(load_balancer=load_balancer) def create_load_balancer_listeners(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") ports = self._get_list_prefix("Listeners.member") self.elb_backend.create_load_balancer_listeners( - name=load_balancer_name, ports=ports) + name=load_balancer_name, ports=ports + ) - template = self.response_template( - CREATE_LOAD_BALANCER_LISTENERS_TEMPLATE) + template = self.response_template(CREATE_LOAD_BALANCER_LISTENERS_TEMPLATE) return template.render() def describe_load_balancers(self): names = self._get_multi_param("LoadBalancerNames.member") all_load_balancers = list(self.elb_backend.describe_load_balancers(names)) - marker = self._get_param('Marker') + marker = self._get_param("Marker") all_names = [balancer.name for balancer in all_load_balancers] if marker: start = all_names.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('PageSize', 50) # the default is 400, but using 50 to make testing easier - load_balancers_resp = all_load_balancers[start:start + page_size] + page_size = self._get_int_param( + "PageSize", 50 + ) # the default is 400, but using 50 to make testing easier + load_balancers_resp = all_load_balancers[start : start + page_size] next_marker = None if len(all_load_balancers) > start + page_size: next_marker = load_balancers_resp[-1].name @@ -71,143 +69,158 @@ class ELBResponse(BaseResponse): return template.render(load_balancers=load_balancers_resp, marker=next_marker) def delete_load_balancer_listeners(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") ports = self._get_multi_param("LoadBalancerPorts.member") ports = [int(port) for port in ports] - self.elb_backend.delete_load_balancer_listeners( - load_balancer_name, ports) + self.elb_backend.delete_load_balancer_listeners(load_balancer_name, ports) template = self.response_template(DELETE_LOAD_BALANCER_LISTENERS) return template.render() def delete_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") self.elb_backend.delete_load_balancer(load_balancer_name) template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE) return template.render() def apply_security_groups_to_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") security_group_ids = self._get_multi_param("SecurityGroups.member") - self.elb_backend.apply_security_groups_to_load_balancer(load_balancer_name, security_group_ids) + self.elb_backend.apply_security_groups_to_load_balancer( + load_balancer_name, security_group_ids + ) template = self.response_template(APPLY_SECURITY_GROUPS_TEMPLATE) return template.render(security_group_ids=security_group_ids) def configure_health_check(self): check = self.elb_backend.configure_health_check( - load_balancer_name=self._get_param('LoadBalancerName'), - timeout=self._get_param('HealthCheck.Timeout'), - healthy_threshold=self._get_param('HealthCheck.HealthyThreshold'), - unhealthy_threshold=self._get_param( - 'HealthCheck.UnhealthyThreshold'), - interval=self._get_param('HealthCheck.Interval'), - target=self._get_param('HealthCheck.Target'), + load_balancer_name=self._get_param("LoadBalancerName"), + timeout=self._get_param("HealthCheck.Timeout"), + healthy_threshold=self._get_param("HealthCheck.HealthyThreshold"), + unhealthy_threshold=self._get_param("HealthCheck.UnhealthyThreshold"), + interval=self._get_param("HealthCheck.Interval"), + target=self._get_param("HealthCheck.Target"), ) template = self.response_template(CONFIGURE_HEALTH_CHECK_TEMPLATE) return template.render(check=check) def register_instances_with_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') - instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')] + load_balancer_name = self._get_param("LoadBalancerName") + instance_ids = [ + list(param.values())[0] + for param in self._get_list_prefix("Instances.member") + ] template = self.response_template(REGISTER_INSTANCES_TEMPLATE) load_balancer = self.elb_backend.register_instances( - load_balancer_name, instance_ids) + load_balancer_name, instance_ids + ) return template.render(load_balancer=load_balancer) def set_load_balancer_listener_ssl_certificate(self): - load_balancer_name = self._get_param('LoadBalancerName') - ssl_certificate_id = self.querystring['SSLCertificateId'][0] - lb_port = self.querystring['LoadBalancerPort'][0] + load_balancer_name = self._get_param("LoadBalancerName") + ssl_certificate_id = self.querystring["SSLCertificateId"][0] + lb_port = self.querystring["LoadBalancerPort"][0] self.elb_backend.set_load_balancer_listener_sslcertificate( - load_balancer_name, lb_port, ssl_certificate_id) + load_balancer_name, lb_port, ssl_certificate_id + ) template = self.response_template(SET_LOAD_BALANCER_SSL_CERTIFICATE) return template.render() def deregister_instances_from_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') - instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')] + load_balancer_name = self._get_param("LoadBalancerName") + instance_ids = [ + list(param.values())[0] + for param in self._get_list_prefix("Instances.member") + ] template = self.response_template(DEREGISTER_INSTANCES_TEMPLATE) load_balancer = self.elb_backend.deregister_instances( - load_balancer_name, instance_ids) + load_balancer_name, instance_ids + ) return template.render(load_balancer=load_balancer) def describe_load_balancer_attributes(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) template = self.response_template(DESCRIBE_ATTRIBUTES_TEMPLATE) return template.render(attributes=load_balancer.attributes) def modify_load_balancer_attributes(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) cross_zone = self._get_dict_param( - "LoadBalancerAttributes.CrossZoneLoadBalancing.") + "LoadBalancerAttributes.CrossZoneLoadBalancing." + ) if cross_zone: attribute = CrossZoneLoadBalancingAttribute() attribute.enabled = cross_zone["enabled"] == "true" self.elb_backend.set_cross_zone_load_balancing_attribute( - load_balancer_name, attribute) + load_balancer_name, attribute + ) access_log = self._get_dict_param("LoadBalancerAttributes.AccessLog.") if access_log: attribute = AccessLogAttribute() attribute.enabled = access_log["enabled"] == "true" - attribute.s3_bucket_name = access_log['s3_bucket_name'] - attribute.s3_bucket_prefix = access_log['s3_bucket_prefix'] + attribute.s3_bucket_name = access_log["s3_bucket_name"] + attribute.s3_bucket_prefix = access_log["s3_bucket_prefix"] attribute.emit_interval = access_log["emit_interval"] - self.elb_backend.set_access_log_attribute( - load_balancer_name, attribute) + self.elb_backend.set_access_log_attribute(load_balancer_name, attribute) connection_draining = self._get_dict_param( - "LoadBalancerAttributes.ConnectionDraining.") + "LoadBalancerAttributes.ConnectionDraining." + ) if connection_draining: attribute = ConnectionDrainingAttribute() attribute.enabled = connection_draining["enabled"] == "true" attribute.timeout = connection_draining.get("timeout", 300) - self.elb_backend.set_connection_draining_attribute(load_balancer_name, attribute) + self.elb_backend.set_connection_draining_attribute( + load_balancer_name, attribute + ) connection_settings = self._get_dict_param( - "LoadBalancerAttributes.ConnectionSettings.") + "LoadBalancerAttributes.ConnectionSettings." + ) if connection_settings: attribute = ConnectionSettingAttribute() attribute.idle_timeout = connection_settings["idle_timeout"] self.elb_backend.set_connection_settings_attribute( - load_balancer_name, attribute) + load_balancer_name, attribute + ) template = self.response_template(MODIFY_ATTRIBUTES_TEMPLATE) - return template.render(load_balancer=load_balancer, attributes=load_balancer.attributes) + return template.render( + load_balancer=load_balancer, attributes=load_balancer.attributes + ) def create_load_balancer_policy(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") other_policy = OtherPolicy() policy_name = self._get_param("PolicyName") other_policy.policy_name = policy_name - self.elb_backend.create_lb_other_policy( - load_balancer_name, other_policy) + self.elb_backend.create_lb_other_policy(load_balancer_name, other_policy) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) return template.render() def create_app_cookie_stickiness_policy(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") policy = AppCookieStickinessPolicy() policy.policy_name = self._get_param("PolicyName") policy.cookie_name = self._get_param("CookieName") - self.elb_backend.create_app_cookie_stickiness_policy( - load_balancer_name, policy) + self.elb_backend.create_app_cookie_stickiness_policy(load_balancer_name, policy) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) return template.render() def create_lb_cookie_stickiness_policy(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") policy = AppCookieStickinessPolicy() policy.policy_name = self._get_param("PolicyName") @@ -217,62 +230,68 @@ class ELBResponse(BaseResponse): else: policy.cookie_expiration_period = None - self.elb_backend.create_lb_cookie_stickiness_policy( - load_balancer_name, policy) + self.elb_backend.create_lb_cookie_stickiness_policy(load_balancer_name, policy) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) return template.render() def set_load_balancer_policies_of_listener(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) - load_balancer_port = int(self._get_param('LoadBalancerPort')) + load_balancer_port = int(self._get_param("LoadBalancerPort")) - mb_listener = [l for l in load_balancer.listeners if int( - l.load_balancer_port) == load_balancer_port] + mb_listener = [ + l + for l in load_balancer.listeners + if int(l.load_balancer_port) == load_balancer_port + ] if mb_listener: policies = self._get_multi_param("PolicyNames.member") self.elb_backend.set_load_balancer_policies_of_listener( - load_balancer_name, load_balancer_port, policies) + load_balancer_name, load_balancer_port, policies + ) # else: explode? template = self.response_template( - SET_LOAD_BALANCER_POLICIES_OF_LISTENER_TEMPLATE) + SET_LOAD_BALANCER_POLICIES_OF_LISTENER_TEMPLATE + ) return template.render() def set_load_balancer_policies_for_backend_server(self): - load_balancer_name = self.querystring.get('LoadBalancerName')[0] + load_balancer_name = self.querystring.get("LoadBalancerName")[0] load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) - instance_port = int(self.querystring.get('InstancePort')[0]) + instance_port = int(self.querystring.get("InstancePort")[0]) - mb_backend = [b for b in load_balancer.backends if int( - b.instance_port) == instance_port] + mb_backend = [ + b for b in load_balancer.backends if int(b.instance_port) == instance_port + ] if mb_backend: - policies = self._get_multi_param('PolicyNames.member') + policies = self._get_multi_param("PolicyNames.member") self.elb_backend.set_load_balancer_policies_of_backend_server( - load_balancer_name, instance_port, policies) + load_balancer_name, instance_port, policies + ) # else: explode? template = self.response_template( - SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE) + SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE + ) return template.render() def describe_instance_health(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") provided_instance_ids = [ list(param.values())[0] - for param in self._get_list_prefix('Instances.member') + for param in self._get_list_prefix("Instances.member") ] registered_instances_id = self.elb_backend.get_load_balancer( - load_balancer_name).instance_ids + load_balancer_name + ).instance_ids if len(provided_instance_ids) == 0: provided_instance_ids = registered_instances_id template = self.response_template(DESCRIBE_INSTANCE_HEALTH_TEMPLATE) instances = [] for instance_id in provided_instance_ids: - state = "InService" \ - if instance_id in registered_instances_id\ - else "Unknown" + state = "InService" if instance_id in registered_instances_id else "Unknown" instances.append({"InstanceId": instance_id, "State": state}) return template.render(instances=instances) @@ -293,17 +312,18 @@ class ELBResponse(BaseResponse): def remove_tags(self): for key, value in self.querystring.items(): if "LoadBalancerNames.member" in key: - number = key.split('.')[2] + number = key.split(".")[2] load_balancer_name = self._get_param( - 'LoadBalancerNames.member.{0}'.format(number)) + "LoadBalancerNames.member.{0}".format(number) + ) elb = self.elb_backend.get_load_balancer(load_balancer_name) if not elb: raise LoadBalancerNotFoundError(load_balancer_name) - key = 'Tag.member.{0}.Key'.format(number) + key = "Tag.member.{0}.Key".format(number) for t_key, t_val in self.querystring.items(): - if t_key.startswith('Tags.member.'): - if t_key.split('.')[3] == 'Key': + if t_key.startswith("Tags.member."): + if t_key.split(".")[3] == "Key": elb.remove_tag(t_val[0]) template = self.response_template(REMOVE_TAGS_TEMPLATE) @@ -313,9 +333,10 @@ class ELBResponse(BaseResponse): elbs = [] for key, value in self.querystring.items(): if "LoadBalancerNames.member" in key: - number = key.split('.')[2] + number = key.split(".")[2] load_balancer_name = self._get_param( - 'LoadBalancerNames.member.{0}'.format(number)) + "LoadBalancerNames.member.{0}".format(number) + ) elb = self.elb_backend.get_load_balancer(load_balancer_name) if not elb: raise LoadBalancerNotFoundError(load_balancer_name) @@ -329,10 +350,10 @@ class ELBResponse(BaseResponse): tag_keys = [] for t_key, t_val in sorted(self.querystring.items()): - if t_key.startswith('Tags.member.'): - if t_key.split('.')[3] == 'Key': + if t_key.startswith("Tags.member."): + if t_key.split(".")[3] == "Key": tag_keys.extend(t_val) - elif t_key.split('.')[3] == 'Value': + elif t_key.split(".")[3] == "Value": tag_values.extend(t_val) counts = {} diff --git a/moto/elb/urls.py b/moto/elb/urls.py index 3d96e1892..bb7f1c7bf 100644 --- a/moto/elb/urls.py +++ b/moto/elb/urls.py @@ -16,29 +16,25 @@ def api_version_elb_backend(*args, **kwargs): """ request = args[0] - if hasattr(request, 'values'): + if hasattr(request, "values"): # boto3 - version = request.values.get('Version') + version = request.values.get("Version") elif isinstance(request, AWSPreparedRequest): # boto in-memory - version = parse_qs(request.body).get('Version')[0] + version = parse_qs(request.body).get("Version")[0] else: # boto in server mode request.parse_request() - version = request.querystring.get('Version')[0] + version = request.querystring.get("Version")[0] - if '2012-06-01' == version: + if "2012-06-01" == version: return ELBResponse.dispatch(*args, **kwargs) - elif '2015-12-01' == version: + elif "2015-12-01" == version: return ELBV2Response.dispatch(*args, **kwargs) else: raise Exception("Unknown ELB API version: {}".format(version)) -url_bases = [ - "https?://elasticloadbalancing.(.+).amazonaws.com", -] +url_bases = ["https?://elasticloadbalancing.(.+).amazonaws.com"] -url_paths = { - '{0}/$': api_version_elb_backend, -} +url_paths = {"{0}/$": api_version_elb_backend} diff --git a/moto/elbv2/__init__.py b/moto/elbv2/__init__.py index 21a6d06c6..61c4a37ff 100644 --- a/moto/elbv2/__init__.py +++ b/moto/elbv2/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import elbv2_backends from ..core.models import base_decorator -elb_backend = elbv2_backends['us-east-1'] +elb_backend = elbv2_backends["us-east-1"] mock_elbv2 = base_decorator(elbv2_backends) diff --git a/moto/elbv2/exceptions.py b/moto/elbv2/exceptions.py index 11dcbcb21..8ea509d0d 100644 --- a/moto/elbv2/exceptions.py +++ b/moto/elbv2/exceptions.py @@ -7,186 +7,175 @@ class ELBClientError(RESTError): class DuplicateTagKeysError(ELBClientError): - def __init__(self, cidr): super(DuplicateTagKeysError, self).__init__( - "DuplicateTagKeys", - "Tag key was specified more than once: {0}" - .format(cidr)) + "DuplicateTagKeys", "Tag key was specified more than once: {0}".format(cidr) + ) class LoadBalancerNotFoundError(ELBClientError): - def __init__(self): super(LoadBalancerNotFoundError, self).__init__( - "LoadBalancerNotFound", - "The specified load balancer does not exist.") + "LoadBalancerNotFound", "The specified load balancer does not exist." + ) class ListenerNotFoundError(ELBClientError): - def __init__(self): super(ListenerNotFoundError, self).__init__( - "ListenerNotFound", - "The specified listener does not exist.") + "ListenerNotFound", "The specified listener does not exist." + ) class SubnetNotFoundError(ELBClientError): - def __init__(self): super(SubnetNotFoundError, self).__init__( - "SubnetNotFound", - "The specified subnet does not exist.") + "SubnetNotFound", "The specified subnet does not exist." + ) class TargetGroupNotFoundError(ELBClientError): - def __init__(self): super(TargetGroupNotFoundError, self).__init__( - "TargetGroupNotFound", - "The specified target group does not exist.") + "TargetGroupNotFound", "The specified target group does not exist." + ) class TooManyTagsError(ELBClientError): - def __init__(self): super(TooManyTagsError, self).__init__( "TooManyTagsError", - "The quota for the number of tags that can be assigned to a load balancer has been reached") + "The quota for the number of tags that can be assigned to a load balancer has been reached", + ) class BadHealthCheckDefinition(ELBClientError): - def __init__(self): super(BadHealthCheckDefinition, self).__init__( "ValidationError", - "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL") + "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL", + ) class DuplicateListenerError(ELBClientError): - def __init__(self): super(DuplicateListenerError, self).__init__( - "DuplicateListener", - "A listener with the specified port already exists.") + "DuplicateListener", "A listener with the specified port already exists." + ) class DuplicateLoadBalancerName(ELBClientError): - def __init__(self): super(DuplicateLoadBalancerName, self).__init__( "DuplicateLoadBalancerName", - "A load balancer with the specified name already exists.") + "A load balancer with the specified name already exists.", + ) class DuplicateTargetGroupName(ELBClientError): - def __init__(self): super(DuplicateTargetGroupName, self).__init__( "DuplicateTargetGroupName", - "A target group with the specified name already exists.") + "A target group with the specified name already exists.", + ) class InvalidTargetError(ELBClientError): - def __init__(self): super(InvalidTargetError, self).__init__( "InvalidTarget", - "The specified target does not exist or is not in the same VPC as the target group.") + "The specified target does not exist or is not in the same VPC as the target group.", + ) class EmptyListenersError(ELBClientError): - def __init__(self): super(EmptyListenersError, self).__init__( - "ValidationError", - "Listeners cannot be empty") + "ValidationError", "Listeners cannot be empty" + ) class PriorityInUseError(ELBClientError): - def __init__(self): super(PriorityInUseError, self).__init__( - "PriorityInUse", - "The specified priority is in use.") + "PriorityInUse", "The specified priority is in use." + ) class InvalidConditionFieldError(ELBClientError): - def __init__(self, invalid_name): super(InvalidConditionFieldError, self).__init__( "ValidationError", - "Condition field '%s' must be one of '[path-pattern, host-header]" % (invalid_name)) + "Condition field '%s' must be one of '[path-pattern, host-header]" + % (invalid_name), + ) class InvalidConditionValueError(ELBClientError): - def __init__(self, msg): - super(InvalidConditionValueError, self).__init__( - "ValidationError", msg) + super(InvalidConditionValueError, self).__init__("ValidationError", msg) class InvalidActionTypeError(ELBClientError): - def __init__(self, invalid_name, index): super(InvalidActionTypeError, self).__init__( "ValidationError", - "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect]" % (invalid_name, index) + "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect, fixed-response]" + % (invalid_name, index), ) class ActionTargetGroupNotFoundError(ELBClientError): - def __init__(self, arn): super(ActionTargetGroupNotFoundError, self).__init__( - "TargetGroupNotFound", - "Target group '%s' not found" % arn + "TargetGroupNotFound", "Target group '%s' not found" % arn ) class InvalidDescribeRulesRequest(ELBClientError): - def __init__(self, msg): - super(InvalidDescribeRulesRequest, self).__init__( - "ValidationError", msg - ) + super(InvalidDescribeRulesRequest, self).__init__("ValidationError", msg) class ResourceInUseError(ELBClientError): - def __init__(self, msg="A specified resource is in use"): - super(ResourceInUseError, self).__init__( - "ResourceInUse", msg) + super(ResourceInUseError, self).__init__("ResourceInUse", msg) class RuleNotFoundError(ELBClientError): - def __init__(self): super(RuleNotFoundError, self).__init__( - "RuleNotFound", - "The specified rule does not exist.") + "RuleNotFound", "The specified rule does not exist." + ) class DuplicatePriorityError(ELBClientError): - def __init__(self, invalid_value): super(DuplicatePriorityError, self).__init__( "ValidationError", - "Priority '%s' was provided multiple times" % invalid_value) + "Priority '%s' was provided multiple times" % invalid_value, + ) class InvalidTargetGroupNameError(ELBClientError): - def __init__(self, msg): - super(InvalidTargetGroupNameError, self).__init__( - "ValidationError", msg - ) + super(InvalidTargetGroupNameError, self).__init__("ValidationError", msg) class InvalidModifyRuleArgumentsError(ELBClientError): - def __init__(self): super(InvalidModifyRuleArgumentsError, self).__init__( - "ValidationError", - "Either conditions or actions must be specified" + "ValidationError", "Either conditions or actions must be specified" + ) + + +class InvalidStatusCodeActionTypeError(ELBClientError): + def __init__(self, msg): + super(InvalidStatusCodeActionTypeError, self).__init__("ValidationError", msg) + + +class InvalidLoadBalancerActionException(ELBClientError): + def __init__(self, msg): + super(InvalidLoadBalancerActionException, self).__init__( + "InvalidLoadBalancerAction", msg ) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 7e73c7042..fdce9a8c2 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -3,10 +3,11 @@ from __future__ import unicode_literals import datetime import re from jinja2 import Template +from botocore.exceptions import ParamValidationError from moto.compat import OrderedDict from moto.core.exceptions import RESTError from moto.core import BaseBackend, BaseModel -from moto.core.utils import camelcase_to_underscores +from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase from moto.ec2.models import ec2_backends from moto.acm.models import acm_backends from .utils import make_arn_for_target_group @@ -31,13 +32,16 @@ from .exceptions import ( RuleNotFoundError, DuplicatePriorityError, InvalidTargetGroupNameError, - InvalidModifyRuleArgumentsError + InvalidModifyRuleArgumentsError, + InvalidStatusCodeActionTypeError, + InvalidLoadBalancerActionException, ) class FakeHealthStatus(BaseModel): - - def __init__(self, instance_id, port, health_port, status, reason=None, description=None): + def __init__( + self, instance_id, port, health_port, status, reason=None, description=None + ): self.instance_id = instance_id self.port = port self.health_port = health_port @@ -47,23 +51,25 @@ class FakeHealthStatus(BaseModel): class FakeTargetGroup(BaseModel): - HTTP_CODE_REGEX = re.compile(r'(?:(?:\d+-\d+|\d+),?)+') + HTTP_CODE_REGEX = re.compile(r"(?:(?:\d+-\d+|\d+),?)+") - def __init__(self, - name, - arn, - vpc_id, - protocol, - port, - healthcheck_protocol=None, - healthcheck_port=None, - healthcheck_path=None, - healthcheck_interval_seconds=None, - healthcheck_timeout_seconds=None, - healthy_threshold_count=None, - unhealthy_threshold_count=None, - matcher=None, - target_type=None): + def __init__( + self, + name, + arn, + vpc_id, + protocol, + port, + healthcheck_protocol=None, + healthcheck_port=None, + healthcheck_path=None, + healthcheck_interval_seconds=None, + healthcheck_timeout_seconds=None, + healthy_threshold_count=None, + unhealthy_threshold_count=None, + matcher=None, + target_type=None, + ): # TODO: default values differs when you add Network Load balancer self.name = name @@ -71,9 +77,9 @@ class FakeTargetGroup(BaseModel): self.vpc_id = vpc_id self.protocol = protocol self.port = port - self.healthcheck_protocol = healthcheck_protocol or 'HTTP' + self.healthcheck_protocol = healthcheck_protocol or "HTTP" self.healthcheck_port = healthcheck_port or str(self.port) - self.healthcheck_path = healthcheck_path or '/' + self.healthcheck_path = healthcheck_path or "/" self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30 self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5 self.healthy_threshold_count = healthy_threshold_count or 5 @@ -81,14 +87,14 @@ class FakeTargetGroup(BaseModel): self.load_balancer_arns = [] self.tags = {} if matcher is None: - self.matcher = {'HttpCode': '200'} + self.matcher = {"HttpCode": "200"} else: self.matcher = matcher self.target_type = target_type self.attributes = { - 'deregistration_delay.timeout_seconds': 300, - 'stickiness.enabled': 'false', + "deregistration_delay.timeout_seconds": 300, + "stickiness.enabled": "false", } self.targets = OrderedDict() @@ -99,41 +105,55 @@ class FakeTargetGroup(BaseModel): def register(self, targets): for target in targets: - self.targets[target['id']] = { - 'id': target['id'], - 'port': target.get('port', self.port), + self.targets[target["id"]] = { + "id": target["id"], + "port": target.get("port", self.port), } def deregister(self, targets): for target in targets: - t = self.targets.pop(target['id'], None) + t = self.targets.pop(target["id"], None) if not t: raise InvalidTargetError() + def deregister_terminated_instances(self, instance_ids): + for target_id in list(self.targets.keys()): + if target_id in instance_ids: + del self.targets[target_id] + def add_tag(self, key, value): if len(self.tags) >= 10 and key not in self.tags: raise TooManyTagsError() self.tags[key] = value def health_for(self, target, ec2_backend): - t = self.targets.get(target['id']) + t = self.targets.get(target["id"]) if t is None: raise InvalidTargetError() - if t['id'].startswith("i-"): # EC2 instance ID - instance = ec2_backend.get_instance_by_id(t['id']) + if t["id"].startswith("i-"): # EC2 instance ID + instance = ec2_backend.get_instance_by_id(t["id"]) if instance.state == "stopped": - return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'unused', 'Target.InvalidState', 'Target is in the stopped state') - return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy') + return FakeHealthStatus( + t["id"], + t["port"], + self.healthcheck_port, + "unused", + "Target.InvalidState", + "Target is in the stopped state", + ) + return FakeHealthStatus(t["id"], t["port"], self.healthcheck_port, "healthy") @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elbv2_backend = elbv2_backends[region_name] - name = properties.get('Name') + name = properties.get("Name") vpc_id = properties.get("VpcId") - protocol = properties.get('Protocol') + protocol = properties.get("Protocol") port = properties.get("Port") healthcheck_protocol = properties.get("HealthCheckProtocol") healthcheck_port = properties.get("HealthCheckPort") @@ -164,8 +184,16 @@ class FakeTargetGroup(BaseModel): class FakeListener(BaseModel): - - def __init__(self, load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions): + def __init__( + self, + load_balancer_arn, + arn, + protocol, + port, + ssl_policy, + certificate, + default_actions, + ): self.load_balancer_arn = load_balancer_arn self.arn = arn self.protocol = protocol.upper() @@ -178,9 +206,9 @@ class FakeListener(BaseModel): self._default_rule = FakeRule( listener_arn=self.arn, conditions=[], - priority='default', + priority="default", actions=default_actions, - is_default=True + is_default=True, ) @property @@ -196,11 +224,15 @@ class FakeListener(BaseModel): def register(self, rule): self._non_default_rules.append(rule) - self._non_default_rules = sorted(self._non_default_rules, key=lambda x: x.priority) + self._non_default_rules = sorted( + self._non_default_rules, key=lambda x: x.priority + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elbv2_backend = elbv2_backends[region_name] load_balancer_arn = properties.get("LoadBalancerArn") @@ -211,16 +243,36 @@ class FakeListener(BaseModel): # transform default actions to confirm with the rest of the code and XML templates if "DefaultActions" in properties: default_actions = [] - for i, action in enumerate(properties['DefaultActions']): - action_type = action['Type'] - if action_type == 'forward': - default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) - elif action_type in ['redirect', 'authenticate-cognito']: - redirect_action = {'type': action_type} - key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig' - for redirect_config_key, redirect_config_value in action[key].items(): + for i, action in enumerate(properties["DefaultActions"]): + action_type = action["Type"] + if action_type == "forward": + default_actions.append( + { + "type": action_type, + "target_group_arn": action["TargetGroupArn"], + } + ) + elif action_type in [ + "redirect", + "authenticate-cognito", + "fixed-response", + ]: + redirect_action = {"type": action_type} + key = ( + underscores_to_camelcase( + action_type.capitalize().replace("-", "_") + ) + + "Config" + ) + for redirect_config_key, redirect_config_value in action[ + key + ].items(): # need to match the output of _get_list_prefix - redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value + redirect_action[ + camelcase_to_underscores(key) + + "._" + + camelcase_to_underscores(redirect_config_key) + ] = redirect_config_value default_actions.append(redirect_action) else: raise InvalidActionTypeError(action_type, i + 1) @@ -228,7 +280,8 @@ class FakeListener(BaseModel): default_actions = None listener = elbv2_backend.create_listener( - load_balancer_arn, protocol, port, ssl_policy, certificates, default_actions) + load_balancer_arn, protocol, port, ssl_policy, certificates, default_actions + ) return listener @@ -238,7 +291,8 @@ class FakeAction(BaseModel): self.type = data.get("type") def to_xml(self): - template = Template("""{{ action.type }} + template = Template( + """{{ action.type }} {% if action.type == "forward" %} {{ action.data["target_group_arn"] }} {% elif action.type == "redirect" %} @@ -253,16 +307,24 @@ class FakeAction(BaseModel): {{ action.data["authenticate_cognito_config._user_pool_client_id"] }} {{ action.data["authenticate_cognito_config._user_pool_domain"] }} + {% elif action.type == "fixed-response" %} + + {{ action.data["fixed_response_config._content_type"] }} + {{ action.data["fixed_response_config._message_body"] }} + {{ action.data["fixed_response_config._status_code"] }} + {% endif %} - """) + """ + ) return template.render(action=self) class FakeRule(BaseModel): - def __init__(self, listener_arn, conditions, priority, actions, is_default): self.listener_arn = listener_arn - self.arn = listener_arn.replace(':listener/', ':listener-rule/') + "/%s" % (id(self)) + self.arn = listener_arn.replace(":listener/", ":listener-rule/") + "/%s" % ( + id(self) + ) self.conditions = conditions self.priority = priority # int or 'default' self.actions = actions @@ -270,20 +332,36 @@ class FakeRule(BaseModel): class FakeBackend(BaseModel): - def __init__(self, instance_port): self.instance_port = instance_port self.policy_names = [] def __repr__(self): - return "FakeBackend(inp: %s, policies: %s)" % (self.instance_port, self.policy_names) + return "FakeBackend(inp: %s, policies: %s)" % ( + self.instance_port, + self.policy_names, + ) class FakeLoadBalancer(BaseModel): - VALID_ATTRS = {'access_logs.s3.enabled', 'access_logs.s3.bucket', 'access_logs.s3.prefix', - 'deletion_protection.enabled', 'idle_timeout.timeout_seconds'} + VALID_ATTRS = { + "access_logs.s3.enabled", + "access_logs.s3.bucket", + "access_logs.s3.prefix", + "deletion_protection.enabled", + "idle_timeout.timeout_seconds", + } - def __init__(self, name, security_groups, subnets, vpc_id, arn, dns_name, scheme='internet-facing'): + def __init__( + self, + name, + security_groups, + subnets, + vpc_id, + arn, + dns_name, + scheme="internet-facing", + ): self.name = name self.created_time = datetime.datetime.now() self.scheme = scheme @@ -295,13 +373,13 @@ class FakeLoadBalancer(BaseModel): self.arn = arn self.dns_name = dns_name - self.stack = 'ipv4' + self.stack = "ipv4" self.attrs = { - 'access_logs.s3.enabled': 'false', - 'access_logs.s3.bucket': None, - 'access_logs.s3.prefix': None, - 'deletion_protection.enabled': 'false', - 'idle_timeout.timeout_seconds': '60' + "access_logs.s3.enabled": "false", + "access_logs.s3.bucket": None, + "access_logs.s3.prefix": None, + "deletion_protection.enabled": "false", + "idle_timeout.timeout_seconds": "60", } @property @@ -321,25 +399,29 @@ class FakeLoadBalancer(BaseModel): del self.tags[key] def delete(self, region): - ''' Not exposed as part of the ELB API - used for CloudFormation. ''' + """ Not exposed as part of the ELB API - used for CloudFormation. """ elbv2_backends[region].delete_load_balancer(self.arn) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elbv2_backend = elbv2_backends[region_name] - name = properties.get('Name', resource_name) + name = properties.get("Name", resource_name) security_groups = properties.get("SecurityGroups") - subnet_ids = properties.get('Subnets') - scheme = properties.get('Scheme', 'internet-facing') + subnet_ids = properties.get("Subnets") + scheme = properties.get("Scheme", "internet-facing") - load_balancer = elbv2_backend.create_load_balancer(name, security_groups, subnet_ids, scheme=scheme) + load_balancer = elbv2_backend.create_load_balancer( + name, security_groups, subnet_ids, scheme=scheme + ) return load_balancer def get_cfn_attribute(self, attribute_name): - ''' + """ Implemented attributes: * DNSName * LoadBalancerName @@ -350,25 +432,27 @@ class FakeLoadBalancer(BaseModel): * SecurityGroups This method is similar to models.py:FakeLoadBalancer.get_cfn_attribute() - ''' + """ from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + not_implemented_yet = [ - 'CanonicalHostedZoneID', - 'LoadBalancerFullName', - 'SecurityGroups', + "CanonicalHostedZoneID", + "LoadBalancerFullName", + "SecurityGroups", ] - if attribute_name == 'DNSName': + if attribute_name == "DNSName": return self.dns_name - elif attribute_name == 'LoadBalancerName': + elif attribute_name == "LoadBalancerName": return self.name elif attribute_name in not_implemented_yet: - raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "%s" ]"' % attribute_name) + raise NotImplementedError( + '"Fn::GetAtt" : [ "{0}" , "%s" ]"' % attribute_name + ) else: raise UnformattedGetAttTemplateException() class ELBv2Backend(BaseBackend): - def __init__(self, region_name=None): self.region_name = region_name self.target_groups = OrderedDict() @@ -399,7 +483,9 @@ class ELBv2Backend(BaseBackend): self.__dict__ = {} self.__init__(region_name) - def create_load_balancer(self, name, security_groups, subnet_ids, scheme='internet-facing'): + def create_load_balancer( + self, name, security_groups, subnet_ids, scheme="internet-facing" + ): vpc_id = None subnets = [] if not subnet_ids: @@ -411,7 +497,9 @@ class ELBv2Backend(BaseBackend): subnets.append(subnet) vpc_id = subnets[0].vpc_id - arn = make_arn_for_load_balancer(account_id=1, name=name, region_name=self.region_name) + arn = make_arn_for_load_balancer( + account_id=1, name=name, region_name=self.region_name + ) dns_name = "%s-1.%s.elb.amazonaws.com" % (name, self.region_name) if arn in self.load_balancers: @@ -424,7 +512,8 @@ class ELBv2Backend(BaseBackend): scheme=scheme, subnets=subnets, vpc_id=vpc_id, - dns_name=dns_name) + dns_name=dns_name, + ) self.load_balancers[arn] = new_load_balancer return new_load_balancer @@ -437,13 +526,13 @@ class ELBv2Backend(BaseBackend): # validate conditions for condition in conditions: - field = condition['field'] - if field not in ['path-pattern', 'host-header']: + field = condition["field"] + if field not in ["path-pattern", "host-header"]: raise InvalidConditionFieldError(field) - values = condition['values'] + values = condition["values"] if len(values) == 0: - raise InvalidConditionValueError('A condition value must be specified') + raise InvalidConditionValueError("A condition value must be specified") if len(values) > 1: raise InvalidConditionValueError( "The '%s' field contains too many values; the limit is '1'" % field @@ -469,36 +558,67 @@ class ELBv2Backend(BaseBackend): def _validate_actions(self, actions): # validate Actions - target_group_arns = [target_group.arn for target_group in self.target_groups.values()] + target_group_arns = [ + target_group.arn for target_group in self.target_groups.values() + ] for i, action in enumerate(actions): index = i + 1 action_type = action.type - if action_type == 'forward': - action_target_group_arn = action.data['target_group_arn'] + if action_type == "forward": + action_target_group_arn = action.data["target_group_arn"] if action_target_group_arn not in target_group_arns: raise ActionTargetGroupNotFoundError(action_target_group_arn) - elif action_type in ['redirect', 'authenticate-cognito']: + elif action_type == "fixed-response": + self._validate_fixed_response_action(action, i, index) + elif action_type in ["redirect", "authenticate-cognito"]: pass else: raise InvalidActionTypeError(action_type, index) + def _validate_fixed_response_action(self, action, i, index): + status_code = action.data.get("fixed_response_config._status_code") + if status_code is None: + raise ParamValidationError( + report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' + % i + ) + if not re.match(r"^(2|4|5)\d\d$", status_code): + raise InvalidStatusCodeActionTypeError( + "1 validation error detected: Value '%s' at 'actions.%s.member.fixedResponseConfig.statusCode' failed to satisfy constraint: \ +Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" + % (status_code, index) + ) + content_type = action.data["fixed_response_config._content_type"] + if content_type and content_type not in [ + "text/plain", + "text/css", + "text/html", + "application/javascript", + "application/json", + ]: + raise InvalidLoadBalancerActionException( + "The ContentType must be one of:'text/html', 'application/json', 'application/javascript', 'text/css', 'text/plain'" + ) + def create_target_group(self, name, **kwargs): if len(name) > 32: raise InvalidTargetGroupNameError( "Target group name '%s' cannot be longer than '32' characters" % name ) - if not re.match('^[a-zA-Z0-9\-]+$', name): + if not re.match("^[a-zA-Z0-9\-]+$", name): raise InvalidTargetGroupNameError( - "Target group name '%s' can only contain characters that are alphanumeric characters or hyphens(-)" % name + "Target group name '%s' can only contain characters that are alphanumeric characters or hyphens(-)" + % name ) # undocumented validation - if not re.match('(?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$', name): + if not re.match("(?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$", name): raise InvalidTargetGroupNameError( - "1 validation error detected: Value '%s' at 'targetGroup.targetGroupArn.targetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern: (?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$" % name + "1 validation error detected: Value '%s' at 'targetGroup.targetGroupArn.targetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern: (?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$" + % name ) - if name.startswith('-') or name.endswith('-'): + if name.startswith("-") or name.endswith("-"): raise InvalidTargetGroupNameError( "Target group name '%s' cannot begin or end with '-'" % name ) @@ -506,25 +626,51 @@ class ELBv2Backend(BaseBackend): if target_group.name == name: raise DuplicateTargetGroupName() - valid_protocols = ['HTTPS', 'HTTP', 'TCP'] - if kwargs.get('healthcheck_protocol') and kwargs['healthcheck_protocol'] not in valid_protocols: + valid_protocols = ["HTTPS", "HTTP", "TCP"] + if ( + kwargs.get("healthcheck_protocol") + and kwargs["healthcheck_protocol"] not in valid_protocols + ): raise InvalidConditionValueError( "Value {} at 'healthCheckProtocol' failed to satisfy constraint: " - "Member must satisfy enum value set: {}".format(kwargs['healthcheck_protocol'], valid_protocols)) - if kwargs.get('protocol') and kwargs['protocol'] not in valid_protocols: + "Member must satisfy enum value set: {}".format( + kwargs["healthcheck_protocol"], valid_protocols + ) + ) + if kwargs.get("protocol") and kwargs["protocol"] not in valid_protocols: raise InvalidConditionValueError( "Value {} at 'protocol' failed to satisfy constraint: " - "Member must satisfy enum value set: {}".format(kwargs['protocol'], valid_protocols)) + "Member must satisfy enum value set: {}".format( + kwargs["protocol"], valid_protocols + ) + ) - if kwargs.get('matcher') and FakeTargetGroup.HTTP_CODE_REGEX.match(kwargs['matcher']['HttpCode']) is None: - raise RESTError('InvalidParameterValue', 'HttpCode must be like 200 | 200-399 | 200,201 ...') + if ( + kwargs.get("matcher") + and FakeTargetGroup.HTTP_CODE_REGEX.match(kwargs["matcher"]["HttpCode"]) + is None + ): + raise RESTError( + "InvalidParameterValue", + "HttpCode must be like 200 | 200-399 | 200,201 ...", + ) - arn = make_arn_for_target_group(account_id=1, name=name, region_name=self.region_name) + arn = make_arn_for_target_group( + account_id=1, name=name, region_name=self.region_name + ) target_group = FakeTargetGroup(name, arn, **kwargs) self.target_groups[target_group.arn] = target_group return target_group - def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions): + def create_listener( + self, + load_balancer_arn, + protocol, + port, + ssl_policy, + certificate, + default_actions, + ): default_actions = [FakeAction(action) for action in default_actions] balancer = self.load_balancers.get(load_balancer_arn) if balancer is None: @@ -534,12 +680,23 @@ class ELBv2Backend(BaseBackend): self._validate_actions(default_actions) - arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self)) - listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions) + arn = load_balancer_arn.replace(":loadbalancer/", ":listener/") + "/%s%s" % ( + port, + id(self), + ) + listener = FakeListener( + load_balancer_arn, + arn, + protocol, + port, + ssl_policy, + certificate, + default_actions, + ) balancer.listeners[listener.arn] = listener for action in default_actions: - if action.type == 'forward': - target_group = self.target_groups[action.data['target_group_arn']] + if action.type == "forward": + target_group = self.target_groups[action.data["target_group_arn"]] target_group.load_balancer_arns.append(load_balancer_arn) return listener @@ -581,7 +738,7 @@ class ELBv2Backend(BaseBackend): ) if listener_arn is not None and rule_arns is not None: raise InvalidDescribeRulesRequest( - 'Listener rule ARNs and a listener ARN cannot be specified at the same time' + "Listener rule ARNs and a listener ARN cannot be specified at the same time" ) if listener_arn: listener = self.describe_listeners(None, [listener_arn])[0] @@ -601,8 +758,11 @@ class ELBv2Backend(BaseBackend): if load_balancer_arn: if load_balancer_arn not in self.load_balancers: raise LoadBalancerNotFoundError() - return [tg for tg in self.target_groups.values() - if load_balancer_arn in tg.load_balancer_arns] + return [ + tg + for tg in self.target_groups.values() + if load_balancer_arn in tg.load_balancer_arns + ] if target_group_arns: try: @@ -662,7 +822,9 @@ class ELBv2Backend(BaseBackend): if self._any_listener_using(target_group_arn): raise ResourceInUseError( "The target group '{}' is currently in use by a listener or a rule".format( - target_group_arn)) + target_group_arn + ) + ) del self.target_groups[target_group_arn] return target_group @@ -685,16 +847,19 @@ class ELBv2Backend(BaseBackend): if conditions: for condition in conditions: - field = condition['field'] - if field not in ['path-pattern', 'host-header']: + field = condition["field"] + if field not in ["path-pattern", "host-header"]: raise InvalidConditionFieldError(field) - values = condition['values'] + values = condition["values"] if len(values) == 0: - raise InvalidConditionValueError('A condition value must be specified') + raise InvalidConditionValueError( + "A condition value must be specified" + ) if len(values) > 1: raise InvalidConditionValueError( - "The '%s' field contains too many values; the limit is '1'" % field + "The '%s' field contains too many values; the limit is '1'" + % field ) # TODO: check pattern of value for 'host-header' # TODO: check pattern of value for 'path-pattern' @@ -735,16 +900,18 @@ class ELBv2Backend(BaseBackend): def set_rule_priorities(self, rule_priorities): # validate - priorities = [rule_priority['priority'] for rule_priority in rule_priorities] + priorities = [rule_priority["priority"] for rule_priority in rule_priorities] for priority in set(priorities): if priorities.count(priority) > 1: raise DuplicatePriorityError(priority) # validate for rule_priority in rule_priorities: - given_rule_arn = rule_priority['rule_arn'] - priority = rule_priority['priority'] - _given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn]) + given_rule_arn = rule_priority["rule_arn"] + priority = rule_priority["priority"] + _given_rules = self.describe_rules( + listener_arn=None, rule_arns=[given_rule_arn] + ) if not _given_rules: raise RuleNotFoundError() given_rule = _given_rules[0] @@ -756,9 +923,11 @@ class ELBv2Backend(BaseBackend): # modify modified_rules = [] for rule_priority in rule_priorities: - given_rule_arn = rule_priority['rule_arn'] - priority = rule_priority['priority'] - _given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn]) + given_rule_arn = rule_priority["rule_arn"] + priority = rule_priority["priority"] + _given_rules = self.describe_rules( + listener_arn=None, rule_arns=[given_rule_arn] + ) if not _given_rules: raise RuleNotFoundError() given_rule = _given_rules[0] @@ -767,15 +936,21 @@ class ELBv2Backend(BaseBackend): return modified_rules def set_ip_address_type(self, arn, ip_type): - if ip_type not in ('internal', 'dualstack'): - raise RESTError('InvalidParameterValue', 'IpAddressType must be either internal | dualstack') + if ip_type not in ("internal", "dualstack"): + raise RESTError( + "InvalidParameterValue", + "IpAddressType must be either internal | dualstack", + ) balancer = self.load_balancers.get(arn) if balancer is None: raise LoadBalancerNotFoundError() - if ip_type == 'dualstack' and balancer.scheme == 'internal': - raise RESTError('InvalidConfigurationRequest', 'Internal load balancers cannot be dualstack') + if ip_type == "dualstack" and balancer.scheme == "internal": + raise RESTError( + "InvalidConfigurationRequest", + "Internal load balancers cannot be dualstack", + ) balancer.stack = ip_type @@ -787,7 +962,10 @@ class ELBv2Backend(BaseBackend): # Check all security groups exist for sec_group_id in sec_groups: if self.ec2_backend.get_security_group_from_id(sec_group_id) is None: - raise RESTError('InvalidSecurityGroup', 'Security group {0} does not exist'.format(sec_group_id)) + raise RESTError( + "InvalidSecurityGroup", + "Security group {0} does not exist".format(sec_group_id), + ) balancer.security_groups = sec_groups @@ -803,7 +981,10 @@ class ELBv2Backend(BaseBackend): subnet = self.ec2_backend.get_subnet(subnet) if subnet.availability_zone in sub_zone_list: - raise RESTError('InvalidConfigurationRequest', 'More than 1 subnet cannot be specified for 1 availability zone') + raise RESTError( + "InvalidConfigurationRequest", + "More than 1 subnet cannot be specified for 1 availability zone", + ) sub_zone_list[subnet.availability_zone] = subnet.id subnet_objects.append(subnet) @@ -811,7 +992,10 @@ class ELBv2Backend(BaseBackend): raise SubnetNotFoundError() if len(sub_zone_list) < 2: - raise RESTError('InvalidConfigurationRequest', 'More than 1 availability zone must be specified') + raise RESTError( + "InvalidConfigurationRequest", + "More than 1 availability zone must be specified", + ) balancer.subnets = subnet_objects @@ -824,7 +1008,9 @@ class ELBv2Backend(BaseBackend): for key in attrs: if key not in FakeLoadBalancer.VALID_ATTRS: - raise RESTError('InvalidConfigurationRequest', 'Key {0} not valid'.format(key)) + raise RESTError( + "InvalidConfigurationRequest", "Key {0} not valid".format(key) + ) balancer.attrs.update(attrs) return balancer.attrs @@ -836,17 +1022,33 @@ class ELBv2Backend(BaseBackend): return balancer.attrs - def modify_target_group(self, arn, health_check_proto=None, health_check_port=None, health_check_path=None, health_check_interval=None, - health_check_timeout=None, healthy_threshold_count=None, unhealthy_threshold_count=None, http_codes=None): + def modify_target_group( + self, + arn, + health_check_proto=None, + health_check_port=None, + health_check_path=None, + health_check_interval=None, + health_check_timeout=None, + healthy_threshold_count=None, + unhealthy_threshold_count=None, + http_codes=None, + ): target_group = self.target_groups.get(arn) if target_group is None: raise TargetGroupNotFoundError() - if http_codes is not None and FakeTargetGroup.HTTP_CODE_REGEX.match(http_codes) is None: - raise RESTError('InvalidParameterValue', 'HttpCode must be like 200 | 200-399 | 200,201 ...') + if ( + http_codes is not None + and FakeTargetGroup.HTTP_CODE_REGEX.match(http_codes) is None + ): + raise RESTError( + "InvalidParameterValue", + "HttpCode must be like 200 | 200-399 | 200,201 ...", + ) if http_codes is not None: - target_group.matcher['HttpCode'] = http_codes + target_group.matcher["HttpCode"] = http_codes if health_check_interval is not None: target_group.healthcheck_interval_seconds = health_check_interval if health_check_path is not None: @@ -864,7 +1066,15 @@ class ELBv2Backend(BaseBackend): return target_group - def modify_listener(self, arn, port=None, protocol=None, ssl_policy=None, certificates=None, default_actions=None): + def modify_listener( + self, + arn, + port=None, + protocol=None, + ssl_policy=None, + certificates=None, + default_actions=None, + ): default_actions = [FakeAction(action) for action in default_actions] for load_balancer in self.load_balancers.values(): if arn in load_balancer.listeners: @@ -884,33 +1094,46 @@ class ELBv2Backend(BaseBackend): listener.port = port if protocol is not None: - if protocol not in ('HTTP', 'HTTPS', 'TCP'): - raise RESTError('UnsupportedProtocol', 'Protocol {0} is not supported'.format(protocol)) + if protocol not in ("HTTP", "HTTPS", "TCP"): + raise RESTError( + "UnsupportedProtocol", + "Protocol {0} is not supported".format(protocol), + ) # HTTPS checks - if protocol == 'HTTPS': + if protocol == "HTTPS": # HTTPS # Might already be HTTPS so may not provide certs - if certificates is None and listener.protocol != 'HTTPS': - raise RESTError('InvalidConfigurationRequest', 'Certificates must be provided for HTTPS') + if certificates is None and listener.protocol != "HTTPS": + raise RESTError( + "InvalidConfigurationRequest", + "Certificates must be provided for HTTPS", + ) # Check certificates exist if certificates is not None: default_cert = None all_certs = set() # for SNI for cert in certificates: - if cert['is_default'] == 'true': - default_cert = cert['certificate_arn'] + if cert["is_default"] == "true": + default_cert = cert["certificate_arn"] try: - self.acm_backend.get_certificate(cert['certificate_arn']) + self.acm_backend.get_certificate(cert["certificate_arn"]) except Exception: - raise RESTError('CertificateNotFound', 'Certificate {0} not found'.format(cert['certificate_arn'])) + raise RESTError( + "CertificateNotFound", + "Certificate {0} not found".format( + cert["certificate_arn"] + ), + ) - all_certs.add(cert['certificate_arn']) + all_certs.add(cert["certificate_arn"]) if default_cert is None: - raise RESTError('InvalidConfigurationRequest', 'No default certificate') + raise RESTError( + "InvalidConfigurationRequest", "No default certificate" + ) listener.certificate = default_cert listener.certificates = list(all_certs) @@ -932,10 +1155,14 @@ class ELBv2Backend(BaseBackend): for listener in load_balancer.listeners.values(): for rule in listener.rules: for action in rule.actions: - if action.data.get('target_group_arn') == target_group_arn: + if action.data.get("target_group_arn") == target_group_arn: return True return False + def notify_terminate_instances(self, instance_ids): + for target_group in self.target_groups.values(): + target_group.deregister_terminated_instances(instance_ids) + elbv2_backends = {} for region in ec2_backends.keys(): diff --git a/moto/elbv2/responses.py b/moto/elbv2/responses.py index 25c23bb17..922de96d4 100644 --- a/moto/elbv2/responses.py +++ b/moto/elbv2/responses.py @@ -10,120 +10,120 @@ from .exceptions import TargetGroupNotFoundError SSL_POLICIES = [ { - 'name': 'ELBSecurityPolicy-2016-08', - 'ssl_protocols': ['TLSv1', 'TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18} + "name": "ELBSecurityPolicy-2016-08", + "ssl_protocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, ], }, { - 'name': 'ELBSecurityPolicy-TLS-1-2-2017-01', - 'ssl_protocols': ['TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 5}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 8}, - {'name': 'AES128-GCM-SHA256', 'priority': 9}, - {'name': 'AES128-SHA256', 'priority': 10}, - {'name': 'AES256-GCM-SHA384', 'priority': 11}, - {'name': 'AES256-SHA256', 'priority': 12} - ] + "name": "ELBSecurityPolicy-TLS-1-2-2017-01", + "ssl_protocols": ["TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 5}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 8}, + {"name": "AES128-GCM-SHA256", "priority": 9}, + {"name": "AES128-SHA256", "priority": 10}, + {"name": "AES256-GCM-SHA384", "priority": 11}, + {"name": "AES256-SHA256", "priority": 12}, + ], }, { - 'name': 'ELBSecurityPolicy-TLS-1-1-2017-01', - 'ssl_protocols': ['TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18} - ] + "name": "ELBSecurityPolicy-TLS-1-1-2017-01", + "ssl_protocols": ["TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, + ], }, { - 'name': 'ELBSecurityPolicy-2015-05', - 'ssl_protocols': ['TLSv1', 'TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18} - ] + "name": "ELBSecurityPolicy-2015-05", + "ssl_protocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, + ], }, { - 'name': 'ELBSecurityPolicy-TLS-1-0-2015-04', - 'ssl_protocols': ['TLSv1', 'TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18}, - {'name': 'DES-CBC3-SHA', 'priority': 19} - ] - } + "name": "ELBSecurityPolicy-TLS-1-0-2015-04", + "ssl_protocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, + {"name": "DES-CBC3-SHA", "priority": 19}, + ], + }, ] @@ -134,10 +134,10 @@ class ELBV2Response(BaseResponse): @amzn_request_id def create_load_balancer(self): - load_balancer_name = self._get_param('Name') + load_balancer_name = self._get_param("Name") subnet_ids = self._get_multi_param("Subnets.member") security_groups = self._get_multi_param("SecurityGroups.member") - scheme = self._get_param('Scheme') + scheme = self._get_param("Scheme") load_balancer = self.elbv2_backend.create_load_balancer( name=load_balancer_name, @@ -151,43 +151,43 @@ class ELBV2Response(BaseResponse): @amzn_request_id def create_rule(self): - lister_arn = self._get_param('ListenerArn') - _conditions = self._get_list_prefix('Conditions.member') + lister_arn = self._get_param("ListenerArn") + _conditions = self._get_list_prefix("Conditions.member") conditions = [] for _condition in _conditions: condition = {} - condition['field'] = _condition['field'] + condition["field"] = _condition["field"] values = sorted( - [e for e in _condition.items() if e[0].startswith('values.member')], - key=lambda x: x[0] + [e for e in _condition.items() if e[0].startswith("values.member")], + key=lambda x: x[0], ) - condition['values'] = [e[1] for e in values] + condition["values"] = [e[1] for e in values] conditions.append(condition) - priority = self._get_int_param('Priority') - actions = self._get_list_prefix('Actions.member') + priority = self._get_int_param("Priority") + actions = self._get_list_prefix("Actions.member") rules = self.elbv2_backend.create_rule( listener_arn=lister_arn, conditions=conditions, priority=priority, - actions=actions + actions=actions, ) template = self.response_template(CREATE_RULE_TEMPLATE) return template.render(rules=rules) @amzn_request_id def create_target_group(self): - name = self._get_param('Name') - vpc_id = self._get_param('VpcId') - protocol = self._get_param('Protocol') - port = self._get_param('Port') - healthcheck_protocol = self._get_param('HealthCheckProtocol') - healthcheck_port = self._get_param('HealthCheckPort') - healthcheck_path = self._get_param('HealthCheckPath') - healthcheck_interval_seconds = self._get_param('HealthCheckIntervalSeconds') - healthcheck_timeout_seconds = self._get_param('HealthCheckTimeoutSeconds') - healthy_threshold_count = self._get_param('HealthyThresholdCount') - unhealthy_threshold_count = self._get_param('UnhealthyThresholdCount') - matcher = self._get_param('Matcher') + name = self._get_param("Name") + vpc_id = self._get_param("VpcId") + protocol = self._get_param("Protocol") + port = self._get_param("Port") + healthcheck_protocol = self._get_param("HealthCheckProtocol") + healthcheck_port = self._get_param("HealthCheckPort") + healthcheck_path = self._get_param("HealthCheckPath") + healthcheck_interval_seconds = self._get_param("HealthCheckIntervalSeconds") + healthcheck_timeout_seconds = self._get_param("HealthCheckTimeoutSeconds") + healthy_threshold_count = self._get_param("HealthyThresholdCount") + unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount") + matcher = self._get_param("Matcher") target_group = self.elbv2_backend.create_target_group( name, @@ -209,16 +209,16 @@ class ELBV2Response(BaseResponse): @amzn_request_id def create_listener(self): - load_balancer_arn = self._get_param('LoadBalancerArn') - protocol = self._get_param('Protocol') - port = self._get_param('Port') - ssl_policy = self._get_param('SslPolicy', 'ELBSecurityPolicy-2016-08') - certificates = self._get_list_prefix('Certificates.member') + load_balancer_arn = self._get_param("LoadBalancerArn") + protocol = self._get_param("Protocol") + port = self._get_param("Port") + ssl_policy = self._get_param("SslPolicy", "ELBSecurityPolicy-2016-08") + certificates = self._get_list_prefix("Certificates.member") if certificates: - certificate = certificates[0].get('certificate_arn') + certificate = certificates[0].get("certificate_arn") else: certificate = None - default_actions = self._get_list_prefix('DefaultActions.member') + default_actions = self._get_list_prefix("DefaultActions.member") listener = self.elbv2_backend.create_listener( load_balancer_arn=load_balancer_arn, @@ -226,7 +226,8 @@ class ELBV2Response(BaseResponse): port=port, ssl_policy=ssl_policy, certificate=certificate, - default_actions=default_actions) + default_actions=default_actions, + ) template = self.response_template(CREATE_LISTENER_TEMPLATE) return template.render(listener=listener) @@ -235,15 +236,19 @@ class ELBV2Response(BaseResponse): def describe_load_balancers(self): arns = self._get_multi_param("LoadBalancerArns.member") names = self._get_multi_param("Names.member") - all_load_balancers = list(self.elbv2_backend.describe_load_balancers(arns, names)) - marker = self._get_param('Marker') + all_load_balancers = list( + self.elbv2_backend.describe_load_balancers(arns, names) + ) + marker = self._get_param("Marker") all_names = [balancer.name for balancer in all_load_balancers] if marker: start = all_names.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('PageSize', 50) # the default is 400, but using 50 to make testing easier - load_balancers_resp = all_load_balancers[start:start + page_size] + page_size = self._get_int_param( + "PageSize", 50 + ) # the default is 400, but using 50 to make testing easier + load_balancers_resp = all_load_balancers[start : start + page_size] next_marker = None if len(all_load_balancers) > start + page_size: next_marker = load_balancers_resp[-1].name @@ -253,18 +258,26 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_rules(self): - listener_arn = self._get_param('ListenerArn') - rule_arns = self._get_multi_param('RuleArns.member') if any(k for k in list(self.querystring.keys()) if k.startswith('RuleArns.member')) else None + listener_arn = self._get_param("ListenerArn") + rule_arns = ( + self._get_multi_param("RuleArns.member") + if any( + k + for k in list(self.querystring.keys()) + if k.startswith("RuleArns.member") + ) + else None + ) all_rules = self.elbv2_backend.describe_rules(listener_arn, rule_arns) all_arns = [rule.arn for rule in all_rules] - page_size = self._get_int_param('PageSize', 50) # set 50 for temporary + page_size = self._get_int_param("PageSize", 50) # set 50 for temporary - marker = self._get_param('Marker') + marker = self._get_param("Marker") if marker: start = all_arns.index(marker) + 1 else: start = 0 - rules_resp = all_rules[start:start + page_size] + rules_resp = all_rules[start : start + page_size] next_marker = None if len(all_rules) > start + page_size: @@ -274,17 +287,19 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_target_groups(self): - load_balancer_arn = self._get_param('LoadBalancerArn') - target_group_arns = self._get_multi_param('TargetGroupArns.member') - names = self._get_multi_param('Names.member') + load_balancer_arn = self._get_param("LoadBalancerArn") + target_group_arns = self._get_multi_param("TargetGroupArns.member") + names = self._get_multi_param("Names.member") - target_groups = self.elbv2_backend.describe_target_groups(load_balancer_arn, target_group_arns, names) + target_groups = self.elbv2_backend.describe_target_groups( + load_balancer_arn, target_group_arns, names + ) template = self.response_template(DESCRIBE_TARGET_GROUPS_TEMPLATE) return template.render(target_groups=target_groups) @amzn_request_id def describe_target_group_attributes(self): - target_group_arn = self._get_param('TargetGroupArn') + target_group_arn = self._get_param("TargetGroupArn") target_group = self.elbv2_backend.target_groups.get(target_group_arn) if not target_group: raise TargetGroupNotFoundError() @@ -293,73 +308,73 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_listeners(self): - load_balancer_arn = self._get_param('LoadBalancerArn') - listener_arns = self._get_multi_param('ListenerArns.member') + load_balancer_arn = self._get_param("LoadBalancerArn") + listener_arns = self._get_multi_param("ListenerArns.member") if not load_balancer_arn and not listener_arns: raise LoadBalancerNotFoundError() - listeners = self.elbv2_backend.describe_listeners(load_balancer_arn, listener_arns) + listeners = self.elbv2_backend.describe_listeners( + load_balancer_arn, listener_arns + ) template = self.response_template(DESCRIBE_LISTENERS_TEMPLATE) return template.render(listeners=listeners) @amzn_request_id def delete_load_balancer(self): - arn = self._get_param('LoadBalancerArn') + arn = self._get_param("LoadBalancerArn") self.elbv2_backend.delete_load_balancer(arn) template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE) return template.render() @amzn_request_id def delete_rule(self): - arn = self._get_param('RuleArn') + arn = self._get_param("RuleArn") self.elbv2_backend.delete_rule(arn) template = self.response_template(DELETE_RULE_TEMPLATE) return template.render() @amzn_request_id def delete_target_group(self): - arn = self._get_param('TargetGroupArn') + arn = self._get_param("TargetGroupArn") self.elbv2_backend.delete_target_group(arn) template = self.response_template(DELETE_TARGET_GROUP_TEMPLATE) return template.render() @amzn_request_id def delete_listener(self): - arn = self._get_param('ListenerArn') + arn = self._get_param("ListenerArn") self.elbv2_backend.delete_listener(arn) template = self.response_template(DELETE_LISTENER_TEMPLATE) return template.render() @amzn_request_id def modify_rule(self): - rule_arn = self._get_param('RuleArn') - _conditions = self._get_list_prefix('Conditions.member') + rule_arn = self._get_param("RuleArn") + _conditions = self._get_list_prefix("Conditions.member") conditions = [] for _condition in _conditions: condition = {} - condition['field'] = _condition['field'] + condition["field"] = _condition["field"] values = sorted( - [e for e in _condition.items() if e[0].startswith('values.member')], - key=lambda x: x[0] + [e for e in _condition.items() if e[0].startswith("values.member")], + key=lambda x: x[0], ) - condition['values'] = [e[1] for e in values] + condition["values"] = [e[1] for e in values] conditions.append(condition) - actions = self._get_list_prefix('Actions.member') + actions = self._get_list_prefix("Actions.member") rules = self.elbv2_backend.modify_rule( - rule_arn=rule_arn, - conditions=conditions, - actions=actions + rule_arn=rule_arn, conditions=conditions, actions=actions ) template = self.response_template(MODIFY_RULE_TEMPLATE) return template.render(rules=rules) @amzn_request_id def modify_target_group_attributes(self): - target_group_arn = self._get_param('TargetGroupArn') + target_group_arn = self._get_param("TargetGroupArn") target_group = self.elbv2_backend.target_groups.get(target_group_arn) attributes = { - attr['key']: attr['value'] - for attr in self._get_list_prefix('Attributes.member') + attr["key"]: attr["value"] + for attr in self._get_list_prefix("Attributes.member") } target_group.attributes.update(attributes) if not target_group: @@ -369,8 +384,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def register_targets(self): - target_group_arn = self._get_param('TargetGroupArn') - targets = self._get_list_prefix('Targets.member') + target_group_arn = self._get_param("TargetGroupArn") + targets = self._get_list_prefix("Targets.member") self.elbv2_backend.register_targets(target_group_arn, targets) template = self.response_template(REGISTER_TARGETS_TEMPLATE) @@ -378,8 +393,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def deregister_targets(self): - target_group_arn = self._get_param('TargetGroupArn') - targets = self._get_list_prefix('Targets.member') + target_group_arn = self._get_param("TargetGroupArn") + targets = self._get_list_prefix("Targets.member") self.elbv2_backend.deregister_targets(target_group_arn, targets) template = self.response_template(DEREGISTER_TARGETS_TEMPLATE) @@ -387,32 +402,34 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_target_health(self): - target_group_arn = self._get_param('TargetGroupArn') - targets = self._get_list_prefix('Targets.member') - target_health_descriptions = self.elbv2_backend.describe_target_health(target_group_arn, targets) + target_group_arn = self._get_param("TargetGroupArn") + targets = self._get_list_prefix("Targets.member") + target_health_descriptions = self.elbv2_backend.describe_target_health( + target_group_arn, targets + ) template = self.response_template(DESCRIBE_TARGET_HEALTH_TEMPLATE) return template.render(target_health_descriptions=target_health_descriptions) @amzn_request_id def set_rule_priorities(self): - rule_priorities = self._get_list_prefix('RulePriorities.member') + rule_priorities = self._get_list_prefix("RulePriorities.member") for rule_priority in rule_priorities: - rule_priority['priority'] = int(rule_priority['priority']) + rule_priority["priority"] = int(rule_priority["priority"]) rules = self.elbv2_backend.set_rule_priorities(rule_priorities) template = self.response_template(SET_RULE_PRIORITIES_TEMPLATE) return template.render(rules=rules) @amzn_request_id def add_tags(self): - resource_arns = self._get_multi_param('ResourceArns.member') + resource_arns = self._get_multi_param("ResourceArns.member") for arn in resource_arns: - if ':targetgroup' in arn: + if ":targetgroup" in arn: resource = self.elbv2_backend.target_groups.get(arn) if not resource: raise TargetGroupNotFoundError() - elif ':loadbalancer' in arn: + elif ":loadbalancer" in arn: resource = self.elbv2_backend.load_balancers.get(arn) if not resource: raise LoadBalancerNotFoundError() @@ -425,15 +442,15 @@ class ELBV2Response(BaseResponse): @amzn_request_id def remove_tags(self): - resource_arns = self._get_multi_param('ResourceArns.member') - tag_keys = self._get_multi_param('TagKeys.member') + resource_arns = self._get_multi_param("ResourceArns.member") + tag_keys = self._get_multi_param("TagKeys.member") for arn in resource_arns: - if ':targetgroup' in arn: + if ":targetgroup" in arn: resource = self.elbv2_backend.target_groups.get(arn) if not resource: raise TargetGroupNotFoundError() - elif ':loadbalancer' in arn: + elif ":loadbalancer" in arn: resource = self.elbv2_backend.load_balancers.get(arn) if not resource: raise LoadBalancerNotFoundError() @@ -446,14 +463,14 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_tags(self): - resource_arns = self._get_multi_param('ResourceArns.member') + resource_arns = self._get_multi_param("ResourceArns.member") resources = [] for arn in resource_arns: - if ':targetgroup' in arn: + if ":targetgroup" in arn: resource = self.elbv2_backend.target_groups.get(arn) if not resource: raise TargetGroupNotFoundError() - elif ':loadbalancer' in arn: + elif ":loadbalancer" in arn: resource = self.elbv2_backend.load_balancers.get(arn) if not resource: raise LoadBalancerNotFoundError() @@ -471,14 +488,14 @@ class ELBV2Response(BaseResponse): # page_size = self._get_int_param('PageSize') limits = { - 'application-load-balancers': 20, - 'target-groups': 3000, - 'targets-per-application-load-balancer': 30, - 'listeners-per-application-load-balancer': 50, - 'rules-per-application-load-balancer': 100, - 'network-load-balancers': 20, - 'targets-per-network-load-balancer': 200, - 'listeners-per-network-load-balancer': 50 + "application-load-balancers": 20, + "target-groups": 3000, + "targets-per-application-load-balancer": 30, + "listeners-per-application-load-balancer": 50, + "rules-per-application-load-balancer": 100, + "network-load-balancers": 20, + "targets-per-network-load-balancer": 200, + "listeners-per-network-load-balancer": 50, } template = self.response_template(DESCRIBE_LIMITS_TEMPLATE) @@ -486,22 +503,22 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_ssl_policies(self): - names = self._get_multi_param('Names.member.') + names = self._get_multi_param("Names.member.") # Supports paging but not worth implementing yet # marker = self._get_param('Marker') # page_size = self._get_int_param('PageSize') policies = SSL_POLICIES if names: - policies = filter(lambda policy: policy['name'] in names, policies) + policies = filter(lambda policy: policy["name"] in names, policies) template = self.response_template(DESCRIBE_SSL_POLICIES_TEMPLATE) return template.render(policies=policies) @amzn_request_id def set_ip_address_type(self): - arn = self._get_param('LoadBalancerArn') - ip_type = self._get_param('IpAddressType') + arn = self._get_param("LoadBalancerArn") + ip_type = self._get_param("IpAddressType") self.elbv2_backend.set_ip_address_type(arn, ip_type) @@ -510,8 +527,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def set_security_groups(self): - arn = self._get_param('LoadBalancerArn') - sec_groups = self._get_multi_param('SecurityGroups.member.') + arn = self._get_param("LoadBalancerArn") + sec_groups = self._get_multi_param("SecurityGroups.member.") self.elbv2_backend.set_security_groups(arn, sec_groups) @@ -520,8 +537,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def set_subnets(self): - arn = self._get_param('LoadBalancerArn') - subnets = self._get_multi_param('Subnets.member.') + arn = self._get_param("LoadBalancerArn") + subnets = self._get_multi_param("Subnets.member.") subnet_zone_list = self.elbv2_backend.set_subnets(arn, subnets) @@ -530,8 +547,10 @@ class ELBV2Response(BaseResponse): @amzn_request_id def modify_load_balancer_attributes(self): - arn = self._get_param('LoadBalancerArn') - attrs = self._get_map_prefix('Attributes.member', key_end='Key', value_end='Value') + arn = self._get_param("LoadBalancerArn") + attrs = self._get_map_prefix( + "Attributes.member", key_end="Key", value_end="Value" + ) all_attrs = self.elbv2_backend.modify_load_balancer_attributes(arn, attrs) @@ -540,7 +559,7 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_load_balancer_attributes(self): - arn = self._get_param('LoadBalancerArn') + arn = self._get_param("LoadBalancerArn") attrs = self.elbv2_backend.describe_load_balancer_attributes(arn) template = self.response_template(DESCRIBE_LOADBALANCER_ATTRS_TEMPLATE) @@ -548,37 +567,54 @@ class ELBV2Response(BaseResponse): @amzn_request_id def modify_target_group(self): - arn = self._get_param('TargetGroupArn') + arn = self._get_param("TargetGroupArn") - health_check_proto = self._get_param('HealthCheckProtocol') # 'HTTP' | 'HTTPS' | 'TCP', - health_check_port = self._get_param('HealthCheckPort') - health_check_path = self._get_param('HealthCheckPath') - health_check_interval = self._get_param('HealthCheckIntervalSeconds') - health_check_timeout = self._get_param('HealthCheckTimeoutSeconds') - healthy_threshold_count = self._get_param('HealthyThresholdCount') - unhealthy_threshold_count = self._get_param('UnhealthyThresholdCount') - http_codes = self._get_param('Matcher.HttpCode') + health_check_proto = self._get_param( + "HealthCheckProtocol" + ) # 'HTTP' | 'HTTPS' | 'TCP', + health_check_port = self._get_param("HealthCheckPort") + health_check_path = self._get_param("HealthCheckPath") + health_check_interval = self._get_param("HealthCheckIntervalSeconds") + health_check_timeout = self._get_param("HealthCheckTimeoutSeconds") + healthy_threshold_count = self._get_param("HealthyThresholdCount") + unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount") + http_codes = self._get_param("Matcher.HttpCode") - target_group = self.elbv2_backend.modify_target_group(arn, health_check_proto, health_check_port, health_check_path, health_check_interval, - health_check_timeout, healthy_threshold_count, unhealthy_threshold_count, http_codes) + target_group = self.elbv2_backend.modify_target_group( + arn, + health_check_proto, + health_check_port, + health_check_path, + health_check_interval, + health_check_timeout, + healthy_threshold_count, + unhealthy_threshold_count, + http_codes, + ) template = self.response_template(MODIFY_TARGET_GROUP_TEMPLATE) return template.render(target_group=target_group) @amzn_request_id def modify_listener(self): - arn = self._get_param('ListenerArn') - port = self._get_param('Port') - protocol = self._get_param('Protocol') - ssl_policy = self._get_param('SslPolicy') - certificates = self._get_list_prefix('Certificates.member') - default_actions = self._get_list_prefix('DefaultActions.member') + arn = self._get_param("ListenerArn") + port = self._get_param("Port") + protocol = self._get_param("Protocol") + ssl_policy = self._get_param("SslPolicy") + certificates = self._get_list_prefix("Certificates.member") + default_actions = self._get_list_prefix("DefaultActions.member") # Should really move SSL Policies to models - if ssl_policy is not None and ssl_policy not in [item['name'] for item in SSL_POLICIES]: - raise RESTError('SSLPolicyNotFound', 'Policy {0} not found'.format(ssl_policy)) + if ssl_policy is not None and ssl_policy not in [ + item["name"] for item in SSL_POLICIES + ]: + raise RESTError( + "SSLPolicyNotFound", "Policy {0} not found".format(ssl_policy) + ) - listener = self.elbv2_backend.modify_listener(arn, port, protocol, ssl_policy, certificates, default_actions) + listener = self.elbv2_backend.modify_listener( + arn, port, protocol, ssl_policy, certificates, default_actions + ) template = self.response_template(MODIFY_LISTENER_TEMPLATE) return template.render(listener=listener) @@ -588,10 +624,10 @@ class ELBV2Response(BaseResponse): tag_keys = [] for t_key, t_val in sorted(self.querystring.items()): - if t_key.startswith('Tags.member.'): - if t_key.split('.')[3] == 'Key': + if t_key.startswith("Tags.member."): + if t_key.split(".")[3] == "Key": tag_keys.extend(t_val) - elif t_key.split('.')[3] == 'Value': + elif t_key.split(".")[3] == "Value": tag_values.extend(t_val) counts = {} diff --git a/moto/elbv2/urls.py b/moto/elbv2/urls.py index af51f7d3a..06b8f107e 100644 --- a/moto/elbv2/urls.py +++ b/moto/elbv2/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from ..elb.urls import api_version_elb_backend -url_bases = [ - "https?://elasticloadbalancing.(.+).amazonaws.com", -] +url_bases = ["https?://elasticloadbalancing.(.+).amazonaws.com"] -url_paths = { - '{0}/$': api_version_elb_backend, -} +url_paths = {"{0}/$": api_version_elb_backend} diff --git a/moto/elbv2/utils.py b/moto/elbv2/utils.py index 47a3e66d5..017878e2f 100644 --- a/moto/elbv2/utils.py +++ b/moto/elbv2/utils.py @@ -1,8 +1,10 @@ def make_arn_for_load_balancer(account_id, name, region_name): return "arn:aws:elasticloadbalancing:{}:{}:loadbalancer/{}/50dc6c495c0c9188".format( - region_name, account_id, name) + region_name, account_id, name + ) def make_arn_for_target_group(account_id, name, region_name): return "arn:aws:elasticloadbalancing:{}:{}:targetgroup/{}/50dc6c495c0c9188".format( - region_name, account_id, name) + region_name, account_id, name + ) diff --git a/moto/emr/__init__.py b/moto/emr/__init__.py index b4223f2cb..d35506271 100644 --- a/moto/emr/__init__.py +++ b/moto/emr/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import emr_backends from ..core.models import base_decorator, deprecated_base_decorator -emr_backend = emr_backends['us-east-1'] +emr_backend = emr_backends["us-east-1"] mock_emr = base_decorator(emr_backends) mock_emr_deprecated = deprecated_base_decorator(emr_backends) diff --git a/moto/emr/models.py b/moto/emr/models.py index 4b591acb1..b62ce7932 100644 --- a/moto/emr/models.py +++ b/moto/emr/models.py @@ -11,7 +11,6 @@ from .utils import random_instance_group_id, random_cluster_id, random_step_id class FakeApplication(BaseModel): - def __init__(self, name, version, args=None, additional_info=None): self.additional_info = additional_info or {} self.args = args or [] @@ -20,7 +19,6 @@ class FakeApplication(BaseModel): class FakeBootstrapAction(BaseModel): - def __init__(self, args, name, script_path): self.args = args or [] self.name = name @@ -28,20 +26,27 @@ class FakeBootstrapAction(BaseModel): class FakeInstanceGroup(BaseModel): - - def __init__(self, instance_count, instance_role, instance_type, - market='ON_DEMAND', name=None, id=None, bid_price=None): + def __init__( + self, + instance_count, + instance_role, + instance_type, + market="ON_DEMAND", + name=None, + id=None, + bid_price=None, + ): self.id = id or random_instance_group_id() self.bid_price = bid_price self.market = market if name is None: - if instance_role == 'MASTER': - name = 'master' - elif instance_role == 'CORE': - name = 'slave' + if instance_role == "MASTER": + name = "master" + elif instance_role == "CORE": + name = "slave" else: - name = 'Task instance group' + name = "Task instance group" self.name = name self.num_instances = instance_count self.role = instance_role @@ -51,21 +56,22 @@ class FakeInstanceGroup(BaseModel): self.start_datetime = datetime.now(pytz.utc) self.ready_datetime = datetime.now(pytz.utc) self.end_datetime = None - self.state = 'RUNNING' + self.state = "RUNNING" def set_instance_count(self, instance_count): self.num_instances = instance_count class FakeStep(BaseModel): - - def __init__(self, - state, - name='', - jar='', - args=None, - properties=None, - action_on_failure='TERMINATE_CLUSTER'): + def __init__( + self, + state, + name="", + jar="", + args=None, + properties=None, + action_on_failure="TERMINATE_CLUSTER", + ): self.id = random_step_id() self.action_on_failure = action_on_failure @@ -82,23 +88,24 @@ class FakeStep(BaseModel): class FakeCluster(BaseModel): - - def __init__(self, - emr_backend, - name, - log_uri, - job_flow_role, - service_role, - steps, - instance_attrs, - bootstrap_actions=None, - configurations=None, - cluster_id=None, - visible_to_all_users='false', - release_label=None, - requested_ami_version=None, - running_ami_version=None, - custom_ami_id=None): + def __init__( + self, + emr_backend, + name, + log_uri, + job_flow_role, + service_role, + steps, + instance_attrs, + bootstrap_actions=None, + configurations=None, + cluster_id=None, + visible_to_all_users="false", + release_label=None, + requested_ami_version=None, + running_ami_version=None, + custom_ami_id=None, + ): self.id = cluster_id or random_cluster_id() emr_backend.clusters[self.id] = self self.emr_backend = emr_backend @@ -106,7 +113,7 @@ class FakeCluster(BaseModel): self.applications = [] self.bootstrap_actions = [] - for bootstrap_action in (bootstrap_actions or []): + for bootstrap_action in bootstrap_actions or []: self.add_bootstrap_action(bootstrap_action) self.configurations = configurations or [] @@ -125,47 +132,68 @@ class FakeCluster(BaseModel): self.instance_group_ids = [] self.master_instance_group_id = None self.core_instance_group_id = None - if 'master_instance_type' in instance_attrs and instance_attrs['master_instance_type']: + if ( + "master_instance_type" in instance_attrs + and instance_attrs["master_instance_type"] + ): self.emr_backend.add_instance_groups( self.id, - [{'instance_count': 1, - 'instance_role': 'MASTER', - 'instance_type': instance_attrs['master_instance_type'], - 'market': 'ON_DEMAND', - 'name': 'master'}]) - if 'slave_instance_type' in instance_attrs and instance_attrs['slave_instance_type']: + [ + { + "instance_count": 1, + "instance_role": "MASTER", + "instance_type": instance_attrs["master_instance_type"], + "market": "ON_DEMAND", + "name": "master", + } + ], + ) + if ( + "slave_instance_type" in instance_attrs + and instance_attrs["slave_instance_type"] + ): self.emr_backend.add_instance_groups( self.id, - [{'instance_count': instance_attrs['instance_count'] - 1, - 'instance_role': 'CORE', - 'instance_type': instance_attrs['slave_instance_type'], - 'market': 'ON_DEMAND', - 'name': 'slave'}]) + [ + { + "instance_count": instance_attrs["instance_count"] - 1, + "instance_role": "CORE", + "instance_type": instance_attrs["slave_instance_type"], + "market": "ON_DEMAND", + "name": "slave", + } + ], + ) self.additional_master_security_groups = instance_attrs.get( - 'additional_master_security_groups') + "additional_master_security_groups" + ) self.additional_slave_security_groups = instance_attrs.get( - 'additional_slave_security_groups') - self.availability_zone = instance_attrs.get('availability_zone') - self.ec2_key_name = instance_attrs.get('ec2_key_name') - self.ec2_subnet_id = instance_attrs.get('ec2_subnet_id') - self.hadoop_version = instance_attrs.get('hadoop_version') + "additional_slave_security_groups" + ) + self.availability_zone = instance_attrs.get("availability_zone") + self.ec2_key_name = instance_attrs.get("ec2_key_name") + self.ec2_subnet_id = instance_attrs.get("ec2_subnet_id") + self.hadoop_version = instance_attrs.get("hadoop_version") self.keep_job_flow_alive_when_no_steps = instance_attrs.get( - 'keep_job_flow_alive_when_no_steps') + "keep_job_flow_alive_when_no_steps" + ) self.master_security_group = instance_attrs.get( - 'emr_managed_master_security_group') + "emr_managed_master_security_group" + ) self.service_access_security_group = instance_attrs.get( - 'service_access_security_group') + "service_access_security_group" + ) self.slave_security_group = instance_attrs.get( - 'emr_managed_slave_security_group') - self.termination_protected = instance_attrs.get( - 'termination_protected') + "emr_managed_slave_security_group" + ) + self.termination_protected = instance_attrs.get("termination_protected") self.release_label = release_label self.requested_ami_version = requested_ami_version self.running_ami_version = running_ami_version self.custom_ami_id = custom_ami_id - self.role = job_flow_role or 'EMRJobflowDefault' + self.role = job_flow_role or "EMRJobflowDefault" self.service_role = service_role self.creation_datetime = datetime.now(pytz.utc) @@ -194,42 +222,46 @@ class FakeCluster(BaseModel): return sum(group.num_instances for group in self.instance_groups) def start_cluster(self): - self.state = 'STARTING' + self.state = "STARTING" self.start_datetime = datetime.now(pytz.utc) def run_bootstrap_actions(self): - self.state = 'BOOTSTRAPPING' + self.state = "BOOTSTRAPPING" self.ready_datetime = datetime.now(pytz.utc) - self.state = 'WAITING' + self.state = "WAITING" if not self.steps: if not self.keep_job_flow_alive_when_no_steps: self.terminate() def terminate(self): - self.state = 'TERMINATING' + self.state = "TERMINATING" self.end_datetime = datetime.now(pytz.utc) - self.state = 'TERMINATED' + self.state = "TERMINATED" def add_applications(self, applications): - self.applications.extend([ - FakeApplication( - name=app.get('name', ''), - version=app.get('version', ''), - args=app.get('args', []), - additional_info=app.get('additiona_info', {})) - for app in applications]) + self.applications.extend( + [ + FakeApplication( + name=app.get("name", ""), + version=app.get("version", ""), + args=app.get("args", []), + additional_info=app.get("additiona_info", {}), + ) + for app in applications + ] + ) def add_bootstrap_action(self, bootstrap_action): self.bootstrap_actions.append(FakeBootstrapAction(**bootstrap_action)) def add_instance_group(self, instance_group): - if instance_group.role == 'MASTER': + if instance_group.role == "MASTER": if self.master_instance_group_id: - raise Exception('Cannot add another master instance group') + raise Exception("Cannot add another master instance group") self.master_instance_group_id = instance_group.id - if instance_group.role == 'CORE': + if instance_group.role == "CORE": if self.core_instance_group_id: - raise Exception('Cannot add another core instance group') + raise Exception("Cannot add another core instance group") self.core_instance_group_id = instance_group.id self.instance_group_ids.append(instance_group.id) @@ -238,12 +270,12 @@ class FakeCluster(BaseModel): for step in steps: if self.steps: # If we already have other steps, this one is pending - fake = FakeStep(state='PENDING', **step) + fake = FakeStep(state="PENDING", **step) else: - fake = FakeStep(state='STARTING', **step) + fake = FakeStep(state="STARTING", **step) self.steps.append(fake) added_steps.append(fake) - self.state = 'RUNNING' + self.state = "RUNNING" return added_steps def add_tags(self, tags): @@ -261,7 +293,6 @@ class FakeCluster(BaseModel): class ElasticMapReduceBackend(BaseBackend): - def __init__(self, region_name): super(ElasticMapReduceBackend, self).__init__() self.region_name = region_name @@ -296,12 +327,17 @@ class ElasticMapReduceBackend(BaseBackend): cluster = self.get_cluster(cluster_id) cluster.add_tags(tags) - def describe_job_flows(self, job_flow_ids=None, job_flow_states=None, created_after=None, created_before=None): + def describe_job_flows( + self, + job_flow_ids=None, + job_flow_states=None, + created_after=None, + created_before=None, + ): clusters = self.clusters.values() within_two_month = datetime.now(pytz.utc) - timedelta(days=60) - clusters = [ - c for c in clusters if c.creation_datetime >= within_two_month] + clusters = [c for c in clusters if c.creation_datetime >= within_two_month] if job_flow_ids: clusters = [c for c in clusters if c.id in job_flow_ids] @@ -309,12 +345,10 @@ class ElasticMapReduceBackend(BaseBackend): clusters = [c for c in clusters if c.state in job_flow_states] if created_after: created_after = dtparse(created_after) - clusters = [ - c for c in clusters if c.creation_datetime > created_after] + clusters = [c for c in clusters if c.creation_datetime > created_after] if created_before: created_before = dtparse(created_before) - clusters = [ - c for c in clusters if c.creation_datetime < created_before] + clusters = [c for c in clusters if c.creation_datetime < created_before] # Amazon EMR can return a maximum of 512 job flow descriptions return sorted(clusters, key=lambda x: x.id)[:512] @@ -328,12 +362,12 @@ class ElasticMapReduceBackend(BaseBackend): def get_cluster(self, cluster_id): if cluster_id in self.clusters: return self.clusters[cluster_id] - raise EmrError('ResourceNotFoundException', '', 'error_json') + raise EmrError("ResourceNotFoundException", "", "error_json") def get_instance_groups(self, instance_group_ids): return [ - group for group_id, group - in self.instance_groups.items() + group + for group_id, group in self.instance_groups.items() if group_id in instance_group_ids ] @@ -341,38 +375,43 @@ class ElasticMapReduceBackend(BaseBackend): max_items = 50 actions = self.clusters[cluster_id].bootstrap_actions start_idx = 0 if marker is None else int(marker) - marker = None if len(actions) <= start_idx + \ - max_items else str(start_idx + max_items) - return actions[start_idx:start_idx + max_items], marker + marker = ( + None + if len(actions) <= start_idx + max_items + else str(start_idx + max_items) + ) + return actions[start_idx : start_idx + max_items], marker - def list_clusters(self, cluster_states=None, created_after=None, - created_before=None, marker=None): + def list_clusters( + self, cluster_states=None, created_after=None, created_before=None, marker=None + ): max_items = 50 clusters = self.clusters.values() if cluster_states: clusters = [c for c in clusters if c.state in cluster_states] if created_after: created_after = dtparse(created_after) - clusters = [ - c for c in clusters if c.creation_datetime > created_after] + clusters = [c for c in clusters if c.creation_datetime > created_after] if created_before: created_before = dtparse(created_before) - clusters = [ - c for c in clusters if c.creation_datetime < created_before] + clusters = [c for c in clusters if c.creation_datetime < created_before] clusters = sorted(clusters, key=lambda x: x.id) start_idx = 0 if marker is None else int(marker) - marker = None if len(clusters) <= start_idx + \ - max_items else str(start_idx + max_items) - return clusters[start_idx:start_idx + max_items], marker + marker = ( + None + if len(clusters) <= start_idx + max_items + else str(start_idx + max_items) + ) + return clusters[start_idx : start_idx + max_items], marker def list_instance_groups(self, cluster_id, marker=None): max_items = 50 - groups = sorted(self.clusters[cluster_id].instance_groups, - key=lambda x: x.id) + groups = sorted(self.clusters[cluster_id].instance_groups, key=lambda x: x.id) start_idx = 0 if marker is None else int(marker) - marker = None if len(groups) <= start_idx + \ - max_items else str(start_idx + max_items) - return groups[start_idx:start_idx + max_items], marker + marker = ( + None if len(groups) <= start_idx + max_items else str(start_idx + max_items) + ) + return groups[start_idx : start_idx + max_items], marker def list_steps(self, cluster_id, marker=None, step_ids=None, step_states=None): max_items = 50 @@ -382,15 +421,16 @@ class ElasticMapReduceBackend(BaseBackend): if step_states: steps = [s for s in steps if s.state in step_states] start_idx = 0 if marker is None else int(marker) - marker = None if len(steps) <= start_idx + \ - max_items else str(start_idx + max_items) - return steps[start_idx:start_idx + max_items], marker + marker = ( + None if len(steps) <= start_idx + max_items else str(start_idx + max_items) + ) + return steps[start_idx : start_idx + max_items], marker def modify_instance_groups(self, instance_groups): result_groups = [] for instance_group in instance_groups: - group = self.instance_groups[instance_group['instance_group_id']] - group.set_instance_count(int(instance_group['instance_count'])) + group = self.instance_groups[instance_group["instance_group_id"]] + group.set_instance_count(int(instance_group["instance_count"])) return result_groups def remove_tags(self, cluster_id, tag_keys): diff --git a/moto/emr/responses.py b/moto/emr/responses.py index c807b5f54..94847ec8b 100644 --- a/moto/emr/responses.py +++ b/moto/emr/responses.py @@ -20,20 +20,27 @@ def generate_boto3_response(operation): determined to be from boto3. Pass the API action as a parameter. """ + def _boto3_request(method): @wraps(method) def f(self, *args, **kwargs): rendered = method(self, *args, **kwargs) - if 'json' in self.headers.get('Content-Type', []): + if "json" in self.headers.get("Content-Type", []): self.response_headers.update( - {'x-amzn-requestid': '2690d7eb-ed86-11dd-9877-6fad448a8419', - 'date': datetime.now(pytz.utc).strftime('%a, %d %b %Y %H:%M:%S %Z'), - 'content-type': 'application/x-amz-json-1.1'}) - resp = xml_to_json_response( - self.aws_service_spec, operation, rendered) - return '' if resp is None else json.dumps(resp) + { + "x-amzn-requestid": "2690d7eb-ed86-11dd-9877-6fad448a8419", + "date": datetime.now(pytz.utc).strftime( + "%a, %d %b %Y %H:%M:%S %Z" + ), + "content-type": "application/x-amz-json-1.1", + } + ) + resp = xml_to_json_response(self.aws_service_spec, operation, rendered) + return "" if resp is None else json.dumps(resp) return rendered + return f + return _boto3_request @@ -41,10 +48,12 @@ class ElasticMapReduceResponse(BaseResponse): # EMR end points are inconsistent in the placement of region name # in the URL, so parsing it out needs to be handled differently - region_regex = [re.compile(r'elasticmapreduce\.(.+?)\.amazonaws\.com'), - re.compile(r'(.+?)\.elasticmapreduce\.amazonaws\.com')] + region_regex = [ + re.compile(r"elasticmapreduce\.(.+?)\.amazonaws\.com"), + re.compile(r"(.+?)\.elasticmapreduce\.amazonaws\.com"), + ] - aws_service_spec = AWSServiceSpec('data/emr/2009-03-31/service-2.json') + aws_service_spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json") def get_region_from_url(self, request, full_url): parsed = urlparse(full_url) @@ -58,28 +67,28 @@ class ElasticMapReduceResponse(BaseResponse): def backend(self): return emr_backends[self.region] - @generate_boto3_response('AddInstanceGroups') + @generate_boto3_response("AddInstanceGroups") def add_instance_groups(self): - jobflow_id = self._get_param('JobFlowId') - instance_groups = self._get_list_prefix('InstanceGroups.member') + jobflow_id = self._get_param("JobFlowId") + instance_groups = self._get_list_prefix("InstanceGroups.member") for item in instance_groups: - item['instance_count'] = int(item['instance_count']) - instance_groups = self.backend.add_instance_groups( - jobflow_id, instance_groups) + item["instance_count"] = int(item["instance_count"]) + instance_groups = self.backend.add_instance_groups(jobflow_id, instance_groups) template = self.response_template(ADD_INSTANCE_GROUPS_TEMPLATE) return template.render(instance_groups=instance_groups) - @generate_boto3_response('AddJobFlowSteps') + @generate_boto3_response("AddJobFlowSteps") def add_job_flow_steps(self): - job_flow_id = self._get_param('JobFlowId') + job_flow_id = self._get_param("JobFlowId") steps = self.backend.add_job_flow_steps( - job_flow_id, steps_from_query_string(self._get_list_prefix('Steps.member'))) + job_flow_id, steps_from_query_string(self._get_list_prefix("Steps.member")) + ) template = self.response_template(ADD_JOB_FLOW_STEPS_TEMPLATE) return template.render(steps=steps) - @generate_boto3_response('AddTags') + @generate_boto3_response("AddTags") def add_tags(self): - cluster_id = self._get_param('ResourceId') + cluster_id = self._get_param("ResourceId") tags = tags_from_query_string(self.querystring) self.backend.add_tags(cluster_id, tags) template = self.response_template(ADD_TAGS_TEMPLATE) @@ -94,235 +103,257 @@ class ElasticMapReduceResponse(BaseResponse): def delete_security_configuration(self): raise NotImplementedError - @generate_boto3_response('DescribeCluster') + @generate_boto3_response("DescribeCluster") def describe_cluster(self): - cluster_id = self._get_param('ClusterId') + cluster_id = self._get_param("ClusterId") cluster = self.backend.get_cluster(cluster_id) template = self.response_template(DESCRIBE_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - @generate_boto3_response('DescribeJobFlows') + @generate_boto3_response("DescribeJobFlows") def describe_job_flows(self): - created_after = self._get_param('CreatedAfter') - created_before = self._get_param('CreatedBefore') + created_after = self._get_param("CreatedAfter") + created_before = self._get_param("CreatedBefore") job_flow_ids = self._get_multi_param("JobFlowIds.member") - job_flow_states = self._get_multi_param('JobFlowStates.member') + job_flow_states = self._get_multi_param("JobFlowStates.member") clusters = self.backend.describe_job_flows( - job_flow_ids, job_flow_states, created_after, created_before) + job_flow_ids, job_flow_states, created_after, created_before + ) template = self.response_template(DESCRIBE_JOB_FLOWS_TEMPLATE) return template.render(clusters=clusters) def describe_security_configuration(self): raise NotImplementedError - @generate_boto3_response('DescribeStep') + @generate_boto3_response("DescribeStep") def describe_step(self): - cluster_id = self._get_param('ClusterId') - step_id = self._get_param('StepId') + cluster_id = self._get_param("ClusterId") + step_id = self._get_param("StepId") step = self.backend.describe_step(cluster_id, step_id) template = self.response_template(DESCRIBE_STEP_TEMPLATE) return template.render(step=step) - @generate_boto3_response('ListBootstrapActions') + @generate_boto3_response("ListBootstrapActions") def list_bootstrap_actions(self): - cluster_id = self._get_param('ClusterId') - marker = self._get_param('Marker') + cluster_id = self._get_param("ClusterId") + marker = self._get_param("Marker") bootstrap_actions, marker = self.backend.list_bootstrap_actions( - cluster_id, marker) + cluster_id, marker + ) template = self.response_template(LIST_BOOTSTRAP_ACTIONS_TEMPLATE) return template.render(bootstrap_actions=bootstrap_actions, marker=marker) - @generate_boto3_response('ListClusters') + @generate_boto3_response("ListClusters") def list_clusters(self): - cluster_states = self._get_multi_param('ClusterStates.member') - created_after = self._get_param('CreatedAfter') - created_before = self._get_param('CreatedBefore') - marker = self._get_param('Marker') + cluster_states = self._get_multi_param("ClusterStates.member") + created_after = self._get_param("CreatedAfter") + created_before = self._get_param("CreatedBefore") + marker = self._get_param("Marker") clusters, marker = self.backend.list_clusters( - cluster_states, created_after, created_before, marker) + cluster_states, created_after, created_before, marker + ) template = self.response_template(LIST_CLUSTERS_TEMPLATE) return template.render(clusters=clusters, marker=marker) - @generate_boto3_response('ListInstanceGroups') + @generate_boto3_response("ListInstanceGroups") def list_instance_groups(self): - cluster_id = self._get_param('ClusterId') - marker = self._get_param('Marker') + cluster_id = self._get_param("ClusterId") + marker = self._get_param("Marker") instance_groups, marker = self.backend.list_instance_groups( - cluster_id, marker=marker) + cluster_id, marker=marker + ) template = self.response_template(LIST_INSTANCE_GROUPS_TEMPLATE) return template.render(instance_groups=instance_groups, marker=marker) def list_instances(self): raise NotImplementedError - @generate_boto3_response('ListSteps') + @generate_boto3_response("ListSteps") def list_steps(self): - cluster_id = self._get_param('ClusterId') - marker = self._get_param('Marker') - step_ids = self._get_multi_param('StepIds.member') - step_states = self._get_multi_param('StepStates.member') + cluster_id = self._get_param("ClusterId") + marker = self._get_param("Marker") + step_ids = self._get_multi_param("StepIds.member") + step_states = self._get_multi_param("StepStates.member") steps, marker = self.backend.list_steps( - cluster_id, marker=marker, step_ids=step_ids, step_states=step_states) + cluster_id, marker=marker, step_ids=step_ids, step_states=step_states + ) template = self.response_template(LIST_STEPS_TEMPLATE) return template.render(steps=steps, marker=marker) - @generate_boto3_response('ModifyInstanceGroups') + @generate_boto3_response("ModifyInstanceGroups") def modify_instance_groups(self): - instance_groups = self._get_list_prefix('InstanceGroups.member') + instance_groups = self._get_list_prefix("InstanceGroups.member") for item in instance_groups: - item['instance_count'] = int(item['instance_count']) + item["instance_count"] = int(item["instance_count"]) instance_groups = self.backend.modify_instance_groups(instance_groups) template = self.response_template(MODIFY_INSTANCE_GROUPS_TEMPLATE) return template.render(instance_groups=instance_groups) - @generate_boto3_response('RemoveTags') + @generate_boto3_response("RemoveTags") def remove_tags(self): - cluster_id = self._get_param('ResourceId') - tag_keys = self._get_multi_param('TagKeys.member') + cluster_id = self._get_param("ResourceId") + tag_keys = self._get_multi_param("TagKeys.member") self.backend.remove_tags(cluster_id, tag_keys) template = self.response_template(REMOVE_TAGS_TEMPLATE) return template.render() - @generate_boto3_response('RunJobFlow') + @generate_boto3_response("RunJobFlow") def run_job_flow(self): instance_attrs = dict( - master_instance_type=self._get_param( - 'Instances.MasterInstanceType'), - slave_instance_type=self._get_param('Instances.SlaveInstanceType'), - instance_count=self._get_int_param('Instances.InstanceCount', 1), - ec2_key_name=self._get_param('Instances.Ec2KeyName'), - ec2_subnet_id=self._get_param('Instances.Ec2SubnetId'), - hadoop_version=self._get_param('Instances.HadoopVersion'), + master_instance_type=self._get_param("Instances.MasterInstanceType"), + slave_instance_type=self._get_param("Instances.SlaveInstanceType"), + instance_count=self._get_int_param("Instances.InstanceCount", 1), + ec2_key_name=self._get_param("Instances.Ec2KeyName"), + ec2_subnet_id=self._get_param("Instances.Ec2SubnetId"), + hadoop_version=self._get_param("Instances.HadoopVersion"), availability_zone=self._get_param( - 'Instances.Placement.AvailabilityZone', self.backend.region_name + 'a'), + "Instances.Placement.AvailabilityZone", self.backend.region_name + "a" + ), keep_job_flow_alive_when_no_steps=self._get_bool_param( - 'Instances.KeepJobFlowAliveWhenNoSteps', False), + "Instances.KeepJobFlowAliveWhenNoSteps", False + ), termination_protected=self._get_bool_param( - 'Instances.TerminationProtected', False), + "Instances.TerminationProtected", False + ), emr_managed_master_security_group=self._get_param( - 'Instances.EmrManagedMasterSecurityGroup'), + "Instances.EmrManagedMasterSecurityGroup" + ), emr_managed_slave_security_group=self._get_param( - 'Instances.EmrManagedSlaveSecurityGroup'), + "Instances.EmrManagedSlaveSecurityGroup" + ), service_access_security_group=self._get_param( - 'Instances.ServiceAccessSecurityGroup'), + "Instances.ServiceAccessSecurityGroup" + ), additional_master_security_groups=self._get_multi_param( - 'Instances.AdditionalMasterSecurityGroups.member.'), - additional_slave_security_groups=self._get_multi_param('Instances.AdditionalSlaveSecurityGroups.member.')) + "Instances.AdditionalMasterSecurityGroups.member." + ), + additional_slave_security_groups=self._get_multi_param( + "Instances.AdditionalSlaveSecurityGroups.member." + ), + ) kwargs = dict( - name=self._get_param('Name'), - log_uri=self._get_param('LogUri'), - job_flow_role=self._get_param('JobFlowRole'), - service_role=self._get_param('ServiceRole'), - steps=steps_from_query_string( - self._get_list_prefix('Steps.member')), - visible_to_all_users=self._get_bool_param( - 'VisibleToAllUsers', False), + name=self._get_param("Name"), + log_uri=self._get_param("LogUri"), + job_flow_role=self._get_param("JobFlowRole"), + service_role=self._get_param("ServiceRole"), + steps=steps_from_query_string(self._get_list_prefix("Steps.member")), + visible_to_all_users=self._get_bool_param("VisibleToAllUsers", False), instance_attrs=instance_attrs, ) - bootstrap_actions = self._get_list_prefix('BootstrapActions.member') + bootstrap_actions = self._get_list_prefix("BootstrapActions.member") if bootstrap_actions: for ba in bootstrap_actions: args = [] idx = 1 - keyfmt = 'script_bootstrap_action._args.member.{0}' + keyfmt = "script_bootstrap_action._args.member.{0}" key = keyfmt.format(idx) while key in ba: args.append(ba.pop(key)) idx += 1 key = keyfmt.format(idx) - ba['args'] = args - ba['script_path'] = ba.pop('script_bootstrap_action._path') - kwargs['bootstrap_actions'] = bootstrap_actions + ba["args"] = args + ba["script_path"] = ba.pop("script_bootstrap_action._path") + kwargs["bootstrap_actions"] = bootstrap_actions - configurations = self._get_list_prefix('Configurations.member') + configurations = self._get_list_prefix("Configurations.member") if configurations: for idx, config in enumerate(configurations, 1): for key in list(config.keys()): - if key.startswith('properties.'): + if key.startswith("properties."): config.pop(key) - config['properties'] = {} + config["properties"] = {} map_items = self._get_map_prefix( - 'Configurations.member.{0}.Properties.entry'.format(idx)) - config['properties'] = map_items + "Configurations.member.{0}.Properties.entry".format(idx) + ) + config["properties"] = map_items - kwargs['configurations'] = configurations + kwargs["configurations"] = configurations - release_label = self._get_param('ReleaseLabel') - ami_version = self._get_param('AmiVersion') + release_label = self._get_param("ReleaseLabel") + ami_version = self._get_param("AmiVersion") if release_label: - kwargs['release_label'] = release_label + kwargs["release_label"] = release_label if ami_version: message = ( - 'Only one AMI version and release label may be specified. ' - 'Provided AMI: {0}, release label: {1}.').format( - ami_version, release_label) - raise EmrError(error_type="ValidationException", - message=message, template='error_json') + "Only one AMI version and release label may be specified. " + "Provided AMI: {0}, release label: {1}." + ).format(ami_version, release_label) + raise EmrError( + error_type="ValidationException", + message=message, + template="error_json", + ) else: if ami_version: - kwargs['requested_ami_version'] = ami_version - kwargs['running_ami_version'] = ami_version + kwargs["requested_ami_version"] = ami_version + kwargs["running_ami_version"] = ami_version else: - kwargs['running_ami_version'] = '1.0.0' + kwargs["running_ami_version"] = "1.0.0" - custom_ami_id = self._get_param('CustomAmiId') + custom_ami_id = self._get_param("CustomAmiId") if custom_ami_id: - kwargs['custom_ami_id'] = custom_ami_id - if release_label and release_label < 'emr-5.7.0': - message = 'Custom AMI is not allowed' - raise EmrError(error_type='ValidationException', - message=message, template='error_json') + kwargs["custom_ami_id"] = custom_ami_id + if release_label and release_label < "emr-5.7.0": + message = "Custom AMI is not allowed" + raise EmrError( + error_type="ValidationException", + message=message, + template="error_json", + ) elif ami_version: - message = 'Custom AMI is not supported in this version of EMR' - raise EmrError(error_type='ValidationException', - message=message, template='error_json') + message = "Custom AMI is not supported in this version of EMR" + raise EmrError( + error_type="ValidationException", + message=message, + template="error_json", + ) cluster = self.backend.run_job_flow(**kwargs) - applications = self._get_list_prefix('Applications.member') + applications = self._get_list_prefix("Applications.member") if applications: self.backend.add_applications(cluster.id, applications) else: self.backend.add_applications( - cluster.id, [{'Name': 'Hadoop', 'Version': '0.18'}]) + cluster.id, [{"Name": "Hadoop", "Version": "0.18"}] + ) - instance_groups = self._get_list_prefix( - 'Instances.InstanceGroups.member') + instance_groups = self._get_list_prefix("Instances.InstanceGroups.member") if instance_groups: for ig in instance_groups: - ig['instance_count'] = int(ig['instance_count']) + ig["instance_count"] = int(ig["instance_count"]) self.backend.add_instance_groups(cluster.id, instance_groups) - tags = self._get_list_prefix('Tags.member') + tags = self._get_list_prefix("Tags.member") if tags: self.backend.add_tags( - cluster.id, dict((d['key'], d['value']) for d in tags)) + cluster.id, dict((d["key"], d["value"]) for d in tags) + ) template = self.response_template(RUN_JOB_FLOW_TEMPLATE) return template.render(cluster=cluster) - @generate_boto3_response('SetTerminationProtection') + @generate_boto3_response("SetTerminationProtection") def set_termination_protection(self): - termination_protection = self._get_param('TerminationProtected') - job_ids = self._get_multi_param('JobFlowIds.member') - self.backend.set_termination_protection( - job_ids, termination_protection) + termination_protection = self._get_param("TerminationProtected") + job_ids = self._get_multi_param("JobFlowIds.member") + self.backend.set_termination_protection(job_ids, termination_protection) template = self.response_template(SET_TERMINATION_PROTECTION_TEMPLATE) return template.render() - @generate_boto3_response('SetVisibleToAllUsers') + @generate_boto3_response("SetVisibleToAllUsers") def set_visible_to_all_users(self): - visible_to_all_users = self._get_param('VisibleToAllUsers') - job_ids = self._get_multi_param('JobFlowIds.member') + visible_to_all_users = self._get_param("VisibleToAllUsers") + job_ids = self._get_multi_param("JobFlowIds.member") self.backend.set_visible_to_all_users(job_ids, visible_to_all_users) template = self.response_template(SET_VISIBLE_TO_ALL_USERS_TEMPLATE) return template.render() - @generate_boto3_response('TerminateJobFlows') + @generate_boto3_response("TerminateJobFlows") def terminate_job_flows(self): - job_ids = self._get_multi_param('JobFlowIds.member.') + job_ids = self._get_multi_param("JobFlowIds.member.") self.backend.terminate_job_flows(job_ids) template = self.response_template(TERMINATE_JOB_FLOWS_TEMPLATE) return template.render() diff --git a/moto/emr/urls.py b/moto/emr/urls.py index 870eaf9d7..81275135d 100644 --- a/moto/emr/urls.py +++ b/moto/emr/urls.py @@ -6,6 +6,4 @@ url_bases = [ "https?://elasticmapreduce.(.+).amazonaws.com", ] -url_paths = { - '{0}/$': ElasticMapReduceResponse.dispatch, -} +url_paths = {"{0}/$": ElasticMapReduceResponse.dispatch} diff --git a/moto/emr/utils.py b/moto/emr/utils.py index 4f12522cf..0f75995b8 100644 --- a/moto/emr/utils.py +++ b/moto/emr/utils.py @@ -7,24 +7,24 @@ import six def random_id(size=13): chars = list(range(10)) + list(string.ascii_uppercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) def random_cluster_id(size=13): - return 'j-{0}'.format(random_id()) + return "j-{0}".format(random_id()) def random_step_id(size=13): - return 's-{0}'.format(random_id()) + return "s-{0}".format(random_id()) def random_instance_group_id(size=13): - return 'i-{0}'.format(random_id()) + return "i-{0}".format(random_id()) def tags_from_query_string(querystring_dict): - prefix = 'Tags' - suffix = 'Key' + prefix = "Tags" + suffix = "Key" response_values = {} for key, value in querystring_dict.items(): if key.startswith(prefix) and key.endswith(suffix): @@ -32,8 +32,7 @@ def tags_from_query_string(querystring_dict): tag_key = querystring_dict.get("Tags.{0}.Key".format(tag_index))[0] tag_value_key = "Tags.{0}.Value".format(tag_index) if tag_value_key in querystring_dict: - response_values[tag_key] = querystring_dict.get(tag_value_key)[ - 0] + response_values[tag_key] = querystring_dict.get(tag_value_key)[0] else: response_values[tag_key] = None return response_values @@ -42,14 +41,15 @@ def tags_from_query_string(querystring_dict): def steps_from_query_string(querystring_dict): steps = [] for step in querystring_dict: - step['jar'] = step.pop('hadoop_jar_step._jar') - step['properties'] = dict((o['Key'], o['Value']) - for o in step.get('properties', [])) - step['args'] = [] + step["jar"] = step.pop("hadoop_jar_step._jar") + step["properties"] = dict( + (o["Key"], o["Value"]) for o in step.get("properties", []) + ) + step["args"] = [] idx = 1 - keyfmt = 'hadoop_jar_step._args.member.{0}' + keyfmt = "hadoop_jar_step._args.member.{0}" while keyfmt.format(idx) in step: - step['args'].append(step.pop(keyfmt.format(idx))) + step["args"].append(step.pop(keyfmt.format(idx))) idx += 1 steps.append(step) return steps diff --git a/moto/events/__init__.py b/moto/events/__init__.py index 5c93c59c8..8fd414325 100644 --- a/moto/events/__init__.py +++ b/moto/events/__init__.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals -from .models import events_backend +from .models import events_backends +from ..core.models import base_decorator -events_backends = {"global": events_backend} -mock_events = events_backend.decorator +events_backend = events_backends["us-east-1"] +mock_events = base_decorator(events_backends) diff --git a/moto/events/models.py b/moto/events/models.py index 2422e0b51..0298c7c69 100644 --- a/moto/events/models.py +++ b/moto/events/models.py @@ -1,44 +1,48 @@ import os import re import json +import boto3 from moto.core.exceptions import JsonRESTError from moto.core import BaseBackend, BaseModel +from moto.sts.models import ACCOUNT_ID class Rule(BaseModel): - def _generate_arn(self, name): - return 'arn:aws:events:us-west-2:111111111111:rule/' + name + return "arn:aws:events:{region_name}:111111111111:rule/{name}".format( + region_name=self.region_name, name=name + ) - def __init__(self, name, **kwargs): + def __init__(self, name, region_name, **kwargs): self.name = name - self.arn = kwargs.get('Arn') or self._generate_arn(name) - self.event_pattern = kwargs.get('EventPattern') - self.schedule_exp = kwargs.get('ScheduleExpression') - self.state = kwargs.get('State') or 'ENABLED' - self.description = kwargs.get('Description') - self.role_arn = kwargs.get('RoleArn') + self.region_name = region_name + self.arn = kwargs.get("Arn") or self._generate_arn(name) + self.event_pattern = kwargs.get("EventPattern") + self.schedule_exp = kwargs.get("ScheduleExpression") + self.state = kwargs.get("State") or "ENABLED" + self.description = kwargs.get("Description") + self.role_arn = kwargs.get("RoleArn") self.targets = [] def enable(self): - self.state = 'ENABLED' + self.state = "ENABLED" def disable(self): - self.state = 'DISABLED' + self.state = "DISABLED" # This song and dance for targets is because we need order for Limits and NextTokens, but can't use OrderedDicts # with Python 2.6, so tracking it with an array it is. def _check_target_exists(self, target_id): for i in range(0, len(self.targets)): - if target_id == self.targets[i]['Id']: + if target_id == self.targets[i]["Id"]: return i return None def put_targets(self, targets): # Not testing for valid ARNs. for target in targets: - index = self._check_target_exists(target['Id']) + index = self._check_target_exists(target["Id"]) if index is not None: self.targets[index] = target else: @@ -51,24 +55,71 @@ class Rule(BaseModel): self.targets.pop(index) -class EventsBackend(BaseBackend): - ACCOUNT_ID = re.compile(r'^(\d{1,12}|\*)$') - STATEMENT_ID = re.compile(r'^[a-zA-Z0-9-_]{1,64}$') +class EventBus(BaseModel): + def __init__(self, region_name, name): + self.region = region_name + self.name = name - def __init__(self): + self._permissions = {} + + @property + def arn(self): + return "arn:aws:events:{region}:{account_id}:event-bus/{name}".format( + region=self.region, account_id=ACCOUNT_ID, name=self.name + ) + + @property + def policy(self): + if not len(self._permissions): + return None + + policy = {"Version": "2012-10-17", "Statement": []} + + for sid, permission in self._permissions.items(): + policy["Statement"].append( + { + "Sid": sid, + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::{}:root".format(permission["Principal"]) + }, + "Action": permission["Action"], + "Resource": self.arn, + } + ) + + return json.dumps(policy) + + +class EventsBackend(BaseBackend): + ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$") + STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$") + + def __init__(self, region_name): self.rules = {} # This array tracks the order in which the rules have been added, since # 2.6 doesn't have OrderedDicts. self.rules_order = [] self.next_tokens = {} + self.region_name = region_name + self.event_buses = {} + self.event_sources = {} - self.permissions = {} + self._add_default_event_bus() + + def reset(self): + region_name = self.region_name + self.__dict__ = {} + self.__init__(region_name) + + def _add_default_event_bus(self): + self.event_buses["default"] = EventBus(self.region_name, "default") def _get_rule_by_index(self, i): return self.rules.get(self.rules_order[i]) def _gen_next_token(self, index): - token = os.urandom(128).encode('base64') + token = os.urandom(128).encode("base64") self.next_tokens[token] = index return token @@ -114,24 +165,25 @@ class EventsBackend(BaseBackend): return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( - len(self.rules), next_token, limit) + len(self.rules), next_token, limit + ) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) for target in rule.targets: - if target['Arn'] == target_arn: + if target["Arn"] == target_arn: matching_rules.append(rule.name) - return_obj['RuleNames'] = matching_rules + return_obj["RuleNames"] = matching_rules if new_next_token is not None: - return_obj['NextToken'] = new_next_token + return_obj["NextToken"] = new_next_token return return_obj def list_rules(self, prefix=None, next_token=None, limit=None): - match_string = '.*' + match_string = ".*" if prefix is not None: - match_string = '^' + prefix + match_string + match_string = "^" + prefix + match_string match_regex = re.compile(match_string) @@ -139,16 +191,17 @@ class EventsBackend(BaseBackend): return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( - len(self.rules), next_token, limit) + len(self.rules), next_token, limit + ) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) if match_regex.match(rule.name): matching_rules.append(rule) - return_obj['Rules'] = matching_rules + return_obj["Rules"] = matching_rules if new_next_token is not None: - return_obj['NextToken'] = new_next_token + return_obj["NextToken"] = new_next_token return return_obj @@ -158,7 +211,8 @@ class EventsBackend(BaseBackend): rule = self.rules[rule] start_index, end_index, new_next_token = self._process_token_and_limits( - len(rule.targets), next_token, limit) + len(rule.targets), next_token, limit + ) returned_targets = [] return_obj = {} @@ -166,14 +220,14 @@ class EventsBackend(BaseBackend): for i in range(start_index, end_index): returned_targets.append(rule.targets[i]) - return_obj['Targets'] = returned_targets + return_obj["Targets"] = returned_targets if new_next_token is not None: - return_obj['NextToken'] = new_next_token + return_obj["NextToken"] = new_next_token return return_obj def put_rule(self, name, **kwargs): - rule = Rule(name, **kwargs) + rule = Rule(name, self.region_name, **kwargs) self.rules[rule.name] = rule self.rules_order.append(rule.name) return rule.arn @@ -191,9 +245,9 @@ class EventsBackend(BaseBackend): num_events = len(events) if num_events < 1: - raise JsonRESTError('ValidationError', 'Need at least 1 event') + raise JsonRESTError("ValidationError", "Need at least 1 event") elif num_events > 10: - raise JsonRESTError('ValidationError', 'Can only submit 10 events at once') + raise JsonRESTError("ValidationError", "Can only submit 10 events at once") # We dont really need to store the events yet return [] @@ -210,42 +264,103 @@ class EventsBackend(BaseBackend): def test_event_pattern(self): raise NotImplementedError() - def put_permission(self, action, principal, statement_id): - if action is None or action != 'events:PutEvents': - raise JsonRESTError('InvalidParameterValue', 'Action must be PutEvents') + def put_permission(self, event_bus_name, action, principal, statement_id): + if not event_bus_name: + event_bus_name = "default" + + event_bus = self.describe_event_bus(event_bus_name) + + if action is None or action != "events:PutEvents": + raise JsonRESTError( + "ValidationException", + "Provided value in parameter 'action' is not supported.", + ) if principal is None or self.ACCOUNT_ID.match(principal) is None: - raise JsonRESTError('InvalidParameterValue', 'Principal must match ^(\d{1,12}|\*)$') + raise JsonRESTError( + "InvalidParameterValue", "Principal must match ^(\d{1,12}|\*)$" + ) if statement_id is None or self.STATEMENT_ID.match(statement_id) is None: - raise JsonRESTError('InvalidParameterValue', 'StatementId must match ^[a-zA-Z0-9-_]{1,64}$') + raise JsonRESTError( + "InvalidParameterValue", "StatementId must match ^[a-zA-Z0-9-_]{1,64}$" + ) - self.permissions[statement_id] = {'action': action, 'principal': principal} - - def remove_permission(self, statement_id): - try: - del self.permissions[statement_id] - except KeyError: - raise JsonRESTError('ResourceNotFoundException', 'StatementId not found') - - def describe_event_bus(self): - arn = "arn:aws:events:us-east-1:000000000000:event-bus/default" - statements = [] - for statement_id, data in self.permissions.items(): - statements.append({ - 'Sid': statement_id, - 'Effect': 'Allow', - 'Principal': {'AWS': 'arn:aws:iam::{0}:root'.format(data['principal'])}, - 'Action': data['action'], - 'Resource': arn - }) - policy = {'Version': '2012-10-17', 'Statement': statements} - policy_json = json.dumps(policy) - return { - 'Policy': policy_json, - 'Name': 'default', - 'Arn': arn + event_bus._permissions[statement_id] = { + "Action": action, + "Principal": principal, } + def remove_permission(self, event_bus_name, statement_id): + if not event_bus_name: + event_bus_name = "default" -events_backend = EventsBackend() + event_bus = self.describe_event_bus(event_bus_name) + + if not len(event_bus._permissions): + raise JsonRESTError( + "ResourceNotFoundException", "EventBus does not have a policy." + ) + + if not event_bus._permissions.pop(statement_id, None): + raise JsonRESTError( + "ResourceNotFoundException", + "Statement with the provided id does not exist.", + ) + + def describe_event_bus(self, name): + if not name: + name = "default" + + event_bus = self.event_buses.get(name) + + if not event_bus: + raise JsonRESTError( + "ResourceNotFoundException", "Event bus {} does not exist.".format(name) + ) + + return event_bus + + def create_event_bus(self, name, event_source_name): + if name in self.event_buses: + raise JsonRESTError( + "ResourceAlreadyExistsException", + "Event bus {} already exists.".format(name), + ) + + if not event_source_name and "/" in name: + raise JsonRESTError( + "ValidationException", "Event bus name must not contain '/'." + ) + + if event_source_name and event_source_name not in self.event_sources: + raise JsonRESTError( + "ResourceNotFoundException", + "Event source {} does not exist.".format(event_source_name), + ) + + self.event_buses[name] = EventBus(self.region_name, name) + + return self.event_buses[name] + + def list_event_buses(self, name_prefix): + if name_prefix: + return [ + event_bus + for event_bus in self.event_buses.values() + if event_bus.name.startswith(name_prefix) + ] + + return list(self.event_buses.values()) + + def delete_event_bus(self, name): + if name == "default": + raise JsonRESTError( + "ValidationException", "Cannot delete event bus default." + ) + + self.event_buses.pop(name, None) + + +available_regions = boto3.session.Session().get_available_regions("events") +events_backends = {region: EventsBackend(region) for region in available_regions} diff --git a/moto/events/responses.py b/moto/events/responses.py index f9cb9b5b5..b415564f8 100644 --- a/moto/events/responses.py +++ b/moto/events/responses.py @@ -2,25 +2,34 @@ import json import re from moto.core.responses import BaseResponse -from moto.events import events_backend +from moto.events import events_backends class EventsHandler(BaseResponse): + @property + def events_backend(self): + """ + Events Backend + + :return: Events Backend object + :rtype: moto.events.models.EventsBackend + """ + return events_backends[self.region] def _generate_rule_dict(self, rule): return { - 'Name': rule.name, - 'Arn': rule.arn, - 'EventPattern': rule.event_pattern, - 'State': rule.state, - 'Description': rule.description, - 'ScheduleExpression': rule.schedule_exp, - 'RoleArn': rule.role_arn + "Name": rule.name, + "Arn": rule.arn, + "EventPattern": rule.event_pattern, + "State": rule.state, + "Description": rule.description, + "ScheduleExpression": rule.schedule_exp, + "RoleArn": rule.role_arn, } @property def request_params(self): - if not hasattr(self, '_json_body'): + if not hasattr(self, "_json_body"): try: self._json_body = json.loads(self.body) except ValueError: @@ -30,127 +39,134 @@ class EventsHandler(BaseResponse): def _get_param(self, param, if_none=None): return self.request_params.get(param, if_none) - def error(self, type_, message='', status=400): + def error(self, type_, message="", status=400): headers = self.response_headers - headers['status'] = status - return json.dumps({'__type': type_, 'message': message}), headers, + headers["status"] = status + return json.dumps({"__type": type_, "message": message}), headers def delete_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') - events_backend.delete_rule(name) + return self.error("ValidationException", "Parameter Name is required.") + self.events_backend.delete_rule(name) - return '', self.response_headers + return "", self.response_headers def describe_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") - rule = events_backend.describe_rule(name) + rule = self.events_backend.describe_rule(name) if not rule: - return self.error('ResourceNotFoundException', 'Rule test does not exist.') + return self.error("ResourceNotFoundException", "Rule test does not exist.") rule_dict = self._generate_rule_dict(rule) return json.dumps(rule_dict), self.response_headers def disable_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") - if not events_backend.disable_rule(name): - return self.error('ResourceNotFoundException', 'Rule ' + name + ' does not exist.') + if not self.events_backend.disable_rule(name): + return self.error( + "ResourceNotFoundException", "Rule " + name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def enable_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") - if not events_backend.enable_rule(name): - return self.error('ResourceNotFoundException', 'Rule ' + name + ' does not exist.') + if not self.events_backend.enable_rule(name): + return self.error( + "ResourceNotFoundException", "Rule " + name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def generate_presigned_url(self): pass def list_rule_names_by_target(self): - target_arn = self._get_param('TargetArn') - next_token = self._get_param('NextToken') - limit = self._get_param('Limit') + target_arn = self._get_param("TargetArn") + next_token = self._get_param("NextToken") + limit = self._get_param("Limit") if not target_arn: - return self.error('ValidationException', 'Parameter TargetArn is required.') + return self.error("ValidationException", "Parameter TargetArn is required.") - rule_names = events_backend.list_rule_names_by_target( - target_arn, next_token, limit) + rule_names = self.events_backend.list_rule_names_by_target( + target_arn, next_token, limit + ) return json.dumps(rule_names), self.response_headers def list_rules(self): - prefix = self._get_param('NamePrefix') - next_token = self._get_param('NextToken') - limit = self._get_param('Limit') + prefix = self._get_param("NamePrefix") + next_token = self._get_param("NextToken") + limit = self._get_param("Limit") - rules = events_backend.list_rules(prefix, next_token, limit) - rules_obj = {'Rules': []} + rules = self.events_backend.list_rules(prefix, next_token, limit) + rules_obj = {"Rules": []} - for rule in rules['Rules']: - rules_obj['Rules'].append(self._generate_rule_dict(rule)) + for rule in rules["Rules"]: + rules_obj["Rules"].append(self._generate_rule_dict(rule)) - if rules.get('NextToken'): - rules_obj['NextToken'] = rules['NextToken'] + if rules.get("NextToken"): + rules_obj["NextToken"] = rules["NextToken"] return json.dumps(rules_obj), self.response_headers def list_targets_by_rule(self): - rule_name = self._get_param('Rule') - next_token = self._get_param('NextToken') - limit = self._get_param('Limit') + rule_name = self._get_param("Rule") + next_token = self._get_param("NextToken") + limit = self._get_param("Limit") if not rule_name: - return self.error('ValidationException', 'Parameter Rule is required.') + return self.error("ValidationException", "Parameter Rule is required.") try: - targets = events_backend.list_targets_by_rule( - rule_name, next_token, limit) + targets = self.events_backend.list_targets_by_rule( + rule_name, next_token, limit + ) except KeyError: - return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.') + return self.error( + "ResourceNotFoundException", "Rule " + rule_name + " does not exist." + ) return json.dumps(targets), self.response_headers def put_events(self): - events = self._get_param('Entries') + events = self._get_param("Entries") - failed_entries = events_backend.put_events(events) + failed_entries = self.events_backend.put_events(events) if failed_entries: - return json.dumps({ - 'FailedEntryCount': len(failed_entries), - 'Entries': failed_entries - }) + return json.dumps( + {"FailedEntryCount": len(failed_entries), "Entries": failed_entries} + ) - return '', self.response_headers + return "", self.response_headers def put_rule(self): - name = self._get_param('Name') - event_pattern = self._get_param('EventPattern') - sched_exp = self._get_param('ScheduleExpression') - state = self._get_param('State') - desc = self._get_param('Description') - role_arn = self._get_param('RoleArn') + name = self._get_param("Name") + event_pattern = self._get_param("EventPattern") + sched_exp = self._get_param("ScheduleExpression") + state = self._get_param("State") + desc = self._get_param("Description") + role_arn = self._get_param("RoleArn") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") if event_pattern: try: @@ -158,72 +174,126 @@ class EventsHandler(BaseResponse): except ValueError: # Not quite as informative as the real error, but it'll work # for now. - return self.error('InvalidEventPatternException', 'Event pattern is not valid.') + return self.error( + "InvalidEventPatternException", "Event pattern is not valid." + ) if sched_exp: - if not (re.match('^cron\(.*\)', sched_exp) or - re.match('^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)', sched_exp)): - return self.error('ValidationException', 'Parameter ScheduleExpression is not valid.') + if not ( + re.match("^cron\(.*\)", sched_exp) + or re.match( + "^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)", sched_exp + ) + ): + return self.error( + "ValidationException", "Parameter ScheduleExpression is not valid." + ) - rule_arn = events_backend.put_rule( + rule_arn = self.events_backend.put_rule( name, ScheduleExpression=sched_exp, EventPattern=event_pattern, State=state, Description=desc, - RoleArn=role_arn + RoleArn=role_arn, ) - return json.dumps({'RuleArn': rule_arn}), self.response_headers + return json.dumps({"RuleArn": rule_arn}), self.response_headers def put_targets(self): - rule_name = self._get_param('Rule') - targets = self._get_param('Targets') + rule_name = self._get_param("Rule") + targets = self._get_param("Targets") if not rule_name: - return self.error('ValidationException', 'Parameter Rule is required.') + return self.error("ValidationException", "Parameter Rule is required.") if not targets: - return self.error('ValidationException', 'Parameter Targets is required.') + return self.error("ValidationException", "Parameter Targets is required.") - if not events_backend.put_targets(rule_name, targets): - return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.') + if not self.events_backend.put_targets(rule_name, targets): + return self.error( + "ResourceNotFoundException", "Rule " + rule_name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def remove_targets(self): - rule_name = self._get_param('Rule') - ids = self._get_param('Ids') + rule_name = self._get_param("Rule") + ids = self._get_param("Ids") if not rule_name: - return self.error('ValidationException', 'Parameter Rule is required.') + return self.error("ValidationException", "Parameter Rule is required.") if not ids: - return self.error('ValidationException', 'Parameter Ids is required.') + return self.error("ValidationException", "Parameter Ids is required.") - if not events_backend.remove_targets(rule_name, ids): - return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.') + if not self.events_backend.remove_targets(rule_name, ids): + return self.error( + "ResourceNotFoundException", "Rule " + rule_name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def test_event_pattern(self): pass def put_permission(self): - action = self._get_param('Action') - principal = self._get_param('Principal') - statement_id = self._get_param('StatementId') + event_bus_name = self._get_param("EventBusName") + action = self._get_param("Action") + principal = self._get_param("Principal") + statement_id = self._get_param("StatementId") - events_backend.put_permission(action, principal, statement_id) + self.events_backend.put_permission( + event_bus_name, action, principal, statement_id + ) - return '' + return "" def remove_permission(self): - statement_id = self._get_param('StatementId') + event_bus_name = self._get_param("EventBusName") + statement_id = self._get_param("StatementId") - events_backend.remove_permission(statement_id) + self.events_backend.remove_permission(event_bus_name, statement_id) - return '' + return "" def describe_event_bus(self): - return json.dumps(events_backend.describe_event_bus()) + name = self._get_param("Name") + + event_bus = self.events_backend.describe_event_bus(name) + response = {"Name": event_bus.name, "Arn": event_bus.arn} + + if event_bus.policy: + response["Policy"] = event_bus.policy + + return json.dumps(response), self.response_headers + + def create_event_bus(self): + name = self._get_param("Name") + event_source_name = self._get_param("EventSourceName") + + event_bus = self.events_backend.create_event_bus(name, event_source_name) + + return json.dumps({"EventBusArn": event_bus.arn}), self.response_headers + + def list_event_buses(self): + name_prefix = self._get_param("NamePrefix") + # ToDo: add 'NextToken' & 'Limit' parameters + + response = [] + for event_bus in self.events_backend.list_event_buses(name_prefix): + event_bus_response = {"Name": event_bus.name, "Arn": event_bus.arn} + + if event_bus.policy: + event_bus_response["Policy"] = event_bus.policy + + response.append(event_bus_response) + + return json.dumps({"EventBuses": response}), self.response_headers + + def delete_event_bus(self): + name = self._get_param("Name") + + self.events_backend.delete_event_bus(name) + + return "", self.response_headers diff --git a/moto/events/urls.py b/moto/events/urls.py index a6e533b08..39e6a3462 100644 --- a/moto/events/urls.py +++ b/moto/events/urls.py @@ -2,10 +2,6 @@ from __future__ import unicode_literals from .responses import EventsHandler -url_bases = [ - "https?://events.(.+).amazonaws.com" -] +url_bases = ["https?://events.(.+).amazonaws.com"] -url_paths = { - "{0}/": EventsHandler.dispatch, -} +url_paths = {"{0}/": EventsHandler.dispatch} diff --git a/moto/glacier/__init__.py b/moto/glacier/__init__.py index 1570fa7d4..270d580f5 100644 --- a/moto/glacier/__init__.py +++ b/moto/glacier/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import glacier_backends from ..core.models import base_decorator, deprecated_base_decorator -glacier_backend = glacier_backends['us-east-1'] +glacier_backend = glacier_backends["us-east-1"] mock_glacier = base_decorator(glacier_backends) mock_glacier_deprecated = deprecated_base_decorator(glacier_backends) diff --git a/moto/glacier/models.py b/moto/glacier/models.py index 2c16bc97d..6a3fc074d 100644 --- a/moto/glacier/models.py +++ b/moto/glacier/models.py @@ -25,7 +25,6 @@ class Job(BaseModel): class ArchiveJob(Job): - def __init__(self, job_id, tier, arn, archive_id): self.job_id = job_id self.tier = tier @@ -50,7 +49,7 @@ class ArchiveJob(Job): "StatusCode": "InProgress", "StatusMessage": None, "VaultARN": self.arn, - "Tier": self.tier + "Tier": self.tier, } if datetime.datetime.now() > self.et: d["Completed"] = True @@ -61,7 +60,6 @@ class ArchiveJob(Job): class InventoryJob(Job): - def __init__(self, job_id, tier, arn): self.job_id = job_id self.tier = tier @@ -83,7 +81,7 @@ class InventoryJob(Job): "StatusCode": "InProgress", "StatusMessage": None, "VaultARN": self.arn, - "Tier": self.tier + "Tier": self.tier, } if datetime.datetime.now() > self.et: d["Completed"] = True @@ -94,7 +92,6 @@ class InventoryJob(Job): class Vault(BaseModel): - def __init__(self, vault_name, region): self.st = datetime.datetime.now() self.vault_name = vault_name @@ -104,7 +101,9 @@ class Vault(BaseModel): @property def arn(self): - return "arn:aws:glacier:{0}:012345678901:vaults/{1}".format(self.region, self.vault_name) + return "arn:aws:glacier:{0}:012345678901:vaults/{1}".format( + self.region, self.vault_name + ) def to_dict(self): archives_size = 0 @@ -126,7 +125,9 @@ class Vault(BaseModel): self.archives[archive_id]["body"] = body self.archives[archive_id]["size"] = len(body) self.archives[archive_id]["sha256"] = hashlib.sha256(body).hexdigest() - self.archives[archive_id]["creation_date"] = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.000Z") + self.archives[archive_id]["creation_date"] = datetime.datetime.now().strftime( + "%Y-%m-%dT%H:%M:%S.000Z" + ) self.archives[archive_id]["description"] = description return archive_id @@ -142,7 +143,7 @@ class Vault(BaseModel): "ArchiveDescription": archive["description"], "CreationDate": archive["creation_date"], "Size": archive["size"], - "SHA256TreeHash": archive["sha256"] + "SHA256TreeHash": archive["sha256"], } archive_list.append(aobj) return archive_list @@ -180,7 +181,7 @@ class Vault(BaseModel): return { "VaultARN": self.arn, "InventoryDate": jobj["CompletionDate"], - "ArchiveList": archives + "ArchiveList": archives, } else: archive_body = self.get_archive_body(job.archive_id) @@ -188,7 +189,6 @@ class Vault(BaseModel): class GlacierBackend(BaseBackend): - def __init__(self, region_name): self.vaults = {} self.region_name = region_name diff --git a/moto/glacier/responses.py b/moto/glacier/responses.py index abdf83e4f..5a82be479 100644 --- a/moto/glacier/responses.py +++ b/moto/glacier/responses.py @@ -9,7 +9,6 @@ from .utils import region_from_glacier_url, vault_from_glacier_url class GlacierResponse(_TemplateEnvironmentMixin): - def __init__(self, backend): super(GlacierResponse, self).__init__() self.backend = backend @@ -22,14 +21,11 @@ class GlacierResponse(_TemplateEnvironmentMixin): def _all_vault_response(self, request, full_url, headers): vaults = self.backend.list_vaules() - response = json.dumps({ - "Marker": None, - "VaultList": [ - vault.to_dict() for vault in vaults - ] - }) + response = json.dumps( + {"Marker": None, "VaultList": [vault.to_dict() for vault in vaults]} + ) - headers['content-type'] = 'application/json' + headers["content-type"] = "application/json" return 200, headers, response @classmethod @@ -44,16 +40,16 @@ class GlacierResponse(_TemplateEnvironmentMixin): querystring = parse_qs(parsed_url.query, keep_blank_values=True) vault_name = vault_from_glacier_url(full_url) - if method == 'GET': + if method == "GET": return self._vault_response_get(vault_name, querystring, headers) - elif method == 'PUT': + elif method == "PUT": return self._vault_response_put(vault_name, querystring, headers) - elif method == 'DELETE': + elif method == "DELETE": return self._vault_response_delete(vault_name, querystring, headers) def _vault_response_get(self, vault_name, querystring, headers): vault = self.backend.get_vault(vault_name) - headers['content-type'] = 'application/json' + headers["content-type"] = "application/json" return 200, headers, json.dumps(vault.to_dict()) def _vault_response_put(self, vault_name, querystring, headers): @@ -72,40 +68,46 @@ class GlacierResponse(_TemplateEnvironmentMixin): def _vault_archive_response(self, request, full_url, headers): method = request.method - if hasattr(request, 'body'): + if hasattr(request, "body"): body = request.body else: body = request.data description = "" - if 'x-amz-archive-description' in request.headers: - description = request.headers['x-amz-archive-description'] + if "x-amz-archive-description" in request.headers: + description = request.headers["x-amz-archive-description"] parsed_url = urlparse(full_url) querystring = parse_qs(parsed_url.query, keep_blank_values=True) vault_name = full_url.split("/")[-2] - if method == 'POST': - return self._vault_archive_response_post(vault_name, body, description, querystring, headers) + if method == "POST": + return self._vault_archive_response_post( + vault_name, body, description, querystring, headers + ) else: return 400, headers, "400 Bad Request" - def _vault_archive_response_post(self, vault_name, body, description, querystring, headers): + def _vault_archive_response_post( + self, vault_name, body, description, querystring, headers + ): vault = self.backend.get_vault(vault_name) vault_id = vault.create_archive(body, description) - headers['x-amz-archive-id'] = vault_id + headers["x-amz-archive-id"] = vault_id return 201, headers, "" @classmethod def vault_archive_individual_response(clazz, request, full_url, headers): region_name = region_from_glacier_url(full_url) response_instance = GlacierResponse(glacier_backends[region_name]) - return response_instance._vault_archive_individual_response(request, full_url, headers) + return response_instance._vault_archive_individual_response( + request, full_url, headers + ) def _vault_archive_individual_response(self, request, full_url, headers): method = request.method vault_name = full_url.split("/")[-3] archive_id = full_url.split("/")[-1] - if method == 'DELETE': + if method == "DELETE": vault = self.backend.get_vault(vault_name) vault.delete_archive(archive_id) return 204, headers, "" @@ -118,42 +120,47 @@ class GlacierResponse(_TemplateEnvironmentMixin): def _vault_jobs_response(self, request, full_url, headers): method = request.method - if hasattr(request, 'body'): + if hasattr(request, "body"): body = request.body else: body = request.data account_id = full_url.split("/")[1] vault_name = full_url.split("/")[-2] - if method == 'GET': + if method == "GET": jobs = self.backend.list_jobs(vault_name) - headers['content-type'] = 'application/json' - return 200, headers, json.dumps({ - "JobList": [ - job.to_dict() for job in jobs - ], - "Marker": None, - }) - elif method == 'POST': + headers["content-type"] = "application/json" + return ( + 200, + headers, + json.dumps( + {"JobList": [job.to_dict() for job in jobs], "Marker": None} + ), + ) + elif method == "POST": json_body = json.loads(body.decode("utf-8")) - job_type = json_body['Type'] + job_type = json_body["Type"] archive_id = None - if 'ArchiveId' in json_body: - archive_id = json_body['ArchiveId'] - if 'Tier' in json_body: + if "ArchiveId" in json_body: + archive_id = json_body["ArchiveId"] + if "Tier" in json_body: tier = json_body["Tier"] else: tier = "Standard" job_id = self.backend.initiate_job(vault_name, job_type, tier, archive_id) - headers['x-amz-job-id'] = job_id - headers['Location'] = "/{0}/vaults/{1}/jobs/{2}".format(account_id, vault_name, job_id) + headers["x-amz-job-id"] = job_id + headers["Location"] = "/{0}/vaults/{1}/jobs/{2}".format( + account_id, vault_name, job_id + ) return 202, headers, "" @classmethod def vault_jobs_individual_response(clazz, request, full_url, headers): region_name = region_from_glacier_url(full_url) response_instance = GlacierResponse(glacier_backends[region_name]) - return response_instance._vault_jobs_individual_response(request, full_url, headers) + return response_instance._vault_jobs_individual_response( + request, full_url, headers + ) def _vault_jobs_individual_response(self, request, full_url, headers): vault_name = full_url.split("/")[-3] @@ -176,10 +183,10 @@ class GlacierResponse(_TemplateEnvironmentMixin): if vault.job_ready(job_id): output = vault.get_job_output(job_id) if isinstance(output, dict): - headers['content-type'] = 'application/json' + headers["content-type"] = "application/json" return 200, headers, json.dumps(output) else: - headers['content-type'] = 'application/octet-stream' + headers["content-type"] = "application/octet-stream" return 200, headers, output else: return 404, headers, "404 Not Found" diff --git a/moto/glacier/urls.py b/moto/glacier/urls.py index 6038c2bb4..480b125af 100644 --- a/moto/glacier/urls.py +++ b/moto/glacier/urls.py @@ -1,16 +1,14 @@ from __future__ import unicode_literals from .responses import GlacierResponse -url_bases = [ - "https?://glacier.(.+).amazonaws.com", -] +url_bases = ["https?://glacier.(.+).amazonaws.com"] url_paths = { - '{0}/(?P.+)/vaults$': GlacierResponse.all_vault_response, - '{0}/(?P.+)/vaults/(?P[^/.]+)$': GlacierResponse.vault_response, - '{0}/(?P.+)/vaults/(?P.+)/archives$': GlacierResponse.vault_archive_response, - '{0}/(?P.+)/vaults/(?P.+)/archives/(?P.+)$': GlacierResponse.vault_archive_individual_response, - '{0}/(?P.+)/vaults/(?P.+)/jobs$': GlacierResponse.vault_jobs_response, - '{0}/(?P.+)/vaults/(?P.+)/jobs/(?P[^/.]+)$': GlacierResponse.vault_jobs_individual_response, - '{0}/(?P.+)/vaults/(?P.+)/jobs/(?P.+)/output$': GlacierResponse.vault_jobs_output_response, + "{0}/(?P.+)/vaults$": GlacierResponse.all_vault_response, + "{0}/(?P.+)/vaults/(?P[^/.]+)$": GlacierResponse.vault_response, + "{0}/(?P.+)/vaults/(?P.+)/archives$": GlacierResponse.vault_archive_response, + "{0}/(?P.+)/vaults/(?P.+)/archives/(?P.+)$": GlacierResponse.vault_archive_individual_response, + "{0}/(?P.+)/vaults/(?P.+)/jobs$": GlacierResponse.vault_jobs_response, + "{0}/(?P.+)/vaults/(?P.+)/jobs/(?P[^/.]+)$": GlacierResponse.vault_jobs_individual_response, + "{0}/(?P.+)/vaults/(?P.+)/jobs/(?P.+)/output$": GlacierResponse.vault_jobs_output_response, } diff --git a/moto/glacier/utils.py b/moto/glacier/utils.py index f4a869bf3..d6dd7c656 100644 --- a/moto/glacier/utils.py +++ b/moto/glacier/utils.py @@ -7,10 +7,10 @@ from six.moves.urllib.parse import urlparse def region_from_glacier_url(url): domain = urlparse(url).netloc - if '.' in domain: + if "." in domain: return domain.split(".")[1] else: - return 'us-east-1' + return "us-east-1" def vault_from_glacier_url(full_url): @@ -18,4 +18,6 @@ def vault_from_glacier_url(full_url): def get_job_id(): - return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(92)) + return "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(92) + ) diff --git a/moto/glue/exceptions.py b/moto/glue/exceptions.py index 8972adb35..c4b7048db 100644 --- a/moto/glue/exceptions.py +++ b/moto/glue/exceptions.py @@ -9,46 +9,38 @@ class GlueClientError(JsonRESTError): class AlreadyExistsException(GlueClientError): def __init__(self, typ): super(GlueClientError, self).__init__( - 'AlreadyExistsException', - '%s already exists.' % (typ), + "AlreadyExistsException", "%s already exists." % (typ) ) class DatabaseAlreadyExistsException(AlreadyExistsException): def __init__(self): - super(DatabaseAlreadyExistsException, self).__init__('Database') + super(DatabaseAlreadyExistsException, self).__init__("Database") class TableAlreadyExistsException(AlreadyExistsException): def __init__(self): - super(TableAlreadyExistsException, self).__init__('Table') + super(TableAlreadyExistsException, self).__init__("Table") class PartitionAlreadyExistsException(AlreadyExistsException): def __init__(self): - super(PartitionAlreadyExistsException, self).__init__('Partition') + super(PartitionAlreadyExistsException, self).__init__("Partition") class EntityNotFoundException(GlueClientError): def __init__(self, msg): - super(GlueClientError, self).__init__( - 'EntityNotFoundException', - msg, - ) + super(GlueClientError, self).__init__("EntityNotFoundException", msg) class DatabaseNotFoundException(EntityNotFoundException): def __init__(self, db): - super(DatabaseNotFoundException, self).__init__( - 'Database %s not found.' % db, - ) + super(DatabaseNotFoundException, self).__init__("Database %s not found." % db) class TableNotFoundException(EntityNotFoundException): def __init__(self, tbl): - super(TableNotFoundException, self).__init__( - 'Table %s not found.' % tbl, - ) + super(TableNotFoundException, self).__init__("Table %s not found." % tbl) class PartitionNotFoundException(EntityNotFoundException): diff --git a/moto/glue/models.py b/moto/glue/models.py index 0989e0e9b..8f3396d9a 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -4,7 +4,7 @@ import time from moto.core import BaseBackend, BaseModel from moto.compat import OrderedDict -from.exceptions import ( +from .exceptions import ( JsonRESTError, DatabaseAlreadyExistsException, DatabaseNotFoundException, @@ -17,7 +17,6 @@ from.exceptions import ( class GlueBackend(BaseBackend): - def __init__(self): self.databases = OrderedDict() @@ -66,14 +65,12 @@ class GlueBackend(BaseBackend): class FakeDatabase(BaseModel): - def __init__(self, database_name): self.name = database_name self.tables = OrderedDict() class FakeTable(BaseModel): - def __init__(self, database_name, table_name, table_input): self.database_name = database_name self.name = table_name @@ -98,10 +95,7 @@ class FakeTable(BaseModel): raise VersionNotFoundException() def as_dict(self, version=-1): - obj = { - 'DatabaseName': self.database_name, - 'Name': self.name, - } + obj = {"DatabaseName": self.database_name, "Name": self.name} obj.update(self.get_version(version)) return obj @@ -124,7 +118,7 @@ class FakeTable(BaseModel): def update_partition(self, old_values, partiton_input): partition = FakePartition(self.database_name, self.name, partiton_input) key = str(partition.values) - if old_values == partiton_input['Values']: + if old_values == partiton_input["Values"]: # Altering a partition in place. Don't remove it so the order of # returned partitions doesn't change if key not in self.partitions: @@ -151,13 +145,13 @@ class FakePartition(BaseModel): self.database_name = database_name self.table_name = table_name self.partition_input = partiton_input - self.values = self.partition_input.get('Values', []) + self.values = self.partition_input.get("Values", []) def as_dict(self): obj = { - 'DatabaseName': self.database_name, - 'TableName': self.table_name, - 'CreationTime': self.creation_time, + "DatabaseName": self.database_name, + "TableName": self.table_name, + "CreationTime": self.creation_time, } obj.update(self.partition_input) return obj diff --git a/moto/glue/responses.py b/moto/glue/responses.py index 875513e7f..bf7b5776b 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -7,12 +7,11 @@ from .models import glue_backend from .exceptions import ( PartitionAlreadyExistsException, PartitionNotFoundException, - TableNotFoundException + TableNotFoundException, ) class GlueResponse(BaseResponse): - @property def glue_backend(self): return glue_backend @@ -22,94 +21,94 @@ class GlueResponse(BaseResponse): return json.loads(self.body) def create_database(self): - database_name = self.parameters['DatabaseInput']['Name'] + database_name = self.parameters["DatabaseInput"]["Name"] self.glue_backend.create_database(database_name) return "" def get_database(self): - database_name = self.parameters.get('Name') + database_name = self.parameters.get("Name") database = self.glue_backend.get_database(database_name) - return json.dumps({'Database': {'Name': database.name}}) + return json.dumps({"Database": {"Name": database.name}}) def create_table(self): - database_name = self.parameters.get('DatabaseName') - table_input = self.parameters.get('TableInput') - table_name = table_input.get('Name') + database_name = self.parameters.get("DatabaseName") + table_input = self.parameters.get("TableInput") + table_name = table_input.get("Name") self.glue_backend.create_table(database_name, table_name, table_input) return "" def get_table(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('Name') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("Name") table = self.glue_backend.get_table(database_name, table_name) - return json.dumps({'Table': table.as_dict()}) + return json.dumps({"Table": table.as_dict()}) def update_table(self): - database_name = self.parameters.get('DatabaseName') - table_input = self.parameters.get('TableInput') - table_name = table_input.get('Name') + database_name = self.parameters.get("DatabaseName") + table_input = self.parameters.get("TableInput") + table_name = table_input.get("Name") table = self.glue_backend.get_table(database_name, table_name) table.update(table_input) return "" def get_table_versions(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) - return json.dumps({ - "TableVersions": [ - { - "Table": table.as_dict(version=n), - "VersionId": str(n + 1), - } for n in range(len(table.versions)) - ], - }) + return json.dumps( + { + "TableVersions": [ + {"Table": table.as_dict(version=n), "VersionId": str(n + 1)} + for n in range(len(table.versions)) + ] + } + ) def get_table_version(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) - ver_id = self.parameters.get('VersionId') + ver_id = self.parameters.get("VersionId") - return json.dumps({ - "TableVersion": { - "Table": table.as_dict(version=ver_id), - "VersionId": ver_id, - }, - }) + return json.dumps( + { + "TableVersion": { + "Table": table.as_dict(version=ver_id), + "VersionId": ver_id, + } + } + ) def get_tables(self): - database_name = self.parameters.get('DatabaseName') + database_name = self.parameters.get("DatabaseName") tables = self.glue_backend.get_tables(database_name) - return json.dumps({ - 'TableList': [ - table.as_dict() for table in tables - ] - }) + return json.dumps({"TableList": [table.as_dict() for table in tables]}) def delete_table(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('Name') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("Name") resp = self.glue_backend.delete_table(database_name, table_name) return json.dumps(resp) def batch_delete_table(self): - database_name = self.parameters.get('DatabaseName') + database_name = self.parameters.get("DatabaseName") errors = [] - for table_name in self.parameters.get('TablesToDelete'): + for table_name in self.parameters.get("TablesToDelete"): try: self.glue_backend.delete_table(database_name, table_name) except TableNotFoundException: - errors.append({ - "TableName": table_name, - "ErrorDetail": { - "ErrorCode": "EntityNotFoundException", - "ErrorMessage": "Table not found" + errors.append( + { + "TableName": table_name, + "ErrorDetail": { + "ErrorCode": "EntityNotFoundException", + "ErrorMessage": "Table not found", + }, } - }) + ) out = {} if errors: @@ -118,33 +117,31 @@ class GlueResponse(BaseResponse): return json.dumps(out) def get_partitions(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - if 'Expression' in self.parameters: - raise NotImplementedError("Expression filtering in get_partitions is not implemented in moto") + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + if "Expression" in self.parameters: + raise NotImplementedError( + "Expression filtering in get_partitions is not implemented in moto" + ) table = self.glue_backend.get_table(database_name, table_name) - return json.dumps({ - 'Partitions': [ - p.as_dict() for p in table.get_partitions() - ] - }) + return json.dumps({"Partitions": [p.as_dict() for p in table.get_partitions()]}) def get_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - values = self.parameters.get('PartitionValues') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + values = self.parameters.get("PartitionValues") table = self.glue_backend.get_table(database_name, table_name) p = table.get_partition(values) - return json.dumps({'Partition': p.as_dict()}) + return json.dumps({"Partition": p.as_dict()}) def batch_get_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - partitions_to_get = self.parameters.get('PartitionsToGet') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + partitions_to_get = self.parameters.get("PartitionsToGet") table = self.glue_backend.get_table(database_name, table_name) @@ -156,12 +153,12 @@ class GlueResponse(BaseResponse): except PartitionNotFoundException: continue - return json.dumps({'Partitions': partitions}) + return json.dumps({"Partitions": partitions}) def create_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - part_input = self.parameters.get('PartitionInput') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + part_input = self.parameters.get("PartitionInput") table = self.glue_backend.get_table(database_name, table_name) table.create_partition(part_input) @@ -169,22 +166,24 @@ class GlueResponse(BaseResponse): return "" def batch_create_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) errors_output = [] - for part_input in self.parameters.get('PartitionInputList'): + for part_input in self.parameters.get("PartitionInputList"): try: table.create_partition(part_input) except PartitionAlreadyExistsException: - errors_output.append({ - 'PartitionValues': part_input['Values'], - 'ErrorDetail': { - 'ErrorCode': 'AlreadyExistsException', - 'ErrorMessage': 'Partition already exists.' + errors_output.append( + { + "PartitionValues": part_input["Values"], + "ErrorDetail": { + "ErrorCode": "AlreadyExistsException", + "ErrorMessage": "Partition already exists.", + }, } - }) + ) out = {} if errors_output: @@ -193,10 +192,10 @@ class GlueResponse(BaseResponse): return json.dumps(out) def update_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - part_input = self.parameters.get('PartitionInput') - part_to_update = self.parameters.get('PartitionValueList') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + part_input = self.parameters.get("PartitionInput") + part_to_update = self.parameters.get("PartitionValueList") table = self.glue_backend.get_table(database_name, table_name) table.update_partition(part_to_update, part_input) @@ -204,9 +203,9 @@ class GlueResponse(BaseResponse): return "" def delete_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - part_to_delete = self.parameters.get('PartitionValues') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + part_to_delete = self.parameters.get("PartitionValues") table = self.glue_backend.get_table(database_name, table_name) table.delete_partition(part_to_delete) @@ -214,26 +213,28 @@ class GlueResponse(BaseResponse): return "" def batch_delete_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) errors_output = [] - for part_input in self.parameters.get('PartitionsToDelete'): - values = part_input.get('Values') + for part_input in self.parameters.get("PartitionsToDelete"): + values = part_input.get("Values") try: table.delete_partition(values) except PartitionNotFoundException: - errors_output.append({ - 'PartitionValues': values, - 'ErrorDetail': { - 'ErrorCode': 'EntityNotFoundException', - 'ErrorMessage': 'Partition not found', + errors_output.append( + { + "PartitionValues": values, + "ErrorDetail": { + "ErrorCode": "EntityNotFoundException", + "ErrorMessage": "Partition not found", + }, } - }) + ) out = {} if errors_output: - out['Errors'] = errors_output + out["Errors"] = errors_output return json.dumps(out) diff --git a/moto/glue/urls.py b/moto/glue/urls.py index f3eaa9cad..2c7854732 100644 --- a/moto/glue/urls.py +++ b/moto/glue/urls.py @@ -2,10 +2,6 @@ from __future__ import unicode_literals from .responses import GlueResponse -url_bases = [ - "https?://glue(.*).amazonaws.com" -] +url_bases = ["https?://glue(.*).amazonaws.com"] -url_paths = { - '{0}/$': GlueResponse.dispatch -} +url_paths = {"{0}/$": GlueResponse.dispatch} diff --git a/moto/iam/exceptions.py b/moto/iam/exceptions.py index ac08e0d88..1d0f3ca01 100644 --- a/moto/iam/exceptions.py +++ b/moto/iam/exceptions.py @@ -6,32 +6,28 @@ class IAMNotFoundException(RESTError): code = 404 def __init__(self, message): - super(IAMNotFoundException, self).__init__( - "NoSuchEntity", message) + super(IAMNotFoundException, self).__init__("NoSuchEntity", message) class IAMConflictException(RESTError): code = 409 - def __init__(self, code='Conflict', message=""): - super(IAMConflictException, self).__init__( - code, message) + def __init__(self, code="Conflict", message=""): + super(IAMConflictException, self).__init__(code, message) class IAMReportNotPresentException(RESTError): code = 410 def __init__(self, message): - super(IAMReportNotPresentException, self).__init__( - "ReportNotPresent", message) + super(IAMReportNotPresentException, self).__init__("ReportNotPresent", message) class IAMLimitExceededException(RESTError): code = 400 def __init__(self, message): - super(IAMLimitExceededException, self).__init__( - "LimitExceeded", message) + super(IAMLimitExceededException, self).__init__("LimitExceeded", message) class MalformedCertificate(RESTError): @@ -39,7 +35,8 @@ class MalformedCertificate(RESTError): def __init__(self, cert): super(MalformedCertificate, self).__init__( - 'MalformedCertificate', 'Certificate {cert} is malformed'.format(cert=cert)) + "MalformedCertificate", "Certificate {cert} is malformed".format(cert=cert) + ) class MalformedPolicyDocument(RESTError): @@ -47,7 +44,8 @@ class MalformedPolicyDocument(RESTError): def __init__(self, message=""): super(MalformedPolicyDocument, self).__init__( - 'MalformedPolicyDocument', message) + "MalformedPolicyDocument", message + ) class DuplicateTags(RESTError): @@ -55,16 +53,22 @@ class DuplicateTags(RESTError): def __init__(self): super(DuplicateTags, self).__init__( - 'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.') + "InvalidInput", + "Duplicate tag keys found. Please note that Tag keys are case insensitive.", + ) class TagKeyTooBig(RESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): + def __init__(self, tag, param="tags.X.member.key"): super(TagKeyTooBig, self).__init__( - 'ValidationError', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 128.".format(tag, param)) + "ValidationError", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 128.".format( + tag, param + ), + ) class TagValueTooBig(RESTError): @@ -72,24 +76,62 @@ class TagValueTooBig(RESTError): def __init__(self, tag): super(TagValueTooBig, self).__init__( - 'ValidationError', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " - "constraint: Member must have length less than or equal to 256.".format(tag)) + "ValidationError", + "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " + "constraint: Member must have length less than or equal to 256.".format( + tag + ), + ) class InvalidTagCharacters(RESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): - message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param) + def __init__(self, tag, param="tags.X.member.key"): + message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format( + tag, param + ) message += "constraint: Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" - super(InvalidTagCharacters, self).__init__('ValidationError', message) + super(InvalidTagCharacters, self).__init__("ValidationError", message) class TooManyTags(RESTError): code = 400 - def __init__(self, tags, param='tags'): + def __init__(self, tags, param="tags"): super(TooManyTags, self).__init__( - 'ValidationError', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 50.".format(tags, param)) + "ValidationError", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 50.".format( + tags, param + ), + ) + + +class EntityAlreadyExists(RESTError): + code = 409 + + def __init__(self, message): + super(EntityAlreadyExists, self).__init__("EntityAlreadyExists", message) + + +class ValidationError(RESTError): + code = 400 + + def __init__(self, message): + super(ValidationError, self).__init__("ValidationError", message) + + +class InvalidInput(RESTError): + code = 400 + + def __init__(self, message): + super(InvalidInput, self).__init__("InvalidInput", message) + + +class NoSuchEntity(RESTError): + code = 404 + + def __init__(self, message): + super(NoSuchEntity, self).__init__("NoSuchEntity", message) diff --git a/moto/iam/models.py b/moto/iam/models.py index d76df8a28..5bbd9235d 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -1,5 +1,9 @@ from __future__ import unicode_literals import base64 +import hashlib +import os +import random +import string import sys from datetime import datetime import json @@ -7,27 +11,45 @@ import re from cryptography import x509 from cryptography.hazmat.backends import default_backend +from six.moves.urllib.parse import urlparse from moto.core.exceptions import RESTError -from moto.core import BaseBackend, BaseModel -from moto.core.utils import iso_8601_datetime_without_milliseconds, iso_8601_datetime_with_milliseconds +from moto.core import BaseBackend, BaseModel, ACCOUNT_ID +from moto.core.utils import ( + iso_8601_datetime_without_milliseconds, + iso_8601_datetime_with_milliseconds, +) from moto.iam.policy_validation import IAMPolicyDocumentValidator from .aws_managed_policies import aws_managed_policies_data -from .exceptions import IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, IAMLimitExceededException, \ - MalformedCertificate, DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig -from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id - -ACCOUNT_ID = 123456789012 +from .exceptions import ( + IAMNotFoundException, + IAMConflictException, + IAMReportNotPresentException, + IAMLimitExceededException, + MalformedCertificate, + DuplicateTags, + TagKeyTooBig, + InvalidTagCharacters, + TooManyTags, + TagValueTooBig, + EntityAlreadyExists, + ValidationError, + InvalidInput, + NoSuchEntity, +) +from .utils import ( + random_access_key, + random_alphanumeric, + random_resource_id, + random_policy_id, +) class MFADevice(object): """MFA Device class.""" - def __init__(self, - serial_number, - authentication_code_1, - authentication_code_2): + def __init__(self, serial_number, authentication_code_1, authentication_code_2): self.enable_date = datetime.utcnow() self.serial_number = serial_number self.authentication_code_1 = authentication_code_1 @@ -38,31 +60,60 @@ class MFADevice(object): return iso_8601_datetime_without_milliseconds(self.enable_date) +class VirtualMfaDevice(object): + def __init__(self, device_name): + self.serial_number = "arn:aws:iam::{0}:mfa{1}".format(ACCOUNT_ID, device_name) + + random_base32_string = "".join( + random.choice(string.ascii_uppercase + "234567") for _ in range(64) + ) + self.base32_string_seed = base64.b64encode( + random_base32_string.encode("ascii") + ).decode("ascii") + self.qr_code_png = base64.b64encode( + os.urandom(64) + ) # this would be a generated PNG + + self.enable_date = None + self.user_attribute = None + self.user = None + + @property + def enabled_iso_8601(self): + return iso_8601_datetime_without_milliseconds(self.enable_date) + + class Policy(BaseModel): is_attachable = False - def __init__(self, - name, - default_version_id=None, - description=None, - document=None, - path=None, - create_date=None, - update_date=None): + def __init__( + self, + name, + default_version_id=None, + description=None, + document=None, + path=None, + create_date=None, + update_date=None, + ): self.name = name self.attachment_count = 0 - self.description = description or '' + self.description = description or "" self.id = random_policy_id() - self.path = path or '/' + self.path = path or "/" if default_version_id: self.default_version_id = default_version_id - self.next_version_num = int(default_version_id.lstrip('v')) + 1 + self.next_version_num = int(default_version_id.lstrip("v")) + 1 else: - self.default_version_id = 'v1' + self.default_version_id = "v1" self.next_version_num = 2 - self.versions = [PolicyVersion(self.arn, document, True, self.default_version_id, update_date)] + self.versions = [ + PolicyVersion( + self.arn, document, True, self.default_version_id, update_date + ) + ] self.create_date = create_date if create_date is not None else datetime.utcnow() self.update_date = update_date if update_date is not None else datetime.utcnow() @@ -93,14 +144,94 @@ class SAMLProvider(BaseModel): return "arn:aws:iam::{0}:saml-provider/{1}".format(ACCOUNT_ID, self.name) -class PolicyVersion(object): +class OpenIDConnectProvider(BaseModel): + def __init__(self, url, thumbprint_list, client_id_list=None): + self._errors = [] + self._validate(url, thumbprint_list, client_id_list) - def __init__(self, - policy_arn, - document, - is_default=False, - version_id='v1', - create_date=None): + parsed_url = urlparse(url) + self.url = parsed_url.netloc + parsed_url.path + self.thumbprint_list = thumbprint_list + self.client_id_list = client_id_list + self.create_date = datetime.utcnow() + + @property + def arn(self): + return "arn:aws:iam::{0}:oidc-provider/{1}".format(ACCOUNT_ID, self.url) + + @property + def created_iso_8601(self): + return iso_8601_datetime_without_milliseconds(self.create_date) + + def _validate(self, url, thumbprint_list, client_id_list): + if any(len(client_id) > 255 for client_id in client_id_list): + self._errors.append( + self._format_error( + key="clientIDList", + value=client_id_list, + constraint="Member must satisfy constraint: " + "[Member must have length less than or equal to 255, " + "Member must have length greater than or equal to 1]", + ) + ) + + if any(len(thumbprint) > 40 for thumbprint in thumbprint_list): + self._errors.append( + self._format_error( + key="thumbprintList", + value=thumbprint_list, + constraint="Member must satisfy constraint: " + "[Member must have length less than or equal to 40, " + "Member must have length greater than or equal to 40]", + ) + ) + + if len(url) > 255: + self._errors.append( + self._format_error( + key="url", + value=url, + constraint="Member must have length less than or equal to 255", + ) + ) + + self._raise_errors() + + parsed_url = urlparse(url) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValidationError("Invalid Open ID Connect Provider URL") + + if len(thumbprint_list) > 5: + raise InvalidInput("Thumbprint list must contain fewer than 5 entries.") + + if len(client_id_list) > 100: + raise IAMLimitExceededException( + "Cannot exceed quota for ClientIdsPerOpenIdConnectProvider: 100" + ) + + def _format_error(self, key, value, constraint): + return 'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}'.format( + constraint=constraint, key=key, value=value + ) + + def _raise_errors(self): + if self._errors: + count = len(self._errors) + plural = "s" if len(self._errors) > 1 else "" + errors = "; ".join(self._errors) + self._errors = [] # reset collected errors + + raise ValidationError( + "{count} validation error{plural} detected: {errors}".format( + count=count, plural=plural, errors=errors + ) + ) + + +class PolicyVersion(object): + def __init__( + self, policy_arn, document, is_default=False, version_id="v1", create_date=None + ): self.policy_arn = policy_arn self.document = document or {} self.is_default = is_default @@ -136,23 +267,30 @@ class AWSManagedPolicy(ManagedPolicy): @classmethod def from_data(cls, name, data): - return cls(name, - default_version_id=data.get('DefaultVersionId'), - path=data.get('Path'), - document=json.dumps(data.get('Document')), - create_date=datetime.strptime(data.get('CreateDate'), "%Y-%m-%dT%H:%M:%S+00:00"), - update_date=datetime.strptime(data.get('UpdateDate'), "%Y-%m-%dT%H:%M:%S+00:00")) + return cls( + name, + default_version_id=data.get("DefaultVersionId"), + path=data.get("Path"), + document=json.dumps(data.get("Document")), + create_date=datetime.strptime( + data.get("CreateDate"), "%Y-%m-%dT%H:%M:%S+00:00" + ), + update_date=datetime.strptime( + data.get("UpdateDate"), "%Y-%m-%dT%H:%M:%S+00:00" + ), + ) @property def arn(self): - return 'arn:aws:iam::aws:policy{0}{1}'.format(self.path, self.name) + return "arn:aws:iam::aws:policy{0}{1}".format(self.path, self.name) # AWS defines some of its own managed policies and we periodically # import them via `make aws_managed_policies` aws_managed_policies = [ - AWSManagedPolicy.from_data(name, d) for name, d - in json.loads(aws_managed_policies_data).items()] + AWSManagedPolicy.from_data(name, d) + for name, d in json.loads(aws_managed_policies_data).items() +] class InlinePolicy(Policy): @@ -160,40 +298,53 @@ class InlinePolicy(Policy): class Role(BaseModel): - - def __init__(self, role_id, name, assume_role_policy_document, path, permissions_boundary, description, tags): + def __init__( + self, + role_id, + name, + assume_role_policy_document, + path, + permissions_boundary, + description, + tags, + max_session_duration, + ): self.id = role_id self.name = name self.assume_role_policy_document = assume_role_policy_document - self.path = path or '/' + self.path = path or "/" self.policies = {} self.managed_policies = {} self.create_date = datetime.utcnow() self.tags = tags self.description = description self.permissions_boundary = permissions_boundary + self.max_session_duration = max_session_duration @property def created_iso_8601(self): return iso_8601_datetime_with_milliseconds(self.create_date) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] role = iam_backend.create_role( role_name=resource_name, - assume_role_policy_document=properties['AssumeRolePolicyDocument'], - path=properties.get('Path', '/'), - permissions_boundary=properties.get('PermissionsBoundary', ''), - description=properties.get('Description', ''), - tags=properties.get('Tags', {}) + assume_role_policy_document=properties["AssumeRolePolicyDocument"], + path=properties.get("Path", "/"), + permissions_boundary=properties.get("PermissionsBoundary", ""), + description=properties.get("Description", ""), + tags=properties.get("Tags", {}), + max_session_duration=properties.get("MaxSessionDuration", 3600), ) - policies = properties.get('Policies', []) + policies = properties.get("Policies", []) for policy in policies: - policy_name = policy['PolicyName'] - policy_json = policy['PolicyDocument'] + policy_name = policy["PolicyName"] + policy_json = policy["PolicyDocument"] role.put_policy(policy_name, policy_json) return role @@ -210,7 +361,8 @@ class Role(BaseModel): del self.policies[policy_name] except KeyError: raise IAMNotFoundException( - "The role policy with name {0} cannot be found.".format(policy_name)) + "The role policy with name {0} cannot be found.".format(policy_name) + ) @property def physical_resource_id(self): @@ -218,8 +370,9 @@ class Role(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': - raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "Arn" ]"') + + if attribute_name == "Arn": + return self.arn raise UnformattedGetAttTemplateException() def get_tags(self): @@ -227,11 +380,10 @@ class Role(BaseModel): class InstanceProfile(BaseModel): - def __init__(self, instance_profile_id, name, path, roles): self.id = instance_profile_id self.name = name - self.path = path or '/' + self.path = path or "/" self.roles = roles if roles else [] self.create_date = datetime.utcnow() @@ -240,19 +392,21 @@ class InstanceProfile(BaseModel): return iso_8601_datetime_with_milliseconds(self.create_date) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - role_ids = properties['Roles'] + role_ids = properties["Roles"] return iam_backend.create_instance_profile( - name=resource_name, - path=properties.get('Path', '/'), - role_ids=role_ids, + name=resource_name, path=properties.get("Path", "/"), role_ids=role_ids ) @property def arn(self): - return "arn:aws:iam::{0}:instance-profile{1}{2}".format(ACCOUNT_ID, self.path, self.name) + return "arn:aws:iam::{0}:instance-profile{1}{2}".format( + ACCOUNT_ID, self.path, self.name + ) @property def physical_resource_id(self): @@ -260,13 +414,13 @@ class InstanceProfile(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() class Certificate(BaseModel): - def __init__(self, cert_name, cert_body, private_key, cert_chain=None, path=None): self.cert_name = cert_name self.cert_body = cert_body @@ -280,17 +434,18 @@ class Certificate(BaseModel): @property def arn(self): - return "arn:aws:iam::{0}:server-certificate{1}{2}".format(ACCOUNT_ID, self.path, self.cert_name) + return "arn:aws:iam::{0}:server-certificate{1}{2}".format( + ACCOUNT_ID, self.path, self.cert_name + ) class SigningCertificate(BaseModel): - def __init__(self, id, user_name, body): self.id = id self.user_name = user_name self.body = body self.upload_date = datetime.utcnow() - self.status = 'Active' + self.status = "Active" @property def uploaded_iso_8601(self): @@ -298,12 +453,11 @@ class SigningCertificate(BaseModel): class AccessKey(BaseModel): - def __init__(self, user_name): self.user_name = user_name self.access_key_id = "AKIA" + random_access_key() self.secret_access_key = random_alphanumeric(40) - self.status = 'Active' + self.status = "Active" self.create_date = datetime.utcnow() self.last_used = datetime.utcnow() @@ -317,14 +471,28 @@ class AccessKey(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'SecretAccessKey': + + if attribute_name == "SecretAccessKey": return self.secret_access_key raise UnformattedGetAttTemplateException() -class Group(BaseModel): +class SshPublicKey(BaseModel): + def __init__(self, user_name, ssh_public_key_body): + self.user_name = user_name + self.ssh_public_key_body = ssh_public_key_body + self.ssh_public_key_id = "APKA" + random_access_key() + self.fingerprint = hashlib.md5(ssh_public_key_body.encode()).hexdigest() + self.status = "Active" + self.upload_date = datetime.utcnow() - def __init__(self, name, path='/'): + @property + def uploaded_iso_8601(self): + return iso_8601_datetime_without_milliseconds(self.upload_date) + + +class Group(BaseModel): + def __init__(self, name, path="/"): self.name = name self.id = random_resource_id() self.path = path @@ -340,17 +508,20 @@ class Group(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "Arn" ]"') raise UnformattedGetAttTemplateException() @property def arn(self): - if self.path == '/': + if self.path == "/": return "arn:aws:iam::{0}:group/{1}".format(ACCOUNT_ID, self.name) else: - return "arn:aws:iam::{0}:group/{1}/{2}".format(ACCOUNT_ID, self.path, self.name) + return "arn:aws:iam::{0}:group/{1}/{2}".format( + ACCOUNT_ID, self.path, self.name + ) def get_policy(self, policy_name): try: @@ -359,9 +530,9 @@ class Group(BaseModel): raise IAMNotFoundException("Policy {0} not found".format(policy_name)) return { - 'policy_name': policy_name, - 'policy_document': policy_json, - 'group_name': self.name, + "policy_name": policy_name, + "policy_document": policy_json, + "group_name": self.name, } def put_policy(self, policy_name, policy_json): @@ -372,7 +543,6 @@ class Group(BaseModel): class User(BaseModel): - def __init__(self, name, path=None): self.name = name self.id = random_resource_id() @@ -382,6 +552,7 @@ class User(BaseModel): self.policies = {} self.managed_policies = {} self.access_keys = [] + self.ssh_public_keys = [] self.password = None self.password_reset_required = False self.signing_certificates = {} @@ -399,13 +570,12 @@ class User(BaseModel): try: policy_json = self.policies[policy_name] except KeyError: - raise IAMNotFoundException( - "Policy {0} not found".format(policy_name)) + raise IAMNotFoundException("Policy {0} not found".format(policy_name)) return { - 'policy_name': policy_name, - 'policy_document': policy_json, - 'user_name': self.name, + "policy_name": policy_name, + "policy_document": policy_json, + "user_name": self.name, } def put_policy(self, policy_name, policy_json): @@ -416,8 +586,7 @@ class User(BaseModel): def delete_policy(self, policy_name): if policy_name not in self.policies: - raise IAMNotFoundException( - "Policy {0} not found".format(policy_name)) + raise IAMNotFoundException("Policy {0} not found".format(policy_name)) del self.policies[policy_name] @@ -426,14 +595,11 @@ class User(BaseModel): self.access_keys.append(access_key) return access_key - def enable_mfa_device(self, - serial_number, - authentication_code_1, - authentication_code_2): + def enable_mfa_device( + self, serial_number, authentication_code_1, authentication_code_2 + ): self.mfa_devices[serial_number] = MFADevice( - serial_number, - authentication_code_1, - authentication_code_2 + serial_number, authentication_code_1, authentication_code_2 ) def get_all_access_keys(self): @@ -453,58 +619,296 @@ class User(BaseModel): return key else: raise IAMNotFoundException( - "The Access Key with id {0} cannot be found".format(access_key_id)) + "The Access Key with id {0} cannot be found".format(access_key_id) + ) + + def upload_ssh_public_key(self, ssh_public_key_body): + pubkey = SshPublicKey(self.name, ssh_public_key_body) + self.ssh_public_keys.append(pubkey) + return pubkey + + def get_ssh_public_key(self, ssh_public_key_id): + for key in self.ssh_public_keys: + if key.ssh_public_key_id == ssh_public_key_id: + return key + else: + raise IAMNotFoundException( + "The SSH Public Key with id {0} cannot be found".format( + ssh_public_key_id + ) + ) + + def get_all_ssh_public_keys(self): + return self.ssh_public_keys + + def update_ssh_public_key(self, ssh_public_key_id, status): + key = self.get_ssh_public_key(ssh_public_key_id) + key.status = status + + def delete_ssh_public_key(self, ssh_public_key_id): + key = self.get_ssh_public_key(ssh_public_key_id) + self.ssh_public_keys.remove(key) def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "Arn" ]"') raise UnformattedGetAttTemplateException() def to_csv(self): - date_format = '%Y-%m-%dT%H:%M:%S+00:00' + date_format = "%Y-%m-%dT%H:%M:%S+00:00" date_created = self.create_date # aagrawal,arn:aws:iam::509284790694:user/aagrawal,2014-09-01T22:28:48+00:00,true,2014-11-12T23:36:49+00:00,2014-09-03T18:59:00+00:00,N/A,false,true,2014-09-01T22:28:48+00:00,false,N/A,false,N/A,false,N/A if not self.password: - password_enabled = 'false' - password_last_used = 'not_supported' + password_enabled = "false" + password_last_used = "not_supported" else: - password_enabled = 'true' - password_last_used = 'no_information' + password_enabled = "true" + password_last_used = "no_information" if len(self.access_keys) == 0: - access_key_1_active = 'false' - access_key_1_last_rotated = 'N/A' - access_key_2_active = 'false' - access_key_2_last_rotated = 'N/A' + access_key_1_active = "false" + access_key_1_last_rotated = "N/A" + access_key_2_active = "false" + access_key_2_last_rotated = "N/A" elif len(self.access_keys) == 1: - access_key_1_active = 'true' + access_key_1_active = "true" access_key_1_last_rotated = date_created.strftime(date_format) - access_key_2_active = 'false' - access_key_2_last_rotated = 'N/A' + access_key_2_active = "false" + access_key_2_last_rotated = "N/A" else: - access_key_1_active = 'true' + access_key_1_active = "true" access_key_1_last_rotated = date_created.strftime(date_format) - access_key_2_active = 'true' + access_key_2_active = "true" access_key_2_last_rotated = date_created.strftime(date_format) - return '{0},{1},{2},{3},{4},{5},not_supported,false,{6},{7},{8},{9},false,N/A,false,N/A'.format(self.name, - self.arn, - date_created.strftime( - date_format), - password_enabled, - password_last_used, - date_created.strftime( - date_format), - access_key_1_active, - access_key_1_last_rotated, - access_key_2_active, - access_key_2_last_rotated - ) + return "{0},{1},{2},{3},{4},{5},not_supported,false,{6},{7},{8},{9},false,N/A,false,N/A".format( + self.name, + self.arn, + date_created.strftime(date_format), + password_enabled, + password_last_used, + date_created.strftime(date_format), + access_key_1_active, + access_key_1_last_rotated, + access_key_2_active, + access_key_2_last_rotated, + ) + + +class AccountPasswordPolicy(BaseModel): + def __init__( + self, + allow_change_password, + hard_expiry, + max_password_age, + minimum_password_length, + password_reuse_prevention, + require_lowercase_characters, + require_numbers, + require_symbols, + require_uppercase_characters, + ): + self._errors = [] + self._validate( + max_password_age, minimum_password_length, password_reuse_prevention + ) + + self.allow_users_to_change_password = allow_change_password + self.hard_expiry = hard_expiry + self.max_password_age = max_password_age + self.minimum_password_length = minimum_password_length + self.password_reuse_prevention = password_reuse_prevention + self.require_lowercase_characters = require_lowercase_characters + self.require_numbers = require_numbers + self.require_symbols = require_symbols + self.require_uppercase_characters = require_uppercase_characters + + @property + def expire_passwords(self): + return True if self.max_password_age and self.max_password_age > 0 else False + + def _validate( + self, max_password_age, minimum_password_length, password_reuse_prevention + ): + if minimum_password_length > 128: + self._errors.append( + self._format_error( + key="minimumPasswordLength", + value=minimum_password_length, + constraint="Member must have value less than or equal to 128", + ) + ) + + if password_reuse_prevention and password_reuse_prevention > 24: + self._errors.append( + self._format_error( + key="passwordReusePrevention", + value=password_reuse_prevention, + constraint="Member must have value less than or equal to 24", + ) + ) + + if max_password_age and max_password_age > 1095: + self._errors.append( + self._format_error( + key="maxPasswordAge", + value=max_password_age, + constraint="Member must have value less than or equal to 1095", + ) + ) + + self._raise_errors() + + def _format_error(self, key, value, constraint): + return 'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}'.format( + constraint=constraint, key=key, value=value + ) + + def _raise_errors(self): + if self._errors: + count = len(self._errors) + plural = "s" if len(self._errors) > 1 else "" + errors = "; ".join(self._errors) + self._errors = [] # reset collected errors + + raise ValidationError( + "{count} validation error{plural} detected: {errors}".format( + count=count, plural=plural, errors=errors + ) + ) + + +class AccountSummary(BaseModel): + def __init__(self, iam_backend): + self._iam_backend = iam_backend + + self._group_policy_size_quota = 5120 + self._instance_profiles_quota = 1000 + self._groups_per_user_quota = 10 + self._attached_policies_per_user_quota = 10 + self._policies_quota = 1500 + self._account_mfa_enabled = 0 # Haven't found any information being able to activate MFA for the root account programmatically + self._access_keys_per_user_quota = 2 + self._assume_role_policy_size_quota = 2048 + self._policy_versions_in_use_quota = 10000 + self._global_endpoint_token_version = ( + 1 # ToDo: Implement set_security_token_service_preferences() + ) + self._versions_per_policy_quota = 5 + self._attached_policies_per_group_quota = 10 + self._policy_size_quota = 6144 + self._account_signing_certificates_present = 0 # valid values: 0 | 1 + self._users_quota = 5000 + self._server_certificates_quota = 20 + self._user_policy_size_quota = 2048 + self._roles_quota = 1000 + self._signing_certificates_per_user_quota = 2 + self._role_policy_size_quota = 10240 + self._attached_policies_per_role_quota = 10 + self._account_access_keys_present = 0 # valid values: 0 | 1 + self._groups_quota = 300 + + @property + def summary_map(self): + return { + "GroupPolicySizeQuota": self._group_policy_size_quota, + "InstanceProfilesQuota": self._instance_profiles_quota, + "Policies": self._policies, + "GroupsPerUserQuota": self._groups_per_user_quota, + "InstanceProfiles": self._instance_profiles, + "AttachedPoliciesPerUserQuota": self._attached_policies_per_user_quota, + "Users": self._users, + "PoliciesQuota": self._policies_quota, + "Providers": self._providers, + "AccountMFAEnabled": self._account_mfa_enabled, + "AccessKeysPerUserQuota": self._access_keys_per_user_quota, + "AssumeRolePolicySizeQuota": self._assume_role_policy_size_quota, + "PolicyVersionsInUseQuota": self._policy_versions_in_use_quota, + "GlobalEndpointTokenVersion": self._global_endpoint_token_version, + "VersionsPerPolicyQuota": self._versions_per_policy_quota, + "AttachedPoliciesPerGroupQuota": self._attached_policies_per_group_quota, + "PolicySizeQuota": self._policy_size_quota, + "Groups": self._groups, + "AccountSigningCertificatesPresent": self._account_signing_certificates_present, + "UsersQuota": self._users_quota, + "ServerCertificatesQuota": self._server_certificates_quota, + "MFADevices": self._mfa_devices, + "UserPolicySizeQuota": self._user_policy_size_quota, + "PolicyVersionsInUse": self._policy_versions_in_use, + "ServerCertificates": self._server_certificates, + "Roles": self._roles, + "RolesQuota": self._roles_quota, + "SigningCertificatesPerUserQuota": self._signing_certificates_per_user_quota, + "MFADevicesInUse": self._mfa_devices_in_use, + "RolePolicySizeQuota": self._role_policy_size_quota, + "AttachedPoliciesPerRoleQuota": self._attached_policies_per_role_quota, + "AccountAccessKeysPresent": self._account_access_keys_present, + "GroupsQuota": self._groups_quota, + } + + @property + def _groups(self): + return len(self._iam_backend.groups) + + @property + def _instance_profiles(self): + return len(self._iam_backend.instance_profiles) + + @property + def _mfa_devices(self): + # Don't know, if hardware devices are also counted here + return len(self._iam_backend.virtual_mfa_devices) + + @property + def _mfa_devices_in_use(self): + devices = 0 + + for user in self._iam_backend.users.values(): + devices += len(user.mfa_devices) + + return devices + + @property + def _policies(self): + customer_policies = [ + policy + for policy in self._iam_backend.managed_policies + if not policy.startswith("arn:aws:iam::aws:policy") + ] + return len(customer_policies) + + @property + def _policy_versions_in_use(self): + attachments = 0 + + for policy in self._iam_backend.managed_policies.values(): + attachments += policy.attachment_count + + return attachments + + @property + def _providers(self): + providers = len(self._iam_backend.saml_providers) + len( + self._iam_backend.open_id_providers + ) + return providers + + @property + def _roles(self): + return len(self._iam_backend.roles) + + @property + def _server_certificates(self): + return len(self._iam_backend.certificates) + + @property + def _users(self): + return len(self._iam_backend.users) class IAMBackend(BaseBackend): - def __init__(self): self.instance_profiles = {} self.roles = {} @@ -515,8 +919,11 @@ class IAMBackend(BaseBackend): self.managed_policies = self._init_managed_policies() self.account_aliases = [] self.saml_providers = {} - self.policy_arn_regex = re.compile( - r'^arn:aws:iam::[0-9]*:policy/.*$') + self.open_id_providers = {} + self.policy_arn_regex = re.compile(r"^arn:aws:iam::[0-9]*:policy/.*$") + self.virtual_mfa_devices = {} + self.account_password_policy = None + self.account_summary = AccountSummary(self) super(IAMBackend, self).__init__() def _init_managed_policies(self): @@ -532,9 +939,10 @@ class IAMBackend(BaseBackend): role.description = role_description return role - def update_role(self, role_name, role_description): + def update_role(self, role_name, role_description, max_session_duration): role = self.get_role(role_name) role.description = role_description + role.max_session_duration = max_session_duration return role def detach_role_policy(self, policy_arn, role_name): @@ -582,11 +990,14 @@ class IAMBackend(BaseBackend): iam_policy_document_validator.validate() policy = ManagedPolicy( - policy_name, - description=description, - document=policy_document, - path=path, + policy_name, description=description, document=policy_document, path=path ) + if policy.arn in self.managed_policies: + raise EntityAlreadyExists( + "A policy called {0} already exists. Duplicate names are not allowed.".format( + policy_name + ) + ) self.managed_policies[policy.arn] = policy return policy @@ -595,15 +1006,21 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException("Policy {0} not found".format(policy_arn)) return self.managed_policies.get(policy_arn) - def list_attached_role_policies(self, role_name, marker=None, max_items=100, path_prefix='/'): + def list_attached_role_policies( + self, role_name, marker=None, max_items=100, path_prefix="/" + ): policies = self.get_role(role_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) - def list_attached_group_policies(self, group_name, marker=None, max_items=100, path_prefix='/'): + def list_attached_group_policies( + self, group_name, marker=None, max_items=100, path_prefix="/" + ): policies = self.get_group(group_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) - def list_attached_user_policies(self, user_name, marker=None, max_items=100, path_prefix='/'): + def list_attached_user_policies( + self, user_name, marker=None, max_items=100, path_prefix="/" + ): policies = self.get_user(user_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) @@ -613,11 +1030,10 @@ class IAMBackend(BaseBackend): if only_attached: policies = [p for p in policies if p.attachment_count > 0] - if scope == 'AWS': + if scope == "AWS": policies = [p for p in policies if isinstance(p, AWSManagedPolicy)] - elif scope == 'Local': - policies = [p for p in policies if not isinstance( - p, AWSManagedPolicy)] + elif scope == "Local": + policies = [p for p in policies if not isinstance(p, AWSManagedPolicy)] return self._filter_attached_policies(policies, marker, max_items, path_prefix) @@ -628,7 +1044,7 @@ class IAMBackend(BaseBackend): policies = sorted(policies, key=lambda policy: policy.name) start_idx = int(marker) if marker else 0 - policies = policies[start_idx:start_idx + max_items] + policies = policies[start_idx : start_idx + max_items] if len(policies) < max_items: marker = None @@ -637,13 +1053,42 @@ class IAMBackend(BaseBackend): return policies, marker - def create_role(self, role_name, assume_role_policy_document, path, permissions_boundary, description, tags): + def create_role( + self, + role_name, + assume_role_policy_document, + path, + permissions_boundary, + description, + tags, + max_session_duration, + ): role_id = random_resource_id() - if permissions_boundary and not self.policy_arn_regex.match(permissions_boundary): - raise RESTError('InvalidParameterValue', 'Value ({}) for parameter PermissionsBoundary is invalid.'.format(permissions_boundary)) + if permissions_boundary and not self.policy_arn_regex.match( + permissions_boundary + ): + raise RESTError( + "InvalidParameterValue", + "Value ({}) for parameter PermissionsBoundary is invalid.".format( + permissions_boundary + ), + ) + if [role for role in self.get_roles() if role.name == role_name]: + raise EntityAlreadyExists( + "Role with name {0} already exists.".format(role_name) + ) clean_tags = self._tag_verification(tags) - role = Role(role_id, role_name, assume_role_policy_document, path, permissions_boundary, description, clean_tags) + role = Role( + role_id, + role_name, + assume_role_policy_document, + path, + permissions_boundary, + description, + clean_tags, + max_session_duration, + ) self.roles[role_id] = role return role @@ -663,11 +1108,25 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException("Role {0} not found".format(arn)) def delete_role(self, role_name): - for role in self.get_roles(): - if role.name == role_name: - del self.roles[role.id] - return - raise IAMNotFoundException("Role {0} not found".format(role_name)) + role = self.get_role(role_name) + for instance_profile in self.get_instance_profiles(): + for role in instance_profile.roles: + if role.name == role_name: + raise IAMConflictException( + code="DeleteConflict", + message="Cannot delete entity, must remove roles from instance profile first.", + ) + if role.managed_policies: + raise IAMConflictException( + code="DeleteConflict", + message="Cannot delete entity, must detach all policies first.", + ) + if role.policies: + raise IAMConflictException( + code="DeleteConflict", + message="Cannot delete entity, must delete policies first.", + ) + del self.roles[role.id] def get_roles(self): return self.roles.values() @@ -688,7 +1147,11 @@ class IAMBackend(BaseBackend): for p, d in role.policies.items(): if p == policy_name: return p, d - raise IAMNotFoundException("Policy Document {0} not attached to role {1}".format(policy_name, role_name)) + raise IAMNotFoundException( + "Policy Document {0} not attached to role {1}".format( + policy_name, role_name + ) + ) def list_role_policies(self, role_name): role = self.get_role(role_name) @@ -701,17 +1164,17 @@ class IAMBackend(BaseBackend): tag_keys = {} for tag in tags: # Need to index by the lowercase tag key since the keys are case insensitive, but their case is retained. - ref_key = tag['Key'].lower() + ref_key = tag["Key"].lower() self._check_tag_duplicate(tag_keys, ref_key) - self._validate_tag_key(tag['Key']) - if len(tag['Value']) > 256: - raise TagValueTooBig(tag['Value']) + self._validate_tag_key(tag["Key"]) + if len(tag["Value"]) > 256: + raise TagValueTooBig(tag["Value"]) tag_keys[ref_key] = tag return tag_keys - def _validate_tag_key(self, tag_key, exception_param='tags.X.member.key'): + def _validate_tag_key(self, tag_key, exception_param="tags.X.member.key"): """Validates the tag key. :param tag_key: The tag key to check against. @@ -725,7 +1188,7 @@ class IAMBackend(BaseBackend): # Validate that the tag key fits the proper Regex: # [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+ - match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key) + match = re.findall(r"[\w\s_.:/=+\-@]+", tag_key) # Kudos if you can come up with a better way of doing a global search :) if not len(match) or len(match[0]) < len(tag_key): raise InvalidTagCharacters(tag_key, param=exception_param) @@ -747,7 +1210,7 @@ class IAMBackend(BaseBackend): tag_index = sorted(role.tags) start_idx = int(marker) if marker else 0 - tag_index = tag_index[start_idx:start_idx + max_items] + tag_index = tag_index[start_idx : start_idx + max_items] if len(role.tags) <= (start_idx + max_items): marker = None @@ -766,13 +1229,13 @@ class IAMBackend(BaseBackend): def untag_role(self, role_name, tag_keys): if len(tag_keys) > 50: - raise TooManyTags(tag_keys, param='tagKeys') + raise TooManyTags(tag_keys, param="tagKeys") role = self.get_role(role_name) for key in tag_keys: ref_key = key.lower() - self._validate_tag_key(key, exception_param='tagKeys') + self._validate_tag_key(key, exception_param="tagKeys") role.tags.pop(ref_key, None) @@ -784,11 +1247,13 @@ class IAMBackend(BaseBackend): if not policy: raise IAMNotFoundException("Policy not found") if len(policy.versions) >= 5: - raise IAMLimitExceededException("A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version.") - set_as_default = (set_as_default == "true") # convert it to python bool + raise IAMLimitExceededException( + "A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version." + ) + set_as_default = set_as_default == "true" # convert it to python bool version = PolicyVersion(policy_arn, policy_document, set_as_default) policy.versions.append(version) - version.version_id = 'v{0}'.format(policy.next_version_num) + version.version_id = "v{0}".format(policy.next_version_num) policy.next_version_num += 1 if set_as_default: policy.update_default_version(version.version_id) @@ -814,8 +1279,10 @@ class IAMBackend(BaseBackend): if not policy: raise IAMNotFoundException("Policy not found") if version_id == policy.default_version_id: - raise IAMConflictException(code="DeleteConflict", - message="Cannot delete the default version of a policy.") + raise IAMConflictException( + code="DeleteConflict", + message="Cannot delete the default version of a policy.", + ) for i, v in enumerate(policy.versions): if v.version_id == version_id: del policy.versions[i] @@ -823,12 +1290,17 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException("Policy not found") def create_instance_profile(self, name, path, role_ids): + if self.instance_profiles.get(name): + raise IAMConflictException( + code="EntityAlreadyExists", + message="Instance Profile {0} already exists.".format(name), + ) + instance_profile_id = random_resource_id() roles = [iam_backend.get_role_by_id(role_id) for role_id in role_ids] - instance_profile = InstanceProfile( - instance_profile_id, name, path, roles) - self.instance_profiles[instance_profile_id] = instance_profile + instance_profile = InstanceProfile(instance_profile_id, name, path, roles) + self.instance_profiles[name] = instance_profile return instance_profile def get_instance_profile(self, profile_name): @@ -837,7 +1309,8 @@ class IAMBackend(BaseBackend): return profile raise IAMNotFoundException( - "Instance profile {0} not found".format(profile_name)) + "Instance profile {0} not found".format(profile_name) + ) def get_instance_profiles(self): return self.instance_profiles.values() @@ -865,7 +1338,9 @@ class IAMBackend(BaseBackend): def get_all_server_certs(self, marker=None): return self.certificates.values() - def upload_server_cert(self, cert_name, cert_body, private_key, cert_chain=None, path=None): + def upload_server_certificate( + self, cert_name, cert_body, private_key, cert_chain=None, path=None + ): certificate_id = random_resource_id() cert = Certificate(cert_name, cert_body, private_key, cert_chain, path) self.certificates[certificate_id] = cert @@ -877,8 +1352,8 @@ class IAMBackend(BaseBackend): return cert raise IAMNotFoundException( - "The Server Certificate with name {0} cannot be " - "found.".format(name)) + "The Server Certificate with name {0} cannot be " "found.".format(name) + ) def delete_server_certificate(self, name): cert_id = None @@ -889,15 +1364,14 @@ class IAMBackend(BaseBackend): if cert_id is None: raise IAMNotFoundException( - "The Server Certificate with name {0} cannot be " - "found.".format(name)) + "The Server Certificate with name {0} cannot be " "found.".format(name) + ) self.certificates.pop(cert_id, None) - def create_group(self, group_name, path='/'): + def create_group(self, group_name, path="/"): if group_name in self.groups: - raise IAMConflictException( - "Group {0} already exists".format(group_name)) + raise IAMConflictException("Group {0} already exists".format(group_name)) group = Group(group_name, path) self.groups[group_name] = group @@ -908,8 +1382,7 @@ class IAMBackend(BaseBackend): try: group = self.groups[group_name] except KeyError: - raise IAMNotFoundException( - "Group {0} not found".format(group_name)) + raise IAMNotFoundException("Group {0} not found".format(group_name)) return group @@ -940,10 +1413,19 @@ class IAMBackend(BaseBackend): group = self.get_group(group_name) return group.get_policy(policy_name) - def create_user(self, user_name, path='/'): + def delete_group(self, group_name): + try: + del self.groups[group_name] + except KeyError: + raise IAMNotFoundException( + "The group with name {0} cannot be found.".format(group_name) + ) + + def create_user(self, user_name, path="/"): if user_name in self.users: raise IAMConflictException( - "EntityAlreadyExists", "User {0} already exists".format(user_name)) + "EntityAlreadyExists", "User {0} already exists".format(user_name) + ) user = User(user_name, path) self.users[user_name] = user @@ -964,7 +1446,8 @@ class IAMBackend(BaseBackend): users = self.users.values() except KeyError: raise IAMNotFoundException( - "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items)) + "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items) + ) return users @@ -986,7 +1469,8 @@ class IAMBackend(BaseBackend): roles = self.roles.values() except KeyError: raise IAMNotFoundException( - "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items)) + "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items) + ) return roles @@ -999,14 +1483,16 @@ class IAMBackend(BaseBackend): if sys.version_info < (3, 0): data = bytes(body) else: - data = bytes(body, 'utf8') + data = bytes(body, "utf8") x509.load_pem_x509_certificate(data, default_backend()) except Exception: raise MalformedCertificate(body) - user.signing_certificates[cert_id] = SigningCertificate(cert_id, user_name, body) + user.signing_certificates[cert_id] = SigningCertificate( + cert_id, user_name, body + ) return user.signing_certificates[cert_id] @@ -1016,7 +1502,9 @@ class IAMBackend(BaseBackend): try: del user.signing_certificates[cert_id] except KeyError: - raise IAMNotFoundException("The Certificate with id {id} cannot be found.".format(id=cert_id)) + raise IAMNotFoundException( + "The Certificate with id {id} cannot be found.".format(id=cert_id) + ) def list_signing_certificates(self, user_name): user = self.get_user(user_name) @@ -1030,14 +1518,17 @@ class IAMBackend(BaseBackend): user.signing_certificates[cert_id].status = status except KeyError: - raise IAMNotFoundException("The Certificate with id {id} cannot be found.".format(id=cert_id)) + raise IAMNotFoundException( + "The Certificate with id {id} cannot be found.".format(id=cert_id) + ) def create_login_profile(self, user_name, password): # This does not currently deal with PasswordPolicyViolation. user = self.get_user(user_name) if user.password: raise IAMConflictException( - "User {0} already has password".format(user_name)) + "User {0} already has password".format(user_name) + ) user.password = password return user @@ -1045,7 +1536,8 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) if not user.password: raise IAMNotFoundException( - "Login profile for {0} not found".format(user_name)) + "Login profile for {0} not found".format(user_name) + ) return user def update_login_profile(self, user_name, password, password_reset_required): @@ -1053,7 +1545,8 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) if not user.password: raise IAMNotFoundException( - "Login profile for {0} not found".format(user_name)) + "Login profile for {0} not found".format(user_name) + ) user.password = password user.password_reset_required = password_reset_required return user @@ -1062,7 +1555,8 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) if not user.password: raise IAMNotFoundException( - "Login profile for {0} not found".format(user_name)) + "Login profile for {0} not found".format(user_name) + ) user.password = None def add_user_to_group(self, group_name, user_name): @@ -1077,7 +1571,8 @@ class IAMBackend(BaseBackend): group.users.remove(user) except ValueError: raise IAMNotFoundException( - "User {0} not in group {1}".format(user_name, group_name)) + "User {0} not in group {1}".format(user_name, group_name) + ) def get_user_policy(self, user_name, policy_name): user = self.get_user(user_name) @@ -1099,6 +1594,9 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) user.delete_policy(policy_name) + def delete_policy(self, policy_arn): + del self.managed_policies[policy_arn] + def create_access_key(self, user_name=None): user = self.get_user(user_name) key = user.create_access_key() @@ -1112,13 +1610,11 @@ class IAMBackend(BaseBackend): access_keys_list = self.get_all_access_keys_for_all_users() for key in access_keys_list: if key.access_key_id == access_key_id: - return { - 'user_name': key.user_name, - 'last_used': key.last_used_iso_8601, - } + return {"user_name": key.user_name, "last_used": key.last_used_iso_8601} else: raise IAMNotFoundException( - "The Access Key with id {0} cannot be found".format(access_key_id)) + "The Access Key with id {0} cannot be found".format(access_key_id) + ) def get_all_access_keys_for_all_users(self): access_keys_list = [] @@ -1135,32 +1631,66 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) user.delete_access_key(access_key_id) - def enable_mfa_device(self, - user_name, - serial_number, - authentication_code_1, - authentication_code_2): + def upload_ssh_public_key(self, user_name, ssh_public_key_body): + user = self.get_user(user_name) + return user.upload_ssh_public_key(ssh_public_key_body) + + def get_ssh_public_key(self, user_name, ssh_public_key_id): + user = self.get_user(user_name) + return user.get_ssh_public_key(ssh_public_key_id) + + def get_all_ssh_public_keys(self, user_name): + user = self.get_user(user_name) + return user.get_all_ssh_public_keys() + + def update_ssh_public_key(self, user_name, ssh_public_key_id, status): + user = self.get_user(user_name) + return user.update_ssh_public_key(ssh_public_key_id, status) + + def delete_ssh_public_key(self, user_name, ssh_public_key_id): + user = self.get_user(user_name) + return user.delete_ssh_public_key(ssh_public_key_id) + + def enable_mfa_device( + self, user_name, serial_number, authentication_code_1, authentication_code_2 + ): """Enable MFA Device for user.""" user = self.get_user(user_name) if serial_number in user.mfa_devices: raise IAMConflictException( - "EntityAlreadyExists", - "Device {0} already exists".format(serial_number) + "EntityAlreadyExists", "Device {0} already exists".format(serial_number) ) + device = self.virtual_mfa_devices.get(serial_number, None) + if device: + device.enable_date = datetime.utcnow() + device.user = user + device.user_attribute = { + "Path": user.path, + "UserName": user.name, + "UserId": user.id, + "Arn": user.arn, + "CreateDate": user.created_iso_8601, + "PasswordLastUsed": None, # not supported + "PermissionsBoundary": {}, # ToDo: add put_user_permissions_boundary() functionality + "Tags": {}, # ToDo: add tag_user() functionality + } + user.enable_mfa_device( - serial_number, - authentication_code_1, - authentication_code_2 + serial_number, authentication_code_1, authentication_code_2 ) def deactivate_mfa_device(self, user_name, serial_number): """Deactivate and detach MFA Device from user if device exists.""" user = self.get_user(user_name) if serial_number not in user.mfa_devices: - raise IAMNotFoundException( - "Device {0} not found".format(serial_number) - ) + raise IAMNotFoundException("Device {0} not found".format(serial_number)) + + device = self.virtual_mfa_devices.get(serial_number, None) + if device: + device.enable_date = None + device.user = None + device.user_attribute = None user.deactivate_mfa_device(serial_number) @@ -1168,11 +1698,87 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) return user.mfa_devices.values() + def create_virtual_mfa_device(self, device_name, path): + if not path: + path = "/" + + if not path.startswith("/") and not path.endswith("/"): + raise ValidationError( + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters." + ) + + if any(not len(part) for part in path.split("/")[1:-1]): + raise ValidationError( + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters." + ) + + if len(path) > 512: + raise ValidationError( + "1 validation error detected: " + 'Value "{}" at "path" failed to satisfy constraint: ' + "Member must have length less than or equal to 512" + ) + + device = VirtualMfaDevice(path + device_name) + + if device.serial_number in self.virtual_mfa_devices: + raise EntityAlreadyExists( + "MFADevice entity at the same path and name already exists." + ) + + self.virtual_mfa_devices[device.serial_number] = device + return device + + def delete_virtual_mfa_device(self, serial_number): + device = self.virtual_mfa_devices.pop(serial_number, None) + + if not device: + raise IAMNotFoundException( + "VirtualMFADevice with serial number {0} doesn't exist.".format( + serial_number + ) + ) + + def list_virtual_mfa_devices(self, assignment_status, marker, max_items): + devices = list(self.virtual_mfa_devices.values()) + + if assignment_status == "Assigned": + devices = [device for device in devices if device.enable_date] + + if assignment_status == "Unassigned": + devices = [device for device in devices if not device.enable_date] + + sorted(devices, key=lambda device: device.serial_number) + max_items = int(max_items) + start_idx = int(marker) if marker else 0 + + if start_idx > len(devices): + raise ValidationError("Invalid Marker.") + + devices = devices[start_idx : start_idx + max_items] + + if len(devices) < max_items: + marker = None + else: + marker = str(start_idx + max_items) + + return devices, marker + def delete_user(self, user_name): - try: - del self.users[user_name] - except KeyError: - raise IAMNotFoundException("User {0} not found".format(user_name)) + user = self.get_user(user_name) + if user.managed_policies: + raise IAMConflictException( + code="DeleteConflict", + message="Cannot delete entity, must detach all policies first.", + ) + if user.policies: + raise IAMConflictException( + code="DeleteConflict", + message="Cannot delete entity, must delete policies first.", + ) + del self.users[user_name] def report_generated(self): return self.credential_report @@ -1183,10 +1789,10 @@ class IAMBackend(BaseBackend): def get_credential_report(self): if not self.credential_report: raise IAMReportNotPresentException("Credential report not present") - report = 'user,arn,user_creation_time,password_enabled,password_last_used,password_last_changed,password_next_rotation,mfa_active,access_key_1_active,access_key_1_last_rotated,access_key_2_active,access_key_2_last_rotated,cert_1_active,cert_1_last_rotated,cert_2_active,cert_2_last_rotated\n' + report = "user,arn,user_creation_time,password_enabled,password_last_used,password_last_changed,password_next_rotation,mfa_active,access_key_1_active,access_key_1_last_rotated,access_key_2_active,access_key_2_last_rotated,cert_1_active,cert_1_last_rotated,cert_2_active,cert_2_last_rotated\n" for user in self.users: report += self.users[user].to_csv() - return base64.b64encode(report.encode('ascii')).decode('ascii') + return base64.b64encode(report.encode("ascii")).decode("ascii") def list_account_aliases(self): return self.account_aliases @@ -1205,24 +1811,24 @@ class IAMBackend(BaseBackend): if len(filter) == 0: return { - 'instance_profiles': self.instance_profiles.values(), - 'roles': self.roles.values(), - 'groups': self.groups.values(), - 'users': self.users.values(), - 'managed_policies': self.managed_policies.values() + "instance_profiles": self.instance_profiles.values(), + "roles": self.roles.values(), + "groups": self.groups.values(), + "users": self.users.values(), + "managed_policies": self.managed_policies.values(), } - if 'AWSManagedPolicy' in filter: + if "AWSManagedPolicy" in filter: returned_policies = aws_managed_policies - if 'LocalManagedPolicy' in filter: + if "LocalManagedPolicy" in filter: returned_policies = returned_policies + list(local_policies) return { - 'instance_profiles': self.instance_profiles.values(), - 'roles': self.roles.values() if 'Role' in filter else [], - 'groups': self.groups.values() if 'Group' in filter else [], - 'users': self.users.values() if 'User' in filter else [], - 'managed_policies': returned_policies + "instance_profiles": self.instance_profiles.values(), + "roles": self.roles.values() if "Role" in filter else [], + "groups": self.groups.values() if "Group" in filter else [], + "users": self.users.values() if "User" in filter else [], + "managed_policies": returned_policies, } def create_saml_provider(self, name, saml_metadata_document): @@ -1242,7 +1848,8 @@ class IAMBackend(BaseBackend): del self.saml_providers[saml_provider.name] except KeyError: raise IAMNotFoundException( - "SAMLProvider {0} not found".format(saml_provider_arn)) + "SAMLProvider {0} not found".format(saml_provider_arn) + ) def list_saml_providers(self): return self.saml_providers.values() @@ -1251,7 +1858,9 @@ class IAMBackend(BaseBackend): for saml_provider in self.list_saml_providers(): if saml_provider.arn == saml_provider_arn: return saml_provider - raise IAMNotFoundException("SamlProvider {0} not found".format(saml_provider_arn)) + raise IAMNotFoundException( + "SamlProvider {0} not found".format(saml_provider_arn) + ) def get_user_from_access_key_id(self, access_key_id): for user_name, user in self.users.items(): @@ -1261,5 +1870,75 @@ class IAMBackend(BaseBackend): return user return None + def create_open_id_connect_provider(self, url, thumbprint_list, client_id_list): + open_id_provider = OpenIDConnectProvider(url, thumbprint_list, client_id_list) + + if open_id_provider.arn in self.open_id_providers: + raise EntityAlreadyExists("Unknown") + + self.open_id_providers[open_id_provider.arn] = open_id_provider + return open_id_provider + + def delete_open_id_connect_provider(self, arn): + self.open_id_providers.pop(arn, None) + + def get_open_id_connect_provider(self, arn): + open_id_provider = self.open_id_providers.get(arn) + + if not open_id_provider: + raise IAMNotFoundException( + "OpenIDConnect Provider not found for arn {}".format(arn) + ) + + return open_id_provider + + def list_open_id_connect_providers(self): + return list(self.open_id_providers.keys()) + + def update_account_password_policy( + self, + allow_change_password, + hard_expiry, + max_password_age, + minimum_password_length, + password_reuse_prevention, + require_lowercase_characters, + require_numbers, + require_symbols, + require_uppercase_characters, + ): + self.account_password_policy = AccountPasswordPolicy( + allow_change_password, + hard_expiry, + max_password_age, + minimum_password_length, + password_reuse_prevention, + require_lowercase_characters, + require_numbers, + require_symbols, + require_uppercase_characters, + ) + + def get_account_password_policy(self): + if not self.account_password_policy: + raise NoSuchEntity( + "The Password Policy with domain name {} cannot be found.".format( + ACCOUNT_ID + ) + ) + + return self.account_password_policy + + def delete_account_password_policy(self): + if not self.account_password_policy: + raise NoSuchEntity( + "The account policy with name PasswordPolicy cannot be found." + ) + + self.account_password_policy = None + + def get_account_summary(self): + return self.account_summary + iam_backend = IAMBackend() diff --git a/moto/iam/policy_validation.py b/moto/iam/policy_validation.py index 6ee286072..95610ac4d 100644 --- a/moto/iam/policy_validation.py +++ b/moto/iam/policy_validation.py @@ -6,17 +6,9 @@ from six import string_types from moto.iam.exceptions import MalformedPolicyDocument -VALID_TOP_ELEMENTS = [ - "Version", - "Id", - "Statement", - "Conditions" -] +VALID_TOP_ELEMENTS = ["Version", "Id", "Statement", "Conditions"] -VALID_VERSIONS = [ - "2008-10-17", - "2012-10-17" -] +VALID_VERSIONS = ["2008-10-17", "2012-10-17"] VALID_STATEMENT_ELEMENTS = [ "Sid", @@ -25,13 +17,10 @@ VALID_STATEMENT_ELEMENTS = [ "Resource", "NotResource", "Effect", - "Condition" + "Condition", ] -VALID_EFFECTS = [ - "Allow", - "Deny" -] +VALID_EFFECTS = ["Allow", "Deny"] VALID_CONDITIONS = [ "StringEquals", @@ -60,34 +49,41 @@ VALID_CONDITIONS = [ "ArnLike", "ArnNotEquals", "ArnNotLike", - "Null" + "Null", ] -VALID_CONDITION_PREFIXES = [ - "ForAnyValue:", - "ForAllValues:" -] +VALID_CONDITION_PREFIXES = ["ForAnyValue:", "ForAllValues:"] -VALID_CONDITION_POSTFIXES = [ - "IfExists" -] +VALID_CONDITION_POSTFIXES = ["IfExists"] SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS = { - "iam": 'IAM resource {resource} cannot contain region information.', - "s3": 'Resource {resource} can not contain region information.' + "iam": "IAM resource {resource} cannot contain region information.", + "s3": "Resource {resource} can not contain region information.", } VALID_RESOURCE_PATH_STARTING_VALUES = { "iam": { - "values": ["user/", "federated-user/", "role/", "group/", "instance-profile/", "mfa/", "server-certificate/", - "policy/", "sms-mfa/", "saml-provider/", "oidc-provider/", "report/", "access-report/"], - "error_message": 'IAM resource path must either be "*" or start with {values}.' + "values": [ + "user/", + "federated-user/", + "role/", + "group/", + "instance-profile/", + "mfa/", + "server-certificate/", + "policy/", + "sms-mfa/", + "saml-provider/", + "oidc-provider/", + "report/", + "access-report/", + ], + "error_message": 'IAM resource path must either be "*" or start with {values}.', } } class IAMPolicyDocumentValidator: - def __init__(self, policy_document): self._policy_document = policy_document self._policy_json = {} @@ -102,7 +98,9 @@ class IAMPolicyDocumentValidator: try: self._validate_version() except Exception: - raise MalformedPolicyDocument("Policy document must be version 2012-10-17 or greater.") + raise MalformedPolicyDocument( + "Policy document must be version 2012-10-17 or greater." + ) try: self._perform_first_legacy_parsing() self._validate_resources_for_formats() @@ -112,7 +110,9 @@ class IAMPolicyDocumentValidator: try: self._validate_sid_uniqueness() except Exception: - raise MalformedPolicyDocument("Statement IDs (SID) in a single policy must be unique.") + raise MalformedPolicyDocument( + "Statement IDs (SID) in a single policy must be unique." + ) try: self._validate_action_like_exist() except Exception: @@ -152,8 +152,10 @@ class IAMPolicyDocumentValidator: sids = [] for statement in self._statements: if "Sid" in statement: - assert statement["Sid"] not in sids - sids.append(statement["Sid"]) + statementId = statement["Sid"] + if statementId: + assert statementId not in sids + sids.append(statementId) def _validate_statements_syntax(self): assert "Statement" in self._policy_json @@ -174,8 +176,8 @@ class IAMPolicyDocumentValidator: for statement_element in statement.keys(): assert statement_element in VALID_STATEMENT_ELEMENTS - assert ("Resource" not in statement or "NotResource" not in statement) - assert ("Action" not in statement or "NotAction" not in statement) + assert "Resource" not in statement or "NotResource" not in statement + assert "Action" not in statement or "NotAction" not in statement IAMPolicyDocumentValidator._validate_effect_syntax(statement) IAMPolicyDocumentValidator._validate_action_syntax(statement) @@ -189,23 +191,33 @@ class IAMPolicyDocumentValidator: def _validate_effect_syntax(statement): assert "Effect" in statement assert isinstance(statement["Effect"], string_types) - assert statement["Effect"].lower() in [allowed_effect.lower() for allowed_effect in VALID_EFFECTS] + assert statement["Effect"].lower() in [ + allowed_effect.lower() for allowed_effect in VALID_EFFECTS + ] @staticmethod def _validate_action_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "Action") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "Action" + ) @staticmethod def _validate_not_action_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "NotAction") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "NotAction" + ) @staticmethod def _validate_resource_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "Resource") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "Resource" + ) @staticmethod def _validate_not_resource_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "NotResource") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "NotResource" + ) @staticmethod def _validate_string_or_list_of_strings_syntax(statement, key): @@ -221,22 +233,28 @@ class IAMPolicyDocumentValidator: assert isinstance(statement["Condition"], dict) for condition_key, condition_value in statement["Condition"].items(): assert isinstance(condition_value, dict) - for condition_element_key, condition_element_value in condition_value.items(): + for ( + condition_element_key, + condition_element_value, + ) in condition_value.items(): assert isinstance(condition_element_value, (list, string_types)) - if IAMPolicyDocumentValidator._strip_condition_key(condition_key) not in VALID_CONDITIONS: + if ( + IAMPolicyDocumentValidator._strip_condition_key(condition_key) + not in VALID_CONDITIONS + ): assert not condition_value # empty dict @staticmethod def _strip_condition_key(condition_key): for valid_prefix in VALID_CONDITION_PREFIXES: if condition_key.startswith(valid_prefix): - condition_key = condition_key[len(valid_prefix):] + condition_key = condition_key[len(valid_prefix) :] break # strip only the first match for valid_postfix in VALID_CONDITION_POSTFIXES: if condition_key.endswith(valid_postfix): - condition_key = condition_key[:-len(valid_postfix)] + condition_key = condition_key[: -len(valid_postfix)] break # strip only the first match return condition_key @@ -252,15 +270,17 @@ class IAMPolicyDocumentValidator: def _validate_resource_exist(self): for statement in self._statements: - assert ("Resource" in statement or "NotResource" in statement) + assert "Resource" in statement or "NotResource" in statement if "Resource" in statement and isinstance(statement["Resource"], list): assert statement["Resource"] - elif "NotResource" in statement and isinstance(statement["NotResource"], list): + elif "NotResource" in statement and isinstance( + statement["NotResource"], list + ): assert statement["NotResource"] def _validate_action_like_exist(self): for statement in self._statements: - assert ("Action" in statement or "NotAction" in statement) + assert "Action" in statement or "NotAction" in statement if "Action" in statement and isinstance(statement["Action"], list): assert statement["Action"] elif "NotAction" in statement and isinstance(statement["NotAction"], list): @@ -285,13 +305,19 @@ class IAMPolicyDocumentValidator: def _validate_action_prefix(action): action_parts = action.split(":") if len(action_parts) == 1 and action_parts[0] != "*": - raise MalformedPolicyDocument("Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.") + raise MalformedPolicyDocument( + "Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc." + ) elif len(action_parts) > 2: - raise MalformedPolicyDocument("Actions/Condition can contain only one colon.") + raise MalformedPolicyDocument( + "Actions/Condition can contain only one colon." + ) - vendor_pattern = re.compile(r'[^a-zA-Z0-9\-.]') + vendor_pattern = re.compile(r"[^a-zA-Z0-9\-.]") if action_parts[0] != "*" and vendor_pattern.search(action_parts[0]): - raise MalformedPolicyDocument("Vendor {vendor} is not valid".format(vendor=action_parts[0])) + raise MalformedPolicyDocument( + "Vendor {vendor} is not valid".format(vendor=action_parts[0]) + ) def _validate_resources_for_formats(self): self._validate_resource_like_for_formats("Resource") @@ -308,30 +334,51 @@ class IAMPolicyDocumentValidator: for resource in sorted(statement[key], reverse=True): self._validate_resource_format(resource) if self._resource_error == "": - IAMPolicyDocumentValidator._legacy_parse_resource_like(statement, key) + IAMPolicyDocumentValidator._legacy_parse_resource_like( + statement, key + ) def _validate_resource_format(self, resource): if resource != "*": resource_partitions = resource.partition(":") if resource_partitions[1] == "": - self._resource_error = 'Resource {resource} must be in ARN format or "*".'.format(resource=resource) + self._resource_error = 'Resource {resource} must be in ARN format or "*".'.format( + resource=resource + ) return resource_partitions = resource_partitions[2].partition(":") if resource_partitions[0] != "aws": remaining_resource_parts = resource_partitions[2].split(":") - arn1 = remaining_resource_parts[0] if remaining_resource_parts[0] != "" or len(remaining_resource_parts) > 1 else "*" - arn2 = remaining_resource_parts[1] if len(remaining_resource_parts) > 1 else "*" - arn3 = remaining_resource_parts[2] if len(remaining_resource_parts) > 2 else "*" - arn4 = ":".join(remaining_resource_parts[3:]) if len(remaining_resource_parts) > 3 else "*" + arn1 = ( + remaining_resource_parts[0] + if remaining_resource_parts[0] != "" + or len(remaining_resource_parts) > 1 + else "*" + ) + arn2 = ( + remaining_resource_parts[1] + if len(remaining_resource_parts) > 1 + else "*" + ) + arn3 = ( + remaining_resource_parts[2] + if len(remaining_resource_parts) > 2 + else "*" + ) + arn4 = ( + ":".join(remaining_resource_parts[3:]) + if len(remaining_resource_parts) > 3 + else "*" + ) self._resource_error = 'Partition "{partition}" is not valid for resource "arn:{partition}:{arn1}:{arn2}:{arn3}:{arn4}".'.format( partition=resource_partitions[0], arn1=arn1, arn2=arn2, arn3=arn3, - arn4=arn4 + arn4=arn4, ) return @@ -343,8 +390,16 @@ class IAMPolicyDocumentValidator: service = resource_partitions[0] - if service in SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS.keys() and not resource_partitions[2].startswith(":"): - self._resource_error = SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS[service].format(resource=resource) + if service in SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS.keys() and not resource_partitions[ + 2 + ].startswith( + ":" + ): + self._resource_error = SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS[ + service + ].format( + resource=resource + ) return resource_partitions = resource_partitions[2].partition(":") @@ -352,13 +407,19 @@ class IAMPolicyDocumentValidator: if service in VALID_RESOURCE_PATH_STARTING_VALUES.keys(): valid_start = False - for valid_starting_value in VALID_RESOURCE_PATH_STARTING_VALUES[service]["values"]: + for valid_starting_value in VALID_RESOURCE_PATH_STARTING_VALUES[ + service + ]["values"]: if resource_partitions[2].startswith(valid_starting_value): valid_start = True break if not valid_start: - self._resource_error = VALID_RESOURCE_PATH_STARTING_VALUES[service]["error_message"].format( - values=", ".join(VALID_RESOURCE_PATH_STARTING_VALUES[service]["values"]) + self._resource_error = VALID_RESOURCE_PATH_STARTING_VALUES[service][ + "error_message" + ].format( + values=", ".join( + VALID_RESOURCE_PATH_STARTING_VALUES[service]["values"] + ) ) def _perform_first_legacy_parsing(self): @@ -371,7 +432,9 @@ class IAMPolicyDocumentValidator: assert statement["Effect"] in VALID_EFFECTS # case-sensitive matching if "Condition" in statement: for condition_key, condition_value in statement["Condition"].items(): - IAMPolicyDocumentValidator._legacy_parse_condition(condition_key, condition_value) + IAMPolicyDocumentValidator._legacy_parse_condition( + condition_key, condition_value + ) @staticmethod def _legacy_parse_resource_like(statement, key): @@ -387,20 +450,31 @@ class IAMPolicyDocumentValidator: @staticmethod def _legacy_parse_condition(condition_key, condition_value): - stripped_condition_key = IAMPolicyDocumentValidator._strip_condition_key(condition_key) + stripped_condition_key = IAMPolicyDocumentValidator._strip_condition_key( + condition_key + ) if stripped_condition_key.startswith("Date"): - for condition_element_key, condition_element_value in condition_value.items(): + for ( + condition_element_key, + condition_element_value, + ) in condition_value.items(): if isinstance(condition_element_value, string_types): - IAMPolicyDocumentValidator._legacy_parse_date_condition_value(condition_element_value) + IAMPolicyDocumentValidator._legacy_parse_date_condition_value( + condition_element_value + ) else: # it has to be a list for date_condition_value in condition_element_value: - IAMPolicyDocumentValidator._legacy_parse_date_condition_value(date_condition_value) + IAMPolicyDocumentValidator._legacy_parse_date_condition_value( + date_condition_value + ) @staticmethod def _legacy_parse_date_condition_value(date_condition_value): if "t" in date_condition_value.lower() or "-" in date_condition_value: - IAMPolicyDocumentValidator._validate_iso_8601_datetime(date_condition_value.lower()) + IAMPolicyDocumentValidator._validate_iso_8601_datetime( + date_condition_value.lower() + ) else: # timestamp assert 0 <= int(date_condition_value) <= 9223372036854775807 @@ -408,7 +482,11 @@ class IAMPolicyDocumentValidator: def _validate_iso_8601_datetime(datetime): datetime_parts = datetime.partition("t") negative_year = datetime_parts[0].startswith("-") - date_parts = datetime_parts[0][1:].split("-") if negative_year else datetime_parts[0].split("-") + date_parts = ( + datetime_parts[0][1:].split("-") + if negative_year + else datetime_parts[0].split("-") + ) year = "-" + date_parts[0] if negative_year else date_parts[0] assert -292275054 <= int(year) <= 292278993 if len(date_parts) > 1: @@ -442,7 +520,9 @@ class IAMPolicyDocumentValidator: assert 0 <= int(time_zone_minutes) <= 59 else: seconds_with_decimal_fraction = time_parts[2] - seconds_with_decimal_fraction_partition = seconds_with_decimal_fraction.partition(".") + seconds_with_decimal_fraction_partition = seconds_with_decimal_fraction.partition( + "." + ) seconds = seconds_with_decimal_fraction_partition[0] assert 0 <= int(seconds) <= 59 if seconds_with_decimal_fraction_partition[1] == ".": diff --git a/moto/iam/responses.py b/moto/iam/responses.py index 806dd37f4..45bd28c36 100644 --- a/moto/iam/responses.py +++ b/moto/iam/responses.py @@ -6,123 +6,125 @@ from .models import iam_backend, User class IamResponse(BaseResponse): - def attach_role_policy(self): - policy_arn = self._get_param('PolicyArn') - role_name = self._get_param('RoleName') + policy_arn = self._get_param("PolicyArn") + role_name = self._get_param("RoleName") iam_backend.attach_role_policy(policy_arn, role_name) template = self.response_template(ATTACH_ROLE_POLICY_TEMPLATE) return template.render() def detach_role_policy(self): - role_name = self._get_param('RoleName') - policy_arn = self._get_param('PolicyArn') + role_name = self._get_param("RoleName") + policy_arn = self._get_param("PolicyArn") iam_backend.detach_role_policy(policy_arn, role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DetachRolePolicyResponse") def attach_group_policy(self): - policy_arn = self._get_param('PolicyArn') - group_name = self._get_param('GroupName') + policy_arn = self._get_param("PolicyArn") + group_name = self._get_param("GroupName") iam_backend.attach_group_policy(policy_arn, group_name) template = self.response_template(ATTACH_GROUP_POLICY_TEMPLATE) return template.render() def detach_group_policy(self): - policy_arn = self._get_param('PolicyArn') - group_name = self._get_param('GroupName') + policy_arn = self._get_param("PolicyArn") + group_name = self._get_param("GroupName") iam_backend.detach_group_policy(policy_arn, group_name) template = self.response_template(DETACH_GROUP_POLICY_TEMPLATE) return template.render() def attach_user_policy(self): - policy_arn = self._get_param('PolicyArn') - user_name = self._get_param('UserName') + policy_arn = self._get_param("PolicyArn") + user_name = self._get_param("UserName") iam_backend.attach_user_policy(policy_arn, user_name) template = self.response_template(ATTACH_USER_POLICY_TEMPLATE) return template.render() def detach_user_policy(self): - policy_arn = self._get_param('PolicyArn') - user_name = self._get_param('UserName') + policy_arn = self._get_param("PolicyArn") + user_name = self._get_param("UserName") iam_backend.detach_user_policy(policy_arn, user_name) template = self.response_template(DETACH_USER_POLICY_TEMPLATE) return template.render() def create_policy(self): - description = self._get_param('Description') - path = self._get_param('Path') - policy_document = self._get_param('PolicyDocument') - policy_name = self._get_param('PolicyName') + description = self._get_param("Description") + path = self._get_param("Path") + policy_document = self._get_param("PolicyDocument") + policy_name = self._get_param("PolicyName") policy = iam_backend.create_policy( - description, path, policy_document, policy_name) + description, path, policy_document, policy_name + ) template = self.response_template(CREATE_POLICY_TEMPLATE) return template.render(policy=policy) def get_policy(self): - policy_arn = self._get_param('PolicyArn') + policy_arn = self._get_param("PolicyArn") policy = iam_backend.get_policy(policy_arn) template = self.response_template(GET_POLICY_TEMPLATE) return template.render(policy=policy) def list_attached_role_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - path_prefix = self._get_param('PathPrefix', '/') - role_name = self._get_param('RoleName') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + path_prefix = self._get_param("PathPrefix", "/") + role_name = self._get_param("RoleName") policies, marker = iam_backend.list_attached_role_policies( - role_name, marker=marker, max_items=max_items, path_prefix=path_prefix) + role_name, marker=marker, max_items=max_items, path_prefix=path_prefix + ) template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_attached_group_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - path_prefix = self._get_param('PathPrefix', '/') - group_name = self._get_param('GroupName') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + path_prefix = self._get_param("PathPrefix", "/") + group_name = self._get_param("GroupName") policies, marker = iam_backend.list_attached_group_policies( - group_name, marker=marker, max_items=max_items, - path_prefix=path_prefix) + group_name, marker=marker, max_items=max_items, path_prefix=path_prefix + ) template = self.response_template(LIST_ATTACHED_GROUP_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_attached_user_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - path_prefix = self._get_param('PathPrefix', '/') - user_name = self._get_param('UserName') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + path_prefix = self._get_param("PathPrefix", "/") + user_name = self._get_param("UserName") policies, marker = iam_backend.list_attached_user_policies( - user_name, marker=marker, max_items=max_items, - path_prefix=path_prefix) + user_name, marker=marker, max_items=max_items, path_prefix=path_prefix + ) template = self.response_template(LIST_ATTACHED_USER_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - only_attached = self._get_bool_param('OnlyAttached', False) - path_prefix = self._get_param('PathPrefix', '/') - scope = self._get_param('Scope', 'All') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + only_attached = self._get_bool_param("OnlyAttached", False) + path_prefix = self._get_param("PathPrefix", "/") + scope = self._get_param("Scope", "All") policies, marker = iam_backend.list_policies( - marker, max_items, only_attached, path_prefix, scope) + marker, max_items, only_attached, path_prefix, scope + ) template = self.response_template(LIST_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_entities_for_policy(self): - policy_arn = self._get_param('PolicyArn') + policy_arn = self._get_param("PolicyArn") # Options 'User'|'Role'|'Group'|'LocalManagedPolicy'|'AWSManagedPolicy - entity = self._get_param('EntityFilter') - path_prefix = self._get_param('PathPrefix') + entity = self._get_param("EntityFilter") + path_prefix = self._get_param("PathPrefix") # policy_usage_filter = self._get_param('PolicyUsageFilter') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems') + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems") entity_roles = [] entity_groups = [] entity_users = [] - if entity == 'User': + if entity == "User": users = iam_backend.list_users(path_prefix, marker, max_items) if users: for user in users: @@ -130,7 +132,7 @@ class IamResponse(BaseResponse): if p == policy_arn: entity_users.append(user.name) - elif entity == 'Role': + elif entity == "Role": roles = iam_backend.list_roles(path_prefix, marker, max_items) if roles: for role in roles: @@ -138,7 +140,7 @@ class IamResponse(BaseResponse): if p == policy_arn: entity_roles.append(role.name) - elif entity == 'Group': + elif entity == "Group": groups = iam_backend.list_groups() if groups: for group in groups: @@ -146,7 +148,7 @@ class IamResponse(BaseResponse): if p == policy_arn: entity_groups.append(group.name) - elif entity == 'LocalManagedPolicy' or entity == 'AWSManagedPolicy': + elif entity == "LocalManagedPolicy" or entity == "AWSManagedPolicy": users = iam_backend.list_users(path_prefix, marker, max_items) if users: for user in users: @@ -169,150 +171,161 @@ class IamResponse(BaseResponse): entity_groups.append(group.name) template = self.response_template(LIST_ENTITIES_FOR_POLICY_TEMPLATE) - return template.render(roles=entity_roles, users=entity_users, groups=entity_groups) + return template.render( + roles=entity_roles, users=entity_users, groups=entity_groups + ) def create_role(self): - role_name = self._get_param('RoleName') - path = self._get_param('Path') - assume_role_policy_document = self._get_param( - 'AssumeRolePolicyDocument') - permissions_boundary = self._get_param( - 'PermissionsBoundary') - description = self._get_param('Description') - tags = self._get_multi_param('Tags.member') + role_name = self._get_param("RoleName") + path = self._get_param("Path") + assume_role_policy_document = self._get_param("AssumeRolePolicyDocument") + permissions_boundary = self._get_param("PermissionsBoundary") + description = self._get_param("Description") + tags = self._get_multi_param("Tags.member") + max_session_duration = self._get_param("MaxSessionDuration", 3600) role = iam_backend.create_role( - role_name, assume_role_policy_document, path, permissions_boundary, description, tags) + role_name, + assume_role_policy_document, + path, + permissions_boundary, + description, + tags, + max_session_duration, + ) template = self.response_template(CREATE_ROLE_TEMPLATE) return template.render(role=role) def get_role(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") role = iam_backend.get_role(role_name) template = self.response_template(GET_ROLE_TEMPLATE) return template.render(role=role) def delete_role(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") iam_backend.delete_role(role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRoleResponse") def list_role_policies(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") role_policies_names = iam_backend.list_role_policies(role_name) template = self.response_template(LIST_ROLE_POLICIES) return template.render(role_policies=role_policies_names) def put_role_policy(self): - role_name = self._get_param('RoleName') - policy_name = self._get_param('PolicyName') - policy_document = self._get_param('PolicyDocument') + role_name = self._get_param("RoleName") + policy_name = self._get_param("PolicyName") + policy_document = self._get_param("PolicyDocument") iam_backend.put_role_policy(role_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutRolePolicyResponse") def delete_role_policy(self): - role_name = self._get_param('RoleName') - policy_name = self._get_param('PolicyName') + role_name = self._get_param("RoleName") + policy_name = self._get_param("PolicyName") iam_backend.delete_role_policy(role_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRolePolicyResponse") def get_role_policy(self): - role_name = self._get_param('RoleName') - policy_name = self._get_param('PolicyName') + role_name = self._get_param("RoleName") + policy_name = self._get_param("PolicyName") policy_name, policy_document = iam_backend.get_role_policy( - role_name, policy_name) + role_name, policy_name + ) template = self.response_template(GET_ROLE_POLICY_TEMPLATE) - return template.render(role_name=role_name, - policy_name=policy_name, - policy_document=policy_document) + return template.render( + role_name=role_name, + policy_name=policy_name, + policy_document=policy_document, + ) def update_assume_role_policy(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") role = iam_backend.get_role(role_name) - role.assume_role_policy_document = self._get_param('PolicyDocument') + role.assume_role_policy_document = self._get_param("PolicyDocument") template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateAssumeRolePolicyResponse") def update_role_description(self): - role_name = self._get_param('RoleName') - description = self._get_param('Description') + role_name = self._get_param("RoleName") + description = self._get_param("Description") role = iam_backend.update_role_description(role_name, description) template = self.response_template(UPDATE_ROLE_DESCRIPTION_TEMPLATE) return template.render(role=role) def update_role(self): - role_name = self._get_param('RoleName') - description = self._get_param('Description') - role = iam_backend.update_role(role_name, description) + role_name = self._get_param("RoleName") + description = self._get_param("Description") + max_session_duration = self._get_param("MaxSessionDuration", 3600) + role = iam_backend.update_role(role_name, description, max_session_duration) template = self.response_template(UPDATE_ROLE_TEMPLATE) return template.render(role=role) def create_policy_version(self): - policy_arn = self._get_param('PolicyArn') - policy_document = self._get_param('PolicyDocument') - set_as_default = self._get_param('SetAsDefault') - policy_version = iam_backend.create_policy_version(policy_arn, policy_document, set_as_default) + policy_arn = self._get_param("PolicyArn") + policy_document = self._get_param("PolicyDocument") + set_as_default = self._get_param("SetAsDefault") + policy_version = iam_backend.create_policy_version( + policy_arn, policy_document, set_as_default + ) template = self.response_template(CREATE_POLICY_VERSION_TEMPLATE) return template.render(policy_version=policy_version) def get_policy_version(self): - policy_arn = self._get_param('PolicyArn') - version_id = self._get_param('VersionId') + policy_arn = self._get_param("PolicyArn") + version_id = self._get_param("VersionId") policy_version = iam_backend.get_policy_version(policy_arn, version_id) template = self.response_template(GET_POLICY_VERSION_TEMPLATE) return template.render(policy_version=policy_version) def list_policy_versions(self): - policy_arn = self._get_param('PolicyArn') + policy_arn = self._get_param("PolicyArn") policy_versions = iam_backend.list_policy_versions(policy_arn) template = self.response_template(LIST_POLICY_VERSIONS_TEMPLATE) return template.render(policy_versions=policy_versions) def delete_policy_version(self): - policy_arn = self._get_param('PolicyArn') - version_id = self._get_param('VersionId') + policy_arn = self._get_param("PolicyArn") + version_id = self._get_param("VersionId") iam_backend.delete_policy_version(policy_arn, version_id) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeletePolicyVersion') + return template.render(name="DeletePolicyVersion") def create_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') - path = self._get_param('Path', '/') + profile_name = self._get_param("InstanceProfileName") + path = self._get_param("Path", "/") - profile = iam_backend.create_instance_profile( - profile_name, path, role_ids=[]) + profile = iam_backend.create_instance_profile(profile_name, path, role_ids=[]) template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) def get_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') + profile_name = self._get_param("InstanceProfileName") profile = iam_backend.get_instance_profile(profile_name) template = self.response_template(GET_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) def add_role_to_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') - role_name = self._get_param('RoleName') + profile_name = self._get_param("InstanceProfileName") + role_name = self._get_param("RoleName") iam_backend.add_role_to_instance_profile(profile_name, role_name) - template = self.response_template( - ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE) + template = self.response_template(ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE) return template.render() def remove_role_from_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') - role_name = self._get_param('RoleName') + profile_name = self._get_param("InstanceProfileName") + role_name = self._get_param("RoleName") iam_backend.remove_role_from_instance_profile(profile_name, role_name) - template = self.response_template( - REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE) + template = self.response_template(REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE) return template.render() def list_roles(self): @@ -328,23 +341,22 @@ class IamResponse(BaseResponse): return template.render(instance_profiles=profiles) def list_instance_profiles_for_role(self): - role_name = self._get_param('RoleName') - profiles = iam_backend.get_instance_profiles_for_role( - role_name=role_name) + role_name = self._get_param("RoleName") + profiles = iam_backend.get_instance_profiles_for_role(role_name=role_name) - template = self.response_template( - LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE) + template = self.response_template(LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE) return template.render(instance_profiles=profiles) def upload_server_certificate(self): - cert_name = self._get_param('ServerCertificateName') - cert_body = self._get_param('CertificateBody') - path = self._get_param('Path') - private_key = self._get_param('PrivateKey') - cert_chain = self._get_param('CertificateName') + cert_name = self._get_param("ServerCertificateName") + cert_body = self._get_param("CertificateBody") + path = self._get_param("Path") + private_key = self._get_param("PrivateKey") + cert_chain = self._get_param("CertificateName") - cert = iam_backend.upload_server_cert( - cert_name, cert_body, private_key, cert_chain=cert_chain, path=path) + cert = iam_backend.upload_server_certificate( + cert_name, cert_body, private_key, cert_chain=cert_chain, path=path + ) template = self.response_template(UPLOAD_CERT_TEMPLATE) return template.render(certificate=cert) @@ -354,27 +366,27 @@ class IamResponse(BaseResponse): return template.render(server_certificates=certs) def get_server_certificate(self): - cert_name = self._get_param('ServerCertificateName') + cert_name = self._get_param("ServerCertificateName") cert = iam_backend.get_server_certificate(cert_name) template = self.response_template(GET_SERVER_CERTIFICATE_TEMPLATE) return template.render(certificate=cert) def delete_server_certificate(self): - cert_name = self._get_param('ServerCertificateName') + cert_name = self._get_param("ServerCertificateName") iam_backend.delete_server_certificate(cert_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteServerCertificate") def create_group(self): - group_name = self._get_param('GroupName') - path = self._get_param('Path', '/') + group_name = self._get_param("GroupName") + path = self._get_param("Path", "/") group = iam_backend.create_group(group_name, path) template = self.response_template(CREATE_GROUP_TEMPLATE) return template.render(group=group) def get_group(self): - group_name = self._get_param('GroupName') + group_name = self._get_param("GroupName") group = iam_backend.get_group(group_name) template = self.response_template(GET_GROUP_TEMPLATE) @@ -386,48 +398,55 @@ class IamResponse(BaseResponse): return template.render(groups=groups) def list_groups_for_user(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") groups = iam_backend.get_groups_for_user(user_name) template = self.response_template(LIST_GROUPS_FOR_USER_TEMPLATE) return template.render(groups=groups) def put_group_policy(self): - group_name = self._get_param('GroupName') - policy_name = self._get_param('PolicyName') - policy_document = self._get_param('PolicyDocument') + group_name = self._get_param("GroupName") + policy_name = self._get_param("PolicyName") + policy_document = self._get_param("PolicyDocument") iam_backend.put_group_policy(group_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutGroupPolicyResponse") def list_group_policies(self): - group_name = self._get_param('GroupName') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems') - policies = iam_backend.list_group_policies(group_name, - marker=marker, max_items=max_items) + group_name = self._get_param("GroupName") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems") + policies = iam_backend.list_group_policies( + group_name, marker=marker, max_items=max_items + ) template = self.response_template(LIST_GROUP_POLICIES_TEMPLATE) - return template.render(name="ListGroupPoliciesResponse", - policies=policies, - marker=marker) + return template.render( + name="ListGroupPoliciesResponse", policies=policies, marker=marker + ) def get_group_policy(self): - group_name = self._get_param('GroupName') - policy_name = self._get_param('PolicyName') + group_name = self._get_param("GroupName") + policy_name = self._get_param("PolicyName") policy_result = iam_backend.get_group_policy(group_name, policy_name) template = self.response_template(GET_GROUP_POLICY_TEMPLATE) return template.render(name="GetGroupPolicyResponse", **policy_result) + def delete_group(self): + group_name = self._get_param("GroupName") + iam_backend.delete_group(group_name) + template = self.response_template(GENERIC_EMPTY_TEMPLATE) + return template.render(name="DeleteGroup") + def create_user(self): - user_name = self._get_param('UserName') - path = self._get_param('Path') + user_name = self._get_param("UserName") + path = self._get_param("Path") user = iam_backend.create_user(user_name, path) template = self.response_template(USER_TEMPLATE) - return template.render(action='Create', user=user) + return template.render(action="Create", user=user) def get_user(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_current_user() user = iam_backend.get_user_from_access_key_id(access_key_id) @@ -437,178 +456,257 @@ class IamResponse(BaseResponse): user = iam_backend.get_user(user_name) template = self.response_template(USER_TEMPLATE) - return template.render(action='Get', user=user) + return template.render(action="Get", user=user) def list_users(self): - path_prefix = self._get_param('PathPrefix') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems') + path_prefix = self._get_param("PathPrefix") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems") users = iam_backend.list_users(path_prefix, marker, max_items) template = self.response_template(LIST_USERS_TEMPLATE) - return template.render(action='List', users=users) + return template.render(action="List", users=users) def update_user(self): - user_name = self._get_param('UserName') - new_path = self._get_param('NewPath') - new_user_name = self._get_param('NewUserName') + user_name = self._get_param("UserName") + new_path = self._get_param("NewPath") + new_user_name = self._get_param("NewUserName") iam_backend.update_user(user_name, new_path, new_user_name) if new_user_name: user = iam_backend.get_user(new_user_name) else: user = iam_backend.get_user(user_name) template = self.response_template(USER_TEMPLATE) - return template.render(action='Update', user=user) + return template.render(action="Update", user=user) def create_login_profile(self): - user_name = self._get_param('UserName') - password = self._get_param('Password') + user_name = self._get_param("UserName") + password = self._get_param("Password") user = iam_backend.create_login_profile(user_name, password) template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) def get_login_profile(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") user = iam_backend.get_login_profile(user_name) template = self.response_template(GET_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) def update_login_profile(self): - user_name = self._get_param('UserName') - password = self._get_param('Password') - password_reset_required = self._get_param('PasswordResetRequired') - user = iam_backend.update_login_profile(user_name, password, password_reset_required) + user_name = self._get_param("UserName") + password = self._get_param("Password") + password_reset_required = self._get_param("PasswordResetRequired") + user = iam_backend.update_login_profile( + user_name, password, password_reset_required + ) template = self.response_template(UPDATE_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) def add_user_to_group(self): - group_name = self._get_param('GroupName') - user_name = self._get_param('UserName') + group_name = self._get_param("GroupName") + user_name = self._get_param("UserName") iam_backend.add_user_to_group(group_name, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='AddUserToGroup') + return template.render(name="AddUserToGroup") def remove_user_from_group(self): - group_name = self._get_param('GroupName') - user_name = self._get_param('UserName') + group_name = self._get_param("GroupName") + user_name = self._get_param("UserName") iam_backend.remove_user_from_group(group_name, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='RemoveUserFromGroup') + return template.render(name="RemoveUserFromGroup") def get_user_policy(self): - user_name = self._get_param('UserName') - policy_name = self._get_param('PolicyName') + user_name = self._get_param("UserName") + policy_name = self._get_param("PolicyName") policy_document = iam_backend.get_user_policy(user_name, policy_name) template = self.response_template(GET_USER_POLICY_TEMPLATE) return template.render( user_name=user_name, policy_name=policy_name, - policy_document=policy_document.get('policy_document') + policy_document=policy_document.get("policy_document"), ) def list_user_policies(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") policies = iam_backend.list_user_policies(user_name) template = self.response_template(LIST_USER_POLICIES_TEMPLATE) return template.render(policies=policies) def put_user_policy(self): - user_name = self._get_param('UserName') - policy_name = self._get_param('PolicyName') - policy_document = self._get_param('PolicyDocument') + user_name = self._get_param("UserName") + policy_name = self._get_param("PolicyName") + policy_document = self._get_param("PolicyDocument") iam_backend.put_user_policy(user_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='PutUserPolicy') + return template.render(name="PutUserPolicy") def delete_user_policy(self): - user_name = self._get_param('UserName') - policy_name = self._get_param('PolicyName') + user_name = self._get_param("UserName") + policy_name = self._get_param("PolicyName") iam_backend.delete_user_policy(user_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteUserPolicy') + return template.render(name="DeleteUserPolicy") def create_access_key(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") key = iam_backend.create_access_key(user_name) template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE) return template.render(key=key) def update_access_key(self): - user_name = self._get_param('UserName') - access_key_id = self._get_param('AccessKeyId') - status = self._get_param('Status') + user_name = self._get_param("UserName") + access_key_id = self._get_param("AccessKeyId") + status = self._get_param("Status") iam_backend.update_access_key(user_name, access_key_id, status) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='UpdateAccessKey') + return template.render(name="UpdateAccessKey") def get_access_key_last_used(self): - access_key_id = self._get_param('AccessKeyId') + access_key_id = self._get_param("AccessKeyId") last_used_response = iam_backend.get_access_key_last_used(access_key_id) template = self.response_template(GET_ACCESS_KEY_LAST_USED_TEMPLATE) - return template.render(user_name=last_used_response["user_name"], last_used=last_used_response["last_used"]) + return template.render( + user_name=last_used_response["user_name"], + last_used=last_used_response["last_used"], + ) def list_access_keys(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") keys = iam_backend.get_all_access_keys(user_name) template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE) return template.render(user_name=user_name, keys=keys) def delete_access_key(self): - user_name = self._get_param('UserName') - access_key_id = self._get_param('AccessKeyId') + user_name = self._get_param("UserName") + access_key_id = self._get_param("AccessKeyId") iam_backend.delete_access_key(access_key_id, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteAccessKey') + return template.render(name="DeleteAccessKey") + + def upload_ssh_public_key(self): + user_name = self._get_param("UserName") + ssh_public_key_body = self._get_param("SSHPublicKeyBody") + + key = iam_backend.upload_ssh_public_key(user_name, ssh_public_key_body) + template = self.response_template(UPLOAD_SSH_PUBLIC_KEY_TEMPLATE) + return template.render(key=key) + + def get_ssh_public_key(self): + user_name = self._get_param("UserName") + ssh_public_key_id = self._get_param("SSHPublicKeyId") + + key = iam_backend.get_ssh_public_key(user_name, ssh_public_key_id) + template = self.response_template(GET_SSH_PUBLIC_KEY_TEMPLATE) + return template.render(key=key) + + def list_ssh_public_keys(self): + user_name = self._get_param("UserName") + + keys = iam_backend.get_all_ssh_public_keys(user_name) + template = self.response_template(LIST_SSH_PUBLIC_KEYS_TEMPLATE) + return template.render(keys=keys) + + def update_ssh_public_key(self): + user_name = self._get_param("UserName") + ssh_public_key_id = self._get_param("SSHPublicKeyId") + status = self._get_param("Status") + + iam_backend.update_ssh_public_key(user_name, ssh_public_key_id, status) + template = self.response_template(UPDATE_SSH_PUBLIC_KEY_TEMPLATE) + return template.render() + + def delete_ssh_public_key(self): + user_name = self._get_param("UserName") + ssh_public_key_id = self._get_param("SSHPublicKeyId") + + iam_backend.delete_ssh_public_key(user_name, ssh_public_key_id) + template = self.response_template(DELETE_SSH_PUBLIC_KEY_TEMPLATE) + return template.render() def deactivate_mfa_device(self): - user_name = self._get_param('UserName') - serial_number = self._get_param('SerialNumber') + user_name = self._get_param("UserName") + serial_number = self._get_param("SerialNumber") iam_backend.deactivate_mfa_device(user_name, serial_number) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeactivateMFADevice') + return template.render(name="DeactivateMFADevice") def enable_mfa_device(self): - user_name = self._get_param('UserName') - serial_number = self._get_param('SerialNumber') - authentication_code_1 = self._get_param('AuthenticationCode1') - authentication_code_2 = self._get_param('AuthenticationCode2') + user_name = self._get_param("UserName") + serial_number = self._get_param("SerialNumber") + authentication_code_1 = self._get_param("AuthenticationCode1") + authentication_code_2 = self._get_param("AuthenticationCode2") iam_backend.enable_mfa_device( - user_name, - serial_number, - authentication_code_1, - authentication_code_2 + user_name, serial_number, authentication_code_1, authentication_code_2 ) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='EnableMFADevice') + return template.render(name="EnableMFADevice") def list_mfa_devices(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") devices = iam_backend.list_mfa_devices(user_name) template = self.response_template(LIST_MFA_DEVICES_TEMPLATE) return template.render(user_name=user_name, devices=devices) + def create_virtual_mfa_device(self): + path = self._get_param("Path") + virtual_mfa_device_name = self._get_param("VirtualMFADeviceName") + + virtual_mfa_device = iam_backend.create_virtual_mfa_device( + virtual_mfa_device_name, path + ) + + template = self.response_template(CREATE_VIRTUAL_MFA_DEVICE_TEMPLATE) + return template.render(device=virtual_mfa_device) + + def delete_virtual_mfa_device(self): + serial_number = self._get_param("SerialNumber") + + iam_backend.delete_virtual_mfa_device(serial_number) + + template = self.response_template(DELETE_VIRTUAL_MFA_DEVICE_TEMPLATE) + return template.render() + + def list_virtual_mfa_devices(self): + assignment_status = self._get_param("AssignmentStatus", "Any") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems", 100) + + devices, marker = iam_backend.list_virtual_mfa_devices( + assignment_status, marker, max_items + ) + + template = self.response_template(LIST_VIRTUAL_MFA_DEVICES_TEMPLATE) + return template.render(devices=devices, marker=marker) + def delete_user(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") iam_backend.delete_user(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteUser') + return template.render(name="DeleteUser") + + def delete_policy(self): + policy_arn = self._get_param("PolicyArn") + iam_backend.delete_policy(policy_arn) + template = self.response_template(GENERIC_EMPTY_TEMPLATE) + return template.render(name="DeletePolicy") def delete_login_profile(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") iam_backend.delete_login_profile(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteLoginProfile') + return template.render(name="DeleteLoginProfile") def generate_credential_report(self): if iam_backend.report_generated(): @@ -629,48 +727,52 @@ class IamResponse(BaseResponse): return template.render(aliases=aliases) def create_account_alias(self): - alias = self._get_param('AccountAlias') + alias = self._get_param("AccountAlias") iam_backend.create_account_alias(alias) template = self.response_template(CREATE_ACCOUNT_ALIAS_TEMPLATE) return template.render() def delete_account_alias(self): - alias = self._get_param('AccountAlias') + alias = self._get_param("AccountAlias") iam_backend.delete_account_alias(alias) template = self.response_template(DELETE_ACCOUNT_ALIAS_TEMPLATE) return template.render() def get_account_authorization_details(self): - filter_param = self._get_multi_param('Filter.member') + filter_param = self._get_multi_param("Filter.member") account_details = iam_backend.get_account_authorization_details(filter_param) template = self.response_template(GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE) return template.render( - instance_profiles=account_details['instance_profiles'], - policies=account_details['managed_policies'], - users=account_details['users'], - groups=account_details['groups'], - roles=account_details['roles'], - get_groups_for_user=iam_backend.get_groups_for_user + instance_profiles=account_details["instance_profiles"], + policies=account_details["managed_policies"], + users=account_details["users"], + groups=account_details["groups"], + roles=account_details["roles"], + get_groups_for_user=iam_backend.get_groups_for_user, ) def create_saml_provider(self): - saml_provider_name = self._get_param('Name') - saml_metadata_document = self._get_param('SAMLMetadataDocument') - saml_provider = iam_backend.create_saml_provider(saml_provider_name, saml_metadata_document) + saml_provider_name = self._get_param("Name") + saml_metadata_document = self._get_param("SAMLMetadataDocument") + saml_provider = iam_backend.create_saml_provider( + saml_provider_name, saml_metadata_document + ) template = self.response_template(CREATE_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) def update_saml_provider(self): - saml_provider_arn = self._get_param('SAMLProviderArn') - saml_metadata_document = self._get_param('SAMLMetadataDocument') - saml_provider = iam_backend.update_saml_provider(saml_provider_arn, saml_metadata_document) + saml_provider_arn = self._get_param("SAMLProviderArn") + saml_metadata_document = self._get_param("SAMLMetadataDocument") + saml_provider = iam_backend.update_saml_provider( + saml_provider_arn, saml_metadata_document + ) template = self.response_template(UPDATE_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) def delete_saml_provider(self): - saml_provider_arn = self._get_param('SAMLProviderArn') + saml_provider_arn = self._get_param("SAMLProviderArn") iam_backend.delete_saml_provider(saml_provider_arn) template = self.response_template(DELETE_SAML_PROVIDER_TEMPLATE) @@ -683,48 +785,48 @@ class IamResponse(BaseResponse): return template.render(saml_providers=saml_providers) def get_saml_provider(self): - saml_provider_arn = self._get_param('SAMLProviderArn') + saml_provider_arn = self._get_param("SAMLProviderArn") saml_provider = iam_backend.get_saml_provider(saml_provider_arn) template = self.response_template(GET_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) def upload_signing_certificate(self): - user_name = self._get_param('UserName') - cert_body = self._get_param('CertificateBody') + user_name = self._get_param("UserName") + cert_body = self._get_param("CertificateBody") cert = iam_backend.upload_signing_certificate(user_name, cert_body) template = self.response_template(UPLOAD_SIGNING_CERTIFICATE_TEMPLATE) return template.render(cert=cert) def update_signing_certificate(self): - user_name = self._get_param('UserName') - cert_id = self._get_param('CertificateId') - status = self._get_param('Status') + user_name = self._get_param("UserName") + cert_id = self._get_param("CertificateId") + status = self._get_param("Status") iam_backend.update_signing_certificate(user_name, cert_id, status) template = self.response_template(UPDATE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() def delete_signing_certificate(self): - user_name = self._get_param('UserName') - cert_id = self._get_param('CertificateId') + user_name = self._get_param("UserName") + cert_id = self._get_param("CertificateId") iam_backend.delete_signing_certificate(user_name, cert_id) template = self.response_template(DELETE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() def list_signing_certificates(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") certs = iam_backend.list_signing_certificates(user_name) template = self.response_template(LIST_SIGNING_CERTIFICATES_TEMPLATE) return template.render(user_name=user_name, certificates=certs) def list_role_tags(self): - role_name = self._get_param('RoleName') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems', 100) + role_name = self._get_param("RoleName") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems", 100) tags, marker = iam_backend.list_role_tags(role_name, marker, max_items) @@ -732,8 +834,8 @@ class IamResponse(BaseResponse): return template.render(tags=tags, marker=marker) def tag_role(self): - role_name = self._get_param('RoleName') - tags = self._get_multi_param('Tags.member') + role_name = self._get_param("RoleName") + tags = self._get_multi_param("Tags.member") iam_backend.tag_role(role_name, tags) @@ -741,14 +843,100 @@ class IamResponse(BaseResponse): return template.render() def untag_role(self): - role_name = self._get_param('RoleName') - tag_keys = self._get_multi_param('TagKeys.member') + role_name = self._get_param("RoleName") + tag_keys = self._get_multi_param("TagKeys.member") iam_backend.untag_role(role_name, tag_keys) template = self.response_template(UNTAG_ROLE_TEMPLATE) return template.render() + def create_open_id_connect_provider(self): + open_id_provider_url = self._get_param("Url") + thumbprint_list = self._get_multi_param("ThumbprintList.member") + client_id_list = self._get_multi_param("ClientIDList.member") + + open_id_provider = iam_backend.create_open_id_connect_provider( + open_id_provider_url, thumbprint_list, client_id_list + ) + + template = self.response_template(CREATE_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) + return template.render(open_id_provider=open_id_provider) + + def delete_open_id_connect_provider(self): + open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") + + iam_backend.delete_open_id_connect_provider(open_id_provider_arn) + + template = self.response_template(DELETE_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) + return template.render() + + def get_open_id_connect_provider(self): + open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") + + open_id_provider = iam_backend.get_open_id_connect_provider( + open_id_provider_arn + ) + + template = self.response_template(GET_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) + return template.render(open_id_provider=open_id_provider) + + def list_open_id_connect_providers(self): + open_id_provider_arns = iam_backend.list_open_id_connect_providers() + + template = self.response_template(LIST_OPEN_ID_CONNECT_PROVIDERS_TEMPLATE) + return template.render(open_id_provider_arns=open_id_provider_arns) + + def update_account_password_policy(self): + allow_change_password = self._get_bool_param( + "AllowUsersToChangePassword", False + ) + hard_expiry = self._get_bool_param("HardExpiry") + max_password_age = self._get_int_param("MaxPasswordAge") + minimum_password_length = self._get_int_param("MinimumPasswordLength", 6) + password_reuse_prevention = self._get_int_param("PasswordReusePrevention") + require_lowercase_characters = self._get_bool_param( + "RequireLowercaseCharacters", False + ) + require_numbers = self._get_bool_param("RequireNumbers", False) + require_symbols = self._get_bool_param("RequireSymbols", False) + require_uppercase_characters = self._get_bool_param( + "RequireUppercaseCharacters", False + ) + + iam_backend.update_account_password_policy( + allow_change_password, + hard_expiry, + max_password_age, + minimum_password_length, + password_reuse_prevention, + require_lowercase_characters, + require_numbers, + require_symbols, + require_uppercase_characters, + ) + + template = self.response_template(UPDATE_ACCOUNT_PASSWORD_POLICY_TEMPLATE) + return template.render() + + def get_account_password_policy(self): + account_password_policy = iam_backend.get_account_password_policy() + + template = self.response_template(GET_ACCOUNT_PASSWORD_POLICY_TEMPLATE) + return template.render(password_policy=account_password_policy) + + def delete_account_password_policy(self): + iam_backend.delete_account_password_policy() + + template = self.response_template(DELETE_ACCOUNT_PASSWORD_POLICY_TEMPLATE) + return template.render() + + def get_account_summary(self): + account_summary = iam_backend.get_account_summary() + + template = self.response_template(GET_ACCOUNT_SUMMARY_TEMPLATE) + return template.render(summary_map=account_summary.summary_map) + LIST_ENTITIES_FOR_POLICY_TEMPLATE = """ @@ -1004,9 +1192,12 @@ CREATE_ROLE_TEMPLATE = """ """ +UPLOAD_SSH_PUBLIC_KEY_TEMPLATE = """ + + + {{ key.user_name }} + {{ key.ssh_public_key_body }} + {{ key.ssh_public_key_id }} + {{ key.fingerprint }} + {{ key.status }} + {{ key.uploaded_iso_8601 }} + + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + +GET_SSH_PUBLIC_KEY_TEMPLATE = """ + + + {{ key.user_name }} + {{ key.ssh_public_key_body }} + {{ key.ssh_public_key_id }} + {{ key.fingerprint }} + {{ key.status }} + {{ key.uploaded_iso_8601 }} + + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + +LIST_SSH_PUBLIC_KEYS_TEMPLATE = """ + + + {% for key in keys %} + + {{ key.user_name }} + {{ key.ssh_public_key_id }} + {{ key.status }} + {{ key.uploaded_iso_8601 }} + + {% endfor %} + + false + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + +UPDATE_SSH_PUBLIC_KEY_TEMPLATE = """ + + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + +DELETE_SSH_PUBLIC_KEY_TEMPLATE = """ + + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + CREDENTIAL_REPORT_GENERATING = """ @@ -1562,6 +1824,7 @@ CREDENTIAL_REPORT_GENERATING = """ """ + CREDENTIAL_REPORT_GENERATED = """ COMPLETE @@ -1571,6 +1834,7 @@ CREDENTIAL_REPORT_GENERATED = """ """ + CREDENTIAL_REPORT = """ {{ report }} @@ -1582,6 +1846,7 @@ CREDENTIAL_REPORT = """ """ + LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE = """ false @@ -1614,6 +1879,7 @@ LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE = """ """ + LIST_MFA_DEVICES_TEMPLATE = """ @@ -1632,6 +1898,61 @@ LIST_MFA_DEVICES_TEMPLATE = """ """ +CREATE_VIRTUAL_MFA_DEVICE_TEMPLATE = """ + + + {{ device.serial_number }} + {{ device.base32_string_seed }} + {{ device.qr_code_png }} + + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + + +DELETE_VIRTUAL_MFA_DEVICE_TEMPLATE = """ + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + + +LIST_VIRTUAL_MFA_DEVICES_TEMPLATE = """ + + {% if marker is none %} + false + {% else %} + true + {{ marker }} + {% endif %} + + {% for device in devices %} + + {{ device.serial_number }} + {% if device.enable_date %} + {{ device.enabled_iso_8601 }} + {% endif %} + {% if device.user %} + + {{ device.user.path }} + {{ device.user.name }} + {{ device.user.id }} + {{ device.user.created_iso_8601 }} + {{ device.user.arn }} + + {% endif %} + + {% endfor %} + + + + b61ce1b1-0401-11e1-b2f8-2dEXAMPLEbfc + +""" + + LIST_ACCOUNT_ALIASES_TEMPLATE = """ false @@ -1968,3 +2289,115 @@ UNTAG_ROLE_TEMPLATE = """ + + {{ open_id_provider.arn }} + + + f248366a-4f64-11e4-aefa-bfd6aEXAMPLE + +""" + + +DELETE_OPEN_ID_CONNECT_PROVIDER_TEMPLATE = """ + + b5e49e29-4f64-11e4-aefa-bfd6aEXAMPLE + +""" + + +GET_OPEN_ID_CONNECT_PROVIDER_TEMPLATE = """ + + + {% for thumbprint in open_id_provider.thumbprint_list %} + {{ thumbprint }} + {% endfor %} + + {{ open_id_provider.created_iso_8601 }} + + {% for client_id in open_id_provider.client_id_list %} + {{ client_id }} + {% endfor %} + + {{ open_id_provider.url }} + + + 2c91531b-4f65-11e4-aefa-bfd6aEXAMPLE + +""" + + +LIST_OPEN_ID_CONNECT_PROVIDERS_TEMPLATE = """ + + + {% for open_id_provider_arn in open_id_provider_arns %} + + {{ open_id_provider_arn }} + + {% endfor %} + + + + de2c0228-4f63-11e4-aefa-bfd6aEXAMPLE + +""" + + +UPDATE_ACCOUNT_PASSWORD_POLICY_TEMPLATE = """ + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + + +GET_ACCOUNT_PASSWORD_POLICY_TEMPLATE = """ + + + {{ password_policy.allow_users_to_change_password | lower }} + {{ password_policy.expire_passwords | lower }} + {% if password_policy.hard_expiry %} + {{ password_policy.hard_expiry | lower }} + {% endif %} + {% if password_policy.max_password_age %} + {{ password_policy.max_password_age }} + {% endif %} + {{ password_policy.minimum_password_length }} + {% if password_policy.password_reuse_prevention %} + {{ password_policy.password_reuse_prevention }} + {% endif %} + {{ password_policy.require_lowercase_characters | lower }} + {{ password_policy.require_numbers | lower }} + {{ password_policy.require_symbols | lower }} + {{ password_policy.require_uppercase_characters | lower }} + + + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + + +DELETE_ACCOUNT_PASSWORD_POLICY_TEMPLATE = """ + + 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE + +""" + + +GET_ACCOUNT_SUMMARY_TEMPLATE = """ + + + {% for key, value in summary_map.items() %} + + {{ key }} + {{ value }} + + {% endfor %} + + + + 85cb9b90-ac28-11e4-a88d-97964EXAMPLE + +""" diff --git a/moto/iam/urls.py b/moto/iam/urls.py index 46db41e46..c4ce1d81f 100644 --- a/moto/iam/urls.py +++ b/moto/iam/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import IamResponse -url_bases = [ - "https?://iam(.*).amazonaws.com", -] +url_bases = ["https?://iam(.*).amazonaws.com"] -url_paths = { - '{0}/$': IamResponse.dispatch, -} +url_paths = {"{0}/$": IamResponse.dispatch} diff --git a/moto/iam/utils.py b/moto/iam/utils.py index 2bd6448f9..391f54dbd 100644 --- a/moto/iam/utils.py +++ b/moto/iam/utils.py @@ -5,29 +5,26 @@ import six def random_alphanumeric(length): - return ''.join(six.text_type( - random.choice( - string.ascii_letters + string.digits + "+" + "/" - )) for _ in range(length) + return "".join( + six.text_type(random.choice(string.ascii_letters + string.digits + "+" + "/")) + for _ in range(length) ) def random_resource_id(size=20): chars = list(range(10)) + list(string.ascii_lowercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) def random_access_key(): - return ''.join(six.text_type( - random.choice( - string.ascii_uppercase + string.digits - )) for _ in range(16) + return "".join( + six.text_type(random.choice(string.ascii_uppercase + string.digits)) + for _ in range(16) ) def random_policy_id(): - return 'A' + ''.join( - random.choice(string.ascii_uppercase + string.digits) - for _ in range(20) + return "A" + "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(20) ) diff --git a/moto/instance_metadata/responses.py b/moto/instance_metadata/responses.py index 460e65aca..81dfd8b59 100644 --- a/moto/instance_metadata/responses.py +++ b/moto/instance_metadata/responses.py @@ -7,7 +7,6 @@ from moto.core.responses import BaseResponse class InstanceMetadataResponse(BaseResponse): - def metadata_response(self, request, full_url, headers): """ Mock response for localhost metadata @@ -21,7 +20,7 @@ class InstanceMetadataResponse(BaseResponse): AccessKeyId="test-key", SecretAccessKey="test-secret-key", Token="test-session-token", - Expiration=tomorrow.strftime("%Y-%m-%dT%H:%M:%SZ") + Expiration=tomorrow.strftime("%Y-%m-%dT%H:%M:%SZ"), ) path = parsed_url.path @@ -29,21 +28,18 @@ class InstanceMetadataResponse(BaseResponse): meta_data_prefix = "/latest/meta-data/" # Strip prefix if it is there if path.startswith(meta_data_prefix): - path = path[len(meta_data_prefix):] + path = path[len(meta_data_prefix) :] - if path == '': - result = 'iam' - elif path == 'iam': - result = json.dumps({ - 'security-credentials': { - 'default-role': credentials - } - }) - elif path == 'iam/security-credentials/': - result = 'default-role' - elif path == 'iam/security-credentials/default-role': + if path == "": + result = "iam" + elif path == "iam": + result = json.dumps({"security-credentials": {"default-role": credentials}}) + elif path == "iam/security-credentials/": + result = "default-role" + elif path == "iam/security-credentials/default-role": result = json.dumps(credentials) else: raise NotImplementedError( - "The {0} metadata path has not been implemented".format(path)) + "The {0} metadata path has not been implemented".format(path) + ) return 200, headers, result diff --git a/moto/instance_metadata/urls.py b/moto/instance_metadata/urls.py index 7776b364a..b77935473 100644 --- a/moto/instance_metadata/urls.py +++ b/moto/instance_metadata/urls.py @@ -1,12 +1,8 @@ from __future__ import unicode_literals from .responses import InstanceMetadataResponse -url_bases = [ - "http://169.254.169.254" -] +url_bases = ["http://169.254.169.254"] instance_metadata = InstanceMetadataResponse() -url_paths = { - '{0}/(?P.+)': instance_metadata.metadata_response, -} +url_paths = {"{0}/(?P.+)": instance_metadata.metadata_response} diff --git a/moto/iot/__init__.py b/moto/iot/__init__.py index 199b8aeae..97d36fbcc 100644 --- a/moto/iot/__init__.py +++ b/moto/iot/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import iot_backends from ..core.models import base_decorator -iot_backend = iot_backends['us-east-1'] +iot_backend = iot_backends["us-east-1"] mock_iot = base_decorator(iot_backends) diff --git a/moto/iot/exceptions.py b/moto/iot/exceptions.py index b5725d8fe..2854fbb17 100644 --- a/moto/iot/exceptions.py +++ b/moto/iot/exceptions.py @@ -10,8 +10,7 @@ class ResourceNotFoundException(IoTClientError): def __init__(self): self.code = 404 super(ResourceNotFoundException, self).__init__( - "ResourceNotFoundException", - "The specified resource does not exist" + "ResourceNotFoundException", "The specified resource does not exist" ) @@ -19,8 +18,7 @@ class InvalidRequestException(IoTClientError): def __init__(self, msg=None): self.code = 400 super(InvalidRequestException, self).__init__( - "InvalidRequestException", - msg or "The request is not valid." + "InvalidRequestException", msg or "The request is not valid." ) @@ -37,8 +35,8 @@ class VersionConflictException(IoTClientError): def __init__(self, name): self.code = 409 super(VersionConflictException, self).__init__( - 'VersionConflictException', - 'The version for thing %s does not match the expected version.' % name + "VersionConflictException", + "The version for thing %s does not match the expected version." % name, ) @@ -46,14 +44,11 @@ class CertificateStateException(IoTClientError): def __init__(self, msg, cert_id): self.code = 406 super(CertificateStateException, self).__init__( - 'CertificateStateException', - '%s Id: %s' % (msg, cert_id) + "CertificateStateException", "%s Id: %s" % (msg, cert_id) ) class DeleteConflictException(IoTClientError): def __init__(self, msg): self.code = 409 - super(DeleteConflictException, self).__init__( - 'DeleteConflictException', msg - ) + super(DeleteConflictException, self).__init__("DeleteConflictException", msg) diff --git a/moto/iot/models.py b/moto/iot/models.py index 89d71dd14..b2599de1d 100644 --- a/moto/iot/models.py +++ b/moto/iot/models.py @@ -28,7 +28,7 @@ class FakeThing(BaseModel): self.thing_name = thing_name self.thing_type = thing_type self.attributes = attributes - self.arn = 'arn:aws:iot:%s:1:thing/%s' % (self.region_name, thing_name) + self.arn = "arn:aws:iot:%s:1:thing/%s" % (self.region_name, thing_name) self.version = 1 # TODO: we need to handle 'version'? @@ -37,15 +37,15 @@ class FakeThing(BaseModel): def to_dict(self, include_default_client_id=False): obj = { - 'thingName': self.thing_name, - 'thingArn': self.arn, - 'attributes': self.attributes, - 'version': self.version + "thingName": self.thing_name, + "thingArn": self.arn, + "attributes": self.attributes, + "version": self.version, } if self.thing_type: - obj['thingTypeName'] = self.thing_type.thing_type_name + obj["thingTypeName"] = self.thing_type.thing_type_name if include_default_client_id: - obj['defaultClientId'] = self.thing_name + obj["defaultClientId"] = self.thing_name return obj @@ -56,23 +56,27 @@ class FakeThingType(BaseModel): self.thing_type_properties = thing_type_properties self.thing_type_id = str(uuid.uuid4()) # I don't know the rule of id t = time.time() - self.metadata = { - 'deprecated': False, - 'creationData': int(t * 1000) / 1000.0 - } - self.arn = 'arn:aws:iot:%s:1:thingtype/%s' % (self.region_name, thing_type_name) + self.metadata = {"deprecated": False, "creationDate": int(t * 1000) / 1000.0} + self.arn = "arn:aws:iot:%s:1:thingtype/%s" % (self.region_name, thing_type_name) def to_dict(self): return { - 'thingTypeName': self.thing_type_name, - 'thingTypeId': self.thing_type_id, - 'thingTypeProperties': self.thing_type_properties, - 'thingTypeMetadata': self.metadata + "thingTypeName": self.thing_type_name, + "thingTypeId": self.thing_type_id, + "thingTypeProperties": self.thing_type_properties, + "thingTypeMetadata": self.metadata, } class FakeThingGroup(BaseModel): - def __init__(self, thing_group_name, parent_group_name, thing_group_properties, region_name): + def __init__( + self, + thing_group_name, + parent_group_name, + thing_group_properties, + region_name, + thing_groups, + ): self.region_name = region_name self.thing_group_name = thing_group_name self.thing_group_id = str(uuid.uuid4()) # I don't know the rule of id @@ -80,33 +84,59 @@ class FakeThingGroup(BaseModel): self.parent_group_name = parent_group_name self.thing_group_properties = thing_group_properties or {} t = time.time() - self.metadata = { - 'creationData': int(t * 1000) / 1000.0 - } - self.arn = 'arn:aws:iot:%s:1:thinggroup/%s' % (self.region_name, thing_group_name) + self.metadata = {"creationDate": int(t * 1000) / 1000.0} + if parent_group_name: + self.metadata["parentGroupName"] = parent_group_name + # initilize rootToParentThingGroups + if "rootToParentThingGroups" not in self.metadata: + self.metadata["rootToParentThingGroups"] = [] + # search for parent arn + for thing_group_arn, thing_group in thing_groups.items(): + if thing_group.thing_group_name == parent_group_name: + parent_thing_group_structure = thing_group + break + # if parent arn found (should always be found) + if parent_thing_group_structure: + # copy parent's rootToParentThingGroups + if "rootToParentThingGroups" in parent_thing_group_structure.metadata: + self.metadata["rootToParentThingGroups"].extend( + parent_thing_group_structure.metadata["rootToParentThingGroups"] + ) + self.metadata["rootToParentThingGroups"].extend( + [ + { + "groupName": parent_group_name, + "groupArn": parent_thing_group_structure.arn, + } + ] + ) + self.arn = "arn:aws:iot:%s:1:thinggroup/%s" % ( + self.region_name, + thing_group_name, + ) self.things = OrderedDict() def to_dict(self): return { - 'thingGroupName': self.thing_group_name, - 'thingGroupId': self.thing_group_id, - 'version': self.version, - 'thingGroupProperties': self.thing_group_properties, - 'thingGroupMetadata': self.metadata + "thingGroupName": self.thing_group_name, + "thingGroupId": self.thing_group_id, + "version": self.version, + "thingGroupProperties": self.thing_group_properties, + "thingGroupMetadata": self.metadata, } class FakeCertificate(BaseModel): def __init__(self, certificate_pem, status, region_name, ca_certificate_pem=None): m = hashlib.sha256() - m.update(str(uuid.uuid4()).encode('utf-8')) + m.update(str(uuid.uuid4()).encode("utf-8")) self.certificate_id = m.hexdigest() - self.arn = 'arn:aws:iot:%s:1:cert/%s' % (region_name, self.certificate_id) + self.arn = "arn:aws:iot:%s:1:cert/%s" % (region_name, self.certificate_id) self.certificate_pem = certificate_pem self.status = status # TODO: must adjust - self.owner = '1' + self.owner = "1" self.transfer_data = {} self.creation_date = time.time() self.last_modified_date = self.creation_date @@ -114,16 +144,16 @@ class FakeCertificate(BaseModel): self.ca_certificate_id = None self.ca_certificate_pem = ca_certificate_pem if ca_certificate_pem: - m.update(str(uuid.uuid4()).encode('utf-8')) + m.update(str(uuid.uuid4()).encode("utf-8")) self.ca_certificate_id = m.hexdigest() def to_dict(self): return { - 'certificateArn': self.arn, - 'certificateId': self.certificate_id, - 'caCertificateId': self.ca_certificate_id, - 'status': self.status, - 'creationDate': self.creation_date + "certificateArn": self.arn, + "certificateId": self.certificate_id, + "caCertificateId": self.ca_certificate_id, + "status": self.status, + "creationDate": self.creation_date, } def to_description_dict(self): @@ -133,14 +163,14 @@ class FakeCertificate(BaseModel): - previousOwnedBy """ return { - 'certificateArn': self.arn, - 'certificateId': self.certificate_id, - 'status': self.status, - 'certificatePem': self.certificate_pem, - 'ownedBy': self.owner, - 'creationDate': self.creation_date, - 'lastModifiedDate': self.last_modified_date, - 'transferData': self.transfer_data + "certificateArn": self.arn, + "certificateId": self.certificate_id, + "status": self.status, + "certificatePem": self.certificate_pem, + "ownedBy": self.owner, + "creationDate": self.creation_date, + "lastModifiedDate": self.last_modified_date, + "transferData": self.transfer_data, } @@ -169,10 +199,7 @@ class FakePolicy(BaseModel): } def to_dict(self): - return { - 'policyName': self.name, - 'policyArn': self.arn, - } + return {"policyName": self.name, "policyArn": self.arn} class FakePolicyVersion(object): @@ -223,14 +250,25 @@ class FakeJob(BaseModel): JOB_ID_REGEX_PATTERN = "[a-zA-Z0-9_-]" JOB_ID_REGEX = re.compile(JOB_ID_REGEX_PATTERN) - def __init__(self, job_id, targets, document_source, document, description, presigned_url_config, target_selection, - job_executions_rollout_config, document_parameters, region_name): + def __init__( + self, + job_id, + targets, + document_source, + document, + description, + presigned_url_config, + target_selection, + job_executions_rollout_config, + document_parameters, + region_name, + ): if not self._job_id_matcher(self.JOB_ID_REGEX, job_id): raise InvalidRequestException() self.region_name = region_name self.job_id = job_id - self.job_arn = 'arn:aws:iot:%s:1:job/%s' % (self.region_name, job_id) + self.job_arn = "arn:aws:iot:%s:1:job/%s" % (self.region_name, job_id) self.targets = targets self.document_source = document_source self.document = document @@ -246,14 +284,14 @@ class FakeJob(BaseModel): self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.completed_at = None self.job_process_details = { - 'processingTargets': targets, - 'numberOfQueuedThings': 1, - 'numberOfCanceledThings': 0, - 'numberOfSucceededThings': 0, - 'numberOfFailedThings': 0, - 'numberOfRejectedThings': 0, - 'numberOfInProgressThings': 0, - 'numberOfRemovedThings': 0 + "processingTargets": targets, + "numberOfQueuedThings": 1, + "numberOfCanceledThings": 0, + "numberOfSucceededThings": 0, + "numberOfFailedThings": 0, + "numberOfRejectedThings": 0, + "numberOfInProgressThings": 0, + "numberOfRemovedThings": 0, } self.document_parameters = document_parameters @@ -358,16 +396,18 @@ class IoTBackend(BaseBackend): thing_types = self.list_thing_types() thing_type = None if thing_type_name: - filtered_thing_types = [_ for _ in thing_types if _.thing_type_name == thing_type_name] + filtered_thing_types = [ + _ for _ in thing_types if _.thing_type_name == thing_type_name + ] if len(filtered_thing_types) == 0: raise ResourceNotFoundException() thing_type = filtered_thing_types[0] if attribute_payload is None: attributes = {} - elif 'attributes' not in attribute_payload: + elif "attributes" not in attribute_payload: attributes = {} else: - attributes = attribute_payload['attributes'] + attributes = attribute_payload["attributes"] thing = FakeThing(thing_name, thing_type, attributes, self.region_name) self.things[thing.arn] = thing return thing.thing_name, thing.arn @@ -375,41 +415,68 @@ class IoTBackend(BaseBackend): def create_thing_type(self, thing_type_name, thing_type_properties): if thing_type_properties is None: thing_type_properties = {} - thing_type = FakeThingType(thing_type_name, thing_type_properties, self.region_name) + thing_type = FakeThingType( + thing_type_name, thing_type_properties, self.region_name + ) self.thing_types[thing_type.arn] = thing_type return thing_type.thing_type_name, thing_type.arn def list_thing_types(self, thing_type_name=None): if thing_type_name: # It's weird but thing_type_name is filtered by forward match, not complete match - return [_ for _ in self.thing_types.values() if _.thing_type_name.startswith(thing_type_name)] + return [ + _ + for _ in self.thing_types.values() + if _.thing_type_name.startswith(thing_type_name) + ] return self.thing_types.values() - def list_things(self, attribute_name, attribute_value, thing_type_name, max_results, token): + def list_things( + self, attribute_name, attribute_value, thing_type_name, max_results, token + ): all_things = [_.to_dict() for _ in self.things.values()] if attribute_name is not None and thing_type_name is not None: - filtered_things = list(filter(lambda elem: - attribute_name in elem["attributes"] and - elem["attributes"][attribute_name] == attribute_value and - "thingTypeName" in elem and - elem["thingTypeName"] == thing_type_name, all_things)) + filtered_things = list( + filter( + lambda elem: attribute_name in elem["attributes"] + and elem["attributes"][attribute_name] == attribute_value + and "thingTypeName" in elem + and elem["thingTypeName"] == thing_type_name, + all_things, + ) + ) elif attribute_name is not None and thing_type_name is None: - filtered_things = list(filter(lambda elem: - attribute_name in elem["attributes"] and - elem["attributes"][attribute_name] == attribute_value, all_things)) + filtered_things = list( + filter( + lambda elem: attribute_name in elem["attributes"] + and elem["attributes"][attribute_name] == attribute_value, + all_things, + ) + ) elif attribute_name is None and thing_type_name is not None: filtered_things = list( - filter(lambda elem: "thingTypeName" in elem and elem["thingTypeName"] == thing_type_name, all_things)) + filter( + lambda elem: "thingTypeName" in elem + and elem["thingTypeName"] == thing_type_name, + all_things, + ) + ) else: filtered_things = all_things if token is None: things = filtered_things[0:max_results] - next_token = str(max_results) if len(filtered_things) > max_results else None + next_token = ( + str(max_results) if len(filtered_things) > max_results else None + ) else: token = int(token) - things = filtered_things[token:token + max_results] - next_token = str(token + max_results) if len(filtered_things) > token + max_results else None + things = filtered_things[token : token + max_results] + next_token = ( + str(token + max_results) + if len(filtered_things) > token + max_results + else None + ) return things, next_token @@ -420,7 +487,9 @@ class IoTBackend(BaseBackend): return things[0] def describe_thing_type(self, thing_type_name): - thing_types = [_ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name] + thing_types = [ + _ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name + ] if len(thing_types) == 0: raise ResourceNotFoundException() return thing_types[0] @@ -430,6 +499,12 @@ class IoTBackend(BaseBackend): # can raise ResourceNotFoundError thing = self.describe_thing(thing_name) + + # detach all principals + for k in list(self.principal_things.keys()): + if k[1] == thing_name: + del self.principal_things[k] + del self.things[thing.arn] def delete_thing_type(self, thing_type_name): @@ -437,7 +512,14 @@ class IoTBackend(BaseBackend): thing_type = self.describe_thing_type(thing_type_name) del self.thing_types[thing_type.arn] - def update_thing(self, thing_name, thing_type_name, attribute_payload, expected_version, remove_thing_type): + def update_thing( + self, + thing_name, + thing_type_name, + attribute_payload, + expected_version, + remove_thing_type, + ): # if attributes payload = {}, nothing thing = self.describe_thing(thing_name) thing_type = None @@ -448,7 +530,9 @@ class IoTBackend(BaseBackend): # thing_type if thing_type_name: thing_types = self.list_thing_types() - filtered_thing_types = [_ for _ in thing_types if _.thing_type_name == thing_type_name] + filtered_thing_types = [ + _ for _ in thing_types if _.thing_type_name == thing_type_name + ] if len(filtered_thing_types) == 0: raise ResourceNotFoundException() thing_type = filtered_thing_types[0] @@ -458,9 +542,9 @@ class IoTBackend(BaseBackend): thing.thing_type = None # attribute - if attribute_payload is not None and 'attributes' in attribute_payload: - do_merge = attribute_payload.get('merge', False) - attributes = attribute_payload['attributes'] + if attribute_payload is not None and "attributes" in attribute_payload: + do_merge = attribute_payload.get("merge", False) + attributes = attribute_payload["attributes"] if not do_merge: thing.attributes = attributes else: @@ -468,46 +552,59 @@ class IoTBackend(BaseBackend): def _random_string(self): n = 20 - random_str = ''.join([random.choice(string.ascii_letters + string.digits) for i in range(n)]) + random_str = "".join( + [random.choice(string.ascii_letters + string.digits) for i in range(n)] + ) return random_str def create_keys_and_certificate(self, set_as_active): # implement here # caCertificate can be blank key_pair = { - 'PublicKey': self._random_string(), - 'PrivateKey': self._random_string() + "PublicKey": self._random_string(), + "PrivateKey": self._random_string(), } certificate_pem = self._random_string() - status = 'ACTIVE' if set_as_active else 'INACTIVE' + status = "ACTIVE" if set_as_active else "INACTIVE" certificate = FakeCertificate(certificate_pem, status, self.region_name) self.certificates[certificate.certificate_id] = certificate return certificate, key_pair def delete_certificate(self, certificate_id): cert = self.describe_certificate(certificate_id) - if cert.status == 'ACTIVE': + if cert.status == "ACTIVE": raise CertificateStateException( - 'Certificate must be deactivated (not ACTIVE) before deletion.', certificate_id) - - certs = [k[0] for k, v in self.principal_things.items() - if self._get_principal(k[0]).certificate_id == certificate_id] - if len(certs) > 0: - raise DeleteConflictException( - 'Things must be detached before deletion (arn: %s)' % certs[0] + "Certificate must be deactivated (not ACTIVE) before deletion.", + certificate_id, ) - certs = [k[0] for k, v in self.principal_policies.items() - if self._get_principal(k[0]).certificate_id == certificate_id] + certs = [ + k[0] + for k, v in self.principal_things.items() + if self._get_principal(k[0]).certificate_id == certificate_id + ] if len(certs) > 0: raise DeleteConflictException( - 'Certificate policies must be detached before deletion (arn: %s)' % certs[0] + "Things must be detached before deletion (arn: %s)" % certs[0] + ) + + certs = [ + k[0] + for k, v in self.principal_policies.items() + if self._get_principal(k[0]).certificate_id == certificate_id + ] + if len(certs) > 0: + raise DeleteConflictException( + "Certificate policies must be detached before deletion (arn: %s)" + % certs[0] ) del self.certificates[certificate_id] def describe_certificate(self, certificate_id): - certs = [_ for _ in self.certificates.values() if _.certificate_id == certificate_id] + certs = [ + _ for _ in self.certificates.values() if _.certificate_id == certificate_id + ] if len(certs) == 0: raise ResourceNotFoundException() return certs[0] @@ -515,9 +612,15 @@ class IoTBackend(BaseBackend): def list_certificates(self): return self.certificates.values() - def register_certificate(self, certificate_pem, ca_certificate_pem, set_as_active, status): - certificate = FakeCertificate(certificate_pem, 'ACTIVE' if set_as_active else status, - self.region_name, ca_certificate_pem) + def register_certificate( + self, certificate_pem, ca_certificate_pem, set_as_active, status + ): + certificate = FakeCertificate( + certificate_pem, + "ACTIVE" if set_as_active else status, + self.region_name, + ca_certificate_pem, + ) self.certificates[certificate.certificate_id] = certificate return certificate @@ -565,10 +668,12 @@ class IoTBackend(BaseBackend): def delete_policy(self, policy_name): - policies = [k[1] for k, v in self.principal_policies.items() if k[1] == policy_name] + policies = [ + k[1] for k, v in self.principal_policies.items() if k[1] == policy_name + ] if len(policies) > 0: raise DeleteConflictException( - 'The policy cannot be deleted as the policy is attached to one or more principals (name=%s)' + "The policy cannot be deleted as the policy is attached to one or more principals (name=%s)" % policy_name ) @@ -630,7 +735,7 @@ class IoTBackend(BaseBackend): """ raise ResourceNotFoundException """ - if ':cert/' in principal_arn: + if ":cert/" in principal_arn: certs = [_ for _ in self.certificates.values() if _.arn == principal_arn] if len(certs) == 0: raise ResourceNotFoundException() @@ -660,11 +765,15 @@ class IoTBackend(BaseBackend): del self.principal_policies[k] def list_principal_policies(self, principal_arn): - policies = [v[1] for k, v in self.principal_policies.items() if k[0] == principal_arn] + policies = [ + v[1] for k, v in self.principal_policies.items() if k[0] == principal_arn + ] return policies def list_policy_principals(self, policy_name): - principals = [k[0] for k, v in self.principal_policies.items() if k[1] == policy_name] + principals = [ + k[0] for k, v in self.principal_policies.items() if k[1] == policy_name + ] return principals def attach_thing_principal(self, thing_name, principal_arn): @@ -686,21 +795,37 @@ class IoTBackend(BaseBackend): del self.principal_things[k] def list_principal_things(self, principal_arn): - thing_names = [k[0] for k, v in self.principal_things.items() if k[0] == principal_arn] + thing_names = [ + k[0] for k, v in self.principal_things.items() if k[0] == principal_arn + ] return thing_names def list_thing_principals(self, thing_name): - principals = [k[0] for k, v in self.principal_things.items() if k[1] == thing_name] + principals = [ + k[0] for k, v in self.principal_things.items() if k[1] == thing_name + ] return principals def describe_thing_group(self, thing_group_name): - thing_groups = [_ for _ in self.thing_groups.values() if _.thing_group_name == thing_group_name] + thing_groups = [ + _ + for _ in self.thing_groups.values() + if _.thing_group_name == thing_group_name + ] if len(thing_groups) == 0: raise ResourceNotFoundException() return thing_groups[0] - def create_thing_group(self, thing_group_name, parent_group_name, thing_group_properties): - thing_group = FakeThingGroup(thing_group_name, parent_group_name, thing_group_properties, self.region_name) + def create_thing_group( + self, thing_group_name, parent_group_name, thing_group_properties + ): + thing_group = FakeThingGroup( + thing_group_name, + parent_group_name, + thing_group_properties, + self.region_name, + self.thing_groups, + ) self.thing_groups[thing_group.arn] = thing_group return thing_group.thing_group_name, thing_group.arn, thing_group.thing_group_id @@ -712,19 +837,25 @@ class IoTBackend(BaseBackend): thing_groups = self.thing_groups.values() return thing_groups - def update_thing_group(self, thing_group_name, thing_group_properties, expected_version): + def update_thing_group( + self, thing_group_name, thing_group_properties, expected_version + ): thing_group = self.describe_thing_group(thing_group_name) if expected_version and expected_version != thing_group.version: raise VersionConflictException(thing_group_name) - attribute_payload = thing_group_properties.get('attributePayload', None) - if attribute_payload is not None and 'attributes' in attribute_payload: - do_merge = attribute_payload.get('merge', False) - attributes = attribute_payload['attributes'] + attribute_payload = thing_group_properties.get("attributePayload", None) + if attribute_payload is not None and "attributes" in attribute_payload: + do_merge = attribute_payload.get("merge", False) + attributes = attribute_payload["attributes"] if not do_merge: - thing_group.thing_group_properties['attributePayload']['attributes'] = attributes + thing_group.thing_group_properties["attributePayload"][ + "attributes" + ] = attributes else: - thing_group.thing_group_properties['attributePayload']['attributes'].update(attributes) - elif attribute_payload is not None and 'attributes' not in attribute_payload: + thing_group.thing_group_properties["attributePayload"][ + "attributes" + ].update(attributes) + elif attribute_payload is not None and "attributes" not in attribute_payload: thing_group.attributes = {} thing_group.version = thing_group.version + 1 return thing_group.version @@ -733,13 +864,13 @@ class IoTBackend(BaseBackend): # identify thing group if thing_group_name is None and thing_group_arn is None: raise InvalidRequestException( - ' Both thingGroupArn and thingGroupName are empty. Need to specify at least one of them' + " Both thingGroupArn and thingGroupName are empty. Need to specify at least one of them" ) if thing_group_name is not None: thing_group = self.describe_thing_group(thing_group_name) if thing_group_arn and thing_group.arn != thing_group_arn: raise InvalidRequestException( - 'ThingGroupName thingGroupArn does not match specified thingGroupName in request' + "ThingGroupName thingGroupArn does not match specified thingGroupName in request" ) elif thing_group_arn is not None: if thing_group_arn not in self.thing_groups: @@ -751,13 +882,13 @@ class IoTBackend(BaseBackend): # identify thing if thing_name is None and thing_arn is None: raise InvalidRequestException( - 'Both thingArn and thingName are empty. Need to specify at least one of them' + "Both thingArn and thingName are empty. Need to specify at least one of them" ) if thing_name is not None: thing = self.describe_thing(thing_name) if thing_arn and thing.arn != thing_arn: raise InvalidRequestException( - 'ThingName thingArn does not match specified thingName in request' + "ThingName thingArn does not match specified thingName in request" ) elif thing_arn is not None: if thing_arn not in self.things: @@ -765,7 +896,9 @@ class IoTBackend(BaseBackend): thing = self.things[thing_arn] return thing - def add_thing_to_thing_group(self, thing_group_name, thing_group_arn, thing_name, thing_arn): + def add_thing_to_thing_group( + self, thing_group_name, thing_group_arn, thing_name, thing_arn + ): thing_group = self._identify_thing_group(thing_group_name, thing_group_arn) thing = self._identify_thing(thing_name, thing_arn) if thing.arn in thing_group.things: @@ -773,7 +906,9 @@ class IoTBackend(BaseBackend): return thing_group.things[thing.arn] = thing - def remove_thing_from_thing_group(self, thing_group_name, thing_group_arn, thing_name, thing_arn): + def remove_thing_from_thing_group( + self, thing_group_name, thing_group_arn, thing_name, thing_arn + ): thing_group = self._identify_thing_group(thing_group_name, thing_group_arn) thing = self._identify_thing(thing_name, thing_arn) if thing.arn not in thing_group.things: @@ -791,31 +926,53 @@ class IoTBackend(BaseBackend): ret = [] for thing_group in all_thing_groups: if thing.arn in thing_group.things: - ret.append({ - 'groupName': thing_group.thing_group_name, - 'groupArn': thing_group.arn - }) + ret.append( + { + "groupName": thing_group.thing_group_name, + "groupArn": thing_group.arn, + } + ) return ret - def update_thing_groups_for_thing(self, thing_name, thing_groups_to_add, thing_groups_to_remove): + def update_thing_groups_for_thing( + self, thing_name, thing_groups_to_add, thing_groups_to_remove + ): thing = self.describe_thing(thing_name) for thing_group_name in thing_groups_to_add: thing_group = self.describe_thing_group(thing_group_name) self.add_thing_to_thing_group( - thing_group.thing_group_name, None, - thing.thing_name, None + thing_group.thing_group_name, None, thing.thing_name, None ) for thing_group_name in thing_groups_to_remove: thing_group = self.describe_thing_group(thing_group_name) self.remove_thing_from_thing_group( - thing_group.thing_group_name, None, - thing.thing_name, None + thing_group.thing_group_name, None, thing.thing_name, None ) - def create_job(self, job_id, targets, document_source, document, description, presigned_url_config, - target_selection, job_executions_rollout_config, document_parameters): - job = FakeJob(job_id, targets, document_source, document, description, presigned_url_config, target_selection, - job_executions_rollout_config, document_parameters, self.region_name) + def create_job( + self, + job_id, + targets, + document_source, + document, + description, + presigned_url_config, + target_selection, + job_executions_rollout_config, + document_parameters, + ): + job = FakeJob( + job_id, + targets, + document_source, + document, + description, + presigned_url_config, + target_selection, + job_executions_rollout_config, + document_parameters, + self.region_name, + ) self.jobs[job_id] = job for thing_arn in targets: diff --git a/moto/iot/responses.py b/moto/iot/responses.py index 8954c7003..e88e9264a 100644 --- a/moto/iot/responses.py +++ b/moto/iot/responses.py @@ -8,7 +8,7 @@ from .models import iot_backends class IoTResponse(BaseResponse): - SERVICE_NAME = 'iot' + SERVICE_NAME = "iot" @property def iot_backend(self): @@ -29,18 +29,19 @@ class IoTResponse(BaseResponse): thing_type_name = self._get_param("thingTypeName") thing_type_properties = self._get_param("thingTypeProperties") thing_type_name, thing_type_arn = self.iot_backend.create_thing_type( - thing_type_name=thing_type_name, - thing_type_properties=thing_type_properties, + thing_type_name=thing_type_name, thing_type_properties=thing_type_properties + ) + return json.dumps( + dict(thingTypeName=thing_type_name, thingTypeArn=thing_type_arn) ) - return json.dumps(dict(thingTypeName=thing_type_name, thingTypeArn=thing_type_arn)) def list_thing_types(self): previous_next_token = self._get_param("nextToken") - max_results = self._get_int_param("maxResults", 50) # not the default, but makes testing easier + max_results = self._get_int_param( + "maxResults", 50 + ) # not the default, but makes testing easier thing_type_name = self._get_param("thingTypeName") - thing_types = self.iot_backend.list_thing_types( - thing_type_name=thing_type_name - ) + thing_types = self.iot_backend.list_thing_types(thing_type_name=thing_type_name) thing_types = [_.to_dict() for _ in thing_types] if previous_next_token is None: @@ -48,14 +49,20 @@ class IoTResponse(BaseResponse): next_token = str(max_results) if len(thing_types) > max_results else None else: token = int(previous_next_token) - result = thing_types[token:token + max_results] - next_token = str(token + max_results) if len(thing_types) > token + max_results else None + result = thing_types[token : token + max_results] + next_token = ( + str(token + max_results) + if len(thing_types) > token + max_results + else None + ) return json.dumps(dict(thingTypes=result, nextToken=next_token)) def list_things(self): previous_next_token = self._get_param("nextToken") - max_results = self._get_int_param("maxResults", 50) # not the default, but makes testing easier + max_results = self._get_int_param( + "maxResults", 50 + ) # not the default, but makes testing easier attribute_name = self._get_param("attributeName") attribute_value = self._get_param("attributeValue") thing_type_name = self._get_param("thingTypeName") @@ -64,22 +71,20 @@ class IoTResponse(BaseResponse): attribute_value=attribute_value, thing_type_name=thing_type_name, max_results=max_results, - token=previous_next_token + token=previous_next_token, ) return json.dumps(dict(things=things, nextToken=next_token)) def describe_thing(self): thing_name = self._get_param("thingName") - thing = self.iot_backend.describe_thing( - thing_name=thing_name, - ) + thing = self.iot_backend.describe_thing(thing_name=thing_name) return json.dumps(thing.to_dict(include_default_client_id=True)) def describe_thing_type(self): thing_type_name = self._get_param("thingTypeName") thing_type = self.iot_backend.describe_thing_type( - thing_type_name=thing_type_name, + thing_type_name=thing_type_name ) return json.dumps(thing_type.to_dict()) @@ -87,16 +92,13 @@ class IoTResponse(BaseResponse): thing_name = self._get_param("thingName") expected_version = self._get_param("expectedVersion") self.iot_backend.delete_thing( - thing_name=thing_name, - expected_version=expected_version, + thing_name=thing_name, expected_version=expected_version ) return json.dumps(dict()) def delete_thing_type(self): thing_type_name = self._get_param("thingTypeName") - self.iot_backend.delete_thing_type( - thing_type_name=thing_type_name, - ) + self.iot_backend.delete_thing_type(thing_type_name=thing_type_name) return json.dumps(dict()) def update_thing(self): @@ -124,7 +126,7 @@ class IoTResponse(BaseResponse): presigned_url_config=self._get_param("presignedUrlConfig"), target_selection=self._get_param("targetSelection"), job_executions_rollout_config=self._get_param("jobExecutionsRolloutConfig"), - document_parameters=self._get_param("documentParameters") + document_parameters=self._get_param("documentParameters"), ) return json.dumps(dict(jobArn=job_arn, jobId=job_id, description=description)) @@ -265,28 +267,30 @@ class IoTResponse(BaseResponse): def create_keys_and_certificate(self): set_as_active = self._get_bool_param("setAsActive") cert, key_pair = self.iot_backend.create_keys_and_certificate( - set_as_active=set_as_active, + set_as_active=set_as_active + ) + return json.dumps( + dict( + certificateArn=cert.arn, + certificateId=cert.certificate_id, + certificatePem=cert.certificate_pem, + keyPair=key_pair, + ) ) - return json.dumps(dict( - certificateArn=cert.arn, - certificateId=cert.certificate_id, - certificatePem=cert.certificate_pem, - keyPair=key_pair - )) def delete_certificate(self): certificate_id = self._get_param("certificateId") - self.iot_backend.delete_certificate( - certificate_id=certificate_id, - ) + self.iot_backend.delete_certificate(certificate_id=certificate_id) return json.dumps(dict()) def describe_certificate(self): certificate_id = self._get_param("certificateId") certificate = self.iot_backend.describe_certificate( - certificate_id=certificate_id, + certificate_id=certificate_id + ) + return json.dumps( + dict(certificateDescription=certificate.to_description_dict()) ) - return json.dumps(dict(certificateDescription=certificate.to_description_dict())) def list_certificates(self): # page_size = self._get_int_param("pageSize") @@ -306,16 +310,17 @@ class IoTResponse(BaseResponse): certificate_pem=certificate_pem, ca_certificate_pem=ca_certificate_pem, set_as_active=set_as_active, - status=status + status=status, + ) + return json.dumps( + dict(certificateId=cert.certificate_id, certificateArn=cert.arn) ) - return json.dumps(dict(certificateId=cert.certificate_id, certificateArn=cert.arn)) def update_certificate(self): certificate_id = self._get_param("certificateId") new_status = self._get_param("newStatus") self.iot_backend.update_certificate( - certificate_id=certificate_id, - new_status=new_status, + certificate_id=certificate_id, new_status=new_status ) return json.dumps(dict()) @@ -323,8 +328,7 @@ class IoTResponse(BaseResponse): policy_name = self._get_param("policyName") policy_document = self._get_param("policyDocument") policy = self.iot_backend.create_policy( - policy_name=policy_name, - policy_document=policy_document, + policy_name=policy_name, policy_document=policy_document ) return json.dumps(policy.to_dict_at_creation()) @@ -339,16 +343,12 @@ class IoTResponse(BaseResponse): def get_policy(self): policy_name = self._get_param("policyName") - policy = self.iot_backend.get_policy( - policy_name=policy_name, - ) + policy = self.iot_backend.get_policy(policy_name=policy_name) return json.dumps(policy.to_get_dict()) def delete_policy(self): policy_name = self._get_param("policyName") - self.iot_backend.delete_policy( - policy_name=policy_name, - ) + self.iot_backend.delete_policy(policy_name=policy_name) return json.dumps(dict()) def create_policy_version(self): @@ -387,11 +387,8 @@ class IoTResponse(BaseResponse): def attach_policy(self): policy_name = self._get_param("policyName") - target = self._get_param('target') - self.iot_backend.attach_policy( - policy_name=policy_name, - target=target, - ) + target = self._get_param("target") + self.iot_backend.attach_policy(policy_name=policy_name, target=target) return json.dumps(dict()) def list_attached_policies(self): @@ -407,95 +404,82 @@ class IoTResponse(BaseResponse): def attach_principal_policy(self): policy_name = self._get_param("policyName") - principal = self.headers.get('x-amzn-iot-principal') + principal = self.headers.get("x-amzn-iot-principal") self.iot_backend.attach_principal_policy( - policy_name=policy_name, - principal_arn=principal, + policy_name=policy_name, principal_arn=principal ) return json.dumps(dict()) def detach_policy(self): policy_name = self._get_param("policyName") - target = self._get_param('target') - self.iot_backend.detach_policy( - policy_name=policy_name, - target=target, - ) + target = self._get_param("target") + self.iot_backend.detach_policy(policy_name=policy_name, target=target) return json.dumps(dict()) def detach_principal_policy(self): policy_name = self._get_param("policyName") - principal = self.headers.get('x-amzn-iot-principal') + principal = self.headers.get("x-amzn-iot-principal") self.iot_backend.detach_principal_policy( - policy_name=policy_name, - principal_arn=principal, + policy_name=policy_name, principal_arn=principal ) return json.dumps(dict()) def list_principal_policies(self): - principal = self.headers.get('x-amzn-iot-principal') + principal = self.headers.get("x-amzn-iot-principal") # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") # ascending_order = self._get_param("ascendingOrder") - policies = self.iot_backend.list_principal_policies( - principal_arn=principal - ) + policies = self.iot_backend.list_principal_policies(principal_arn=principal) # TODO: implement pagination in the future next_marker = None - return json.dumps(dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker)) + return json.dumps( + dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker) + ) def list_policy_principals(self): - policy_name = self.headers.get('x-amzn-iot-policy') + policy_name = self.headers.get("x-amzn-iot-policy") # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") # ascending_order = self._get_param("ascendingOrder") - principals = self.iot_backend.list_policy_principals( - policy_name=policy_name, - ) + principals = self.iot_backend.list_policy_principals(policy_name=policy_name) # TODO: implement pagination in the future next_marker = None return json.dumps(dict(principals=principals, nextMarker=next_marker)) def attach_thing_principal(self): thing_name = self._get_param("thingName") - principal = self.headers.get('x-amzn-principal') + principal = self.headers.get("x-amzn-principal") self.iot_backend.attach_thing_principal( - thing_name=thing_name, - principal_arn=principal, + thing_name=thing_name, principal_arn=principal ) return json.dumps(dict()) def detach_thing_principal(self): thing_name = self._get_param("thingName") - principal = self.headers.get('x-amzn-principal') + principal = self.headers.get("x-amzn-principal") self.iot_backend.detach_thing_principal( - thing_name=thing_name, - principal_arn=principal, + thing_name=thing_name, principal_arn=principal ) return json.dumps(dict()) def list_principal_things(self): next_token = self._get_param("nextToken") # max_results = self._get_int_param("maxResults") - principal = self.headers.get('x-amzn-principal') - things = self.iot_backend.list_principal_things( - principal_arn=principal, - ) + principal = self.headers.get("x-amzn-principal") + things = self.iot_backend.list_principal_things(principal_arn=principal) # TODO: implement pagination in the future next_token = None return json.dumps(dict(things=things, nextToken=next_token)) def list_thing_principals(self): thing_name = self._get_param("thingName") - principals = self.iot_backend.list_thing_principals( - thing_name=thing_name, - ) + principals = self.iot_backend.list_thing_principals(thing_name=thing_name) return json.dumps(dict(principals=principals)) def describe_thing_group(self): thing_group_name = self._get_param("thingGroupName") thing_group = self.iot_backend.describe_thing_group( - thing_group_name=thing_group_name, + thing_group_name=thing_group_name ) return json.dumps(thing_group.to_dict()) @@ -503,23 +487,28 @@ class IoTResponse(BaseResponse): thing_group_name = self._get_param("thingGroupName") parent_group_name = self._get_param("parentGroupName") thing_group_properties = self._get_param("thingGroupProperties") - thing_group_name, thing_group_arn, thing_group_id = self.iot_backend.create_thing_group( + ( + thing_group_name, + thing_group_arn, + thing_group_id, + ) = self.iot_backend.create_thing_group( thing_group_name=thing_group_name, parent_group_name=parent_group_name, thing_group_properties=thing_group_properties, ) - return json.dumps(dict( - thingGroupName=thing_group_name, - thingGroupArn=thing_group_arn, - thingGroupId=thing_group_id) + return json.dumps( + dict( + thingGroupName=thing_group_name, + thingGroupArn=thing_group_arn, + thingGroupId=thing_group_id, + ) ) def delete_thing_group(self): thing_group_name = self._get_param("thingGroupName") expected_version = self._get_param("expectedVersion") self.iot_backend.delete_thing_group( - thing_group_name=thing_group_name, - expected_version=expected_version, + thing_group_name=thing_group_name, expected_version=expected_version ) return json.dumps(dict()) @@ -535,7 +524,9 @@ class IoTResponse(BaseResponse): recursive=recursive, ) next_token = None - rets = [{'groupName': _.thing_group_name, 'groupArn': _.arn} for _ in thing_groups] + rets = [ + {"groupName": _.thing_group_name, "groupArn": _.arn} for _ in thing_groups + ] # TODO: implement pagination in the future return json.dumps(dict(thingGroups=rets, nextToken=next_token)) @@ -582,8 +573,7 @@ class IoTResponse(BaseResponse): # next_token = self._get_param("nextToken") # max_results = self._get_int_param("maxResults") things = self.iot_backend.list_things_in_thing_group( - thing_group_name=thing_group_name, - recursive=recursive, + thing_group_name=thing_group_name, recursive=recursive ) next_token = None thing_names = [_.thing_name for _ in things] diff --git a/moto/iot/urls.py b/moto/iot/urls.py index 6d11c15a5..2ad908714 100644 --- a/moto/iot/urls.py +++ b/moto/iot/urls.py @@ -1,14 +1,10 @@ from __future__ import unicode_literals from .responses import IoTResponse -url_bases = [ - "https?://iot.(.+).amazonaws.com", -] +url_bases = ["https?://iot.(.+).amazonaws.com"] response = IoTResponse() -url_paths = { - '{0}/.*$': response.dispatch, -} +url_paths = {"{0}/.*$": response.dispatch} diff --git a/moto/iotdata/__init__.py b/moto/iotdata/__init__.py index 214f2e575..016fef5fb 100644 --- a/moto/iotdata/__init__.py +++ b/moto/iotdata/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import iotdata_backends from ..core.models import base_decorator -iotdata_backend = iotdata_backends['us-east-1'] +iotdata_backend = iotdata_backends["us-east-1"] mock_iotdata = base_decorator(iotdata_backends) diff --git a/moto/iotdata/exceptions.py b/moto/iotdata/exceptions.py index ddc6b37fd..30998ffc3 100644 --- a/moto/iotdata/exceptions.py +++ b/moto/iotdata/exceptions.py @@ -10,8 +10,7 @@ class ResourceNotFoundException(IoTDataPlaneClientError): def __init__(self): self.code = 404 super(ResourceNotFoundException, self).__init__( - "ResourceNotFoundException", - "The specified resource does not exist" + "ResourceNotFoundException", "The specified resource does not exist" ) @@ -21,3 +20,9 @@ class InvalidRequestException(IoTDataPlaneClientError): super(InvalidRequestException, self).__init__( "InvalidRequestException", message ) + + +class ConflictException(IoTDataPlaneClientError): + def __init__(self, message): + self.code = 409 + super(ConflictException, self).__init__("ConflictException", message) diff --git a/moto/iotdata/models.py b/moto/iotdata/models.py index ad4caa89e..e534e1d1f 100644 --- a/moto/iotdata/models.py +++ b/moto/iotdata/models.py @@ -6,8 +6,9 @@ import jsondiff from moto.core import BaseBackend, BaseModel from moto.iot import iot_backends from .exceptions import ( + ConflictException, ResourceNotFoundException, - InvalidRequestException + InvalidRequestException, ) @@ -15,6 +16,7 @@ class FakeShadow(BaseModel): """See the specification: http://docs.aws.amazon.com/iot/latest/developerguide/thing-shadow-document-syntax.html """ + def __init__(self, desired, reported, requested_payload, version, deleted=False): self.desired = desired self.reported = reported @@ -23,15 +25,23 @@ class FakeShadow(BaseModel): self.timestamp = int(time.time()) self.deleted = deleted - self.metadata_desired = self._create_metadata_from_state(self.desired, self.timestamp) - self.metadata_reported = self._create_metadata_from_state(self.reported, self.timestamp) + self.metadata_desired = self._create_metadata_from_state( + self.desired, self.timestamp + ) + self.metadata_reported = self._create_metadata_from_state( + self.reported, self.timestamp + ) @classmethod def create_from_previous_version(cls, previous_shadow, payload): """ set None to payload when you want to delete shadow """ - version, previous_payload = (previous_shadow.version + 1, previous_shadow.to_dict(include_delta=False)) if previous_shadow else (1, {}) + version, previous_payload = ( + (previous_shadow.version + 1, previous_shadow.to_dict(include_delta=False)) + if previous_shadow + else (1, {}) + ) if payload is None: # if given payload is None, delete existing payload @@ -40,13 +50,11 @@ class FakeShadow(BaseModel): return shadow # we can make sure that payload has 'state' key - desired = payload['state'].get( - 'desired', - previous_payload.get('state', {}).get('desired', None) + desired = payload["state"].get( + "desired", previous_payload.get("state", {}).get("desired", None) ) - reported = payload['state'].get( - 'reported', - previous_payload.get('state', {}).get('reported', None) + reported = payload["state"].get( + "reported", previous_payload.get("state", {}).get("reported", None) ) shadow = FakeShadow(desired, reported, payload, version) return shadow @@ -75,58 +83,60 @@ class FakeShadow(BaseModel): if isinstance(elem, list): return [_f(_, ts) for _ in elem] return {"timestamp": ts} + return _f(state, ts) def to_response_dict(self): - desired = self.requested_payload['state'].get('desired', None) - reported = self.requested_payload['state'].get('reported', None) + desired = self.requested_payload["state"].get("desired", None) + reported = self.requested_payload["state"].get("reported", None) payload = {} if desired is not None: - payload['desired'] = desired + payload["desired"] = desired if reported is not None: - payload['reported'] = reported + payload["reported"] = reported metadata = {} if desired is not None: - metadata['desired'] = self._create_metadata_from_state(desired, self.timestamp) + metadata["desired"] = self._create_metadata_from_state( + desired, self.timestamp + ) if reported is not None: - metadata['reported'] = self._create_metadata_from_state(reported, self.timestamp) + metadata["reported"] = self._create_metadata_from_state( + reported, self.timestamp + ) return { - 'state': payload, - 'metadata': metadata, - 'timestamp': self.timestamp, - 'version': self.version + "state": payload, + "metadata": metadata, + "timestamp": self.timestamp, + "version": self.version, } def to_dict(self, include_delta=True): """returning nothing except for just top-level keys for now. """ if self.deleted: - return { - 'timestamp': self.timestamp, - 'version': self.version - } + return {"timestamp": self.timestamp, "version": self.version} delta = self.parse_payload(self.desired, self.reported) payload = {} if self.desired is not None: - payload['desired'] = self.desired + payload["desired"] = self.desired if self.reported is not None: - payload['reported'] = self.reported + payload["reported"] = self.reported if include_delta and (delta is not None and len(delta.keys()) != 0): - payload['delta'] = delta + payload["delta"] = delta metadata = {} if self.metadata_desired is not None: - metadata['desired'] = self.metadata_desired + metadata["desired"] = self.metadata_desired if self.metadata_reported is not None: - metadata['reported'] = self.metadata_reported + metadata["reported"] = self.metadata_reported return { - 'state': payload, - 'metadata': metadata, - 'timestamp': self.timestamp, - 'version': self.version + "state": payload, + "metadata": metadata, + "timestamp": self.timestamp, + "version": self.version, } @@ -153,15 +163,19 @@ class IoTDataPlaneBackend(BaseBackend): try: payload = json.loads(payload) except ValueError: - raise InvalidRequestException('invalid json') - if 'state' not in payload: - raise InvalidRequestException('need node `state`') - if not isinstance(payload['state'], dict): - raise InvalidRequestException('state node must be an Object') - if any(_ for _ in payload['state'].keys() if _ not in ['desired', 'reported']): - raise InvalidRequestException('State contains an invalid node') + raise InvalidRequestException("invalid json") + if "state" not in payload: + raise InvalidRequestException("need node `state`") + if not isinstance(payload["state"], dict): + raise InvalidRequestException("state node must be an Object") + if any(_ for _ in payload["state"].keys() if _ not in ["desired", "reported"]): + raise InvalidRequestException("State contains an invalid node") - new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload) + if "version" in payload and thing.thing_shadow.version != payload["version"]: + raise ConflictException("Version conflict") + new_shadow = FakeShadow.create_from_previous_version( + thing.thing_shadow, payload + ) thing.thing_shadow = new_shadow return thing.thing_shadow @@ -180,7 +194,9 @@ class IoTDataPlaneBackend(BaseBackend): if thing.thing_shadow is None: raise ResourceNotFoundException() payload = None - new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload) + new_shadow = FakeShadow.create_from_previous_version( + thing.thing_shadow, payload + ) thing.thing_shadow = new_shadow return thing.thing_shadow diff --git a/moto/iotdata/responses.py b/moto/iotdata/responses.py index 8ab724ed1..045ed5e59 100644 --- a/moto/iotdata/responses.py +++ b/moto/iotdata/responses.py @@ -5,7 +5,7 @@ import json class IoTDataPlaneResponse(BaseResponse): - SERVICE_NAME = 'iot-data' + SERVICE_NAME = "iot-data" @property def iotdata_backend(self): @@ -15,32 +15,23 @@ class IoTDataPlaneResponse(BaseResponse): thing_name = self._get_param("thingName") payload = self.body payload = self.iotdata_backend.update_thing_shadow( - thing_name=thing_name, - payload=payload, + thing_name=thing_name, payload=payload ) return json.dumps(payload.to_response_dict()) def get_thing_shadow(self): thing_name = self._get_param("thingName") - payload = self.iotdata_backend.get_thing_shadow( - thing_name=thing_name, - ) + payload = self.iotdata_backend.get_thing_shadow(thing_name=thing_name) return json.dumps(payload.to_dict()) def delete_thing_shadow(self): thing_name = self._get_param("thingName") - payload = self.iotdata_backend.delete_thing_shadow( - thing_name=thing_name, - ) + payload = self.iotdata_backend.delete_thing_shadow(thing_name=thing_name) return json.dumps(payload.to_dict()) def publish(self): topic = self._get_param("topic") qos = self._get_int_param("qos") payload = self._get_param("payload") - self.iotdata_backend.publish( - topic=topic, - qos=qos, - payload=payload, - ) + self.iotdata_backend.publish(topic=topic, qos=qos, payload=payload) return json.dumps(dict()) diff --git a/moto/iotdata/urls.py b/moto/iotdata/urls.py index a3bcb0a52..b3baa66cc 100644 --- a/moto/iotdata/urls.py +++ b/moto/iotdata/urls.py @@ -1,14 +1,10 @@ from __future__ import unicode_literals from .responses import IoTDataPlaneResponse -url_bases = [ - "https?://data.iot.(.+).amazonaws.com", -] +url_bases = ["https?://data.iot.(.+).amazonaws.com"] response = IoTDataPlaneResponse() -url_paths = { - '{0}/.*$': response.dispatch, -} +url_paths = {"{0}/.*$": response.dispatch} diff --git a/moto/kinesis/__init__.py b/moto/kinesis/__init__.py index 7d9767a9f..823379cd5 100644 --- a/moto/kinesis/__init__.py +++ b/moto/kinesis/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import kinesis_backends from ..core.models import base_decorator, deprecated_base_decorator -kinesis_backend = kinesis_backends['us-east-1'] +kinesis_backend = kinesis_backends["us-east-1"] mock_kinesis = base_decorator(kinesis_backends) mock_kinesis_deprecated = deprecated_base_decorator(kinesis_backends) diff --git a/moto/kinesis/exceptions.py b/moto/kinesis/exceptions.py index 82f796ecc..1f25d6720 100644 --- a/moto/kinesis/exceptions.py +++ b/moto/kinesis/exceptions.py @@ -2,47 +2,42 @@ from __future__ import unicode_literals import json from werkzeug.exceptions import BadRequest +from moto.core import ACCOUNT_ID class ResourceNotFoundError(BadRequest): - def __init__(self, message): super(ResourceNotFoundError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) class ResourceInUseError(BadRequest): - def __init__(self, message): super(ResourceInUseError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceInUseException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceInUseException"} + ) class StreamNotFoundError(ResourceNotFoundError): - def __init__(self, stream_name): super(StreamNotFoundError, self).__init__( - 'Stream {0} under account 123456789012 not found.'.format(stream_name)) + "Stream {0} under account {1} not found.".format(stream_name, ACCOUNT_ID) + ) class ShardNotFoundError(ResourceNotFoundError): - def __init__(self, shard_id): super(ShardNotFoundError, self).__init__( - 'Shard {0} under account 123456789012 not found.'.format(shard_id)) + "Shard {0} under account {1} not found.".format(shard_id, ACCOUNT_ID) + ) class InvalidArgumentError(BadRequest): - def __init__(self, message): super(InvalidArgumentError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'InvalidArgumentException', - }) + self.description = json.dumps( + {"message": message, "__type": "InvalidArgumentException"} + ) diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index e7a389981..48642f197 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -13,9 +13,19 @@ from hashlib import md5 from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time -from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError, \ - ResourceNotFoundError, InvalidArgumentError -from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator +from moto.core import ACCOUNT_ID +from .exceptions import ( + StreamNotFoundError, + ShardNotFoundError, + ResourceInUseError, + ResourceNotFoundError, + InvalidArgumentError, +) +from .utils import ( + compose_shard_iterator, + compose_new_shard_iterator, + decompose_shard_iterator, +) class Record(BaseModel): @@ -32,12 +42,11 @@ class Record(BaseModel): "Data": self.data, "PartitionKey": self.partition_key, "SequenceNumber": str(self.sequence_number), - "ApproximateArrivalTimestamp": self.created_at_datetime.isoformat() + "ApproximateArrivalTimestamp": self.created_at_datetime.isoformat(), } class Shard(BaseModel): - def __init__(self, shard_id, starting_hash, ending_hash): self._shard_id = shard_id self.starting_hash = starting_hash @@ -75,7 +84,8 @@ class Shard(BaseModel): last_sequence_number = 0 sequence_number = last_sequence_number + 1 self.records[sequence_number] = Record( - partition_key, data, sequence_number, explicit_hash_key) + partition_key, data, sequence_number, explicit_hash_key + ) return sequence_number def get_min_sequence_number(self): @@ -94,39 +104,46 @@ class Shard(BaseModel): else: # find the last item in the list that was created before # at_timestamp - r = next((r for r in reversed(self.records.values()) if r.created_at < at_timestamp), None) + r = next( + ( + r + for r in reversed(self.records.values()) + if r.created_at < at_timestamp + ), + None, + ) return r.sequence_number def to_json(self): return { "HashKeyRange": { "EndingHashKey": str(self.ending_hash), - "StartingHashKey": str(self.starting_hash) + "StartingHashKey": str(self.starting_hash), }, "SequenceNumberRange": { "EndingSequenceNumber": self.get_max_sequence_number(), "StartingSequenceNumber": self.get_min_sequence_number(), }, - "ShardId": self.shard_id + "ShardId": self.shard_id, } class Stream(BaseModel): - def __init__(self, stream_name, shard_count, region): self.stream_name = stream_name self.shard_count = shard_count self.creation_datetime = datetime.datetime.now() self.region = region - self.account_number = "123456789012" + self.account_number = ACCOUNT_ID self.shards = {} self.tags = {} self.status = "ACTIVE" - step = 2**128 // shard_count - hash_ranges = itertools.chain(map(lambda i: (i, i * step, (i + 1) * step), - range(shard_count - 1)), - [(shard_count - 1, (shard_count - 1) * step, 2**128)]) + step = 2 ** 128 // shard_count + hash_ranges = itertools.chain( + map(lambda i: (i, i * step, (i + 1) * step), range(shard_count - 1)), + [(shard_count - 1, (shard_count - 1) * step, 2 ** 128)], + ) for index, start, end in hash_ranges: shard = Shard(index, start, end) @@ -137,7 +154,7 @@ class Stream(BaseModel): return "arn:aws:kinesis:{region}:{account_number}:{stream_name}".format( region=self.region, account_number=self.account_number, - stream_name=self.stream_name + stream_name=self.stream_name, ) def get_shard(self, shard_id): @@ -158,21 +175,22 @@ class Stream(BaseModel): key = int(explicit_hash_key) - if key >= 2**128: + if key >= 2 ** 128: raise InvalidArgumentError("explicit_hash_key") else: - key = int(md5(partition_key.encode('utf-8')).hexdigest(), 16) + key = int(md5(partition_key.encode("utf-8")).hexdigest(), 16) for shard in self.shards.values(): if shard.starting_hash <= key < shard.ending_hash: return shard - def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data): + def put_record( + self, partition_key, explicit_hash_key, sequence_number_for_ordering, data + ): shard = self.get_shard_for_key(partition_key, explicit_hash_key) - sequence_number = shard.put_record( - partition_key, data, explicit_hash_key) + sequence_number = shard.put_record(partition_key, data, explicit_hash_key) return sequence_number, shard.shard_id def to_json(self): @@ -198,69 +216,69 @@ class Stream(BaseModel): } @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - region = properties.get('Region', 'us-east-1') - shard_count = properties.get('ShardCount', 1) - return Stream(properties['Name'], shard_count, region) + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + region = properties.get("Region", "us-east-1") + shard_count = properties.get("ShardCount", 1) + return Stream(properties["Name"], shard_count, region) class FirehoseRecord(BaseModel): - def __init__(self, record_data): self.record_id = 12345678 self.record_data = record_data class DeliveryStream(BaseModel): - def __init__(self, stream_name, **stream_kwargs): self.name = stream_name - self.redshift_username = stream_kwargs.get('redshift_username') - self.redshift_password = stream_kwargs.get('redshift_password') - self.redshift_jdbc_url = stream_kwargs.get('redshift_jdbc_url') - self.redshift_role_arn = stream_kwargs.get('redshift_role_arn') - self.redshift_copy_command = stream_kwargs.get('redshift_copy_command') + self.redshift_username = stream_kwargs.get("redshift_username") + self.redshift_password = stream_kwargs.get("redshift_password") + self.redshift_jdbc_url = stream_kwargs.get("redshift_jdbc_url") + self.redshift_role_arn = stream_kwargs.get("redshift_role_arn") + self.redshift_copy_command = stream_kwargs.get("redshift_copy_command") - self.s3_role_arn = stream_kwargs.get('s3_role_arn') - self.s3_bucket_arn = stream_kwargs.get('s3_bucket_arn') - self.s3_prefix = stream_kwargs.get('s3_prefix') - self.s3_compression_format = stream_kwargs.get( - 's3_compression_format', 'UNCOMPRESSED') - self.s3_buffering_hings = stream_kwargs.get('s3_buffering_hings') + self.s3_config = stream_kwargs.get("s3_config") + self.extended_s3_config = stream_kwargs.get("extended_s3_config") - self.redshift_s3_role_arn = stream_kwargs.get('redshift_s3_role_arn') - self.redshift_s3_bucket_arn = stream_kwargs.get( - 'redshift_s3_bucket_arn') - self.redshift_s3_prefix = stream_kwargs.get('redshift_s3_prefix') + self.redshift_s3_role_arn = stream_kwargs.get("redshift_s3_role_arn") + self.redshift_s3_bucket_arn = stream_kwargs.get("redshift_s3_bucket_arn") + self.redshift_s3_prefix = stream_kwargs.get("redshift_s3_prefix") self.redshift_s3_compression_format = stream_kwargs.get( - 'redshift_s3_compression_format', 'UNCOMPRESSED') - self.redshift_s3_buffering_hings = stream_kwargs.get( - 'redshift_s3_buffering_hings') + "redshift_s3_compression_format", "UNCOMPRESSED" + ) + self.redshift_s3_buffering_hints = stream_kwargs.get( + "redshift_s3_buffering_hints" + ) self.records = [] - self.status = 'ACTIVE' + self.status = "ACTIVE" self.created_at = datetime.datetime.utcnow() self.last_updated = datetime.datetime.utcnow() @property def arn(self): - return 'arn:aws:firehose:us-east-1:123456789012:deliverystream/{0}'.format(self.name) + return "arn:aws:firehose:us-east-1:{1}:deliverystream/{0}".format( + self.name, ACCOUNT_ID + ) def destinations_to_dict(self): - if self.s3_role_arn: - return [{ - 'DestinationId': 'string', - 'S3DestinationDescription': { - 'RoleARN': self.s3_role_arn, - 'BucketARN': self.s3_bucket_arn, - 'Prefix': self.s3_prefix, - 'BufferingHints': self.s3_buffering_hings, - 'CompressionFormat': self.s3_compression_format, + if self.s3_config: + return [ + {"DestinationId": "string", "S3DestinationDescription": self.s3_config} + ] + elif self.extended_s3_config: + return [ + { + "DestinationId": "string", + "ExtendedS3DestinationDescription": self.extended_s3_config, } - }] + ] else: - return [{ + return [ + { "DestinationId": "string", "RedshiftDestinationDescription": { "ClusterJDBCURL": self.redshift_jdbc_url, @@ -268,15 +286,15 @@ class DeliveryStream(BaseModel): "RoleARN": self.redshift_role_arn, "S3DestinationDescription": { "BucketARN": self.redshift_s3_bucket_arn, - "BufferingHints": self.redshift_s3_buffering_hings, + "BufferingHints": self.redshift_s3_buffering_hints, "CompressionFormat": self.redshift_s3_compression_format, "Prefix": self.redshift_s3_prefix, - "RoleARN": self.redshift_s3_role_arn + "RoleARN": self.redshift_s3_role_arn, }, "Username": self.redshift_username, }, - } - ] + } + ] def to_dict(self): return { @@ -299,7 +317,6 @@ class DeliveryStream(BaseModel): class KinesisBackend(BaseBackend): - def __init__(self): self.streams = OrderedDict() self.delivery_streams = {} @@ -328,14 +345,24 @@ class KinesisBackend(BaseBackend): return self.streams.pop(stream_name) raise StreamNotFoundError(stream_name) - def get_shard_iterator(self, stream_name, shard_id, shard_iterator_type, starting_sequence_number, - at_timestamp): + def get_shard_iterator( + self, + stream_name, + shard_id, + shard_iterator_type, + starting_sequence_number, + at_timestamp, + ): # Validate params stream = self.describe_stream(stream_name) shard = stream.get_shard(shard_id) shard_iterator = compose_new_shard_iterator( - stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp + stream_name, + shard, + shard_iterator_type, + starting_sequence_number, + at_timestamp, ) return shard_iterator @@ -346,14 +373,24 @@ class KinesisBackend(BaseBackend): stream = self.describe_stream(stream_name) shard = stream.get_shard(shard_id) - records, last_sequence_id, millis_behind_latest = shard.get_records(last_sequence_id, limit) + records, last_sequence_id, millis_behind_latest = shard.get_records( + last_sequence_id, limit + ) next_shard_iterator = compose_shard_iterator( - stream_name, shard, last_sequence_id) + stream_name, shard, last_sequence_id + ) return next_shard_iterator, records, millis_behind_latest - def put_record(self, stream_name, partition_key, explicit_hash_key, sequence_number_for_ordering, data): + def put_record( + self, + stream_name, + partition_key, + explicit_hash_key, + sequence_number_for_ordering, + data, + ): stream = self.describe_stream(stream_name) sequence_number, shard_id = stream.put_record( @@ -365,10 +402,7 @@ class KinesisBackend(BaseBackend): def put_records(self, stream_name, records): stream = self.describe_stream(stream_name) - response = { - "FailedRecordCount": 0, - "Records": [] - } + response = {"FailedRecordCount": 0, "Records": []} for record in records: partition_key = record.get("PartitionKey") @@ -378,10 +412,9 @@ class KinesisBackend(BaseBackend): sequence_number, shard_id = stream.put_record( partition_key, explicit_hash_key, None, data ) - response['Records'].append({ - "SequenceNumber": sequence_number, - "ShardId": shard_id - }) + response["Records"].append( + {"SequenceNumber": sequence_number, "ShardId": shard_id} + ) return response @@ -391,18 +424,18 @@ class KinesisBackend(BaseBackend): if shard_to_split not in stream.shards: raise ResourceNotFoundError(shard_to_split) - if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key): + if not re.match(r"0|([1-9]\d{0,38})", new_starting_hash_key): raise InvalidArgumentError(new_starting_hash_key) new_starting_hash_key = int(new_starting_hash_key) shard = stream.shards[shard_to_split] - last_id = sorted(stream.shards.values(), - key=attrgetter('_shard_id'))[-1]._shard_id + last_id = sorted(stream.shards.values(), key=attrgetter("_shard_id"))[ + -1 + ]._shard_id if shard.starting_hash < new_starting_hash_key < shard.ending_hash: - new_shard = Shard( - last_id + 1, new_starting_hash_key, shard.ending_hash) + new_shard = Shard(last_id + 1, new_starting_hash_key, shard.ending_hash) shard.ending_hash = new_starting_hash_key stream.shards[new_shard.shard_id] = new_shard else: @@ -439,10 +472,11 @@ class KinesisBackend(BaseBackend): del stream.shards[shard2.shard_id] for index in shard2.records: record = shard2.records[index] - shard1.put_record(record.partition_key, - record.data, record.explicit_hash_key) + shard1.put_record( + record.partition_key, record.data, record.explicit_hash_key + ) - ''' Firehose ''' + """ Firehose """ def create_delivery_stream(self, stream_name, **stream_kwargs): stream = DeliveryStream(stream_name, **stream_kwargs) @@ -466,25 +500,21 @@ class KinesisBackend(BaseBackend): record = stream.put_record(record_data) return record - def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None): + def list_tags_for_stream( + self, stream_name, exclusive_start_tag_key=None, limit=None + ): stream = self.describe_stream(stream_name) tags = [] - result = { - 'HasMoreTags': False, - 'Tags': tags - } + result = {"HasMoreTags": False, "Tags": tags} for key, val in sorted(stream.tags.items(), key=lambda x: x[0]): if limit and len(tags) >= limit: - result['HasMoreTags'] = True + result["HasMoreTags"] = True break if exclusive_start_tag_key and key < exclusive_start_tag_key: continue - tags.append({ - 'Key': key, - 'Value': val - }) + tags.append({"Key": key, "Value": val}) return result diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index 3a81bd9f4..500f7855d 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -7,7 +7,6 @@ from .models import kinesis_backends class KinesisResponse(BaseResponse): - @property def parameters(self): return json.loads(self.body) @@ -18,47 +17,47 @@ class KinesisResponse(BaseResponse): @property def is_firehose(self): - host = self.headers.get('host') or self.headers['Host'] - return host.startswith('firehose') or 'firehose' in self.headers.get('Authorization', '') + host = self.headers.get("host") or self.headers["Host"] + return host.startswith("firehose") or "firehose" in self.headers.get( + "Authorization", "" + ) def create_stream(self): - stream_name = self.parameters.get('StreamName') - shard_count = self.parameters.get('ShardCount') - self.kinesis_backend.create_stream( - stream_name, shard_count, self.region) + stream_name = self.parameters.get("StreamName") + shard_count = self.parameters.get("ShardCount") + self.kinesis_backend.create_stream(stream_name, shard_count, self.region) return "" def describe_stream(self): - stream_name = self.parameters.get('StreamName') + stream_name = self.parameters.get("StreamName") stream = self.kinesis_backend.describe_stream(stream_name) return json.dumps(stream.to_json()) def describe_stream_summary(self): - stream_name = self.parameters.get('StreamName') + stream_name = self.parameters.get("StreamName") stream = self.kinesis_backend.describe_stream_summary(stream_name) return json.dumps(stream.to_json_summary()) def list_streams(self): streams = self.kinesis_backend.list_streams() stream_names = [stream.stream_name for stream in streams] - max_streams = self._get_param('Limit', 10) + max_streams = self._get_param("Limit", 10) try: - token = self.parameters.get('ExclusiveStartStreamName') + token = self.parameters.get("ExclusiveStartStreamName") except ValueError: - token = self._get_param('ExclusiveStartStreamName') + token = self._get_param("ExclusiveStartStreamName") if token: start = stream_names.index(token) + 1 else: start = 0 - streams_resp = stream_names[start:start + max_streams] + streams_resp = stream_names[start : start + max_streams] has_more_streams = False if start + max_streams < len(stream_names): has_more_streams = True - return json.dumps({ - "HasMoreStreams": has_more_streams, - "StreamNames": streams_resp - }) + return json.dumps( + {"HasMoreStreams": has_more_streams, "StreamNames": streams_resp} + ) def delete_stream(self): stream_name = self.parameters.get("StreamName") @@ -69,30 +68,36 @@ class KinesisResponse(BaseResponse): stream_name = self.parameters.get("StreamName") shard_id = self.parameters.get("ShardId") shard_iterator_type = self.parameters.get("ShardIteratorType") - starting_sequence_number = self.parameters.get( - "StartingSequenceNumber") + starting_sequence_number = self.parameters.get("StartingSequenceNumber") at_timestamp = self.parameters.get("Timestamp") shard_iterator = self.kinesis_backend.get_shard_iterator( - stream_name, shard_id, shard_iterator_type, starting_sequence_number, at_timestamp + stream_name, + shard_id, + shard_iterator_type, + starting_sequence_number, + at_timestamp, ) - return json.dumps({ - "ShardIterator": shard_iterator - }) + return json.dumps({"ShardIterator": shard_iterator}) def get_records(self): shard_iterator = self.parameters.get("ShardIterator") limit = self.parameters.get("Limit") - next_shard_iterator, records, millis_behind_latest = self.kinesis_backend.get_records( - shard_iterator, limit) + ( + next_shard_iterator, + records, + millis_behind_latest, + ) = self.kinesis_backend.get_records(shard_iterator, limit) - return json.dumps({ - "NextShardIterator": next_shard_iterator, - "Records": [record.to_json() for record in records], - 'MillisBehindLatest': millis_behind_latest - }) + return json.dumps( + { + "NextShardIterator": next_shard_iterator, + "Records": [record.to_json() for record in records], + "MillisBehindLatest": millis_behind_latest, + } + ) def put_record(self): if self.is_firehose: @@ -100,18 +105,18 @@ class KinesisResponse(BaseResponse): stream_name = self.parameters.get("StreamName") partition_key = self.parameters.get("PartitionKey") explicit_hash_key = self.parameters.get("ExplicitHashKey") - sequence_number_for_ordering = self.parameters.get( - "SequenceNumberForOrdering") + sequence_number_for_ordering = self.parameters.get("SequenceNumberForOrdering") data = self.parameters.get("Data") sequence_number, shard_id = self.kinesis_backend.put_record( - stream_name, partition_key, explicit_hash_key, sequence_number_for_ordering, data + stream_name, + partition_key, + explicit_hash_key, + sequence_number_for_ordering, + data, ) - return json.dumps({ - "SequenceNumber": sequence_number, - "ShardId": shard_id, - }) + return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id}) def put_records(self): if self.is_firehose: @@ -119,9 +124,7 @@ class KinesisResponse(BaseResponse): stream_name = self.parameters.get("StreamName") records = self.parameters.get("Records") - response = self.kinesis_backend.put_records( - stream_name, records - ) + response = self.kinesis_backend.put_records(stream_name, records) return json.dumps(response) @@ -143,43 +146,39 @@ class KinesisResponse(BaseResponse): ) return "" - ''' Firehose ''' + """ Firehose """ def create_delivery_stream(self): - stream_name = self.parameters['DeliveryStreamName'] - redshift_config = self.parameters.get( - 'RedshiftDestinationConfiguration') + stream_name = self.parameters["DeliveryStreamName"] + redshift_config = self.parameters.get("RedshiftDestinationConfiguration") + s3_config = self.parameters.get("S3DestinationConfiguration") + extended_s3_config = self.parameters.get("ExtendedS3DestinationConfiguration") if redshift_config: - redshift_s3_config = redshift_config['S3Configuration'] + redshift_s3_config = redshift_config["S3Configuration"] stream_kwargs = { - 'redshift_username': redshift_config['Username'], - 'redshift_password': redshift_config['Password'], - 'redshift_jdbc_url': redshift_config['ClusterJDBCURL'], - 'redshift_role_arn': redshift_config['RoleARN'], - 'redshift_copy_command': redshift_config['CopyCommand'], + "redshift_username": redshift_config["Username"], + "redshift_password": redshift_config["Password"], + "redshift_jdbc_url": redshift_config["ClusterJDBCURL"], + "redshift_role_arn": redshift_config["RoleARN"], + "redshift_copy_command": redshift_config["CopyCommand"], + "redshift_s3_role_arn": redshift_s3_config["RoleARN"], + "redshift_s3_bucket_arn": redshift_s3_config["BucketARN"], + "redshift_s3_prefix": redshift_s3_config["Prefix"], + "redshift_s3_compression_format": redshift_s3_config.get( + "CompressionFormat" + ), + "redshift_s3_buffering_hints": redshift_s3_config["BufferingHints"], + } + elif s3_config: + stream_kwargs = {"s3_config": s3_config} + elif extended_s3_config: + stream_kwargs = {"extended_s3_config": extended_s3_config} - 'redshift_s3_role_arn': redshift_s3_config['RoleARN'], - 'redshift_s3_bucket_arn': redshift_s3_config['BucketARN'], - 'redshift_s3_prefix': redshift_s3_config['Prefix'], - 'redshift_s3_compression_format': redshift_s3_config.get('CompressionFormat'), - 'redshift_s3_buffering_hings': redshift_s3_config['BufferingHints'], - } - else: - # S3 Config - s3_config = self.parameters['S3DestinationConfiguration'] - stream_kwargs = { - 's3_role_arn': s3_config['RoleARN'], - 's3_bucket_arn': s3_config['BucketARN'], - 's3_prefix': s3_config['Prefix'], - 's3_compression_format': s3_config.get('CompressionFormat'), - 's3_buffering_hings': s3_config['BufferingHints'], - } stream = self.kinesis_backend.create_delivery_stream( - stream_name, **stream_kwargs) - return json.dumps({ - 'DeliveryStreamARN': stream.arn - }) + stream_name, **stream_kwargs + ) + return json.dumps({"DeliveryStreamARN": stream.arn}) def describe_delivery_stream(self): stream_name = self.parameters["DeliveryStreamName"] @@ -188,60 +187,54 @@ class KinesisResponse(BaseResponse): def list_delivery_streams(self): streams = self.kinesis_backend.list_delivery_streams() - return json.dumps({ - "DeliveryStreamNames": [ - stream.name for stream in streams - ], - "HasMoreDeliveryStreams": False - }) + return json.dumps( + { + "DeliveryStreamNames": [stream.name for stream in streams], + "HasMoreDeliveryStreams": False, + } + ) def delete_delivery_stream(self): - stream_name = self.parameters['DeliveryStreamName'] + stream_name = self.parameters["DeliveryStreamName"] self.kinesis_backend.delete_delivery_stream(stream_name) return json.dumps({}) def firehose_put_record(self): - stream_name = self.parameters['DeliveryStreamName'] - record_data = self.parameters['Record']['Data'] + stream_name = self.parameters["DeliveryStreamName"] + record_data = self.parameters["Record"]["Data"] - record = self.kinesis_backend.put_firehose_record( - stream_name, record_data) - return json.dumps({ - "RecordId": record.record_id, - }) + record = self.kinesis_backend.put_firehose_record(stream_name, record_data) + return json.dumps({"RecordId": record.record_id}) def put_record_batch(self): - stream_name = self.parameters['DeliveryStreamName'] - records = self.parameters['Records'] + stream_name = self.parameters["DeliveryStreamName"] + records = self.parameters["Records"] request_responses = [] for record in records: record_response = self.kinesis_backend.put_firehose_record( - stream_name, record['Data']) - request_responses.append({ - "RecordId": record_response.record_id - }) - return json.dumps({ - "FailedPutCount": 0, - "RequestResponses": request_responses, - }) + stream_name, record["Data"] + ) + request_responses.append({"RecordId": record_response.record_id}) + return json.dumps({"FailedPutCount": 0, "RequestResponses": request_responses}) def add_tags_to_stream(self): - stream_name = self.parameters.get('StreamName') - tags = self.parameters.get('Tags') + stream_name = self.parameters.get("StreamName") + tags = self.parameters.get("Tags") self.kinesis_backend.add_tags_to_stream(stream_name, tags) return json.dumps({}) def list_tags_for_stream(self): - stream_name = self.parameters.get('StreamName') - exclusive_start_tag_key = self.parameters.get('ExclusiveStartTagKey') - limit = self.parameters.get('Limit') + stream_name = self.parameters.get("StreamName") + exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey") + limit = self.parameters.get("Limit") response = self.kinesis_backend.list_tags_for_stream( - stream_name, exclusive_start_tag_key, limit) + stream_name, exclusive_start_tag_key, limit + ) return json.dumps(response) def remove_tags_from_stream(self): - stream_name = self.parameters.get('StreamName') - tag_keys = self.parameters.get('TagKeys') + stream_name = self.parameters.get("StreamName") + tag_keys = self.parameters.get("TagKeys") self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys) return json.dumps({}) diff --git a/moto/kinesis/urls.py b/moto/kinesis/urls.py index a8d15eecd..c95f03190 100644 --- a/moto/kinesis/urls.py +++ b/moto/kinesis/urls.py @@ -6,6 +6,4 @@ url_bases = [ "https?://firehose.(.+).amazonaws.com", ] -url_paths = { - '{0}/$': KinesisResponse.dispatch, -} +url_paths = {"{0}/$": KinesisResponse.dispatch} diff --git a/moto/kinesis/utils.py b/moto/kinesis/utils.py index 0c3edbb5a..b455cb7ba 100644 --- a/moto/kinesis/utils.py +++ b/moto/kinesis/utils.py @@ -14,8 +14,9 @@ else: raise Exception("Python version is not supported") -def compose_new_shard_iterator(stream_name, shard, shard_iterator_type, starting_sequence_number, - at_timestamp): +def compose_new_shard_iterator( + stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp +): if shard_iterator_type == "AT_SEQUENCE_NUMBER": last_sequence_id = int(starting_sequence_number) - 1 elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER": @@ -28,17 +29,16 @@ def compose_new_shard_iterator(stream_name, shard, shard_iterator_type, starting last_sequence_id = shard.get_sequence_number_at(at_timestamp) else: raise InvalidArgumentError( - "Invalid ShardIteratorType: {0}".format(shard_iterator_type)) + "Invalid ShardIteratorType: {0}".format(shard_iterator_type) + ) return compose_shard_iterator(stream_name, shard, last_sequence_id) def compose_shard_iterator(stream_name, shard, last_sequence_id): return encode_method( - "{0}:{1}:{2}".format( - stream_name, - shard.shard_id, - last_sequence_id, - ).encode("utf-8") + "{0}:{1}:{2}".format(stream_name, shard.shard_id, last_sequence_id).encode( + "utf-8" + ) ).decode("utf-8") diff --git a/moto/kms/__init__.py b/moto/kms/__init__.py index b4bb0b639..ecedb8bfd 100644 --- a/moto/kms/__init__.py +++ b/moto/kms/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import kms_backends from ..core.models import base_decorator, deprecated_base_decorator -kms_backend = kms_backends['us-east-1'] +kms_backend = kms_backends["us-east-1"] mock_kms = base_decorator(kms_backends) mock_kms_deprecated = deprecated_base_decorator(kms_backends) diff --git a/moto/kms/exceptions.py b/moto/kms/exceptions.py index 70edd3dcd..4ddfd279f 100644 --- a/moto/kms/exceptions.py +++ b/moto/kms/exceptions.py @@ -6,31 +6,47 @@ class NotFoundException(JsonRESTError): code = 400 def __init__(self, message): - super(NotFoundException, self).__init__( - "NotFoundException", message) + super(NotFoundException, self).__init__("NotFoundException", message) class ValidationException(JsonRESTError): code = 400 def __init__(self, message): - super(ValidationException, self).__init__( - "ValidationException", message) + super(ValidationException, self).__init__("ValidationException", message) class AlreadyExistsException(JsonRESTError): code = 400 def __init__(self, message): - super(AlreadyExistsException, self).__init__( - "AlreadyExistsException", message) + super(AlreadyExistsException, self).__init__("AlreadyExistsException", message) class NotAuthorizedException(JsonRESTError): code = 400 def __init__(self): - super(NotAuthorizedException, self).__init__( - "NotAuthorizedException", None) + super(NotAuthorizedException, self).__init__("NotAuthorizedException", None) self.description = '{"__type":"NotAuthorizedException"}' + + +class AccessDeniedException(JsonRESTError): + code = 400 + + def __init__(self, message): + super(AccessDeniedException, self).__init__("AccessDeniedException", message) + + self.description = '{"__type":"AccessDeniedException"}' + + +class InvalidCiphertextException(JsonRESTError): + code = 400 + + def __init__(self): + super(InvalidCiphertextException, self).__init__( + "InvalidCiphertextException", None + ) + + self.description = '{"__type":"InvalidCiphertextException"}' diff --git a/moto/kms/models.py b/moto/kms/models.py index 577840b06..9d7739779 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -1,16 +1,18 @@ from __future__ import unicode_literals import os -import boto.kms -from moto.core import BaseBackend, BaseModel -from moto.core.utils import iso_8601_datetime_without_milliseconds -from .utils import generate_key_id from collections import defaultdict from datetime import datetime, timedelta +import boto.kms + +from moto.core import BaseBackend, BaseModel +from moto.core.utils import iso_8601_datetime_without_milliseconds + +from .utils import decrypt, encrypt, generate_key_id, generate_master_key + class Key(BaseModel): - def __init__(self, policy, key_usage, description, tags, region): self.id = generate_key_id() self.policy = policy @@ -19,10 +21,11 @@ class Key(BaseModel): self.description = description self.enabled = True self.region = region - self.account_id = "0123456789012" + self.account_id = "012345678912" self.key_rotation_status = False self.deletion_date = None self.tags = tags or {} + self.key_material = generate_master_key() @property def physical_resource_id(self): @@ -30,7 +33,9 @@ class Key(BaseModel): @property def arn(self): - return "arn:aws:kms:{0}:{1}:key/{2}".format(self.region, self.account_id, self.id) + return "arn:aws:kms:{0}:{1}:key/{2}".format( + self.region, self.account_id, self.id + ) def to_dict(self): key_dict = { @@ -45,38 +50,42 @@ class Key(BaseModel): "KeyState": self.key_state, } } - if self.key_state == 'PendingDeletion': - key_dict['KeyMetadata']['DeletionDate'] = iso_8601_datetime_without_milliseconds(self.deletion_date) + if self.key_state == "PendingDeletion": + key_dict["KeyMetadata"][ + "DeletionDate" + ] = iso_8601_datetime_without_milliseconds(self.deletion_date) return key_dict def delete(self, region_name): kms_backends[region_name].delete_key(self.id) @classmethod - def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + self, resource_name, cloudformation_json, region_name + ): kms_backend = kms_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] key = kms_backend.create_key( - policy=properties['KeyPolicy'], - key_usage='ENCRYPT_DECRYPT', - description=properties['Description'], - tags=properties.get('Tags'), + policy=properties["KeyPolicy"], + key_usage="ENCRYPT_DECRYPT", + description=properties["Description"], + tags=properties.get("Tags"), region=region_name, ) - key.key_rotation_status = properties['EnableKeyRotation'] - key.enabled = properties['Enabled'] + key.key_rotation_status = properties["EnableKeyRotation"] + key.enabled = properties["Enabled"] return key def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() class KmsBackend(BaseBackend): - def __init__(self): self.keys = {} self.key_to_aliases = defaultdict(set) @@ -109,16 +118,43 @@ class KmsBackend(BaseBackend): # allow the different methods (alias, ARN :key/, keyId, ARN alias) to # describe key not just KeyId key_id = self.get_key_id(key_id) - if r'alias/' in str(key_id).lower(): - key_id = self.get_key_id_from_alias(key_id.split('alias/')[1]) + if r"alias/" in str(key_id).lower(): + key_id = self.get_key_id_from_alias(key_id.split("alias/")[1]) return self.keys[self.get_key_id(key_id)] def list_keys(self): return self.keys.values() - def get_key_id(self, key_id): + @staticmethod + def get_key_id(key_id): # Allow use of ARN as well as pure KeyId - return str(key_id).split(r':key/')[1] if r':key/' in str(key_id).lower() else key_id + if key_id.startswith("arn:") and ":key/" in key_id: + return key_id.split(":key/")[1] + + return key_id + + @staticmethod + def get_alias_name(alias_name): + # Allow use of ARN as well as alias name + if alias_name.startswith("arn:") and ":alias/" in alias_name: + return alias_name.split(":alias/")[1] + + return alias_name + + def any_id_to_key_id(self, key_id): + """Go from any valid key ID to the raw key ID. + + Acceptable inputs: + - raw key ID + - key ARN + - alias name + - alias ARN + """ + key_id = self.get_alias_name(key_id) + key_id = self.get_key_id(key_id) + if key_id.startswith("alias/"): + key_id = self.get_key_id_from_alias(key_id) + return key_id def alias_exists(self, alias_name): for aliases in self.key_to_aliases.values(): @@ -162,37 +198,89 @@ class KmsBackend(BaseBackend): def disable_key(self, key_id): self.keys[key_id].enabled = False - self.keys[key_id].key_state = 'Disabled' + self.keys[key_id].key_state = "Disabled" def enable_key(self, key_id): self.keys[key_id].enabled = True - self.keys[key_id].key_state = 'Enabled' + self.keys[key_id].key_state = "Enabled" def cancel_key_deletion(self, key_id): - self.keys[key_id].key_state = 'Disabled' + self.keys[key_id].key_state = "Disabled" self.keys[key_id].deletion_date = None def schedule_key_deletion(self, key_id, pending_window_in_days): if 7 <= pending_window_in_days <= 30: self.keys[key_id].enabled = False - self.keys[key_id].key_state = 'PendingDeletion' - self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days) - return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date) + self.keys[key_id].key_state = "PendingDeletion" + self.keys[key_id].deletion_date = datetime.now() + timedelta( + days=pending_window_in_days + ) + return iso_8601_datetime_without_milliseconds( + self.keys[key_id].deletion_date + ) - def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens): - key = self.keys[self.get_key_id(key_id)] + def encrypt(self, key_id, plaintext, encryption_context): + key_id = self.any_id_to_key_id(key_id) + + ciphertext_blob = encrypt( + master_keys=self.keys, + key_id=key_id, + plaintext=plaintext, + encryption_context=encryption_context, + ) + arn = self.keys[key_id].arn + return ciphertext_blob, arn + + def decrypt(self, ciphertext_blob, encryption_context): + plaintext, key_id = decrypt( + master_keys=self.keys, + ciphertext_blob=ciphertext_blob, + encryption_context=encryption_context, + ) + arn = self.keys[key_id].arn + return plaintext, arn + + def re_encrypt( + self, + ciphertext_blob, + source_encryption_context, + destination_key_id, + destination_encryption_context, + ): + destination_key_id = self.any_id_to_key_id(destination_key_id) + + plaintext, decrypting_arn = self.decrypt( + ciphertext_blob=ciphertext_blob, + encryption_context=source_encryption_context, + ) + new_ciphertext_blob, encrypting_arn = self.encrypt( + key_id=destination_key_id, + plaintext=plaintext, + encryption_context=destination_encryption_context, + ) + return new_ciphertext_blob, decrypting_arn, encrypting_arn + + def generate_data_key( + self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens + ): + key_id = self.any_id_to_key_id(key_id) if key_spec: - if key_spec == 'AES_128': - bytes = 16 + # Note: Actual validation of key_spec is done in kms.responses + if key_spec == "AES_128": + plaintext_len = 16 else: - bytes = 32 + plaintext_len = 32 else: - bytes = number_of_bytes + plaintext_len = number_of_bytes - plaintext = os.urandom(bytes) + plaintext = os.urandom(plaintext_len) - return plaintext, key.arn + ciphertext_blob, arn = self.encrypt( + key_id=key_id, plaintext=plaintext, encryption_context=encryption_context + ) + + return plaintext, ciphertext_blob, arn kms_backends = {} diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 53012b7f8..d3a9726e1 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -2,346 +2,524 @@ from __future__ import unicode_literals import base64 import json +import os import re + import six from moto.core.responses import BaseResponse from .models import kms_backends -from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException +from .exceptions import ( + NotFoundException, + ValidationException, + AlreadyExistsException, + NotAuthorizedException, +) +ACCOUNT_ID = "012345678912" reserved_aliases = [ - 'alias/aws/ebs', - 'alias/aws/s3', - 'alias/aws/redshift', - 'alias/aws/rds', + "alias/aws/ebs", + "alias/aws/s3", + "alias/aws/redshift", + "alias/aws/rds", ] class KmsResponse(BaseResponse): - @property def parameters(self): - return json.loads(self.body) + params = json.loads(self.body) + + for key in ("Plaintext", "CiphertextBlob"): + if key in params: + params[key] = base64.b64decode(params[key].encode("utf-8")) + + return params @property def kms_backend(self): return kms_backends[self.region] + def _display_arn(self, key_id): + if key_id.startswith("arn:"): + return key_id + + if key_id.startswith("alias/"): + id_type = "" + else: + id_type = "key/" + + return "arn:aws:kms:{region}:{account}:{id_type}{key_id}".format( + region=self.region, account=ACCOUNT_ID, id_type=id_type, key_id=key_id + ) + + def _validate_cmk_id(self, key_id): + """Determine whether a CMK ID exists. + + - raw key ID + - key ARN + """ + is_arn = key_id.startswith("arn:") and ":key/" in key_id + is_raw_key_id = re.match( + r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", + key_id, + re.IGNORECASE, + ) + + if not is_arn and not is_raw_key_id: + raise NotFoundException("Invalid keyId {key_id}".format(key_id=key_id)) + + cmk_id = self.kms_backend.get_key_id(key_id) + + if cmk_id not in self.kms_backend.keys: + raise NotFoundException( + "Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id)) + ) + + def _validate_alias(self, key_id): + """Determine whether an alias exists. + + - alias name + - alias ARN + """ + error = NotFoundException( + "Alias {key_id} is not found.".format(key_id=self._display_arn(key_id)) + ) + + is_arn = key_id.startswith("arn:") and ":alias/" in key_id + is_name = key_id.startswith("alias/") + + if not is_arn and not is_name: + raise error + + alias_name = self.kms_backend.get_alias_name(key_id) + cmk_id = self.kms_backend.get_key_id_from_alias(alias_name) + if cmk_id is None: + raise error + + def _validate_key_id(self, key_id): + """Determine whether or not a key ID exists. + + - raw key ID + - key ARN + - alias name + - alias ARN + """ + is_alias_arn = key_id.startswith("arn:") and ":alias/" in key_id + is_alias_name = key_id.startswith("alias/") + + if is_alias_arn or is_alias_name: + self._validate_alias(key_id) + return + + self._validate_cmk_id(key_id) + def create_key(self): - policy = self.parameters.get('Policy') - key_usage = self.parameters.get('KeyUsage') - description = self.parameters.get('Description') - tags = self.parameters.get('Tags') + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" + policy = self.parameters.get("Policy") + key_usage = self.parameters.get("KeyUsage") + description = self.parameters.get("Description") + tags = self.parameters.get("Tags") key = self.kms_backend.create_key( - policy, key_usage, description, tags, self.region) + policy, key_usage, description, tags, self.region + ) return json.dumps(key.to_dict()) def update_key_description(self): - key_id = self.parameters.get('KeyId') - description = self.parameters.get('Description') + """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html""" + key_id = self.parameters.get("KeyId") + description = self.parameters.get("Description") + + self._validate_cmk_id(key_id) self.kms_backend.update_key_description(key_id, description) return json.dumps(None) def tag_resource(self): - key_id = self.parameters.get('KeyId') - tags = self.parameters.get('Tags') + """https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html""" + key_id = self.parameters.get("KeyId") + tags = self.parameters.get("Tags") + + self._validate_cmk_id(key_id) + self.kms_backend.tag_resource(key_id, tags) return json.dumps({}) def list_resource_tags(self): - key_id = self.parameters.get('KeyId') + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" + key_id = self.parameters.get("KeyId") + + self._validate_cmk_id(key_id) + tags = self.kms_backend.list_resource_tags(key_id) - return json.dumps({ - "Tags": tags, - "NextMarker": None, - "Truncated": False, - }) + return json.dumps({"Tags": tags, "NextMarker": None, "Truncated": False}) def describe_key(self): - key_id = self.parameters.get('KeyId') - try: - key = self.kms_backend.describe_key( - self.kms_backend.get_key_id(key_id)) - except KeyError: - headers = dict(self.headers) - headers['status'] = 404 - return "{}", headers + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" + key_id = self.parameters.get("KeyId") + + self._validate_key_id(key_id) + + key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id)) + return json.dumps(key.to_dict()) def list_keys(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html""" keys = self.kms_backend.list_keys() - return json.dumps({ - "Keys": [ - { - "KeyArn": key.arn, - "KeyId": key.id, - } for key in keys - ], - "NextMarker": None, - "Truncated": False, - }) + return json.dumps( + { + "Keys": [{"KeyArn": key.arn, "KeyId": key.id} for key in keys], + "NextMarker": None, + "Truncated": False, + } + ) def create_alias(self): - alias_name = self.parameters['AliasName'] - target_key_id = self.parameters['TargetKeyId'] + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html""" + alias_name = self.parameters["AliasName"] + target_key_id = self.parameters["TargetKeyId"] - if not alias_name.startswith('alias/'): - raise ValidationException('Invalid identifier') + if not alias_name.startswith("alias/"): + raise ValidationException("Invalid identifier") if alias_name in reserved_aliases: raise NotAuthorizedException() - if ':' in alias_name: - raise ValidationException('{alias_name} contains invalid characters for an alias'.format(alias_name=alias_name)) + if ":" in alias_name: + raise ValidationException( + "{alias_name} contains invalid characters for an alias".format( + alias_name=alias_name + ) + ) - if not re.match(r'^[a-zA-Z0-9:/_-]+$', alias_name): - raise ValidationException("1 validation error detected: Value '{alias_name}' at 'aliasName' " - "failed to satisfy constraint: Member must satisfy regular " - "expression pattern: ^[a-zA-Z0-9:/_-]+$" - .format(alias_name=alias_name)) + if not re.match(r"^[a-zA-Z0-9:/_-]+$", alias_name): + raise ValidationException( + "1 validation error detected: Value '{alias_name}' at 'aliasName' " + "failed to satisfy constraint: Member must satisfy regular " + "expression pattern: ^[a-zA-Z0-9:/_-]+$".format(alias_name=alias_name) + ) if self.kms_backend.alias_exists(target_key_id): - raise ValidationException('Aliases must refer to keys. Not aliases') + raise ValidationException("Aliases must refer to keys. Not aliases") if self.kms_backend.alias_exists(alias_name): - raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} ' - 'already exists'.format(region=self.region, alias_name=alias_name)) + raise AlreadyExistsException( + "An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} " + "already exists".format(region=self.region, alias_name=alias_name) + ) + + self._validate_cmk_id(target_key_id) self.kms_backend.add_alias(target_key_id, alias_name) return json.dumps(None) def delete_alias(self): - alias_name = self.parameters['AliasName'] + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html""" + alias_name = self.parameters["AliasName"] - if not alias_name.startswith('alias/'): - raise ValidationException('Invalid identifier') + if not alias_name.startswith("alias/"): + raise ValidationException("Invalid identifier") - if not self.kms_backend.alias_exists(alias_name): - raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:' - '{alias_name} is not found.'.format(region=self.region, alias_name=alias_name)) + self._validate_alias(alias_name) self.kms_backend.delete_alias(alias_name) return json.dumps(None) def list_aliases(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html""" region = self.region + # TODO: The actual API can filter on KeyId. + response_aliases = [ { - 'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region, - reserved_alias=reserved_alias), - 'AliasName': reserved_alias - } for reserved_alias in reserved_aliases + "AliasArn": "arn:aws:kms:{region}:012345678912:{reserved_alias}".format( + region=region, reserved_alias=reserved_alias + ), + "AliasName": reserved_alias, + } + for reserved_alias in reserved_aliases ] backend_aliases = self.kms_backend.get_all_aliases() for target_key_id, aliases in backend_aliases.items(): for alias_name in aliases: - response_aliases.append({ - 'AliasArn': u'arn:aws:kms:{region}:012345678912:{alias_name}'.format(region=region, - alias_name=alias_name), - 'AliasName': alias_name, - 'TargetKeyId': target_key_id, - }) + response_aliases.append( + { + "AliasArn": "arn:aws:kms:{region}:012345678912:{alias_name}".format( + region=region, alias_name=alias_name + ), + "AliasName": alias_name, + "TargetKeyId": target_key_id, + } + ) - return json.dumps({ - 'Truncated': False, - 'Aliases': response_aliases, - }) + return json.dumps({"Truncated": False, "Aliases": response_aliases}) def enable_key_rotation(self): - key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.enable_key_rotation(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" + key_id = self.parameters.get("KeyId") + + self._validate_cmk_id(key_id) + + self.kms_backend.enable_key_rotation(key_id) return json.dumps(None) def disable_key_rotation(self): - key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.disable_key_rotation(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" + key_id = self.parameters.get("KeyId") + + self._validate_cmk_id(key_id) + + self.kms_backend.disable_key_rotation(key_id) + return json.dumps(None) def get_key_rotation_status(self): - key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) - return json.dumps({'KeyRotationEnabled': rotation_enabled}) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html""" + key_id = self.parameters.get("KeyId") + + self._validate_cmk_id(key_id) + + rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) + + return json.dumps({"KeyRotationEnabled": rotation_enabled}) def put_key_policy(self): - key_id = self.parameters.get('KeyId') - policy_name = self.parameters.get('PolicyName') - policy = self.parameters.get('Policy') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html""" + key_id = self.parameters.get("KeyId") + policy_name = self.parameters.get("PolicyName") + policy = self.parameters.get("Policy") _assert_default_policy(policy_name) - try: - self.kms_backend.put_key_policy(key_id, policy) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + self._validate_cmk_id(key_id) + + self.kms_backend.put_key_policy(key_id, policy) return json.dumps(None) def get_key_policy(self): - key_id = self.parameters.get('KeyId') - policy_name = self.parameters.get('PolicyName') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html""" + key_id = self.parameters.get("KeyId") + policy_name = self.parameters.get("PolicyName") _assert_default_policy(policy_name) - try: - return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + self._validate_cmk_id(key_id) + + return json.dumps({"Policy": self.kms_backend.get_key_policy(key_id)}) def list_key_policies(self): - key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.describe_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html""" + key_id = self.parameters.get("KeyId") - return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) + self._validate_cmk_id(key_id) + + self.kms_backend.describe_key(key_id) + + return json.dumps({"Truncated": False, "PolicyNames": ["default"]}) def encrypt(self): - """ - We perform no encryption, we just encode the value as base64 and then - decode it in decrypt(). - """ - value = self.parameters.get("Plaintext") - if isinstance(value, six.text_type): - value = value.encode('utf-8') - return json.dumps({"CiphertextBlob": base64.b64encode(value).decode("utf-8"), 'KeyId': 'key_id'}) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html""" + key_id = self.parameters.get("KeyId") + encryption_context = self.parameters.get("EncryptionContext", {}) + plaintext = self.parameters.get("Plaintext") + + self._validate_key_id(key_id) + + if isinstance(plaintext, six.text_type): + plaintext = plaintext.encode("utf-8") + + ciphertext_blob, arn = self.kms_backend.encrypt( + key_id=key_id, plaintext=plaintext, encryption_context=encryption_context + ) + ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") + + return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn}) def decrypt(self): - # TODO refuse decode if EncryptionContext is not the same as when it was encrypted / generated + """https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html""" + ciphertext_blob = self.parameters.get("CiphertextBlob") + encryption_context = self.parameters.get("EncryptionContext", {}) - value = self.parameters.get("CiphertextBlob") - try: - return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'}) - except UnicodeDecodeError: - # Generate data key will produce random bytes which when decrypted is still returned as base64 - return json.dumps({"Plaintext": value}) + plaintext, arn = self.kms_backend.decrypt( + ciphertext_blob=ciphertext_blob, encryption_context=encryption_context + ) + + plaintext_response = base64.b64encode(plaintext).decode("utf-8") + + return json.dumps({"Plaintext": plaintext_response, "KeyId": arn}) + + def re_encrypt(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html""" + ciphertext_blob = self.parameters.get("CiphertextBlob") + source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) + destination_key_id = self.parameters.get("DestinationKeyId") + destination_encryption_context = self.parameters.get( + "DestinationEncryptionContext", {} + ) + + self._validate_cmk_id(destination_key_id) + + ( + new_ciphertext_blob, + decrypting_arn, + encrypting_arn, + ) = self.kms_backend.re_encrypt( + ciphertext_blob=ciphertext_blob, + source_encryption_context=source_encryption_context, + destination_key_id=destination_key_id, + destination_encryption_context=destination_encryption_context, + ) + + response_ciphertext_blob = base64.b64encode(new_ciphertext_blob).decode("utf-8") + + return json.dumps( + { + "CiphertextBlob": response_ciphertext_blob, + "KeyId": encrypting_arn, + "SourceKeyId": decrypting_arn, + } + ) def disable_key(self): - key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.disable_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html""" + key_id = self.parameters.get("KeyId") + + self._validate_cmk_id(key_id) + + self.kms_backend.disable_key(key_id) + return json.dumps(None) def enable_key(self): - key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.enable_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html""" + key_id = self.parameters.get("KeyId") + + self._validate_cmk_id(key_id) + + self.kms_backend.enable_key(key_id) + return json.dumps(None) def cancel_key_deletion(self): - key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.cancel_key_deletion(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) - return json.dumps({'KeyId': key_id}) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html""" + key_id = self.parameters.get("KeyId") + + self._validate_cmk_id(key_id) + + self.kms_backend.cancel_key_deletion(key_id) + + return json.dumps({"KeyId": key_id}) def schedule_key_deletion(self): - key_id = self.parameters.get('KeyId') - if self.parameters.get('PendingWindowInDays') is None: + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html""" + key_id = self.parameters.get("KeyId") + if self.parameters.get("PendingWindowInDays") is None: pending_window_in_days = 30 else: - pending_window_in_days = self.parameters.get('PendingWindowInDays') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - return json.dumps({ - 'KeyId': key_id, - 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) - }) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + pending_window_in_days = self.parameters.get("PendingWindowInDays") + + self._validate_cmk_id(key_id) + + return json.dumps( + { + "KeyId": key_id, + "DeletionDate": self.kms_backend.schedule_key_deletion( + key_id, pending_window_in_days + ), + } + ) def generate_data_key(self): - key_id = self.parameters.get('KeyId') - encryption_context = self.parameters.get('EncryptionContext') - number_of_bytes = self.parameters.get('NumberOfBytes') - key_spec = self.parameters.get('KeySpec') - grant_tokens = self.parameters.get('GrantTokens') + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html""" + key_id = self.parameters.get("KeyId") + encryption_context = self.parameters.get("EncryptionContext", {}) + number_of_bytes = self.parameters.get("NumberOfBytes") + key_spec = self.parameters.get("KeySpec") + grant_tokens = self.parameters.get("GrantTokens") # Param validation - if key_id.startswith('alias'): - if self.kms_backend.get_key_id_from_alias(key_id) is None: - raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:{alias_name} is not found.'.format( - region=self.region, alias_name=key_id)) - else: - if self.kms_backend.get_key_id(key_id) not in self.kms_backend.keys: - raise NotFoundException('Invalid keyId') + self._validate_key_id(key_id) - if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0): - raise ValidationException("1 validation error detected: Value '2048' at 'numberOfBytes' failed " - "to satisfy constraint: Member must have value less than or " - "equal to 1024") + if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): + raise ValidationException( + ( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes) + ) - if key_spec and key_spec not in ('AES_256', 'AES_128'): - raise ValidationException("1 validation error detected: Value 'AES_257' at 'keySpec' failed " - "to satisfy constraint: Member must satisfy enum value set: " - "[AES_256, AES_128]") + if key_spec and key_spec not in ("AES_256", "AES_128"): + raise ValidationException( + ( + "1 validation error detected: Value '{key_spec}' at 'keySpec' failed " + "to satisfy constraint: Member must satisfy enum value set: " + "[AES_256, AES_128]" + ).format(key_spec=key_spec) + ) if not key_spec and not number_of_bytes: - raise ValidationException("Please specify either number of bytes or key spec.") + raise ValidationException( + "Please specify either number of bytes or key spec." + ) + if key_spec and number_of_bytes: - raise ValidationException("Please specify either number of bytes or key spec.") + raise ValidationException( + "Please specify either number of bytes or key spec." + ) - plaintext, key_arn = self.kms_backend.generate_data_key(key_id, encryption_context, - number_of_bytes, key_spec, grant_tokens) + plaintext, ciphertext_blob, key_arn = self.kms_backend.generate_data_key( + key_id=key_id, + encryption_context=encryption_context, + number_of_bytes=number_of_bytes, + key_spec=key_spec, + grant_tokens=grant_tokens, + ) - plaintext = base64.b64encode(plaintext).decode() + plaintext_response = base64.b64encode(plaintext).decode("utf-8") + ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") - return json.dumps({ - 'CiphertextBlob': plaintext, - 'Plaintext': plaintext, - 'KeyId': key_arn # not alias - }) + return json.dumps( + { + "CiphertextBlob": ciphertext_blob_response, + "Plaintext": plaintext_response, + "KeyId": key_arn, # not alias + } + ) def generate_data_key_without_plaintext(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html""" result = json.loads(self.generate_data_key()) - del result['Plaintext'] + del result["Plaintext"] return json.dumps(result) + def generate_random(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html""" + number_of_bytes = self.parameters.get("NumberOfBytes") -def _assert_valid_key_id(key_id): - if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE): - raise NotFoundException('Invalid keyId') + if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): + raise ValidationException( + ( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes) + ) + + entropy = os.urandom(number_of_bytes) + + response_entropy = base64.b64encode(entropy).decode("utf-8") + + return json.dumps({"Plaintext": response_entropy}) def _assert_default_policy(policy_name): - if policy_name != 'default': + if policy_name != "default": raise NotFoundException("No such policy exists") diff --git a/moto/kms/urls.py b/moto/kms/urls.py index 5b0b48969..97e1a3720 100644 --- a/moto/kms/urls.py +++ b/moto/kms/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import KmsResponse -url_bases = [ - "https?://kms.(.+).amazonaws.com", -] +url_bases = ["https?://kms.(.+).amazonaws.com"] -url_paths = { - '{0}/$': KmsResponse.dispatch, -} +url_paths = {"{0}/$": KmsResponse.dispatch} diff --git a/moto/kms/utils.py b/moto/kms/utils.py index fad38150f..4eacba1a6 100644 --- a/moto/kms/utils.py +++ b/moto/kms/utils.py @@ -1,7 +1,161 @@ from __future__ import unicode_literals +from collections import namedtuple +import io +import os +import struct import uuid +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes + +from .exceptions import ( + InvalidCiphertextException, + AccessDeniedException, + NotFoundException, +) + + +MASTER_KEY_LEN = 32 +KEY_ID_LEN = 36 +IV_LEN = 12 +TAG_LEN = 16 +HEADER_LEN = KEY_ID_LEN + IV_LEN + TAG_LEN +# NOTE: This is just a simple binary format. It is not what KMS actually does. +CIPHERTEXT_HEADER_FORMAT = ">{key_id_len}s{iv_len}s{tag_len}s".format( + key_id_len=KEY_ID_LEN, iv_len=IV_LEN, tag_len=TAG_LEN +) +Ciphertext = namedtuple("Ciphertext", ("key_id", "iv", "ciphertext", "tag")) + def generate_key_id(): return str(uuid.uuid4()) + + +def generate_data_key(number_of_bytes): + """Generate a data key.""" + return os.urandom(number_of_bytes) + + +def generate_master_key(): + """Generate a master key.""" + return generate_data_key(MASTER_KEY_LEN) + + +def _serialize_ciphertext_blob(ciphertext): + """Serialize Ciphertext object into a ciphertext blob. + + NOTE: This is just a simple binary format. It is not what KMS actually does. + """ + header = struct.pack( + CIPHERTEXT_HEADER_FORMAT, + ciphertext.key_id.encode("utf-8"), + ciphertext.iv, + ciphertext.tag, + ) + return header + ciphertext.ciphertext + + +def _deserialize_ciphertext_blob(ciphertext_blob): + """Deserialize ciphertext blob into a Ciphertext object. + + NOTE: This is just a simple binary format. It is not what KMS actually does. + """ + header = ciphertext_blob[:HEADER_LEN] + ciphertext = ciphertext_blob[HEADER_LEN:] + key_id, iv, tag = struct.unpack(CIPHERTEXT_HEADER_FORMAT, header) + return Ciphertext( + key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag + ) + + +def _serialize_encryption_context(encryption_context): + """Serialize encryption context for use a AAD. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + """ + aad = io.BytesIO() + for key, value in sorted(encryption_context.items(), key=lambda x: x[0]): + aad.write(key.encode("utf-8")) + aad.write(value.encode("utf-8")) + return aad.getvalue() + + +def encrypt(master_keys, key_id, plaintext, encryption_context): + """Encrypt data using a master key material. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + + NOTE: This function is NOT compatible with KMS APIs. + :param dict master_keys: Mapping of a KmsBackend's known master keys + :param str key_id: Key ID of moto master key + :param bytes plaintext: Plaintext data to encrypt + :param dict[str, str] encryption_context: KMS-style encryption context + :returns: Moto-structured ciphertext blob encrypted under a moto master key in master_keys + :rtype: bytes + """ + try: + key = master_keys[key_id] + except KeyError: + is_alias = key_id.startswith("alias/") or ":alias/" in key_id + raise NotFoundException( + "{id_type} {key_id} is not found.".format( + id_type="Alias" if is_alias else "keyId", key_id=key_id + ) + ) + + iv = os.urandom(IV_LEN) + aad = _serialize_encryption_context(encryption_context=encryption_context) + + encryptor = Cipher( + algorithms.AES(key.key_material), modes.GCM(iv), backend=default_backend() + ).encryptor() + encryptor.authenticate_additional_data(aad) + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + return _serialize_ciphertext_blob( + ciphertext=Ciphertext( + key_id=key_id, iv=iv, ciphertext=ciphertext, tag=encryptor.tag + ) + ) + + +def decrypt(master_keys, ciphertext_blob, encryption_context): + """Decrypt a ciphertext blob using a master key material. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + + NOTE: This function is NOT compatible with KMS APIs. + + :param dict master_keys: Mapping of a KmsBackend's known master keys + :param bytes ciphertext_blob: moto-structured ciphertext blob encrypted under a moto master key in master_keys + :param dict[str, str] encryption_context: KMS-style encryption context + :returns: plaintext bytes and moto key ID + :rtype: bytes and str + """ + try: + ciphertext = _deserialize_ciphertext_blob(ciphertext_blob=ciphertext_blob) + except Exception: + raise InvalidCiphertextException() + + aad = _serialize_encryption_context(encryption_context=encryption_context) + + try: + key = master_keys[ciphertext.key_id] + except KeyError: + raise AccessDeniedException( + "The ciphertext refers to a customer master key that does not exist, " + "does not exist in this region, or you are not allowed to access." + ) + + try: + decryptor = Cipher( + algorithms.AES(key.key_material), + modes.GCM(ciphertext.iv, ciphertext.tag), + backend=default_backend(), + ).decryptor() + decryptor.authenticate_additional_data(aad) + plaintext = decryptor.update(ciphertext.ciphertext) + decryptor.finalize() + except Exception: + raise InvalidCiphertextException() + + return plaintext, ciphertext.key_id diff --git a/moto/logs/exceptions.py b/moto/logs/exceptions.py index bb02eced3..9f6628b0f 100644 --- a/moto/logs/exceptions.py +++ b/moto/logs/exceptions.py @@ -10,8 +10,7 @@ class ResourceNotFoundException(LogsClientError): def __init__(self): self.code = 400 super(ResourceNotFoundException, self).__init__( - "ResourceNotFoundException", - "The specified resource does not exist" + "ResourceNotFoundException", "The specified resource does not exist" ) @@ -19,8 +18,7 @@ class InvalidParameterException(LogsClientError): def __init__(self, msg=None): self.code = 400 super(InvalidParameterException, self).__init__( - "InvalidParameterException", - msg or "A parameter is specified incorrectly." + "InvalidParameterException", msg or "A parameter is specified incorrectly." ) @@ -28,6 +26,5 @@ class ResourceAlreadyExistsException(LogsClientError): def __init__(self): self.code = 400 super(ResourceAlreadyExistsException, self).__init__( - 'ResourceAlreadyExistsException', - 'The specified log group already exists' + "ResourceAlreadyExistsException", "The specified log group already exists" ) diff --git a/moto/logs/models.py b/moto/logs/models.py index 2b8dcfeb4..d0639524e 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -3,7 +3,8 @@ import boto.logs from moto.core.utils import unix_time_millis from .exceptions import ( ResourceNotFoundException, - ResourceAlreadyExistsException + ResourceAlreadyExistsException, + InvalidParameterException, ) @@ -13,7 +14,7 @@ class LogEvent: def __init__(self, ingestion_time, log_event): self.ingestionTime = ingestion_time self.timestamp = log_event["timestamp"] - self.message = log_event['message'] + self.message = log_event["message"] self.eventId = self.__class__._event_id self.__class__._event_id += 1 @@ -23,14 +24,14 @@ class LogEvent: "ingestionTime": self.ingestionTime, # "logStreamName": "message": self.message, - "timestamp": self.timestamp + "timestamp": self.timestamp, } def to_response_dict(self): return { "ingestionTime": self.ingestionTime, "message": self.message, - "timestamp": self.timestamp + "timestamp": self.timestamp, } @@ -40,22 +41,32 @@ class LogStream: def __init__(self, region, log_group, name): self.region = region self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format( - region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name) - self.creationTime = unix_time_millis() + region=region, + id=self.__class__._log_ids, + log_group=log_group, + log_stream=name, + ) + self.creationTime = int(unix_time_millis()) self.firstEventTimestamp = None self.lastEventTimestamp = None self.lastIngestionTime = None self.logStreamName = name self.storedBytes = 0 - self.uploadSequenceToken = 0 # I'm guessing this is token needed for sequenceToken by put_events + self.uploadSequenceToken = ( + 0 # I'm guessing this is token needed for sequenceToken by put_events + ) self.events = [] self.__class__._log_ids += 1 def _update(self): # events can be empty when stream is described soon after creation - self.firstEventTimestamp = min([x.timestamp for x in self.events]) if self.events else None - self.lastEventTimestamp = max([x.timestamp for x in self.events]) if self.events else None + self.firstEventTimestamp = ( + min([x.timestamp for x in self.events]) if self.events else None + ) + self.lastEventTimestamp = ( + max([x.timestamp for x in self.events]) if self.events else None + ) def to_describe_dict(self): # Compute start and end times @@ -77,18 +88,31 @@ class LogStream: res.update(rest) return res - def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): + def put_log_events( + self, log_group_name, log_stream_name, log_events, sequence_token + ): # TODO: ensure sequence_token # TODO: to be thread safe this would need a lock - self.lastIngestionTime = unix_time_millis() + self.lastIngestionTime = int(unix_time_millis()) # TODO: make this match AWS if possible self.storedBytes += sum([len(log_event["message"]) for log_event in log_events]) - self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events] + self.events += [ + LogEvent(self.lastIngestionTime, log_event) for log_event in log_events + ] self.uploadSequenceToken += 1 - return '{:056d}'.format(self.uploadSequenceToken) + return "{:056d}".format(self.uploadSequenceToken) - def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head): + def get_log_events( + self, + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ): def filter_func(event): if start_time and event.timestamp < start_time: return False @@ -98,31 +122,82 @@ class LogStream: return True - def get_paging_token_from_index(index, back=False): - if index is not None: - return "b/{:056d}".format(index) if back else "f/{:056d}".format(index) - return 0 - - def get_index_from_paging_token(token): + def get_index_and_direction_from_token(token): if token is not None: - return int(token[2:]) - return 0 + try: + return token[0], int(token[2:]) + except Exception: + raise InvalidParameterException( + "The specified nextToken is invalid." + ) + return None, 0 - events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head) - next_index = get_index_from_paging_token(next_token) - back_index = next_index + events = sorted( + filter(filter_func, self.events), key=lambda event: event.timestamp, + ) - events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]] - if next_index + limit < len(self.events): - next_index += limit + direction, index = get_index_and_direction_from_token(next_token) + limit_index = limit - 1 + final_index = len(events) - 1 - back_index -= limit - if back_index <= 0: - back_index = 0 + if direction is None: + if start_from_head: + start_index = 0 + end_index = start_index + limit_index + else: + end_index = final_index + start_index = end_index - limit_index + elif direction == "f": + start_index = index + 1 + end_index = start_index + limit_index + elif direction == "b": + end_index = index - 1 + start_index = end_index - limit_index + else: + raise InvalidParameterException("The specified nextToken is invalid.") - return events_page, get_paging_token_from_index(back_index, True), get_paging_token_from_index(next_index) + if start_index < 0: + start_index = 0 + elif start_index > final_index: + return ( + [], + "b/{:056d}".format(final_index), + "f/{:056d}".format(final_index), + ) + + if end_index > final_index: + end_index = final_index + elif end_index < 0: + return ( + [], + "b/{:056d}".format(0), + "f/{:056d}".format(0), + ) + + events_page = [ + event.to_response_dict() for event in events[start_index : end_index + 1] + ] + + return ( + events_page, + "b/{:056d}".format(start_index), + "f/{:056d}".format(end_index), + ) + + def filter_log_events( + self, + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ): + if filter_pattern: + raise NotImplementedError("filter_pattern is not yet implemented") - def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): def filter_func(event): if start_time and event.timestamp < start_time: return False @@ -133,9 +208,11 @@ class LogStream: return True events = [] - for event in sorted(filter(filter_func, self.events), key=lambda x: x.timestamp): + for event in sorted( + filter(filter_func, self.events), key=lambda x: x.timestamp + ): event_obj = event.to_filter_dict() - event_obj['logStreamName'] = self.logStreamName + event_obj["logStreamName"] = self.logStreamName events.append(event_obj) return events @@ -145,72 +222,140 @@ class LogGroup: self.name = name self.region = region self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format( - region=region, log_group=name) - self.creationTime = unix_time_millis() + region=region, log_group=name + ) + self.creationTime = int(unix_time_millis()) self.tags = tags self.streams = dict() # {name: LogStream} - self.retentionInDays = None # AWS defaults to Never Expire for log group retention + self.retentionInDays = ( + None # AWS defaults to Never Expire for log group retention + ) def create_log_stream(self, log_stream_name): if log_stream_name in self.streams: raise ResourceAlreadyExistsException() - self.streams[log_stream_name] = LogStream(self.region, self.name, log_stream_name) + self.streams[log_stream_name] = LogStream( + self.region, self.name, log_stream_name + ) def delete_log_stream(self, log_stream_name): if log_stream_name not in self.streams: raise ResourceNotFoundException() del self.streams[log_stream_name] - def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by): + def describe_log_streams( + self, + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ): # responses only logStreamName, creationTime, arn, storedBytes when no events are stored. - log_streams = [(name, stream.to_describe_dict()) for name, stream in self.streams.items() if name.startswith(log_stream_name_prefix)] + log_streams = [ + (name, stream.to_describe_dict()) + for name, stream in self.streams.items() + if name.startswith(log_stream_name_prefix) + ] def sorter(item): - return item[0] if order_by == 'logStreamName' else item[1].get('lastEventTimestamp', 0) + return ( + item[0] + if order_by == "logStreamName" + else item[1].get("lastEventTimestamp", 0) + ) if next_token is None: next_token = 0 log_streams = sorted(log_streams, key=sorter, reverse=descending) new_token = next_token + limit - log_streams_page = [x[1] for x in log_streams[next_token: new_token]] + log_streams_page = [x[1] for x in log_streams[next_token:new_token]] if new_token >= len(log_streams): new_token = None return log_streams_page, new_token - def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): + def put_log_events( + self, log_group_name, log_stream_name, log_events, sequence_token + ): if log_stream_name not in self.streams: raise ResourceNotFoundException() stream = self.streams[log_stream_name] - return stream.put_log_events(log_group_name, log_stream_name, log_events, sequence_token) + return stream.put_log_events( + log_group_name, log_stream_name, log_events, sequence_token + ) - def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head): + def get_log_events( + self, + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ): if log_stream_name not in self.streams: raise ResourceNotFoundException() stream = self.streams[log_stream_name] - return stream.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head) + return stream.get_log_events( + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ) - def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): - streams = [stream for name, stream in self.streams.items() if not log_stream_names or name in log_stream_names] + def filter_log_events( + self, + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ): + streams = [ + stream + for name, stream in self.streams.items() + if not log_stream_names or name in log_stream_names + ] events = [] for stream in streams: - events += stream.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved) + events += stream.filter_log_events( + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ) if interleaved: - events = sorted(events, key=lambda event: event['timestamp']) + events = sorted(events, key=lambda event: event["timestamp"]) if next_token is None: next_token = 0 - events_page = events[next_token: next_token + limit] + events_page = events[next_token : next_token + limit] next_token += limit if next_token >= len(events): next_token = None - searched_streams = [{"logStreamName": stream.logStreamName, "searchedCompletely": True} for stream in streams] + searched_streams = [ + {"logStreamName": stream.logStreamName, "searchedCompletely": True} + for stream in streams + ] return events_page, next_token, searched_streams def to_describe_dict(self): @@ -229,6 +374,21 @@ class LogGroup: def set_retention_policy(self, retention_in_days): self.retentionInDays = retention_in_days + def list_tags(self): + return self.tags if self.tags else {} + + def tag(self, tags): + if self.tags: + self.tags.update(tags) + else: + self.tags = tags + + def untag(self, tags_to_remove): + if self.tags: + self.tags = { + k: v for (k, v) in self.tags.items() if k not in tags_to_remove + } + class LogsBackend(BaseBackend): def __init__(self, region_name): @@ -257,13 +417,17 @@ class LogsBackend(BaseBackend): def describe_log_groups(self, limit, log_group_name_prefix, next_token): if log_group_name_prefix is None: - log_group_name_prefix = '' + log_group_name_prefix = "" if next_token is None: next_token = 0 - groups = [group.to_describe_dict() for name, group in self.groups.items() if name.startswith(log_group_name_prefix)] - groups = sorted(groups, key=lambda x: x['creationTime'], reverse=True) - groups_page = groups[next_token:next_token + limit] + groups = [ + group.to_describe_dict() + for name, group in self.groups.items() + if name.startswith(log_group_name_prefix) + ] + groups = sorted(groups, key=lambda x: x["creationTime"], reverse=True) + groups_page = groups[next_token : next_token + limit] next_token += limit if next_token >= len(groups): @@ -283,30 +447,85 @@ class LogsBackend(BaseBackend): log_group = self.groups[log_group_name] return log_group.delete_log_stream(log_stream_name) - def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by): + def describe_log_streams( + self, + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ): if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.describe_log_streams(descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by) + return log_group.describe_log_streams( + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ) - def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): + def put_log_events( + self, log_group_name, log_stream_name, log_events, sequence_token + ): # TODO: add support for sequence_tokens if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.put_log_events(log_group_name, log_stream_name, log_events, sequence_token) + return log_group.put_log_events( + log_group_name, log_stream_name, log_events, sequence_token + ) - def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head): + def get_log_events( + self, + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ): if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head) + return log_group.get_log_events( + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ) - def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): + def filter_log_events( + self, + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ): if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved) + return log_group.filter_log_events( + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ) def put_retention_policy(self, log_group_name, retention_in_days): if log_group_name not in self.groups: @@ -320,5 +539,25 @@ class LogsBackend(BaseBackend): log_group = self.groups[log_group_name] return log_group.set_retention_policy(None) + def list_tags_log_group(self, log_group_name): + if log_group_name not in self.groups: + raise ResourceNotFoundException() + log_group = self.groups[log_group_name] + return log_group.list_tags() -logs_backends = {region.name: LogsBackend(region.name) for region in boto.logs.regions()} + def tag_log_group(self, log_group_name, tags): + if log_group_name not in self.groups: + raise ResourceNotFoundException() + log_group = self.groups[log_group_name] + log_group.tag(tags) + + def untag_log_group(self, log_group_name, tags): + if log_group_name not in self.groups: + raise ResourceNotFoundException() + log_group = self.groups[log_group_name] + log_group.untag(tags) + + +logs_backends = { + region.name: LogsBackend(region.name) for region in boto.logs.regions() +} diff --git a/moto/logs/responses.py b/moto/logs/responses.py index 39f24a260..072c76b71 100644 --- a/moto/logs/responses.py +++ b/moto/logs/responses.py @@ -5,6 +5,7 @@ import json # See http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/Welcome.html + class LogsResponse(BaseResponse): @property def logs_backend(self): @@ -21,116 +22,159 @@ class LogsResponse(BaseResponse): return self.request_params.get(param, if_none) def create_log_group(self): - log_group_name = self._get_param('logGroupName') - tags = self._get_param('tags') + log_group_name = self._get_param("logGroupName") + tags = self._get_param("tags") assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern self.logs_backend.create_log_group(log_group_name, tags) - return '' + return "" def delete_log_group(self): - log_group_name = self._get_param('logGroupName') + log_group_name = self._get_param("logGroupName") self.logs_backend.delete_log_group(log_group_name) - return '' + return "" def describe_log_groups(self): - log_group_name_prefix = self._get_param('logGroupNamePrefix') - next_token = self._get_param('nextToken') - limit = self._get_param('limit', 50) + log_group_name_prefix = self._get_param("logGroupNamePrefix") + next_token = self._get_param("nextToken") + limit = self._get_param("limit", 50) assert limit <= 50 groups, next_token = self.logs_backend.describe_log_groups( - limit, log_group_name_prefix, next_token) - return json.dumps({ - "logGroups": groups, - "nextToken": next_token - }) + limit, log_group_name_prefix, next_token + ) + return json.dumps({"logGroups": groups, "nextToken": next_token}) def create_log_stream(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") self.logs_backend.create_log_stream(log_group_name, log_stream_name) - return '' + return "" def delete_log_stream(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") self.logs_backend.delete_log_stream(log_group_name, log_stream_name) - return '' + return "" def describe_log_streams(self): - log_group_name = self._get_param('logGroupName') - log_stream_name_prefix = self._get_param('logStreamNamePrefix', '') - descending = self._get_param('descending', False) - limit = self._get_param('limit', 50) + log_group_name = self._get_param("logGroupName") + log_stream_name_prefix = self._get_param("logStreamNamePrefix", "") + descending = self._get_param("descending", False) + limit = self._get_param("limit", 50) assert limit <= 50 - next_token = self._get_param('nextToken') - order_by = self._get_param('orderBy', 'LogStreamName') - assert order_by in {'LogStreamName', 'LastEventTime'} + next_token = self._get_param("nextToken") + order_by = self._get_param("orderBy", "LogStreamName") + assert order_by in {"LogStreamName", "LastEventTime"} - if order_by == 'LastEventTime': + if order_by == "LastEventTime": assert not log_stream_name_prefix streams, next_token = self.logs_backend.describe_log_streams( - descending, limit, log_group_name, log_stream_name_prefix, - next_token, order_by) - return json.dumps({ - "logStreams": streams, - "nextToken": next_token - }) + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ) + return json.dumps({"logStreams": streams, "nextToken": next_token}) def put_log_events(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') - log_events = self._get_param('logEvents') - sequence_token = self._get_param('sequenceToken') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") + log_events = self._get_param("logEvents") + sequence_token = self._get_param("sequenceToken") - next_sequence_token = self.logs_backend.put_log_events(log_group_name, log_stream_name, log_events, sequence_token) - return json.dumps({'nextSequenceToken': next_sequence_token}) + next_sequence_token = self.logs_backend.put_log_events( + log_group_name, log_stream_name, log_events, sequence_token + ) + return json.dumps({"nextSequenceToken": next_sequence_token}) def get_log_events(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') - start_time = self._get_param('startTime') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") + start_time = self._get_param("startTime") end_time = self._get_param("endTime") - limit = self._get_param('limit', 10000) + limit = self._get_param("limit", 10000) assert limit <= 10000 - next_token = self._get_param('nextToken') - start_from_head = self._get_param('startFromHead', False) + next_token = self._get_param("nextToken") + start_from_head = self._get_param("startFromHead", False) - events, next_backward_token, next_foward_token = \ - self.logs_backend.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head) - return json.dumps({ - "events": events, - "nextBackwardToken": next_backward_token, - "nextForwardToken": next_foward_token - }) + ( + events, + next_backward_token, + next_foward_token, + ) = self.logs_backend.get_log_events( + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ) + return json.dumps( + { + "events": events, + "nextBackwardToken": next_backward_token, + "nextForwardToken": next_foward_token, + } + ) def filter_log_events(self): - log_group_name = self._get_param('logGroupName') - log_stream_names = self._get_param('logStreamNames', []) - start_time = self._get_param('startTime') + log_group_name = self._get_param("logGroupName") + log_stream_names = self._get_param("logStreamNames", []) + start_time = self._get_param("startTime") # impl, see: http://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html - filter_pattern = self._get_param('filterPattern') - interleaved = self._get_param('interleaved', False) + filter_pattern = self._get_param("filterPattern") + interleaved = self._get_param("interleaved", False) end_time = self._get_param("endTime") - limit = self._get_param('limit', 10000) + limit = self._get_param("limit", 10000) assert limit <= 10000 - next_token = self._get_param('nextToken') + next_token = self._get_param("nextToken") - events, next_token, searched_streams = self.logs_backend.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved) - return json.dumps({ - "events": events, - "nextToken": next_token, - "searchedLogStreams": searched_streams - }) + events, next_token, searched_streams = self.logs_backend.filter_log_events( + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ) + return json.dumps( + { + "events": events, + "nextToken": next_token, + "searchedLogStreams": searched_streams, + } + ) def put_retention_policy(self): - log_group_name = self._get_param('logGroupName') - retention_in_days = self._get_param('retentionInDays') + log_group_name = self._get_param("logGroupName") + retention_in_days = self._get_param("retentionInDays") self.logs_backend.put_retention_policy(log_group_name, retention_in_days) - return '' + return "" def delete_retention_policy(self): - log_group_name = self._get_param('logGroupName') + log_group_name = self._get_param("logGroupName") self.logs_backend.delete_retention_policy(log_group_name) - return '' + return "" + + def list_tags_log_group(self): + log_group_name = self._get_param("logGroupName") + tags = self.logs_backend.list_tags_log_group(log_group_name) + return json.dumps({"tags": tags}) + + def tag_log_group(self): + log_group_name = self._get_param("logGroupName") + tags = self._get_param("tags") + self.logs_backend.tag_log_group(log_group_name, tags) + return "" + + def untag_log_group(self): + log_group_name = self._get_param("logGroupName") + tags = self._get_param("tags") + self.logs_backend.untag_log_group(log_group_name, tags) + return "" diff --git a/moto/logs/urls.py b/moto/logs/urls.py index b7910e675..e4e1f5a88 100644 --- a/moto/logs/urls.py +++ b/moto/logs/urls.py @@ -1,9 +1,5 @@ from .responses import LogsResponse -url_bases = [ - "https?://logs.(.+).amazonaws.com", -] +url_bases = ["https?://logs.(.+).amazonaws.com"] -url_paths = { - '{0}/$': LogsResponse.dispatch, -} +url_paths = {"{0}/$": LogsResponse.dispatch} diff --git a/moto/opsworks/__init__.py b/moto/opsworks/__init__.py index b492b6a53..e0e6b88d0 100644 --- a/moto/opsworks/__init__.py +++ b/moto/opsworks/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import opsworks_backends from ..core.models import base_decorator, deprecated_base_decorator -opsworks_backend = opsworks_backends['us-east-1'] +opsworks_backend = opsworks_backends["us-east-1"] mock_opsworks = base_decorator(opsworks_backends) mock_opsworks_deprecated = deprecated_base_decorator(opsworks_backends) diff --git a/moto/opsworks/exceptions.py b/moto/opsworks/exceptions.py index 00bdffbc5..3867b3b90 100644 --- a/moto/opsworks/exceptions.py +++ b/moto/opsworks/exceptions.py @@ -5,20 +5,16 @@ from werkzeug.exceptions import BadRequest class ResourceNotFoundException(BadRequest): - def __init__(self, message): super(ResourceNotFoundException, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) class ValidationException(BadRequest): - def __init__(self, message): super(ValidationException, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) diff --git a/moto/opsworks/models.py b/moto/opsworks/models.py index 4fe428c65..96d918cc9 100644 --- a/moto/opsworks/models.py +++ b/moto/opsworks/models.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from moto.core import BaseBackend, BaseModel from moto.ec2 import ec2_backends +from moto.core import ACCOUNT_ID import uuid import datetime from random import choice @@ -15,24 +16,30 @@ class OpsworkInstance(BaseModel): used to populate a reservation request when "start" is called """ - def __init__(self, stack_id, layer_ids, instance_type, ec2_backend, - auto_scale_type=None, - hostname=None, - os=None, - ami_id="ami-08111162", - ssh_keyname=None, - availability_zone=None, - virtualization_type="hvm", - subnet_id=None, - architecture="x86_64", - root_device_type="ebs", - block_device_mappings=None, - install_updates_on_boot=True, - ebs_optimized=False, - agent_version="INHERIT", - instance_profile_arn=None, - associate_public_ip=None, - security_group_ids=None): + def __init__( + self, + stack_id, + layer_ids, + instance_type, + ec2_backend, + auto_scale_type=None, + hostname=None, + os=None, + ami_id="ami-08111162", + ssh_keyname=None, + availability_zone=None, + virtualization_type="hvm", + subnet_id=None, + architecture="x86_64", + root_device_type="ebs", + block_device_mappings=None, + install_updates_on_boot=True, + ebs_optimized=False, + agent_version="INHERIT", + instance_profile_arn=None, + associate_public_ip=None, + security_group_ids=None, + ): self.ec2_backend = ec2_backend @@ -55,13 +62,12 @@ class OpsworkInstance(BaseModel): # formatting in to_dict() self.block_device_mappings = block_device_mappings if self.block_device_mappings is None: - self.block_device_mappings = [{ - 'DeviceName': 'ROOT_DEVICE', - 'Ebs': { - 'VolumeSize': 8, - 'VolumeType': 'gp2' + self.block_device_mappings = [ + { + "DeviceName": "ROOT_DEVICE", + "Ebs": {"VolumeSize": 8, "VolumeType": "gp2"}, } - }] + ] self.security_group_ids = security_group_ids if self.security_group_ids is None: self.security_group_ids = [] @@ -102,9 +108,9 @@ class OpsworkInstance(BaseModel): ) self.instance = reservation.instances[0] self.reported_os = { - 'Family': 'rhel (fixed)', - 'Name': 'amazon (fixed)', - 'Version': '2016.03 (fixed)' + "Family": "rhel (fixed)", + "Name": "amazon (fixed)", + "Version": "2016.03 (fixed)", } self.platform = self.instance.platform self.security_group_ids = self.instance.security_groups @@ -156,32 +162,43 @@ class OpsworkInstance(BaseModel): d.update({"RootDeviceVolumeId": "vol-a20e450a (fixed)"}) if self.ssh_keyname is not None: d.update( - {"SshHostDsaKeyFingerprint": "24:36:32:fe:d8:5f:9c:18:b1:ad:37:e9:eb:e8:69:58 (fixed)"}) + { + "SshHostDsaKeyFingerprint": "24:36:32:fe:d8:5f:9c:18:b1:ad:37:e9:eb:e8:69:58 (fixed)" + } + ) d.update( - {"SshHostRsaKeyFingerprint": "3c:bd:37:52:d7:ca:67:e1:6e:4b:ac:31:86:79:f5:6c (fixed)"}) + { + "SshHostRsaKeyFingerprint": "3c:bd:37:52:d7:ca:67:e1:6e:4b:ac:31:86:79:f5:6c (fixed)" + } + ) d.update({"PrivateDns": self.instance.private_dns}) d.update({"PrivateIp": self.instance.private_ip}) - d.update({"PublicDns": getattr(self.instance, 'public_dns', None)}) - d.update({"PublicIp": getattr(self.instance, 'public_ip', None)}) + d.update({"PublicDns": getattr(self.instance, "public_dns", None)}) + d.update({"PublicIp": getattr(self.instance, "public_ip", None)}) return d class Layer(BaseModel): - - def __init__(self, stack_id, type, name, shortname, - attributes=None, - custom_instance_profile_arn=None, - custom_json=None, - custom_security_group_ids=None, - packages=None, - volume_configurations=None, - enable_autohealing=None, - auto_assign_elastic_ips=None, - auto_assign_public_ips=None, - custom_recipes=None, - install_updates_on_boot=None, - use_ebs_optimized_instances=None, - lifecycle_event_configuration=None): + def __init__( + self, + stack_id, + type, + name, + shortname, + attributes=None, + custom_instance_profile_arn=None, + custom_json=None, + custom_security_group_ids=None, + packages=None, + volume_configurations=None, + enable_autohealing=None, + auto_assign_elastic_ips=None, + auto_assign_public_ips=None, + custom_recipes=None, + install_updates_on_boot=None, + use_ebs_optimized_instances=None, + lifecycle_event_configuration=None, + ): self.stack_id = stack_id self.type = type self.name = name @@ -190,31 +207,31 @@ class Layer(BaseModel): self.attributes = attributes if attributes is None: self.attributes = { - 'BundlerVersion': None, - 'EcsClusterArn': None, - 'EnableHaproxyStats': None, - 'GangliaPassword': None, - 'GangliaUrl': None, - 'GangliaUser': None, - 'HaproxyHealthCheckMethod': None, - 'HaproxyHealthCheckUrl': None, - 'HaproxyStatsPassword': None, - 'HaproxyStatsUrl': None, - 'HaproxyStatsUser': None, - 'JavaAppServer': None, - 'JavaAppServerVersion': None, - 'Jvm': None, - 'JvmOptions': None, - 'JvmVersion': None, - 'ManageBundler': None, - 'MemcachedMemory': None, - 'MysqlRootPassword': None, - 'MysqlRootPasswordUbiquitous': None, - 'NodejsVersion': None, - 'PassengerVersion': None, - 'RailsStack': None, - 'RubyVersion': None, - 'RubygemsVersion': None + "BundlerVersion": None, + "EcsClusterArn": None, + "EnableHaproxyStats": None, + "GangliaPassword": None, + "GangliaUrl": None, + "GangliaUser": None, + "HaproxyHealthCheckMethod": None, + "HaproxyHealthCheckUrl": None, + "HaproxyStatsPassword": None, + "HaproxyStatsUrl": None, + "HaproxyStatsUser": None, + "JavaAppServer": None, + "JavaAppServerVersion": None, + "Jvm": None, + "JvmOptions": None, + "JvmVersion": None, + "ManageBundler": None, + "MemcachedMemory": None, + "MysqlRootPassword": None, + "MysqlRootPasswordUbiquitous": None, + "NodejsVersion": None, + "PassengerVersion": None, + "RailsStack": None, + "RubyVersion": None, + "RubygemsVersion": None, } # May not be accurate self.packages = packages @@ -224,11 +241,11 @@ class Layer(BaseModel): self.custom_recipes = custom_recipes if custom_recipes is None: self.custom_recipes = { - 'Configure': [], - 'Deploy': [], - 'Setup': [], - 'Shutdown': [], - 'Undeploy': [], + "Configure": [], + "Deploy": [], + "Setup": [], + "Shutdown": [], + "Undeploy": [], } self.custom_security_group_ids = custom_security_group_ids @@ -271,9 +288,9 @@ class Layer(BaseModel): "Configure": [], "Setup": [], "Shutdown": [], - "Undeploy": [] + "Undeploy": [], }, # May not be accurate - "DefaultSecurityGroupNames": ['AWS-OpsWorks-Custom-Server'], + "DefaultSecurityGroupNames": ["AWS-OpsWorks-Custom-Server"], "EnableAutoHealing": self.enable_autohealing, "LayerId": self.id, "LifecycleEventConfiguration": self.lifecycle_event_configuration, @@ -287,29 +304,33 @@ class Layer(BaseModel): if self.custom_json is not None: d.update({"CustomJson": self.custom_json}) if self.custom_instance_profile_arn is not None: - d.update( - {"CustomInstanceProfileArn": self.custom_instance_profile_arn}) + d.update({"CustomInstanceProfileArn": self.custom_instance_profile_arn}) return d class Stack(BaseModel): - - def __init__(self, name, region, service_role_arn, default_instance_profile_arn, - vpcid="vpc-1f99bf7a", - attributes=None, - default_os='Ubuntu 12.04 LTS', - hostname_theme='Layer_Dependent', - default_availability_zone='us-east-1a', - default_subnet_id='subnet-73981004', - custom_json=None, - configuration_manager=None, - chef_configuration=None, - use_custom_cookbooks=False, - use_opsworks_security_groups=True, - custom_cookbooks_source=None, - default_ssh_keyname=None, - default_root_device_type='instance-store', - agent_version='LATEST'): + def __init__( + self, + name, + region, + service_role_arn, + default_instance_profile_arn, + vpcid="vpc-1f99bf7a", + attributes=None, + default_os="Ubuntu 12.04 LTS", + hostname_theme="Layer_Dependent", + default_availability_zone="us-east-1a", + default_subnet_id="subnet-73981004", + custom_json=None, + configuration_manager=None, + chef_configuration=None, + use_custom_cookbooks=False, + use_opsworks_security_groups=True, + custom_cookbooks_source=None, + default_ssh_keyname=None, + default_root_device_type="instance-store", + agent_version="LATEST", + ): self.name = name self.region = region @@ -319,11 +340,11 @@ class Stack(BaseModel): self.vpcid = vpcid self.attributes = attributes if attributes is None: - self.attributes = {'Color': None} + self.attributes = {"Color": None} self.configuration_manager = configuration_manager if configuration_manager is None: - self.configuration_manager = {'Name': 'Chef', 'Version': '11.4'} + self.configuration_manager = {"Name": "Chef", "Version": "11.4"} self.chef_configuration = chef_configuration if chef_configuration is None: @@ -347,7 +368,7 @@ class Stack(BaseModel): self.id = "{0}".format(uuid.uuid4()) self.layers = [] self.apps = [] - self.account_number = "123456789012" + self.account_number = ACCOUNT_ID self.created_at = datetime.datetime.utcnow() def __eq__(self, other): @@ -356,15 +377,13 @@ class Stack(BaseModel): def generate_hostname(self): # this doesn't match amazon's implementation return "{theme}-{rand}-(moto)".format( - theme=self.hostname_theme, - rand=[choice("abcdefghijhk") for _ in range(4)]) + theme=self.hostname_theme, rand=[choice("abcdefghijhk") for _ in range(4)] + ) @property def arn(self): return "arn:aws:opsworks:{region}:{account_number}:stack/{id}".format( - region=self.region, - account_number=self.account_number, - id=self.id + region=self.region, account_number=self.account_number, id=self.id ) def to_dict(self): @@ -389,7 +408,7 @@ class Stack(BaseModel): "StackId": self.id, "UseCustomCookbooks": self.use_custom_cookbooks, "UseOpsworksSecurityGroups": self.use_opsworks_security_groups, - "VpcId": self.vpcid + "VpcId": self.vpcid, } if self.custom_json is not None: response.update({"CustomJson": self.custom_json}) @@ -399,17 +418,21 @@ class Stack(BaseModel): class App(BaseModel): - - def __init__(self, stack_id, name, type, - shortname=None, - description=None, - datasources=None, - app_source=None, - domains=None, - enable_ssl=False, - ssl_configuration=None, - attributes=None, - environment=None): + def __init__( + self, + stack_id, + name, + type, + shortname=None, + description=None, + datasources=None, + app_source=None, + domains=None, + enable_ssl=False, + ssl_configuration=None, + attributes=None, + environment=None, + ): self.stack_id = stack_id self.name = name self.type = type @@ -463,13 +486,12 @@ class App(BaseModel): "Shortname": self.shortname, "SslConfiguration": self.ssl_configuration, "StackId": self.stack_id, - "Type": self.type + "Type": self.type, } return d class OpsWorksBackend(BaseBackend): - def __init__(self, ec2_backend): self.stacks = {} self.layers = {} @@ -488,55 +510,59 @@ class OpsWorksBackend(BaseBackend): return stack def create_layer(self, **kwargs): - name = kwargs['name'] - shortname = kwargs['shortname'] - stackid = kwargs['stack_id'] + name = kwargs["name"] + shortname = kwargs["shortname"] + stackid = kwargs["stack_id"] if stackid not in self.stacks: raise ResourceNotFoundException(stackid) if name in [l.name for l in self.stacks[stackid].layers]: raise ValidationException( - 'There is already a layer named "{0}" ' - 'for this stack'.format(name)) + 'There is already a layer named "{0}" ' "for this stack".format(name) + ) if shortname in [l.shortname for l in self.stacks[stackid].layers]: raise ValidationException( 'There is already a layer with shortname "{0}" ' - 'for this stack'.format(shortname)) + "for this stack".format(shortname) + ) layer = Layer(**kwargs) self.layers[layer.id] = layer self.stacks[stackid].layers.append(layer) return layer def create_app(self, **kwargs): - name = kwargs['name'] - stackid = kwargs['stack_id'] + name = kwargs["name"] + stackid = kwargs["stack_id"] if stackid not in self.stacks: raise ResourceNotFoundException(stackid) if name in [a.name for a in self.stacks[stackid].apps]: raise ValidationException( - 'There is already an app named "{0}" ' - 'for this stack'.format(name)) + 'There is already an app named "{0}" ' "for this stack".format(name) + ) app = App(**kwargs) self.apps[app.id] = app self.stacks[stackid].apps.append(app) return app def create_instance(self, **kwargs): - stack_id = kwargs['stack_id'] - layer_ids = kwargs['layer_ids'] + stack_id = kwargs["stack_id"] + layer_ids = kwargs["layer_ids"] if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) + "Unable to find stack with ID {0}".format(stack_id) + ) unknown_layers = set(layer_ids) - set(self.layers.keys()) if unknown_layers: raise ResourceNotFoundException(", ".join(unknown_layers)) layers = [self.layers[id] for id in layer_ids] - if len(set([layer.stack_id for layer in layers])) != 1 or \ - any([layer.stack_id != stack_id for layer in layers]): + if len(set([layer.stack_id for layer in layers])) != 1 or any( + [layer.stack_id != stack_id for layer in layers] + ): raise ValidationException( - "Please only provide layer IDs from the same stack") + "Please only provide layer IDs from the same stack" + ) stack = self.stacks[stack_id] # pick the first to set default instance_profile_arn and @@ -549,12 +575,9 @@ class OpsWorksBackend(BaseBackend): kwargs.setdefault("subnet_id", stack.default_subnet_id) kwargs.setdefault("root_device_type", stack.default_root_device_type) if layer.custom_instance_profile_arn: - kwargs.setdefault("instance_profile_arn", - layer.custom_instance_profile_arn) - kwargs.setdefault("instance_profile_arn", - stack.default_instance_profile_arn) - kwargs.setdefault("security_group_ids", - layer.custom_security_group_ids) + kwargs.setdefault("instance_profile_arn", layer.custom_instance_profile_arn) + kwargs.setdefault("instance_profile_arn", stack.default_instance_profile_arn) + kwargs.setdefault("security_group_ids", layer.custom_security_group_ids) kwargs.setdefault("associate_public_ip", layer.auto_assign_public_ips) kwargs.setdefault("ebs_optimized", layer.use_ebs_optimized_instances) kwargs.update({"ec2_backend": self.ec2_backend}) @@ -579,7 +602,8 @@ class OpsWorksBackend(BaseBackend): if stack_id is not None: if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) + "Unable to find stack with ID {0}".format(stack_id) + ) return [layer.to_dict() for layer in self.stacks[stack_id].layers] unknown_layers = set(layer_ids) - set(self.layers.keys()) @@ -595,7 +619,8 @@ class OpsWorksBackend(BaseBackend): if stack_id is not None: if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) + "Unable to find stack with ID {0}".format(stack_id) + ) return [app.to_dict() for app in self.stacks[stack_id].apps] unknown_apps = set(app_ids) - set(self.apps.keys()) @@ -605,9 +630,11 @@ class OpsWorksBackend(BaseBackend): def describe_instances(self, instance_ids, layer_id, stack_id): if len(list(filter(None, (instance_ids, layer_id, stack_id)))) != 1: - raise ValidationException("Please provide either one or more " - "instance IDs or one stack ID or one " - "layer ID") + raise ValidationException( + "Please provide either one or more " + "instance IDs or one stack ID or one " + "layer ID" + ) if instance_ids: unknown_instances = set(instance_ids) - set(self.instances.keys()) if unknown_instances: @@ -617,23 +644,28 @@ class OpsWorksBackend(BaseBackend): if layer_id: if layer_id not in self.layers: raise ResourceNotFoundException( - "Unable to find layer with ID {0}".format(layer_id)) - instances = [i.to_dict() for i in self.instances.values() - if layer_id in i.layer_ids] + "Unable to find layer with ID {0}".format(layer_id) + ) + instances = [ + i.to_dict() for i in self.instances.values() if layer_id in i.layer_ids + ] return instances if stack_id: if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) - instances = [i.to_dict() for i in self.instances.values() - if stack_id == i.stack_id] + "Unable to find stack with ID {0}".format(stack_id) + ) + instances = [ + i.to_dict() for i in self.instances.values() if stack_id == i.stack_id + ] return instances def start_instance(self, instance_id): if instance_id not in self.instances: raise ResourceNotFoundException( - "Unable to find instance with ID {0}".format(instance_id)) + "Unable to find instance with ID {0}".format(instance_id) + ) self.instances[instance_id].start() diff --git a/moto/opsworks/responses.py b/moto/opsworks/responses.py index c9f8fe125..870b75244 100644 --- a/moto/opsworks/responses.py +++ b/moto/opsworks/responses.py @@ -7,7 +7,6 @@ from .models import opsworks_backends class OpsWorksResponse(BaseResponse): - @property def parameters(self): return json.loads(self.body) @@ -23,23 +22,22 @@ class OpsWorksResponse(BaseResponse): vpcid=self.parameters.get("VpcId"), attributes=self.parameters.get("Attributes"), default_instance_profile_arn=self.parameters.get( - "DefaultInstanceProfileArn"), + "DefaultInstanceProfileArn" + ), default_os=self.parameters.get("DefaultOs"), hostname_theme=self.parameters.get("HostnameTheme"), - default_availability_zone=self.parameters.get( - "DefaultAvailabilityZone"), + default_availability_zone=self.parameters.get("DefaultAvailabilityZone"), default_subnet_id=self.parameters.get("DefaultInstanceProfileArn"), custom_json=self.parameters.get("CustomJson"), configuration_manager=self.parameters.get("ConfigurationManager"), chef_configuration=self.parameters.get("ChefConfiguration"), use_custom_cookbooks=self.parameters.get("UseCustomCookbooks"), use_opsworks_security_groups=self.parameters.get( - "UseOpsworksSecurityGroups"), - custom_cookbooks_source=self.parameters.get( - "CustomCookbooksSource"), + "UseOpsworksSecurityGroups" + ), + custom_cookbooks_source=self.parameters.get("CustomCookbooksSource"), default_ssh_keyname=self.parameters.get("DefaultSshKeyName"), - default_root_device_type=self.parameters.get( - "DefaultRootDeviceType"), + default_root_device_type=self.parameters.get("DefaultRootDeviceType"), service_role_arn=self.parameters.get("ServiceRoleArn"), agent_version=self.parameters.get("AgentVersion"), ) @@ -48,47 +46,43 @@ class OpsWorksResponse(BaseResponse): def create_layer(self): kwargs = dict( - stack_id=self.parameters.get('StackId'), - type=self.parameters.get('Type'), - name=self.parameters.get('Name'), - shortname=self.parameters.get('Shortname'), - attributes=self.parameters.get('Attributes'), - custom_instance_profile_arn=self.parameters.get( - "CustomInstanceProfileArn"), + stack_id=self.parameters.get("StackId"), + type=self.parameters.get("Type"), + name=self.parameters.get("Name"), + shortname=self.parameters.get("Shortname"), + attributes=self.parameters.get("Attributes"), + custom_instance_profile_arn=self.parameters.get("CustomInstanceProfileArn"), custom_json=self.parameters.get("CustomJson"), - custom_security_group_ids=self.parameters.get( - 'CustomSecurityGroupIds'), - packages=self.parameters.get('Packages'), + custom_security_group_ids=self.parameters.get("CustomSecurityGroupIds"), + packages=self.parameters.get("Packages"), volume_configurations=self.parameters.get("VolumeConfigurations"), enable_autohealing=self.parameters.get("EnableAutoHealing"), - auto_assign_elastic_ips=self.parameters.get( - "AutoAssignElasticIps"), + auto_assign_elastic_ips=self.parameters.get("AutoAssignElasticIps"), auto_assign_public_ips=self.parameters.get("AutoAssignPublicIps"), custom_recipes=self.parameters.get("CustomRecipes"), - install_updates_on_boot=self.parameters.get( - "InstallUpdatesOnBoot"), - use_ebs_optimized_instances=self.parameters.get( - "UseEbsOptimizedInstances"), + install_updates_on_boot=self.parameters.get("InstallUpdatesOnBoot"), + use_ebs_optimized_instances=self.parameters.get("UseEbsOptimizedInstances"), lifecycle_event_configuration=self.parameters.get( - "LifecycleEventConfiguration") + "LifecycleEventConfiguration" + ), ) layer = self.opsworks_backend.create_layer(**kwargs) return json.dumps({"LayerId": layer.id}, indent=1) def create_app(self): kwargs = dict( - stack_id=self.parameters.get('StackId'), - name=self.parameters.get('Name'), - type=self.parameters.get('Type'), - shortname=self.parameters.get('Shortname'), - description=self.parameters.get('Description'), - datasources=self.parameters.get('DataSources'), - app_source=self.parameters.get('AppSource'), - domains=self.parameters.get('Domains'), - enable_ssl=self.parameters.get('EnableSsl'), - ssl_configuration=self.parameters.get('SslConfiguration'), - attributes=self.parameters.get('Attributes'), - environment=self.parameters.get('Environment') + stack_id=self.parameters.get("StackId"), + name=self.parameters.get("Name"), + type=self.parameters.get("Type"), + shortname=self.parameters.get("Shortname"), + description=self.parameters.get("Description"), + datasources=self.parameters.get("DataSources"), + app_source=self.parameters.get("AppSource"), + domains=self.parameters.get("Domains"), + enable_ssl=self.parameters.get("EnableSsl"), + ssl_configuration=self.parameters.get("SslConfiguration"), + attributes=self.parameters.get("Attributes"), + environment=self.parameters.get("Environment"), ) app = self.opsworks_backend.create_app(**kwargs) return json.dumps({"AppId": app.id}, indent=1) @@ -109,8 +103,7 @@ class OpsWorksResponse(BaseResponse): architecture=self.parameters.get("Architecture"), root_device_type=self.parameters.get("RootDeviceType"), block_device_mappings=self.parameters.get("BlockDeviceMappings"), - install_updates_on_boot=self.parameters.get( - "InstallUpdatesOnBoot"), + install_updates_on_boot=self.parameters.get("InstallUpdatesOnBoot"), ebs_optimized=self.parameters.get("EbsOptimized"), agent_version=self.parameters.get("AgentVersion"), ) @@ -139,7 +132,8 @@ class OpsWorksResponse(BaseResponse): layer_id = self.parameters.get("LayerId") stack_id = self.parameters.get("StackId") instances = self.opsworks_backend.describe_instances( - instance_ids, layer_id, stack_id) + instance_ids, layer_id, stack_id + ) return json.dumps({"Instances": instances}, indent=1) def start_instance(self): diff --git a/moto/opsworks/urls.py b/moto/opsworks/urls.py index 3d72bb0dd..1e5246e59 100644 --- a/moto/opsworks/urls.py +++ b/moto/opsworks/urls.py @@ -3,10 +3,6 @@ from .responses import OpsWorksResponse # AWS OpsWorks has a single endpoint: opsworks.us-east-1.amazonaws.com # and only supports HTTPS requests. -url_bases = [ - "https?://opsworks.us-east-1.amazonaws.com" -] +url_bases = ["https?://opsworks.us-east-1.amazonaws.com"] -url_paths = { - '{0}/$': OpsWorksResponse.dispatch, -} +url_paths = {"{0}/$": OpsWorksResponse.dispatch} diff --git a/moto/organizations/exceptions.py b/moto/organizations/exceptions.py new file mode 100644 index 000000000..01b98da7e --- /dev/null +++ b/moto/organizations/exceptions.py @@ -0,0 +1,12 @@ +from __future__ import unicode_literals +from moto.core.exceptions import JsonRESTError + + +class InvalidInputException(JsonRESTError): + code = 400 + + def __init__(self): + super(InvalidInputException, self).__init__( + "InvalidInputException", + "You provided a value that does not match the required pattern.", + ) diff --git a/moto/organizations/models.py b/moto/organizations/models.py index 561c6c3a8..42e4dd00a 100644 --- a/moto/organizations/models.py +++ b/moto/organizations/models.py @@ -8,20 +8,19 @@ from moto.core import BaseBackend, BaseModel from moto.core.exceptions import RESTError from moto.core.utils import unix_time from moto.organizations import utils +from moto.organizations.exceptions import InvalidInputException class FakeOrganization(BaseModel): - def __init__(self, feature_set): self.id = utils.make_random_org_id() self.root_id = utils.make_random_root_id() self.feature_set = feature_set self.master_account_id = utils.MASTER_ACCOUNT_ID self.master_account_email = utils.MASTER_ACCOUNT_EMAIL - self.available_policy_types = [{ - 'Type': 'SERVICE_CONTROL_POLICY', - 'Status': 'ENABLED' - }] + self.available_policy_types = [ + {"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"} + ] @property def arn(self): @@ -33,129 +32,115 @@ class FakeOrganization(BaseModel): def describe(self): return { - 'Organization': { - 'Id': self.id, - 'Arn': self.arn, - 'FeatureSet': self.feature_set, - 'MasterAccountArn': self.master_account_arn, - 'MasterAccountId': self.master_account_id, - 'MasterAccountEmail': self.master_account_email, - 'AvailablePolicyTypes': self.available_policy_types, + "Organization": { + "Id": self.id, + "Arn": self.arn, + "FeatureSet": self.feature_set, + "MasterAccountArn": self.master_account_arn, + "MasterAccountId": self.master_account_id, + "MasterAccountEmail": self.master_account_email, + "AvailablePolicyTypes": self.available_policy_types, } } class FakeAccount(BaseModel): - def __init__(self, organization, **kwargs): - self.type = 'ACCOUNT' + self.type = "ACCOUNT" self.organization_id = organization.id self.master_account_id = organization.master_account_id self.create_account_status_id = utils.make_random_create_account_status_id() self.id = utils.make_random_account_id() - self.name = kwargs['AccountName'] - self.email = kwargs['Email'] + self.name = kwargs["AccountName"] + self.email = kwargs["Email"] self.create_time = datetime.datetime.utcnow() - self.status = 'ACTIVE' - self.joined_method = 'CREATED' + self.status = "ACTIVE" + self.joined_method = "CREATED" self.parent_id = organization.root_id self.attached_policies = [] + self.tags = {} @property def arn(self): return utils.ACCOUNT_ARN_FORMAT.format( - self.master_account_id, - self.organization_id, - self.id + self.master_account_id, self.organization_id, self.id ) @property def create_account_status(self): return { - 'CreateAccountStatus': { - 'Id': self.create_account_status_id, - 'AccountName': self.name, - 'State': 'SUCCEEDED', - 'RequestedTimestamp': unix_time(self.create_time), - 'CompletedTimestamp': unix_time(self.create_time), - 'AccountId': self.id, + "CreateAccountStatus": { + "Id": self.create_account_status_id, + "AccountName": self.name, + "State": "SUCCEEDED", + "RequestedTimestamp": unix_time(self.create_time), + "CompletedTimestamp": unix_time(self.create_time), + "AccountId": self.id, } } def describe(self): return { - 'Account': { - 'Id': self.id, - 'Arn': self.arn, - 'Email': self.email, - 'Name': self.name, - 'Status': self.status, - 'JoinedMethod': self.joined_method, - 'JoinedTimestamp': unix_time(self.create_time), + "Account": { + "Id": self.id, + "Arn": self.arn, + "Email": self.email, + "Name": self.name, + "Status": self.status, + "JoinedMethod": self.joined_method, + "JoinedTimestamp": unix_time(self.create_time), } } class FakeOrganizationalUnit(BaseModel): - def __init__(self, organization, **kwargs): - self.type = 'ORGANIZATIONAL_UNIT' + self.type = "ORGANIZATIONAL_UNIT" self.organization_id = organization.id self.master_account_id = organization.master_account_id self.id = utils.make_random_ou_id(organization.root_id) - self.name = kwargs.get('Name') - self.parent_id = kwargs.get('ParentId') + self.name = kwargs.get("Name") + self.parent_id = kwargs.get("ParentId") self._arn_format = utils.OU_ARN_FORMAT self.attached_policies = [] @property def arn(self): return self._arn_format.format( - self.master_account_id, - self.organization_id, - self.id + self.master_account_id, self.organization_id, self.id ) def describe(self): return { - 'OrganizationalUnit': { - 'Id': self.id, - 'Arn': self.arn, - 'Name': self.name, - } + "OrganizationalUnit": {"Id": self.id, "Arn": self.arn, "Name": self.name} } class FakeRoot(FakeOrganizationalUnit): - def __init__(self, organization, **kwargs): super(FakeRoot, self).__init__(organization, **kwargs) - self.type = 'ROOT' + self.type = "ROOT" self.id = organization.root_id - self.name = 'Root' - self.policy_types = [{ - 'Type': 'SERVICE_CONTROL_POLICY', - 'Status': 'ENABLED' - }] + self.name = "Root" + self.policy_types = [{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}] self._arn_format = utils.ROOT_ARN_FORMAT self.attached_policies = [] def describe(self): return { - 'Id': self.id, - 'Arn': self.arn, - 'Name': self.name, - 'PolicyTypes': self.policy_types + "Id": self.id, + "Arn": self.arn, + "Name": self.name, + "PolicyTypes": self.policy_types, } class FakeServiceControlPolicy(BaseModel): - def __init__(self, organization, **kwargs): - self.content = kwargs.get('Content') - self.description = kwargs.get('Description') - self.name = kwargs.get('Name') - self.type = kwargs.get('Type') + self.content = kwargs.get("Content") + self.description = kwargs.get("Description") + self.name = kwargs.get("Name") + self.type = kwargs.get("Type") self.id = utils.make_random_service_control_policy_id() self.aws_managed = False self.organization_id = organization.id @@ -166,29 +151,26 @@ class FakeServiceControlPolicy(BaseModel): @property def arn(self): return self._arn_format.format( - self.master_account_id, - self.organization_id, - self.id + self.master_account_id, self.organization_id, self.id ) def describe(self): return { - 'Policy': { - 'PolicySummary': { - 'Id': self.id, - 'Arn': self.arn, - 'Name': self.name, - 'Description': self.description, - 'Type': self.type, - 'AwsManaged': self.aws_managed, + "Policy": { + "PolicySummary": { + "Id": self.id, + "Arn": self.arn, + "Name": self.name, + "Description": self.description, + "Type": self.type, + "AwsManaged": self.aws_managed, }, - 'Content': self.content + "Content": self.content, } } class OrganizationsBackend(BaseBackend): - def __init__(self): self.org = None self.accounts = [] @@ -196,33 +178,25 @@ class OrganizationsBackend(BaseBackend): self.policies = [] def create_organization(self, **kwargs): - self.org = FakeOrganization(kwargs['FeatureSet']) + self.org = FakeOrganization(kwargs["FeatureSet"]) root_ou = FakeRoot(self.org) self.ou.append(root_ou) master_account = FakeAccount( - self.org, - AccountName='master', - Email=self.org.master_account_email, + self.org, AccountName="master", Email=self.org.master_account_email ) master_account.id = self.org.master_account_id self.accounts.append(master_account) default_policy = FakeServiceControlPolicy( self.org, - Name='FullAWSAccess', - Description='Allows access to every operation', - Type='SERVICE_CONTROL_POLICY', + Name="FullAWSAccess", + Description="Allows access to every operation", + Type="SERVICE_CONTROL_POLICY", Content=json.dumps( { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}], } - ) + ), ) default_policy.id = utils.DEFAULT_POLICY_ID default_policy.aws_managed = True @@ -234,15 +208,13 @@ class OrganizationsBackend(BaseBackend): def describe_organization(self): if not self.org: raise RESTError( - 'AWSOrganizationsNotInUseException', - "Your account is not a member of an organization." + "AWSOrganizationsNotInUseException", + "Your account is not a member of an organization.", ) return self.org.describe() def list_roots(self): - return dict( - Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)] - ) + return dict(Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)]) def create_organizational_unit(self, **kwargs): new_ou = FakeOrganizationalUnit(self.org, **kwargs) @@ -254,8 +226,8 @@ class OrganizationsBackend(BaseBackend): ou = next((ou for ou in self.ou if ou.id == ou_id), None) if ou is None: raise RESTError( - 'OrganizationalUnitNotFoundException', - "You specified an organizational unit that doesn't exist." + "OrganizationalUnitNotFoundException", + "You specified an organizational unit that doesn't exist.", ) return ou @@ -264,24 +236,19 @@ class OrganizationsBackend(BaseBackend): self.get_organizational_unit_by_id(parent_id) except RESTError: raise RESTError( - 'ParentNotFoundException', - "You specified parent that doesn't exist." + "ParentNotFoundException", "You specified parent that doesn't exist." ) return parent_id def describe_organizational_unit(self, **kwargs): - ou = self.get_organizational_unit_by_id(kwargs['OrganizationalUnitId']) + ou = self.get_organizational_unit_by_id(kwargs["OrganizationalUnitId"]) return ou.describe() def list_organizational_units_for_parent(self, **kwargs): - parent_id = self.validate_parent_id(kwargs['ParentId']) + parent_id = self.validate_parent_id(kwargs["ParentId"]) return dict( OrganizationalUnits=[ - { - 'Id': ou.id, - 'Arn': ou.arn, - 'Name': ou.name, - } + {"Id": ou.id, "Arn": ou.arn, "Name": ou.name} for ou in self.ou if ou.parent_id == parent_id ] @@ -294,76 +261,88 @@ class OrganizationsBackend(BaseBackend): return new_account.create_account_status def get_account_by_id(self, account_id): - account = next(( - account for account in self.accounts - if account.id == account_id - ), None) + account = next( + (account for account in self.accounts if account.id == account_id), None + ) if account is None: raise RESTError( - 'AccountNotFoundException', - "You specified an account that doesn't exist." + "AccountNotFoundException", + "You specified an account that doesn't exist.", + ) + return account + + def get_account_by_attr(self, attr, value): + account = next( + ( + account + for account in self.accounts + if hasattr(account, attr) and getattr(account, attr) == value + ), + None, + ) + if account is None: + raise RESTError( + "AccountNotFoundException", + "You specified an account that doesn't exist.", ) return account def describe_account(self, **kwargs): - account = self.get_account_by_id(kwargs['AccountId']) + account = self.get_account_by_id(kwargs["AccountId"]) return account.describe() + def describe_create_account_status(self, **kwargs): + account = self.get_account_by_attr( + "create_account_status_id", kwargs["CreateAccountRequestId"] + ) + return account.create_account_status + def list_accounts(self): return dict( - Accounts=[account.describe()['Account'] for account in self.accounts] + Accounts=[account.describe()["Account"] for account in self.accounts] ) def list_accounts_for_parent(self, **kwargs): - parent_id = self.validate_parent_id(kwargs['ParentId']) + parent_id = self.validate_parent_id(kwargs["ParentId"]) return dict( Accounts=[ - account.describe()['Account'] + account.describe()["Account"] for account in self.accounts if account.parent_id == parent_id ] ) def move_account(self, **kwargs): - new_parent_id = self.validate_parent_id(kwargs['DestinationParentId']) - self.validate_parent_id(kwargs['SourceParentId']) - account = self.get_account_by_id(kwargs['AccountId']) + new_parent_id = self.validate_parent_id(kwargs["DestinationParentId"]) + self.validate_parent_id(kwargs["SourceParentId"]) + account = self.get_account_by_id(kwargs["AccountId"]) index = self.accounts.index(account) self.accounts[index].parent_id = new_parent_id def list_parents(self, **kwargs): - if re.compile(r'[0-9]{12}').match(kwargs['ChildId']): - child_object = self.get_account_by_id(kwargs['ChildId']) + if re.compile(r"[0-9]{12}").match(kwargs["ChildId"]): + child_object = self.get_account_by_id(kwargs["ChildId"]) else: - child_object = self.get_organizational_unit_by_id(kwargs['ChildId']) + child_object = self.get_organizational_unit_by_id(kwargs["ChildId"]) return dict( Parents=[ - { - 'Id': ou.id, - 'Type': ou.type, - } + {"Id": ou.id, "Type": ou.type} for ou in self.ou if ou.id == child_object.parent_id ] ) def list_children(self, **kwargs): - parent_id = self.validate_parent_id(kwargs['ParentId']) - if kwargs['ChildType'] == 'ACCOUNT': + parent_id = self.validate_parent_id(kwargs["ParentId"]) + if kwargs["ChildType"] == "ACCOUNT": obj_list = self.accounts - elif kwargs['ChildType'] == 'ORGANIZATIONAL_UNIT': + elif kwargs["ChildType"] == "ORGANIZATIONAL_UNIT": obj_list = self.ou else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") return dict( Children=[ - { - 'Id': obj.id, - 'Type': kwargs['ChildType'], - } + {"Id": obj.id, "Type": kwargs["ChildType"]} for obj in obj_list if obj.parent_id == parent_id ] @@ -375,101 +354,122 @@ class OrganizationsBackend(BaseBackend): return new_policy.describe() def describe_policy(self, **kwargs): - if re.compile(utils.SCP_ID_REGEX).match(kwargs['PolicyId']): - policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) + if re.compile(utils.SCP_ID_REGEX).match(kwargs["PolicyId"]): + policy = next( + (p for p in self.policies if p.id == kwargs["PolicyId"]), None + ) if policy is None: raise RESTError( - 'PolicyNotFoundException', - "You specified a policy that doesn't exist." + "PolicyNotFoundException", + "You specified a policy that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") return policy.describe() def attach_policy(self, **kwargs): - policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) - if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])): - ou = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None) + policy = next((p for p in self.policies if p.id == kwargs["PolicyId"]), None) + if re.compile(utils.ROOT_ID_REGEX).match(kwargs["TargetId"]) or re.compile( + utils.OU_ID_REGEX + ).match(kwargs["TargetId"]): + ou = next((ou for ou in self.ou if ou.id == kwargs["TargetId"]), None) if ou is not None: if ou not in ou.attached_policies: ou.attached_policies.append(policy) policy.attachments.append(ou) else: raise RESTError( - 'OrganizationalUnitNotFoundException', - "You specified an organizational unit that doesn't exist." + "OrganizationalUnitNotFoundException", + "You specified an organizational unit that doesn't exist.", ) - elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs['TargetId']): - account = next((a for a in self.accounts if a.id == kwargs['TargetId']), None) + elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs["TargetId"]): + account = next( + (a for a in self.accounts if a.id == kwargs["TargetId"]), None + ) if account is not None: if account not in account.attached_policies: account.attached_policies.append(policy) policy.attachments.append(account) else: raise RESTError( - 'AccountNotFoundException', - "You specified an account that doesn't exist." + "AccountNotFoundException", + "You specified an account that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") def list_policies(self, **kwargs): - return dict(Policies=[ - p.describe()['Policy']['PolicySummary'] for p in self.policies - ]) + return dict( + Policies=[p.describe()["Policy"]["PolicySummary"] for p in self.policies] + ) def list_policies_for_target(self, **kwargs): - if re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId']): - obj = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None) + if re.compile(utils.OU_ID_REGEX).match(kwargs["TargetId"]): + obj = next((ou for ou in self.ou if ou.id == kwargs["TargetId"]), None) if obj is None: raise RESTError( - 'OrganizationalUnitNotFoundException', - "You specified an organizational unit that doesn't exist." + "OrganizationalUnitNotFoundException", + "You specified an organizational unit that doesn't exist.", ) - elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs['TargetId']): - obj = next((a for a in self.accounts if a.id == kwargs['TargetId']), None) + elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs["TargetId"]): + obj = next((a for a in self.accounts if a.id == kwargs["TargetId"]), None) if obj is None: raise RESTError( - 'AccountNotFoundException', - "You specified an account that doesn't exist." + "AccountNotFoundException", + "You specified an account that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) - return dict(Policies=[ - p.describe()['Policy']['PolicySummary'] for p in obj.attached_policies - ]) + raise RESTError("InvalidInputException", "You specified an invalid value.") + return dict( + Policies=[ + p.describe()["Policy"]["PolicySummary"] for p in obj.attached_policies + ] + ) def list_targets_for_policy(self, **kwargs): - if re.compile(utils.SCP_ID_REGEX).match(kwargs['PolicyId']): - policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) + if re.compile(utils.SCP_ID_REGEX).match(kwargs["PolicyId"]): + policy = next( + (p for p in self.policies if p.id == kwargs["PolicyId"]), None + ) if policy is None: raise RESTError( - 'PolicyNotFoundException', - "You specified a policy that doesn't exist." + "PolicyNotFoundException", + "You specified a policy that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") objects = [ - { - 'TargetId': obj.id, - 'Arn': obj.arn, - 'Name': obj.name, - 'Type': obj.type, - } for obj in policy.attachments + {"TargetId": obj.id, "Arn": obj.arn, "Name": obj.name, "Type": obj.type} + for obj in policy.attachments ] return dict(Targets=objects) + def tag_resource(self, **kwargs): + account = next((a for a in self.accounts if a.id == kwargs["ResourceId"]), None) + + if account is None: + raise InvalidInputException + + new_tags = {tag["Key"]: tag["Value"] for tag in kwargs["Tags"]} + account.tags.update(new_tags) + + def list_tags_for_resource(self, **kwargs): + account = next((a for a in self.accounts if a.id == kwargs["ResourceId"]), None) + + if account is None: + raise InvalidInputException + + tags = [{"Key": key, "Value": value} for key, value in account.tags.items()] + return dict(Tags=tags) + + def untag_resource(self, **kwargs): + account = next((a for a in self.accounts if a.id == kwargs["ResourceId"]), None) + + if account is None: + raise InvalidInputException + + for key in kwargs["TagKeys"]: + account.tags.pop(key, None) + organizations_backend = OrganizationsBackend() diff --git a/moto/organizations/responses.py b/moto/organizations/responses.py index 814f30bad..7c42eb4ec 100644 --- a/moto/organizations/responses.py +++ b/moto/organizations/responses.py @@ -6,7 +6,6 @@ from .models import organizations_backend class OrganizationsResponse(BaseResponse): - @property def organizations_backend(self): return organizations_backend @@ -27,14 +26,10 @@ class OrganizationsResponse(BaseResponse): ) def describe_organization(self): - return json.dumps( - self.organizations_backend.describe_organization() - ) + return json.dumps(self.organizations_backend.describe_organization()) def list_roots(self): - return json.dumps( - self.organizations_backend.list_roots() - ) + return json.dumps(self.organizations_backend.list_roots()) def create_organizational_unit(self): return json.dumps( @@ -43,12 +38,16 @@ class OrganizationsResponse(BaseResponse): def describe_organizational_unit(self): return json.dumps( - self.organizations_backend.describe_organizational_unit(**self.request_params) + self.organizations_backend.describe_organizational_unit( + **self.request_params + ) ) def list_organizational_units_for_parent(self): return json.dumps( - self.organizations_backend.list_organizational_units_for_parent(**self.request_params) + self.organizations_backend.list_organizational_units_for_parent( + **self.request_params + ) ) def list_parents(self): @@ -66,11 +65,16 @@ class OrganizationsResponse(BaseResponse): self.organizations_backend.describe_account(**self.request_params) ) - def list_accounts(self): + def describe_create_account_status(self): return json.dumps( - self.organizations_backend.list_accounts() + self.organizations_backend.describe_create_account_status( + **self.request_params + ) ) + def list_accounts(self): + return json.dumps(self.organizations_backend.list_accounts()) + def list_accounts_for_parent(self): return json.dumps( self.organizations_backend.list_accounts_for_parent(**self.request_params) @@ -115,3 +119,18 @@ class OrganizationsResponse(BaseResponse): return json.dumps( self.organizations_backend.list_targets_for_policy(**self.request_params) ) + + def tag_resource(self): + return json.dumps( + self.organizations_backend.tag_resource(**self.request_params) + ) + + def list_tags_for_resource(self): + return json.dumps( + self.organizations_backend.list_tags_for_resource(**self.request_params) + ) + + def untag_resource(self): + return json.dumps( + self.organizations_backend.untag_resource(**self.request_params) + ) diff --git a/moto/organizations/urls.py b/moto/organizations/urls.py index 7911f5b53..d0909bbef 100644 --- a/moto/organizations/urls.py +++ b/moto/organizations/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import OrganizationsResponse -url_bases = [ - "https?://organizations.(.+).amazonaws.com", -] +url_bases = ["https?://organizations.(.+).amazonaws.com"] -url_paths = { - '{0}/$': OrganizationsResponse.dispatch, -} +url_paths = {"{0}/$": OrganizationsResponse.dispatch} diff --git a/moto/organizations/utils.py b/moto/organizations/utils.py index 5cbe59ada..e71357ce6 100644 --- a/moto/organizations/utils.py +++ b/moto/organizations/utils.py @@ -2,16 +2,18 @@ from __future__ import unicode_literals import random import string +from moto.core import ACCOUNT_ID -MASTER_ACCOUNT_ID = '123456789012' -MASTER_ACCOUNT_EMAIL = 'master@example.com' -DEFAULT_POLICY_ID = 'p-FullAWSAccess' -ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}' -MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}' -ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}' -ROOT_ARN_FORMAT = 'arn:aws:organizations::{0}:root/{1}/{2}' -OU_ARN_FORMAT = 'arn:aws:organizations::{0}:ou/{1}/{2}' -SCP_ARN_FORMAT = 'arn:aws:organizations::{0}:policy/{1}/service_control_policy/{2}' + +MASTER_ACCOUNT_ID = ACCOUNT_ID +MASTER_ACCOUNT_EMAIL = "master@example.com" +DEFAULT_POLICY_ID = "p-FullAWSAccess" +ORGANIZATION_ARN_FORMAT = "arn:aws:organizations::{0}:organization/{1}" +MASTER_ACCOUNT_ARN_FORMAT = "arn:aws:organizations::{0}:account/{1}/{0}" +ACCOUNT_ARN_FORMAT = "arn:aws:organizations::{0}:account/{1}/{2}" +ROOT_ARN_FORMAT = "arn:aws:organizations::{0}:root/{1}/{2}" +OU_ARN_FORMAT = "arn:aws:organizations::{0}:ou/{1}/{2}" +SCP_ARN_FORMAT = "arn:aws:organizations::{0}:policy/{1}/service_control_policy/{2}" CHARSET = string.ascii_lowercase + string.digits ORG_ID_SIZE = 10 @@ -22,26 +24,26 @@ CREATE_ACCOUNT_STATUS_ID_SIZE = 8 SCP_ID_SIZE = 8 EMAIL_REGEX = "^.+@[a-zA-Z0-9-.]+.[a-zA-Z]{2,3}|[0-9]{1,3}$" -ORG_ID_REGEX = r'o-[a-z0-9]{%s}' % ORG_ID_SIZE -ROOT_ID_REGEX = r'r-[a-z0-9]{%s}' % ROOT_ID_SIZE -OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE) -ACCOUNT_ID_REGEX = r'[0-9]{%s}' % ACCOUNT_ID_SIZE -CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % CREATE_ACCOUNT_STATUS_ID_SIZE -SCP_ID_REGEX = r'%s|p-[a-z0-9]{%s}' % (DEFAULT_POLICY_ID, SCP_ID_SIZE) +ORG_ID_REGEX = r"o-[a-z0-9]{%s}" % ORG_ID_SIZE +ROOT_ID_REGEX = r"r-[a-z0-9]{%s}" % ROOT_ID_SIZE +OU_ID_REGEX = r"ou-[a-z0-9]{%s}-[a-z0-9]{%s}" % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE) +ACCOUNT_ID_REGEX = r"[0-9]{%s}" % ACCOUNT_ID_SIZE +CREATE_ACCOUNT_STATUS_ID_REGEX = r"car-[a-z0-9]{%s}" % CREATE_ACCOUNT_STATUS_ID_SIZE +SCP_ID_REGEX = r"%s|p-[a-z0-9]{%s}" % (DEFAULT_POLICY_ID, SCP_ID_SIZE) def make_random_org_id(): # The regex pattern for an organization ID string requires "o-" # followed by from 10 to 32 lower-case letters or digits. # e.g. 'o-vipjnq5z86' - return 'o-' + ''.join(random.choice(CHARSET) for x in range(ORG_ID_SIZE)) + return "o-" + "".join(random.choice(CHARSET) for x in range(ORG_ID_SIZE)) def make_random_root_id(): # The regex pattern for a root ID string requires "r-" followed by # from 4 to 32 lower-case letters or digits. # e.g. 'r-3zwx' - return 'r-' + ''.join(random.choice(CHARSET) for x in range(ROOT_ID_SIZE)) + return "r-" + "".join(random.choice(CHARSET) for x in range(ROOT_ID_SIZE)) def make_random_ou_id(root_id): @@ -50,28 +52,32 @@ def make_random_ou_id(root_id): # that contains the OU) followed by a second "-" dash and from 8 to 32 # additional lower-case letters or digits. # e.g. ou-g8sd-5oe3bjaw - return '-'.join([ - 'ou', - root_id.partition('-')[2], - ''.join(random.choice(CHARSET) for x in range(OU_ID_SUFFIX_SIZE)), - ]) + return "-".join( + [ + "ou", + root_id.partition("-")[2], + "".join(random.choice(CHARSET) for x in range(OU_ID_SUFFIX_SIZE)), + ] + ) def make_random_account_id(): # The regex pattern for an account ID string requires exactly 12 digits. # e.g. '488633172133' - return ''.join([random.choice(string.digits) for n in range(ACCOUNT_ID_SIZE)]) + return "".join([random.choice(string.digits) for n in range(ACCOUNT_ID_SIZE)]) def make_random_create_account_status_id(): # The regex pattern for an create account request ID string requires # "car-" followed by from 8 to 32 lower-case letters or digits. # e.g. 'car-35gxzwrp' - return 'car-' + ''.join(random.choice(CHARSET) for x in range(CREATE_ACCOUNT_STATUS_ID_SIZE)) + return "car-" + "".join( + random.choice(CHARSET) for x in range(CREATE_ACCOUNT_STATUS_ID_SIZE) + ) def make_random_service_control_policy_id(): # The regex pattern for a policy ID string requires "p-" followed by # from 8 to 128 lower-case letters or digits. # e.g. 'p-k2av4a8a' - return 'p-' + ''.join(random.choice(CHARSET) for x in range(SCP_ID_SIZE)) + return "p-" + "".join(random.choice(CHARSET) for x in range(SCP_ID_SIZE)) diff --git a/moto/packages/httpretty/__init__.py b/moto/packages/httpretty/__init__.py index 679294a4b..c6a78526f 100644 --- a/moto/packages/httpretty/__init__.py +++ b/moto/packages/httpretty/__init__.py @@ -25,7 +25,7 @@ # OTHER DEALINGS IN THE SOFTWARE. from __future__ import unicode_literals -__version__ = version = '0.8.10' +__version__ = version = "0.8.10" from .core import httpretty, httprettified, EmptyRequestHeaders from .errors import HTTPrettyError, UnmockedError diff --git a/moto/packages/httpretty/compat.py b/moto/packages/httpretty/compat.py index b9e215b13..c452dec0e 100644 --- a/moto/packages/httpretty/compat.py +++ b/moto/packages/httpretty/compat.py @@ -34,33 +34,36 @@ if PY3: # pragma: no cover text_type = str byte_type = bytes import io + StringIO = io.BytesIO basestring = (str, bytes) class BaseClass(object): - def __repr__(self): return self.__str__() + + else: # pragma: no cover text_type = unicode byte_type = str import StringIO + StringIO = StringIO.StringIO basestring = basestring class BaseClass(object): - def __repr__(self): ret = self.__str__() if PY3: # pragma: no cover return ret else: - return ret.encode('utf-8') + return ret.encode("utf-8") try: # pragma: no cover from urllib.parse import urlsplit, urlunsplit, parse_qs, quote, quote_plus, unquote + unquote_utf8 = unquote except ImportError: # pragma: no cover from urlparse import urlsplit, urlunsplit, parse_qs, unquote @@ -68,7 +71,7 @@ except ImportError: # pragma: no cover def unquote_utf8(qs): if isinstance(qs, text_type): - qs = qs.encode('utf-8') + qs = qs.encode("utf-8") s = unquote(qs) if isinstance(s, byte_type): return s.decode("utf-8") @@ -88,16 +91,16 @@ if not PY3: # pragma: no cover __all__ = [ - 'PY3', - 'StringIO', - 'text_type', - 'byte_type', - 'BaseClass', - 'BaseHTTPRequestHandler', - 'quote', - 'quote_plus', - 'urlunsplit', - 'urlsplit', - 'parse_qs', - 'ClassTypes', + "PY3", + "StringIO", + "text_type", + "byte_type", + "BaseClass", + "BaseHTTPRequestHandler", + "quote", + "quote_plus", + "urlunsplit", + "urlsplit", + "parse_qs", + "ClassTypes", ] diff --git a/moto/packages/httpretty/core.py b/moto/packages/httpretty/core.py index f94723017..83bd19237 100644 --- a/moto/packages/httpretty/core.py +++ b/moto/packages/httpretty/core.py @@ -52,19 +52,11 @@ from .compat import ( unquote, unquote_utf8, ClassTypes, - basestring -) -from .http import ( - STATUSES, - HttpBaseClass, - parse_requestline, - last_requestline, + basestring, ) +from .http import STATUSES, HttpBaseClass, parse_requestline, last_requestline -from .utils import ( - utf8, - decode_utf8, -) +from .utils import utf8, decode_utf8 from .errors import HTTPrettyError, UnmockedError @@ -91,12 +83,14 @@ if PY3: # pragma: no cover basestring = (bytes, str) try: # pragma: no cover import socks + old_socksocket = socks.socksocket except ImportError: socks = None try: # pragma: no cover import ssl + old_ssl_wrap_socket = ssl.wrap_socket if not PY3: old_sslwrap_simple = ssl.sslwrap_simple @@ -109,7 +103,11 @@ except ImportError: # pragma: no cover ssl = None try: # pragma: no cover - from requests.packages.urllib3.contrib.pyopenssl import inject_into_urllib3, extract_from_urllib3 + from requests.packages.urllib3.contrib.pyopenssl import ( + inject_into_urllib3, + extract_from_urllib3, + ) + pyopenssl_override = True except: pyopenssl_override = False @@ -127,7 +125,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): internal `parse_request` method. It also replaces the `rfile` and `wfile` attributes with StringIO - instances so that we garantee that it won't make any I/O, neighter + instances so that we guarantee that it won't make any I/O, neighter for writing nor reading. It has some convenience attributes: @@ -154,7 +152,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): 'application/x-www-form-urlencoded' """ - def __init__(self, headers, body=''): + def __init__(self, headers, body=""): # first of all, lets make sure that if headers or body are # unicode strings, it must be converted into a utf-8 encoded # byte string @@ -163,7 +161,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): # Now let's concatenate the headers with the body, and create # `rfile` based on it - self.rfile = StringIO(b'\r\n\r\n'.join([self.raw_headers, self.body])) + self.rfile = StringIO(b"\r\n\r\n".join([self.raw_headers, self.body])) self.wfile = StringIO() # Creating `wfile` as an empty # StringIO, just to avoid any real # I/O calls @@ -186,7 +184,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): # `querystring` holds a dictionary with the parsed query string try: - self.path = self.path.encode('iso-8859-1') + self.path = self.path.encode("iso-8859-1") except UnicodeDecodeError: pass @@ -201,9 +199,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): def __str__(self): return ''.format( - self.headers.get('content-type', ''), - len(self.headers), - len(self.body), + self.headers.get("content-type", ""), len(self.headers), len(self.body) ) def parse_querystring(self, qs): @@ -219,13 +215,13 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): """ Attempt to parse the post based on the content-type passed. Return the regular body if not """ PARSING_FUNCTIONS = { - 'application/json': json.loads, - 'text/json': json.loads, - 'application/x-www-form-urlencoded': self.parse_querystring, + "application/json": json.loads, + "text/json": json.loads, + "application/x-www-form-urlencoded": self.parse_querystring, } FALLBACK_FUNCTION = lambda x: x - content_type = self.headers.get('content-type', '') + content_type = self.headers.get("content-type", "") do_parse = PARSING_FUNCTIONS.get(content_type, FALLBACK_FUNCTION) try: @@ -240,19 +236,17 @@ class EmptyRequestHeaders(dict): class HTTPrettyRequestEmpty(object): - body = '' + body = "" headers = EmptyRequestHeaders() class FakeSockFile(StringIO): - def close(self): self.socket.close() StringIO.close(self) class FakeSSLSocket(object): - def __init__(self, sock, *args, **kw): self._httpretty_sock = sock @@ -261,14 +255,19 @@ class FakeSSLSocket(object): class fakesock(object): - class socket(object): _entry = None debuglevel = 0 _sent_data = [] - def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, - proto=0, fileno=None, _sock=None): + def __init__( + self, + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=0, + fileno=None, + _sock=None, + ): """ Matches both the Python 2 API: def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): @@ -300,23 +299,16 @@ class fakesock(object): now = datetime.now() shift = now + timedelta(days=30 * 12) return { - 'notAfter': shift.strftime('%b %d %H:%M:%S GMT'), - 'subjectAltName': ( - ('DNS', '*.%s' % self._host), - ('DNS', self._host), - ('DNS', '*'), + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", "*.%s" % self._host), + ("DNS", self._host), + ("DNS", "*"), ), - 'subject': ( - ( - ('organizationName', '*.%s' % self._host), - ), - ( - ('organizationalUnitName', - 'Domain Control Validated'), - ), - ( - ('commonName', '*.%s' % self._host), - ), + "subject": ( + (("organizationName", "*.%s" % self._host),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", "*.%s" % self._host),), ), } @@ -339,7 +331,9 @@ class fakesock(object): # See issue #206 self.is_http = False else: - self.is_http = self._port in POTENTIAL_HTTP_PORTS | POTENTIAL_HTTPS_PORTS + self.is_http = ( + self._port in POTENTIAL_HTTP_PORTS | POTENTIAL_HTTPS_PORTS + ) if not self.is_http: if self.truesock: @@ -353,7 +347,7 @@ class fakesock(object): self.truesock.close() self._closed = True - def makefile(self, mode='r', bufsize=-1): + def makefile(self, mode="r", bufsize=-1): """Returns this fake socket's own StringIO buffer. If there is an entry associated with the socket, the file @@ -408,9 +402,8 @@ class fakesock(object): self.fd = FakeSockFile() self.fd.socket = self try: - requestline, _ = data.split(b'\r\n', 1) - method, path, version = parse_requestline( - decode_utf8(requestline)) + requestline, _ = data.split(b"\r\n", 1) + method, path, version = parse_requestline(decode_utf8(requestline)) is_parsing_headers = True except ValueError: is_parsing_headers = False @@ -427,8 +420,12 @@ class fakesock(object): headers = utf8(last_requestline(self._sent_data)) meta = self._entry.request.headers body = utf8(self._sent_data[-1]) - if meta.get('transfer-encoding', '') == 'chunked': - if not body.isdigit() and body != b'\r\n' and body != b'0\r\n\r\n': + if meta.get("transfer-encoding", "") == "chunked": + if ( + not body.isdigit() + and body != b"\r\n" + and body != b"0\r\n\r\n" + ): self._entry.request.body += body else: self._entry.request.body += body @@ -439,14 +436,17 @@ class fakesock(object): # path might come with s = urlsplit(path) POTENTIAL_HTTP_PORTS.add(int(s.port or 80)) - headers, body = list(map(utf8, data.split(b'\r\n\r\n', 1))) + headers, body = list(map(utf8, data.split(b"\r\n\r\n", 1))) request = httpretty.historify_request(headers, body) - info = URIInfo(hostname=self._host, port=self._port, - path=s.path, - query=s.query, - last_request=request) + info = URIInfo( + hostname=self._host, + port=self._port, + path=s.path, + query=s.query, + last_request=request, + ) matcher, entries = httpretty.match_uriinfo(info) @@ -464,8 +464,10 @@ class fakesock(object): message = [ "HTTPretty intercepted and unexpected socket method call.", - ("Please open an issue at " - "'https://github.com/gabrielfalcao/HTTPretty/issues'"), + ( + "Please open an issue at " + "'https://github.com/gabrielfalcao/HTTPretty/issues'" + ), "And paste the following traceback:\n", "".join(decode_utf8(lines)), ] @@ -478,22 +480,22 @@ class fakesock(object): self.timeout = new_timeout def send(self, *args, **kwargs): - return self.debug('send', *args, **kwargs) + return self.debug("send", *args, **kwargs) def sendto(self, *args, **kwargs): - return self.debug('sendto', *args, **kwargs) + return self.debug("sendto", *args, **kwargs) def recvfrom_into(self, *args, **kwargs): - return self.debug('recvfrom_into', *args, **kwargs) + return self.debug("recvfrom_into", *args, **kwargs) def recv_into(self, *args, **kwargs): - return self.debug('recv_into', *args, **kwargs) + return self.debug("recv_into", *args, **kwargs) def recvfrom(self, *args, **kwargs): - return self.debug('recvfrom', *args, **kwargs) + return self.debug("recvfrom", *args, **kwargs) def recv(self, *args, **kwargs): - return self.debug('recv', *args, **kwargs) + return self.debug("recv", *args, **kwargs) def __getattr__(self, name): if not self.truesock: @@ -505,7 +507,9 @@ def fake_wrap_socket(s, *args, **kw): return s -def create_fake_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None): +def create_fake_connection( + address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None +): s = fakesock.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: s.settimeout(timeout) @@ -516,26 +520,29 @@ def create_fake_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, sour def fake_gethostbyname(host): - return '127.0.0.1' + return "127.0.0.1" def fake_gethostname(): - return 'localhost' + return "localhost" -def fake_getaddrinfo( - host, port, family=None, socktype=None, proto=None, flags=None): - return [(2, 1, 6, '', (host, port))] +def fake_getaddrinfo(host, port, family=None, socktype=None, proto=None, flags=None): + return [(2, 1, 6, "", (host, port))] class Entry(BaseClass): - - def __init__(self, method, uri, body, - adding_headers=None, - forcing_headers=None, - status=200, - streaming=False, - **headers): + def __init__( + self, + method, + uri, + body, + adding_headers=None, + forcing_headers=None, + status=200, + streaming=False, + **headers + ): self.method = method self.uri = uri @@ -554,7 +561,7 @@ class Entry(BaseClass): self.streaming = streaming if not streaming and not self.body_is_callable: - self.body_length = len(self.body or '') + self.body_length = len(self.body or "") else: self.body_length = 0 @@ -569,10 +576,9 @@ class Entry(BaseClass): self.validate() def validate(self): - content_length_keys = 'Content-Length', 'content-length' + content_length_keys = "Content-Length", "content-length" for key in content_length_keys: - got = self.adding_headers.get( - key, self.forcing_headers.get(key, None)) + got = self.adding_headers.get(key, self.forcing_headers.get(key, None)) if got is None: continue @@ -581,28 +587,25 @@ class Entry(BaseClass): igot = int(got) except ValueError: warnings.warn( - 'HTTPretty got to register the Content-Length header ' - 'with "%r" which is not a number' % got, + "HTTPretty got to register the Content-Length header " + 'with "%r" which is not a number' % got ) if igot > self.body_length: raise HTTPrettyError( - 'HTTPretty got inconsistent parameters. The header ' + "HTTPretty got inconsistent parameters. The header " 'Content-Length you registered expects size "%d" but ' - 'the body you registered for that has actually length ' - '"%d".' % ( - igot, self.body_length, - ) + "the body you registered for that has actually length " + '"%d".' % (igot, self.body_length) ) def __str__(self): - return r'' % ( - self.method, self.uri, self.status) + return r"" % (self.method, self.uri, self.status) def normalize_headers(self, headers): new = {} for k in headers: - new_k = '-'.join([s.lower() for s in k.split('-')]) + new_k = "-".join([s.lower() for s in k.split("-")]) new[new_k] = headers[k] return new @@ -611,10 +614,10 @@ class Entry(BaseClass): now = datetime.utcnow() headers = { - 'status': self.status, - 'date': now.strftime('%a, %d %b %Y %H:%M:%S GMT'), - 'server': 'Python/HTTPretty', - 'connection': 'close', + "status": self.status, + "date": now.strftime("%a, %d %b %Y %H:%M:%S GMT"), + "server": "Python/HTTPretty", + "connection": "close", } if self.forcing_headers: @@ -624,44 +627,38 @@ class Entry(BaseClass): headers.update(self.normalize_headers(self.adding_headers)) headers = self.normalize_headers(headers) - status = headers.get('status', self.status) + status = headers.get("status", self.status) if self.body_is_callable: status, headers, self.body = self.callable_body( - self.request, self.info.full_url(), headers) + self.request, self.info.full_url(), headers + ) headers = self.normalize_headers(headers) if self.request.method != "HEAD": - headers.update({ - 'content-length': len(self.body) - }) + headers.update({"content-length": len(self.body)}) - string_list = [ - 'HTTP/1.1 %d %s' % (status, STATUSES[status]), - ] + string_list = ["HTTP/1.1 %d %s" % (status, STATUSES[status])] - if 'date' in headers: - string_list.append('date: %s' % headers.pop('date')) + if "date" in headers: + string_list.append("date: %s" % headers.pop("date")) if not self.forcing_headers: - content_type = headers.pop('content-type', - 'text/plain; charset=utf-8') + content_type = headers.pop("content-type", "text/plain; charset=utf-8") - content_length = headers.pop('content-length', self.body_length) + content_length = headers.pop("content-length", self.body_length) - string_list.append('content-type: %s' % content_type) + string_list.append("content-type: %s" % content_type) if not self.streaming: - string_list.append('content-length: %s' % content_length) + string_list.append("content-length: %s" % content_length) - string_list.append('server: %s' % headers.pop('server')) + string_list.append("server: %s" % headers.pop("server")) for k, v in headers.items(): - string_list.append( - '{0}: {1}'.format(k, v), - ) + string_list.append("{0}: {1}".format(k, v)) for item in string_list: - fk.write(utf8(item) + b'\n') + fk.write(utf8(item) + b"\n") - fk.write(b'\r\n') + fk.write(b"\r\n") if self.streaming: self.body, body = itertools.tee(self.body) @@ -673,58 +670,53 @@ class Entry(BaseClass): fk.seek(0) -def url_fix(s, charset='utf-8'): +def url_fix(s, charset="utf-8"): scheme, netloc, path, querystring, fragment = urlsplit(s) - path = quote(path, b'/%') - querystring = quote_plus(querystring, b':&=') + path = quote(path, b"/%") + querystring = quote_plus(querystring, b":&=") return urlunsplit((scheme, netloc, path, querystring, fragment)) class URIInfo(BaseClass): + def __init__( + self, + username="", + password="", + hostname="", + port=80, + path="/", + query="", + fragment="", + scheme="", + last_request=None, + ): - def __init__(self, - username='', - password='', - hostname='', - port=80, - path='/', - query='', - fragment='', - scheme='', - last_request=None): - - self.username = username or '' - self.password = password or '' - self.hostname = hostname or '' + self.username = username or "" + self.password = password or "" + self.hostname = hostname or "" if port: port = int(port) - elif scheme == 'https': + elif scheme == "https": port = 443 self.port = port or 80 - self.path = path or '' - self.query = query or '' + self.path = path or "" + self.query = query or "" if scheme: self.scheme = scheme elif self.port in POTENTIAL_HTTPS_PORTS: - self.scheme = 'https' + self.scheme = "https" else: - self.scheme = 'http' - self.fragment = fragment or '' + self.scheme = "http" + self.fragment = fragment or "" self.last_request = last_request def __str__(self): - attrs = ( - 'username', - 'password', - 'hostname', - 'port', - 'path', - ) - fmt = ", ".join(['%s="%s"' % (k, getattr(self, k, '')) for k in attrs]) - return r'' % fmt + attrs = ("username", "password", "hostname", "port", "path") + fmt = ", ".join(['%s="%s"' % (k, getattr(self, k, "")) for k in attrs]) + return r"" % fmt def __hash__(self): return hash(text_type(self)) @@ -745,8 +737,7 @@ class URIInfo(BaseClass): def full_url(self, use_querystring=True): credentials = "" if self.password: - credentials = "{0}:{1}@".format( - self.username, self.password) + credentials = "{0}:{1}@".format(self.username, self.password) query = "" if use_querystring and self.query: @@ -757,7 +748,7 @@ class URIInfo(BaseClass): credentials=credentials, domain=self.get_full_domain(), path=decode_utf8(self.path), - query=query + query=query, ) return result @@ -772,19 +763,21 @@ class URIInfo(BaseClass): @classmethod def from_uri(cls, uri, entry): result = urlsplit(uri) - if result.scheme == 'https': + if result.scheme == "https": POTENTIAL_HTTPS_PORTS.add(int(result.port or 443)) else: POTENTIAL_HTTP_PORTS.add(int(result.port or 80)) - return cls(result.username, - result.password, - result.hostname, - result.port, - result.path, - result.query, - result.fragment, - result.scheme, - entry) + return cls( + result.username, + result.password, + result.hostname, + result.port, + result.path, + result.query, + result.fragment, + result.scheme, + entry, + ) class URIMatcher(object): @@ -793,10 +786,10 @@ class URIMatcher(object): def __init__(self, uri, entries, match_querystring=False): self._match_querystring = match_querystring - if type(uri).__name__ in ('SRE_Pattern', 'Pattern'): + if type(uri).__name__ in ("SRE_Pattern", "Pattern"): self.regex = uri result = urlsplit(uri.pattern) - if result.scheme == 'https': + if result.scheme == "https": POTENTIAL_HTTPS_PORTS.add(int(result.port or 443)) else: POTENTIAL_HTTP_PORTS.add(int(result.port or 80)) @@ -812,11 +805,12 @@ class URIMatcher(object): if self.info: return self.info == info else: - return self.regex.search(info.full_url( - use_querystring=self._match_querystring)) + return self.regex.search( + info.full_url(use_querystring=self._match_querystring) + ) def __str__(self): - wrap = 'URLMatcher({0})' + wrap = "URLMatcher({0})" if self.info: return wrap.format(text_type(self.info)) else: @@ -836,8 +830,7 @@ class URIMatcher(object): self.current_entries[method] = -1 if not self.entries or not entries_for_method: - raise ValueError('I have no entries for method %s: %s' - % (method, self)) + raise ValueError("I have no entries for method %s: %s" % (method, self)) entry = entries_for_method[self.current_entries[method]] if self.current_entries[method] != -1: @@ -861,6 +854,7 @@ class URIMatcher(object): class httpretty(HttpBaseClass): """The URI registration class""" + _entries = {} latest_requests = [] @@ -878,12 +872,13 @@ class httpretty(HttpBaseClass): @classmethod @contextlib.contextmanager - def record(cls, filename, indentation=4, encoding='utf-8'): + def record(cls, filename, indentation=4, encoding="utf-8"): try: import urllib3 except ImportError: raise RuntimeError( - 'HTTPretty requires urllib3 installed for recording actual requests.') + "HTTPretty requires urllib3 installed for recording actual requests." + ) http = urllib3.PoolManager() @@ -894,30 +889,31 @@ class httpretty(HttpBaseClass): cls.disable() response = http.request(request.method, uri) - calls.append({ - 'request': { - 'uri': uri, - 'method': request.method, - 'headers': dict(request.headers), - 'body': decode_utf8(request.body), - 'querystring': request.querystring - }, - 'response': { - 'status': response.status, - 'body': decode_utf8(response.data), - 'headers': dict(response.headers) + calls.append( + { + "request": { + "uri": uri, + "method": request.method, + "headers": dict(request.headers), + "body": decode_utf8(request.body), + "querystring": request.querystring, + }, + "response": { + "status": response.status, + "body": decode_utf8(response.data), + "headers": dict(response.headers), + }, } - }) + ) cls.enable() return response.status, response.headers, response.data for method in cls.METHODS: - cls.register_uri(method, re.compile( - r'.*', re.M), body=record_request) + cls.register_uri(method, re.compile(r".*", re.M), body=record_request) yield cls.disable() - with codecs.open(filename, 'w', encoding) as f: + with codecs.open(filename, "w", encoding) as f: f.write(json.dumps(calls, indent=indentation)) @classmethod @@ -927,10 +923,14 @@ class httpretty(HttpBaseClass): data = json.loads(open(origin).read()) for item in data: - uri = item['request']['uri'] - method = item['request']['method'] - cls.register_uri(method, uri, body=item['response'][ - 'body'], forcing_headers=item['response']['headers']) + uri = item["request"]["uri"] + method = item["request"]["method"] + cls.register_uri( + method, + uri, + body=item["response"]["body"], + forcing_headers=item["response"]["headers"], + ) yield cls.disable() @@ -944,7 +944,7 @@ class httpretty(HttpBaseClass): cls.last_request = HTTPrettyRequestEmpty() @classmethod - def historify_request(cls, headers, body='', append=True): + def historify_request(cls, headers, body="", append=True): request = HTTPrettyRequest(headers, body) cls.last_request = request if append or not cls.latest_requests: @@ -954,17 +954,23 @@ class httpretty(HttpBaseClass): return request @classmethod - def register_uri(cls, method, uri, body='HTTPretty :)', - adding_headers=None, - forcing_headers=None, - status=200, - responses=None, match_querystring=False, - **headers): + def register_uri( + cls, + method, + uri, + body="HTTPretty :)", + adding_headers=None, + forcing_headers=None, + status=200, + responses=None, + match_querystring=False, + **headers + ): uri_is_string = isinstance(uri, basestring) - if uri_is_string and re.search(r'^\w+://[^/]+[.]\w{2,}$', uri): - uri += '/' + if uri_is_string and re.search(r"^\w+://[^/]+[.]\w{2,}$", uri): + uri += "/" if isinstance(responses, list) and len(responses) > 0: for response in responses: @@ -972,17 +978,14 @@ class httpretty(HttpBaseClass): response.method = method entries_for_this_uri = responses else: - headers[str('body')] = body - headers[str('adding_headers')] = adding_headers - headers[str('forcing_headers')] = forcing_headers - headers[str('status')] = status + headers[str("body")] = body + headers[str("adding_headers")] = adding_headers + headers[str("forcing_headers")] = forcing_headers + headers[str("status")] = status - entries_for_this_uri = [ - cls.Response(method=method, uri=uri, **headers), - ] + entries_for_this_uri = [cls.Response(method=method, uri=uri, **headers)] - matcher = URIMatcher(uri, entries_for_this_uri, - match_querystring) + matcher = URIMatcher(uri, entries_for_this_uri, match_querystring) if matcher in cls._entries: matcher.entries.extend(cls._entries[matcher]) del cls._entries[matcher] @@ -990,17 +993,26 @@ class httpretty(HttpBaseClass): cls._entries[matcher] = entries_for_this_uri def __str__(self): - return '' % len(self._entries) + return "" % len(self._entries) @classmethod - def Response(cls, body, method=None, uri=None, adding_headers=None, forcing_headers=None, - status=200, streaming=False, **headers): + def Response( + cls, + body, + method=None, + uri=None, + adding_headers=None, + forcing_headers=None, + status=200, + streaming=False, + **headers + ): - headers[str('body')] = body - headers[str('adding_headers')] = adding_headers - headers[str('forcing_headers')] = forcing_headers - headers[str('status')] = int(status) - headers[str('streaming')] = streaming + headers[str("body")] = body + headers[str("adding_headers")] = adding_headers + headers[str("forcing_headers")] = forcing_headers + headers[str("status")] = int(status) + headers[str("streaming")] = streaming return Entry(method, uri, **headers) @classmethod @@ -1016,19 +1028,19 @@ class httpretty(HttpBaseClass): socket.gethostbyname = old_gethostbyname socket.getaddrinfo = old_getaddrinfo - socket.__dict__['socket'] = old_socket - socket.__dict__['_socketobject'] = old_socket + socket.__dict__["socket"] = old_socket + socket.__dict__["_socketobject"] = old_socket if not BAD_SOCKET_SHADOW: - socket.__dict__['SocketType'] = old_socket + socket.__dict__["SocketType"] = old_socket - socket.__dict__['create_connection'] = old_create_connection - socket.__dict__['gethostname'] = old_gethostname - socket.__dict__['gethostbyname'] = old_gethostbyname - socket.__dict__['getaddrinfo'] = old_getaddrinfo + socket.__dict__["create_connection"] = old_create_connection + socket.__dict__["gethostname"] = old_gethostname + socket.__dict__["gethostbyname"] = old_gethostbyname + socket.__dict__["getaddrinfo"] = old_getaddrinfo if socks: socks.socksocket = old_socksocket - socks.__dict__['socksocket'] = old_socksocket + socks.__dict__["socksocket"] = old_socksocket if ssl: ssl.wrap_socket = old_ssl_wrap_socket @@ -1037,12 +1049,12 @@ class httpretty(HttpBaseClass): ssl.SSLContext.wrap_socket = old_sslcontext_wrap_socket except AttributeError: pass - ssl.__dict__['wrap_socket'] = old_ssl_wrap_socket - ssl.__dict__['SSLSocket'] = old_sslsocket + ssl.__dict__["wrap_socket"] = old_ssl_wrap_socket + ssl.__dict__["SSLSocket"] = old_sslsocket if not PY3: ssl.sslwrap_simple = old_sslwrap_simple - ssl.__dict__['sslwrap_simple'] = old_sslwrap_simple + ssl.__dict__["sslwrap_simple"] = old_sslwrap_simple if pyopenssl_override: inject_into_urllib3() @@ -1065,25 +1077,26 @@ class httpretty(HttpBaseClass): socket.gethostbyname = fake_gethostbyname socket.getaddrinfo = fake_getaddrinfo - socket.__dict__['socket'] = fakesock.socket - socket.__dict__['_socketobject'] = fakesock.socket + socket.__dict__["socket"] = fakesock.socket + socket.__dict__["_socketobject"] = fakesock.socket if not BAD_SOCKET_SHADOW: - socket.__dict__['SocketType'] = fakesock.socket + socket.__dict__["SocketType"] = fakesock.socket - socket.__dict__['create_connection'] = create_fake_connection - socket.__dict__['gethostname'] = fake_gethostname - socket.__dict__['gethostbyname'] = fake_gethostbyname - socket.__dict__['getaddrinfo'] = fake_getaddrinfo + socket.__dict__["create_connection"] = create_fake_connection + socket.__dict__["gethostname"] = fake_gethostname + socket.__dict__["gethostbyname"] = fake_gethostbyname + socket.__dict__["getaddrinfo"] = fake_getaddrinfo if socks: socks.socksocket = fakesock.socket - socks.__dict__['socksocket'] = fakesock.socket + socks.__dict__["socksocket"] = fakesock.socket if ssl: ssl.wrap_socket = fake_wrap_socket ssl.SSLSocket = FakeSSLSocket try: + def fake_sslcontext_wrap_socket(cls, *args, **kwargs): return fake_wrap_socket(*args, **kwargs) @@ -1091,12 +1104,12 @@ class httpretty(HttpBaseClass): except AttributeError: pass - ssl.__dict__['wrap_socket'] = fake_wrap_socket - ssl.__dict__['SSLSocket'] = FakeSSLSocket + ssl.__dict__["wrap_socket"] = fake_wrap_socket + ssl.__dict__["SSLSocket"] = FakeSSLSocket if not PY3: ssl.sslwrap_simple = fake_wrap_socket - ssl.__dict__['sslwrap_simple'] = fake_wrap_socket + ssl.__dict__["sslwrap_simple"] = fake_wrap_socket if pyopenssl_override: extract_from_urllib3() @@ -1104,9 +1117,10 @@ class httpretty(HttpBaseClass): def httprettified(test): "A decorator tests that use HTTPretty" + def decorate_class(klass): for attr in dir(klass): - if not attr.startswith('test_'): + if not attr.startswith("test_"): continue attr_value = getattr(klass, attr) @@ -1125,8 +1139,9 @@ def httprettified(test): return test(*args, **kw) finally: httpretty.disable() + return wrapper if isinstance(test, ClassTypes): return decorate_class(test) - return decorate_callable(test) \ No newline at end of file + return decorate_callable(test) diff --git a/moto/packages/httpretty/errors.py b/moto/packages/httpretty/errors.py index e2dcad357..8221e5f66 100644 --- a/moto/packages/httpretty/errors.py +++ b/moto/packages/httpretty/errors.py @@ -32,9 +32,8 @@ class HTTPrettyError(Exception): class UnmockedError(HTTPrettyError): - def __init__(self): super(UnmockedError, self).__init__( - 'No mocking was registered, and real connections are ' - 'not allowed (httpretty.allow_net_connect = False).' + "No mocking was registered, and real connections are " + "not allowed (httpretty.allow_net_connect = False)." ) diff --git a/moto/packages/httpretty/http.py b/moto/packages/httpretty/http.py index ee1625905..20c00707e 100644 --- a/moto/packages/httpretty/http.py +++ b/moto/packages/httpretty/http.py @@ -109,14 +109,14 @@ STATUSES = { class HttpBaseClass(BaseClass): - GET = 'GET' - PUT = 'PUT' - POST = 'POST' - DELETE = 'DELETE' - HEAD = 'HEAD' - PATCH = 'PATCH' - OPTIONS = 'OPTIONS' - CONNECT = 'CONNECT' + GET = "GET" + PUT = "PUT" + POST = "POST" + DELETE = "DELETE" + HEAD = "HEAD" + PATCH = "PATCH" + OPTIONS = "OPTIONS" + CONNECT = "CONNECT" METHODS = (GET, PUT, POST, DELETE, HEAD, PATCH, OPTIONS, CONNECT) @@ -133,12 +133,12 @@ def parse_requestline(s): ... ValueError: Not a Request-Line """ - methods = '|'.join(HttpBaseClass.METHODS) - m = re.match(r'(' + methods + ')\s+(.*)\s+HTTP/(1.[0|1])', s, re.I) + methods = "|".join(HttpBaseClass.METHODS) + m = re.match(r"(" + methods + ")\s+(.*)\s+HTTP/(1.[0|1])", s, re.I) if m: return m.group(1).upper(), m.group(2), m.group(3) else: - raise ValueError('Not a Request-Line') + raise ValueError("Not a Request-Line") def last_requestline(sent_data): diff --git a/moto/packages/httpretty/utils.py b/moto/packages/httpretty/utils.py index caa8fa13b..2bf5d0829 100644 --- a/moto/packages/httpretty/utils.py +++ b/moto/packages/httpretty/utils.py @@ -25,14 +25,12 @@ # OTHER DEALINGS IN THE SOFTWARE. from __future__ import unicode_literals -from .compat import ( - byte_type, text_type -) +from .compat import byte_type, text_type def utf8(s): if isinstance(s, text_type): - s = s.encode('utf-8') + s = s.encode("utf-8") elif s is None: return byte_type() diff --git a/moto/polly/__init__.py b/moto/polly/__init__.py index 9c2281126..6db0215de 100644 --- a/moto/polly/__init__.py +++ b/moto/polly/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import polly_backends from ..core.models import base_decorator -polly_backend = polly_backends['us-east-1'] +polly_backend = polly_backends["us-east-1"] mock_polly = base_decorator(polly_backends) diff --git a/moto/polly/models.py b/moto/polly/models.py index e7b7117dc..f91c80c64 100644 --- a/moto/polly/models.py +++ b/moto/polly/models.py @@ -8,7 +8,7 @@ from moto.core import BaseBackend, BaseModel from .resources import VOICE_DATA from .utils import make_arn_for_lexicon -DEFAULT_ACCOUNT_ID = 123456789012 +from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID class Lexicon(BaseModel): @@ -32,33 +32,36 @@ class Lexicon(BaseModel): try: root = ET.fromstring(self.content) self.size = len(self.content) - self.last_modified = int((datetime.datetime.now() - - datetime.datetime(1970, 1, 1)).total_seconds()) - self.lexemes_count = len(root.findall('.')) + self.last_modified = int( + ( + datetime.datetime.now() - datetime.datetime(1970, 1, 1) + ).total_seconds() + ) + self.lexemes_count = len(root.findall(".")) for key, value in root.attrib.items(): - if key.endswith('alphabet'): + if key.endswith("alphabet"): self.alphabet = value - elif key.endswith('lang'): + elif key.endswith("lang"): self.language_code = value except Exception as err: - raise ValueError('Failure parsing XML: {0}'.format(err)) + raise ValueError("Failure parsing XML: {0}".format(err)) def to_dict(self): return { - 'Attributes': { - 'Alphabet': self.alphabet, - 'LanguageCode': self.language_code, - 'LastModified': self.last_modified, - 'LexemesCount': self.lexemes_count, - 'LexiconArn': self.arn, - 'Size': self.size + "Attributes": { + "Alphabet": self.alphabet, + "LanguageCode": self.language_code, + "LastModified": self.last_modified, + "LexemesCount": self.lexemes_count, + "LexiconArn": self.arn, + "Size": self.size, } } def __repr__(self): - return ''.format(self.name) + return "".format(self.name) class PollyBackend(BaseBackend): @@ -77,7 +80,7 @@ class PollyBackend(BaseBackend): if language_code is None: return VOICE_DATA - return [item for item in VOICE_DATA if item['LanguageCode'] == language_code] + return [item for item in VOICE_DATA if item["LanguageCode"] == language_code] def delete_lexicon(self, name): # implement here @@ -93,7 +96,7 @@ class PollyBackend(BaseBackend): for name, lexicon in self._lexicons.items(): lexicon_dict = lexicon.to_dict() - lexicon_dict['Name'] = name + lexicon_dict["Name"] = name result.append(lexicon_dict) @@ -111,4 +114,6 @@ class PollyBackend(BaseBackend): available_regions = boto3.session.Session().get_available_regions("polly") -polly_backends = {region: PollyBackend(region_name=region) for region in available_regions} +polly_backends = { + region: PollyBackend(region_name=region) for region in available_regions +} diff --git a/moto/polly/resources.py b/moto/polly/resources.py index f4ad69a98..560e62b7b 100644 --- a/moto/polly/resources.py +++ b/moto/polly/resources.py @@ -1,63 +1,418 @@ # -*- coding: utf-8 -*- VOICE_DATA = [ - {'Id': 'Joanna', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Joanna'}, - {'Id': 'Mizuki', 'LanguageCode': 'ja-JP', 'LanguageName': 'Japanese', 'Gender': 'Female', 'Name': 'Mizuki'}, - {'Id': 'Filiz', 'LanguageCode': 'tr-TR', 'LanguageName': 'Turkish', 'Gender': 'Female', 'Name': 'Filiz'}, - {'Id': 'Astrid', 'LanguageCode': 'sv-SE', 'LanguageName': 'Swedish', 'Gender': 'Female', 'Name': 'Astrid'}, - {'Id': 'Tatyana', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Female', 'Name': 'Tatyana'}, - {'Id': 'Maxim', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Male', 'Name': 'Maxim'}, - {'Id': 'Carmen', 'LanguageCode': 'ro-RO', 'LanguageName': 'Romanian', 'Gender': 'Female', 'Name': 'Carmen'}, - {'Id': 'Ines', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Female', 'Name': 'Inês'}, - {'Id': 'Cristiano', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Male', 'Name': 'Cristiano'}, - {'Id': 'Vitoria', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Female', 'Name': 'Vitória'}, - {'Id': 'Ricardo', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Male', 'Name': 'Ricardo'}, - {'Id': 'Maja', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Maja'}, - {'Id': 'Jan', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jan'}, - {'Id': 'Ewa', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Ewa'}, - {'Id': 'Ruben', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Male', 'Name': 'Ruben'}, - {'Id': 'Lotte', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Female', 'Name': 'Lotte'}, - {'Id': 'Liv', 'LanguageCode': 'nb-NO', 'LanguageName': 'Norwegian', 'Gender': 'Female', 'Name': 'Liv'}, - {'Id': 'Giorgio', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Male', 'Name': 'Giorgio'}, - {'Id': 'Carla', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Female', 'Name': 'Carla'}, - {'Id': 'Karl', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Male', 'Name': 'Karl'}, - {'Id': 'Dora', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Female', 'Name': 'Dóra'}, - {'Id': 'Mathieu', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Male', 'Name': 'Mathieu'}, - {'Id': 'Celine', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Female', 'Name': 'Céline'}, - {'Id': 'Chantal', 'LanguageCode': 'fr-CA', 'LanguageName': 'Canadian French', 'Gender': 'Female', 'Name': 'Chantal'}, - {'Id': 'Penelope', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Female', 'Name': 'Penélope'}, - {'Id': 'Miguel', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Male', 'Name': 'Miguel'}, - {'Id': 'Enrique', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Male', 'Name': 'Enrique'}, - {'Id': 'Conchita', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Female', 'Name': 'Conchita'}, - {'Id': 'Geraint', 'LanguageCode': 'en-GB-WLS', 'LanguageName': 'Welsh English', 'Gender': 'Male', 'Name': 'Geraint'}, - {'Id': 'Salli', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Salli'}, - {'Id': 'Kimberly', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kimberly'}, - {'Id': 'Kendra', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kendra'}, - {'Id': 'Justin', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Justin'}, - {'Id': 'Joey', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Joey'}, - {'Id': 'Ivy', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Ivy'}, - {'Id': 'Raveena', 'LanguageCode': 'en-IN', 'LanguageName': 'Indian English', 'Gender': 'Female', 'Name': 'Raveena'}, - {'Id': 'Emma', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Emma'}, - {'Id': 'Brian', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Male', 'Name': 'Brian'}, - {'Id': 'Amy', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Amy'}, - {'Id': 'Russell', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Male', 'Name': 'Russell'}, - {'Id': 'Nicole', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Female', 'Name': 'Nicole'}, - {'Id': 'Vicki', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Vicki'}, - {'Id': 'Marlene', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Marlene'}, - {'Id': 'Hans', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Male', 'Name': 'Hans'}, - {'Id': 'Naja', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Female', 'Name': 'Naja'}, - {'Id': 'Mads', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Male', 'Name': 'Mads'}, - {'Id': 'Gwyneth', 'LanguageCode': 'cy-GB', 'LanguageName': 'Welsh', 'Gender': 'Female', 'Name': 'Gwyneth'}, - {'Id': 'Jacek', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jacek'} + { + "Id": "Joanna", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Joanna", + }, + { + "Id": "Mizuki", + "LanguageCode": "ja-JP", + "LanguageName": "Japanese", + "Gender": "Female", + "Name": "Mizuki", + }, + { + "Id": "Filiz", + "LanguageCode": "tr-TR", + "LanguageName": "Turkish", + "Gender": "Female", + "Name": "Filiz", + }, + { + "Id": "Astrid", + "LanguageCode": "sv-SE", + "LanguageName": "Swedish", + "Gender": "Female", + "Name": "Astrid", + }, + { + "Id": "Tatyana", + "LanguageCode": "ru-RU", + "LanguageName": "Russian", + "Gender": "Female", + "Name": "Tatyana", + }, + { + "Id": "Maxim", + "LanguageCode": "ru-RU", + "LanguageName": "Russian", + "Gender": "Male", + "Name": "Maxim", + }, + { + "Id": "Carmen", + "LanguageCode": "ro-RO", + "LanguageName": "Romanian", + "Gender": "Female", + "Name": "Carmen", + }, + { + "Id": "Ines", + "LanguageCode": "pt-PT", + "LanguageName": "Portuguese", + "Gender": "Female", + "Name": "Inês", + }, + { + "Id": "Cristiano", + "LanguageCode": "pt-PT", + "LanguageName": "Portuguese", + "Gender": "Male", + "Name": "Cristiano", + }, + { + "Id": "Vitoria", + "LanguageCode": "pt-BR", + "LanguageName": "Brazilian Portuguese", + "Gender": "Female", + "Name": "Vitória", + }, + { + "Id": "Ricardo", + "LanguageCode": "pt-BR", + "LanguageName": "Brazilian Portuguese", + "Gender": "Male", + "Name": "Ricardo", + }, + { + "Id": "Maja", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Female", + "Name": "Maja", + }, + { + "Id": "Jan", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Male", + "Name": "Jan", + }, + { + "Id": "Ewa", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Female", + "Name": "Ewa", + }, + { + "Id": "Ruben", + "LanguageCode": "nl-NL", + "LanguageName": "Dutch", + "Gender": "Male", + "Name": "Ruben", + }, + { + "Id": "Lotte", + "LanguageCode": "nl-NL", + "LanguageName": "Dutch", + "Gender": "Female", + "Name": "Lotte", + }, + { + "Id": "Liv", + "LanguageCode": "nb-NO", + "LanguageName": "Norwegian", + "Gender": "Female", + "Name": "Liv", + }, + { + "Id": "Giorgio", + "LanguageCode": "it-IT", + "LanguageName": "Italian", + "Gender": "Male", + "Name": "Giorgio", + }, + { + "Id": "Carla", + "LanguageCode": "it-IT", + "LanguageName": "Italian", + "Gender": "Female", + "Name": "Carla", + }, + { + "Id": "Karl", + "LanguageCode": "is-IS", + "LanguageName": "Icelandic", + "Gender": "Male", + "Name": "Karl", + }, + { + "Id": "Dora", + "LanguageCode": "is-IS", + "LanguageName": "Icelandic", + "Gender": "Female", + "Name": "Dóra", + }, + { + "Id": "Mathieu", + "LanguageCode": "fr-FR", + "LanguageName": "French", + "Gender": "Male", + "Name": "Mathieu", + }, + { + "Id": "Celine", + "LanguageCode": "fr-FR", + "LanguageName": "French", + "Gender": "Female", + "Name": "Céline", + }, + { + "Id": "Chantal", + "LanguageCode": "fr-CA", + "LanguageName": "Canadian French", + "Gender": "Female", + "Name": "Chantal", + }, + { + "Id": "Penelope", + "LanguageCode": "es-US", + "LanguageName": "US Spanish", + "Gender": "Female", + "Name": "Penélope", + }, + { + "Id": "Miguel", + "LanguageCode": "es-US", + "LanguageName": "US Spanish", + "Gender": "Male", + "Name": "Miguel", + }, + { + "Id": "Enrique", + "LanguageCode": "es-ES", + "LanguageName": "Castilian Spanish", + "Gender": "Male", + "Name": "Enrique", + }, + { + "Id": "Conchita", + "LanguageCode": "es-ES", + "LanguageName": "Castilian Spanish", + "Gender": "Female", + "Name": "Conchita", + }, + { + "Id": "Geraint", + "LanguageCode": "en-GB-WLS", + "LanguageName": "Welsh English", + "Gender": "Male", + "Name": "Geraint", + }, + { + "Id": "Salli", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Salli", + }, + { + "Id": "Kimberly", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Kimberly", + }, + { + "Id": "Kendra", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Kendra", + }, + { + "Id": "Justin", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Male", + "Name": "Justin", + }, + { + "Id": "Joey", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Male", + "Name": "Joey", + }, + { + "Id": "Ivy", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Ivy", + }, + { + "Id": "Raveena", + "LanguageCode": "en-IN", + "LanguageName": "Indian English", + "Gender": "Female", + "Name": "Raveena", + }, + { + "Id": "Emma", + "LanguageCode": "en-GB", + "LanguageName": "British English", + "Gender": "Female", + "Name": "Emma", + }, + { + "Id": "Brian", + "LanguageCode": "en-GB", + "LanguageName": "British English", + "Gender": "Male", + "Name": "Brian", + }, + { + "Id": "Amy", + "LanguageCode": "en-GB", + "LanguageName": "British English", + "Gender": "Female", + "Name": "Amy", + }, + { + "Id": "Russell", + "LanguageCode": "en-AU", + "LanguageName": "Australian English", + "Gender": "Male", + "Name": "Russell", + }, + { + "Id": "Nicole", + "LanguageCode": "en-AU", + "LanguageName": "Australian English", + "Gender": "Female", + "Name": "Nicole", + }, + { + "Id": "Vicki", + "LanguageCode": "de-DE", + "LanguageName": "German", + "Gender": "Female", + "Name": "Vicki", + }, + { + "Id": "Marlene", + "LanguageCode": "de-DE", + "LanguageName": "German", + "Gender": "Female", + "Name": "Marlene", + }, + { + "Id": "Hans", + "LanguageCode": "de-DE", + "LanguageName": "German", + "Gender": "Male", + "Name": "Hans", + }, + { + "Id": "Naja", + "LanguageCode": "da-DK", + "LanguageName": "Danish", + "Gender": "Female", + "Name": "Naja", + }, + { + "Id": "Mads", + "LanguageCode": "da-DK", + "LanguageName": "Danish", + "Gender": "Male", + "Name": "Mads", + }, + { + "Id": "Gwyneth", + "LanguageCode": "cy-GB", + "LanguageName": "Welsh", + "Gender": "Female", + "Name": "Gwyneth", + }, + { + "Id": "Jacek", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Male", + "Name": "Jacek", + }, ] # {...} is also shorthand set syntax -LANGUAGE_CODES = {'cy-GB', 'da-DK', 'de-DE', 'en-AU', 'en-GB', 'en-GB-WLS', 'en-IN', 'en-US', 'es-ES', 'es-US', - 'fr-CA', 'fr-FR', 'is-IS', 'it-IT', 'ja-JP', 'nb-NO', 'nl-NL', 'pl-PL', 'pt-BR', 'pt-PT', 'ro-RO', - 'ru-RU', 'sv-SE', 'tr-TR'} +LANGUAGE_CODES = { + "cy-GB", + "da-DK", + "de-DE", + "en-AU", + "en-GB", + "en-GB-WLS", + "en-IN", + "en-US", + "es-ES", + "es-US", + "fr-CA", + "fr-FR", + "is-IS", + "it-IT", + "ja-JP", + "nb-NO", + "nl-NL", + "pl-PL", + "pt-BR", + "pt-PT", + "ro-RO", + "ru-RU", + "sv-SE", + "tr-TR", +} -VOICE_IDS = {'Geraint', 'Gwyneth', 'Mads', 'Naja', 'Hans', 'Marlene', 'Nicole', 'Russell', 'Amy', 'Brian', 'Emma', - 'Raveena', 'Ivy', 'Joanna', 'Joey', 'Justin', 'Kendra', 'Kimberly', 'Salli', 'Conchita', 'Enrique', - 'Miguel', 'Penelope', 'Chantal', 'Celine', 'Mathieu', 'Dora', 'Karl', 'Carla', 'Giorgio', 'Mizuki', - 'Liv', 'Lotte', 'Ruben', 'Ewa', 'Jacek', 'Jan', 'Maja', 'Ricardo', 'Vitoria', 'Cristiano', 'Ines', - 'Carmen', 'Maxim', 'Tatyana', 'Astrid', 'Filiz'} +VOICE_IDS = { + "Geraint", + "Gwyneth", + "Mads", + "Naja", + "Hans", + "Marlene", + "Nicole", + "Russell", + "Amy", + "Brian", + "Emma", + "Raveena", + "Ivy", + "Joanna", + "Joey", + "Justin", + "Kendra", + "Kimberly", + "Salli", + "Conchita", + "Enrique", + "Miguel", + "Penelope", + "Chantal", + "Celine", + "Mathieu", + "Dora", + "Karl", + "Carla", + "Giorgio", + "Mizuki", + "Liv", + "Lotte", + "Ruben", + "Ewa", + "Jacek", + "Jan", + "Maja", + "Ricardo", + "Vitoria", + "Cristiano", + "Ines", + "Carmen", + "Maxim", + "Tatyana", + "Astrid", + "Filiz", +} diff --git a/moto/polly/responses.py b/moto/polly/responses.py index 810264424..e7de01b2b 100644 --- a/moto/polly/responses.py +++ b/moto/polly/responses.py @@ -9,7 +9,7 @@ from moto.core.responses import BaseResponse from .models import polly_backends from .resources import LANGUAGE_CODES, VOICE_IDS -LEXICON_NAME_REGEX = re.compile(r'^[0-9A-Za-z]{1,20}$') +LEXICON_NAME_REGEX = re.compile(r"^[0-9A-Za-z]{1,20}$") class PollyResponse(BaseResponse): @@ -19,71 +19,75 @@ class PollyResponse(BaseResponse): @property def json(self): - if not hasattr(self, '_json'): + if not hasattr(self, "_json"): self._json = json.loads(self.body) return self._json def _error(self, code, message): - return json.dumps({'__type': code, 'message': message}), dict(status=400) + return json.dumps({"__type": code, "message": message}), dict(status=400) def _get_action(self): # Amazon is now naming things /v1/api_name - url_parts = urlsplit(self.uri).path.lstrip('/').split('/') + url_parts = urlsplit(self.uri).path.lstrip("/").split("/") # [0] = 'v1' return url_parts[1] # DescribeVoices def voices(self): - language_code = self._get_param('LanguageCode') - next_token = self._get_param('NextToken') + language_code = self._get_param("LanguageCode") + next_token = self._get_param("NextToken") if language_code is not None and language_code not in LANGUAGE_CODES: - msg = "1 validation error detected: Value '{0}' at 'languageCode' failed to satisfy constraint: " \ - "Member must satisfy enum value set: [{1}]".format(language_code, ', '.join(LANGUAGE_CODES)) + msg = ( + "1 validation error detected: Value '{0}' at 'languageCode' failed to satisfy constraint: " + "Member must satisfy enum value set: [{1}]".format( + language_code, ", ".join(LANGUAGE_CODES) + ) + ) return msg, dict(status=400) voices = self.polly_backend.describe_voices(language_code, next_token) - return json.dumps({'Voices': voices}) + return json.dumps({"Voices": voices}) def lexicons(self): # Dish out requests based on methods # anything after the /v1/lexicons/ - args = urlsplit(self.uri).path.lstrip('/').split('/')[2:] + args = urlsplit(self.uri).path.lstrip("/").split("/")[2:] - if self.method == 'GET': + if self.method == "GET": if len(args) == 0: return self._get_lexicons_list() else: return self._get_lexicon(*args) - elif self.method == 'PUT': + elif self.method == "PUT": return self._put_lexicons(*args) - elif self.method == 'DELETE': + elif self.method == "DELETE": return self._delete_lexicon(*args) - return self._error('InvalidAction', 'Bad route') + return self._error("InvalidAction", "Bad route") # PutLexicon def _put_lexicons(self, lexicon_name): if LEXICON_NAME_REGEX.match(lexicon_name) is None: - return self._error('InvalidParameterValue', 'Lexicon name must match [0-9A-Za-z]{1,20}') + return self._error( + "InvalidParameterValue", "Lexicon name must match [0-9A-Za-z]{1,20}" + ) - if 'Content' not in self.json: - return self._error('MissingParameter', 'Content is missing from the body') + if "Content" not in self.json: + return self._error("MissingParameter", "Content is missing from the body") - self.polly_backend.put_lexicon(lexicon_name, self.json['Content']) + self.polly_backend.put_lexicon(lexicon_name, self.json["Content"]) - return '' + return "" # ListLexicons def _get_lexicons_list(self): - next_token = self._get_param('NextToken') + next_token = self._get_param("NextToken") - result = { - 'Lexicons': self.polly_backend.list_lexicons(next_token) - } + result = {"Lexicons": self.polly_backend.list_lexicons(next_token)} return json.dumps(result) @@ -92,14 +96,11 @@ class PollyResponse(BaseResponse): try: lexicon = self.polly_backend.get_lexicon(lexicon_name) except KeyError: - return self._error('LexiconNotFoundException', 'Lexicon not found') + return self._error("LexiconNotFoundException", "Lexicon not found") result = { - 'Lexicon': { - 'Name': lexicon_name, - 'Content': lexicon.content - }, - 'LexiconAttributes': lexicon.to_dict()['Attributes'] + "Lexicon": {"Name": lexicon_name, "Content": lexicon.content}, + "LexiconAttributes": lexicon.to_dict()["Attributes"], } return json.dumps(result) @@ -109,80 +110,94 @@ class PollyResponse(BaseResponse): try: self.polly_backend.delete_lexicon(lexicon_name) except KeyError: - return self._error('LexiconNotFoundException', 'Lexicon not found') + return self._error("LexiconNotFoundException", "Lexicon not found") - return '' + return "" # SynthesizeSpeech def speech(self): # Sanity check params args = { - 'lexicon_names': None, - 'sample_rate': 22050, - 'speech_marks': None, - 'text': None, - 'text_type': 'text' + "lexicon_names": None, + "sample_rate": 22050, + "speech_marks": None, + "text": None, + "text_type": "text", } - if 'LexiconNames' in self.json: - for lex in self.json['LexiconNames']: + if "LexiconNames" in self.json: + for lex in self.json["LexiconNames"]: try: self.polly_backend.get_lexicon(lex) except KeyError: - return self._error('LexiconNotFoundException', 'Lexicon not found') + return self._error("LexiconNotFoundException", "Lexicon not found") - args['lexicon_names'] = self.json['LexiconNames'] + args["lexicon_names"] = self.json["LexiconNames"] - if 'OutputFormat' not in self.json: - return self._error('MissingParameter', 'Missing parameter OutputFormat') - if self.json['OutputFormat'] not in ('json', 'mp3', 'ogg_vorbis', 'pcm'): - return self._error('InvalidParameterValue', 'Not one of json, mp3, ogg_vorbis, pcm') - args['output_format'] = self.json['OutputFormat'] + if "OutputFormat" not in self.json: + return self._error("MissingParameter", "Missing parameter OutputFormat") + if self.json["OutputFormat"] not in ("json", "mp3", "ogg_vorbis", "pcm"): + return self._error( + "InvalidParameterValue", "Not one of json, mp3, ogg_vorbis, pcm" + ) + args["output_format"] = self.json["OutputFormat"] - if 'SampleRate' in self.json: - sample_rate = int(self.json['SampleRate']) + if "SampleRate" in self.json: + sample_rate = int(self.json["SampleRate"]) if sample_rate not in (8000, 16000, 22050): - return self._error('InvalidSampleRateException', 'The specified sample rate is not valid.') - args['sample_rate'] = sample_rate + return self._error( + "InvalidSampleRateException", + "The specified sample rate is not valid.", + ) + args["sample_rate"] = sample_rate - if 'SpeechMarkTypes' in self.json: - for value in self.json['SpeechMarkTypes']: - if value not in ('sentance', 'ssml', 'viseme', 'word'): - return self._error('InvalidParameterValue', 'Not one of sentance, ssml, viseme, word') - args['speech_marks'] = self.json['SpeechMarkTypes'] + if "SpeechMarkTypes" in self.json: + for value in self.json["SpeechMarkTypes"]: + if value not in ("sentance", "ssml", "viseme", "word"): + return self._error( + "InvalidParameterValue", + "Not one of sentance, ssml, viseme, word", + ) + args["speech_marks"] = self.json["SpeechMarkTypes"] - if 'Text' not in self.json: - return self._error('MissingParameter', 'Missing parameter Text') - args['text'] = self.json['Text'] + if "Text" not in self.json: + return self._error("MissingParameter", "Missing parameter Text") + args["text"] = self.json["Text"] - if 'TextType' in self.json: - if self.json['TextType'] not in ('ssml', 'text'): - return self._error('InvalidParameterValue', 'Not one of ssml, text') - args['text_type'] = self.json['TextType'] + if "TextType" in self.json: + if self.json["TextType"] not in ("ssml", "text"): + return self._error("InvalidParameterValue", "Not one of ssml, text") + args["text_type"] = self.json["TextType"] - if 'VoiceId' not in self.json: - return self._error('MissingParameter', 'Missing parameter VoiceId') - if self.json['VoiceId'] not in VOICE_IDS: - return self._error('InvalidParameterValue', 'Not one of {0}'.format(', '.join(VOICE_IDS))) - args['voice_id'] = self.json['VoiceId'] + if "VoiceId" not in self.json: + return self._error("MissingParameter", "Missing parameter VoiceId") + if self.json["VoiceId"] not in VOICE_IDS: + return self._error( + "InvalidParameterValue", "Not one of {0}".format(", ".join(VOICE_IDS)) + ) + args["voice_id"] = self.json["VoiceId"] # More validation - if len(args['text']) > 3000: - return self._error('TextLengthExceededException', 'Text too long') + if len(args["text"]) > 3000: + return self._error("TextLengthExceededException", "Text too long") - if args['speech_marks'] is not None and args['output_format'] != 'json': - return self._error('MarksNotSupportedForFormatException', 'OutputFormat must be json') - if args['speech_marks'] is not None and args['text_type'] == 'text': - return self._error('SsmlMarksNotSupportedForTextTypeException', 'TextType must be ssml') + if args["speech_marks"] is not None and args["output_format"] != "json": + return self._error( + "MarksNotSupportedForFormatException", "OutputFormat must be json" + ) + if args["speech_marks"] is not None and args["text_type"] == "text": + return self._error( + "SsmlMarksNotSupportedForTextTypeException", "TextType must be ssml" + ) - content_type = 'audio/json' - if args['output_format'] == 'mp3': - content_type = 'audio/mpeg' - elif args['output_format'] == 'ogg_vorbis': - content_type = 'audio/ogg' - elif args['output_format'] == 'pcm': - content_type = 'audio/pcm' + content_type = "audio/json" + if args["output_format"] == "mp3": + content_type = "audio/mpeg" + elif args["output_format"] == "ogg_vorbis": + content_type = "audio/ogg" + elif args["output_format"] == "pcm": + content_type = "audio/pcm" - headers = {'Content-Type': content_type} + headers = {"Content-Type": content_type} - return '\x00\x00\x00\x00\x00\x00\x00\x00', headers + return "\x00\x00\x00\x00\x00\x00\x00\x00", headers diff --git a/moto/polly/urls.py b/moto/polly/urls.py index bd4057a0b..5408c8cc1 100644 --- a/moto/polly/urls.py +++ b/moto/polly/urls.py @@ -1,13 +1,11 @@ from __future__ import unicode_literals from .responses import PollyResponse -url_bases = [ - "https?://polly.(.+).amazonaws.com", -] +url_bases = ["https?://polly.(.+).amazonaws.com"] url_paths = { - '{0}/v1/voices': PollyResponse.dispatch, - '{0}/v1/lexicons/(?P[^/]+)': PollyResponse.dispatch, - '{0}/v1/lexicons': PollyResponse.dispatch, - '{0}/v1/speech': PollyResponse.dispatch, + "{0}/v1/voices": PollyResponse.dispatch, + "{0}/v1/lexicons/(?P[^/]+)": PollyResponse.dispatch, + "{0}/v1/lexicons": PollyResponse.dispatch, + "{0}/v1/speech": PollyResponse.dispatch, } diff --git a/moto/rds/__init__.py b/moto/rds/__init__.py index a4086d89c..bd260d023 100644 --- a/moto/rds/__init__.py +++ b/moto/rds/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import rds_backends from ..core.models import base_decorator, deprecated_base_decorator -rds_backend = rds_backends['us-east-1'] +rds_backend = rds_backends["us-east-1"] mock_rds = base_decorator(rds_backends) mock_rds_deprecated = deprecated_base_decorator(rds_backends) diff --git a/moto/rds/exceptions.py b/moto/rds/exceptions.py index 5bcc95560..cf9b9aac6 100644 --- a/moto/rds/exceptions.py +++ b/moto/rds/exceptions.py @@ -5,38 +5,34 @@ from werkzeug.exceptions import BadRequest class RDSClientError(BadRequest): - def __init__(self, code, message): super(RDSClientError, self).__init__() - self.description = json.dumps({ - "Error": { - "Code": code, - "Message": message, - 'Type': 'Sender', - }, - 'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1', - }) + self.description = json.dumps( + { + "Error": {"Code": code, "Message": message, "Type": "Sender"}, + "RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1", + } + ) class DBInstanceNotFoundError(RDSClientError): - def __init__(self, database_identifier): super(DBInstanceNotFoundError, self).__init__( - 'DBInstanceNotFound', - "Database {0} not found.".format(database_identifier)) + "DBInstanceNotFound", "Database {0} not found.".format(database_identifier) + ) class DBSecurityGroupNotFoundError(RDSClientError): - def __init__(self, security_group_name): super(DBSecurityGroupNotFoundError, self).__init__( - 'DBSecurityGroupNotFound', - "Security Group {0} not found.".format(security_group_name)) + "DBSecurityGroupNotFound", + "Security Group {0} not found.".format(security_group_name), + ) class DBSubnetGroupNotFoundError(RDSClientError): - def __init__(self, subnet_group_name): super(DBSubnetGroupNotFoundError, self).__init__( - 'DBSubnetGroupNotFound', - "Subnet Group {0} not found.".format(subnet_group_name)) + "DBSubnetGroupNotFound", + "Subnet Group {0} not found.".format(subnet_group_name), + ) diff --git a/moto/rds/models.py b/moto/rds/models.py index feecefe0c..421f3784b 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -1,7 +1,5 @@ from __future__ import unicode_literals -import datetime - import boto.rds from jinja2 import Template @@ -13,122 +11,34 @@ from moto.rds2.models import rds2_backends class Database(BaseModel): - - def __init__(self, **kwargs): - self.status = "available" - - self.is_replica = False - self.replicas = [] - - self.region = kwargs.get('region') - self.engine = kwargs.get("engine") - self.engine_version = kwargs.get("engine_version") - if self.engine_version is None: - self.engine_version = "5.6.21" - self.iops = kwargs.get("iops") - self.storage_encrypted = kwargs.get("storage_encrypted", False) - if self.storage_encrypted: - self.kms_key_id = kwargs.get("kms_key_id", "default_kms_key_id") - else: - self.kms_key_id = kwargs.get("kms_key_id") - self.storage_type = kwargs.get("storage_type") - self.master_username = kwargs.get('master_username') - self.master_password = kwargs.get('master_password') - self.auto_minor_version_upgrade = kwargs.get( - 'auto_minor_version_upgrade') - if self.auto_minor_version_upgrade is None: - self.auto_minor_version_upgrade = True - self.allocated_storage = kwargs.get('allocated_storage') - self.db_instance_identifier = kwargs.get('db_instance_identifier') - self.source_db_identifier = kwargs.get("source_db_identifier") - self.db_instance_class = kwargs.get('db_instance_class') - self.port = kwargs.get('port') - self.db_name = kwargs.get("db_name") - self.publicly_accessible = kwargs.get("publicly_accessible") - if self.publicly_accessible is None: - self.publicly_accessible = True - - self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot") - if self.copy_tags_to_snapshot is None: - self.copy_tags_to_snapshot = False - - self.backup_retention_period = kwargs.get("backup_retention_period") - if self.backup_retention_period is None: - self.backup_retention_period = 1 - - self.availability_zone = kwargs.get("availability_zone") - self.multi_az = kwargs.get("multi_az") - self.db_subnet_group_name = kwargs.get("db_subnet_group_name") - self.instance_create_time = str(datetime.datetime.utcnow()) - if self.db_subnet_group_name: - self.db_subnet_group = rds_backends[ - self.region].describe_subnet_groups(self.db_subnet_group_name)[0] - else: - self.db_subnet_group = [] - - self.security_groups = kwargs.get('security_groups', []) - - # PreferredBackupWindow - # PreferredMaintenanceWindow - # backup_retention_period = self._get_param("BackupRetentionPeriod") - # OptionGroupName - # DBParameterGroupName - # VpcSecurityGroupIds.member.N - - @property - def db_instance_arn(self): - return "arn:aws:rds:{0}:1234567890:db:{1}".format( - self.region, self.db_instance_identifier) - - @property - def physical_resource_id(self): - return self.db_instance_identifier - - @property - def address(self): - return "{0}.aaaaaaaaaa.{1}.rds.amazonaws.com".format(self.db_instance_identifier, self.region) - - def add_replica(self, replica): - self.replicas.append(replica.db_instance_identifier) - - def remove_replica(self, replica): - self.replicas.remove(replica.db_instance_identifier) - - def set_as_replica(self): - self.is_replica = True - self.replicas = [] - - def update(self, db_kwargs): - for key, value in db_kwargs.items(): - if value is not None: - setattr(self, key, value) - def get_cfn_attribute(self, attribute_name): - if attribute_name == 'Endpoint.Address': + if attribute_name == "Endpoint.Address": return self.address - elif attribute_name == 'Endpoint.Port': + elif attribute_name == "Endpoint.Port": return self.port raise UnformattedGetAttTemplateException() @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - db_instance_identifier = properties.get('DBInstanceIdentifier') + db_instance_identifier = properties.get("DBInstanceIdentifier") if not db_instance_identifier: db_instance_identifier = resource_name.lower() + get_random_hex(12) - db_security_groups = properties.get('DBSecurityGroups') + db_security_groups = properties.get("DBSecurityGroups") if not db_security_groups: db_security_groups = [] security_groups = [group.group_name for group in db_security_groups] db_subnet_group = properties.get("DBSubnetGroupName") db_subnet_group_name = db_subnet_group.subnet_name if db_subnet_group else None db_kwargs = { - "auto_minor_version_upgrade": properties.get('AutoMinorVersionUpgrade'), - "allocated_storage": properties.get('AllocatedStorage'), + "auto_minor_version_upgrade": properties.get("AutoMinorVersionUpgrade"), + "allocated_storage": properties.get("AllocatedStorage"), "availability_zone": properties.get("AvailabilityZone"), "backup_retention_period": properties.get("BackupRetentionPeriod"), - "db_instance_class": properties.get('DBInstanceClass'), + "db_instance_class": properties.get("DBInstanceClass"), "db_instance_identifier": db_instance_identifier, "db_name": properties.get("DBName"), "db_subnet_group_name": db_subnet_group_name, @@ -136,10 +46,10 @@ class Database(BaseModel): "engine_version": properties.get("EngineVersion"), "iops": properties.get("Iops"), "kms_key_id": properties.get("KmsKeyId"), - "master_password": properties.get('MasterUserPassword'), - "master_username": properties.get('MasterUsername'), + "master_password": properties.get("MasterUserPassword"), + "master_username": properties.get("MasterUsername"), "multi_az": properties.get("MultiAZ"), - "port": properties.get('Port', 3306), + "port": properties.get("Port", 3306), "publicly_accessible": properties.get("PubliclyAccessible"), "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"), "region": region_name, @@ -160,7 +70,8 @@ class Database(BaseModel): return database def to_xml(self): - template = Template(""" + template = Template( + """ {{ database.backup_retention_period }} {{ database.status }} {{ database.multi_az }} @@ -243,7 +154,8 @@ class Database(BaseModel): {{ database.port }} {{ database.db_instance_arn }} - """) + """ + ) return template.render(database=self) def delete(self, region_name): @@ -252,7 +164,6 @@ class Database(BaseModel): class SecurityGroup(BaseModel): - def __init__(self, group_name, description): self.group_name = group_name self.description = description @@ -261,7 +172,8 @@ class SecurityGroup(BaseModel): self.ec2_security_groups = [] def to_xml(self): - template = Template(""" + template = Template( + """ {% for security_group in security_group.ec2_security_groups %} @@ -284,7 +196,8 @@ class SecurityGroup(BaseModel): {{ security_group.ownder_id }} {{ security_group.group_name }} - """) + """ + ) return template.render(security_group=self) def authorize_cidr(self, cidr_ip): @@ -294,20 +207,19 @@ class SecurityGroup(BaseModel): self.ec2_security_groups.append(security_group) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] group_name = resource_name.lower() + get_random_hex(12) - description = properties['GroupDescription'] - security_group_ingress_rules = properties.get( - 'DBSecurityGroupIngress', []) - tags = properties.get('Tags') + description = properties["GroupDescription"] + security_group_ingress_rules = properties.get("DBSecurityGroupIngress", []) + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] rds_backend = rds_backends[region_name] security_group = rds_backend.create_security_group( - group_name, - description, - tags, + group_name, description, tags ) for security_group_ingress in security_group_ingress_rules: @@ -315,12 +227,10 @@ class SecurityGroup(BaseModel): if ingress_type == "CIDRIP": security_group.authorize_cidr(ingress_value) elif ingress_type == "EC2SecurityGroupName": - subnet = ec2_backend.get_security_group_from_name( - ingress_value) + subnet = ec2_backend.get_security_group_from_name(ingress_value) security_group.authorize_security_group(subnet) elif ingress_type == "EC2SecurityGroupId": - subnet = ec2_backend.get_security_group_from_id( - ingress_value) + subnet = ec2_backend.get_security_group_from_id(ingress_value) security_group.authorize_security_group(subnet) return security_group @@ -330,7 +240,6 @@ class SecurityGroup(BaseModel): class SubnetGroup(BaseModel): - def __init__(self, subnet_name, description, subnets): self.subnet_name = subnet_name self.description = description @@ -340,7 +249,8 @@ class SubnetGroup(BaseModel): self.vpc_id = self.subnets[0].vpc_id def to_xml(self): - template = Template(""" + template = Template( + """ {{ subnet_group.vpc_id }} {{ subnet_group.status }} {{ subnet_group.description }} @@ -357,27 +267,26 @@ class SubnetGroup(BaseModel): {% endfor %} - """) + """ + ) return template.render(subnet_group=self) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] subnet_name = resource_name.lower() + get_random_hex(12) - description = properties['DBSubnetGroupDescription'] - subnet_ids = properties['SubnetIds'] - tags = properties.get('Tags') + description = properties["DBSubnetGroupDescription"] + subnet_ids = properties["SubnetIds"] + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] - subnets = [ec2_backend.get_subnet(subnet_id) - for subnet_id in subnet_ids] + subnets = [ec2_backend.get_subnet(subnet_id) for subnet_id in subnet_ids] rds_backend = rds_backends[region_name] subnet_group = rds_backend.create_subnet_group( - subnet_name, - description, - subnets, - tags, + subnet_name, description, subnets, tags ) return subnet_group @@ -387,7 +296,6 @@ class SubnetGroup(BaseModel): class RDSBackend(BaseBackend): - def __init__(self, region): self.region = region @@ -405,5 +313,6 @@ class RDSBackend(BaseBackend): return rds2_backends[self.region] -rds_backends = dict((region.name, RDSBackend(region.name)) - for region in boto.rds.regions()) +rds_backends = dict( + (region.name, RDSBackend(region.name)) for region in boto.rds.regions() +) diff --git a/moto/rds/responses.py b/moto/rds/responses.py index 0afb03979..e3d37effc 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -6,19 +6,18 @@ from .models import rds_backends class RDSResponse(BaseResponse): - @property def backend(self): return rds_backends[self.region] def _get_db_kwargs(self): args = { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), - "allocated_storage": self._get_int_param('AllocatedStorage'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), + "allocated_storage": self._get_int_param("AllocatedStorage"), "availability_zone": self._get_param("AvailabilityZone"), "backup_retention_period": self._get_param("BackupRetentionPeriod"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_name": self._get_param("DBName"), # DBParameterGroupName "db_subnet_group_name": self._get_param("DBSubnetGroupName"), @@ -26,48 +25,48 @@ class RDSResponse(BaseResponse): "engine_version": self._get_param("EngineVersion"), "iops": self._get_int_param("Iops"), "kms_key_id": self._get_param("KmsKeyId"), - "master_password": self._get_param('MasterUserPassword'), - "master_username": self._get_param('MasterUsername'), + "master_password": self._get_param("MasterUserPassword"), + "master_username": self._get_param("MasterUsername"), "multi_az": self._get_bool_param("MultiAZ"), # OptionGroupName - "port": self._get_param('Port'), + "port": self._get_param("Port"), # PreferredBackupWindow # PreferredMaintenanceWindow "publicly_accessible": self._get_param("PubliclyAccessible"), "region": self.region, - "security_groups": self._get_multi_param('DBSecurityGroups.member'), + "security_groups": self._get_multi_param("DBSecurityGroups.member"), "storage_encrypted": self._get_param("StorageEncrypted"), "storage_type": self._get_param("StorageType"), # VpcSecurityGroupIds.member.N "tags": list(), } - args['tags'] = self.unpack_complex_list_params( - 'Tags.Tag', ('Key', 'Value')) + args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) return args def _get_db_replica_kwargs(self): return { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "availability_zone": self._get_param("AvailabilityZone"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"), "iops": self._get_int_param("Iops"), # OptionGroupName - "port": self._get_param('Port'), + "port": self._get_param("Port"), "publicly_accessible": self._get_param("PubliclyAccessible"), - "source_db_identifier": self._get_param('SourceDBInstanceIdentifier'), + "source_db_identifier": self._get_param("SourceDBInstanceIdentifier"), "storage_type": self._get_param("StorageType"), } def unpack_complex_list_params(self, label, names): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}.{2}'.format(label, count, names[0])): + while self._get_param("{0}.{1}.{2}".format(label, count, names[0])): param = dict() for i in range(len(names)): param[names[i]] = self._get_param( - '{0}.{1}.{2}'.format(label, count, names[i])) + "{0}.{1}.{2}".format(label, count, names[i]) + ) unpacked_list.append(param) count += 1 return unpacked_list @@ -87,16 +86,18 @@ class RDSResponse(BaseResponse): return template.render(database=database) def describe_db_instances(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") all_instances = list(self.backend.describe_databases(db_instance_identifier)) - marker = self._get_param('Marker') + marker = self._get_param("Marker") all_ids = [instance.db_instance_identifier for instance in all_instances] if marker: start = all_ids.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier - instances_resp = all_instances[start:start + page_size] + page_size = self._get_int_param( + "MaxRecords", 50 + ) # the default is 100, but using 50 to make testing easier + instances_resp = all_instances[start : start + page_size] next_marker = None if len(all_instances) > start + page_size: next_marker = instances_resp[-1].db_instance_identifier @@ -105,73 +106,74 @@ class RDSResponse(BaseResponse): return template.render(databases=instances_resp, marker=next_marker) def modify_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") db_kwargs = self._get_db_kwargs() - new_db_instance_identifier = self._get_param('NewDBInstanceIdentifier') + new_db_instance_identifier = self._get_param("NewDBInstanceIdentifier") if new_db_instance_identifier: - db_kwargs['new_db_instance_identifier'] = new_db_instance_identifier - database = self.backend.modify_database( - db_instance_identifier, db_kwargs) + db_kwargs["new_db_instance_identifier"] = new_db_instance_identifier + database = self.backend.modify_database(db_instance_identifier, db_kwargs) template = self.response_template(MODIFY_DATABASE_TEMPLATE) return template.render(database=database) def delete_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.delete_database(db_instance_identifier) template = self.response_template(DELETE_DATABASE_TEMPLATE) return template.render(database=database) def create_db_security_group(self): - group_name = self._get_param('DBSecurityGroupName') - description = self._get_param('DBSecurityGroupDescription') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + group_name = self._get_param("DBSecurityGroupName") + description = self._get_param("DBSecurityGroupDescription") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) security_group = self.backend.create_security_group( - group_name, description, tags) + group_name, description, tags + ) template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def describe_db_security_groups(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_groups = self.backend.describe_security_groups( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_groups = self.backend.describe_security_groups(security_group_name) template = self.response_template(DESCRIBE_SECURITY_GROUPS_TEMPLATE) return template.render(security_groups=security_groups) def delete_db_security_group(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_group = self.backend.delete_security_group( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_group = self.backend.delete_security_group(security_group_name) template = self.response_template(DELETE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def authorize_db_security_group_ingress(self): - security_group_name = self._get_param('DBSecurityGroupName') - cidr_ip = self._get_param('CIDRIP') + security_group_name = self._get_param("DBSecurityGroupName") + cidr_ip = self._get_param("CIDRIP") security_group = self.backend.authorize_security_group( - security_group_name, cidr_ip) + security_group_name, cidr_ip + ) template = self.response_template(AUTHORIZE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def create_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') - description = self._get_param('DBSubnetGroupDescription') - subnet_ids = self._get_multi_param('SubnetIds.member') - subnets = [ec2_backends[self.region].get_subnet( - subnet_id) for subnet_id in subnet_ids] - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + subnet_name = self._get_param("DBSubnetGroupName") + description = self._get_param("DBSubnetGroupDescription") + subnet_ids = self._get_multi_param("SubnetIds.member") + subnets = [ + ec2_backends[self.region].get_subnet(subnet_id) for subnet_id in subnet_ids + ] + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) subnet_group = self.backend.create_subnet_group( - subnet_name, description, subnets, tags) + subnet_name, description, subnets, tags + ) template = self.response_template(CREATE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) def describe_db_subnet_groups(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_groups = self.backend.describe_subnet_groups(subnet_name) template = self.response_template(DESCRIBE_SUBNET_GROUPS_TEMPLATE) return template.render(subnet_groups=subnet_groups) def delete_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_group = self.backend.delete_subnet_group(subnet_name) template = self.response_template(DELETE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) diff --git a/moto/rds/urls.py b/moto/rds/urls.py index 646f17304..9c7570167 100644 --- a/moto/rds/urls.py +++ b/moto/rds/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import RDSResponse -url_bases = [ - "https?://rds(\..+)?.amazonaws.com", -] +url_bases = ["https?://rds(\..+)?.amazonaws.com"] -url_paths = { - '{0}/$': RDSResponse.dispatch, -} +url_paths = {"{0}/$": RDSResponse.dispatch} diff --git a/moto/rds2/__init__.py b/moto/rds2/__init__.py index 723fa0968..acc8564e2 100644 --- a/moto/rds2/__init__.py +++ b/moto/rds2/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import rds2_backends from ..core.models import base_decorator, deprecated_base_decorator -rds2_backend = rds2_backends['us-west-1'] +rds2_backend = rds2_backends["us-west-1"] mock_rds2 = base_decorator(rds2_backends) mock_rds2_deprecated = deprecated_base_decorator(rds2_backends) diff --git a/moto/rds2/exceptions.py b/moto/rds2/exceptions.py index e82ae7077..b6dc5bb99 100644 --- a/moto/rds2/exceptions.py +++ b/moto/rds2/exceptions.py @@ -5,10 +5,10 @@ from werkzeug.exceptions import BadRequest class RDSClientError(BadRequest): - def __init__(self, code, message): super(RDSClientError, self).__init__() - template = Template(""" + template = Template( + """ {{ code }} @@ -16,87 +16,94 @@ class RDSClientError(BadRequest): Sender 6876f774-7273-11e4-85dc-39e55ca848d1 - """) + """ + ) self.description = template.render(code=code, message=message) class DBInstanceNotFoundError(RDSClientError): - def __init__(self, database_identifier): super(DBInstanceNotFoundError, self).__init__( - 'DBInstanceNotFound', - "Database {0} not found.".format(database_identifier)) + "DBInstanceNotFound", "Database {0} not found.".format(database_identifier) + ) class DBSnapshotNotFoundError(RDSClientError): - def __init__(self): super(DBSnapshotNotFoundError, self).__init__( - 'DBSnapshotNotFound', - "DBSnapshotIdentifier does not refer to an existing DB snapshot.") + "DBSnapshotNotFound", + "DBSnapshotIdentifier does not refer to an existing DB snapshot.", + ) class DBSecurityGroupNotFoundError(RDSClientError): - def __init__(self, security_group_name): super(DBSecurityGroupNotFoundError, self).__init__( - 'DBSecurityGroupNotFound', - "Security Group {0} not found.".format(security_group_name)) + "DBSecurityGroupNotFound", + "Security Group {0} not found.".format(security_group_name), + ) class DBSubnetGroupNotFoundError(RDSClientError): - def __init__(self, subnet_group_name): super(DBSubnetGroupNotFoundError, self).__init__( - 'DBSubnetGroupNotFound', - "Subnet Group {0} not found.".format(subnet_group_name)) + "DBSubnetGroupNotFound", + "Subnet Group {0} not found.".format(subnet_group_name), + ) class DBParameterGroupNotFoundError(RDSClientError): - def __init__(self, db_parameter_group_name): super(DBParameterGroupNotFoundError, self).__init__( - 'DBParameterGroupNotFound', - 'DB Parameter Group {0} not found.'.format(db_parameter_group_name)) + "DBParameterGroupNotFound", + "DB Parameter Group {0} not found.".format(db_parameter_group_name), + ) class OptionGroupNotFoundFaultError(RDSClientError): - def __init__(self, option_group_name): super(OptionGroupNotFoundFaultError, self).__init__( - 'OptionGroupNotFoundFault', - 'Specified OptionGroupName: {0} not found.'.format(option_group_name) + "OptionGroupNotFoundFault", + "Specified OptionGroupName: {0} not found.".format(option_group_name), ) class InvalidDBClusterStateFaultError(RDSClientError): - def __init__(self, database_identifier): super(InvalidDBClusterStateFaultError, self).__init__( - 'InvalidDBClusterStateFault', - 'Invalid DB type, when trying to perform StopDBInstance on {0}e. See AWS RDS documentation on rds.stop_db_instance'.format(database_identifier)) + "InvalidDBClusterStateFault", + "Invalid DB type, when trying to perform StopDBInstance on {0}e. See AWS RDS documentation on rds.stop_db_instance".format( + database_identifier + ), + ) class InvalidDBInstanceStateError(RDSClientError): - def __init__(self, database_identifier, istate): - estate = "in available state" if istate == 'stop' else "stopped, it cannot be started" + estate = ( + "in available state" + if istate == "stop" + else "stopped, it cannot be started" + ) super(InvalidDBInstanceStateError, self).__init__( - 'InvalidDBInstanceState', - 'Instance {} is not {}.'.format(database_identifier, estate)) + "InvalidDBInstanceState", + "Instance {} is not {}.".format(database_identifier, estate), + ) class SnapshotQuotaExceededError(RDSClientError): - def __init__(self): super(SnapshotQuotaExceededError, self).__init__( - 'SnapshotQuotaExceeded', - 'The request cannot be processed because it would exceed the maximum number of snapshots.') + "SnapshotQuotaExceeded", + "The request cannot be processed because it would exceed the maximum number of snapshots.", + ) class DBSnapshotAlreadyExistsError(RDSClientError): - def __init__(self, database_snapshot_identifier): super(DBSnapshotAlreadyExistsError, self).__init__( - 'DBSnapshotAlreadyExists', - 'Cannot create the snapshot because a snapshot with the identifier {} already exists.'.format(database_snapshot_identifier)) + "DBSnapshotAlreadyExists", + "Cannot create the snapshot because a snapshot with the identifier {} already exists.".format( + database_snapshot_identifier + ), + ) diff --git a/moto/rds2/models.py b/moto/rds2/models.py index 4c0daa230..686d22ccf 100644 --- a/moto/rds2/models.py +++ b/moto/rds2/models.py @@ -14,39 +14,41 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import get_random_hex from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.ec2.models import ec2_backends -from .exceptions import (RDSClientError, - DBInstanceNotFoundError, - DBSnapshotNotFoundError, - DBSecurityGroupNotFoundError, - DBSubnetGroupNotFoundError, - DBParameterGroupNotFoundError, - OptionGroupNotFoundFaultError, - InvalidDBClusterStateFaultError, - InvalidDBInstanceStateError, - SnapshotQuotaExceededError, - DBSnapshotAlreadyExistsError) +from .exceptions import ( + RDSClientError, + DBInstanceNotFoundError, + DBSnapshotNotFoundError, + DBSecurityGroupNotFoundError, + DBSubnetGroupNotFoundError, + DBParameterGroupNotFoundError, + OptionGroupNotFoundFaultError, + InvalidDBClusterStateFaultError, + InvalidDBInstanceStateError, + SnapshotQuotaExceededError, + DBSnapshotAlreadyExistsError, +) class Database(BaseModel): - def __init__(self, **kwargs): self.status = "available" self.is_replica = False self.replicas = [] - self.region = kwargs.get('region') + self.region = kwargs.get("region") self.engine = kwargs.get("engine") self.engine_version = kwargs.get("engine_version", None) - self.default_engine_versions = {"MySQL": "5.6.21", - "mysql": "5.6.21", - "oracle-se1": "11.2.0.4.v3", - "oracle-se": "11.2.0.4.v3", - "oracle-ee": "11.2.0.4.v3", - "sqlserver-ee": "11.00.2100.60.v1", - "sqlserver-se": "11.00.2100.60.v1", - "sqlserver-ex": "11.00.2100.60.v1", - "sqlserver-web": "11.00.2100.60.v1", - "postgres": "9.3.3" - } + self.default_engine_versions = { + "MySQL": "5.6.21", + "mysql": "5.6.21", + "oracle-se1": "11.2.0.4.v3", + "oracle-se": "11.2.0.4.v3", + "oracle-ee": "11.2.0.4.v3", + "sqlserver-ee": "11.00.2100.60.v1", + "sqlserver-se": "11.00.2100.60.v1", + "sqlserver-ex": "11.00.2100.60.v1", + "sqlserver-web": "11.00.2100.60.v1", + "postgres": "9.3.3", + } if not self.engine_version and self.engine in self.default_engine_versions: self.engine_version = self.default_engine_versions[self.engine] self.iops = kwargs.get("iops") @@ -56,22 +58,29 @@ class Database(BaseModel): else: self.kms_key_id = kwargs.get("kms_key_id") self.storage_type = kwargs.get("storage_type") - self.master_username = kwargs.get('master_username') - self.master_user_password = kwargs.get('master_user_password') - self.auto_minor_version_upgrade = kwargs.get( - 'auto_minor_version_upgrade') + if self.storage_type is None: + self.storage_type = Database.default_storage_type(iops=self.iops) + self.master_username = kwargs.get("master_username") + self.master_user_password = kwargs.get("master_user_password") + self.auto_minor_version_upgrade = kwargs.get("auto_minor_version_upgrade") if self.auto_minor_version_upgrade is None: self.auto_minor_version_upgrade = True - self.allocated_storage = kwargs.get('allocated_storage') - self.db_instance_identifier = kwargs.get('db_instance_identifier') + self.allocated_storage = kwargs.get("allocated_storage") + if self.allocated_storage is None: + self.allocated_storage = Database.default_allocated_storage( + engine=self.engine, storage_type=self.storage_type + ) + self.db_instance_identifier = kwargs.get("db_instance_identifier") self.source_db_identifier = kwargs.get("source_db_identifier") - self.db_instance_class = kwargs.get('db_instance_class') - self.port = kwargs.get('port') + self.db_instance_class = kwargs.get("db_instance_class") + self.port = kwargs.get("port") if self.port is None: self.port = Database.default_port(self.engine) - self.db_instance_identifier = kwargs.get('db_instance_identifier') + self.db_instance_identifier = kwargs.get("db_instance_identifier") self.db_name = kwargs.get("db_name") - self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) + self.instance_create_time = iso_8601_datetime_with_milliseconds( + datetime.datetime.now() + ) self.publicly_accessible = kwargs.get("publicly_accessible") if self.publicly_accessible is None: self.publicly_accessible = True @@ -85,39 +94,51 @@ class Database(BaseModel): self.multi_az = kwargs.get("multi_az") self.db_subnet_group_name = kwargs.get("db_subnet_group_name") if self.db_subnet_group_name: - self.db_subnet_group = rds2_backends[ - self.region].describe_subnet_groups(self.db_subnet_group_name)[0] + self.db_subnet_group = rds2_backends[self.region].describe_subnet_groups( + self.db_subnet_group_name + )[0] else: self.db_subnet_group = None - self.security_groups = kwargs.get('security_groups', []) - self.vpc_security_group_ids = kwargs.get('vpc_security_group_ids', []) + self.security_groups = kwargs.get("security_groups", []) + self.vpc_security_group_ids = kwargs.get("vpc_security_group_ids", []) self.preferred_maintenance_window = kwargs.get( - 'preferred_maintenance_window', 'wed:06:38-wed:07:08') - self.db_parameter_group_name = kwargs.get('db_parameter_group_name') - if self.db_parameter_group_name and self.db_parameter_group_name not in rds2_backends[self.region].db_parameter_groups: + "preferred_maintenance_window", "wed:06:38-wed:07:08" + ) + self.db_parameter_group_name = kwargs.get("db_parameter_group_name") + if ( + self.db_parameter_group_name + and self.db_parameter_group_name + not in rds2_backends[self.region].db_parameter_groups + ): raise DBParameterGroupNotFoundError(self.db_parameter_group_name) self.preferred_backup_window = kwargs.get( - 'preferred_backup_window', '13:14-13:44') - self.license_model = kwargs.get('license_model', 'general-public-license') - self.option_group_name = kwargs.get('option_group_name', None) - if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups: + "preferred_backup_window", "13:14-13:44" + ) + self.license_model = kwargs.get("license_model", "general-public-license") + self.option_group_name = kwargs.get("option_group_name", None) + if ( + self.option_group_name + and self.option_group_name not in rds2_backends[self.region].option_groups + ): raise OptionGroupNotFoundFaultError(self.option_group_name) - self.default_option_groups = {"MySQL": "default.mysql5.6", - "mysql": "default.mysql5.6", - "postgres": "default.postgres9.3" - } + self.default_option_groups = { + "MySQL": "default.mysql5.6", + "mysql": "default.mysql5.6", + "postgres": "default.postgres9.3", + } if not self.option_group_name and self.engine in self.default_option_groups: self.option_group_name = self.default_option_groups[self.engine] - self.character_set_name = kwargs.get('character_set_name', None) + self.character_set_name = kwargs.get("character_set_name", None) self.iam_database_authentication_enabled = False self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U" - self.tags = kwargs.get('tags', []) + self.tags = kwargs.get("tags", []) @property def db_instance_arn(self): return "arn:aws:rds:{0}:1234567890:db:{1}".format( - self.region, self.db_instance_identifier) + self.region, self.db_instance_identifier + ) @property def physical_resource_id(self): @@ -125,26 +146,38 @@ class Database(BaseModel): def db_parameter_groups(self): if not self.db_parameter_group_name: - db_family, db_parameter_group_name = self.default_db_parameter_group_details() - description = 'Default parameter group for {0}'.format(db_family) - return [DBParameterGroup(name=db_parameter_group_name, - family=db_family, - description=description, - tags={})] + ( + db_family, + db_parameter_group_name, + ) = self.default_db_parameter_group_details() + description = "Default parameter group for {0}".format(db_family) + return [ + DBParameterGroup( + name=db_parameter_group_name, + family=db_family, + description=description, + tags={}, + ) + ] else: - return [rds2_backends[self.region].db_parameter_groups[self.db_parameter_group_name]] + return [ + rds2_backends[self.region].db_parameter_groups[ + self.db_parameter_group_name + ] + ] def default_db_parameter_group_details(self): if not self.engine_version: return (None, None) - minor_engine_version = '.'.join(self.engine_version.rsplit('.')[:-1]) - db_family = '{0}{1}'.format(self.engine.lower(), minor_engine_version) + minor_engine_version = ".".join(self.engine_version.rsplit(".")[:-1]) + db_family = "{0}{1}".format(self.engine.lower(), minor_engine_version) - return db_family, 'default.{0}'.format(db_family) + return db_family, "default.{0}".format(db_family) def to_xml(self): - template = Template(""" + template = Template( + """ {{ database.backup_retention_period }} {{ database.status }} {% if database.db_name %}{{ database.db_name }}{% endif %} @@ -247,12 +280,15 @@ class Database(BaseModel): {{ database.port }} {{ database.db_instance_arn }} - """) + """ + ) return template.render(database=self) @property def address(self): - return "{0}.aaaaaaaaaa.{1}.rds.amazonaws.com".format(self.db_instance_identifier, self.region) + return "{0}.aaaaaaaaaa.{1}.rds.amazonaws.com".format( + self.db_instance_identifier, self.region + ) def add_replica(self, replica): self.replicas.append(replica.db_instance_identifier) @@ -270,47 +306,73 @@ class Database(BaseModel): setattr(self, key, value) def get_cfn_attribute(self, attribute_name): - if attribute_name == 'Endpoint.Address': + if attribute_name == "Endpoint.Address": return self.address - elif attribute_name == 'Endpoint.Port': + elif attribute_name == "Endpoint.Port": return self.port raise UnformattedGetAttTemplateException() @staticmethod def default_port(engine): return { - 'mysql': 3306, - 'mariadb': 3306, - 'postgres': 5432, - 'oracle-ee': 1521, - 'oracle-se2': 1521, - 'oracle-se1': 1521, - 'oracle-se': 1521, - 'sqlserver-ee': 1433, - 'sqlserver-ex': 1433, - 'sqlserver-se': 1433, - 'sqlserver-web': 1433, + "mysql": 3306, + "mariadb": 3306, + "postgres": 5432, + "oracle-ee": 1521, + "oracle-se2": 1521, + "oracle-se1": 1521, + "oracle-se": 1521, + "sqlserver-ee": 1433, + "sqlserver-ex": 1433, + "sqlserver-se": 1433, + "sqlserver-web": 1433, }[engine] - @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + @staticmethod + def default_storage_type(iops): + if iops is None: + return "gp2" + else: + return "io1" - db_instance_identifier = properties.get('DBInstanceIdentifier') + @staticmethod + def default_allocated_storage(engine, storage_type): + return { + "aurora": {"gp2": 0, "io1": 0, "standard": 0}, + "mysql": {"gp2": 20, "io1": 100, "standard": 5}, + "mariadb": {"gp2": 20, "io1": 100, "standard": 5}, + "postgres": {"gp2": 20, "io1": 100, "standard": 5}, + "oracle-ee": {"gp2": 20, "io1": 100, "standard": 10}, + "oracle-se2": {"gp2": 20, "io1": 100, "standard": 10}, + "oracle-se1": {"gp2": 20, "io1": 100, "standard": 10}, + "oracle-se": {"gp2": 20, "io1": 100, "standard": 10}, + "sqlserver-ee": {"gp2": 200, "io1": 200, "standard": 200}, + "sqlserver-ex": {"gp2": 20, "io1": 100, "standard": 20}, + "sqlserver-se": {"gp2": 200, "io1": 200, "standard": 200}, + "sqlserver-web": {"gp2": 20, "io1": 100, "standard": 20}, + }[engine][storage_type] + + @classmethod + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + + db_instance_identifier = properties.get("DBInstanceIdentifier") if not db_instance_identifier: db_instance_identifier = resource_name.lower() + get_random_hex(12) - db_security_groups = properties.get('DBSecurityGroups') + db_security_groups = properties.get("DBSecurityGroups") if not db_security_groups: db_security_groups = [] security_groups = [group.group_name for group in db_security_groups] db_subnet_group = properties.get("DBSubnetGroupName") db_subnet_group_name = db_subnet_group.subnet_name if db_subnet_group else None db_kwargs = { - "auto_minor_version_upgrade": properties.get('AutoMinorVersionUpgrade'), - "allocated_storage": properties.get('AllocatedStorage'), + "auto_minor_version_upgrade": properties.get("AutoMinorVersionUpgrade"), + "allocated_storage": properties.get("AllocatedStorage"), "availability_zone": properties.get("AvailabilityZone"), "backup_retention_period": properties.get("BackupRetentionPeriod"), - "db_instance_class": properties.get('DBInstanceClass'), + "db_instance_class": properties.get("DBInstanceClass"), "db_instance_identifier": db_instance_identifier, "db_name": properties.get("DBName"), "db_subnet_group_name": db_subnet_group_name, @@ -318,11 +380,11 @@ class Database(BaseModel): "engine_version": properties.get("EngineVersion"), "iops": properties.get("Iops"), "kms_key_id": properties.get("KmsKeyId"), - "master_user_password": properties.get('MasterUserPassword'), - "master_username": properties.get('MasterUsername'), + "master_user_password": properties.get("MasterUserPassword"), + "master_username": properties.get("MasterUsername"), "multi_az": properties.get("MultiAZ"), - "db_parameter_group_name": properties.get('DBParameterGroupName'), - "port": properties.get('Port', 3306), + "db_parameter_group_name": properties.get("DBParameterGroupName"), + "port": properties.get("Port", 3306), "publicly_accessible": properties.get("PubliclyAccessible"), "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"), "region": region_name, @@ -330,7 +392,7 @@ class Database(BaseModel): "storage_encrypted": properties.get("StorageEncrypted"), "storage_type": properties.get("StorageType"), "tags": properties.get("Tags"), - "vpc_security_group_ids": properties.get('VpcSecurityGroupIds', []), + "vpc_security_group_ids": properties.get("VpcSecurityGroupIds", []), } rds2_backend = rds2_backends[region_name] @@ -344,7 +406,8 @@ class Database(BaseModel): return database def to_json(self): - template = Template("""{ + template = Template( + """{ "AllocatedStorage": 10, "AutoMinorVersionUpgrade": "{{ database.auto_minor_version_upgrade }}", "AvailabilityZone": "{{ database.availability_zone }}", @@ -413,22 +476,21 @@ class Database(BaseModel): {% endfor %} ], "DBInstanceArn": "{{ database.db_instance_arn }}" - }""") + }""" + ) return template.render(database=self) def get_tags(self): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def delete(self, region_name): backend = rds2_backends[region_name] @@ -444,10 +506,13 @@ class Snapshot(BaseModel): @property def snapshot_arn(self): - return "arn:aws:rds:{0}:1234567890:snapshot:{1}".format(self.database.region, self.snapshot_id) + return "arn:aws:rds:{0}:1234567890:snapshot:{1}".format( + self.database.region, self.snapshot_id + ) def to_xml(self): - template = Template(""" + template = Template( + """ {{ snapshot.snapshot_id }} {{ database.db_instance_identifier }} {{ snapshot.created_at }} @@ -478,26 +543,24 @@ class Snapshot(BaseModel): {{ snapshot.snapshot_arn }} false - """) + """ + ) return template.render(snapshot=self, database=self.database) def get_tags(self): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] class SecurityGroup(BaseModel): - def __init__(self, group_name, description, tags): self.group_name = group_name self.description = description @@ -505,11 +568,12 @@ class SecurityGroup(BaseModel): self.ip_ranges = [] self.ec2_security_groups = [] self.tags = tags - self.owner_id = '1234567890' + self.owner_id = "1234567890" self.vpc_id = None def to_xml(self): - template = Template(""" + template = Template( + """ {% for security_group in security_group.ec2_security_groups %} @@ -532,11 +596,13 @@ class SecurityGroup(BaseModel): {{ security_group.ownder_id }} {{ security_group.group_name }} - """) + """ + ) return template.render(security_group=self) def to_json(self): - template = Template("""{ + template = Template( + """{ "DBSecurityGroupDescription": "{{ security_group.description }}", "DBSecurityGroupName": "{{ security_group.group_name }}", "EC2SecurityGroups": {{ security_group.ec2_security_groups }}, @@ -547,7 +613,8 @@ class SecurityGroup(BaseModel): ], "OwnerId": "{{ security_group.owner_id }}", "VpcId": "{{ security_group.vpc_id }}" - }""") + }""" + ) return template.render(security_group=self) def authorize_cidr(self, cidr_ip): @@ -557,32 +624,29 @@ class SecurityGroup(BaseModel): self.ec2_security_groups.append(security_group) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] group_name = resource_name.lower() + get_random_hex(12) - description = properties['GroupDescription'] - security_group_ingress_rules = properties.get( - 'DBSecurityGroupIngress', []) - tags = properties.get('Tags') + description = properties["GroupDescription"] + security_group_ingress_rules = properties.get("DBSecurityGroupIngress", []) + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] rds2_backend = rds2_backends[region_name] security_group = rds2_backend.create_security_group( - group_name, - description, - tags, + group_name, description, tags ) for security_group_ingress in security_group_ingress_rules: for ingress_type, ingress_value in security_group_ingress.items(): if ingress_type == "CIDRIP": security_group.authorize_cidr(ingress_value) elif ingress_type == "EC2SecurityGroupName": - subnet = ec2_backend.get_security_group_from_name( - ingress_value) + subnet = ec2_backend.get_security_group_from_name(ingress_value) security_group.authorize_security_group(subnet) elif ingress_type == "EC2SecurityGroupId": - subnet = ec2_backend.get_security_group_from_id( - ingress_value) + subnet = ec2_backend.get_security_group_from_id(ingress_value) security_group.authorize_security_group(subnet) return security_group @@ -590,15 +654,13 @@ class SecurityGroup(BaseModel): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def delete(self, region_name): backend = rds2_backends[region_name] @@ -606,7 +668,6 @@ class SecurityGroup(BaseModel): class SubnetGroup(BaseModel): - def __init__(self, subnet_name, description, subnets, tags): self.subnet_name = subnet_name self.description = description @@ -616,7 +677,8 @@ class SubnetGroup(BaseModel): self.vpc_id = self.subnets[0].vpc_id def to_xml(self): - template = Template(""" + template = Template( + """ {{ subnet_group.vpc_id }} {{ subnet_group.status }} {{ subnet_group.description }} @@ -633,11 +695,13 @@ class SubnetGroup(BaseModel): {% endfor %} - """) + """ + ) return template.render(subnet_group=self) def to_json(self): - template = Template(""""DBSubnetGroup": { + template = Template( + """"DBSubnetGroup": { "VpcId": "{{ subnet_group.vpc_id }}", "SubnetGroupStatus": "{{ subnet_group.status }}", "DBSubnetGroupDescription": "{{ subnet_group.description }}", @@ -654,27 +718,26 @@ class SubnetGroup(BaseModel): }{%- if not loop.last -%},{%- endif -%}{% endfor %} ] } - }""") + }""" + ) return template.render(subnet_group=self) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] subnet_name = resource_name.lower() + get_random_hex(12) - description = properties['DBSubnetGroupDescription'] - subnet_ids = properties['SubnetIds'] - tags = properties.get('Tags') + description = properties["DBSubnetGroupDescription"] + subnet_ids = properties["SubnetIds"] + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] - subnets = [ec2_backend.get_subnet(subnet_id) - for subnet_id in subnet_ids] + subnets = [ec2_backend.get_subnet(subnet_id) for subnet_id in subnet_ids] rds2_backend = rds2_backends[region_name] subnet_group = rds2_backend.create_subnet_group( - subnet_name, - description, - subnets, - tags, + subnet_name, description, subnets, tags ) return subnet_group @@ -682,15 +745,13 @@ class SubnetGroup(BaseModel): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def delete(self, region_name): backend = rds2_backends[region_name] @@ -698,11 +759,11 @@ class SubnetGroup(BaseModel): class RDS2Backend(BaseBackend): - def __init__(self, region): self.region = region self.arn_regex = re_compile( - r'^arn:aws:rds:.*:[0-9]*:(db|es|og|pg|ri|secgrp|snapshot|subgrp):.*$') + r"^arn:aws:rds:.*:[0-9]*:(db|es|og|pg|ri|secgrp|snapshot|subgrp):.*$" + ) self.databases = OrderedDict() self.snapshots = OrderedDict() self.db_parameter_groups = {} @@ -717,18 +778,20 @@ class RDS2Backend(BaseBackend): self.__init__(region) def create_database(self, db_kwargs): - database_id = db_kwargs['db_instance_identifier'] + database_id = db_kwargs["db_instance_identifier"] database = Database(**db_kwargs) self.databases[database_id] = database return database - def create_snapshot(self, db_instance_identifier, db_snapshot_identifier, tags=None): + def create_snapshot( + self, db_instance_identifier, db_snapshot_identifier, tags=None + ): database = self.databases.get(db_instance_identifier) if not database: raise DBInstanceNotFoundError(db_instance_identifier) if db_snapshot_identifier in self.snapshots: raise DBSnapshotAlreadyExistsError(db_snapshot_identifier) - if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')): + if len(self.snapshots) >= int(os.environ.get("MOTO_RDS_SNAPSHOT_LIMIT", "100")): raise SnapshotQuotaExceededError() if tags is None: tags = list() @@ -745,11 +808,11 @@ class RDS2Backend(BaseBackend): return self.snapshots.pop(db_snapshot_identifier) def create_database_replica(self, db_kwargs): - database_id = db_kwargs['db_instance_identifier'] - source_database_id = db_kwargs['source_db_identifier'] + database_id = db_kwargs["db_instance_identifier"] + source_database_id = db_kwargs["source_db_identifier"] primary = self.find_db_from_id(source_database_id) if self.arn_regex.match(source_database_id): - db_kwargs['region'] = self.region + db_kwargs["region"] = self.region # Shouldn't really copy here as the instance is duplicated. RDS replicas have different instances. replica = copy.copy(primary) @@ -784,9 +847,11 @@ class RDS2Backend(BaseBackend): def modify_database(self, db_instance_identifier, db_kwargs): database = self.describe_databases(db_instance_identifier)[0] - if 'new_db_instance_identifier' in db_kwargs: + if "new_db_instance_identifier" in db_kwargs: del self.databases[db_instance_identifier] - db_instance_identifier = db_kwargs['db_instance_identifier'] = db_kwargs.pop('new_db_instance_identifier') + db_instance_identifier = db_kwargs[ + "db_instance_identifier" + ] = db_kwargs.pop("new_db_instance_identifier") self.databases[db_instance_identifier] = database database.update(db_kwargs) return database @@ -799,26 +864,26 @@ class RDS2Backend(BaseBackend): database = self.describe_databases(db_instance_identifier)[0] # todo: certain rds types not allowed to be stopped at this time. if database.is_replica or database.multi_az: - # todo: more db types not supported by stop/start instance api - raise InvalidDBClusterStateFaultError(db_instance_identifier) - if database.status != 'available': - raise InvalidDBInstanceStateError(db_instance_identifier, 'stop') + # todo: more db types not supported by stop/start instance api + raise InvalidDBClusterStateFaultError(db_instance_identifier) + if database.status != "available": + raise InvalidDBInstanceStateError(db_instance_identifier, "stop") if db_snapshot_identifier: self.create_snapshot(db_instance_identifier, db_snapshot_identifier) - database.status = 'stopped' + database.status = "stopped" return database def start_database(self, db_instance_identifier): database = self.describe_databases(db_instance_identifier)[0] # todo: bunch of different error messages to be generated from this api call - if database.status != 'stopped': - raise InvalidDBInstanceStateError(db_instance_identifier, 'start') - database.status = 'available' + if database.status != "stopped": + raise InvalidDBInstanceStateError(db_instance_identifier, "start") + database.status = "available" return database def find_db_from_id(self, db_id): if self.arn_regex.match(db_id): - arn_breakdown = db_id.split(':') + arn_breakdown = db_id.split(":") region = arn_breakdown[3] backend = rds2_backends[region] db_name = arn_breakdown[-1] @@ -836,7 +901,7 @@ class RDS2Backend(BaseBackend): if database.is_replica: primary = self.find_db_from_id(database.source_db_identifier) primary.remove_replica(database) - database.status = 'deleting' + database.status = "deleting" return database else: raise DBInstanceNotFoundError(db_instance_identifier) @@ -891,34 +956,49 @@ class RDS2Backend(BaseBackend): raise DBSubnetGroupNotFoundError(subnet_name) def create_option_group(self, option_group_kwargs): - option_group_id = option_group_kwargs['name'] - valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'], - 'mysql': ['5.5', '5.6', '5.7', '8.0'], - 'oracle-se2': ['11.2', '12.1', '12.2'], - 'oracle-se1': ['11.2', '12.1', '12.2'], - 'oracle-se': ['11.2', '12.1', '12.2'], - 'oracle-ee': ['11.2', '12.1', '12.2'], - 'sqlserver-se': ['10.50', '11.00'], - 'sqlserver-ee': ['10.50', '11.00'], - 'sqlserver-ex': ['10.50', '11.00'], - 'sqlserver-web': ['10.50', '11.00']} - if option_group_kwargs['name'] in self.option_groups: - raise RDSClientError('OptionGroupAlreadyExistsFault', - 'An option group named {0} already exists.'.format(option_group_kwargs['name'])) - if 'description' not in option_group_kwargs or not option_group_kwargs['description']: - raise RDSClientError('InvalidParameterValue', - 'The parameter OptionGroupDescription must be provided and must not be blank.') - if option_group_kwargs['engine_name'] not in valid_option_group_engines.keys(): - raise RDSClientError('InvalidParameterValue', - 'Invalid DB engine: non-existant') - if option_group_kwargs['major_engine_version'] not in\ - valid_option_group_engines[option_group_kwargs['engine_name']]: - raise RDSClientError('InvalidParameterCombination', - 'Cannot find major version {0} for {1}'.format( - option_group_kwargs[ - 'major_engine_version'], - option_group_kwargs['engine_name'] - )) + option_group_id = option_group_kwargs["name"] + valid_option_group_engines = { + "mariadb": ["10.0", "10.1", "10.2", "10.3"], + "mysql": ["5.5", "5.6", "5.7", "8.0"], + "oracle-se2": ["11.2", "12.1", "12.2"], + "oracle-se1": ["11.2", "12.1", "12.2"], + "oracle-se": ["11.2", "12.1", "12.2"], + "oracle-ee": ["11.2", "12.1", "12.2"], + "sqlserver-se": ["10.50", "11.00"], + "sqlserver-ee": ["10.50", "11.00"], + "sqlserver-ex": ["10.50", "11.00"], + "sqlserver-web": ["10.50", "11.00"], + } + if option_group_kwargs["name"] in self.option_groups: + raise RDSClientError( + "OptionGroupAlreadyExistsFault", + "An option group named {0} already exists.".format( + option_group_kwargs["name"] + ), + ) + if ( + "description" not in option_group_kwargs + or not option_group_kwargs["description"] + ): + raise RDSClientError( + "InvalidParameterValue", + "The parameter OptionGroupDescription must be provided and must not be blank.", + ) + if option_group_kwargs["engine_name"] not in valid_option_group_engines.keys(): + raise RDSClientError( + "InvalidParameterValue", "Invalid DB engine: non-existant" + ) + if ( + option_group_kwargs["major_engine_version"] + not in valid_option_group_engines[option_group_kwargs["engine_name"]] + ): + raise RDSClientError( + "InvalidParameterCombination", + "Cannot find major version {0} for {1}".format( + option_group_kwargs["major_engine_version"], + option_group_kwargs["engine_name"], + ), + ) option_group = OptionGroup(**option_group_kwargs) self.option_groups[option_group_id] = option_group return option_group @@ -932,82 +1012,129 @@ class RDS2Backend(BaseBackend): def describe_option_groups(self, option_group_kwargs): option_group_list = [] - if option_group_kwargs['marker']: - marker = option_group_kwargs['marker'] + if option_group_kwargs["marker"]: + marker = option_group_kwargs["marker"] else: marker = 0 - if option_group_kwargs['max_records']: - if option_group_kwargs['max_records'] < 20 or option_group_kwargs['max_records'] > 100: - raise RDSClientError('InvalidParameterValue', - 'Invalid value for max records. Must be between 20 and 100') - max_records = option_group_kwargs['max_records'] + if option_group_kwargs["max_records"]: + if ( + option_group_kwargs["max_records"] < 20 + or option_group_kwargs["max_records"] > 100 + ): + raise RDSClientError( + "InvalidParameterValue", + "Invalid value for max records. Must be between 20 and 100", + ) + max_records = option_group_kwargs["max_records"] else: max_records = 100 for option_group_name, option_group in self.option_groups.items(): - if option_group_kwargs['name'] and option_group.name != option_group_kwargs['name']: + if ( + option_group_kwargs["name"] + and option_group.name != option_group_kwargs["name"] + ): continue - elif option_group_kwargs['engine_name'] and \ - option_group.engine_name != option_group_kwargs['engine_name']: + elif ( + option_group_kwargs["engine_name"] + and option_group.engine_name != option_group_kwargs["engine_name"] + ): continue - elif option_group_kwargs['major_engine_version'] and \ - option_group.major_engine_version != option_group_kwargs['major_engine_version']: + elif ( + option_group_kwargs["major_engine_version"] + and option_group.major_engine_version + != option_group_kwargs["major_engine_version"] + ): continue else: option_group_list.append(option_group) if not len(option_group_list): - raise OptionGroupNotFoundFaultError(option_group_kwargs['name']) - return option_group_list[marker:max_records + marker] + raise OptionGroupNotFoundFaultError(option_group_kwargs["name"]) + return option_group_list[marker : max_records + marker] @staticmethod def describe_option_group_options(engine_name, major_engine_version=None): - default_option_group_options = {'mysql': {'5.6': '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'oracle-ee': {'11.2': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'oracle-sa': {'11.2': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'oracle-sa1': {'11.2': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'sqlserver-ee': {'10.50': '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - '11.00': '\n \n \n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}} + default_option_group_options = { + "mysql": { + "5.6": '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "oracle-ee": { + "11.2": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "oracle-sa": { + "11.2": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "oracle-sa1": { + "11.2": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "sqlserver-ee": { + "10.50": '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "11.00": '\n \n \n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + } if engine_name not in default_option_group_options: - raise RDSClientError('InvalidParameterValue', - 'Invalid DB engine: {0}'.format(engine_name)) - if major_engine_version and major_engine_version not in default_option_group_options[engine_name]: - raise RDSClientError('InvalidParameterCombination', - 'Cannot find major version {0} for {1}'.format(major_engine_version, engine_name)) + raise RDSClientError( + "InvalidParameterValue", "Invalid DB engine: {0}".format(engine_name) + ) + if ( + major_engine_version + and major_engine_version not in default_option_group_options[engine_name] + ): + raise RDSClientError( + "InvalidParameterCombination", + "Cannot find major version {0} for {1}".format( + major_engine_version, engine_name + ), + ) if major_engine_version: return default_option_group_options[engine_name][major_engine_version] - return default_option_group_options[engine_name]['all'] + return default_option_group_options[engine_name]["all"] - def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None): + def modify_option_group( + self, + option_group_name, + options_to_include=None, + options_to_remove=None, + apply_immediately=None, + ): if option_group_name not in self.option_groups: raise OptionGroupNotFoundFaultError(option_group_name) if not options_to_include and not options_to_remove: - raise RDSClientError('InvalidParameterValue', - 'At least one option must be added, modified, or removed.') + raise RDSClientError( + "InvalidParameterValue", + "At least one option must be added, modified, or removed.", + ) if options_to_remove: - self.option_groups[option_group_name].remove_options( - options_to_remove) + self.option_groups[option_group_name].remove_options(options_to_remove) if options_to_include: - self.option_groups[option_group_name].add_options( - options_to_include) + self.option_groups[option_group_name].add_options(options_to_include) return self.option_groups[option_group_name] def create_db_parameter_group(self, db_parameter_group_kwargs): - db_parameter_group_id = db_parameter_group_kwargs['name'] - if db_parameter_group_kwargs['name'] in self.db_parameter_groups: - raise RDSClientError('DBParameterGroupAlreadyExistsFault', - 'A DB parameter group named {0} already exists.'.format(db_parameter_group_kwargs['name'])) - if not db_parameter_group_kwargs.get('description'): - raise RDSClientError('InvalidParameterValue', - 'The parameter Description must be provided and must not be blank.') - if not db_parameter_group_kwargs.get('family'): - raise RDSClientError('InvalidParameterValue', - 'The parameter DBParameterGroupName must be provided and must not be blank.') + db_parameter_group_id = db_parameter_group_kwargs["name"] + if db_parameter_group_kwargs["name"] in self.db_parameter_groups: + raise RDSClientError( + "DBParameterGroupAlreadyExistsFault", + "A DB parameter group named {0} already exists.".format( + db_parameter_group_kwargs["name"] + ), + ) + if not db_parameter_group_kwargs.get("description"): + raise RDSClientError( + "InvalidParameterValue", + "The parameter Description must be provided and must not be blank.", + ) + if not db_parameter_group_kwargs.get("family"): + raise RDSClientError( + "InvalidParameterValue", + "The parameter DBParameterGroupName must be provided and must not be blank.", + ) db_parameter_group = DBParameterGroup(**db_parameter_group_kwargs) self.db_parameter_groups[db_parameter_group_id] = db_parameter_group @@ -1016,27 +1143,39 @@ class RDS2Backend(BaseBackend): def describe_db_parameter_groups(self, db_parameter_group_kwargs): db_parameter_group_list = [] - if db_parameter_group_kwargs.get('marker'): - marker = db_parameter_group_kwargs['marker'] + if db_parameter_group_kwargs.get("marker"): + marker = db_parameter_group_kwargs["marker"] else: marker = 0 - if db_parameter_group_kwargs.get('max_records'): - if db_parameter_group_kwargs['max_records'] < 20 or db_parameter_group_kwargs['max_records'] > 100: - raise RDSClientError('InvalidParameterValue', - 'Invalid value for max records. Must be between 20 and 100') - max_records = db_parameter_group_kwargs['max_records'] + if db_parameter_group_kwargs.get("max_records"): + if ( + db_parameter_group_kwargs["max_records"] < 20 + or db_parameter_group_kwargs["max_records"] > 100 + ): + raise RDSClientError( + "InvalidParameterValue", + "Invalid value for max records. Must be between 20 and 100", + ) + max_records = db_parameter_group_kwargs["max_records"] else: max_records = 100 - for db_parameter_group_name, db_parameter_group in self.db_parameter_groups.items(): - if not db_parameter_group_kwargs.get('name') or db_parameter_group.name == db_parameter_group_kwargs.get('name'): + for ( + db_parameter_group_name, + db_parameter_group, + ) in self.db_parameter_groups.items(): + if not db_parameter_group_kwargs.get( + "name" + ) or db_parameter_group.name == db_parameter_group_kwargs.get("name"): db_parameter_group_list.append(db_parameter_group) else: continue - return db_parameter_group_list[marker:max_records + marker] + return db_parameter_group_list[marker : max_records + marker] - def modify_db_parameter_group(self, db_parameter_group_name, db_parameter_group_parameters): + def modify_db_parameter_group( + self, db_parameter_group_name, db_parameter_group_parameters + ): if db_parameter_group_name not in self.db_parameter_groups: raise DBParameterGroupNotFoundError(db_parameter_group_name) @@ -1047,103 +1186,105 @@ class RDS2Backend(BaseBackend): def list_tags_for_resource(self, arn): if self.arn_regex.match(arn): - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == 'db': # Database + if resource_type == "db": # Database if resource_name in self.databases: return self.databases[resource_name].get_tags() - elif resource_type == 'es': # Event Subscription + elif resource_type == "es": # Event Subscription # TODO: Complete call to tags on resource type Event # Subscription return [] - elif resource_type == 'og': # Option Group + elif resource_type == "og": # Option Group if resource_name in self.option_groups: return self.option_groups[resource_name].get_tags() - elif resource_type == 'pg': # Parameter Group + elif resource_type == "pg": # Parameter Group if resource_name in self.db_parameter_groups: return self.db_parameter_groups[resource_name].get_tags() - elif resource_type == 'ri': # Reserved DB instance + elif resource_type == "ri": # Reserved DB instance # TODO: Complete call to tags on resource type Reserved DB # instance return [] - elif resource_type == 'secgrp': # DB security group + elif resource_type == "secgrp": # DB security group if resource_name in self.security_groups: return self.security_groups[resource_name].get_tags() - elif resource_type == 'snapshot': # DB Snapshot + elif resource_type == "snapshot": # DB Snapshot if resource_name in self.snapshots: return self.snapshots[resource_name].get_tags() - elif resource_type == 'subgrp': # DB subnet group + elif resource_type == "subgrp": # DB subnet group if resource_name in self.subnet_groups: return self.subnet_groups[resource_name].get_tags() else: - raise RDSClientError('InvalidParameterValue', - 'Invalid resource name: {0}'.format(arn)) + raise RDSClientError( + "InvalidParameterValue", "Invalid resource name: {0}".format(arn) + ) return [] def remove_tags_from_resource(self, arn, tag_keys): if self.arn_regex.match(arn): - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == 'db': # Database + if resource_type == "db": # Database if resource_name in self.databases: self.databases[resource_name].remove_tags(tag_keys) - elif resource_type == 'es': # Event Subscription + elif resource_type == "es": # Event Subscription return None - elif resource_type == 'og': # Option Group + elif resource_type == "og": # Option Group if resource_name in self.option_groups: return self.option_groups[resource_name].remove_tags(tag_keys) - elif resource_type == 'pg': # Parameter Group + elif resource_type == "pg": # Parameter Group return None - elif resource_type == 'ri': # Reserved DB instance + elif resource_type == "ri": # Reserved DB instance return None - elif resource_type == 'secgrp': # DB security group + elif resource_type == "secgrp": # DB security group if resource_name in self.security_groups: return self.security_groups[resource_name].remove_tags(tag_keys) - elif resource_type == 'snapshot': # DB Snapshot + elif resource_type == "snapshot": # DB Snapshot if resource_name in self.snapshots: return self.snapshots[resource_name].remove_tags(tag_keys) - elif resource_type == 'subgrp': # DB subnet group + elif resource_type == "subgrp": # DB subnet group if resource_name in self.subnet_groups: return self.subnet_groups[resource_name].remove_tags(tag_keys) else: - raise RDSClientError('InvalidParameterValue', - 'Invalid resource name: {0}'.format(arn)) + raise RDSClientError( + "InvalidParameterValue", "Invalid resource name: {0}".format(arn) + ) def add_tags_to_resource(self, arn, tags): if self.arn_regex.match(arn): - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == 'db': # Database + if resource_type == "db": # Database if resource_name in self.databases: return self.databases[resource_name].add_tags(tags) - elif resource_type == 'es': # Event Subscription + elif resource_type == "es": # Event Subscription return [] - elif resource_type == 'og': # Option Group + elif resource_type == "og": # Option Group if resource_name in self.option_groups: return self.option_groups[resource_name].add_tags(tags) - elif resource_type == 'pg': # Parameter Group + elif resource_type == "pg": # Parameter Group return [] - elif resource_type == 'ri': # Reserved DB instance + elif resource_type == "ri": # Reserved DB instance return [] - elif resource_type == 'secgrp': # DB security group + elif resource_type == "secgrp": # DB security group if resource_name in self.security_groups: return self.security_groups[resource_name].add_tags(tags) - elif resource_type == 'snapshot': # DB Snapshot + elif resource_type == "snapshot": # DB Snapshot if resource_name in self.snapshots: return self.snapshots[resource_name].add_tags(tags) - elif resource_type == 'subgrp': # DB subnet group + elif resource_type == "subgrp": # DB subnet group if resource_name in self.subnet_groups: return self.subnet_groups[resource_name].add_tags(tags) else: - raise RDSClientError('InvalidParameterValue', - 'Invalid resource name: {0}'.format(arn)) + raise RDSClientError( + "InvalidParameterValue", "Invalid resource name: {0}".format(arn) + ) class OptionGroup(object): - def __init__(self, name, engine_name, major_engine_version, description=None): self.engine_name = engine_name self.major_engine_version = major_engine_version @@ -1151,11 +1292,12 @@ class OptionGroup(object): self.name = name self.vpc_and_non_vpc_instance_memberships = False self.options = {} - self.vpcId = 'null' + self.vpcId = "null" self.tags = [] def to_json(self): - template = Template("""{ + template = Template( + """{ "VpcId": null, "MajorEngineVersion": "{{ option_group.major_engine_version }}", "OptionGroupDescription": "{{ option_group.description }}", @@ -1163,18 +1305,21 @@ class OptionGroup(object): "EngineName": "{{ option_group.engine_name }}", "Options": [], "OptionGroupName": "{{ option_group.name }}" -}""") +}""" + ) return template.render(option_group=self) def to_xml(self): - template = Template(""" + template = Template( + """ {{ option_group.name }} {{ option_group.vpc_and_non_vpc_instance_memberships }} {{ option_group.major_engine_version }} {{ option_group.engine_name }} {{ option_group.description }} - """) + """ + ) return template.render(option_group=self) def remove_options(self, options_to_remove): @@ -1191,37 +1336,39 @@ class OptionGroup(object): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] class OptionGroupOption(object): - def __init__(self, **kwargs): - self.default_port = kwargs.get('default_port') - self.description = kwargs.get('description') - self.engine_name = kwargs.get('engine_name') - self.major_engine_version = kwargs.get('major_engine_version') - self.name = kwargs.get('name') + self.default_port = kwargs.get("default_port") + self.description = kwargs.get("description") + self.engine_name = kwargs.get("engine_name") + self.major_engine_version = kwargs.get("major_engine_version") + self.name = kwargs.get("name") self.option_group_option_settings = self._make_option_group_option_settings( - kwargs.get('option_group_option_settings', [])) - self.options_depended_on = kwargs.get('options_depended_on', []) - self.permanent = kwargs.get('permanent') - self.persistent = kwargs.get('persistent') - self.port_required = kwargs.get('port_required') + kwargs.get("option_group_option_settings", []) + ) + self.options_depended_on = kwargs.get("options_depended_on", []) + self.permanent = kwargs.get("permanent") + self.persistent = kwargs.get("persistent") + self.port_required = kwargs.get("port_required") def _make_option_group_option_settings(self, option_group_option_settings_kwargs): - return [OptionGroupOptionSetting(**setting_kwargs) for setting_kwargs in option_group_option_settings_kwargs] + return [ + OptionGroupOptionSetting(**setting_kwargs) + for setting_kwargs in option_group_option_settings_kwargs + ] def to_json(self): - template = Template("""{ "MinimumRequiredMinorEngineVersion": + template = Template( + """{ "MinimumRequiredMinorEngineVersion": "2789.0.v1", "OptionsDependedOn": [], "MajorEngineVersion": "10.50", @@ -1233,11 +1380,13 @@ class OptionGroupOption(object): "Name": "Mirroring", "PortRequired": false, "Description": "SQLServer Database Mirroring" - }""") + }""" + ) return template.render(option_group=self) def to_xml(self): - template = Template(""" + template = Template( + """ {{ option_group.major_engine_version }} {{ option_group.default_port }} {{ option_group.port_required }} @@ -1257,34 +1406,35 @@ class OptionGroupOption(object): {{ option_group.engine_name }} {{ option_group.minimum_required_minor_engine_version }} -""") +""" + ) return template.render(option_group=self) class OptionGroupOptionSetting(object): - def __init__(self, *kwargs): - self.allowed_values = kwargs.get('allowed_values') - self.apply_type = kwargs.get('apply_type') - self.default_value = kwargs.get('default_value') - self.is_modifiable = kwargs.get('is_modifiable') - self.setting_description = kwargs.get('setting_description') - self.setting_name = kwargs.get('setting_name') + self.allowed_values = kwargs.get("allowed_values") + self.apply_type = kwargs.get("apply_type") + self.default_value = kwargs.get("default_value") + self.is_modifiable = kwargs.get("is_modifiable") + self.setting_description = kwargs.get("setting_description") + self.setting_name = kwargs.get("setting_name") def to_xml(self): - template = Template(""" + template = Template( + """ {{ option_group_option_setting.allowed_values }} {{ option_group_option_setting.apply_type }} {{ option_group_option_setting.default_value }} {{ option_group_option_setting.is_modifiable }} {{ option_group_option_setting.setting_description }} {{ option_group_option_setting.setting_name }} -""") +""" + ) return template.render(option_group_option_setting=self) class DBParameterGroup(object): - def __init__(self, name, description, family, tags): self.name = name self.description = description @@ -1293,30 +1443,30 @@ class DBParameterGroup(object): self.parameters = defaultdict(dict) def to_xml(self): - template = Template(""" + template = Template( + """ {{ param_group.name }} {{ param_group.family }} {{ param_group.description }} - """) + """ + ) return template.render(param_group=self) def get_tags(self): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def update_parameters(self, new_parameters): for new_parameter in new_parameters: - parameter = self.parameters[new_parameter['ParameterName']] + parameter = self.parameters[new_parameter["ParameterName"]] parameter.update(new_parameter) def delete(self, region_name): @@ -1324,28 +1474,33 @@ class DBParameterGroup(object): backend.delete_db_parameter_group(self.name) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] db_parameter_group_kwargs = { - 'description': properties['Description'], - 'family': properties['Family'], - 'name': resource_name.lower(), - 'tags': properties.get("Tags"), + "description": properties["Description"], + "family": properties["Family"], + "name": resource_name.lower(), + "tags": properties.get("Tags"), } db_parameter_group_parameters = [] - for db_parameter, db_parameter_value in properties.get('Parameters', {}).items(): - db_parameter_group_parameters.append({ - 'ParameterName': db_parameter, - 'ParameterValue': db_parameter_value, - }) + for db_parameter, db_parameter_value in properties.get( + "Parameters", {} + ).items(): + db_parameter_group_parameters.append( + {"ParameterName": db_parameter, "ParameterValue": db_parameter_value} + ) rds2_backend = rds2_backends[region_name] db_parameter_group = rds2_backend.create_db_parameter_group( - db_parameter_group_kwargs) + db_parameter_group_kwargs + ) db_parameter_group.update_parameters(db_parameter_group_parameters) return db_parameter_group -rds2_backends = dict((region.name, RDS2Backend(region.name)) - for region in boto.rds2.regions()) +rds2_backends = dict( + (region.name, RDS2Backend(region.name)) for region in boto.rds2.regions() +) diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py index 7b8d0b63a..625838d4d 100644 --- a/moto/rds2/responses.py +++ b/moto/rds2/responses.py @@ -8,87 +8,90 @@ from .exceptions import DBParameterGroupNotFoundError class RDS2Response(BaseResponse): - @property def backend(self): return rds2_backends[self.region] def _get_db_kwargs(self): args = { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), - "allocated_storage": self._get_int_param('AllocatedStorage'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), + "allocated_storage": self._get_int_param("AllocatedStorage"), "availability_zone": self._get_param("AvailabilityZone"), "backup_retention_period": self._get_param("BackupRetentionPeriod"), "copy_tags_to_snapshot": self._get_param("CopyTagsToSnapshot"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_name": self._get_param("DBName"), "db_parameter_group_name": self._get_param("DBParameterGroupName"), - "db_snapshot_identifier": self._get_param('DBSnapshotIdentifier'), + "db_snapshot_identifier": self._get_param("DBSnapshotIdentifier"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"), "engine": self._get_param("Engine"), "engine_version": self._get_param("EngineVersion"), "license_model": self._get_param("LicenseModel"), "iops": self._get_int_param("Iops"), "kms_key_id": self._get_param("KmsKeyId"), - "master_user_password": self._get_param('MasterUserPassword'), - "master_username": self._get_param('MasterUsername'), + "master_user_password": self._get_param("MasterUserPassword"), + "master_username": self._get_param("MasterUsername"), "multi_az": self._get_bool_param("MultiAZ"), "option_group_name": self._get_param("OptionGroupName"), - "port": self._get_param('Port'), + "port": self._get_param("Port"), # PreferredBackupWindow # PreferredMaintenanceWindow "publicly_accessible": self._get_param("PubliclyAccessible"), "region": self.region, - "security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'), + "security_groups": self._get_multi_param( + "DBSecurityGroups.DBSecurityGroupName" + ), "storage_encrypted": self._get_param("StorageEncrypted"), - "storage_type": self._get_param("StorageType", 'standard'), - "vpc_security_group_ids": self._get_multi_param("VpcSecurityGroupIds.VpcSecurityGroupId"), + "storage_type": self._get_param("StorageType", None), + "vpc_security_group_ids": self._get_multi_param( + "VpcSecurityGroupIds.VpcSecurityGroupId" + ), "tags": list(), } - args['tags'] = self.unpack_complex_list_params( - 'Tags.Tag', ('Key', 'Value')) + args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) return args def _get_db_replica_kwargs(self): return { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "availability_zone": self._get_param("AvailabilityZone"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"), "iops": self._get_int_param("Iops"), # OptionGroupName - "port": self._get_param('Port'), + "port": self._get_param("Port"), "publicly_accessible": self._get_param("PubliclyAccessible"), - "source_db_identifier": self._get_param('SourceDBInstanceIdentifier'), + "source_db_identifier": self._get_param("SourceDBInstanceIdentifier"), "storage_type": self._get_param("StorageType"), } def _get_option_group_kwargs(self): return { - 'major_engine_version': self._get_param('MajorEngineVersion'), - 'description': self._get_param('OptionGroupDescription'), - 'engine_name': self._get_param('EngineName'), - 'name': self._get_param('OptionGroupName') + "major_engine_version": self._get_param("MajorEngineVersion"), + "description": self._get_param("OptionGroupDescription"), + "engine_name": self._get_param("EngineName"), + "name": self._get_param("OptionGroupName"), } def _get_db_parameter_group_kwargs(self): return { - 'description': self._get_param('Description'), - 'family': self._get_param('DBParameterGroupFamily'), - 'name': self._get_param('DBParameterGroupName'), - 'tags': self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')), + "description": self._get_param("Description"), + "family": self._get_param("DBParameterGroupFamily"), + "name": self._get_param("DBParameterGroupName"), + "tags": self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")), } def unpack_complex_list_params(self, label, names): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}.{2}'.format(label, count, names[0])): + while self._get_param("{0}.{1}.{2}".format(label, count, names[0])): param = dict() for i in range(len(names)): param[names[i]] = self._get_param( - '{0}.{1}.{2}'.format(label, count, names[i])) + "{0}.{1}.{2}".format(label, count, names[i]) + ) unpacked_list.append(param) count += 1 return unpacked_list @@ -96,9 +99,8 @@ class RDS2Response(BaseResponse): def unpack_list_params(self, label): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}'.format(label, count)): - unpacked_list.append(self._get_param( - '{0}.{1}'.format(label, count))) + while self._get_param("{0}.{1}".format(label, count)): + unpacked_list.append(self._get_param("{0}.{1}".format(label, count))) count += 1 return unpacked_list @@ -116,16 +118,18 @@ class RDS2Response(BaseResponse): return template.render(database=database) def describe_db_instances(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") all_instances = list(self.backend.describe_databases(db_instance_identifier)) - marker = self._get_param('Marker') + marker = self._get_param("Marker") all_ids = [instance.db_instance_identifier for instance in all_instances] if marker: start = all_ids.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier - instances_resp = all_instances[start:start + page_size] + page_size = self._get_int_param( + "MaxRecords", 50 + ) # the default is 100, but using 50 to make testing easier + instances_resp = all_instances[start : start + page_size] next_marker = None if len(all_instances) > start + page_size: next_marker = instances_resp[-1].db_instance_identifier @@ -134,134 +138,143 @@ class RDS2Response(BaseResponse): return template.render(databases=instances_resp, marker=next_marker) def modify_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") db_kwargs = self._get_db_kwargs() - new_db_instance_identifier = self._get_param('NewDBInstanceIdentifier') + new_db_instance_identifier = self._get_param("NewDBInstanceIdentifier") if new_db_instance_identifier: - db_kwargs['new_db_instance_identifier'] = new_db_instance_identifier - database = self.backend.modify_database( - db_instance_identifier, db_kwargs) + db_kwargs["new_db_instance_identifier"] = new_db_instance_identifier + database = self.backend.modify_database(db_instance_identifier, db_kwargs) template = self.response_template(MODIFY_DATABASE_TEMPLATE) return template.render(database=database) def delete_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_name = self._get_param('FinalDBSnapshotIdentifier') - database = self.backend.delete_database(db_instance_identifier, db_snapshot_name) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_name = self._get_param("FinalDBSnapshotIdentifier") + database = self.backend.delete_database( + db_instance_identifier, db_snapshot_name + ) template = self.response_template(DELETE_DATABASE_TEMPLATE) return template.render(database=database) def reboot_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.reboot_db_instance(db_instance_identifier) template = self.response_template(REBOOT_DATABASE_TEMPLATE) return template.render(database=database) def create_db_snapshot(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) - snapshot = self.backend.create_snapshot(db_instance_identifier, db_snapshot_identifier, tags) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) + snapshot = self.backend.create_snapshot( + db_instance_identifier, db_snapshot_identifier, tags + ) template = self.response_template(CREATE_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) def describe_db_snapshots(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') - snapshots = self.backend.describe_snapshots(db_instance_identifier, db_snapshot_identifier) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") + snapshots = self.backend.describe_snapshots( + db_instance_identifier, db_snapshot_identifier + ) template = self.response_template(DESCRIBE_SNAPSHOTS_TEMPLATE) return template.render(snapshots=snapshots) def delete_db_snapshot(self): - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") snapshot = self.backend.delete_snapshot(db_snapshot_identifier) template = self.response_template(DELETE_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) def list_tags_for_resource(self): - arn = self._get_param('ResourceName') + arn = self._get_param("ResourceName") template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE) tags = self.backend.list_tags_for_resource(arn) return template.render(tags=tags) def add_tags_to_resource(self): - arn = self._get_param('ResourceName') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + arn = self._get_param("ResourceName") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) tags = self.backend.add_tags_to_resource(arn, tags) template = self.response_template(ADD_TAGS_TO_RESOURCE_TEMPLATE) return template.render(tags=tags) def remove_tags_from_resource(self): - arn = self._get_param('ResourceName') - tag_keys = self.unpack_list_params('TagKeys.member') + arn = self._get_param("ResourceName") + tag_keys = self.unpack_list_params("TagKeys.member") self.backend.remove_tags_from_resource(arn, tag_keys) template = self.response_template(REMOVE_TAGS_FROM_RESOURCE_TEMPLATE) return template.render() def stop_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') - database = self.backend.stop_database(db_instance_identifier, db_snapshot_identifier) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") + database = self.backend.stop_database( + db_instance_identifier, db_snapshot_identifier + ) template = self.response_template(STOP_DATABASE_TEMPLATE) return template.render(database=database) def start_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.start_database(db_instance_identifier) template = self.response_template(START_DATABASE_TEMPLATE) return template.render(database=database) def create_db_security_group(self): - group_name = self._get_param('DBSecurityGroupName') - description = self._get_param('DBSecurityGroupDescription') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + group_name = self._get_param("DBSecurityGroupName") + description = self._get_param("DBSecurityGroupDescription") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) security_group = self.backend.create_security_group( - group_name, description, tags) + group_name, description, tags + ) template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def describe_db_security_groups(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_groups = self.backend.describe_security_groups( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_groups = self.backend.describe_security_groups(security_group_name) template = self.response_template(DESCRIBE_SECURITY_GROUPS_TEMPLATE) return template.render(security_groups=security_groups) def delete_db_security_group(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_group = self.backend.delete_security_group( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_group = self.backend.delete_security_group(security_group_name) template = self.response_template(DELETE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def authorize_db_security_group_ingress(self): - security_group_name = self._get_param('DBSecurityGroupName') - cidr_ip = self._get_param('CIDRIP') + security_group_name = self._get_param("DBSecurityGroupName") + cidr_ip = self._get_param("CIDRIP") security_group = self.backend.authorize_security_group( - security_group_name, cidr_ip) + security_group_name, cidr_ip + ) template = self.response_template(AUTHORIZE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def create_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') - description = self._get_param('DBSubnetGroupDescription') - subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) - subnets = [ec2_backends[self.region].get_subnet( - subnet_id) for subnet_id in subnet_ids] + subnet_name = self._get_param("DBSubnetGroupName") + description = self._get_param("DBSubnetGroupDescription") + subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) + subnets = [ + ec2_backends[self.region].get_subnet(subnet_id) for subnet_id in subnet_ids + ] subnet_group = self.backend.create_subnet_group( - subnet_name, description, subnets, tags) + subnet_name, description, subnets, tags + ) template = self.response_template(CREATE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) def describe_db_subnet_groups(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_groups = self.backend.describe_subnet_groups(subnet_name) template = self.response_template(DESCRIBE_SUBNET_GROUPS_TEMPLATE) return template.render(subnet_groups=subnet_groups) def delete_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_group = self.backend.delete_subnet_group(subnet_name) template = self.response_template(DELETE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) @@ -274,50 +287,67 @@ class RDS2Response(BaseResponse): def delete_option_group(self): kwargs = self._get_option_group_kwargs() - option_group = self.backend.delete_option_group(kwargs['name']) + option_group = self.backend.delete_option_group(kwargs["name"]) template = self.response_template(DELETE_OPTION_GROUP_TEMPLATE) return template.render(option_group=option_group) def describe_option_groups(self): kwargs = self._get_option_group_kwargs() - kwargs['max_records'] = self._get_int_param('MaxRecords') - kwargs['marker'] = self._get_param('Marker') + kwargs["max_records"] = self._get_int_param("MaxRecords") + kwargs["marker"] = self._get_param("Marker") option_groups = self.backend.describe_option_groups(kwargs) template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE) return template.render(option_groups=option_groups) def describe_option_group_options(self): - engine_name = self._get_param('EngineName') - major_engine_version = self._get_param('MajorEngineVersion') + engine_name = self._get_param("EngineName") + major_engine_version = self._get_param("MajorEngineVersion") option_group_options = self.backend.describe_option_group_options( - engine_name, major_engine_version) + engine_name, major_engine_version + ) return option_group_options def modify_option_group(self): - option_group_name = self._get_param('OptionGroupName') + option_group_name = self._get_param("OptionGroupName") count = 1 options_to_include = [] - while self._get_param('OptionsToInclude.member.{0}.OptionName'.format(count)): - options_to_include.append({ - 'Port': self._get_param('OptionsToInclude.member.{0}.Port'.format(count)), - 'OptionName': self._get_param('OptionsToInclude.member.{0}.OptionName'.format(count)), - 'DBSecurityGroupMemberships': self._get_param('OptionsToInclude.member.{0}.DBSecurityGroupMemberships'.format(count)), - 'OptionSettings': self._get_param('OptionsToInclude.member.{0}.OptionSettings'.format(count)), - 'VpcSecurityGroupMemberships': self._get_param('OptionsToInclude.member.{0}.VpcSecurityGroupMemberships'.format(count)) - }) + while self._get_param("OptionsToInclude.member.{0}.OptionName".format(count)): + options_to_include.append( + { + "Port": self._get_param( + "OptionsToInclude.member.{0}.Port".format(count) + ), + "OptionName": self._get_param( + "OptionsToInclude.member.{0}.OptionName".format(count) + ), + "DBSecurityGroupMemberships": self._get_param( + "OptionsToInclude.member.{0}.DBSecurityGroupMemberships".format( + count + ) + ), + "OptionSettings": self._get_param( + "OptionsToInclude.member.{0}.OptionSettings".format(count) + ), + "VpcSecurityGroupMemberships": self._get_param( + "OptionsToInclude.member.{0}.VpcSecurityGroupMemberships".format( + count + ) + ), + } + ) count += 1 count = 1 options_to_remove = [] - while self._get_param('OptionsToRemove.member.{0}'.format(count)): - options_to_remove.append(self._get_param( - 'OptionsToRemove.member.{0}'.format(count))) + while self._get_param("OptionsToRemove.member.{0}".format(count)): + options_to_remove.append( + self._get_param("OptionsToRemove.member.{0}".format(count)) + ) count += 1 - apply_immediately = self._get_param('ApplyImmediately') - option_group = self.backend.modify_option_group(option_group_name, - options_to_include, - options_to_remove, - apply_immediately) + apply_immediately = self._get_param("ApplyImmediately") + option_group = self.backend.modify_option_group( + option_group_name, options_to_include, options_to_remove, apply_immediately + ) template = self.response_template(MODIFY_OPTION_GROUP_TEMPLATE) return template.render(option_group=option_group) @@ -329,28 +359,28 @@ class RDS2Response(BaseResponse): def describe_db_parameter_groups(self): kwargs = self._get_db_parameter_group_kwargs() - kwargs['max_records'] = self._get_int_param('MaxRecords') - kwargs['marker'] = self._get_param('Marker') + kwargs["max_records"] = self._get_int_param("MaxRecords") + kwargs["marker"] = self._get_param("Marker") db_parameter_groups = self.backend.describe_db_parameter_groups(kwargs) - template = self.response_template( - DESCRIBE_DB_PARAMETER_GROUPS_TEMPLATE) + template = self.response_template(DESCRIBE_DB_PARAMETER_GROUPS_TEMPLATE) return template.render(db_parameter_groups=db_parameter_groups) def modify_db_parameter_group(self): - db_parameter_group_name = self._get_param('DBParameterGroupName') + db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_group_parameters = self._get_db_parameter_group_paramters() - db_parameter_group = self.backend.modify_db_parameter_group(db_parameter_group_name, - db_parameter_group_parameters) + db_parameter_group = self.backend.modify_db_parameter_group( + db_parameter_group_name, db_parameter_group_parameters + ) template = self.response_template(MODIFY_DB_PARAMETER_GROUP_TEMPLATE) return template.render(db_parameter_group=db_parameter_group) def _get_db_parameter_group_paramters(self): parameter_group_parameters = defaultdict(dict) for param_name, value in self.querystring.items(): - if not param_name.startswith('Parameters.Parameter'): + if not param_name.startswith("Parameters.Parameter"): continue - split_param_name = param_name.split('.') + split_param_name = param_name.split(".") param_id = split_param_name[2] param_setting = split_param_name[3] @@ -359,9 +389,10 @@ class RDS2Response(BaseResponse): return parameter_group_parameters.values() def describe_db_parameters(self): - db_parameter_group_name = self._get_param('DBParameterGroupName') + db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_groups = self.backend.describe_db_parameter_groups( - {'name': db_parameter_group_name}) + {"name": db_parameter_group_name} + ) if not db_parameter_groups: raise DBParameterGroupNotFoundError(db_parameter_group_name) @@ -370,8 +401,7 @@ class RDS2Response(BaseResponse): def delete_db_parameter_group(self): kwargs = self._get_db_parameter_group_kwargs() - db_parameter_group = self.backend.delete_db_parameter_group(kwargs[ - 'name']) + db_parameter_group = self.backend.delete_db_parameter_group(kwargs["name"]) template = self.response_template(DELETE_DB_PARAMETER_GROUP_TEMPLATE) return template.render(db_parameter_group=db_parameter_group) diff --git a/moto/rds2/urls.py b/moto/rds2/urls.py index d19dc2785..d937554e0 100644 --- a/moto/rds2/urls.py +++ b/moto/rds2/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import RDS2Response -url_bases = [ - "https?://rds.(.+).amazonaws.com", - "https?://rds.amazonaws.com", -] +url_bases = ["https?://rds.(.+).amazonaws.com", "https?://rds.amazonaws.com"] -url_paths = { - '{0}/$': RDS2Response.dispatch, -} +url_paths = {"{0}/$": RDS2Response.dispatch} diff --git a/moto/redshift/__init__.py b/moto/redshift/__init__.py index 06f778e8d..47cbf3b58 100644 --- a/moto/redshift/__init__.py +++ b/moto/redshift/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import redshift_backends from ..core.models import base_decorator, deprecated_base_decorator -redshift_backend = redshift_backends['us-east-1'] +redshift_backend = redshift_backends["us-east-1"] mock_redshift = base_decorator(redshift_backends) mock_redshift_deprecated = deprecated_base_decorator(redshift_backends) diff --git a/moto/redshift/exceptions.py b/moto/redshift/exceptions.py index b0cef57ad..0a17e8aab 100644 --- a/moto/redshift/exceptions.py +++ b/moto/redshift/exceptions.py @@ -5,94 +5,93 @@ from werkzeug.exceptions import BadRequest class RedshiftClientError(BadRequest): - def __init__(self, code, message): super(RedshiftClientError, self).__init__() - self.description = json.dumps({ - "Error": { - "Code": code, - "Message": message, - 'Type': 'Sender', - }, - 'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1', - }) + self.description = json.dumps( + { + "Error": {"Code": code, "Message": message, "Type": "Sender"}, + "RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1", + } + ) class ClusterNotFoundError(RedshiftClientError): - def __init__(self, cluster_identifier): super(ClusterNotFoundError, self).__init__( - 'ClusterNotFound', - "Cluster {0} not found.".format(cluster_identifier)) + "ClusterNotFound", "Cluster {0} not found.".format(cluster_identifier) + ) class ClusterSubnetGroupNotFoundError(RedshiftClientError): - def __init__(self, subnet_identifier): super(ClusterSubnetGroupNotFoundError, self).__init__( - 'ClusterSubnetGroupNotFound', - "Subnet group {0} not found.".format(subnet_identifier)) + "ClusterSubnetGroupNotFound", + "Subnet group {0} not found.".format(subnet_identifier), + ) class ClusterSecurityGroupNotFoundError(RedshiftClientError): - def __init__(self, group_identifier): super(ClusterSecurityGroupNotFoundError, self).__init__( - 'ClusterSecurityGroupNotFound', - "Security group {0} not found.".format(group_identifier)) + "ClusterSecurityGroupNotFound", + "Security group {0} not found.".format(group_identifier), + ) class ClusterParameterGroupNotFoundError(RedshiftClientError): - def __init__(self, group_identifier): super(ClusterParameterGroupNotFoundError, self).__init__( - 'ClusterParameterGroupNotFound', - "Parameter group {0} not found.".format(group_identifier)) + "ClusterParameterGroupNotFound", + "Parameter group {0} not found.".format(group_identifier), + ) class InvalidSubnetError(RedshiftClientError): - def __init__(self, subnet_identifier): super(InvalidSubnetError, self).__init__( - 'InvalidSubnet', - "Subnet {0} not found.".format(subnet_identifier)) + "InvalidSubnet", "Subnet {0} not found.".format(subnet_identifier) + ) class SnapshotCopyGrantAlreadyExistsFaultError(RedshiftClientError): def __init__(self, snapshot_copy_grant_name): super(SnapshotCopyGrantAlreadyExistsFaultError, self).__init__( - 'SnapshotCopyGrantAlreadyExistsFault', + "SnapshotCopyGrantAlreadyExistsFault", "Cannot create the snapshot copy grant because a grant " - "with the identifier '{0}' already exists".format(snapshot_copy_grant_name)) + "with the identifier '{0}' already exists".format(snapshot_copy_grant_name), + ) class SnapshotCopyGrantNotFoundFaultError(RedshiftClientError): def __init__(self, snapshot_copy_grant_name): super(SnapshotCopyGrantNotFoundFaultError, self).__init__( - 'SnapshotCopyGrantNotFoundFault', - "Snapshot copy grant not found: {0}".format(snapshot_copy_grant_name)) + "SnapshotCopyGrantNotFoundFault", + "Snapshot copy grant not found: {0}".format(snapshot_copy_grant_name), + ) class ClusterSnapshotNotFoundError(RedshiftClientError): def __init__(self, snapshot_identifier): super(ClusterSnapshotNotFoundError, self).__init__( - 'ClusterSnapshotNotFound', - "Snapshot {0} not found.".format(snapshot_identifier)) + "ClusterSnapshotNotFound", + "Snapshot {0} not found.".format(snapshot_identifier), + ) class ClusterSnapshotAlreadyExistsError(RedshiftClientError): def __init__(self, snapshot_identifier): super(ClusterSnapshotAlreadyExistsError, self).__init__( - 'ClusterSnapshotAlreadyExists', + "ClusterSnapshotAlreadyExists", "Cannot create the snapshot because a snapshot with the " - "identifier {0} already exists".format(snapshot_identifier)) + "identifier {0} already exists".format(snapshot_identifier), + ) class InvalidParameterValueError(RedshiftClientError): def __init__(self, message): super(InvalidParameterValueError, self).__init__( - 'InvalidParameterValue', - message) + "InvalidParameterValue", message + ) class ResourceNotFoundFaultError(RedshiftClientError): @@ -106,26 +105,34 @@ class ResourceNotFoundFaultError(RedshiftClientError): msg = "{0} ({1}) not found.".format(resource_type, resource_name) if message: msg = message - super(ResourceNotFoundFaultError, self).__init__( - 'ResourceNotFoundFault', msg) + super(ResourceNotFoundFaultError, self).__init__("ResourceNotFoundFault", msg) class SnapshotCopyDisabledFaultError(RedshiftClientError): def __init__(self, cluster_identifier): super(SnapshotCopyDisabledFaultError, self).__init__( - 'SnapshotCopyDisabledFault', - "Cannot modify retention period because snapshot copy is disabled on Cluster {0}.".format(cluster_identifier)) + "SnapshotCopyDisabledFault", + "Cannot modify retention period because snapshot copy is disabled on Cluster {0}.".format( + cluster_identifier + ), + ) class SnapshotCopyAlreadyDisabledFaultError(RedshiftClientError): def __init__(self, cluster_identifier): super(SnapshotCopyAlreadyDisabledFaultError, self).__init__( - 'SnapshotCopyAlreadyDisabledFault', - "Snapshot Copy is already disabled on Cluster {0}.".format(cluster_identifier)) + "SnapshotCopyAlreadyDisabledFault", + "Snapshot Copy is already disabled on Cluster {0}.".format( + cluster_identifier + ), + ) class SnapshotCopyAlreadyEnabledFaultError(RedshiftClientError): def __init__(self, cluster_identifier): super(SnapshotCopyAlreadyEnabledFaultError, self).__init__( - 'SnapshotCopyAlreadyEnabledFault', - "Snapshot Copy is already enabled on Cluster {0}.".format(cluster_identifier)) + "SnapshotCopyAlreadyEnabledFault", + "Snapshot Copy is already enabled on Cluster {0}.".format( + cluster_identifier + ), + ) diff --git a/moto/redshift/models.py b/moto/redshift/models.py index c0b783bde..2c57c0f06 100644 --- a/moto/redshift/models.py +++ b/moto/redshift/models.py @@ -27,7 +27,7 @@ from .exceptions import ( ) -ACCOUNT_ID = 123456789012 +from moto.core import ACCOUNT_ID class TaggableResourceMixin(object): @@ -48,58 +48,91 @@ class TaggableResourceMixin(object): region=self.region, account_id=ACCOUNT_ID, resource_type=self.resource_type, - resource_id=self.resource_id) + resource_id=self.resource_id, + ) def create_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags - if tag_set['Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def delete_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags - if tag_set['Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] return self.tags class Cluster(TaggableResourceMixin, BaseModel): - resource_type = 'cluster' + resource_type = "cluster" - def __init__(self, redshift_backend, cluster_identifier, node_type, master_username, - master_user_password, db_name, cluster_type, cluster_security_groups, - vpc_security_group_ids, cluster_subnet_group_name, availability_zone, - preferred_maintenance_window, cluster_parameter_group_name, - automated_snapshot_retention_period, port, cluster_version, - allow_version_upgrade, number_of_nodes, publicly_accessible, - encrypted, region_name, tags=None, iam_roles_arn=None, - restored_from_snapshot=False): + def __init__( + self, + redshift_backend, + cluster_identifier, + node_type, + master_username, + master_user_password, + db_name, + cluster_type, + cluster_security_groups, + vpc_security_group_ids, + cluster_subnet_group_name, + availability_zone, + preferred_maintenance_window, + cluster_parameter_group_name, + automated_snapshot_retention_period, + port, + cluster_version, + allow_version_upgrade, + number_of_nodes, + publicly_accessible, + encrypted, + region_name, + tags=None, + iam_roles_arn=None, + enhanced_vpc_routing=None, + restored_from_snapshot=False, + ): super(Cluster, self).__init__(region_name, tags) self.redshift_backend = redshift_backend self.cluster_identifier = cluster_identifier - self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()) - self.status = 'available' + self.create_time = iso_8601_datetime_with_milliseconds( + datetime.datetime.utcnow() + ) + self.status = "available" self.node_type = node_type self.master_username = master_username self.master_user_password = master_user_password self.db_name = db_name if db_name else "dev" self.vpc_security_group_ids = vpc_security_group_ids + self.enhanced_vpc_routing = ( + enhanced_vpc_routing if enhanced_vpc_routing is not None else False + ) self.cluster_subnet_group_name = cluster_subnet_group_name self.publicly_accessible = publicly_accessible self.encrypted = encrypted - self.allow_version_upgrade = allow_version_upgrade if allow_version_upgrade is not None else True + self.allow_version_upgrade = ( + allow_version_upgrade if allow_version_upgrade is not None else True + ) self.cluster_version = cluster_version if cluster_version else "1.0" self.port = int(port) if port else 5439 - self.automated_snapshot_retention_period = int( - automated_snapshot_retention_period) if automated_snapshot_retention_period else 1 - self.preferred_maintenance_window = preferred_maintenance_window if preferred_maintenance_window else "Mon:03:00-Mon:03:30" + self.automated_snapshot_retention_period = ( + int(automated_snapshot_retention_period) + if automated_snapshot_retention_period + else 1 + ) + self.preferred_maintenance_window = ( + preferred_maintenance_window + if preferred_maintenance_window + else "Mon:03:00-Mon:03:30" + ) if cluster_parameter_group_name: self.cluster_parameter_group_name = [cluster_parameter_group_name] else: - self.cluster_parameter_group_name = ['default.redshift-1.0'] + self.cluster_parameter_group_name = ["default.redshift-1.0"] if cluster_security_groups: self.cluster_security_groups = cluster_security_groups @@ -113,7 +146,7 @@ class Cluster(TaggableResourceMixin, BaseModel): # way to pull AZs for a region in boto self.availability_zone = region_name + "a" - if cluster_type == 'single-node': + if cluster_type == "single-node": self.number_of_nodes = 1 elif number_of_nodes: self.number_of_nodes = int(number_of_nodes) @@ -124,37 +157,39 @@ class Cluster(TaggableResourceMixin, BaseModel): self.restored_from_snapshot = restored_from_snapshot @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): redshift_backend = redshift_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] - if 'ClusterSubnetGroupName' in properties: + if "ClusterSubnetGroupName" in properties: subnet_group_name = properties[ - 'ClusterSubnetGroupName'].cluster_subnet_group_name + "ClusterSubnetGroupName" + ].cluster_subnet_group_name else: subnet_group_name = None cluster = redshift_backend.create_cluster( cluster_identifier=resource_name, - node_type=properties.get('NodeType'), - master_username=properties.get('MasterUsername'), - master_user_password=properties.get('MasterUserPassword'), - db_name=properties.get('DBName'), - cluster_type=properties.get('ClusterType'), - cluster_security_groups=properties.get( - 'ClusterSecurityGroups', []), - vpc_security_group_ids=properties.get('VpcSecurityGroupIds', []), + node_type=properties.get("NodeType"), + master_username=properties.get("MasterUsername"), + master_user_password=properties.get("MasterUserPassword"), + db_name=properties.get("DBName"), + cluster_type=properties.get("ClusterType"), + cluster_security_groups=properties.get("ClusterSecurityGroups", []), + vpc_security_group_ids=properties.get("VpcSecurityGroupIds", []), cluster_subnet_group_name=subnet_group_name, - availability_zone=properties.get('AvailabilityZone'), - preferred_maintenance_window=properties.get( - 'PreferredMaintenanceWindow'), - cluster_parameter_group_name=properties.get( - 'ClusterParameterGroupName'), + availability_zone=properties.get("AvailabilityZone"), + preferred_maintenance_window=properties.get("PreferredMaintenanceWindow"), + cluster_parameter_group_name=properties.get("ClusterParameterGroupName"), automated_snapshot_retention_period=properties.get( - 'AutomatedSnapshotRetentionPeriod'), - port=properties.get('Port'), - cluster_version=properties.get('ClusterVersion'), - allow_version_upgrade=properties.get('AllowVersionUpgrade'), - number_of_nodes=properties.get('NumberOfNodes'), + "AutomatedSnapshotRetentionPeriod" + ), + port=properties.get("Port"), + cluster_version=properties.get("ClusterVersion"), + allow_version_upgrade=properties.get("AllowVersionUpgrade"), + enhanced_vpc_routing=properties.get("EnhancedVpcRouting"), + number_of_nodes=properties.get("NumberOfNodes"), publicly_accessible=properties.get("PubliclyAccessible"), encrypted=properties.get("Encrypted"), region_name=region_name, @@ -163,41 +198,43 @@ class Cluster(TaggableResourceMixin, BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Endpoint.Address': + + if attribute_name == "Endpoint.Address": return self.endpoint - elif attribute_name == 'Endpoint.Port': + elif attribute_name == "Endpoint.Port": return self.port raise UnformattedGetAttTemplateException() @property def endpoint(self): return "{0}.cg034hpkmmjt.{1}.redshift.amazonaws.com".format( - self.cluster_identifier, - self.region, + self.cluster_identifier, self.region ) @property def security_groups(self): return [ - security_group for security_group - in self.redshift_backend.describe_cluster_security_groups() - if security_group.cluster_security_group_name in self.cluster_security_groups + security_group + for security_group in self.redshift_backend.describe_cluster_security_groups() + if security_group.cluster_security_group_name + in self.cluster_security_groups ] @property def vpc_security_groups(self): return [ - security_group for security_group - in self.redshift_backend.ec2_backend.describe_security_groups() + security_group + for security_group in self.redshift_backend.ec2_backend.describe_security_groups() if security_group.id in self.vpc_security_group_ids ] @property def parameter_groups(self): return [ - parameter_group for parameter_group - in self.redshift_backend.describe_cluster_parameter_groups() - if parameter_group.cluster_parameter_group_name in self.cluster_parameter_group_name + parameter_group + for parameter_group in self.redshift_backend.describe_cluster_parameter_groups() + if parameter_group.cluster_parameter_group_name + in self.cluster_parameter_group_name ] @property @@ -209,10 +246,10 @@ class Cluster(TaggableResourceMixin, BaseModel): "MasterUsername": self.master_username, "MasterUserPassword": "****", "ClusterVersion": self.cluster_version, - "VpcSecurityGroups": [{ - "Status": "active", - "VpcSecurityGroupId": group.id - } for group in self.vpc_security_groups], + "VpcSecurityGroups": [ + {"Status": "active", "VpcSecurityGroupId": group.id} + for group in self.vpc_security_groups + ], "ClusterSubnetGroupName": self.cluster_subnet_group_name, "AvailabilityZone": self.availability_zone, "ClusterStatus": self.status, @@ -222,41 +259,47 @@ class Cluster(TaggableResourceMixin, BaseModel): "Encrypted": self.encrypted, "DBName": self.db_name, "PreferredMaintenanceWindow": self.preferred_maintenance_window, - "ClusterParameterGroups": [{ - "ParameterApplyStatus": "in-sync", - "ParameterGroupName": group.cluster_parameter_group_name, - } for group in self.parameter_groups], - "ClusterSecurityGroups": [{ - "Status": "active", - "ClusterSecurityGroupName": group.cluster_security_group_name, - } for group in self.security_groups], + "ClusterParameterGroups": [ + { + "ParameterApplyStatus": "in-sync", + "ParameterGroupName": group.cluster_parameter_group_name, + } + for group in self.parameter_groups + ], + "ClusterSecurityGroups": [ + { + "Status": "active", + "ClusterSecurityGroupName": group.cluster_security_group_name, + } + for group in self.security_groups + ], "Port": self.port, "NodeType": self.node_type, "ClusterIdentifier": self.cluster_identifier, "AllowVersionUpgrade": self.allow_version_upgrade, - "Endpoint": { - "Address": self.endpoint, - "Port": self.port - }, - 'ClusterCreateTime': self.create_time, + "Endpoint": {"Address": self.endpoint, "Port": self.port}, + "ClusterCreateTime": self.create_time, "PendingModifiedValues": [], "Tags": self.tags, - "IamRoles": [{ - "ApplyStatus": "in-sync", - "IamRoleArn": iam_role_arn - } for iam_role_arn in self.iam_roles_arn] + "EnhancedVpcRouting": self.enhanced_vpc_routing, + "IamRoles": [ + {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} + for iam_role_arn in self.iam_roles_arn + ], } if self.restored_from_snapshot: - json_response['RestoreStatus'] = { - 'Status': 'completed', - 'CurrentRestoreRateInMegaBytesPerSecond': 123.0, - 'SnapshotSizeInMegaBytes': 123, - 'ProgressInMegaBytes': 123, - 'ElapsedTimeInSeconds': 123, - 'EstimatedTimeToCompletionInSeconds': 123 + json_response["RestoreStatus"] = { + "Status": "completed", + "CurrentRestoreRateInMegaBytesPerSecond": 123.0, + "SnapshotSizeInMegaBytes": 123, + "ProgressInMegaBytes": 123, + "ElapsedTimeInSeconds": 123, + "EstimatedTimeToCompletionInSeconds": 123, } try: - json_response['ClusterSnapshotCopyStatus'] = self.cluster_snapshot_copy_status + json_response[ + "ClusterSnapshotCopyStatus" + ] = self.cluster_snapshot_copy_status except AttributeError: pass return json_response @@ -264,7 +307,7 @@ class Cluster(TaggableResourceMixin, BaseModel): class SnapshotCopyGrant(TaggableResourceMixin, BaseModel): - resource_type = 'snapshotcopygrant' + resource_type = "snapshotcopygrant" def __init__(self, snapshot_copy_grant_name, kms_key_id): self.snapshot_copy_grant_name = snapshot_copy_grant_name @@ -273,16 +316,23 @@ class SnapshotCopyGrant(TaggableResourceMixin, BaseModel): def to_json(self): return { "SnapshotCopyGrantName": self.snapshot_copy_grant_name, - "KmsKeyId": self.kms_key_id + "KmsKeyId": self.kms_key_id, } class SubnetGroup(TaggableResourceMixin, BaseModel): - resource_type = 'subnetgroup' + resource_type = "subnetgroup" - def __init__(self, ec2_backend, cluster_subnet_group_name, description, subnet_ids, - region_name, tags=None): + def __init__( + self, + ec2_backend, + cluster_subnet_group_name, + description, + subnet_ids, + region_name, + tags=None, + ): super(SubnetGroup, self).__init__(region_name, tags) self.ec2_backend = ec2_backend self.cluster_subnet_group_name = cluster_subnet_group_name @@ -292,21 +342,23 @@ class SubnetGroup(TaggableResourceMixin, BaseModel): raise InvalidSubnetError(subnet_ids) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): redshift_backend = redshift_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] subnet_group = redshift_backend.create_cluster_subnet_group( cluster_subnet_group_name=resource_name, description=properties.get("Description"), subnet_ids=properties.get("SubnetIds", []), - region_name=region_name + region_name=region_name, ) return subnet_group @property def subnets(self): - return self.ec2_backend.get_all_subnets(filters={'subnet-id': self.subnet_ids}) + return self.ec2_backend.get_all_subnets(filters={"subnet-id": self.subnet_ids}) @property def vpc_id(self): @@ -322,22 +374,25 @@ class SubnetGroup(TaggableResourceMixin, BaseModel): "Description": self.description, "ClusterSubnetGroupName": self.cluster_subnet_group_name, "SubnetGroupStatus": "Complete", - "Subnets": [{ - "SubnetStatus": "Active", - "SubnetIdentifier": subnet.id, - "SubnetAvailabilityZone": { - "Name": subnet.availability_zone - }, - } for subnet in self.subnets], - "Tags": self.tags + "Subnets": [ + { + "SubnetStatus": "Active", + "SubnetIdentifier": subnet.id, + "SubnetAvailabilityZone": {"Name": subnet.availability_zone}, + } + for subnet in self.subnets + ], + "Tags": self.tags, } class SecurityGroup(TaggableResourceMixin, BaseModel): - resource_type = 'securitygroup' + resource_type = "securitygroup" - def __init__(self, cluster_security_group_name, description, region_name, tags=None): + def __init__( + self, cluster_security_group_name, description, region_name, tags=None + ): super(SecurityGroup, self).__init__(region_name, tags) self.cluster_security_group_name = cluster_security_group_name self.description = description @@ -352,30 +407,39 @@ class SecurityGroup(TaggableResourceMixin, BaseModel): "IPRanges": [], "Description": self.description, "ClusterSecurityGroupName": self.cluster_security_group_name, - "Tags": self.tags + "Tags": self.tags, } class ParameterGroup(TaggableResourceMixin, BaseModel): - resource_type = 'parametergroup' + resource_type = "parametergroup" - def __init__(self, cluster_parameter_group_name, group_family, description, region_name, tags=None): + def __init__( + self, + cluster_parameter_group_name, + group_family, + description, + region_name, + tags=None, + ): super(ParameterGroup, self).__init__(region_name, tags) self.cluster_parameter_group_name = cluster_parameter_group_name self.group_family = group_family self.description = description @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): redshift_backend = redshift_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] parameter_group = redshift_backend.create_cluster_parameter_group( cluster_parameter_group_name=resource_name, description=properties.get("Description"), group_family=properties.get("ParameterGroupFamily"), - region_name=region_name + region_name=region_name, ) return parameter_group @@ -388,77 +452,81 @@ class ParameterGroup(TaggableResourceMixin, BaseModel): "ParameterGroupFamily": self.group_family, "Description": self.description, "ParameterGroupName": self.cluster_parameter_group_name, - "Tags": self.tags + "Tags": self.tags, } class Snapshot(TaggableResourceMixin, BaseModel): - resource_type = 'snapshot' + resource_type = "snapshot" - def __init__(self, cluster, snapshot_identifier, region_name, tags=None, iam_roles_arn=None): + def __init__( + self, cluster, snapshot_identifier, region_name, tags=None, iam_roles_arn=None + ): super(Snapshot, self).__init__(region_name, tags) self.cluster = copy.copy(cluster) self.snapshot_identifier = snapshot_identifier - self.snapshot_type = 'manual' - self.status = 'available' - self.create_time = iso_8601_datetime_with_milliseconds( - datetime.datetime.now()) + self.snapshot_type = "manual" + self.status = "available" + self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.iam_roles_arn = iam_roles_arn or [] @property def resource_id(self): return "{cluster_id}/{snapshot_id}".format( cluster_id=self.cluster.cluster_identifier, - snapshot_id=self.snapshot_identifier) + snapshot_id=self.snapshot_identifier, + ) def to_json(self): return { - 'SnapshotIdentifier': self.snapshot_identifier, - 'ClusterIdentifier': self.cluster.cluster_identifier, - 'SnapshotCreateTime': self.create_time, - 'Status': self.status, - 'Port': self.cluster.port, - 'AvailabilityZone': self.cluster.availability_zone, - 'MasterUsername': self.cluster.master_username, - 'ClusterVersion': self.cluster.cluster_version, - 'SnapshotType': self.snapshot_type, - 'NodeType': self.cluster.node_type, - 'NumberOfNodes': self.cluster.number_of_nodes, - 'DBName': self.cluster.db_name, - 'Tags': self.tags, - "IamRoles": [{ - "ApplyStatus": "in-sync", - "IamRoleArn": iam_role_arn - } for iam_role_arn in self.iam_roles_arn] + "SnapshotIdentifier": self.snapshot_identifier, + "ClusterIdentifier": self.cluster.cluster_identifier, + "SnapshotCreateTime": self.create_time, + "Status": self.status, + "Port": self.cluster.port, + "AvailabilityZone": self.cluster.availability_zone, + "MasterUsername": self.cluster.master_username, + "ClusterVersion": self.cluster.cluster_version, + "SnapshotType": self.snapshot_type, + "NodeType": self.cluster.node_type, + "NumberOfNodes": self.cluster.number_of_nodes, + "DBName": self.cluster.db_name, + "Tags": self.tags, + "EnhancedVpcRouting": self.cluster.enhanced_vpc_routing, + "IamRoles": [ + {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} + for iam_role_arn in self.iam_roles_arn + ], } class RedshiftBackend(BaseBackend): - def __init__(self, ec2_backend, region_name): self.region = region_name self.clusters = {} self.subnet_groups = {} self.security_groups = { - "Default": SecurityGroup("Default", "Default Redshift Security Group", self.region) + "Default": SecurityGroup( + "Default", "Default Redshift Security Group", self.region + ) } self.parameter_groups = { "default.redshift-1.0": ParameterGroup( "default.redshift-1.0", "redshift-1.0", "Default Redshift parameter group", - self.region + self.region, ) } self.ec2_backend = ec2_backend self.snapshots = OrderedDict() self.RESOURCE_TYPE_MAP = { - 'cluster': self.clusters, - 'parametergroup': self.parameter_groups, - 'securitygroup': self.security_groups, - 'snapshot': self.snapshots, - 'subnetgroup': self.subnet_groups + "cluster": self.clusters, + "parametergroup": self.parameter_groups, + "securitygroup": self.security_groups, + "snapshot": self.snapshots, + "subnetgroup": self.subnet_groups, } self.snapshot_copy_grants = {} @@ -469,19 +537,22 @@ class RedshiftBackend(BaseBackend): self.__init__(ec2_backend, region_name) def enable_snapshot_copy(self, **kwargs): - cluster_identifier = kwargs['cluster_identifier'] + cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] - if not hasattr(cluster, 'cluster_snapshot_copy_status'): - if cluster.encrypted == 'true' and kwargs['snapshot_copy_grant_name'] is None: + if not hasattr(cluster, "cluster_snapshot_copy_status"): + if ( + cluster.encrypted == "true" + and kwargs["snapshot_copy_grant_name"] is None + ): raise ClientError( - 'InvalidParameterValue', - 'SnapshotCopyGrantName is required for Snapshot Copy ' - 'on KMS encrypted clusters.' + "InvalidParameterValue", + "SnapshotCopyGrantName is required for Snapshot Copy " + "on KMS encrypted clusters.", ) status = { - 'DestinationRegion': kwargs['destination_region'], - 'RetentionPeriod': kwargs['retention_period'], - 'SnapshotCopyGrantName': kwargs['snapshot_copy_grant_name'], + "DestinationRegion": kwargs["destination_region"], + "RetentionPeriod": kwargs["retention_period"], + "SnapshotCopyGrantName": kwargs["snapshot_copy_grant_name"], } cluster.cluster_snapshot_copy_status = status return cluster @@ -489,24 +560,26 @@ class RedshiftBackend(BaseBackend): raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) def disable_snapshot_copy(self, **kwargs): - cluster_identifier = kwargs['cluster_identifier'] + cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] - if hasattr(cluster, 'cluster_snapshot_copy_status'): + if hasattr(cluster, "cluster_snapshot_copy_status"): del cluster.cluster_snapshot_copy_status return cluster else: raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) - def modify_snapshot_copy_retention_period(self, cluster_identifier, retention_period): + def modify_snapshot_copy_retention_period( + self, cluster_identifier, retention_period + ): cluster = self.clusters[cluster_identifier] - if hasattr(cluster, 'cluster_snapshot_copy_status'): - cluster.cluster_snapshot_copy_status['RetentionPeriod'] = retention_period + if hasattr(cluster, "cluster_snapshot_copy_status"): + cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period return cluster else: raise SnapshotCopyDisabledFaultError(cluster_identifier) def create_cluster(self, **cluster_kwargs): - cluster_identifier = cluster_kwargs['cluster_identifier'] + cluster_identifier = cluster_kwargs["cluster_identifier"] cluster = Cluster(self, **cluster_kwargs) self.clusters[cluster_identifier] = cluster return cluster @@ -521,9 +594,8 @@ class RedshiftBackend(BaseBackend): return clusters def modify_cluster(self, **cluster_kwargs): - cluster_identifier = cluster_kwargs.pop('cluster_identifier') - new_cluster_identifier = cluster_kwargs.pop( - 'new_cluster_identifier', None) + cluster_identifier = cluster_kwargs.pop("cluster_identifier") + new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None) cluster = self.describe_clusters(cluster_identifier)[0] @@ -534,7 +606,7 @@ class RedshiftBackend(BaseBackend): dic = { "cluster_identifier": cluster_identifier, "skip_final_snapshot": True, - "final_cluster_snapshot_identifier": None + "final_cluster_snapshot_identifier": None, } self.delete_cluster(**dic) cluster.cluster_identifier = new_cluster_identifier @@ -545,30 +617,46 @@ class RedshiftBackend(BaseBackend): def delete_cluster(self, **cluster_kwargs): cluster_identifier = cluster_kwargs.pop("cluster_identifier") cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot") - cluster_snapshot_identifer = cluster_kwargs.pop("final_cluster_snapshot_identifier") + cluster_snapshot_identifer = cluster_kwargs.pop( + "final_cluster_snapshot_identifier" + ) if cluster_identifier in self.clusters: - if cluster_skip_final_snapshot is False and cluster_snapshot_identifer is None: + if ( + cluster_skip_final_snapshot is False + and cluster_snapshot_identifer is None + ): raise ClientError( "InvalidParameterValue", - 'FinalSnapshotIdentifier is required for Snapshot copy ' - 'when SkipFinalSnapshot is False' + "FinalSnapshotIdentifier is required for Snapshot copy " + "when SkipFinalSnapshot is False", ) - elif cluster_skip_final_snapshot is False and cluster_snapshot_identifer is not None: # create snapshot + elif ( + cluster_skip_final_snapshot is False + and cluster_snapshot_identifer is not None + ): # create snapshot cluster = self.describe_clusters(cluster_identifier)[0] self.create_cluster_snapshot( cluster_identifier, cluster_snapshot_identifer, cluster.region, - cluster.tags) + cluster.tags, + ) return self.clusters.pop(cluster_identifier) raise ClusterNotFoundError(cluster_identifier) - def create_cluster_subnet_group(self, cluster_subnet_group_name, description, subnet_ids, - region_name, tags=None): + def create_cluster_subnet_group( + self, cluster_subnet_group_name, description, subnet_ids, region_name, tags=None + ): subnet_group = SubnetGroup( - self.ec2_backend, cluster_subnet_group_name, description, subnet_ids, region_name, tags) + self.ec2_backend, + cluster_subnet_group_name, + description, + subnet_ids, + region_name, + tags, + ) self.subnet_groups[cluster_subnet_group_name] = subnet_group return subnet_group @@ -586,9 +674,12 @@ class RedshiftBackend(BaseBackend): return self.subnet_groups.pop(subnet_identifier) raise ClusterSubnetGroupNotFoundError(subnet_identifier) - def create_cluster_security_group(self, cluster_security_group_name, description, region_name, tags=None): + def create_cluster_security_group( + self, cluster_security_group_name, description, region_name, tags=None + ): security_group = SecurityGroup( - cluster_security_group_name, description, region_name, tags) + cluster_security_group_name, description, region_name, tags + ) self.security_groups[cluster_security_group_name] = security_group return security_group @@ -606,10 +697,17 @@ class RedshiftBackend(BaseBackend): return self.security_groups.pop(security_group_identifier) raise ClusterSecurityGroupNotFoundError(security_group_identifier) - def create_cluster_parameter_group(self, cluster_parameter_group_name, - group_family, description, region_name, tags=None): + def create_cluster_parameter_group( + self, + cluster_parameter_group_name, + group_family, + description, + region_name, + tags=None, + ): parameter_group = ParameterGroup( - cluster_parameter_group_name, group_family, description, region_name, tags) + cluster_parameter_group_name, group_family, description, region_name, tags + ) self.parameter_groups[cluster_parameter_group_name] = parameter_group return parameter_group @@ -628,7 +726,9 @@ class RedshiftBackend(BaseBackend): return self.parameter_groups.pop(parameter_group_name) raise ClusterParameterGroupNotFoundError(parameter_group_name) - def create_cluster_snapshot(self, cluster_identifier, snapshot_identifier, region_name, tags): + def create_cluster_snapshot( + self, cluster_identifier, snapshot_identifier, region_name, tags + ): cluster = self.clusters.get(cluster_identifier) if not cluster: raise ClusterNotFoundError(cluster_identifier) @@ -638,7 +738,9 @@ class RedshiftBackend(BaseBackend): self.snapshots[snapshot_identifier] = snapshot return snapshot - def describe_cluster_snapshots(self, cluster_identifier=None, snapshot_identifier=None): + def describe_cluster_snapshots( + self, cluster_identifier=None, snapshot_identifier=None + ): if cluster_identifier: cluster_snapshots = [] for snapshot in self.snapshots.values(): @@ -660,47 +762,54 @@ class RedshiftBackend(BaseBackend): raise ClusterSnapshotNotFoundError(snapshot_identifier) deleted_snapshot = self.snapshots.pop(snapshot_identifier) - deleted_snapshot.status = 'deleted' + deleted_snapshot.status = "deleted" return deleted_snapshot def restore_from_cluster_snapshot(self, **kwargs): - snapshot_identifier = kwargs.pop('snapshot_identifier') - snapshot = self.describe_cluster_snapshots(snapshot_identifier=snapshot_identifier)[0] + snapshot_identifier = kwargs.pop("snapshot_identifier") + snapshot = self.describe_cluster_snapshots( + snapshot_identifier=snapshot_identifier + )[0] create_kwargs = { "node_type": snapshot.cluster.node_type, "master_username": snapshot.cluster.master_username, "master_user_password": snapshot.cluster.master_user_password, "db_name": snapshot.cluster.db_name, - "cluster_type": 'multi-node' if snapshot.cluster.number_of_nodes > 1 else 'single-node', + "cluster_type": "multi-node" + if snapshot.cluster.number_of_nodes > 1 + else "single-node", "availability_zone": snapshot.cluster.availability_zone, "port": snapshot.cluster.port, "cluster_version": snapshot.cluster.cluster_version, "number_of_nodes": snapshot.cluster.number_of_nodes, "encrypted": snapshot.cluster.encrypted, "tags": snapshot.cluster.tags, - "restored_from_snapshot": True + "restored_from_snapshot": True, + "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing, } create_kwargs.update(kwargs) return self.create_cluster(**create_kwargs) def create_snapshot_copy_grant(self, **kwargs): - snapshot_copy_grant_name = kwargs['snapshot_copy_grant_name'] - kms_key_id = kwargs['kms_key_id'] + snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] + kms_key_id = kwargs["kms_key_id"] if snapshot_copy_grant_name not in self.snapshot_copy_grants: - snapshot_copy_grant = SnapshotCopyGrant(snapshot_copy_grant_name, kms_key_id) + snapshot_copy_grant = SnapshotCopyGrant( + snapshot_copy_grant_name, kms_key_id + ) self.snapshot_copy_grants[snapshot_copy_grant_name] = snapshot_copy_grant return snapshot_copy_grant raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name) def delete_snapshot_copy_grant(self, **kwargs): - snapshot_copy_grant_name = kwargs['snapshot_copy_grant_name'] + snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name in self.snapshot_copy_grants: return self.snapshot_copy_grants.pop(snapshot_copy_grant_name) raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) def describe_snapshot_copy_grants(self, **kwargs): copy_grants = self.snapshot_copy_grants.values() - snapshot_copy_grant_name = kwargs['snapshot_copy_grant_name'] + snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name: if snapshot_copy_grant_name in self.snapshot_copy_grants: return [self.snapshot_copy_grants[snapshot_copy_grant_name]] @@ -710,10 +819,10 @@ class RedshiftBackend(BaseBackend): def _get_resource_from_arn(self, arn): try: - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[5] - if resource_type == 'snapshot': - resource_id = arn_breakdown[6].split('/')[1] + if resource_type == "snapshot": + resource_id = arn_breakdown[6].split("/")[1] else: resource_id = arn_breakdown[6] except IndexError: @@ -723,7 +832,8 @@ class RedshiftBackend(BaseBackend): message = ( "Tagging is not supported for this type of resource: '{0}' " "(the ARN is potentially malformed, please check the ARN " - "documentation for more information)".format(resource_type)) + "documentation for more information)".format(resource_type) + ) raise ResourceNotFoundFaultError(message=message) try: resource = resources[resource_id] @@ -738,12 +848,9 @@ class RedshiftBackend(BaseBackend): for resource in resources: for tag in resource.tags: data = { - 'ResourceName': resource.arn, - 'ResourceType': resource.resource_type, - 'Tag': { - 'Key': tag['Key'], - 'Value': tag['Value'] - } + "ResourceName": resource.arn, + "ResourceType": resource.resource_type, + "Tag": {"Key": tag["Key"], "Value": tag["Value"]}, } tagged_resources.append(data) return tagged_resources @@ -768,7 +875,8 @@ class RedshiftBackend(BaseBackend): "You cannot filter a list of resources using an Amazon " "Resource Name (ARN) and a resource type together in the " "same request. Retry the request using either an ARN or " - "a resource type, but not both.") + "a resource type, but not both." + ) if resource_type: return self._describe_tags_for_resource_type(resource_type.lower()) if resource_name: @@ -790,4 +898,6 @@ class RedshiftBackend(BaseBackend): redshift_backends = {} for region in boto.redshift.regions(): - redshift_backends[region.name] = RedshiftBackend(ec2_backends[region.name], region.name) + redshift_backends[region.name] = RedshiftBackend( + ec2_backends[region.name], region.name + ) diff --git a/moto/redshift/responses.py b/moto/redshift/responses.py index a7758febb..a4094949f 100644 --- a/moto/redshift/responses.py +++ b/moto/redshift/responses.py @@ -13,9 +13,10 @@ from .models import redshift_backends def convert_json_error_to_xml(json_error): error = json.loads(json_error) - code = error['Error']['Code'] - message = error['Error']['Message'] - template = Template(""" + code = error["Error"]["Code"] + message = error["Error"]["Message"] + template = Template( + """ {{ code }} @@ -23,7 +24,8 @@ def convert_json_error_to_xml(json_error): Sender 6876f774-7273-11e4-85dc-39e55ca848d1 - """) + """ + ) return template.render(code=code, message=message) @@ -40,13 +42,12 @@ def itemize(data): ret[key] = itemize(data[key]) return ret elif isinstance(data, list): - return {'item': [itemize(value) for value in data]} + return {"item": [itemize(value) for value in data]} else: return data class RedshiftResponse(BaseResponse): - @property def redshift_backend(self): return redshift_backends[self.region] @@ -56,8 +57,8 @@ class RedshiftResponse(BaseResponse): return json.dumps(response) else: xml = xmltodict.unparse(itemize(response), full_document=False) - if hasattr(xml, 'decode'): - xml = xml.decode('utf-8') + if hasattr(xml, "decode"): + xml = xml.decode("utf-8") return xml def call_action(self): @@ -69,11 +70,12 @@ class RedshiftResponse(BaseResponse): def unpack_complex_list_params(self, label, names): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}.{2}'.format(label, count, names[0])): + while self._get_param("{0}.{1}.{2}".format(label, count, names[0])): param = dict() for i in range(len(names)): param[names[i]] = self._get_param( - '{0}.{1}.{2}'.format(label, count, names[i])) + "{0}.{1}.{2}".format(label, count, names[i]) + ) unpacked_list.append(param) count += 1 return unpacked_list @@ -81,143 +83,168 @@ class RedshiftResponse(BaseResponse): def unpack_list_params(self, label): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}'.format(label, count)): - unpacked_list.append(self._get_param( - '{0}.{1}'.format(label, count))) + while self._get_param("{0}.{1}".format(label, count)): + unpacked_list.append(self._get_param("{0}.{1}".format(label, count))) count += 1 return unpacked_list def _get_cluster_security_groups(self): - cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.member') + cluster_security_groups = self._get_multi_param("ClusterSecurityGroups.member") if not cluster_security_groups: - cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.ClusterSecurityGroupName') + cluster_security_groups = self._get_multi_param( + "ClusterSecurityGroups.ClusterSecurityGroupName" + ) return cluster_security_groups def _get_vpc_security_group_ids(self): - vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.member') + vpc_security_group_ids = self._get_multi_param("VpcSecurityGroupIds.member") if not vpc_security_group_ids: - vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.VpcSecurityGroupId') + vpc_security_group_ids = self._get_multi_param( + "VpcSecurityGroupIds.VpcSecurityGroupId" + ) return vpc_security_group_ids def _get_iam_roles(self): - iam_roles = self._get_multi_param('IamRoles.member') + iam_roles = self._get_multi_param("IamRoles.member") if not iam_roles: - iam_roles = self._get_multi_param('IamRoles.IamRoleArn') + iam_roles = self._get_multi_param("IamRoles.IamRoleArn") return iam_roles def _get_subnet_ids(self): - subnet_ids = self._get_multi_param('SubnetIds.member') + subnet_ids = self._get_multi_param("SubnetIds.member") if not subnet_ids: - subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier') + subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") return subnet_ids def create_cluster(self): cluster_kwargs = { - "cluster_identifier": self._get_param('ClusterIdentifier'), - "node_type": self._get_param('NodeType'), - "master_username": self._get_param('MasterUsername'), - "master_user_password": self._get_param('MasterUserPassword'), - "db_name": self._get_param('DBName'), - "cluster_type": self._get_param('ClusterType'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "node_type": self._get_param("NodeType"), + "master_username": self._get_param("MasterUsername"), + "master_user_password": self._get_param("MasterUserPassword"), + "db_name": self._get_param("DBName"), + "cluster_type": self._get_param("ClusterType"), "cluster_security_groups": self._get_cluster_security_groups(), "vpc_security_group_ids": self._get_vpc_security_group_ids(), - "cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'), - "availability_zone": self._get_param('AvailabilityZone'), - "preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'), - "cluster_parameter_group_name": self._get_param('ClusterParameterGroupName'), - "automated_snapshot_retention_period": self._get_int_param('AutomatedSnapshotRetentionPeriod'), - "port": self._get_int_param('Port'), - "cluster_version": self._get_param('ClusterVersion'), - "allow_version_upgrade": self._get_bool_param('AllowVersionUpgrade'), - "number_of_nodes": self._get_int_param('NumberOfNodes'), + "cluster_subnet_group_name": self._get_param("ClusterSubnetGroupName"), + "availability_zone": self._get_param("AvailabilityZone"), + "preferred_maintenance_window": self._get_param( + "PreferredMaintenanceWindow" + ), + "cluster_parameter_group_name": self._get_param( + "ClusterParameterGroupName" + ), + "automated_snapshot_retention_period": self._get_int_param( + "AutomatedSnapshotRetentionPeriod" + ), + "port": self._get_int_param("Port"), + "cluster_version": self._get_param("ClusterVersion"), + "allow_version_upgrade": self._get_bool_param("AllowVersionUpgrade"), + "number_of_nodes": self._get_int_param("NumberOfNodes"), "publicly_accessible": self._get_param("PubliclyAccessible"), "encrypted": self._get_param("Encrypted"), "region_name": self.region, - "tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')), + "tags": self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")), "iam_roles_arn": self._get_iam_roles(), + "enhanced_vpc_routing": self._get_param("EnhancedVpcRouting"), } cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json() - cluster['ClusterStatus'] = 'creating' - return self.get_response({ - "CreateClusterResponse": { - "CreateClusterResult": { - "Cluster": cluster, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + cluster["ClusterStatus"] = "creating" + return self.get_response( + { + "CreateClusterResponse": { + "CreateClusterResult": {"Cluster": cluster}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def restore_from_cluster_snapshot(self): + enhanced_vpc_routing = self._get_bool_param("EnhancedVpcRouting") restore_kwargs = { - "snapshot_identifier": self._get_param('SnapshotIdentifier'), - "cluster_identifier": self._get_param('ClusterIdentifier'), - "port": self._get_int_param('Port'), - "availability_zone": self._get_param('AvailabilityZone'), - "allow_version_upgrade": self._get_bool_param( - 'AllowVersionUpgrade'), - "cluster_subnet_group_name": self._get_param( - 'ClusterSubnetGroupName'), + "snapshot_identifier": self._get_param("SnapshotIdentifier"), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "port": self._get_int_param("Port"), + "availability_zone": self._get_param("AvailabilityZone"), + "allow_version_upgrade": self._get_bool_param("AllowVersionUpgrade"), + "cluster_subnet_group_name": self._get_param("ClusterSubnetGroupName"), "publicly_accessible": self._get_param("PubliclyAccessible"), "cluster_parameter_group_name": self._get_param( - 'ClusterParameterGroupName'), + "ClusterParameterGroupName" + ), "cluster_security_groups": self._get_cluster_security_groups(), "vpc_security_group_ids": self._get_vpc_security_group_ids(), "preferred_maintenance_window": self._get_param( - 'PreferredMaintenanceWindow'), + "PreferredMaintenanceWindow" + ), "automated_snapshot_retention_period": self._get_int_param( - 'AutomatedSnapshotRetentionPeriod'), + "AutomatedSnapshotRetentionPeriod" + ), "region_name": self.region, "iam_roles_arn": self._get_iam_roles(), } - cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json() - cluster['ClusterStatus'] = 'creating' - return self.get_response({ - "RestoreFromClusterSnapshotResponse": { - "RestoreFromClusterSnapshotResult": { - "Cluster": cluster, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + if enhanced_vpc_routing is not None: + restore_kwargs["enhanced_vpc_routing"] = enhanced_vpc_routing + cluster = self.redshift_backend.restore_from_cluster_snapshot( + **restore_kwargs + ).to_json() + cluster["ClusterStatus"] = "creating" + return self.get_response( + { + "RestoreFromClusterSnapshotResponse": { + "RestoreFromClusterSnapshotResult": {"Cluster": cluster}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_clusters(self): cluster_identifier = self._get_param("ClusterIdentifier") clusters = self.redshift_backend.describe_clusters(cluster_identifier) - return self.get_response({ - "DescribeClustersResponse": { - "DescribeClustersResult": { - "Clusters": [cluster.to_json() for cluster in clusters] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClustersResponse": { + "DescribeClustersResult": { + "Clusters": [cluster.to_json() for cluster in clusters] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def modify_cluster(self): request_kwargs = { - "cluster_identifier": self._get_param('ClusterIdentifier'), - "new_cluster_identifier": self._get_param('NewClusterIdentifier'), - "node_type": self._get_param('NodeType'), - "master_user_password": self._get_param('MasterUserPassword'), - "cluster_type": self._get_param('ClusterType'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "new_cluster_identifier": self._get_param("NewClusterIdentifier"), + "node_type": self._get_param("NodeType"), + "master_user_password": self._get_param("MasterUserPassword"), + "cluster_type": self._get_param("ClusterType"), "cluster_security_groups": self._get_cluster_security_groups(), "vpc_security_group_ids": self._get_vpc_security_group_ids(), - "cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'), - "preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'), - "cluster_parameter_group_name": self._get_param('ClusterParameterGroupName'), - "automated_snapshot_retention_period": self._get_int_param('AutomatedSnapshotRetentionPeriod'), - "cluster_version": self._get_param('ClusterVersion'), - "allow_version_upgrade": self._get_bool_param('AllowVersionUpgrade'), - "number_of_nodes": self._get_int_param('NumberOfNodes'), + "cluster_subnet_group_name": self._get_param("ClusterSubnetGroupName"), + "preferred_maintenance_window": self._get_param( + "PreferredMaintenanceWindow" + ), + "cluster_parameter_group_name": self._get_param( + "ClusterParameterGroupName" + ), + "automated_snapshot_retention_period": self._get_int_param( + "AutomatedSnapshotRetentionPeriod" + ), + "cluster_version": self._get_param("ClusterVersion"), + "allow_version_upgrade": self._get_bool_param("AllowVersionUpgrade"), + "number_of_nodes": self._get_int_param("NumberOfNodes"), "publicly_accessible": self._get_param("PubliclyAccessible"), "encrypted": self._get_param("Encrypted"), "iam_roles_arn": self._get_iam_roles(), + "enhanced_vpc_routing": self._get_param("EnhancedVpcRouting"), } cluster_kwargs = {} # We only want parameters that were actually passed in, otherwise @@ -228,394 +255,442 @@ class RedshiftResponse(BaseResponse): cluster = self.redshift_backend.modify_cluster(**cluster_kwargs) - return self.get_response({ - "ModifyClusterResponse": { - "ModifyClusterResult": { - "Cluster": cluster.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "ModifyClusterResponse": { + "ModifyClusterResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster(self): request_kwargs = { "cluster_identifier": self._get_param("ClusterIdentifier"), - "final_cluster_snapshot_identifier": self._get_param("FinalClusterSnapshotIdentifier"), - "skip_final_snapshot": self._get_bool_param("SkipFinalClusterSnapshot") + "final_cluster_snapshot_identifier": self._get_param( + "FinalClusterSnapshotIdentifier" + ), + "skip_final_snapshot": self._get_bool_param("SkipFinalClusterSnapshot"), } cluster = self.redshift_backend.delete_cluster(**request_kwargs) - return self.get_response({ - "DeleteClusterResponse": { - "DeleteClusterResult": { - "Cluster": cluster.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterResponse": { + "DeleteClusterResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def create_cluster_subnet_group(self): - cluster_subnet_group_name = self._get_param('ClusterSubnetGroupName') - description = self._get_param('Description') + cluster_subnet_group_name = self._get_param("ClusterSubnetGroupName") + description = self._get_param("Description") subnet_ids = self._get_subnet_ids() - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) subnet_group = self.redshift_backend.create_cluster_subnet_group( cluster_subnet_group_name=cluster_subnet_group_name, description=description, subnet_ids=subnet_ids, region_name=self.region, - tags=tags + tags=tags, ) - return self.get_response({ - "CreateClusterSubnetGroupResponse": { - "CreateClusterSubnetGroupResult": { - "ClusterSubnetGroup": subnet_group.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "CreateClusterSubnetGroupResponse": { + "CreateClusterSubnetGroupResult": { + "ClusterSubnetGroup": subnet_group.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_subnet_groups(self): subnet_identifier = self._get_param("ClusterSubnetGroupName") subnet_groups = self.redshift_backend.describe_cluster_subnet_groups( - subnet_identifier) + subnet_identifier + ) - return self.get_response({ - "DescribeClusterSubnetGroupsResponse": { - "DescribeClusterSubnetGroupsResult": { - "ClusterSubnetGroups": [subnet_group.to_json() for subnet_group in subnet_groups] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClusterSubnetGroupsResponse": { + "DescribeClusterSubnetGroupsResult": { + "ClusterSubnetGroups": [ + subnet_group.to_json() for subnet_group in subnet_groups + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_subnet_group(self): subnet_identifier = self._get_param("ClusterSubnetGroupName") self.redshift_backend.delete_cluster_subnet_group(subnet_identifier) - return self.get_response({ - "DeleteClusterSubnetGroupResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterSubnetGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def create_cluster_security_group(self): - cluster_security_group_name = self._get_param( - 'ClusterSecurityGroupName') - description = self._get_param('Description') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + cluster_security_group_name = self._get_param("ClusterSecurityGroupName") + description = self._get_param("Description") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) security_group = self.redshift_backend.create_cluster_security_group( cluster_security_group_name=cluster_security_group_name, description=description, region_name=self.region, - tags=tags + tags=tags, ) - return self.get_response({ - "CreateClusterSecurityGroupResponse": { - "CreateClusterSecurityGroupResult": { - "ClusterSecurityGroup": security_group.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "CreateClusterSecurityGroupResponse": { + "CreateClusterSecurityGroupResult": { + "ClusterSecurityGroup": security_group.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_security_groups(self): - cluster_security_group_name = self._get_param( - "ClusterSecurityGroupName") + cluster_security_group_name = self._get_param("ClusterSecurityGroupName") security_groups = self.redshift_backend.describe_cluster_security_groups( - cluster_security_group_name) + cluster_security_group_name + ) - return self.get_response({ - "DescribeClusterSecurityGroupsResponse": { - "DescribeClusterSecurityGroupsResult": { - "ClusterSecurityGroups": [security_group.to_json() for security_group in security_groups] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClusterSecurityGroupsResponse": { + "DescribeClusterSecurityGroupsResult": { + "ClusterSecurityGroups": [ + security_group.to_json() + for security_group in security_groups + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_security_group(self): security_group_identifier = self._get_param("ClusterSecurityGroupName") - self.redshift_backend.delete_cluster_security_group( - security_group_identifier) + self.redshift_backend.delete_cluster_security_group(security_group_identifier) - return self.get_response({ - "DeleteClusterSecurityGroupResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterSecurityGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) - - def create_cluster_parameter_group(self): - cluster_parameter_group_name = self._get_param('ParameterGroupName') - group_family = self._get_param('ParameterGroupFamily') - description = self._get_param('Description') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) - - parameter_group = self.redshift_backend.create_cluster_parameter_group( - cluster_parameter_group_name, - group_family, - description, - self.region, - tags ) - return self.get_response({ - "CreateClusterParameterGroupResponse": { - "CreateClusterParameterGroupResult": { - "ClusterParameterGroup": parameter_group.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + def create_cluster_parameter_group(self): + cluster_parameter_group_name = self._get_param("ParameterGroupName") + group_family = self._get_param("ParameterGroupFamily") + description = self._get_param("Description") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) + + parameter_group = self.redshift_backend.create_cluster_parameter_group( + cluster_parameter_group_name, group_family, description, self.region, tags + ) + + return self.get_response( + { + "CreateClusterParameterGroupResponse": { + "CreateClusterParameterGroupResult": { + "ClusterParameterGroup": parameter_group.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_parameter_groups(self): cluster_parameter_group_name = self._get_param("ParameterGroupName") parameter_groups = self.redshift_backend.describe_cluster_parameter_groups( - cluster_parameter_group_name) + cluster_parameter_group_name + ) - return self.get_response({ - "DescribeClusterParameterGroupsResponse": { - "DescribeClusterParameterGroupsResult": { - "ParameterGroups": [parameter_group.to_json() for parameter_group in parameter_groups] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClusterParameterGroupsResponse": { + "DescribeClusterParameterGroupsResult": { + "ParameterGroups": [ + parameter_group.to_json() + for parameter_group in parameter_groups + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_parameter_group(self): cluster_parameter_group_name = self._get_param("ParameterGroupName") self.redshift_backend.delete_cluster_parameter_group( - cluster_parameter_group_name) + cluster_parameter_group_name + ) - return self.get_response({ - "DeleteClusterParameterGroupResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterParameterGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def create_cluster_snapshot(self): - cluster_identifier = self._get_param('ClusterIdentifier') - snapshot_identifier = self._get_param('SnapshotIdentifier') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + cluster_identifier = self._get_param("ClusterIdentifier") + snapshot_identifier = self._get_param("SnapshotIdentifier") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) - snapshot = self.redshift_backend.create_cluster_snapshot(cluster_identifier, - snapshot_identifier, - self.region, - tags) - return self.get_response({ - 'CreateClusterSnapshotResponse': { - "CreateClusterSnapshotResult": { - "Snapshot": snapshot.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + snapshot = self.redshift_backend.create_cluster_snapshot( + cluster_identifier, snapshot_identifier, self.region, tags + ) + return self.get_response( + { + "CreateClusterSnapshotResponse": { + "CreateClusterSnapshotResult": {"Snapshot": snapshot.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_snapshots(self): - cluster_identifier = self._get_param('ClusterIdentifier') - snapshot_identifier = self._get_param('SnapshotIdentifier') - snapshots = self.redshift_backend.describe_cluster_snapshots(cluster_identifier, - snapshot_identifier) - return self.get_response({ - "DescribeClusterSnapshotsResponse": { - "DescribeClusterSnapshotsResult": { - "Snapshots": [snapshot.to_json() for snapshot in snapshots] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + cluster_identifier = self._get_param("ClusterIdentifier") + snapshot_identifier = self._get_param("SnapshotIdentifier") + snapshots = self.redshift_backend.describe_cluster_snapshots( + cluster_identifier, snapshot_identifier + ) + return self.get_response( + { + "DescribeClusterSnapshotsResponse": { + "DescribeClusterSnapshotsResult": { + "Snapshots": [snapshot.to_json() for snapshot in snapshots] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_snapshot(self): - snapshot_identifier = self._get_param('SnapshotIdentifier') + snapshot_identifier = self._get_param("SnapshotIdentifier") snapshot = self.redshift_backend.delete_cluster_snapshot(snapshot_identifier) - return self.get_response({ - "DeleteClusterSnapshotResponse": { - "DeleteClusterSnapshotResult": { - "Snapshot": snapshot.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterSnapshotResponse": { + "DeleteClusterSnapshotResult": {"Snapshot": snapshot.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def create_snapshot_copy_grant(self): copy_grant_kwargs = { - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), - 'kms_key_id': self._get_param('KmsKeyId'), - 'region_name': self._get_param('Region'), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName"), + "kms_key_id": self._get_param("KmsKeyId"), + "region_name": self._get_param("Region"), } - copy_grant = self.redshift_backend.create_snapshot_copy_grant(**copy_grant_kwargs) - return self.get_response({ - "CreateSnapshotCopyGrantResponse": { - "CreateSnapshotCopyGrantResult": { - "SnapshotCopyGrant": copy_grant.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + copy_grant = self.redshift_backend.create_snapshot_copy_grant( + **copy_grant_kwargs + ) + return self.get_response( + { + "CreateSnapshotCopyGrantResponse": { + "CreateSnapshotCopyGrantResult": { + "SnapshotCopyGrant": copy_grant.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_snapshot_copy_grant(self): copy_grant_kwargs = { - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName") } self.redshift_backend.delete_snapshot_copy_grant(**copy_grant_kwargs) - return self.get_response({ - "DeleteSnapshotCopyGrantResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteSnapshotCopyGrantResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def describe_snapshot_copy_grants(self): copy_grant_kwargs = { - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName") } - copy_grants = self.redshift_backend.describe_snapshot_copy_grants(**copy_grant_kwargs) - return self.get_response({ - "DescribeSnapshotCopyGrantsResponse": { - "DescribeSnapshotCopyGrantsResult": { - "SnapshotCopyGrants": [copy_grant.to_json() for copy_grant in copy_grants] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + copy_grants = self.redshift_backend.describe_snapshot_copy_grants( + **copy_grant_kwargs + ) + return self.get_response( + { + "DescribeSnapshotCopyGrantsResponse": { + "DescribeSnapshotCopyGrantsResult": { + "SnapshotCopyGrants": [ + copy_grant.to_json() for copy_grant in copy_grants + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def create_tags(self): - resource_name = self._get_param('ResourceName') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + resource_name = self._get_param("ResourceName") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) self.redshift_backend.create_tags(resource_name, tags) - return self.get_response({ - "CreateTagsResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "CreateTagsResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def describe_tags(self): - resource_name = self._get_param('ResourceName') - resource_type = self._get_param('ResourceType') + resource_name = self._get_param("ResourceName") + resource_type = self._get_param("ResourceType") - tagged_resources = self.redshift_backend.describe_tags(resource_name, - resource_type) - return self.get_response({ - "DescribeTagsResponse": { - "DescribeTagsResult": { - "TaggedResources": tagged_resources - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + tagged_resources = self.redshift_backend.describe_tags( + resource_name, resource_type + ) + return self.get_response( + { + "DescribeTagsResponse": { + "DescribeTagsResult": {"TaggedResources": tagged_resources}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_tags(self): - resource_name = self._get_param('ResourceName') - tag_keys = self.unpack_list_params('TagKeys.TagKey') + resource_name = self._get_param("ResourceName") + tag_keys = self.unpack_list_params("TagKeys.TagKey") self.redshift_backend.delete_tags(resource_name, tag_keys) - return self.get_response({ - "DeleteTagsResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteTagsResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def enable_snapshot_copy(self): snapshot_copy_kwargs = { - 'cluster_identifier': self._get_param('ClusterIdentifier'), - 'destination_region': self._get_param('DestinationRegion'), - 'retention_period': self._get_param('RetentionPeriod', 7), - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "destination_region": self._get_param("DestinationRegion"), + "retention_period": self._get_param("RetentionPeriod", 7), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName"), } cluster = self.redshift_backend.enable_snapshot_copy(**snapshot_copy_kwargs) - return self.get_response({ - "EnableSnapshotCopyResponse": { - "EnableSnapshotCopyResult": { - "Cluster": cluster.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "EnableSnapshotCopyResponse": { + "EnableSnapshotCopyResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def disable_snapshot_copy(self): snapshot_copy_kwargs = { - 'cluster_identifier': self._get_param('ClusterIdentifier'), + "cluster_identifier": self._get_param("ClusterIdentifier") } cluster = self.redshift_backend.disable_snapshot_copy(**snapshot_copy_kwargs) - return self.get_response({ - "DisableSnapshotCopyResponse": { - "DisableSnapshotCopyResult": { - "Cluster": cluster.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DisableSnapshotCopyResponse": { + "DisableSnapshotCopyResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def modify_snapshot_copy_retention_period(self): snapshot_copy_kwargs = { - 'cluster_identifier': self._get_param('ClusterIdentifier'), - 'retention_period': self._get_param('RetentionPeriod'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "retention_period": self._get_param("RetentionPeriod"), } - cluster = self.redshift_backend.modify_snapshot_copy_retention_period(**snapshot_copy_kwargs) + cluster = self.redshift_backend.modify_snapshot_copy_retention_period( + **snapshot_copy_kwargs + ) - return self.get_response({ - "ModifySnapshotCopyRetentionPeriodResponse": { - "ModifySnapshotCopyRetentionPeriodResult": { - "Clusters": [cluster.to_json()] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "ModifySnapshotCopyRetentionPeriodResponse": { + "ModifySnapshotCopyRetentionPeriodResult": { + "Clusters": [cluster.to_json()] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) diff --git a/moto/redshift/urls.py b/moto/redshift/urls.py index ebef59e86..8494669ee 100644 --- a/moto/redshift/urls.py +++ b/moto/redshift/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import RedshiftResponse -url_bases = [ - "https?://redshift.(.+).amazonaws.com", -] +url_bases = ["https?://redshift.(.+).amazonaws.com"] -url_paths = { - '{0}/$': RedshiftResponse.dispatch, -} +url_paths = {"{0}/$": RedshiftResponse.dispatch} diff --git a/moto/resourcegroups/__init__.py b/moto/resourcegroups/__init__.py index 74b0eb598..13ff17307 100644 --- a/moto/resourcegroups/__init__.py +++ b/moto/resourcegroups/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import resourcegroups_backends from ..core.models import base_decorator -resourcegroups_backend = resourcegroups_backends['us-east-1'] +resourcegroups_backend = resourcegroups_backends["us-east-1"] mock_resourcegroups = base_decorator(resourcegroups_backends) diff --git a/moto/resourcegroups/exceptions.py b/moto/resourcegroups/exceptions.py index a8e542979..6c0f470be 100644 --- a/moto/resourcegroups/exceptions.py +++ b/moto/resourcegroups/exceptions.py @@ -9,5 +9,6 @@ class BadRequestException(HTTPException): def __init__(self, message, **kwargs): super(BadRequestException, self).__init__( - description=json.dumps({"Message": message, "Code": "BadRequestException"}), **kwargs + description=json.dumps({"Message": message, "Code": "BadRequestException"}), + **kwargs ) diff --git a/moto/resourcegroups/models.py b/moto/resourcegroups/models.py index 6734bd48a..7d4d88230 100644 --- a/moto/resourcegroups/models.py +++ b/moto/resourcegroups/models.py @@ -6,6 +6,7 @@ import json import re from moto.core import BaseBackend, BaseModel +from moto.core import ACCOUNT_ID from .exceptions import BadRequestException @@ -23,14 +24,14 @@ class FakeResourceGroup(BaseModel): if self._validate_tags(value=tags): self._tags = tags self._raise_errors() - self.arn = "arn:aws:resource-groups:us-west-1:123456789012:{name}".format(name=name) + self.arn = "arn:aws:resource-groups:us-west-1:{AccountId}:{name}".format( + name=name, AccountId=ACCOUNT_ID + ) @staticmethod def _format_error(key, value, constraint): return "Value '{value}' at '{key}' failed to satisfy constraint: {constraint}".format( - constraint=constraint, - key=key, - value=value, + constraint=constraint, key=key, value=value ) def _raise_errors(self): @@ -38,24 +39,30 @@ class FakeResourceGroup(BaseModel): errors_len = len(self.errors) plural = "s" if len(self.errors) > 1 else "" errors = "; ".join(self.errors) - raise BadRequestException("{errors_len} validation error{plural} detected: {errors}".format( - errors_len=errors_len, plural=plural, errors=errors, - )) + raise BadRequestException( + "{errors_len} validation error{plural} detected: {errors}".format( + errors_len=errors_len, plural=plural, errors=errors + ) + ) def _validate_description(self, value): errors = [] if len(value) > 511: - errors.append(self._format_error( - key="description", - value=value, - constraint="Member must have length less than or equal to 512", - )) + errors.append( + self._format_error( + key="description", + value=value, + constraint="Member must have length less than or equal to 512", + ) + ) if not re.match(r"^[\sa-zA-Z0-9_.-]*$", value): - errors.append(self._format_error( - key="name", - value=value, - constraint=r"Member must satisfy regular expression pattern: [\sa-zA-Z0-9_\.-]*", - )) + errors.append( + self._format_error( + key="name", + value=value, + constraint=r"Member must satisfy regular expression pattern: [\sa-zA-Z0-9_\.-]*", + ) + ) if errors: self.errors += errors return False @@ -64,18 +71,22 @@ class FakeResourceGroup(BaseModel): def _validate_name(self, value): errors = [] if len(value) > 128: - errors.append(self._format_error( - key="name", - value=value, - constraint="Member must have length less than or equal to 128", - )) + errors.append( + self._format_error( + key="name", + value=value, + constraint="Member must have length less than or equal to 128", + ) + ) # Note \ is a character to match not an escape. if not re.match(r"^[a-zA-Z0-9_\\.-]+$", value): - errors.append(self._format_error( - key="name", - value=value, - constraint=r"Member must satisfy regular expression pattern: [a-zA-Z0-9_\.-]+", - )) + errors.append( + self._format_error( + key="name", + value=value, + constraint=r"Member must satisfy regular expression pattern: [a-zA-Z0-9_\.-]+", + ) + ) if errors: self.errors += errors return False @@ -84,17 +95,21 @@ class FakeResourceGroup(BaseModel): def _validate_resource_query(self, value): errors = [] if value["Type"] not in {"CLOUDFORMATION_STACK_1_0", "TAG_FILTERS_1_0"}: - errors.append(self._format_error( - key="resourceQuery.type", - value=value, - constraint="Member must satisfy enum value set: [CLOUDFORMATION_STACK_1_0, TAG_FILTERS_1_0]", - )) + errors.append( + self._format_error( + key="resourceQuery.type", + value=value, + constraint="Member must satisfy enum value set: [CLOUDFORMATION_STACK_1_0, TAG_FILTERS_1_0]", + ) + ) if len(value["Query"]) > 2048: - errors.append(self._format_error( - key="resourceQuery.query", - value=value, - constraint="Member must have length less than or equal to 2048", - )) + errors.append( + self._format_error( + key="resourceQuery.query", + value=value, + constraint="Member must have length less than or equal to 2048", + ) + ) if errors: self.errors += errors return False @@ -183,7 +198,7 @@ class FakeResourceGroup(BaseModel): self._tags = value -class ResourceGroups(): +class ResourceGroups: def __init__(self): self.by_name = {} self.by_arn = {} @@ -213,7 +228,9 @@ class ResourceGroupsBackend(BaseBackend): type = resource_query["Type"] query = json.loads(resource_query["Query"]) query_keys = set(query.keys()) - invalid_json_exception = BadRequestException("Invalid query: Invalid query format: check JSON syntax") + invalid_json_exception = BadRequestException( + "Invalid query: Invalid query format: check JSON syntax" + ) if not isinstance(query["ResourceTypeFilters"], list): raise invalid_json_exception if type == "CLOUDFORMATION_STACK_1_0": @@ -255,7 +272,9 @@ class ResourceGroupsBackend(BaseBackend): "Invalid query: The TagFilter element cannot have empty or null Key field" ) if len(key) > 128: - raise BadRequestException("Invalid query: The maximum length for a tag Key is 128") + raise BadRequestException( + "Invalid query: The maximum length for a tag Key is 128" + ) values = tag_filter["Values"] if not isinstance(values, list): raise invalid_json_exception @@ -274,16 +293,13 @@ class ResourceGroupsBackend(BaseBackend): @staticmethod def _validate_tags(tags): for tag in tags: - if tag.lower().startswith('aws:'): + if tag.lower().startswith("aws:"): raise BadRequestException("Tag keys must not start with 'aws:'") def create_group(self, name, resource_query, description=None, tags=None): tags = tags or {} group = FakeResourceGroup( - name=name, - resource_query=resource_query, - description=description, - tags=tags, + name=name, resource_query=resource_query, description=description, tags=tags ) if name in self.groups: raise BadRequestException("Cannot create group: group already exists") @@ -335,4 +351,6 @@ class ResourceGroupsBackend(BaseBackend): available_regions = boto3.session.Session().get_available_regions("resource-groups") -resourcegroups_backends = {region: ResourceGroupsBackend(region_name=region) for region in available_regions} +resourcegroups_backends = { + region: ResourceGroupsBackend(region_name=region) for region in available_regions +} diff --git a/moto/resourcegroups/responses.py b/moto/resourcegroups/responses.py index 02ea14c1a..77edff19d 100644 --- a/moto/resourcegroups/responses.py +++ b/moto/resourcegroups/responses.py @@ -11,7 +11,7 @@ from .models import resourcegroups_backends class ResourceGroupsResponse(BaseResponse): - SERVICE_NAME = 'resource-groups' + SERVICE_NAME = "resource-groups" @property def resourcegroups_backend(self): @@ -23,140 +23,145 @@ class ResourceGroupsResponse(BaseResponse): resource_query = self._get_param("ResourceQuery") tags = self._get_param("Tags") group = self.resourcegroups_backend.create_group( - name=name, - description=description, - resource_query=resource_query, - tags=tags, + name=name, description=description, resource_query=resource_query, tags=tags + ) + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + }, + "ResourceQuery": group.resource_query, + "Tags": group.tags, + } ) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description - }, - "ResourceQuery": group.resource_query, - "Tags": group.tags - }) def delete_group(self): group_name = self._get_param("GroupName") group = self.resourcegroups_backend.delete_group(group_name=group_name) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description - }, - }) + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } + } + ) def get_group(self): group_name = self._get_param("GroupName") group = self.resourcegroups_backend.get_group(group_name=group_name) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description, + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } } - }) + ) def get_group_query(self): group_name = self._get_param("GroupName") group = self.resourcegroups_backend.get_group(group_name=group_name) - return json.dumps({ - "GroupQuery": { - "GroupName": group.name, - "ResourceQuery": group.resource_query, + return json.dumps( + { + "GroupQuery": { + "GroupName": group.name, + "ResourceQuery": group.resource_query, + } } - }) + ) def get_tags(self): arn = unquote(self._get_param("Arn")) - return json.dumps({ - "Arn": arn, - "Tags": self.resourcegroups_backend.get_tags(arn=arn) - }) + return json.dumps( + {"Arn": arn, "Tags": self.resourcegroups_backend.get_tags(arn=arn)} + ) def list_group_resources(self): - raise NotImplementedError('ResourceGroups.list_group_resources is not yet implemented') + raise NotImplementedError( + "ResourceGroups.list_group_resources is not yet implemented" + ) def list_groups(self): filters = self._get_param("Filters") if filters: raise NotImplementedError( - 'ResourceGroups.list_groups with filter parameter is not yet implemented' + "ResourceGroups.list_groups with filter parameter is not yet implemented" ) max_results = self._get_int_param("MaxResults", 50) next_token = self._get_param("NextToken") groups = self.resourcegroups_backend.list_groups( - filters=filters, - max_results=max_results, - next_token=next_token + filters=filters, max_results=max_results, next_token=next_token + ) + return json.dumps( + { + "GroupIdentifiers": [ + {"GroupName": group.name, "GroupArn": group.arn} + for group in groups.values() + ], + "Groups": [ + { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } + for group in groups.values() + ], + "NextToken": next_token, + } ) - return json.dumps({ - "GroupIdentifiers": [{ - "GroupName": group.name, - "GroupArn": group.arn, - } for group in groups.values()], - "Groups": [{ - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description, - } for group in groups.values()], - "NextToken": next_token, - }) def search_resources(self): - raise NotImplementedError('ResourceGroups.search_resources is not yet implemented') + raise NotImplementedError( + "ResourceGroups.search_resources is not yet implemented" + ) def tag(self): arn = unquote(self._get_param("Arn")) tags = self._get_param("Tags") if arn not in self.resourcegroups_backend.groups.by_arn: raise NotImplementedError( - 'ResourceGroups.tag with non-resource-group Arn parameter is not yet implemented' + "ResourceGroups.tag with non-resource-group Arn parameter is not yet implemented" ) self.resourcegroups_backend.tag(arn=arn, tags=tags) - return json.dumps({ - "Arn": arn, - "Tags": tags - }) + return json.dumps({"Arn": arn, "Tags": tags}) def untag(self): arn = unquote(self._get_param("Arn")) keys = self._get_param("Keys") if arn not in self.resourcegroups_backend.groups.by_arn: raise NotImplementedError( - 'ResourceGroups.untag with non-resource-group Arn parameter is not yet implemented' + "ResourceGroups.untag with non-resource-group Arn parameter is not yet implemented" ) self.resourcegroups_backend.untag(arn=arn, keys=keys) - return json.dumps({ - "Arn": arn, - "Keys": keys - }) + return json.dumps({"Arn": arn, "Keys": keys}) def update_group(self): group_name = self._get_param("GroupName") description = self._get_param("Description", "") - group = self.resourcegroups_backend.update_group(group_name=group_name, description=description) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description - }, - }) + group = self.resourcegroups_backend.update_group( + group_name=group_name, description=description + ) + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } + } + ) def update_group_query(self): group_name = self._get_param("GroupName") resource_query = self._get_param("ResourceQuery") group = self.resourcegroups_backend.update_group_query( - group_name=group_name, - resource_query=resource_query + group_name=group_name, resource_query=resource_query + ) + return json.dumps( + {"GroupQuery": {"GroupName": group.name, "ResourceQuery": resource_query}} ) - return json.dumps({ - "GroupQuery": { - "GroupName": group.name, - "ResourceQuery": resource_query - } - }) diff --git a/moto/resourcegroups/urls.py b/moto/resourcegroups/urls.py index 518dde766..b40179145 100644 --- a/moto/resourcegroups/urls.py +++ b/moto/resourcegroups/urls.py @@ -1,14 +1,12 @@ from __future__ import unicode_literals from .responses import ResourceGroupsResponse -url_bases = [ - "https?://resource-groups(-fips)?.(.+).amazonaws.com", -] +url_bases = ["https?://resource-groups(-fips)?.(.+).amazonaws.com"] url_paths = { - '{0}/groups$': ResourceGroupsResponse.dispatch, - '{0}/groups/(?P[^/]+)$': ResourceGroupsResponse.dispatch, - '{0}/groups/(?P[^/]+)/query$': ResourceGroupsResponse.dispatch, - '{0}/groups-list$': ResourceGroupsResponse.dispatch, - '{0}/resources/(?P[^/]+)/tags$': ResourceGroupsResponse.dispatch, + "{0}/groups$": ResourceGroupsResponse.dispatch, + "{0}/groups/(?P[^/]+)$": ResourceGroupsResponse.dispatch, + "{0}/groups/(?P[^/]+)/query$": ResourceGroupsResponse.dispatch, + "{0}/groups-list$": ResourceGroupsResponse.dispatch, + "{0}/resources/(?P[^/]+)/tags$": ResourceGroupsResponse.dispatch, } diff --git a/moto/resourcegroupstaggingapi/__init__.py b/moto/resourcegroupstaggingapi/__init__.py index bd0c4a7df..2dff989b6 100644 --- a/moto/resourcegroupstaggingapi/__init__.py +++ b/moto/resourcegroupstaggingapi/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import resourcegroupstaggingapi_backends from ..core.models import base_decorator -resourcegroupstaggingapi_backend = resourcegroupstaggingapi_backends['us-east-1'] +resourcegroupstaggingapi_backend = resourcegroupstaggingapi_backends["us-east-1"] mock_resourcegroupstaggingapi = base_decorator(resourcegroupstaggingapi_backends) diff --git a/moto/resourcegroupstaggingapi/models.py b/moto/resourcegroupstaggingapi/models.py index 3f15017cc..7b0c03a88 100644 --- a/moto/resourcegroupstaggingapi/models.py +++ b/moto/resourcegroupstaggingapi/models.py @@ -42,7 +42,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): """ :rtype: moto.s3.models.S3Backend """ - return s3_backends['global'] + return s3_backends["global"] @property def ec2_backend(self): @@ -114,16 +114,18 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # TODO move these to their respective backends filters = [lambda t, v: True] for tag_filter_dict in tag_filters: - values = tag_filter_dict.get('Values', []) + values = tag_filter_dict.get("Values", []) if len(values) == 0: # Check key matches - filters.append(lambda t, v: t == tag_filter_dict['Key']) + filters.append(lambda t, v: t == tag_filter_dict["Key"]) elif len(values) == 1: # Check its exactly the same as key, value - filters.append(lambda t, v: t == tag_filter_dict['Key'] and v == values[0]) + filters.append( + lambda t, v: t == tag_filter_dict["Key"] and v == values[0] + ) else: # Check key matches and value is one of the provided values - filters.append(lambda t, v: t == tag_filter_dict['Key'] and v in values) + filters.append(lambda t, v: t == tag_filter_dict["Key"] and v in values) def tag_filter(tag_list): result = [] @@ -131,7 +133,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): for tag in tag_list: temp_result = [] for f in filters: - f_result = f(tag['Key'], tag['Value']) + f_result = f(tag["Key"], tag["Value"]) temp_result.append(f_result) result.append(all(temp_result)) @@ -140,82 +142,150 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): return True # Do S3, resource type s3 - if not resource_type_filters or 's3' in resource_type_filters: + if not resource_type_filters or "s3" in resource_type_filters: for bucket in self.s3_backend.buckets.values(): tags = [] for tag in bucket.tags.tag_set.tags: - tags.append({'Key': tag.key, 'Value': tag.value}) + tags.append({"Key": tag.key, "Value": tag.value}) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:s3:::' + bucket.name, 'Tags': tags} + yield {"ResourceARN": "arn:aws:s3:::" + bucket.name, "Tags": tags} # EC2 tags def get_ec2_tags(res_id): result = [] for key, value in self.ec2_backend.tags.get(res_id, {}).items(): - result.append({'Key': key, 'Value': value}) + result.append({"Key": key, "Value": value}) return result # EC2 AMI, resource type ec2:image - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:image' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:image" in resource_type_filters + ): for ami in self.ec2_backend.amis.values(): tags = get_ec2_tags(ami.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::image/{1}'.format(self.region_name, ami.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::image/{1}".format( + self.region_name, ami.id + ), + "Tags": tags, + } # EC2 Instance, resource type ec2:instance - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:instance' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:instance" in resource_type_filters + ): for reservation in self.ec2_backend.reservations.values(): for instance in reservation.instances: tags = get_ec2_tags(instance.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::instance/{1}'.format(self.region_name, instance.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::instance/{1}".format( + self.region_name, instance.id + ), + "Tags": tags, + } # EC2 NetworkInterface, resource type ec2:network-interface - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:network-interface' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:network-interface" in resource_type_filters + ): for eni in self.ec2_backend.enis.values(): tags = get_ec2_tags(eni.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::network-interface/{1}'.format(self.region_name, eni.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::network-interface/{1}".format( + self.region_name, eni.id + ), + "Tags": tags, + } # TODO EC2 ReservedInstance # EC2 SecurityGroup, resource type ec2:security-group - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:security-group' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:security-group" in resource_type_filters + ): for vpc in self.ec2_backend.groups.values(): for sg in vpc.values(): tags = get_ec2_tags(sg.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::security-group/{1}'.format(self.region_name, sg.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::security-group/{1}".format( + self.region_name, sg.id + ), + "Tags": tags, + } # EC2 Snapshot, resource type ec2:snapshot - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:snapshot' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:snapshot" in resource_type_filters + ): for snapshot in self.ec2_backend.snapshots.values(): tags = get_ec2_tags(snapshot.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::snapshot/{1}'.format(self.region_name, snapshot.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::snapshot/{1}".format( + self.region_name, snapshot.id + ), + "Tags": tags, + } # TODO EC2 SpotInstanceRequest # EC2 Volume, resource type ec2:volume - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:volume' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:volume" in resource_type_filters + ): for volume in self.ec2_backend.volumes.values(): tags = get_ec2_tags(volume.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::volume/{1}'.format(self.region_name, volume.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::volume/{1}".format( + self.region_name, volume.id + ), + "Tags": tags, + } # TODO add these to the keys and values functions / combine functions # ELB @@ -223,16 +293,20 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): def get_elbv2_tags(arn): result = [] for key, value in self.elbv2_backend.load_balancers[elb.arn].tags.items(): - result.append({'Key': key, 'Value': value}) + result.append({"Key": key, "Value": value}) return result - if not resource_type_filters or 'elasticloadbalancer' in resource_type_filters or 'elasticloadbalancer:loadbalancer' in resource_type_filters: + if ( + not resource_type_filters + or "elasticloadbalancer" in resource_type_filters + or "elasticloadbalancer:loadbalancer" in resource_type_filters + ): for elb in self.elbv2_backend.load_balancers.values(): tags = get_elbv2_tags(elb.arn) if not tag_filter(tags): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': '{0}'.format(elb.arn), 'Tags': tags} + yield {"ResourceARN": "{0}".format(elb.arn), "Tags": tags} # EMR Cluster @@ -244,16 +318,16 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): def get_kms_tags(kms_key_id): result = [] for tag in self.kms_backend.list_resource_tags(kms_key_id): - result.append({'Key': tag['TagKey'], 'Value': tag['TagValue']}) + result.append({"Key": tag["TagKey"], "Value": tag["TagValue"]}) return result - if not resource_type_filters or 'kms' in resource_type_filters: + if not resource_type_filters or "kms" in resource_type_filters: for kms_key in self.kms_backend.list_keys(): tags = get_kms_tags(kms_key.id) if not tag_filter(tags): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': '{0}'.format(kms_key.arn), 'Tags': tags} + yield {"ResourceARN": "{0}".format(kms_key.arn), "Tags": tags} # RDS Instance # RDS Reserved Database Instance @@ -387,25 +461,37 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): for value in get_ec2_values(volume.id): yield value - def get_resources(self, pagination_token=None, - resources_per_page=50, tags_per_page=100, - tag_filters=None, resource_type_filters=None): + def get_resources( + self, + pagination_token=None, + resources_per_page=50, + tags_per_page=100, + tag_filters=None, + resource_type_filters=None, + ): # Simple range checking if 100 >= tags_per_page >= 500: - raise RESTError('InvalidParameterException', 'TagsPerPage must be between 100 and 500') + raise RESTError( + "InvalidParameterException", "TagsPerPage must be between 100 and 500" + ) if 1 >= resources_per_page >= 50: - raise RESTError('InvalidParameterException', 'ResourcesPerPage must be between 1 and 50') + raise RESTError( + "InvalidParameterException", "ResourcesPerPage must be between 1 and 50" + ) # If we have a token, go and find the respective generator, or error if pagination_token: if pagination_token not in self._pages: - raise RESTError('PaginationTokenExpiredException', 'Token does not exist') + raise RESTError( + "PaginationTokenExpiredException", "Token does not exist" + ) - generator = self._pages[pagination_token]['gen'] - left_over = self._pages[pagination_token]['misc'] + generator = self._pages[pagination_token]["gen"] + left_over = self._pages[pagination_token]["misc"] else: - generator = self._get_resources_generator(tag_filters=tag_filters, - resource_type_filters=resource_type_filters) + generator = self._get_resources_generator( + tag_filters=tag_filters, resource_type_filters=resource_type_filters + ) left_over = None result = [] @@ -414,13 +500,13 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): if left_over: result.append(left_over) current_resources += 1 - current_tags += len(left_over['Tags']) + current_tags += len(left_over["Tags"]) try: while True: # Generator format: [{'ResourceARN': str, 'Tags': [{'Key': str, 'Value': str]}, ...] next_item = six.next(generator) - resource_tags = len(next_item['Tags']) + resource_tags = len(next_item["Tags"]) if current_resources >= resources_per_page: break @@ -438,7 +524,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # Didn't hit StopIteration so there's stuff left in generator new_token = str(uuid.uuid4()) - self._pages[new_token] = {'gen': generator, 'misc': next_item} + self._pages[new_token] = {"gen": generator, "misc": next_item} # Token used up, might as well bin now, if you call it again your an idiot if pagination_token: @@ -450,10 +536,12 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): if pagination_token: if pagination_token not in self._pages: - raise RESTError('PaginationTokenExpiredException', 'Token does not exist') + raise RESTError( + "PaginationTokenExpiredException", "Token does not exist" + ) - generator = self._pages[pagination_token]['gen'] - left_over = self._pages[pagination_token]['misc'] + generator = self._pages[pagination_token]["gen"] + left_over = self._pages[pagination_token]["misc"] else: generator = self._get_tag_keys_generator() left_over = None @@ -482,7 +570,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # Didn't hit StopIteration so there's stuff left in generator new_token = str(uuid.uuid4()) - self._pages[new_token] = {'gen': generator, 'misc': next_item} + self._pages[new_token] = {"gen": generator, "misc": next_item} # Token used up, might as well bin now, if you call it again your an idiot if pagination_token: @@ -494,10 +582,12 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): if pagination_token: if pagination_token not in self._pages: - raise RESTError('PaginationTokenExpiredException', 'Token does not exist') + raise RESTError( + "PaginationTokenExpiredException", "Token does not exist" + ) - generator = self._pages[pagination_token]['gen'] - left_over = self._pages[pagination_token]['misc'] + generator = self._pages[pagination_token]["gen"] + left_over = self._pages[pagination_token]["misc"] else: generator = self._get_tag_values_generator(key) left_over = None @@ -526,7 +616,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # Didn't hit StopIteration so there's stuff left in generator new_token = str(uuid.uuid4()) - self._pages[new_token] = {'gen': generator, 'misc': next_item} + self._pages[new_token] = {"gen": generator, "misc": next_item} # Token used up, might as well bin now, if you call it again your an idiot if pagination_token: @@ -546,5 +636,9 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # return failed_resources_map -available_regions = boto3.session.Session().get_available_regions("resourcegroupstaggingapi") -resourcegroupstaggingapi_backends = {region: ResourceGroupsTaggingAPIBackend(region) for region in available_regions} +available_regions = boto3.session.Session().get_available_regions( + "resourcegroupstaggingapi" +) +resourcegroupstaggingapi_backends = { + region: ResourceGroupsTaggingAPIBackend(region) for region in available_regions +} diff --git a/moto/resourcegroupstaggingapi/responses.py b/moto/resourcegroupstaggingapi/responses.py index 966778f29..02f5b5484 100644 --- a/moto/resourcegroupstaggingapi/responses.py +++ b/moto/resourcegroupstaggingapi/responses.py @@ -5,7 +5,7 @@ import json class ResourceGroupsTaggingAPIResponse(BaseResponse): - SERVICE_NAME = 'resourcegroupstaggingapi' + SERVICE_NAME = "resourcegroupstaggingapi" @property def backend(self): @@ -32,25 +32,21 @@ class ResourceGroupsTaggingAPIResponse(BaseResponse): ) # Format tag response - response = { - 'ResourceTagMappingList': resource_tag_mapping_list - } + response = {"ResourceTagMappingList": resource_tag_mapping_list} if pagination_token: - response['PaginationToken'] = pagination_token + response["PaginationToken"] = pagination_token return json.dumps(response) def get_tag_keys(self): pagination_token = self._get_param("PaginationToken") pagination_token, tag_keys = self.backend.get_tag_keys( - pagination_token=pagination_token, + pagination_token=pagination_token ) - response = { - 'TagKeys': tag_keys - } + response = {"TagKeys": tag_keys} if pagination_token: - response['PaginationToken'] = pagination_token + response["PaginationToken"] = pagination_token return json.dumps(response) @@ -58,15 +54,12 @@ class ResourceGroupsTaggingAPIResponse(BaseResponse): pagination_token = self._get_param("PaginationToken") key = self._get_param("Key") pagination_token, tag_values = self.backend.get_tag_values( - pagination_token=pagination_token, - key=key, + pagination_token=pagination_token, key=key ) - response = { - 'TagValues': tag_values - } + response = {"TagValues": tag_values} if pagination_token: - response['PaginationToken'] = pagination_token + response["PaginationToken"] = pagination_token return json.dumps(response) diff --git a/moto/resourcegroupstaggingapi/urls.py b/moto/resourcegroupstaggingapi/urls.py index a972df276..3b0182ee9 100644 --- a/moto/resourcegroupstaggingapi/urls.py +++ b/moto/resourcegroupstaggingapi/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import ResourceGroupsTaggingAPIResponse -url_bases = [ - "https?://tagging.(.+).amazonaws.com", -] +url_bases = ["https?://tagging.(.+).amazonaws.com"] -url_paths = { - '{0}/$': ResourceGroupsTaggingAPIResponse.dispatch, -} +url_paths = {"{0}/$": ResourceGroupsTaggingAPIResponse.dispatch} diff --git a/moto/route53/models.py b/moto/route53/models.py index 61a6609aa..2ae03e54d 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -15,11 +15,10 @@ ROUTE53_ID_CHOICE = string.ascii_uppercase + string.digits def create_route53_zone_id(): # New ID's look like this Z1RWWTK7Y8UDDQ - return ''.join([random.choice(ROUTE53_ID_CHOICE) for _ in range(0, 15)]) + return "".join([random.choice(ROUTE53_ID_CHOICE) for _ in range(0, 15)]) class HealthCheck(BaseModel): - def __init__(self, health_check_id, health_check_args): self.id = health_check_id self.ip_address = health_check_args.get("ip_address") @@ -36,23 +35,26 @@ class HealthCheck(BaseModel): return self.id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties']['HealthCheckConfig'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"]["HealthCheckConfig"] health_check_args = { - "ip_address": properties.get('IPAddress'), - "port": properties.get('Port'), - "type": properties['Type'], - "resource_path": properties.get('ResourcePath'), - "fqdn": properties.get('FullyQualifiedDomainName'), - "search_string": properties.get('SearchString'), - "request_interval": properties.get('RequestInterval'), - "failure_threshold": properties.get('FailureThreshold'), + "ip_address": properties.get("IPAddress"), + "port": properties.get("Port"), + "type": properties["Type"], + "resource_path": properties.get("ResourcePath"), + "fqdn": properties.get("FullyQualifiedDomainName"), + "search_string": properties.get("SearchString"), + "request_interval": properties.get("RequestInterval"), + "failure_threshold": properties.get("FailureThreshold"), } health_check = route53_backend.create_health_check(health_check_args) return health_check def to_xml(self): - template = Template(""" + template = Template( + """ {{ health_check.id }} example.com 192.0.2.17 @@ -68,59 +70,66 @@ class HealthCheck(BaseModel): {% endif %} 1 - """) + """ + ) return template.render(health_check=self) class RecordSet(BaseModel): - def __init__(self, kwargs): - self.name = kwargs.get('Name') - self.type_ = kwargs.get('Type') - self.ttl = kwargs.get('TTL') - self.records = kwargs.get('ResourceRecords', []) - self.set_identifier = kwargs.get('SetIdentifier') - self.weight = kwargs.get('Weight') - self.region = kwargs.get('Region') - self.health_check = kwargs.get('HealthCheckId') - self.hosted_zone_name = kwargs.get('HostedZoneName') - self.hosted_zone_id = kwargs.get('HostedZoneId') - self.alias_target = kwargs.get('AliasTarget') + self.name = kwargs.get("Name") + self.type_ = kwargs.get("Type") + self.ttl = kwargs.get("TTL") + self.records = kwargs.get("ResourceRecords", []) + self.set_identifier = kwargs.get("SetIdentifier") + self.weight = kwargs.get("Weight") + self.region = kwargs.get("Region") + self.health_check = kwargs.get("HealthCheckId") + self.hosted_zone_name = kwargs.get("HostedZoneName") + self.hosted_zone_id = kwargs.get("HostedZoneId") + self.alias_target = kwargs.get("AliasTarget") @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") if zone_name: hosted_zone = route53_backend.get_hosted_zone_by_name(zone_name) else: - hosted_zone = route53_backend.get_hosted_zone( - properties["HostedZoneId"]) + hosted_zone = route53_backend.get_hosted_zone(properties["HostedZoneId"]) record_set = hosted_zone.add_rrset(properties) return record_set @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): # this will break if you changed the zone the record is in, # unfortunately - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") if zone_name: hosted_zone = route53_backend.get_hosted_zone_by_name(zone_name) else: - hosted_zone = route53_backend.get_hosted_zone( - properties["HostedZoneId"]) + hosted_zone = route53_backend.get_hosted_zone(properties["HostedZoneId"]) try: - hosted_zone.delete_rrset({'Name': resource_name}) + hosted_zone.delete_rrset({"Name": resource_name}) except KeyError: pass @@ -129,7 +138,8 @@ class RecordSet(BaseModel): return self.name def to_xml(self): - template = Template(""" + template = Template( + """ {{ record_set.name }} {{ record_set.type_ }} {% if record_set.set_identifier %} @@ -162,26 +172,25 @@ class RecordSet(BaseModel): {% if record_set.health_check %} {{ record_set.health_check }} {% endif %} - """) + """ + ) return template.render(record_set=self) def delete(self, *args, **kwargs): - ''' Not exposed as part of the Route 53 API - used for CloudFormation. args are ignored ''' - hosted_zone = route53_backend.get_hosted_zone_by_name( - self.hosted_zone_name) + """ Not exposed as part of the Route 53 API - used for CloudFormation. args are ignored """ + hosted_zone = route53_backend.get_hosted_zone_by_name(self.hosted_zone_name) if not hosted_zone: hosted_zone = route53_backend.get_hosted_zone(self.hosted_zone_id) - hosted_zone.delete_rrset({'Name': self.name, 'Type': self.type_}) + hosted_zone.delete_rrset({"Name": self.name, "Type": self.type_}) def reverse_domain_name(domain_name): - if domain_name.endswith('.'): # normalize without trailing dot + if domain_name.endswith("."): # normalize without trailing dot domain_name = domain_name[:-1] - return '.'.join(reversed(domain_name.split('.'))) + return ".".join(reversed(domain_name.split("."))) class FakeZone(BaseModel): - def __init__(self, name, id_, private_zone, comment=None): self.name = name self.id = id_ @@ -198,7 +207,11 @@ class FakeZone(BaseModel): def upsert_rrset(self, record_set): new_rrset = RecordSet(record_set) for i, rrset in enumerate(self.rrsets): - if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_ and rrset.set_identifier == new_rrset.set_identifier: + if ( + rrset.name == new_rrset.name + and rrset.type_ == new_rrset.type_ + and rrset.set_identifier == new_rrset.set_identifier + ): self.rrsets[i] = new_rrset break else: @@ -209,13 +222,16 @@ class FakeZone(BaseModel): self.rrsets = [ record_set for record_set in self.rrsets - if record_set.name != rrset['Name'] or - (rrset.get('Type') is not None and record_set.type_ != rrset['Type']) + if record_set.name != rrset["Name"] + or (rrset.get("Type") is not None and record_set.type_ != rrset["Type"]) ] def delete_rrset_by_id(self, set_identifier): self.rrsets = [ - record_set for record_set in self.rrsets if record_set.set_identifier != set_identifier] + record_set + for record_set in self.rrsets + if record_set.set_identifier != set_identifier + ] def get_record_sets(self, start_type, start_name): record_sets = list(self.rrsets) # Copy the list @@ -223,11 +239,15 @@ class FakeZone(BaseModel): record_sets = [ record_set for record_set in record_sets - if reverse_domain_name(record_set.name) >= reverse_domain_name(start_name) + if reverse_domain_name(record_set.name) + >= reverse_domain_name(start_name) ] if start_type: record_sets = [ - record_set for record_set in record_sets if record_set.type_ >= start_type] + record_set + for record_set in record_sets + if record_set.type_ >= start_type + ] return record_sets @@ -236,17 +256,17 @@ class FakeZone(BaseModel): return self.id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] name = properties["Name"] - hosted_zone = route53_backend.create_hosted_zone( - name, private_zone=False) + hosted_zone = route53_backend.create_hosted_zone(name, private_zone=False) return hosted_zone class RecordSetGroup(BaseModel): - def __init__(self, hosted_zone_id, record_sets): self.hosted_zone_id = hosted_zone_id self.record_sets = record_sets @@ -256,8 +276,10 @@ class RecordSetGroup(BaseModel): return "arn:aws:route53:::hostedzone/{0}".format(self.hosted_zone_id) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") if zone_name: @@ -273,7 +295,6 @@ class RecordSetGroup(BaseModel): class Route53Backend(BaseBackend): - def __init__(self): self.zones = {} self.health_checks = {} @@ -281,30 +302,30 @@ class Route53Backend(BaseBackend): def create_hosted_zone(self, name, private_zone, comment=None): new_id = create_route53_zone_id() - new_zone = FakeZone( - name, new_id, private_zone=private_zone, comment=comment) + new_zone = FakeZone(name, new_id, private_zone=private_zone, comment=comment) self.zones[new_id] = new_zone return new_zone def change_tags_for_resource(self, resource_id, tags): - if 'Tag' in tags: - if isinstance(tags['Tag'], list): - for tag in tags['Tag']: - self.resource_tags[resource_id][tag['Key']] = tag['Value'] + if "Tag" in tags: + if isinstance(tags["Tag"], list): + for tag in tags["Tag"]: + self.resource_tags[resource_id][tag["Key"]] = tag["Value"] else: - key, value = (tags['Tag']['Key'], tags['Tag']['Value']) + key, value = (tags["Tag"]["Key"], tags["Tag"]["Value"]) self.resource_tags[resource_id][key] = value else: - if 'Key' in tags: - if isinstance(tags['Key'], list): - for key in tags['Key']: - del(self.resource_tags[resource_id][key]) + if "Key" in tags: + if isinstance(tags["Key"], list): + for key in tags["Key"]: + del self.resource_tags[resource_id][key] else: - del(self.resource_tags[resource_id][tags['Key']]) + del self.resource_tags[resource_id][tags["Key"]] def list_tags_for_resource(self, resource_id): if resource_id in self.resource_tags: return self.resource_tags[resource_id] + return {} def get_all_hosted_zones(self): return self.zones.values() diff --git a/moto/route53/responses.py b/moto/route53/responses.py index f933c575a..3e688b65d 100644 --- a/moto/route53/responses.py +++ b/moto/route53/responses.py @@ -8,23 +8,24 @@ import xmltodict class Route53(BaseResponse): - def list_or_create_hostzone_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) if request.method == "POST": elements = xmltodict.parse(self.body) if "HostedZoneConfig" in elements["CreateHostedZoneRequest"]: - comment = elements["CreateHostedZoneRequest"][ - "HostedZoneConfig"]["Comment"] + comment = elements["CreateHostedZoneRequest"]["HostedZoneConfig"][ + "Comment" + ] try: # in boto3, this field is set directly in the xml private_zone = elements["CreateHostedZoneRequest"][ - "HostedZoneConfig"]["PrivateZone"] + "HostedZoneConfig" + ]["PrivateZone"] except KeyError: # if a VPC subsection is only included in xmls params when private_zone=True, # see boto: boto/route53/connection.py - private_zone = 'VPC' in elements["CreateHostedZoneRequest"] + private_zone = "VPC" in elements["CreateHostedZoneRequest"] else: comment = None private_zone = False @@ -35,9 +36,7 @@ class Route53(BaseResponse): name += "." new_zone = route53_backend.create_hosted_zone( - name, - comment=comment, - private_zone=private_zone, + name, comment=comment, private_zone=private_zone ) template = Template(CREATE_HOSTED_ZONE_RESPONSE) return 201, headers, template.render(zone=new_zone) @@ -54,9 +53,15 @@ class Route53(BaseResponse): dnsname = query_params.get("dnsname") if dnsname: - dnsname = dnsname[0] # parse_qs gives us a list, but this parameter doesn't repeat + dnsname = dnsname[ + 0 + ] # parse_qs gives us a list, but this parameter doesn't repeat # return all zones with that name (there can be more than one) - zones = [zone for zone in route53_backend.get_all_hosted_zones() if zone.name == dnsname] + zones = [ + zone + for zone in route53_backend.get_all_hosted_zones() + if zone.name == dnsname + ] else: # sort by names, but with domain components reversed # see http://boto3.readthedocs.io/en/latest/reference/services/route53.html#Route53.Client.list_hosted_zones_by_name @@ -76,7 +81,7 @@ class Route53(BaseResponse): def get_or_delete_hostzone_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) parsed_url = urlparse(full_url) - zoneid = parsed_url.path.rstrip('/').rsplit('/', 1)[1] + zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1] the_zone = route53_backend.get_hosted_zone(zoneid) if not the_zone: return 404, headers, "Zone %s not Found" % zoneid @@ -95,7 +100,7 @@ class Route53(BaseResponse): parsed_url = urlparse(full_url) method = request.method - zoneid = parsed_url.path.rstrip('/').rsplit('/', 2)[1] + zoneid = parsed_url.path.rstrip("/").rsplit("/", 2)[1] the_zone = route53_backend.get_hosted_zone(zoneid) if not the_zone: return 404, headers, "Zone %s Not Found" % zoneid @@ -103,46 +108,55 @@ class Route53(BaseResponse): if method == "POST": elements = xmltodict.parse(self.body) - change_list = elements['ChangeResourceRecordSetsRequest'][ - 'ChangeBatch']['Changes']['Change'] + change_list = elements["ChangeResourceRecordSetsRequest"]["ChangeBatch"][ + "Changes" + ]["Change"] if not isinstance(change_list, list): - change_list = [elements['ChangeResourceRecordSetsRequest'][ - 'ChangeBatch']['Changes']['Change']] + change_list = [ + elements["ChangeResourceRecordSetsRequest"]["ChangeBatch"][ + "Changes" + ]["Change"] + ] for value in change_list: - action = value['Action'] - record_set = value['ResourceRecordSet'] + action = value["Action"] + record_set = value["ResourceRecordSet"] - cleaned_record_name = record_set['Name'].strip('.') - cleaned_hosted_zone_name = the_zone.name.strip('.') + cleaned_record_name = record_set["Name"].strip(".") + cleaned_hosted_zone_name = the_zone.name.strip(".") if not cleaned_record_name.endswith(cleaned_hosted_zone_name): error_msg = """ An error occurred (InvalidChangeBatch) when calling the ChangeResourceRecordSets operation: RRSet with DNS name %s is not permitted in zone %s - """ % (record_set['Name'], the_zone.name) + """ % ( + record_set["Name"], + the_zone.name, + ) return 400, headers, error_msg - if not record_set['Name'].endswith('.'): - record_set['Name'] += '.' + if not record_set["Name"].endswith("."): + record_set["Name"] += "." - if action in ('CREATE', 'UPSERT'): - if 'ResourceRecords' in record_set: - resource_records = list( - record_set['ResourceRecords'].values())[0] + if action in ("CREATE", "UPSERT"): + if "ResourceRecords" in record_set: + resource_records = list(record_set["ResourceRecords"].values())[ + 0 + ] if not isinstance(resource_records, list): # Depending on how many records there are, this may # or may not be a list resource_records = [resource_records] - record_set['ResourceRecords'] = [x['Value'] for x in resource_records] - if action == 'CREATE': + record_set["ResourceRecords"] = [ + x["Value"] for x in resource_records + ] + if action == "CREATE": the_zone.add_rrset(record_set) else: the_zone.upsert_rrset(record_set) elif action == "DELETE": - if 'SetIdentifier' in record_set: - the_zone.delete_rrset_by_id( - record_set["SetIdentifier"]) + if "SetIdentifier" in record_set: + the_zone.delete_rrset_by_id(record_set["SetIdentifier"]) else: the_zone.delete_rrset(record_set) @@ -163,20 +177,20 @@ class Route53(BaseResponse): method = request.method if method == "POST": - properties = xmltodict.parse(self.body)['CreateHealthCheckRequest'][ - 'HealthCheckConfig'] + properties = xmltodict.parse(self.body)["CreateHealthCheckRequest"][ + "HealthCheckConfig" + ] health_check_args = { - "ip_address": properties.get('IPAddress'), - "port": properties.get('Port'), - "type": properties['Type'], - "resource_path": properties.get('ResourcePath'), - "fqdn": properties.get('FullyQualifiedDomainName'), - "search_string": properties.get('SearchString'), - "request_interval": properties.get('RequestInterval'), - "failure_threshold": properties.get('FailureThreshold'), + "ip_address": properties.get("IPAddress"), + "port": properties.get("Port"), + "type": properties["Type"], + "resource_path": properties.get("ResourcePath"), + "fqdn": properties.get("FullyQualifiedDomainName"), + "search_string": properties.get("SearchString"), + "request_interval": properties.get("RequestInterval"), + "failure_threshold": properties.get("FailureThreshold"), } - health_check = route53_backend.create_health_check( - health_check_args) + health_check = route53_backend.create_health_check(health_check_args) template = Template(CREATE_HEALTH_CHECK_RESPONSE) return 201, headers, template.render(health_check=health_check) elif method == "DELETE": @@ -191,13 +205,14 @@ class Route53(BaseResponse): def not_implemented_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) - action = '' - if 'tags' in full_url: - action = 'tags' - elif 'trafficpolicyinstances' in full_url: - action = 'policies' + action = "" + if "tags" in full_url: + action = "tags" + elif "trafficpolicyinstances" in full_url: + action = "policies" raise NotImplementedError( - "The action for {0} has not been implemented for route 53".format(action)) + "The action for {0} has not been implemented for route 53".format(action) + ) def list_or_change_tags_for_resource_request(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -209,17 +224,19 @@ class Route53(BaseResponse): if request.method == "GET": tags = route53_backend.list_tags_for_resource(id_) template = Template(LIST_TAGS_FOR_RESOURCE_RESPONSE) - return 200, headers, template.render( - resource_type=type_, resource_id=id_, tags=tags) + return ( + 200, + headers, + template.render(resource_type=type_, resource_id=id_, tags=tags), + ) if request.method == "POST": - tags = xmltodict.parse( - self.body)['ChangeTagsForResourceRequest'] + tags = xmltodict.parse(self.body)["ChangeTagsForResourceRequest"] - if 'AddTags' in tags: - tags = tags['AddTags'] - elif 'RemoveTagKeys' in tags: - tags = tags['RemoveTagKeys'] + if "AddTags" in tags: + tags = tags["AddTags"] + elif "RemoveTagKeys" in tags: + tags = tags["RemoveTagKeys"] route53_backend.change_tags_for_resource(id_, tags) template = Template(CHANGE_TAGS_FOR_RESOURCE_RESPONSE) diff --git a/moto/route53/urls.py b/moto/route53/urls.py index 53abf23a2..a697d258a 100644 --- a/moto/route53/urls.py +++ b/moto/route53/urls.py @@ -1,9 +1,7 @@ from __future__ import unicode_literals from .responses import Route53 -url_bases = [ - "https?://route53(.*).amazonaws.com", -] +url_bases = ["https?://route53(.*).amazonaws.com"] def tag_response1(*args, **kwargs): @@ -15,12 +13,12 @@ def tag_response2(*args, **kwargs): url_paths = { - '{0}/(?P[\d_-]+)/hostedzone$': Route53().list_or_create_hostzone_response, - '{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)$': Route53().get_or_delete_hostzone_response, - '{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)/rrset/?$': Route53().rrset_response, - '{0}/(?P[\d_-]+)/hostedzonesbyname': Route53().list_hosted_zones_by_name_response, - '{0}/(?P[\d_-]+)/healthcheck': Route53().health_check_response, - '{0}/(?P[\d_-]+)/tags/healthcheck/(?P[^/]+)$': tag_response1, - '{0}/(?P[\d_-]+)/tags/hostedzone/(?P[^/]+)$': tag_response2, - '{0}/(?P[\d_-]+)/trafficpolicyinstances/*': Route53().not_implemented_response + "{0}/(?P[\d_-]+)/hostedzone$": Route53().list_or_create_hostzone_response, + "{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)$": Route53().get_or_delete_hostzone_response, + "{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)/rrset/?$": Route53().rrset_response, + "{0}/(?P[\d_-]+)/hostedzonesbyname": Route53().list_hosted_zones_by_name_response, + "{0}/(?P[\d_-]+)/healthcheck": Route53().health_check_response, + "{0}/(?P[\d_-]+)/tags/healthcheck/(?P[^/]+)$": tag_response1, + "{0}/(?P[\d_-]+)/tags/hostedzone/(?P[^/]+)$": tag_response2, + "{0}/(?P[\d_-]+)/trafficpolicyinstances/*": Route53().not_implemented_response, } diff --git a/moto/s3/config.py b/moto/s3/config.py new file mode 100644 index 000000000..8098addfc --- /dev/null +++ b/moto/s3/config.py @@ -0,0 +1,121 @@ +import json + +from moto.core.exceptions import InvalidNextTokenException +from moto.core.models import ConfigQueryModel +from moto.s3 import s3_backends + + +class S3ConfigQuery(ConfigQueryModel): + def list_config_service_resources( + self, + resource_ids, + resource_name, + limit, + next_token, + backend_region=None, + resource_region=None, + ): + # The resource_region only matters for aggregated queries as you can filter on bucket regions for them. + # For other resource types, you would need to iterate appropriately for the backend_region. + + # Resource IDs are the same as S3 bucket names + # For aggregation -- did we get both a resource ID and a resource name? + if resource_ids and resource_name: + # If the values are different, then return an empty list: + if resource_name not in resource_ids: + return [], None + + # If no filter was passed in for resource names/ids then return them all: + if not resource_ids and not resource_name: + bucket_list = list(self.backends["global"].buckets.keys()) + + else: + # Match the resource name / ID: + bucket_list = [] + filter_buckets = [resource_name] if resource_name else resource_ids + + for bucket in self.backends["global"].buckets.keys(): + if bucket in filter_buckets: + bucket_list.append(bucket) + + # Filter on the proper region if supplied: + region_filter = backend_region or resource_region + if region_filter: + region_buckets = [] + + for bucket in bucket_list: + if self.backends["global"].buckets[bucket].region_name == region_filter: + region_buckets.append(bucket) + + bucket_list = region_buckets + + if not bucket_list: + return [], None + + # Pagination logic: + sorted_buckets = sorted(bucket_list) + new_token = None + + # Get the start: + if not next_token: + start = 0 + else: + # Tokens for this moto feature is just the bucket name: + # For OTHER non-global resource types, it's the region concatenated with the resource ID. + if next_token not in sorted_buckets: + raise InvalidNextTokenException() + + start = sorted_buckets.index(next_token) + + # Get the list of items to collect: + bucket_list = sorted_buckets[start : (start + limit)] + + if len(sorted_buckets) > (start + limit): + new_token = sorted_buckets[start + limit] + + return ( + [ + { + "type": "AWS::S3::Bucket", + "id": bucket, + "name": bucket, + "region": self.backends["global"].buckets[bucket].region_name, + } + for bucket in bucket_list + ], + new_token, + ) + + def get_config_resource( + self, resource_id, resource_name=None, backend_region=None, resource_region=None + ): + # Get the bucket: + bucket = self.backends["global"].buckets.get(resource_id, {}) + + if not bucket: + return + + # Are we filtering based on region? + region_filter = backend_region or resource_region + if region_filter and bucket.region_name != region_filter: + return + + # Are we also filtering on bucket name? + if resource_name and bucket.name != resource_name: + return + + # Format the bucket to the AWS Config format: + config_data = bucket.to_config_dict() + + # The 'configuration' field is also a JSON string: + config_data["configuration"] = json.dumps(config_data["configuration"]) + + # Supplementary config need all values converted to JSON strings if they are not strings already: + for field, value in config_data["supplementaryConfiguration"].items(): + if not isinstance(value, str): + config_data["supplementaryConfiguration"][field] = json.dumps(value) + + return config_data + + +s3_config_query = S3ConfigQuery(s3_backends) diff --git a/moto/s3/exceptions.py b/moto/s3/exceptions.py index 8d2326fa1..1f2ead639 100644 --- a/moto/s3/exceptions.py +++ b/moto/s3/exceptions.py @@ -12,18 +12,16 @@ ERROR_WITH_KEY_NAME = """{% extends 'single_error' %} class S3ClientError(RESTError): - def __init__(self, *args, **kwargs): - kwargs.setdefault('template', 'single_error') - self.templates['bucket_error'] = ERROR_WITH_BUCKET_NAME + kwargs.setdefault("template", "single_error") + self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super(S3ClientError, self).__init__(*args, **kwargs) class BucketError(S3ClientError): - def __init__(self, *args, **kwargs): - kwargs.setdefault('template', 'bucket_error') - self.templates['bucket_error'] = ERROR_WITH_BUCKET_NAME + kwargs.setdefault("template", "bucket_error") + self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super(BucketError, self).__init__(*args, **kwargs) @@ -33,10 +31,14 @@ class BucketAlreadyExists(BucketError): def __init__(self, *args, **kwargs): super(BucketAlreadyExists, self).__init__( "BucketAlreadyExists", - ("The requested bucket name is not available. The bucket " - "namespace is shared by all users of the system. Please " - "select a different name and try again"), - *args, **kwargs) + ( + "The requested bucket name is not available. The bucket " + "namespace is shared by all users of the system. Please " + "select a different name and try again" + ), + *args, + **kwargs + ) class MissingBucket(BucketError): @@ -44,9 +46,8 @@ class MissingBucket(BucketError): def __init__(self, *args, **kwargs): super(MissingBucket, self).__init__( - "NoSuchBucket", - "The specified bucket does not exist", - *args, **kwargs) + "NoSuchBucket", "The specified bucket does not exist", *args, **kwargs + ) class MissingKey(S3ClientError): @@ -54,9 +55,7 @@ class MissingKey(S3ClientError): def __init__(self, key_name): super(MissingKey, self).__init__( - "NoSuchKey", - "The specified key does not exist.", - Key=key_name, + "NoSuchKey", "The specified key does not exist.", Key=key_name ) @@ -77,9 +76,13 @@ class InvalidPartOrder(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidPartOrder, self).__init__( "InvalidPartOrder", - ("The list of parts was not in ascending order. The parts " - "list must be specified in order by part number."), - *args, **kwargs) + ( + "The list of parts was not in ascending order. The parts " + "list must be specified in order by part number." + ), + *args, + **kwargs + ) class InvalidPart(S3ClientError): @@ -88,10 +91,14 @@ class InvalidPart(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidPart, self).__init__( "InvalidPart", - ("One or more of the specified parts could not be found. " - "The part might not have been uploaded, or the specified " - "entity tag might not have matched the part's entity tag."), - *args, **kwargs) + ( + "One or more of the specified parts could not be found. " + "The part might not have been uploaded, or the specified " + "entity tag might not have matched the part's entity tag." + ), + *args, + **kwargs + ) class EntityTooSmall(S3ClientError): @@ -101,7 +108,9 @@ class EntityTooSmall(S3ClientError): super(EntityTooSmall, self).__init__( "EntityTooSmall", "Your proposed upload is smaller than the minimum allowed object size.", - *args, **kwargs) + *args, + **kwargs + ) class InvalidRequest(S3ClientError): @@ -110,8 +119,12 @@ class InvalidRequest(S3ClientError): def __init__(self, method, *args, **kwargs): super(InvalidRequest, self).__init__( "InvalidRequest", - "Found unsupported HTTP method in CORS config. Unsupported method is {}".format(method), - *args, **kwargs) + "Found unsupported HTTP method in CORS config. Unsupported method is {}".format( + method + ), + *args, + **kwargs + ) class MalformedXML(S3ClientError): @@ -121,7 +134,9 @@ class MalformedXML(S3ClientError): super(MalformedXML, self).__init__( "MalformedXML", "The XML you provided was not well-formed or did not validate against our published schema", - *args, **kwargs) + *args, + **kwargs + ) class MalformedACLError(S3ClientError): @@ -131,14 +146,18 @@ class MalformedACLError(S3ClientError): super(MalformedACLError, self).__init__( "MalformedACLError", "The XML you provided was not well-formed or did not validate against our published schema", - *args, **kwargs) + *args, + **kwargs + ) class InvalidTargetBucketForLogging(S3ClientError): code = 400 def __init__(self, msg): - super(InvalidTargetBucketForLogging, self).__init__("InvalidTargetBucketForLogging", msg) + super(InvalidTargetBucketForLogging, self).__init__( + "InvalidTargetBucketForLogging", msg + ) class CrossLocationLoggingProhibitted(S3ClientError): @@ -146,8 +165,7 @@ class CrossLocationLoggingProhibitted(S3ClientError): def __init__(self): super(CrossLocationLoggingProhibitted, self).__init__( - "CrossLocationLoggingProhibitted", - "Cross S3 location logging not allowed." + "CrossLocationLoggingProhibitted", "Cross S3 location logging not allowed." ) @@ -156,9 +174,8 @@ class InvalidNotificationARN(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidNotificationARN, self).__init__( - "InvalidArgument", - "The ARN is not well formed", - *args, **kwargs) + "InvalidArgument", "The ARN is not well formed", *args, **kwargs + ) class InvalidNotificationDestination(S3ClientError): @@ -168,7 +185,9 @@ class InvalidNotificationDestination(S3ClientError): super(InvalidNotificationDestination, self).__init__( "InvalidArgument", "The notification destination service region is not valid for the bucket location constraint", - *args, **kwargs) + *args, + **kwargs + ) class InvalidNotificationEvent(S3ClientError): @@ -178,7 +197,9 @@ class InvalidNotificationEvent(S3ClientError): super(InvalidNotificationEvent, self).__init__( "InvalidArgument", "The event is not supported for notifications", - *args, **kwargs) + *args, + **kwargs + ) class InvalidStorageClass(S3ClientError): @@ -188,7 +209,9 @@ class InvalidStorageClass(S3ClientError): super(InvalidStorageClass, self).__init__( "InvalidStorageClass", "The storage class you specified is not valid", - *args, **kwargs) + *args, + **kwargs + ) class InvalidBucketName(S3ClientError): @@ -196,9 +219,7 @@ class InvalidBucketName(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidBucketName, self).__init__( - "InvalidBucketName", - "The specified bucket is not valid.", - *args, **kwargs + "InvalidBucketName", "The specified bucket is not valid.", *args, **kwargs ) @@ -209,35 +230,51 @@ class DuplicateTagKeys(S3ClientError): super(DuplicateTagKeys, self).__init__( "InvalidTag", "Cannot provide multiple Tags with the same key", - *args, **kwargs) + *args, + **kwargs + ) class S3AccessDeniedError(S3ClientError): code = 403 def __init__(self, *args, **kwargs): - super(S3AccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs) + super(S3AccessDeniedError, self).__init__( + "AccessDenied", "Access Denied", *args, **kwargs + ) class BucketAccessDeniedError(BucketError): code = 403 def __init__(self, *args, **kwargs): - super(BucketAccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs) + super(BucketAccessDeniedError, self).__init__( + "AccessDenied", "Access Denied", *args, **kwargs + ) class S3InvalidTokenError(S3ClientError): code = 400 def __init__(self, *args, **kwargs): - super(S3InvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs) + super(S3InvalidTokenError, self).__init__( + "InvalidToken", + "The provided token is malformed or otherwise invalid.", + *args, + **kwargs + ) class BucketInvalidTokenError(BucketError): code = 400 def __init__(self, *args, **kwargs): - super(BucketInvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs) + super(BucketInvalidTokenError, self).__init__( + "InvalidToken", + "The provided token is malformed or otherwise invalid.", + *args, + **kwargs + ) class S3InvalidAccessKeyIdError(S3ClientError): @@ -245,8 +282,11 @@ class S3InvalidAccessKeyIdError(S3ClientError): def __init__(self, *args, **kwargs): super(S3InvalidAccessKeyIdError, self).__init__( - 'InvalidAccessKeyId', - "The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs) + "InvalidAccessKeyId", + "The AWS Access Key Id you provided does not exist in our records.", + *args, + **kwargs + ) class BucketInvalidAccessKeyIdError(S3ClientError): @@ -254,8 +294,11 @@ class BucketInvalidAccessKeyIdError(S3ClientError): def __init__(self, *args, **kwargs): super(BucketInvalidAccessKeyIdError, self).__init__( - 'InvalidAccessKeyId', - "The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs) + "InvalidAccessKeyId", + "The AWS Access Key Id you provided does not exist in our records.", + *args, + **kwargs + ) class S3SignatureDoesNotMatchError(S3ClientError): @@ -263,8 +306,11 @@ class S3SignatureDoesNotMatchError(S3ClientError): def __init__(self, *args, **kwargs): super(S3SignatureDoesNotMatchError, self).__init__( - 'SignatureDoesNotMatch', - "The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs) + "SignatureDoesNotMatch", + "The request signature we calculated does not match the signature you provided. Check your key and signing method.", + *args, + **kwargs + ) class BucketSignatureDoesNotMatchError(S3ClientError): @@ -272,5 +318,32 @@ class BucketSignatureDoesNotMatchError(S3ClientError): def __init__(self, *args, **kwargs): super(BucketSignatureDoesNotMatchError, self).__init__( - 'SignatureDoesNotMatch', - "The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs) + "SignatureDoesNotMatch", + "The request signature we calculated does not match the signature you provided. Check your key and signing method.", + *args, + **kwargs + ) + + +class NoSuchPublicAccessBlockConfiguration(S3ClientError): + code = 404 + + def __init__(self, *args, **kwargs): + super(NoSuchPublicAccessBlockConfiguration, self).__init__( + "NoSuchPublicAccessBlockConfiguration", + "The public access block configuration was not found", + *args, + **kwargs + ) + + +class InvalidPublicAccessBlockConfiguration(S3ClientError): + code = 400 + + def __init__(self, *args, **kwargs): + super(InvalidPublicAccessBlockConfiguration, self).__init__( + "InvalidRequest", + "Must specify at least one configuration.", + *args, + **kwargs + ) diff --git a/moto/s3/models.py b/moto/s3/models.py index b5aef34d3..fe8e908ef 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -1,4 +1,7 @@ +# -*- coding: utf-8 -*- from __future__ import unicode_literals + +import json import os import base64 import datetime @@ -10,6 +13,7 @@ import random import string import tempfile import sys +import time import uuid import six @@ -18,9 +22,21 @@ from bisect import insort from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime from .exceptions import ( - BucketAlreadyExists, MissingBucket, InvalidBucketName, InvalidPart, InvalidRequest, - EntityTooSmall, MissingKey, InvalidNotificationDestination, MalformedXML, InvalidStorageClass, - InvalidTargetBucketForLogging, DuplicateTagKeys, CrossLocationLoggingProhibitted + BucketAlreadyExists, + MissingBucket, + InvalidBucketName, + InvalidPart, + InvalidRequest, + EntityTooSmall, + MissingKey, + InvalidNotificationDestination, + MalformedXML, + InvalidStorageClass, + InvalidTargetBucketForLogging, + DuplicateTagKeys, + CrossLocationLoggingProhibitted, + NoSuchPublicAccessBlockConfiguration, + InvalidPublicAccessBlockConfiguration, ) from .utils import clean_key_name, _VersionedKeyStore @@ -28,14 +44,21 @@ MAX_BUCKET_NAME_LENGTH = 63 MIN_BUCKET_NAME_LENGTH = 3 UPLOAD_ID_BYTES = 43 UPLOAD_PART_MIN_SIZE = 5242880 -STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA", - "INTELLIGENT_TIERING", "GLACIER", "DEEP_ARCHIVE"] +STORAGE_CLASS = [ + "STANDARD", + "REDUCED_REDUNDANCY", + "STANDARD_IA", + "ONEZONE_IA", + "INTELLIGENT_TIERING", + "GLACIER", + "DEEP_ARCHIVE", +] DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024 DEFAULT_TEXT_ENCODING = sys.getdefaultencoding() +OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a" class FakeDeleteMarker(BaseModel): - def __init__(self, key): self.key = key self.name = key.name @@ -52,7 +75,6 @@ class FakeDeleteMarker(BaseModel): class FakeKey(BaseModel): - def __init__( self, name, @@ -62,11 +84,11 @@ class FakeKey(BaseModel): is_versioned=False, version_id=0, max_buffer_size=DEFAULT_KEY_BUFFER_SIZE, - multipart=None + multipart=None, ): self.name = name self.last_modified = datetime.datetime.utcnow() - self.acl = get_canned_acl('private') + self.acl = get_canned_acl("private") self.website_redirect_location = None self._storage_class = storage if storage else "STANDARD" self._metadata = {} @@ -179,21 +201,21 @@ class FakeKey(BaseModel): @property def response_dict(self): res = { - 'ETag': self.etag, - 'last-modified': self.last_modified_RFC1123, - 'content-length': str(self.size), + "ETag": self.etag, + "last-modified": self.last_modified_RFC1123, + "content-length": str(self.size), } - if self._storage_class != 'STANDARD': - res['x-amz-storage-class'] = self._storage_class + if self._storage_class != "STANDARD": + res["x-amz-storage-class"] = self._storage_class if self._expiry is not None: rhdr = 'ongoing-request="false", expiry-date="{0}"' - res['x-amz-restore'] = rhdr.format(self.expiry_date) + res["x-amz-restore"] = rhdr.format(self.expiry_date) if self._is_versioned: - res['x-amz-version-id'] = str(self.version_id) + res["x-amz-version-id"] = str(self.version_id) if self.website_redirect_location: - res['x-amz-website-redirect-location'] = self.website_redirect_location + res["x-amz-website-redirect-location"] = self.website_redirect_location return res @@ -217,30 +239,27 @@ class FakeKey(BaseModel): # https://docs.python.org/3/library/pickle.html#handling-stateful-objects def __getstate__(self): state = self.__dict__.copy() - state['value'] = self.value - del state['_value_buffer'] + state["value"] = self.value + del state["_value_buffer"] return state def __setstate__(self, state): - self.__dict__.update({ - k: v for k, v in six.iteritems(state) - if k != 'value' - }) + self.__dict__.update({k: v for k, v in six.iteritems(state) if k != "value"}) - self._value_buffer = \ - tempfile.SpooledTemporaryFile(max_size=self._max_buffer_size) - self.value = state['value'] + self._value_buffer = tempfile.SpooledTemporaryFile( + max_size=self._max_buffer_size + ) + self.value = state["value"] class FakeMultipart(BaseModel): - def __init__(self, key_name, metadata): self.key_name = key_name self.metadata = metadata self.parts = {} self.partlist = [] # ordered list of part ID's rand_b64 = base64.b64encode(os.urandom(UPLOAD_ID_BYTES)) - self.id = rand_b64.decode('utf-8').replace('=', '').replace('+', '') + self.id = rand_b64.decode("utf-8").replace("=", "").replace("+", "") def complete(self, body): decode_hex = codecs.getdecoder("hex_codec") @@ -253,8 +272,8 @@ class FakeMultipart(BaseModel): part = self.parts.get(pn) part_etag = None if part is not None: - part_etag = part.etag.replace('"', '') - etag = etag.replace('"', '') + part_etag = part.etag.replace('"', "") + etag = etag.replace('"', "") if part is None or part_etag != etag: raise InvalidPart() if last is not None and len(last.value) < UPLOAD_PART_MIN_SIZE: @@ -284,8 +303,7 @@ class FakeMultipart(BaseModel): class FakeGrantee(BaseModel): - - def __init__(self, id='', uri='', display_name=''): + def __init__(self, id="", uri="", display_name=""): self.id = id self.uri = uri self.display_name = display_name @@ -293,43 +311,57 @@ class FakeGrantee(BaseModel): def __eq__(self, other): if not isinstance(other, FakeGrantee): return False - return self.id == other.id and self.uri == other.uri and self.display_name == other.display_name + return ( + self.id == other.id + and self.uri == other.uri + and self.display_name == other.display_name + ) @property def type(self): - return 'Group' if self.uri else 'CanonicalUser' + return "Group" if self.uri else "CanonicalUser" def __repr__(self): - return "FakeGrantee(display_name: '{}', id: '{}', uri: '{}')".format(self.display_name, self.id, self.uri) + return "FakeGrantee(display_name: '{}', id: '{}', uri: '{}')".format( + self.display_name, self.id, self.uri + ) -ALL_USERS_GRANTEE = FakeGrantee( - uri='http://acs.amazonaws.com/groups/global/AllUsers') +ALL_USERS_GRANTEE = FakeGrantee(uri="http://acs.amazonaws.com/groups/global/AllUsers") AUTHENTICATED_USERS_GRANTEE = FakeGrantee( - uri='http://acs.amazonaws.com/groups/global/AuthenticatedUsers') -LOG_DELIVERY_GRANTEE = FakeGrantee( - uri='http://acs.amazonaws.com/groups/s3/LogDelivery') + uri="http://acs.amazonaws.com/groups/global/AuthenticatedUsers" +) +LOG_DELIVERY_GRANTEE = FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery") -PERMISSION_FULL_CONTROL = 'FULL_CONTROL' -PERMISSION_WRITE = 'WRITE' -PERMISSION_READ = 'READ' -PERMISSION_WRITE_ACP = 'WRITE_ACP' -PERMISSION_READ_ACP = 'READ_ACP' +PERMISSION_FULL_CONTROL = "FULL_CONTROL" +PERMISSION_WRITE = "WRITE" +PERMISSION_READ = "READ" +PERMISSION_WRITE_ACP = "WRITE_ACP" +PERMISSION_READ_ACP = "READ_ACP" + +CAMEL_CASED_PERMISSIONS = { + "FULL_CONTROL": "FullControl", + "WRITE": "Write", + "READ": "Read", + "WRITE_ACP": "WriteAcp", + "READ_ACP": "ReadAcp", +} class FakeGrant(BaseModel): - def __init__(self, grantees, permissions): self.grantees = grantees self.permissions = permissions def __repr__(self): - return "FakeGrant(grantees: {}, permissions: {})".format(self.grantees, self.permissions) + return "FakeGrant(grantees: {}, permissions: {})".format( + self.grantees, self.permissions + ) class FakeAcl(BaseModel): - - def __init__(self, grants=[]): + def __init__(self, grants=None): + grants = grants or [] self.grants = grants @property @@ -345,75 +377,168 @@ class FakeAcl(BaseModel): def __repr__(self): return "FakeAcl(grants: {})".format(self.grants) + def to_config_dict(self): + """Returns the object into the format expected by AWS Config""" + data = { + "grantSet": None, # Always setting this to None. Feel free to change. + "owner": {"displayName": None, "id": OWNER}, + } + + # Add details for each Grant: + grant_list = [] + for grant in self.grants: + permissions = ( + grant.permissions + if isinstance(grant.permissions, list) + else [grant.permissions] + ) + for permission in permissions: + for grantee in grant.grantees: + # Config does not add the owner if its permissions are FULL_CONTROL: + if permission == "FULL_CONTROL" and grantee.id == OWNER: + continue + + if grantee.uri: + grant_list.append( + { + "grantee": grantee.uri.split( + "http://acs.amazonaws.com/groups/s3/" + )[1], + "permission": CAMEL_CASED_PERMISSIONS[permission], + } + ) + else: + grant_list.append( + { + "grantee": { + "id": grantee.id, + "displayName": None + if not grantee.display_name + else grantee.display_name, + }, + "permission": CAMEL_CASED_PERMISSIONS[permission], + } + ) + + if grant_list: + data["grantList"] = grant_list + + return data + def get_canned_acl(acl): - owner_grantee = FakeGrantee( - id='75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a') + owner_grantee = FakeGrantee(id=OWNER) grants = [FakeGrant([owner_grantee], [PERMISSION_FULL_CONTROL])] - if acl == 'private': + if acl == "private": pass # no other permissions - elif acl == 'public-read': + elif acl == "public-read": grants.append(FakeGrant([ALL_USERS_GRANTEE], [PERMISSION_READ])) - elif acl == 'public-read-write': - grants.append(FakeGrant([ALL_USERS_GRANTEE], [ - PERMISSION_READ, PERMISSION_WRITE])) - elif acl == 'authenticated-read': + elif acl == "public-read-write": grants.append( - FakeGrant([AUTHENTICATED_USERS_GRANTEE], [PERMISSION_READ])) - elif acl == 'bucket-owner-read': + FakeGrant([ALL_USERS_GRANTEE], [PERMISSION_READ, PERMISSION_WRITE]) + ) + elif acl == "authenticated-read": + grants.append(FakeGrant([AUTHENTICATED_USERS_GRANTEE], [PERMISSION_READ])) + elif acl == "bucket-owner-read": pass # TODO: bucket owner ACL - elif acl == 'bucket-owner-full-control': + elif acl == "bucket-owner-full-control": pass # TODO: bucket owner ACL - elif acl == 'aws-exec-read': + elif acl == "aws-exec-read": pass # TODO: bucket owner, EC2 Read - elif acl == 'log-delivery-write': - grants.append(FakeGrant([LOG_DELIVERY_GRANTEE], [ - PERMISSION_READ_ACP, PERMISSION_WRITE])) + elif acl == "log-delivery-write": + grants.append( + FakeGrant([LOG_DELIVERY_GRANTEE], [PERMISSION_READ_ACP, PERMISSION_WRITE]) + ) else: - assert False, 'Unknown canned acl: %s' % (acl,) + assert False, "Unknown canned acl: %s" % (acl,) return FakeAcl(grants=grants) class FakeTagging(BaseModel): - def __init__(self, tag_set=None): self.tag_set = tag_set or FakeTagSet() class FakeTagSet(BaseModel): - def __init__(self, tags=None): self.tags = tags or [] class FakeTag(BaseModel): - def __init__(self, key, value=None): self.key = key self.value = value class LifecycleFilter(BaseModel): - def __init__(self, prefix=None, tag=None, and_filter=None): - self.prefix = prefix or '' + self.prefix = prefix self.tag = tag self.and_filter = and_filter + def to_config_dict(self): + if self.prefix is not None: + return { + "predicate": {"type": "LifecyclePrefixPredicate", "prefix": self.prefix} + } + + elif self.tag: + return { + "predicate": { + "type": "LifecycleTagPredicate", + "tag": {"key": self.tag.key, "value": self.tag.value}, + } + } + + else: + return { + "predicate": { + "type": "LifecycleAndOperator", + "operands": self.and_filter.to_config_dict(), + } + } + class LifecycleAndFilter(BaseModel): - def __init__(self, prefix=None, tags=None): - self.prefix = prefix or '' + self.prefix = prefix self.tags = tags + def to_config_dict(self): + data = [] + + if self.prefix is not None: + data.append({"type": "LifecyclePrefixPredicate", "prefix": self.prefix}) + + for tag in self.tags: + data.append( + { + "type": "LifecycleTagPredicate", + "tag": {"key": tag.key, "value": tag.value}, + } + ) + + return data + class LifecycleRule(BaseModel): - - def __init__(self, id=None, prefix=None, lc_filter=None, status=None, expiration_days=None, - expiration_date=None, transition_days=None, transition_date=None, storage_class=None, - expired_object_delete_marker=None, nve_noncurrent_days=None, nvt_noncurrent_days=None, - nvt_storage_class=None, aimu_days=None): + def __init__( + self, + id=None, + prefix=None, + lc_filter=None, + status=None, + expiration_days=None, + expiration_date=None, + transition_days=None, + transition_date=None, + storage_class=None, + expired_object_delete_marker=None, + nve_noncurrent_days=None, + nvt_noncurrent_days=None, + nvt_storage_class=None, + aimu_days=None, + ): self.id = id self.prefix = prefix self.filter = lc_filter @@ -429,40 +554,212 @@ class LifecycleRule(BaseModel): self.nvt_storage_class = nvt_storage_class self.aimu_days = aimu_days + def to_config_dict(self): + """Converts the object to the AWS Config data dict. + + Note: The following are missing that should be added in the future: + - transitions (returns None for now) + - noncurrentVersionTransitions (returns None for now) + + :param kwargs: + :return: + """ + + lifecycle_dict = { + "id": self.id, + "prefix": self.prefix, + "status": self.status, + "expirationInDays": int(self.expiration_days) + if self.expiration_days + else None, + "expiredObjectDeleteMarker": self.expired_object_delete_marker, + "noncurrentVersionExpirationInDays": -1 or int(self.nve_noncurrent_days), + "expirationDate": self.expiration_date, + "transitions": None, # Replace me with logic to fill in + "noncurrentVersionTransitions": None, # Replace me with logic to fill in + } + + if self.aimu_days: + lifecycle_dict["abortIncompleteMultipartUpload"] = { + "daysAfterInitiation": self.aimu_days + } + else: + lifecycle_dict["abortIncompleteMultipartUpload"] = None + + # Format the filter: + if self.prefix is None and self.filter is None: + lifecycle_dict["filter"] = {"predicate": None} + + elif self.prefix: + lifecycle_dict["filter"] = None + else: + lifecycle_dict["filter"] = self.filter.to_config_dict() + + return lifecycle_dict + class CorsRule(BaseModel): - - def __init__(self, allowed_methods, allowed_origins, allowed_headers=None, expose_headers=None, - max_age_seconds=None): - self.allowed_methods = [allowed_methods] if isinstance(allowed_methods, six.string_types) else allowed_methods - self.allowed_origins = [allowed_origins] if isinstance(allowed_origins, six.string_types) else allowed_origins - self.allowed_headers = [allowed_headers] if isinstance(allowed_headers, six.string_types) else allowed_headers - self.exposed_headers = [expose_headers] if isinstance(expose_headers, six.string_types) else expose_headers + def __init__( + self, + allowed_methods, + allowed_origins, + allowed_headers=None, + expose_headers=None, + max_age_seconds=None, + ): + self.allowed_methods = ( + [allowed_methods] + if isinstance(allowed_methods, six.string_types) + else allowed_methods + ) + self.allowed_origins = ( + [allowed_origins] + if isinstance(allowed_origins, six.string_types) + else allowed_origins + ) + self.allowed_headers = ( + [allowed_headers] + if isinstance(allowed_headers, six.string_types) + else allowed_headers + ) + self.exposed_headers = ( + [expose_headers] + if isinstance(expose_headers, six.string_types) + else expose_headers + ) self.max_age_seconds = max_age_seconds class Notification(BaseModel): - def __init__(self, arn, events, filters=None, id=None): - self.id = id if id else ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(50)) + self.id = ( + id + if id + else "".join( + random.choice(string.ascii_letters + string.digits) for _ in range(50) + ) + ) self.arn = arn self.events = events self.filters = filters if filters else {} + def to_config_dict(self): + data = {} + + # Type and ARN will be filled in by NotificationConfiguration's to_config_dict: + data["events"] = [event for event in self.events] + + if self.filters: + data["filter"] = { + "s3KeyFilter": { + "filterRules": [ + {"name": fr["Name"], "value": fr["Value"]} + for fr in self.filters["S3Key"]["FilterRule"] + ] + } + } + else: + data["filter"] = None + + # Not sure why this is a thing since AWS just seems to return this as filters ¯\_(ツ)_/¯ + data["objectPrefixes"] = [] + + return data + class NotificationConfiguration(BaseModel): - def __init__(self, topic=None, queue=None, cloud_function=None): - self.topic = [Notification(t["Topic"], t["Event"], filters=t.get("Filter"), id=t.get("Id")) for t in topic] \ - if topic else [] - self.queue = [Notification(q["Queue"], q["Event"], filters=q.get("Filter"), id=q.get("Id")) for q in queue] \ - if queue else [] - self.cloud_function = [Notification(c["CloudFunction"], c["Event"], filters=c.get("Filter"), id=c.get("Id")) - for c in cloud_function] if cloud_function else [] + self.topic = ( + [ + Notification( + t["Topic"], t["Event"], filters=t.get("Filter"), id=t.get("Id") + ) + for t in topic + ] + if topic + else [] + ) + self.queue = ( + [ + Notification( + q["Queue"], q["Event"], filters=q.get("Filter"), id=q.get("Id") + ) + for q in queue + ] + if queue + else [] + ) + self.cloud_function = ( + [ + Notification( + c["CloudFunction"], + c["Event"], + filters=c.get("Filter"), + id=c.get("Id"), + ) + for c in cloud_function + ] + if cloud_function + else [] + ) + + def to_config_dict(self): + data = {"configurations": {}} + + for topic in self.topic: + topic_config = topic.to_config_dict() + topic_config["topicARN"] = topic.arn + topic_config["type"] = "TopicConfiguration" + data["configurations"][topic.id] = topic_config + + for queue in self.queue: + queue_config = queue.to_config_dict() + queue_config["queueARN"] = queue.arn + queue_config["type"] = "QueueConfiguration" + data["configurations"][queue.id] = queue_config + + for cloud_function in self.cloud_function: + cf_config = cloud_function.to_config_dict() + cf_config["queueARN"] = cloud_function.arn + cf_config["type"] = "LambdaConfiguration" + data["configurations"][cloud_function.id] = cf_config + + return data + + +def convert_str_to_bool(item): + """Converts a boolean string to a boolean value""" + if isinstance(item, str): + return item.lower() == "true" + + return False + + +class PublicAccessBlock(BaseModel): + def __init__( + self, + block_public_acls, + ignore_public_acls, + block_public_policy, + restrict_public_buckets, + ): + # The boto XML appears to expect these values to exist as lowercase strings... + self.block_public_acls = block_public_acls or "false" + self.ignore_public_acls = ignore_public_acls or "false" + self.block_public_policy = block_public_policy or "false" + self.restrict_public_buckets = restrict_public_buckets or "false" + + def to_config_dict(self): + # Need to make the string values booleans for Config: + return { + "blockPublicAcls": convert_str_to_bool(self.block_public_acls), + "ignorePublicAcls": convert_str_to_bool(self.ignore_public_acls), + "blockPublicPolicy": convert_str_to_bool(self.block_public_policy), + "restrictPublicBuckets": convert_str_to_bool(self.restrict_public_buckets), + } class FakeBucket(BaseModel): - def __init__(self, name, region_name): self.name = name self.region_name = region_name @@ -472,12 +769,15 @@ class FakeBucket(BaseModel): self.rules = [] self.policy = None self.website_configuration = None - self.acl = get_canned_acl('private') + self.acl = get_canned_acl("private") self.tags = FakeTagging() self.cors = [] self.logging = {} self.notification_configuration = None self.accelerate_configuration = None + self.payer = "BucketOwner" + self.creation_date = datetime.datetime.utcnow() + self.public_access_block = None @property def location(self): @@ -485,36 +785,52 @@ class FakeBucket(BaseModel): @property def is_versioned(self): - return self.versioning_status == 'Enabled' + return self.versioning_status == "Enabled" def set_lifecycle(self, rules): self.rules = [] for rule in rules: # Extract and validate actions from Lifecycle rule - expiration = rule.get('Expiration') - transition = rule.get('Transition') + expiration = rule.get("Expiration") + transition = rule.get("Transition") + + try: + top_level_prefix = ( + rule["Prefix"] or "" + ) # If it's `None` the set to the empty string + except KeyError: + top_level_prefix = None nve_noncurrent_days = None - if rule.get('NoncurrentVersionExpiration') is not None: - if rule["NoncurrentVersionExpiration"].get('NoncurrentDays') is None: + if rule.get("NoncurrentVersionExpiration") is not None: + if rule["NoncurrentVersionExpiration"].get("NoncurrentDays") is None: raise MalformedXML() - nve_noncurrent_days = rule["NoncurrentVersionExpiration"]["NoncurrentDays"] + nve_noncurrent_days = rule["NoncurrentVersionExpiration"][ + "NoncurrentDays" + ] nvt_noncurrent_days = None nvt_storage_class = None - if rule.get('NoncurrentVersionTransition') is not None: - if rule["NoncurrentVersionTransition"].get('NoncurrentDays') is None: + if rule.get("NoncurrentVersionTransition") is not None: + if rule["NoncurrentVersionTransition"].get("NoncurrentDays") is None: raise MalformedXML() - if rule["NoncurrentVersionTransition"].get('StorageClass') is None: + if rule["NoncurrentVersionTransition"].get("StorageClass") is None: raise MalformedXML() - nvt_noncurrent_days = rule["NoncurrentVersionTransition"]["NoncurrentDays"] + nvt_noncurrent_days = rule["NoncurrentVersionTransition"][ + "NoncurrentDays" + ] nvt_storage_class = rule["NoncurrentVersionTransition"]["StorageClass"] aimu_days = None - if rule.get('AbortIncompleteMultipartUpload') is not None: - if rule["AbortIncompleteMultipartUpload"].get('DaysAfterInitiation') is None: + if rule.get("AbortIncompleteMultipartUpload") is not None: + if ( + rule["AbortIncompleteMultipartUpload"].get("DaysAfterInitiation") + is None + ): raise MalformedXML() - aimu_days = rule["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] + aimu_days = rule["AbortIncompleteMultipartUpload"][ + "DaysAfterInitiation" + ] eodm = None if expiration and expiration.get("ExpiredObjectDeleteMarker") is not None: @@ -528,45 +844,86 @@ class FakeBucket(BaseModel): if rule.get("Filter"): # Can't have both `Filter` and `Prefix` (need to check for the presence of the key): try: + # 'Prefix' cannot be outside of a Filter: if rule["Prefix"] or not rule["Prefix"]: raise MalformedXML() except KeyError: pass + filters = 0 + try: + prefix_filter = ( + rule["Filter"]["Prefix"] or "" + ) # If it's `None` the set to the empty string + filters += 1 + except KeyError: + prefix_filter = None + and_filter = None if rule["Filter"].get("And"): + filters += 1 and_tags = [] if rule["Filter"]["And"].get("Tag"): if not isinstance(rule["Filter"]["And"]["Tag"], list): - rule["Filter"]["And"]["Tag"] = [rule["Filter"]["And"]["Tag"]] + rule["Filter"]["And"]["Tag"] = [ + rule["Filter"]["And"]["Tag"] + ] for t in rule["Filter"]["And"]["Tag"]: - and_tags.append(FakeTag(t["Key"], t.get("Value", ''))) + and_tags.append(FakeTag(t["Key"], t.get("Value", ""))) - and_filter = LifecycleAndFilter(prefix=rule["Filter"]["And"]["Prefix"], tags=and_tags) + try: + and_prefix = ( + rule["Filter"]["And"]["Prefix"] or "" + ) # If it's `None` then set to the empty string + except KeyError: + and_prefix = None + + and_filter = LifecycleAndFilter(prefix=and_prefix, tags=and_tags) filter_tag = None if rule["Filter"].get("Tag"): - filter_tag = FakeTag(rule["Filter"]["Tag"]["Key"], rule["Filter"]["Tag"].get("Value", '')) + filters += 1 + filter_tag = FakeTag( + rule["Filter"]["Tag"]["Key"], + rule["Filter"]["Tag"].get("Value", ""), + ) - lc_filter = LifecycleFilter(prefix=rule["Filter"]["Prefix"], tag=filter_tag, and_filter=and_filter) + # Can't have more than 1 filter: + if filters > 1: + raise MalformedXML() - self.rules.append(LifecycleRule( - id=rule.get('ID'), - prefix=rule.get('Prefix'), - lc_filter=lc_filter, - status=rule['Status'], - expiration_days=expiration.get('Days') if expiration else None, - expiration_date=expiration.get('Date') if expiration else None, - transition_days=transition.get('Days') if transition else None, - transition_date=transition.get('Date') if transition else None, - storage_class=transition.get('StorageClass') if transition else None, - expired_object_delete_marker=eodm, - nve_noncurrent_days=nve_noncurrent_days, - nvt_noncurrent_days=nvt_noncurrent_days, - nvt_storage_class=nvt_storage_class, - aimu_days=aimu_days, - )) + lc_filter = LifecycleFilter( + prefix=prefix_filter, tag=filter_tag, and_filter=and_filter + ) + + # If no top level prefix and no filter is present, then this is invalid: + if top_level_prefix is None: + try: + rule["Filter"] + except KeyError: + raise MalformedXML() + + self.rules.append( + LifecycleRule( + id=rule.get("ID"), + prefix=top_level_prefix, + lc_filter=lc_filter, + status=rule["Status"], + expiration_days=expiration.get("Days") if expiration else None, + expiration_date=expiration.get("Date") if expiration else None, + transition_days=transition.get("Days") if transition else None, + transition_date=transition.get("Date") if transition else None, + storage_class=transition.get("StorageClass") + if transition + else None, + expired_object_delete_marker=eodm, + nve_noncurrent_days=nve_noncurrent_days, + nvt_noncurrent_days=nvt_noncurrent_days, + nvt_storage_class=nvt_storage_class, + aimu_days=aimu_days, + ) + ) def delete_lifecycle(self): self.rules = [] @@ -578,12 +935,18 @@ class FakeBucket(BaseModel): raise MalformedXML() for rule in rules: - assert isinstance(rule["AllowedMethod"], list) or isinstance(rule["AllowedMethod"], six.string_types) - assert isinstance(rule["AllowedOrigin"], list) or isinstance(rule["AllowedOrigin"], six.string_types) - assert isinstance(rule.get("AllowedHeader", []), list) or isinstance(rule.get("AllowedHeader", ""), - six.string_types) - assert isinstance(rule.get("ExposedHeader", []), list) or isinstance(rule.get("ExposedHeader", ""), - six.string_types) + assert isinstance(rule["AllowedMethod"], list) or isinstance( + rule["AllowedMethod"], six.string_types + ) + assert isinstance(rule["AllowedOrigin"], list) or isinstance( + rule["AllowedOrigin"], six.string_types + ) + assert isinstance(rule.get("AllowedHeader", []), list) or isinstance( + rule.get("AllowedHeader", ""), six.string_types + ) + assert isinstance(rule.get("ExposedHeader", []), list) or isinstance( + rule.get("ExposedHeader", ""), six.string_types + ) assert isinstance(rule.get("MaxAgeSeconds", "0"), six.string_types) if isinstance(rule["AllowedMethod"], six.string_types): @@ -595,13 +958,15 @@ class FakeBucket(BaseModel): if method not in ["GET", "PUT", "HEAD", "POST", "DELETE"]: raise InvalidRequest(method) - self.cors.append(CorsRule( - rule["AllowedMethod"], - rule["AllowedOrigin"], - rule.get("AllowedHeader"), - rule.get("ExposedHeader"), - rule.get("MaxAgeSecond") - )) + self.cors.append( + CorsRule( + rule["AllowedMethod"], + rule["AllowedOrigin"], + rule.get("AllowedHeader"), + rule.get("ExposedHeader"), + rule.get("MaxAgeSecond"), + ) + ) def delete_cors(self): self.cors = [] @@ -623,7 +988,9 @@ class FakeBucket(BaseModel): # Target bucket must exist in the same account (assuming all moto buckets are in the same account): if not bucket_backend.buckets.get(logging_config["TargetBucket"]): - raise InvalidTargetBucketForLogging("The target bucket for logging does not exist.") + raise InvalidTargetBucketForLogging( + "The target bucket for logging does not exist." + ) # Does the target bucket have the log-delivery WRITE and READ_ACP permissions? write = read_acp = False @@ -631,20 +998,31 @@ class FakeBucket(BaseModel): # Must be granted to: http://acs.amazonaws.com/groups/s3/LogDelivery for grantee in grant.grantees: if grantee.uri == "http://acs.amazonaws.com/groups/s3/LogDelivery": - if "WRITE" in grant.permissions or "FULL_CONTROL" in grant.permissions: + if ( + "WRITE" in grant.permissions + or "FULL_CONTROL" in grant.permissions + ): write = True - if "READ_ACP" in grant.permissions or "FULL_CONTROL" in grant.permissions: + if ( + "READ_ACP" in grant.permissions + or "FULL_CONTROL" in grant.permissions + ): read_acp = True break if not write or not read_acp: - raise InvalidTargetBucketForLogging("You must give the log-delivery group WRITE and READ_ACP" - " permissions to the target bucket") + raise InvalidTargetBucketForLogging( + "You must give the log-delivery group WRITE and READ_ACP" + " permissions to the target bucket" + ) # Buckets must also exist within the same region: - if bucket_backend.buckets[logging_config["TargetBucket"]].region_name != self.region_name: + if ( + bucket_backend.buckets[logging_config["TargetBucket"]].region_name + != self.region_name + ): raise CrossLocationLoggingProhibitted() # Checks pass -- set the logging config: @@ -658,7 +1036,7 @@ class FakeBucket(BaseModel): self.notification_configuration = NotificationConfiguration( topic=notification_config.get("TopicConfiguration"), queue=notification_config.get("QueueConfiguration"), - cloud_function=notification_config.get("CloudFunctionConfiguration") + cloud_function=notification_config.get("CloudFunctionConfiguration"), ) # Validate that the region is correct: @@ -669,7 +1047,7 @@ class FakeBucket(BaseModel): raise InvalidNotificationDestination() def set_accelerate_configuration(self, accelerate_config): - if self.accelerate_configuration is None and accelerate_config == 'Suspended': + if self.accelerate_configuration is None and accelerate_config == "Suspended": # Cannot "suspend" a not active acceleration. Leaves it undefined return @@ -680,12 +1058,11 @@ class FakeBucket(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'DomainName': - raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "DomainName" ]"') - elif attribute_name == 'WebsiteURL': - raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "WebsiteURL" ]"') + + if attribute_name == "DomainName": + raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "DomainName" ]"') + elif attribute_name == "WebsiteURL": + raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "WebsiteURL" ]"') raise UnformattedGetAttTemplateException() def set_acl(self, acl): @@ -697,13 +1074,93 @@ class FakeBucket(BaseModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name): + cls, resource_name, cloudformation_json, region_name + ): bucket = s3_backend.create_bucket(resource_name, region_name) return bucket + def to_config_dict(self): + """Return the AWS Config JSON format of this S3 bucket. + + Note: The following features are not implemented and will need to be if you care about them: + - Bucket Accelerate Configuration + """ + config_dict = { + "version": "1.3", + "configurationItemCaptureTime": str(self.creation_date), + "configurationItemStatus": "ResourceDiscovered", + "configurationStateId": str( + int(time.mktime(self.creation_date.timetuple())) + ), # PY2 and 3 compatible + "configurationItemMD5Hash": "", + "arn": "arn:aws:s3:::{}".format(self.name), + "resourceType": "AWS::S3::Bucket", + "resourceId": self.name, + "resourceName": self.name, + "awsRegion": self.region_name, + "availabilityZone": "Regional", + "resourceCreationTime": str(self.creation_date), + "relatedEvents": [], + "relationships": [], + "tags": {tag.key: tag.value for tag in self.tagging.tag_set.tags}, + "configuration": { + "name": self.name, + "owner": {"id": OWNER}, + "creationDate": self.creation_date.isoformat(), + }, + } + + # Make the supplementary configuration: + # This is a dobule-wrapped JSON for some reason... + s_config = { + "AccessControlList": json.dumps(json.dumps(self.acl.to_config_dict())) + } + + if self.public_access_block: + s_config["PublicAccessBlockConfiguration"] = json.dumps( + self.public_access_block.to_config_dict() + ) + + # Tagging is special: + if config_dict["tags"]: + s_config["BucketTaggingConfiguration"] = json.dumps( + {"tagSets": [{"tags": config_dict["tags"]}]} + ) + + # TODO implement Accelerate Configuration: + s_config["BucketAccelerateConfiguration"] = {"status": None} + + if self.rules: + s_config["BucketLifecycleConfiguration"] = { + "rules": [rule.to_config_dict() for rule in self.rules] + } + + s_config["BucketLoggingConfiguration"] = { + "destinationBucketName": self.logging.get("TargetBucket", None), + "logFilePrefix": self.logging.get("TargetPrefix", None), + } + + s_config["BucketPolicy"] = { + "policyText": self.policy.decode("utf-8") if self.policy else None + } + + s_config["IsRequesterPaysEnabled"] = ( + "false" if self.payer == "BucketOwner" else "true" + ) + + if self.notification_configuration: + s_config[ + "BucketNotificationConfiguration" + ] = self.notification_configuration.to_config_dict() + else: + s_config["BucketNotificationConfiguration"] = {"configurations": {}} + + config_dict["supplementaryConfiguration"] = s_config + + return config_dict + class S3Backend(BaseBackend): - def __init__(self): self.buckets = {} @@ -749,27 +1206,33 @@ class S3Backend(BaseBackend): last_modified = version.last_modified version_id = version.version_id latest_modified_per_key[name] = max( - last_modified, - latest_modified_per_key.get(name, datetime.datetime.min) + last_modified, latest_modified_per_key.get(name, datetime.datetime.min) ) if last_modified == latest_modified_per_key[name]: latest_versions[name] = version_id return latest_versions - def get_bucket_versions(self, bucket_name, delimiter=None, - encoding_type=None, - key_marker=None, - max_keys=None, - version_id_marker=None, - prefix=''): + def get_bucket_versions( + self, + bucket_name, + delimiter=None, + encoding_type=None, + key_marker=None, + max_keys=None, + version_id_marker=None, + prefix="", + ): bucket = self.get_bucket(bucket_name) if any((delimiter, key_marker, version_id_marker)): raise NotImplementedError( - "Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker") + "Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker" + ) - return itertools.chain(*(l for key, l in bucket.keys.iterlists() if key.startswith(prefix))) + return itertools.chain( + *(l for key, l in bucket.keys.iterlists() if key.startswith(prefix)) + ) def get_bucket_policy(self, bucket_name): return self.get_bucket(bucket_name).policy @@ -793,14 +1256,16 @@ class S3Backend(BaseBackend): bucket = self.get_bucket(bucket_name) return bucket.website_configuration + def get_bucket_public_access_block(self, bucket_name): + bucket = self.get_bucket(bucket_name) + + if not bucket.public_access_block: + raise NoSuchPublicAccessBlockConfiguration() + + return bucket.public_access_block + def set_key( - self, - bucket_name, - key_name, - value, - storage=None, - etag=None, - multipart=None, + self, bucket_name, key_name, value, storage=None, etag=None, multipart=None ): key_name = clean_key_name(key_name) if storage is not None and storage not in STORAGE_CLASS: @@ -819,7 +1284,8 @@ class S3Backend(BaseBackend): ) keys = [ - key for key in bucket.keys.getlist(key_name, []) + key + for key in bucket.keys.getlist(key_name, []) if key.version_id != new_key.version_id ] + [new_key] bucket.keys.setlist(key_name, keys) @@ -848,7 +1314,7 @@ class S3Backend(BaseBackend): key = key_version break - if part_number and key.multipart: + if part_number and key and key.multipart: key = key.multipart.parts[part_number] if isinstance(key, FakeKey): @@ -856,8 +1322,8 @@ class S3Backend(BaseBackend): else: return None - def set_key_tagging(self, bucket_name, key_name, tagging): - key = self.get_key(bucket_name, key_name) + def set_key_tagging(self, bucket_name, key_name, tagging, version_id=None): + key = self.get_key(bucket_name, key_name, version_id) if key is None: raise MissingKey(key_name) key.set_tagging(tagging) @@ -886,19 +1352,38 @@ class S3Backend(BaseBackend): bucket = self.get_bucket(bucket_name) bucket.delete_cors() + def delete_bucket_public_access_block(self, bucket_name): + bucket = self.get_bucket(bucket_name) + bucket.public_access_block = None + def put_bucket_notification_configuration(self, bucket_name, notification_config): bucket = self.get_bucket(bucket_name) bucket.set_notification_configuration(notification_config) - def put_bucket_accelerate_configuration(self, bucket_name, accelerate_configuration): - if accelerate_configuration not in ['Enabled', 'Suspended']: + def put_bucket_accelerate_configuration( + self, bucket_name, accelerate_configuration + ): + if accelerate_configuration not in ["Enabled", "Suspended"]: raise MalformedXML() bucket = self.get_bucket(bucket_name) - if bucket.name.find('.') != -1: - raise InvalidRequest('PutBucketAccelerateConfiguration') + if bucket.name.find(".") != -1: + raise InvalidRequest("PutBucketAccelerateConfiguration") bucket.set_accelerate_configuration(accelerate_configuration) + def put_bucket_public_access_block(self, bucket_name, pub_block_config): + bucket = self.get_bucket(bucket_name) + + if not pub_block_config: + raise InvalidPublicAccessBlockConfiguration() + + bucket.public_access_block = PublicAccessBlock( + pub_block_config.get("BlockPublicAcls"), + pub_block_config.get("IgnorePublicAcls"), + pub_block_config.get("BlockPublicPolicy"), + pub_block_config.get("RestrictPublicBuckets"), + ) + def initiate_multipart(self, bucket_name, key_name, metadata): bucket = self.get_bucket(bucket_name) new_multipart = FakeMultipart(key_name, metadata) @@ -915,10 +1400,7 @@ class S3Backend(BaseBackend): del bucket.multiparts[multipart_id] key = self.set_key( - bucket_name, - multipart.key_name, - value, etag=etag, - multipart=multipart + bucket_name, multipart.key_name, value, etag=etag, multipart=multipart ) key.set_metadata(multipart.metadata) return key @@ -940,14 +1422,25 @@ class S3Backend(BaseBackend): multipart = bucket.multiparts[multipart_id] return multipart.set_part(part_id, value) - def copy_part(self, dest_bucket_name, multipart_id, part_id, - src_bucket_name, src_key_name, src_version_id, start_byte, end_byte): + def copy_part( + self, + dest_bucket_name, + multipart_id, + part_id, + src_bucket_name, + src_key_name, + src_version_id, + start_byte, + end_byte, + ): dest_bucket = self.get_bucket(dest_bucket_name) multipart = dest_bucket.multiparts[multipart_id] - src_value = self.get_key(src_bucket_name, src_key_name, version_id=src_version_id).value + src_value = self.get_key( + src_bucket_name, src_key_name, version_id=src_version_id + ).value if start_byte is not None: - src_value = src_value[start_byte:end_byte + 1] + src_value = src_value[start_byte : end_byte + 1] return multipart.set_part(part_id, src_value) def prefix_query(self, bucket, prefix, delimiter): @@ -959,33 +1452,33 @@ class S3Backend(BaseBackend): key_without_prefix = key_name.replace(prefix, "", 1) if delimiter and delimiter in key_without_prefix: # If delimiter, we need to split out folder_results - key_without_delimiter = key_without_prefix.split(delimiter)[ - 0] - folder_results.add("{0}{1}{2}".format( - prefix, key_without_delimiter, delimiter)) + key_without_delimiter = key_without_prefix.split(delimiter)[0] + folder_results.add( + "{0}{1}{2}".format(prefix, key_without_delimiter, delimiter) + ) else: key_results.add(key) else: for key_name, key in bucket.keys.items(): if delimiter and delimiter in key_name: # If delimiter, we need to split out folder_results - folder_results.add(key_name.split( - delimiter)[0] + delimiter) + folder_results.add(key_name.split(delimiter)[0] + delimiter) else: key_results.add(key) - key_results = filter(lambda key: not isinstance(key, FakeDeleteMarker), key_results) + key_results = filter( + lambda key: not isinstance(key, FakeDeleteMarker), key_results + ) key_results = sorted(key_results, key=lambda key: key.name) - folder_results = [folder_name for folder_name in sorted( - folder_results, key=lambda key: key)] + folder_results = [ + folder_name for folder_name in sorted(folder_results, key=lambda key: key) + ] return key_results, folder_results def _set_delete_marker(self, bucket_name, key_name): bucket = self.get_bucket(bucket_name) - bucket.keys[key_name] = FakeDeleteMarker( - key=bucket.keys[key_name] - ) + bucket.keys[key_name] = FakeDeleteMarker(key=bucket.keys[key_name]) def delete_key(self, bucket_name, key_name, version_id=None): key_name = clean_key_name(key_name) @@ -1006,7 +1499,7 @@ class S3Backend(BaseBackend): key for key in bucket.keys.getlist(key_name) if str(key.version_id) != str(version_id) - ] + ], ) if not bucket.keys.getlist(key_name): @@ -1015,13 +1508,20 @@ class S3Backend(BaseBackend): except KeyError: return False - def copy_key(self, src_bucket_name, src_key_name, dest_bucket_name, - dest_key_name, storage=None, acl=None, src_version_id=None): + def copy_key( + self, + src_bucket_name, + src_key_name, + dest_bucket_name, + dest_key_name, + storage=None, + acl=None, + src_version_id=None, + ): src_key_name = clean_key_name(src_key_name) dest_key_name = clean_key_name(dest_key_name) dest_bucket = self.get_bucket(dest_bucket_name) - key = self.get_key(src_bucket_name, src_key_name, - version_id=src_version_id) + key = self.get_key(src_bucket_name, src_key_name, version_id=src_version_id) new_key = key.copy(dest_key_name, dest_bucket.is_versioned) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index ee047a14f..3fa793f25 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -1,10 +1,11 @@ from __future__ import unicode_literals import re +import sys import six -from moto.core.utils import str_to_rfc_1123_datetime +from moto.core.utils import str_to_rfc_1123_datetime, py2_strip_unicode_keys from six.moves.urllib.parse import parse_qs, urlparse, unquote import xmltodict @@ -12,19 +13,48 @@ import xmltodict from moto.packages.httpretty.core import HTTPrettyRequest from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin from moto.core.utils import path_url +from moto.core import ACCOUNT_ID -from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_name_from_url, \ - parse_key_name as bucketpath_parse_key_name, is_delete_keys as bucketpath_is_delete_keys +from moto.s3bucket_path.utils import ( + bucket_name_from_url as bucketpath_bucket_name_from_url, + parse_key_name as bucketpath_parse_key_name, + is_delete_keys as bucketpath_is_delete_keys, +) -from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, MissingKey, InvalidPartOrder, MalformedXML, \ - MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError -from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \ - FakeTag -from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url +from .exceptions import ( + BucketAlreadyExists, + S3ClientError, + MissingBucket, + MissingKey, + InvalidPartOrder, + MalformedXML, + MalformedACLError, + InvalidNotificationARN, + InvalidNotificationEvent, + ObjectNotInActiveTierError, +) +from .models import ( + s3_backend, + get_canned_acl, + FakeGrantee, + FakeGrant, + FakeAcl, + FakeKey, + FakeTagging, + FakeTagSet, + FakeTag, +) +from .utils import ( + bucket_name_from_url, + clean_key_name, + undo_clean_key_name, + metadata_from_headers, + parse_region_from_url, +) from xml.dom import minidom -DEFAULT_REGION_NAME = 'us-east-1' +DEFAULT_REGION_NAME = "us-east-1" ACTION_MAP = { "BUCKET": { @@ -42,7 +72,8 @@ ACTION_MAP = { "notification": "GetBucketNotification", "accelerate": "GetAccelerateConfiguration", "versions": "ListBucketVersions", - "DEFAULT": "ListBucket" + "public_access_block": "GetPublicAccessBlock", + "DEFAULT": "ListBucket", }, "PUT": { "lifecycle": "PutLifecycleConfiguration", @@ -55,15 +86,17 @@ ACTION_MAP = { "cors": "PutBucketCORS", "notification": "PutBucketNotification", "accelerate": "PutAccelerateConfiguration", - "DEFAULT": "CreateBucket" + "public_access_block": "PutPublicAccessBlock", + "DEFAULT": "CreateBucket", }, "DELETE": { "lifecycle": "PutLifecycleConfiguration", "policy": "DeleteBucketPolicy", "tagging": "PutBucketTagging", "cors": "PutBucketCORS", - "DEFAULT": "DeleteBucket" - } + "public_access_block": "DeletePublicAccessBlock", + "DEFAULT": "DeleteBucket", + }, }, "KEY": { "GET": { @@ -71,25 +104,24 @@ ACTION_MAP = { "acl": "GetObjectAcl", "tagging": "GetObjectTagging", "versionId": "GetObjectVersion", - "DEFAULT": "GetObject" + "DEFAULT": "GetObject", }, "PUT": { "acl": "PutObjectAcl", "tagging": "PutObjectTagging", - "DEFAULT": "PutObject" + "DEFAULT": "PutObject", }, "DELETE": { "uploadId": "AbortMultipartUpload", "versionId": "DeleteObjectVersion", - "DEFAULT": " DeleteObject" + "DEFAULT": " DeleteObject", }, "POST": { "uploads": "PutObject", "restore": "RestoreObject", - "uploadId": "PutObject" - } - } - + "uploadId": "PutObject", + }, + }, } @@ -98,14 +130,12 @@ def parse_key_name(pth): def is_delete_keys(request, path, bucket_name): - return path == u'/?delete' or ( - path == u'/' and - getattr(request, "query_string", "") == "delete" + return path == "/?delete" or ( + path == "/" and getattr(request, "query_string", "") == "delete" ) class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): - def __init__(self, backend): super(ResponseObject, self).__init__() self.backend = backend @@ -128,34 +158,43 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return template.render(buckets=all_buckets) def subdomain_based_buckets(self, request): - host = request.headers.get('host', request.headers.get('Host')) + host = request.headers.get("host", request.headers.get("Host")) if not host: host = urlparse(request.url).netloc - if (not host or host.startswith('localhost') or host.startswith('localstack') or - re.match(r'^[^.]+$', host) or re.match(r'^.*\.svc\.cluster\.local$', host)): + if ( + not host + or host.startswith("localhost") + or host.startswith("localstack") + or re.match(r"^[^.]+$", host) + or re.match(r"^.*\.svc\.cluster\.local$", host) + ): # Default to path-based buckets for (1) localhost, (2) localstack hosts (e.g. localstack.dev), # (3) local host names that do not contain a "." (e.g., Docker container host names), or # (4) kubernetes host names return False - match = re.match(r'^([^\[\]:]+)(:\d+)?$', host) - if match: - match = re.match(r'((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|$)){4}', - match.groups()[0]) - if match: - return False - - match = re.match(r'^\[(.+)\](:\d+)?$', host) + match = re.match(r"^([^\[\]:]+)(:\d+)?$", host) if match: match = re.match( - r'^(((?=.*(::))(?!.*\3.+\3))\3?|[\dA-F]{1,4}:)([\dA-F]{1,4}(\3|:\b)|\2){5}(([\dA-F]{1,4}(\3|:\b|$)|\2){2}|(((2[0-4]|1\d|[1-9])?\d|25[0-5])\.?\b){4})\Z', - match.groups()[0], re.IGNORECASE) + r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|$)){4}", match.groups()[0] + ) if match: return False - path_based = (host == 's3.amazonaws.com' or re.match( - r"s3[\.\-]([^.]*)\.amazonaws\.com", host)) + match = re.match(r"^\[(.+)\](:\d+)?$", host) + if match: + match = re.match( + r"^(((?=.*(::))(?!.*\3.+\3))\3?|[\dA-F]{1,4}:)([\dA-F]{1,4}(\3|:\b)|\2){5}(([\dA-F]{1,4}(\3|:\b|$)|\2){2}|(((2[0-4]|1\d|[1-9])?\d|25[0-5])\.?\b){4})\Z", + match.groups()[0], + re.IGNORECASE, + ) + if match: + return False + + path_based = host == "s3.amazonaws.com" or re.match( + r"s3[\.\-]([^.]*)\.amazonaws\.com", host + ) return not path_based def is_delete_keys(self, request, path, bucket_name): @@ -189,8 +228,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.method = request.method self.path = self._get_path(request) self.headers = request.headers - if 'host' not in self.headers: - self.headers['host'] = urlparse(full_url).netloc + if "host" not in self.headers: + self.headers["host"] = urlparse(full_url).netloc try: response = self._bucket_response(request, full_url, headers) except S3ClientError as s3error: @@ -221,31 +260,36 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.data["BucketName"] = bucket_name - if hasattr(request, 'body'): + if hasattr(request, "body"): # Boto body = request.body else: # Flask server body = request.data if body is None: - body = b'' + body = b"" if isinstance(body, six.binary_type): - body = body.decode('utf-8') - body = u'{0}'.format(body).encode('utf-8') + body = body.decode("utf-8") + body = "{0}".format(body).encode("utf-8") - if method == 'HEAD': + if method == "HEAD": return self._bucket_response_head(bucket_name) - elif method == 'GET': + elif method == "GET": return self._bucket_response_get(bucket_name, querystring) - elif method == 'PUT': - return self._bucket_response_put(request, body, region_name, bucket_name, querystring) - elif method == 'DELETE': + elif method == "PUT": + return self._bucket_response_put( + request, body, region_name, bucket_name, querystring + ) + elif method == "DELETE": return self._bucket_response_delete(body, bucket_name, querystring) - elif method == 'POST': + elif method == "POST": return self._bucket_response_post(request, body, bucket_name) else: raise NotImplementedError( - "Method {0} has not been impelemented in the S3 backend yet".format(method)) + "Method {0} has not been impelemented in the S3 backend yet".format( + method + ) + ) @staticmethod def _get_querystring(full_url): @@ -268,22 +312,25 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._set_action("BUCKET", "GET", querystring) self._authenticate_and_authorize_s3_action() - if 'uploads' in querystring: - for unsup in ('delimiter', 'max-uploads'): + if "uploads" in querystring: + for unsup in ("delimiter", "max-uploads"): if unsup in querystring: raise NotImplementedError( - "Listing multipart uploads with {} has not been implemented yet.".format(unsup)) - multiparts = list( - self.backend.get_all_multiparts(bucket_name).values()) - if 'prefix' in querystring: - prefix = querystring.get('prefix', [None])[0] + "Listing multipart uploads with {} has not been implemented yet.".format( + unsup + ) + ) + multiparts = list(self.backend.get_all_multiparts(bucket_name).values()) + if "prefix" in querystring: + prefix = querystring.get("prefix", [None])[0] multiparts = [ - upload for upload in multiparts if upload.key_name.startswith(prefix)] + upload + for upload in multiparts + if upload.key_name.startswith(prefix) + ] template = self.response_template(S3_ALL_MULTIPARTS) - return template.render( - bucket_name=bucket_name, - uploads=multiparts) - elif 'location' in querystring: + return template.render(bucket_name=bucket_name, uploads=multiparts) + elif "location" in querystring: bucket = self.backend.get_bucket(bucket_name) template = self.response_template(S3_BUCKET_LOCATION) @@ -293,36 +340,36 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): location = None return template.render(location=location) - elif 'lifecycle' in querystring: + elif "lifecycle" in querystring: bucket = self.backend.get_bucket(bucket_name) if not bucket.rules: template = self.response_template(S3_NO_LIFECYCLE) return 404, {}, template.render(bucket_name=bucket_name) - template = self.response_template( - S3_BUCKET_LIFECYCLE_CONFIGURATION) + template = self.response_template(S3_BUCKET_LIFECYCLE_CONFIGURATION) return template.render(rules=bucket.rules) - elif 'versioning' in querystring: + elif "versioning" in querystring: versioning = self.backend.get_bucket_versioning(bucket_name) template = self.response_template(S3_BUCKET_GET_VERSIONING) return template.render(status=versioning) - elif 'policy' in querystring: + elif "policy" in querystring: policy = self.backend.get_bucket_policy(bucket_name) if not policy: template = self.response_template(S3_NO_POLICY) return 404, {}, template.render(bucket_name=bucket_name) return 200, {}, policy - elif 'website' in querystring: + elif "website" in querystring: website_configuration = self.backend.get_bucket_website_configuration( - bucket_name) + bucket_name + ) if not website_configuration: template = self.response_template(S3_NO_BUCKET_WEBSITE_CONFIG) return 404, {}, template.render(bucket_name=bucket_name) return 200, {}, website_configuration - elif 'acl' in querystring: + elif "acl" in querystring: bucket = self.backend.get_bucket(bucket_name) template = self.response_template(S3_OBJECT_ACL_RESPONSE) return template.render(obj=bucket) - elif 'tagging' in querystring: + elif "tagging" in querystring: bucket = self.backend.get_bucket(bucket_name) # "Special Error" if no tags: if len(bucket.tagging.tag_set.tags) == 0: @@ -330,7 +377,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 404, {}, template.render(bucket_name=bucket_name) template = self.response_template(S3_BUCKET_TAGGING_RESPONSE) return template.render(bucket=bucket) - elif 'logging' in querystring: + elif "logging" in querystring: bucket = self.backend.get_bucket(bucket_name) if not bucket.logging: template = self.response_template(S3_NO_LOGGING_CONFIG) @@ -357,14 +404,20 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 200, {}, template.render() template = self.response_template(S3_BUCKET_ACCELERATE) return template.render(bucket=bucket) + elif "publicAccessBlock" in querystring: + public_block_config = self.backend.get_bucket_public_access_block( + bucket_name + ) + template = self.response_template(S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION) + return template.render(public_block_config=public_block_config) - elif 'versions' in querystring: - delimiter = querystring.get('delimiter', [None])[0] - encoding_type = querystring.get('encoding-type', [None])[0] - key_marker = querystring.get('key-marker', [None])[0] - max_keys = querystring.get('max-keys', [None])[0] - prefix = querystring.get('prefix', [''])[0] - version_id_marker = querystring.get('version-id-marker', [None])[0] + elif "versions" in querystring: + delimiter = querystring.get("delimiter", [None])[0] + encoding_type = querystring.get("encoding-type", [None])[0] + key_marker = querystring.get("key-marker", [None])[0] + max_keys = querystring.get("max-keys", [None])[0] + prefix = querystring.get("prefix", [""])[0] + version_id_marker = querystring.get("version-id-marker", [None])[0] bucket = self.backend.get_bucket(bucket_name) versions = self.backend.get_bucket_versions( @@ -374,7 +427,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): key_marker=key_marker, max_keys=max_keys, version_id_marker=version_id_marker, - prefix=prefix + prefix=prefix, ) latest_versions = self.backend.get_bucket_latest_versions( bucket_name=bucket_name @@ -387,48 +440,62 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: delete_marker_list.append(version) template = self.response_template(S3_BUCKET_GET_VERSIONS) - return 200, {}, template.render( - key_list=key_list, - delete_marker_list=delete_marker_list, - latest_versions=latest_versions, - bucket=bucket, - prefix='', - max_keys=1000, - delimiter='', - is_truncated='false', + return ( + 200, + {}, + template.render( + key_list=key_list, + delete_marker_list=delete_marker_list, + latest_versions=latest_versions, + bucket=bucket, + prefix="", + max_keys=1000, + delimiter="", + is_truncated="false", + ), ) - elif querystring.get('list-type', [None])[0] == '2': + elif querystring.get("list-type", [None])[0] == "2": return 200, {}, self._handle_list_objects_v2(bucket_name, querystring) bucket = self.backend.get_bucket(bucket_name) - prefix = querystring.get('prefix', [None])[0] + prefix = querystring.get("prefix", [None])[0] if prefix and isinstance(prefix, six.binary_type): prefix = prefix.decode("utf-8") - delimiter = querystring.get('delimiter', [None])[0] - max_keys = int(querystring.get('max-keys', [1000])[0]) - marker = querystring.get('marker', [None])[0] + delimiter = querystring.get("delimiter", [None])[0] + max_keys = int(querystring.get("max-keys", [1000])[0]) + marker = querystring.get("marker", [None])[0] result_keys, result_folders = self.backend.prefix_query( - bucket, prefix, delimiter) + bucket, prefix, delimiter + ) if marker: result_keys = self._get_results_from_token(result_keys, marker) - result_keys, is_truncated, _ = self._truncate_result(result_keys, max_keys) + result_keys, is_truncated, next_marker = self._truncate_result( + result_keys, max_keys + ) template = self.response_template(S3_BUCKET_GET_RESPONSE) - return 200, {}, template.render( - bucket=bucket, - prefix=prefix, - delimiter=delimiter, - result_keys=result_keys, - result_folders=result_folders, - is_truncated=is_truncated, - max_keys=max_keys + return ( + 200, + {}, + template.render( + bucket=bucket, + prefix=prefix, + delimiter=delimiter, + result_keys=result_keys, + result_folders=result_folders, + is_truncated=is_truncated, + next_marker=next_marker, + max_keys=max_keys, + ), ) def _set_action(self, action_resource_type, method, querystring): action_set = False - for action_in_querystring, action in ACTION_MAP[action_resource_type][method].items(): + for action_in_querystring, action in ACTION_MAP[action_resource_type][ + method + ].items(): if action_in_querystring in querystring: self.data["Action"] = action action_set = True @@ -439,35 +506,37 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): template = self.response_template(S3_BUCKET_GET_RESPONSE_V2) bucket = self.backend.get_bucket(bucket_name) - prefix = querystring.get('prefix', [None])[0] + prefix = querystring.get("prefix", [None])[0] if prefix and isinstance(prefix, six.binary_type): prefix = prefix.decode("utf-8") - delimiter = querystring.get('delimiter', [None])[0] + delimiter = querystring.get("delimiter", [None])[0] result_keys, result_folders = self.backend.prefix_query( - bucket, prefix, delimiter) + bucket, prefix, delimiter + ) - fetch_owner = querystring.get('fetch-owner', [False])[0] - max_keys = int(querystring.get('max-keys', [1000])[0]) - continuation_token = querystring.get('continuation-token', [None])[0] - start_after = querystring.get('start-after', [None])[0] + fetch_owner = querystring.get("fetch-owner", [False])[0] + max_keys = int(querystring.get("max-keys", [1000])[0]) + continuation_token = querystring.get("continuation-token", [None])[0] + start_after = querystring.get("start-after", [None])[0] + + # sort the combination of folders and keys into lexicographical order + all_keys = result_keys + result_folders + all_keys.sort(key=self._get_name) if continuation_token or start_after: limit = continuation_token or start_after - if not delimiter: - result_keys = self._get_results_from_token(result_keys, limit) - else: - result_folders = self._get_results_from_token(result_folders, limit) + all_keys = self._get_results_from_token(all_keys, limit) - if not delimiter: - result_keys, is_truncated, next_continuation_token = self._truncate_result(result_keys, max_keys) - else: - result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys) + truncated_keys, is_truncated, next_continuation_token = self._truncate_result( + all_keys, max_keys + ) + result_keys, result_folders = self._split_truncated_keys(truncated_keys) key_count = len(result_keys) + len(result_folders) return template.render( bucket=bucket, - prefix=prefix or '', + prefix=prefix or "", delimiter=delimiter, key_count=key_count, result_keys=result_keys, @@ -476,9 +545,27 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): max_keys=max_keys, is_truncated=is_truncated, next_continuation_token=next_continuation_token, - start_after=None if continuation_token else start_after + start_after=None if continuation_token else start_after, ) + @staticmethod + def _get_name(key): + if isinstance(key, FakeKey): + return key.name + else: + return key + + @staticmethod + def _split_truncated_keys(truncated_keys): + result_keys = [] + result_folders = [] + for key in truncated_keys: + if isinstance(key, FakeKey): + result_keys.append(key) + else: + result_folders.append(key) + return result_keys, result_folders + def _get_results_from_token(self, result_keys, token): continuation_index = 0 for key in result_keys: @@ -489,41 +576,43 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _truncate_result(self, result_keys, max_keys): if len(result_keys) > max_keys: - is_truncated = 'true' + is_truncated = "true" result_keys = result_keys[:max_keys] item = result_keys[-1] - next_continuation_token = (item.name if isinstance(item, FakeKey) else item) + next_continuation_token = item.name if isinstance(item, FakeKey) else item else: - is_truncated = 'false' + is_truncated = "false" next_continuation_token = None return result_keys, is_truncated, next_continuation_token - def _bucket_response_put(self, request, body, region_name, bucket_name, querystring): - if not request.headers.get('Content-Length'): + def _bucket_response_put( + self, request, body, region_name, bucket_name, querystring + ): + if not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" self._set_action("BUCKET", "PUT", querystring) self._authenticate_and_authorize_s3_action() - if 'versioning' in querystring: - ver = re.search('([A-Za-z]+)', body.decode()) + if "versioning" in querystring: + ver = re.search("([A-Za-z]+)", body.decode()) if ver: self.backend.set_bucket_versioning(bucket_name, ver.group(1)) template = self.response_template(S3_BUCKET_VERSIONING) return template.render(bucket_versioning_status=ver.group(1)) else: return 404, {}, "" - elif 'lifecycle' in querystring: - rules = xmltodict.parse(body)['LifecycleConfiguration']['Rule'] + elif "lifecycle" in querystring: + rules = xmltodict.parse(body)["LifecycleConfiguration"]["Rule"] if not isinstance(rules, list): # If there is only one rule, xmldict returns just the item rules = [rules] self.backend.set_bucket_lifecycle(bucket_name, rules) return "" - elif 'policy' in querystring: + elif "policy" in querystring: self.backend.set_bucket_policy(bucket_name, body) - return 'True' - elif 'acl' in querystring: + return "True" + elif "acl" in querystring: # Headers are first. If not set, then look at the body (consistent with the documentation): acls = self._acl_from_headers(request.headers) if not acls: @@ -534,7 +623,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): tagging = self._bucket_tagging_from_xml(body) self.backend.put_bucket_tagging(bucket_name, tagging) return "" - elif 'website' in querystring: + elif "website" in querystring: self.backend.set_bucket_website_configuration(bucket_name, body) return "" elif "cors" in querystring: @@ -545,14 +634,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): raise MalformedXML() elif "logging" in querystring: try: - self.backend.put_bucket_logging(bucket_name, self._logging_from_xml(body)) + self.backend.put_bucket_logging( + bucket_name, self._logging_from_xml(body) + ) return "" except KeyError: raise MalformedXML() elif "notification" in querystring: try: - self.backend.put_bucket_notification_configuration(bucket_name, - self._notification_config_from_xml(body)) + self.backend.put_bucket_notification_configuration( + bucket_name, self._notification_config_from_xml(body) + ) return "" except KeyError: raise MalformedXML() @@ -561,25 +653,46 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): elif "accelerate" in querystring: try: accelerate_status = self._accelerate_config_from_xml(body) - self.backend.put_bucket_accelerate_configuration(bucket_name, accelerate_status) + self.backend.put_bucket_accelerate_configuration( + bucket_name, accelerate_status + ) return "" except KeyError: raise MalformedXML() except Exception as e: raise e + elif "publicAccessBlock" in querystring: + parsed_xml = xmltodict.parse(body) + parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) + + # If Python 2, fix the unicode strings: + if sys.version_info[0] < 3: + parsed_xml = { + "PublicAccessBlockConfiguration": py2_strip_unicode_keys( + dict(parsed_xml["PublicAccessBlockConfiguration"]) + ) + } + + self.backend.put_bucket_public_access_block( + bucket_name, parsed_xml["PublicAccessBlockConfiguration"] + ) + return "" + else: if body: # us-east-1, the default AWS region behaves a bit differently # - you should not use it as a location constraint --> it fails # - querying the location constraint returns None try: - forced_region = xmltodict.parse(body)['CreateBucketConfiguration']['LocationConstraint'] + forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][ + "LocationConstraint" + ] if forced_region == DEFAULT_REGION_NAME: raise S3ClientError( - 'InvalidLocationConstraint', - 'The specified location-constraint is not valid' + "InvalidLocationConstraint", + "The specified location-constraint is not valid", ) else: region_name = forced_region @@ -587,8 +700,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): pass try: - new_bucket = self.backend.create_bucket( - bucket_name, region_name) + new_bucket = self.backend.create_bucket(bucket_name, region_name) except BucketAlreadyExists: if region_name == DEFAULT_REGION_NAME: # us-east-1 has different behavior @@ -596,9 +708,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: raise - if 'x-amz-acl' in request.headers: + if "x-amz-acl" in request.headers: # TODO: Support the XML-based ACL format - self.backend.set_bucket_acl(bucket_name, self._acl_from_headers(request.headers)) + self.backend.set_bucket_acl( + bucket_name, self._acl_from_headers(request.headers) + ) template = self.response_template(S3_BUCKET_CREATE_RESPONSE) return 200, {}, template.render(bucket=new_bucket) @@ -607,7 +721,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._set_action("BUCKET", "DELETE", querystring) self._authenticate_and_authorize_s3_action() - if 'policy' in querystring: + if "policy" in querystring: self.backend.delete_bucket_policy(bucket_name, body) return 204, {}, "" elif "tagging" in querystring: @@ -616,10 +730,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): elif "cors" in querystring: self.backend.delete_bucket_cors(bucket_name) return 204, {}, "" - elif 'lifecycle' in querystring: + elif "lifecycle" in querystring: bucket = self.backend.get_bucket(bucket_name) bucket.delete_lifecycle() return 204, {}, "" + elif "publicAccessBlock" in querystring: + self.backend.delete_bucket_public_access_block(bucket_name) + return 204, {}, "" removed_bucket = self.backend.delete_bucket(bucket_name) @@ -629,12 +746,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 204, {}, template.render(bucket=removed_bucket) else: # Tried to delete a bucket that still has keys - template = self.response_template( - S3_DELETE_BUCKET_WITH_ITEMS_ERROR) + template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) return 409, {}, template.render(bucket=removed_bucket) def _bucket_response_post(self, request, body, bucket_name): - if not request.headers.get('Content-Length'): + if not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" path = self._get_path(request) @@ -649,7 +765,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._authenticate_and_authorize_s3_action() # POST to bucket-url should create file from form - if hasattr(request, 'form'): + if hasattr(request, "form"): # Not HTTPretty form = request.form else: @@ -657,15 +773,15 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): body = body.decode() form = {} - for kv in body.split('&'): - k, v = kv.split('=') + for kv in body.split("&"): + k, v = kv.split("=") form[k] = v - key = form['key'] - if 'file' in form: - f = form['file'] + key = form["key"] + if "file" in form: + f = form["file"] else: - f = request.files['file'].stream.read() + f = request.files["file"].stream.read() new_key = self.backend.set_key(bucket_name, key, f) @@ -680,13 +796,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if isinstance(request, HTTPrettyRequest): path = request.path else: - path = request.full_path if hasattr(request, 'full_path') else path_url(request.url) + path = ( + request.full_path + if hasattr(request, "full_path") + else path_url(request.url) + ) return path def _bucket_response_delete_keys(self, request, body, bucket_name): template = self.response_template(S3_DELETE_KEYS_RESPONSE) - keys = minidom.parseString(body).getElementsByTagName('Key') + keys = minidom.parseString(body).getElementsByTagName("Key") deleted_names = [] error_names = [] if len(keys) == 0: @@ -694,27 +814,32 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): for k in keys: key_name = k.firstChild.nodeValue - success = self.backend.delete_key(bucket_name, key_name) + success = self.backend.delete_key( + bucket_name, undo_clean_key_name(key_name) + ) if success: deleted_names.append(key_name) else: error_names.append(key_name) - return 200, {}, template.render(deleted=deleted_names, delete_errors=error_names) + return ( + 200, + {}, + template.render(deleted=deleted_names, delete_errors=error_names), + ) def _handle_range_header(self, request, headers, response_content): response_headers = {} length = len(response_content) last = length - 1 - _, rspec = request.headers.get('range').split('=') - if ',' in rspec: - raise NotImplementedError( - "Multiple range specifiers not supported") + _, rspec = request.headers.get("range").split("=") + if "," in rspec: + raise NotImplementedError("Multiple range specifiers not supported") def toint(i): return int(i) if i else None - begin, end = map(toint, rspec.split('-')) + begin, end = map(toint, rspec.split("-")) if begin is not None: # byte range end = last if end is None else min(end, last) elif end is not None: # suffix byte range @@ -724,16 +849,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 400, response_headers, "" if begin < 0 or end > last or begin > min(end, last): return 416, response_headers, "" - response_headers['content-range'] = "bytes {0}-{1}/{2}".format( - begin, end, length) - return 206, response_headers, response_content[begin:end + 1] + response_headers["content-range"] = "bytes {0}-{1}/{2}".format( + begin, end, length + ) + return 206, response_headers, response_content[begin : end + 1] def key_response(self, request, full_url, headers): self.method = request.method self.path = self._get_path(request) self.headers = request.headers - if 'host' not in self.headers: - self.headers['host'] = urlparse(full_url).netloc + if "host" not in self.headers: + self.headers["host"] = urlparse(full_url).netloc response_headers = {} try: response = self._key_response(request, full_url, headers) @@ -746,8 +872,10 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: status_code, response_headers, response_content = response - if status_code == 200 and 'range' in request.headers: - return self._handle_range_header(request, response_headers, response_content) + if status_code == 200 and "range" in request.headers: + return self._handle_range_header( + request, response_headers, response_content + ) return status_code, response_headers, response_content def _key_response(self, request, full_url, headers): @@ -764,72 +892,84 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # Here we deny public access to private files by checking the # ACL and checking for the mere presence of an Authorization # header. - if 'Authorization' not in request.headers: - if hasattr(request, 'url'): - signed_url = 'Signature=' in request.url - elif hasattr(request, 'requestline'): - signed_url = 'Signature=' in request.path + if "Authorization" not in request.headers: + if hasattr(request, "url"): + signed_url = "Signature=" in request.url + elif hasattr(request, "requestline"): + signed_url = "Signature=" in request.path key = self.backend.get_key(bucket_name, key_name) if key: if not key.acl.public_read and not signed_url: return 403, {}, "" - if hasattr(request, 'body'): + if hasattr(request, "body"): # Boto body = request.body - if hasattr(body, 'read'): + if hasattr(body, "read"): body = body.read() else: # Flask server body = request.data if body is None: - body = b'' + body = b"" - if method == 'GET': - return self._key_response_get(bucket_name, query, key_name, headers=request.headers) - elif method == 'PUT': - return self._key_response_put(request, body, bucket_name, query, key_name, headers) - elif method == 'HEAD': - return self._key_response_head(bucket_name, query, key_name, headers=request.headers) - elif method == 'DELETE': + if method == "GET": + return self._key_response_get( + bucket_name, query, key_name, headers=request.headers + ) + elif method == "PUT": + return self._key_response_put( + request, body, bucket_name, query, key_name, headers + ) + elif method == "HEAD": + return self._key_response_head( + bucket_name, query, key_name, headers=request.headers + ) + elif method == "DELETE": return self._key_response_delete(bucket_name, query, key_name) - elif method == 'POST': + elif method == "POST": return self._key_response_post(request, body, bucket_name, query, key_name) else: raise NotImplementedError( - "Method {0} has not been implemented in the S3 backend yet".format(method)) + "Method {0} has not been implemented in the S3 backend yet".format( + method + ) + ) def _key_response_get(self, bucket_name, query, key_name, headers): self._set_action("KEY", "GET", query) self._authenticate_and_authorize_s3_action() response_headers = {} - if query.get('uploadId'): - upload_id = query['uploadId'][0] + if query.get("uploadId"): + upload_id = query["uploadId"][0] parts = self.backend.list_multipart(bucket_name, upload_id) template = self.response_template(S3_MULTIPART_LIST_RESPONSE) - return 200, response_headers, template.render( - bucket_name=bucket_name, - key_name=key_name, - upload_id=upload_id, - count=len(parts), - parts=parts + return ( + 200, + response_headers, + template.render( + bucket_name=bucket_name, + key_name=key_name, + upload_id=upload_id, + count=len(parts), + parts=parts, + ), ) - version_id = query.get('versionId', [None])[0] - if_modified_since = headers.get('If-Modified-Since', None) - key = self.backend.get_key( - bucket_name, key_name, version_id=version_id) + version_id = query.get("versionId", [None])[0] + if_modified_since = headers.get("If-Modified-Since", None) + key = self.backend.get_key(bucket_name, key_name, version_id=version_id) if key is None: raise MissingKey(key_name) if if_modified_since: if_modified_since = str_to_rfc_1123_datetime(if_modified_since) if if_modified_since and key.last_modified < if_modified_since: - return 304, response_headers, 'Not Modified' - if 'acl' in query: + return 304, response_headers, "Not Modified" + if "acl" in query: template = self.response_template(S3_OBJECT_ACL_RESPONSE) return 200, response_headers, template.render(obj=key) - if 'tagging' in query: + if "tagging" in query: template = self.response_template(S3_OBJECT_TAGGING_RESPONSE) return 200, response_headers, template.render(obj=key) @@ -842,16 +982,21 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._authenticate_and_authorize_s3_action() response_headers = {} - if query.get('uploadId') and query.get('partNumber'): - upload_id = query['uploadId'][0] - part_number = int(query['partNumber'][0]) - if 'x-amz-copy-source' in request.headers: + if query.get("uploadId") and query.get("partNumber"): + upload_id = query["uploadId"][0] + part_number = int(query["partNumber"][0]) + if "x-amz-copy-source" in request.headers: src = unquote(request.headers.get("x-amz-copy-source")).lstrip("/") src_bucket, src_key = src.split("/", 1) - src_key, src_version_id = src_key.split("?versionId=") if "?versionId=" in src_key else (src_key, None) - src_range = request.headers.get( - 'x-amz-copy-source-range', '').split("bytes=")[-1] + src_key, src_version_id = ( + src_key.split("?versionId=") + if "?versionId=" in src_key + else (src_key, None) + ) + src_range = request.headers.get("x-amz-copy-source-range", "").split( + "bytes=" + )[-1] try: start_byte, end_byte = src_range.split("-") @@ -861,70 +1006,87 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if self.backend.get_key(src_bucket, src_key, version_id=src_version_id): key = self.backend.copy_part( - bucket_name, upload_id, part_number, src_bucket, - src_key, src_version_id, start_byte, end_byte) + bucket_name, + upload_id, + part_number, + src_bucket, + src_key, + src_version_id, + start_byte, + end_byte, + ) else: return 404, response_headers, "" template = self.response_template(S3_MULTIPART_UPLOAD_RESPONSE) response = template.render(part=key) else: - key = self.backend.set_part( - bucket_name, upload_id, part_number, body) + key = self.backend.set_part(bucket_name, upload_id, part_number, body) response = "" response_headers.update(key.response_dict) return 200, response_headers, response - storage_class = request.headers.get('x-amz-storage-class', 'STANDARD') + storage_class = request.headers.get("x-amz-storage-class", "STANDARD") acl = self._acl_from_headers(request.headers) if acl is None: acl = self.backend.get_bucket(bucket_name).acl tagging = self._tagging_from_headers(request.headers) - if 'acl' in query: + if "acl" in query: key = self.backend.get_key(bucket_name, key_name) # TODO: Support the XML-based ACL format key.set_acl(acl) return 200, response_headers, "" - if 'tagging' in query: + if "tagging" in query: + if "versionId" in query: + version_id = query["versionId"][0] + else: + version_id = None tagging = self._tagging_from_xml(body) - self.backend.set_key_tagging(bucket_name, key_name, tagging) + self.backend.set_key_tagging(bucket_name, key_name, tagging, version_id) return 200, response_headers, "" - if 'x-amz-copy-source' in request.headers: + if "x-amz-copy-source" in request.headers: # Copy key # you can have a quoted ?version=abc with a version Id, so work on # we need to parse the unquoted string first - src_key = clean_key_name(request.headers.get("x-amz-copy-source")) + src_key = request.headers.get("x-amz-copy-source") if isinstance(src_key, six.binary_type): - src_key = src_key.decode('utf-8') + src_key = src_key.decode("utf-8") src_key_parsed = urlparse(src_key) - src_bucket, src_key = unquote(src_key_parsed.path).\ - lstrip("/").split("/", 1) - src_version_id = parse_qs(src_key_parsed.query).get( - 'versionId', [None])[0] + src_bucket, src_key = ( + clean_key_name(src_key_parsed.path).lstrip("/").split("/", 1) + ) + src_version_id = parse_qs(src_key_parsed.query).get("versionId", [None])[0] key = self.backend.get_key(src_bucket, src_key, version_id=src_version_id) if key is not None: if key.storage_class in ["GLACIER", "DEEP_ARCHIVE"]: raise ObjectNotInActiveTierError(key) - self.backend.copy_key(src_bucket, src_key, bucket_name, key_name, - storage=storage_class, acl=acl, src_version_id=src_version_id) + self.backend.copy_key( + src_bucket, + src_key, + bucket_name, + key_name, + storage=storage_class, + acl=acl, + src_version_id=src_version_id, + ) else: return 404, response_headers, "" new_key = self.backend.get_key(bucket_name, key_name) - mdirective = request.headers.get('x-amz-metadata-directive') - if mdirective is not None and mdirective == 'REPLACE': + mdirective = request.headers.get("x-amz-metadata-directive") + if mdirective is not None and mdirective == "REPLACE": metadata = metadata_from_headers(request.headers) new_key.set_metadata(metadata, replace=True) template = self.response_template(S3_OBJECT_COPY_RESPONSE) response_headers.update(new_key.response_dict) return 200, response_headers, template.render(key=new_key) - streaming_request = hasattr(request, 'streaming') and request.streaming - closing_connection = headers.get('connection') == 'close' + streaming_request = hasattr(request, "streaming") and request.streaming + closing_connection = headers.get("connection") == "close" if closing_connection and streaming_request: # Closing the connection of a streaming request. No more data new_key = self.backend.get_key(bucket_name, key_name) @@ -933,13 +1095,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): new_key = self.backend.append_to_key(bucket_name, key_name, body) else: # Initial data - new_key = self.backend.set_key(bucket_name, key_name, body, - storage=storage_class) + new_key = self.backend.set_key( + bucket_name, key_name, body, storage=storage_class + ) request.streaming = True metadata = metadata_from_headers(request.headers) new_key.set_metadata(metadata) new_key.set_acl(acl) - new_key.website_redirect_location = request.headers.get('x-amz-website-redirect-location') + new_key.website_redirect_location = request.headers.get( + "x-amz-website-redirect-location" + ) new_key.set_tagging(tagging) template = self.response_template(S3_OBJECT_RESPONSE) @@ -948,27 +1113,24 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _key_response_head(self, bucket_name, query, key_name, headers): response_headers = {} - version_id = query.get('versionId', [None])[0] - part_number = query.get('partNumber', [None])[0] + version_id = query.get("versionId", [None])[0] + part_number = query.get("partNumber", [None])[0] if part_number: part_number = int(part_number) - if_modified_since = headers.get('If-Modified-Since', None) + if_modified_since = headers.get("If-Modified-Since", None) if if_modified_since: if_modified_since = str_to_rfc_1123_datetime(if_modified_since) key = self.backend.get_key( - bucket_name, - key_name, - version_id=version_id, - part_number=part_number + bucket_name, key_name, version_id=version_id, part_number=part_number ) if key: response_headers.update(key.metadata) response_headers.update(key.response_dict) if if_modified_since and key.last_modified < if_modified_since: - return 304, response_headers, 'Not Modified' + return 304, response_headers, "Not Modified" else: return 200, response_headers, "" else: @@ -991,20 +1153,20 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if not parsed_xml["AccessControlPolicy"]["AccessControlList"].get("Grant"): raise MalformedACLError() - permissions = [ - "READ", - "WRITE", - "READ_ACP", - "WRITE_ACP", - "FULL_CONTROL" - ] + permissions = ["READ", "WRITE", "READ_ACP", "WRITE_ACP", "FULL_CONTROL"] - if not isinstance(parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], list): - parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"] = \ - [parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"]] + if not isinstance( + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], list + ): + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"] = [ + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"] + ] - grants = self._get_grants_from_xml(parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], - MalformedACLError, permissions) + grants = self._get_grants_from_xml( + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], + MalformedACLError, + permissions, + ) return FakeAcl(grants) def _get_grants_from_xml(self, grant_list, exception_type, permissions): @@ -1013,42 +1175,54 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if grant.get("Permission", "") not in permissions: raise exception_type() - if grant["Grantee"].get("@xsi:type", "") not in ["CanonicalUser", "AmazonCustomerByEmail", "Group"]: + if grant["Grantee"].get("@xsi:type", "") not in [ + "CanonicalUser", + "AmazonCustomerByEmail", + "Group", + ]: raise exception_type() # TODO: Verify that the proper grantee data is supplied based on the type. - grants.append(FakeGrant( - [FakeGrantee(id=grant["Grantee"].get("ID", ""), display_name=grant["Grantee"].get("DisplayName", ""), - uri=grant["Grantee"].get("URI", ""))], - [grant["Permission"]]) + grants.append( + FakeGrant( + [ + FakeGrantee( + id=grant["Grantee"].get("ID", ""), + display_name=grant["Grantee"].get("DisplayName", ""), + uri=grant["Grantee"].get("URI", ""), + ) + ], + [grant["Permission"]], + ) ) return grants def _acl_from_headers(self, headers): - canned_acl = headers.get('x-amz-acl', '') + canned_acl = headers.get("x-amz-acl", "") if canned_acl: return get_canned_acl(canned_acl) grants = [] for header, value in headers.items(): - if not header.startswith('x-amz-grant-'): + if not header.startswith("x-amz-grant-"): continue permission = { - 'read': 'READ', - 'write': 'WRITE', - 'read-acp': 'READ_ACP', - 'write-acp': 'WRITE_ACP', - 'full-control': 'FULL_CONTROL', - }[header[len('x-amz-grant-'):]] + "read": "READ", + "write": "WRITE", + "read-acp": "READ_ACP", + "write-acp": "WRITE_ACP", + "full-control": "FULL_CONTROL", + }[header[len("x-amz-grant-") :]] grantees = [] for key_and_value in value.split(","): key, value = re.match( - '([^=]+)="([^"]+)"', key_and_value.strip()).groups() - if key.lower() == 'id': + '([^=]+)="([^"]+)"', key_and_value.strip() + ).groups() + if key.lower() == "id": grantees.append(FakeGrantee(id=value)) else: grantees.append(FakeGrantee(uri=value)) @@ -1060,8 +1234,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return None def _tagging_from_headers(self, headers): - if headers.get('x-amz-tagging'): - parsed_header = parse_qs(headers['x-amz-tagging'], keep_blank_values=True) + if headers.get("x-amz-tagging"): + parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True) tags = [] for tag in parsed_header.items(): tags.append(FakeTag(tag[0], tag[1][0])) @@ -1073,11 +1247,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return FakeTagging() def _tagging_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml, force_list={'Tag': True}) + parsed_xml = xmltodict.parse(xml, force_list={"Tag": True}) tags = [] - for tag in parsed_xml['Tagging']['TagSet']['Tag']: - tags.append(FakeTag(tag['Key'], tag['Value'])) + for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]: + tags.append(FakeTag(tag["Key"], tag["Value"])) tag_set = FakeTagSet(tags) tagging = FakeTagging(tag_set) @@ -1088,14 +1262,18 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): tags = [] # Optional if no tags are being sent: - if parsed_xml['Tagging'].get('TagSet'): + if parsed_xml["Tagging"].get("TagSet"): # If there is only 1 tag, then it's not a list: - if not isinstance(parsed_xml['Tagging']['TagSet']['Tag'], list): - tags.append(FakeTag(parsed_xml['Tagging']['TagSet']['Tag']['Key'], - parsed_xml['Tagging']['TagSet']['Tag']['Value'])) + if not isinstance(parsed_xml["Tagging"]["TagSet"]["Tag"], list): + tags.append( + FakeTag( + parsed_xml["Tagging"]["TagSet"]["Tag"]["Key"], + parsed_xml["Tagging"]["TagSet"]["Tag"]["Value"], + ) + ) else: - for tag in parsed_xml['Tagging']['TagSet']['Tag']: - tags.append(FakeTag(tag['Key'], tag['Value'])) + for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]: + tags.append(FakeTag(tag["Key"], tag["Value"])) tag_set = FakeTagSet(tags) tagging = FakeTagging(tag_set) @@ -1123,25 +1301,34 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # Get the ACLs: if parsed_xml["BucketLoggingStatus"]["LoggingEnabled"].get("TargetGrants"): - permissions = [ - "READ", - "WRITE", - "FULL_CONTROL" - ] - if not isinstance(parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"]["Grant"], list): + permissions = ["READ", "WRITE", "FULL_CONTROL"] + if not isinstance( + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"][ + "Grant" + ], + list, + ): target_grants = self._get_grants_from_xml( - [parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"]["Grant"]], + [ + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"][ + "TargetGrants" + ]["Grant"] + ], MalformedXML, - permissions + permissions, ) else: target_grants = self._get_grants_from_xml( - parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"]["Grant"], + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"][ + "Grant" + ], MalformedXML, - permissions + permissions, ) - parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"] = target_grants + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"][ + "TargetGrants" + ] = target_grants return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"] @@ -1156,31 +1343,36 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): notification_fields = [ ("Topic", "sns"), ("Queue", "sqs"), - ("CloudFunction", "lambda") + ("CloudFunction", "lambda"), ] event_names = [ - 's3:ReducedRedundancyLostObject', - 's3:ObjectCreated:*', - 's3:ObjectCreated:Put', - 's3:ObjectCreated:Post', - 's3:ObjectCreated:Copy', - 's3:ObjectCreated:CompleteMultipartUpload', - 's3:ObjectRemoved:*', - 's3:ObjectRemoved:Delete', - 's3:ObjectRemoved:DeleteMarkerCreated' + "s3:ReducedRedundancyLostObject", + "s3:ObjectCreated:*", + "s3:ObjectCreated:Put", + "s3:ObjectCreated:Post", + "s3:ObjectCreated:Copy", + "s3:ObjectCreated:CompleteMultipartUpload", + "s3:ObjectRemoved:*", + "s3:ObjectRemoved:Delete", + "s3:ObjectRemoved:DeleteMarkerCreated", ] - found_notifications = 0 # Tripwire -- if this is not ever set, then there were no notifications + found_notifications = ( + 0 # Tripwire -- if this is not ever set, then there were no notifications + ) for name, arn_string in notification_fields: # 1st verify that the proper notification configuration has been passed in (with an ARN that is close # to being correct -- nothing too complex in the ARN logic): - the_notification = parsed_xml["NotificationConfiguration"].get("{}Configuration".format(name)) + the_notification = parsed_xml["NotificationConfiguration"].get( + "{}Configuration".format(name) + ) if the_notification: found_notifications += 1 if not isinstance(the_notification, list): - the_notification = parsed_xml["NotificationConfiguration"]["{}Configuration".format(name)] \ - = [the_notification] + the_notification = parsed_xml["NotificationConfiguration"][ + "{}Configuration".format(name) + ] = [the_notification] for n in the_notification: if not n[name].startswith("arn:aws:{}:".format(arn_string)): @@ -1202,7 +1394,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): raise KeyError() if not isinstance(n["Filter"]["S3Key"]["FilterRule"], list): - n["Filter"]["S3Key"]["FilterRule"] = [n["Filter"]["S3Key"]["FilterRule"]] + n["Filter"]["S3Key"]["FilterRule"] = [ + n["Filter"]["S3Key"]["FilterRule"] + ] for filter_rule in n["Filter"]["S3Key"]["FilterRule"]: assert filter_rule["Name"] in ["suffix", "prefix"] @@ -1215,61 +1409,55 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _accelerate_config_from_xml(self, xml): parsed_xml = xmltodict.parse(xml) - config = parsed_xml['AccelerateConfiguration'] - return config['Status'] + config = parsed_xml["AccelerateConfiguration"] + return config["Status"] def _key_response_delete(self, bucket_name, query, key_name): self._set_action("KEY", "DELETE", query) self._authenticate_and_authorize_s3_action() - if query.get('uploadId'): - upload_id = query['uploadId'][0] + if query.get("uploadId"): + upload_id = query["uploadId"][0] self.backend.cancel_multipart(bucket_name, upload_id) return 204, {}, "" - version_id = query.get('versionId', [None])[0] + version_id = query.get("versionId", [None])[0] self.backend.delete_key(bucket_name, key_name, version_id=version_id) template = self.response_template(S3_DELETE_OBJECT_SUCCESS) return 204, {}, template.render() def _complete_multipart_body(self, body): - ps = minidom.parseString(body).getElementsByTagName('Part') + ps = minidom.parseString(body).getElementsByTagName("Part") prev = 0 for p in ps: - pn = int(p.getElementsByTagName( - 'PartNumber')[0].firstChild.wholeText) + pn = int(p.getElementsByTagName("PartNumber")[0].firstChild.wholeText) if pn <= prev: raise InvalidPartOrder() - yield (pn, p.getElementsByTagName('ETag')[0].firstChild.wholeText) + yield (pn, p.getElementsByTagName("ETag")[0].firstChild.wholeText) def _key_response_post(self, request, body, bucket_name, query, key_name): self._set_action("KEY", "POST", query) self._authenticate_and_authorize_s3_action() - if body == b'' and 'uploads' in query: + if body == b"" and "uploads" in query: metadata = metadata_from_headers(request.headers) - multipart = self.backend.initiate_multipart( - bucket_name, key_name, metadata) + multipart = self.backend.initiate_multipart(bucket_name, key_name, metadata) template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE) response = template.render( - bucket_name=bucket_name, - key_name=key_name, - upload_id=multipart.id, + bucket_name=bucket_name, key_name=key_name, upload_id=multipart.id ) return 200, {}, response - if query.get('uploadId'): + if query.get("uploadId"): body = self._complete_multipart_body(body) - upload_id = query['uploadId'][0] + upload_id = query["uploadId"][0] key = self.backend.complete_multipart(bucket_name, upload_id, body) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) return template.render( - bucket_name=bucket_name, - key_name=key.name, - etag=key.etag, + bucket_name=bucket_name, key_name=key.name, etag=key.etag ) - elif 'restore' in query: - es = minidom.parseString(body).getElementsByTagName('Days') + elif "restore" in query: + es = minidom.parseString(body).getElementsByTagName("Days") days = es[0].childNodes[0].wholeText key = self.backend.get_key(bucket_name, key_name) r = 202 @@ -1279,7 +1467,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return r, {}, "" else: raise NotImplementedError( - "Method POST had only been implemented for multipart uploads and restore operations, so far") + "Method POST had only been implemented for multipart uploads and restore operations, so far" + ) S3ResponseInstance = ResponseObject(s3_backend) @@ -1293,7 +1482,7 @@ S3_ALL_BUCKETS = """ {{ max_keys }} {{ delimiter }} {{ is_truncated }} + {% if next_marker %} + {{ next_marker }} + {% endif %} {% for key in result_keys %} {{ key.name }} @@ -1399,7 +1591,9 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """ {{ rule.id }} {% if rule.filter %} + {% if rule.filter.prefix != None %} {{ rule.filter.prefix }} + {% endif %} {% if rule.filter.tag %} {{ rule.filter.tag.key }} @@ -1408,7 +1602,9 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """ {% endif %} {% if rule.filter.and_filter %} + {% if rule.filter.and_filter.prefix != None %} {{ rule.filter.and_filter.prefix }} + {% endif %} {% for tag in rule.filter.and_filter.tags %} {{ tag.key }} @@ -1419,7 +1615,9 @@ S3_BUCKET_LIFECYCLE_CONFIGURATION = """ {% endif %} {% else %} - {{ rule.prefix if rule.prefix != None }} + {% if rule.prefix != None %} + {{ rule.prefix }} + {% endif %} {% endif %} {{ rule.status }} {% if rule.storage_class %} @@ -1689,7 +1887,8 @@ S3_MULTIPART_COMPLETE_RESPONSE = """ """ -S3_ALL_MULTIPARTS = """ +S3_ALL_MULTIPARTS = ( + """ {{ bucket_name }} @@ -1701,7 +1900,9 @@ S3_ALL_MULTIPARTS = """ {{ upload.key_name }} {{ upload.id }} - arn:aws:iam::123456789012:user/user1-11111a31-17b5-4fb7-9df5-b111111f13de + arn:aws:iam::""" + + ACCOUNT_ID + + """:user/user1-11111a31-17b5-4fb7-9df5-b111111f13de user1-11111a31-17b5-4fb7-9df5-b111111f13de @@ -1714,6 +1915,7 @@ S3_ALL_MULTIPARTS = """ {% endfor %} """ +) S3_NO_POLICY = """ @@ -1886,3 +2088,12 @@ S3_BUCKET_ACCELERATE = """ S3_BUCKET_ACCELERATE_NOT_SET = """ """ + +S3_PUBLIC_ACCESS_BLOCK_CONFIGURATION = """ + + {{public_block_config.block_public_acls}} + {{public_block_config.ignore_public_acls}} + {{public_block_config.block_public_policy}} + {{public_block_config.restrict_public_buckets}} + +""" diff --git a/moto/s3/urls.py b/moto/s3/urls.py index fa81568a4..7241dbef1 100644 --- a/moto/s3/urls.py +++ b/moto/s3/urls.py @@ -4,15 +4,16 @@ from .responses import S3ResponseInstance url_bases = [ "https?://s3(.*).amazonaws.com", - r"https?://(?P[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com" + r"https?://(?P[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com", ] url_paths = { # subdomain bucket - '{0}/$': S3ResponseInstance.bucket_response, - + "{0}/$": S3ResponseInstance.bucket_response, # subdomain key of path-based bucket - '{0}/(?P[^/]+)/?$': S3ResponseInstance.ambiguous_response, + "{0}/(?P[^/]+)/?$": S3ResponseInstance.ambiguous_response, # path-based bucket + key - '{0}/(?P[^/]+)/(?P.+)': S3ResponseInstance.key_response, + "{0}/(?P[^/]+)/(?P.+)": S3ResponseInstance.key_response, + # subdomain bucket + key with empty first part of path + "{0}//(?P.*)$": S3ResponseInstance.key_response, } diff --git a/moto/s3/utils.py b/moto/s3/utils.py index 85a812aad..e7d9e5580 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -5,7 +5,7 @@ import os from boto.s3.key import Key import re import six -from six.moves.urllib.parse import urlparse, unquote +from six.moves.urllib.parse import urlparse, unquote, quote import sys @@ -16,19 +16,19 @@ bucket_name_regex = re.compile("(.+).s3(.*).amazonaws.com") def bucket_name_from_url(url): - if os.environ.get('S3_IGNORE_SUBDOMAIN_BUCKETNAME', '') in ['1', 'true']: + if os.environ.get("S3_IGNORE_SUBDOMAIN_BUCKETNAME", "") in ["1", "true"]: return None domain = urlparse(url).netloc - if domain.startswith('www.'): + if domain.startswith("www."): domain = domain[4:] - if 'amazonaws.com' in domain: + if "amazonaws.com" in domain: bucket_result = bucket_name_regex.search(domain) if bucket_result: return bucket_result.groups()[0] else: - if '.' in domain: + if "." in domain: return domain.split(".")[0] else: # No subdomain found. @@ -36,23 +36,23 @@ def bucket_name_from_url(url): REGION_URL_REGEX = re.compile( - r'^https?://(s3[-\.](?P.+)\.amazonaws\.com/(.+)|' - r'(.+)\.s3-(?P.+)\.amazonaws\.com)/?') + r"^https?://(s3[-\.](?P.+)\.amazonaws\.com/(.+)|" + r"(.+)\.s3-(?P.+)\.amazonaws\.com)/?" +) def parse_region_from_url(url): match = REGION_URL_REGEX.search(url) if match: - region = match.group('region1') or match.group('region2') + region = match.group("region1") or match.group("region2") else: - region = 'us-east-1' + region = "us-east-1" return region def metadata_from_headers(headers): metadata = {} - meta_regex = re.compile( - '^x-amz-meta-([a-zA-Z0-9\-_]+)$', flags=re.IGNORECASE) + meta_regex = re.compile("^x-amz-meta-([a-zA-Z0-9\-_]+)$", flags=re.IGNORECASE) for header, value in headers.items(): if isinstance(header, six.string_types): result = meta_regex.match(header) @@ -70,11 +70,16 @@ def metadata_from_headers(headers): def clean_key_name(key_name): if six.PY2: - return unquote(key_name.encode('utf-8')).decode('utf-8') - + return unquote(key_name.encode("utf-8")).decode("utf-8") return unquote(key_name) +def undo_clean_key_name(key_name): + if six.PY2: + return quote(key_name.encode("utf-8")).decode("utf-8") + return quote(key_name) + + class _VersionedKeyStore(dict): """ A simplified/modified version of Django's `MultiValueDict` taken from: @@ -135,6 +140,7 @@ class _VersionedKeyStore(dict): values = itervalues = _itervalues if sys.version_info[0] < 3: + def items(self): return list(self.iteritems()) diff --git a/moto/s3bucket_path/utils.py b/moto/s3bucket_path/utils.py index 1b9a034f4..d514a1b35 100644 --- a/moto/s3bucket_path/utils.py +++ b/moto/s3bucket_path/utils.py @@ -17,8 +17,10 @@ def parse_key_name(path): def is_delete_keys(request, path, bucket_name): return ( - path == u'/' + bucket_name + u'/?delete' or - path == u'/' + bucket_name + u'?delete' or - (path == u'/' + bucket_name and - getattr(request, "query_string", "") == "delete") + path == "/" + bucket_name + "/?delete" + or path == "/" + bucket_name + "?delete" + or ( + path == "/" + bucket_name + and getattr(request, "query_string", "") == "delete" + ) ) diff --git a/moto/secretsmanager/__init__.py b/moto/secretsmanager/__init__.py index c7fbb2869..5d41d07ae 100644 --- a/moto/secretsmanager/__init__.py +++ b/moto/secretsmanager/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import secretsmanager_backends from ..core.models import base_decorator -secretsmanager_backend = secretsmanager_backends['us-east-1'] +secretsmanager_backend = secretsmanager_backends["us-east-1"] mock_secretsmanager = base_decorator(secretsmanager_backends) diff --git a/moto/secretsmanager/exceptions.py b/moto/secretsmanager/exceptions.py index fa81b6d8b..bf717e20c 100644 --- a/moto/secretsmanager/exceptions.py +++ b/moto/secretsmanager/exceptions.py @@ -7,38 +7,53 @@ class SecretsManagerClientError(JsonRESTError): class ResourceNotFoundException(SecretsManagerClientError): - def __init__(self): + def __init__(self, message): self.code = 404 super(ResourceNotFoundException, self).__init__( + "ResourceNotFoundException", message + ) + + +class SecretNotFoundException(SecretsManagerClientError): + def __init__(self): + self.code = 404 + super(SecretNotFoundException, self).__init__( "ResourceNotFoundException", - "Secrets Manager can't find the specified secret" + message="Secrets Manager can't find the specified secret.", + ) + + +class SecretHasNoValueException(SecretsManagerClientError): + def __init__(self, version_stage): + self.code = 404 + super(SecretHasNoValueException, self).__init__( + "ResourceNotFoundException", + message="Secrets Manager can't find the specified secret " + "value for staging label: {}".format(version_stage), ) class ClientError(SecretsManagerClientError): def __init__(self, message): - super(ClientError, self).__init__( - 'InvalidParameterValue', - message) + super(ClientError, self).__init__("InvalidParameterValue", message) class InvalidParameterException(SecretsManagerClientError): def __init__(self, message): super(InvalidParameterException, self).__init__( - 'InvalidParameterException', - message) + "InvalidParameterException", message + ) class ResourceExistsException(SecretsManagerClientError): def __init__(self, message): super(ResourceExistsException, self).__init__( - 'ResourceExistsException', - message + "ResourceExistsException", message ) class InvalidRequestException(SecretsManagerClientError): def __init__(self, message): super(InvalidRequestException, self).__init__( - 'InvalidRequestException', - message) + "InvalidRequestException", message + ) diff --git a/moto/secretsmanager/models.py b/moto/secretsmanager/models.py index 3e0424b6b..2a1a336d9 100644 --- a/moto/secretsmanager/models.py +++ b/moto/secretsmanager/models.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import unicode_literals import time @@ -9,27 +10,40 @@ import boto3 from moto.core import BaseBackend, BaseModel from .exceptions import ( - ResourceNotFoundException, + SecretNotFoundException, + SecretHasNoValueException, InvalidParameterException, ResourceExistsException, InvalidRequestException, - ClientError + ClientError, ) -from .utils import random_password, secret_arn +from .utils import random_password, secret_arn, get_secret_name_from_arn class SecretsManager(BaseModel): - def __init__(self, region_name, **kwargs): self.region = region_name -class SecretsManagerBackend(BaseBackend): +class SecretsStore(dict): + def __setitem__(self, key, value): + new_key = get_secret_name_from_arn(key) + super(SecretsStore, self).__setitem__(new_key, value) + def __getitem__(self, key): + new_key = get_secret_name_from_arn(key) + return super(SecretsStore, self).__getitem__(new_key) + + def __contains__(self, key): + new_key = get_secret_name_from_arn(key) + return dict.__contains__(self, new_key) + + +class SecretsManagerBackend(BaseBackend): def __init__(self, region_name=None, **kwargs): super(SecretsManagerBackend, self).__init__() self.region = region_name - self.secrets = {} + self.secrets = SecretsStore() def reset(self): region_name = self.region @@ -44,163 +58,197 @@ class SecretsManagerBackend(BaseBackend): return (dt - epoch).total_seconds() def get_secret_value(self, secret_id, version_id, version_stage): - if not self._is_valid_identifier(secret_id): - raise ResourceNotFoundException() + raise SecretNotFoundException() if not version_id and version_stage: # set version_id to match version_stage - versions_dict = self.secrets[secret_id]['versions'] + versions_dict = self.secrets[secret_id]["versions"] for ver_id, ver_val in versions_dict.items(): - if version_stage in ver_val['version_stages']: + if version_stage in ver_val["version_stages"]: version_id = ver_id break if not version_id: - raise ResourceNotFoundException() + raise SecretNotFoundException() # TODO check this part - if 'deleted_date' in self.secrets[secret_id]: + if "deleted_date" in self.secrets[secret_id]: raise InvalidRequestException( "An error occurred (InvalidRequestException) when calling the GetSecretValue operation: You tried to \ perform the operation on a secret that's currently marked deleted." ) secret = self.secrets[secret_id] - version_id = version_id or secret['default_version_id'] + version_id = version_id or secret["default_version_id"] - secret_version = secret['versions'][version_id] + secret_version = secret["versions"][version_id] response_data = { - "ARN": secret_arn(self.region, secret['secret_id']), - "Name": secret['name'], - "VersionId": secret_version['version_id'], - "VersionStages": secret_version['version_stages'], - "CreatedDate": secret_version['createdate'], + "ARN": secret_arn(self.region, secret["secret_id"]), + "Name": secret["name"], + "VersionId": secret_version["version_id"], + "VersionStages": secret_version["version_stages"], + "CreatedDate": secret_version["createdate"], } - if 'secret_string' in secret_version: - response_data["SecretString"] = secret_version['secret_string'] + if "secret_string" in secret_version: + response_data["SecretString"] = secret_version["secret_string"] - if 'secret_binary' in secret_version: - response_data["SecretBinary"] = secret_version['secret_binary'] + if "secret_binary" in secret_version: + response_data["SecretBinary"] = secret_version["secret_binary"] + + if ( + "secret_string" not in secret_version + and "secret_binary" not in secret_version + ): + raise SecretHasNoValueException(version_stage or "AWSCURRENT") response = json.dumps(response_data) return response - def create_secret(self, name, secret_string=None, secret_binary=None, tags=[], **kwargs): + def create_secret( + self, name, secret_string=None, secret_binary=None, tags=[], **kwargs + ): # error if secret exists if name in self.secrets.keys(): - raise ResourceExistsException('A resource with the ID you requested already exists.') + raise ResourceExistsException( + "A resource with the ID you requested already exists." + ) - version_id = self._add_secret(name, secret_string=secret_string, secret_binary=secret_binary, tags=tags) + version_id = self._add_secret( + name, secret_string=secret_string, secret_binary=secret_binary, tags=tags + ) - response = json.dumps({ - "ARN": secret_arn(self.region, name), - "Name": name, - "VersionId": version_id, - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, name), + "Name": name, + "VersionId": version_id, + } + ) return response - def _add_secret(self, secret_id, secret_string=None, secret_binary=None, tags=[], version_id=None, version_stages=None): + def _add_secret( + self, + secret_id, + secret_string=None, + secret_binary=None, + tags=[], + version_id=None, + version_stages=None, + ): if version_stages is None: - version_stages = ['AWSCURRENT'] + version_stages = ["AWSCURRENT"] if not version_id: version_id = str(uuid.uuid4()) secret_version = { - 'createdate': int(time.time()), - 'version_id': version_id, - 'version_stages': version_stages, + "createdate": int(time.time()), + "version_id": version_id, + "version_stages": version_stages, } if secret_string is not None: - secret_version['secret_string'] = secret_string + secret_version["secret_string"] = secret_string if secret_binary is not None: - secret_version['secret_binary'] = secret_binary + secret_version["secret_binary"] = secret_binary if secret_id in self.secrets: # remove all old AWSPREVIOUS stages - for secret_verion_to_look_at in self.secrets[secret_id]['versions'].values(): - if 'AWSPREVIOUS' in secret_verion_to_look_at['version_stages']: - secret_verion_to_look_at['version_stages'].remove('AWSPREVIOUS') + for secret_verion_to_look_at in self.secrets[secret_id][ + "versions" + ].values(): + if "AWSPREVIOUS" in secret_verion_to_look_at["version_stages"]: + secret_verion_to_look_at["version_stages"].remove("AWSPREVIOUS") # set old AWSCURRENT secret to AWSPREVIOUS - previous_current_version_id = self.secrets[secret_id]['default_version_id'] - self.secrets[secret_id]['versions'][previous_current_version_id]['version_stages'] = ['AWSPREVIOUS'] + previous_current_version_id = self.secrets[secret_id]["default_version_id"] + self.secrets[secret_id]["versions"][previous_current_version_id][ + "version_stages" + ] = ["AWSPREVIOUS"] - self.secrets[secret_id]['versions'][version_id] = secret_version - self.secrets[secret_id]['default_version_id'] = version_id + self.secrets[secret_id]["versions"][version_id] = secret_version + self.secrets[secret_id]["default_version_id"] = version_id else: self.secrets[secret_id] = { - 'versions': { - version_id: secret_version - }, - 'default_version_id': version_id, + "versions": {version_id: secret_version}, + "default_version_id": version_id, } secret = self.secrets[secret_id] - secret['secret_id'] = secret_id - secret['name'] = secret_id - secret['rotation_enabled'] = False - secret['rotation_lambda_arn'] = '' - secret['auto_rotate_after_days'] = 0 - secret['tags'] = tags + secret["secret_id"] = secret_id + secret["name"] = secret_id + secret["rotation_enabled"] = False + secret["rotation_lambda_arn"] = "" + secret["auto_rotate_after_days"] = 0 + secret["tags"] = tags return version_id - def put_secret_value(self, secret_id, secret_string, version_stages): + def put_secret_value(self, secret_id, secret_string, secret_binary, version_stages): - version_id = self._add_secret(secret_id, secret_string, version_stages=version_stages) + version_id = self._add_secret( + secret_id, secret_string, secret_binary, version_stages=version_stages + ) - response = json.dumps({ - 'ARN': secret_arn(self.region, secret_id), - 'Name': secret_id, - 'VersionId': version_id, - 'VersionStages': version_stages - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, secret_id), + "Name": secret_id, + "VersionId": version_id, + "VersionStages": version_stages, + } + ) return response def describe_secret(self, secret_id): if not self._is_valid_identifier(secret_id): - raise ResourceNotFoundException + raise SecretNotFoundException() secret = self.secrets[secret_id] - response = json.dumps({ - "ARN": secret_arn(self.region, secret['secret_id']), - "Name": secret['name'], - "Description": "", - "KmsKeyId": "", - "RotationEnabled": secret['rotation_enabled'], - "RotationLambdaARN": secret['rotation_lambda_arn'], - "RotationRules": { - "AutomaticallyAfterDays": secret['auto_rotate_after_days'] - }, - "LastRotatedDate": None, - "LastChangedDate": None, - "LastAccessedDate": None, - "DeletedDate": secret.get('deleted_date', None), - "Tags": secret['tags'] - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, secret["secret_id"]), + "Name": secret["name"], + "Description": "", + "KmsKeyId": "", + "RotationEnabled": secret["rotation_enabled"], + "RotationLambdaARN": secret["rotation_lambda_arn"], + "RotationRules": { + "AutomaticallyAfterDays": secret["auto_rotate_after_days"] + }, + "LastRotatedDate": None, + "LastChangedDate": None, + "LastAccessedDate": None, + "DeletedDate": secret.get("deleted_date", None), + "Tags": secret["tags"], + } + ) return response - def rotate_secret(self, secret_id, client_request_token=None, - rotation_lambda_arn=None, rotation_rules=None): + def rotate_secret( + self, + secret_id, + client_request_token=None, + rotation_lambda_arn=None, + rotation_rules=None, + ): - rotation_days = 'AutomaticallyAfterDays' + rotation_days = "AutomaticallyAfterDays" if not self._is_valid_identifier(secret_id): - raise ResourceNotFoundException + raise SecretNotFoundException() - if 'deleted_date' in self.secrets[secret_id]: + if "deleted_date" in self.secrets[secret_id]: raise InvalidRequestException( "An error occurred (InvalidRequestException) when calling the RotateSecret operation: You tried to \ perform the operation on a secret that's currently marked deleted." @@ -209,18 +257,12 @@ class SecretsManagerBackend(BaseBackend): if client_request_token: token_length = len(client_request_token) if token_length < 32 or token_length > 64: - msg = ( - 'ClientRequestToken ' - 'must be 32-64 characters long.' - ) + msg = "ClientRequestToken " "must be 32-64 characters long." raise InvalidParameterException(msg) if rotation_lambda_arn: if len(rotation_lambda_arn) > 2048: - msg = ( - 'RotationLambdaARN ' - 'must <= 2048 characters long.' - ) + msg = "RotationLambdaARN " "must <= 2048 characters long." raise InvalidParameterException(msg) if rotation_rules: @@ -228,61 +270,82 @@ class SecretsManagerBackend(BaseBackend): rotation_period = rotation_rules[rotation_days] if rotation_period < 1 or rotation_period > 1000: msg = ( - 'RotationRules.AutomaticallyAfterDays ' - 'must be within 1-1000.' + "RotationRules.AutomaticallyAfterDays " "must be within 1-1000." ) raise InvalidParameterException(msg) secret = self.secrets[secret_id] - old_secret_version = secret['versions'][secret['default_version_id']] + old_secret_version = secret["versions"][secret["default_version_id"]] new_version_id = client_request_token or str(uuid.uuid4()) - self._add_secret(secret_id, old_secret_version['secret_string'], secret['tags'], version_id=new_version_id, version_stages=['AWSCURRENT']) + self._add_secret( + secret_id, + old_secret_version["secret_string"], + secret["tags"], + version_id=new_version_id, + version_stages=["AWSCURRENT"], + ) - secret['rotation_lambda_arn'] = rotation_lambda_arn or '' + secret["rotation_lambda_arn"] = rotation_lambda_arn or "" if rotation_rules: - secret['auto_rotate_after_days'] = rotation_rules.get(rotation_days, 0) - if secret['auto_rotate_after_days'] > 0: - secret['rotation_enabled'] = True + secret["auto_rotate_after_days"] = rotation_rules.get(rotation_days, 0) + if secret["auto_rotate_after_days"] > 0: + secret["rotation_enabled"] = True - if 'AWSCURRENT' in old_secret_version['version_stages']: - old_secret_version['version_stages'].remove('AWSCURRENT') + if "AWSCURRENT" in old_secret_version["version_stages"]: + old_secret_version["version_stages"].remove("AWSCURRENT") - response = json.dumps({ - "ARN": secret_arn(self.region, secret['secret_id']), - "Name": secret['name'], - "VersionId": new_version_id - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, secret["secret_id"]), + "Name": secret["name"], + "VersionId": new_version_id, + } + ) return response - def get_random_password(self, password_length, - exclude_characters, exclude_numbers, - exclude_punctuation, exclude_uppercase, - exclude_lowercase, include_space, - require_each_included_type): + def get_random_password( + self, + password_length, + exclude_characters, + exclude_numbers, + exclude_punctuation, + exclude_uppercase, + exclude_lowercase, + include_space, + require_each_included_type, + ): # password size must have value less than or equal to 4096 if password_length > 4096: raise ClientError( "ClientError: An error occurred (ValidationException) \ when calling the GetRandomPassword operation: 1 validation error detected: Value '{}' at 'passwordLength' \ - failed to satisfy constraint: Member must have value less than or equal to 4096".format(password_length)) + failed to satisfy constraint: Member must have value less than or equal to 4096".format( + password_length + ) + ) if password_length < 4: raise InvalidParameterException( "InvalidParameterException: An error occurred (InvalidParameterException) \ - when calling the GetRandomPassword operation: Password length is too short based on the required types.") + when calling the GetRandomPassword operation: Password length is too short based on the required types." + ) - response = json.dumps({ - "RandomPassword": random_password(password_length, - exclude_characters, - exclude_numbers, - exclude_punctuation, - exclude_uppercase, - exclude_lowercase, - include_space, - require_each_included_type) - }) + response = json.dumps( + { + "RandomPassword": random_password( + password_length, + exclude_characters, + exclude_numbers, + exclude_punctuation, + exclude_uppercase, + exclude_lowercase, + include_space, + require_each_included_type, + ) + } + ) return response @@ -290,20 +353,24 @@ class SecretsManagerBackend(BaseBackend): secret = self.secrets[secret_id] version_list = [] - for version_id, version in secret['versions'].items(): - version_list.append({ - 'CreatedDate': int(time.time()), - 'LastAccessedDate': int(time.time()), - 'VersionId': version_id, - 'VersionStages': version['version_stages'], - }) + for version_id, version in secret["versions"].items(): + version_list.append( + { + "CreatedDate": int(time.time()), + "LastAccessedDate": int(time.time()), + "VersionId": version_id, + "VersionStages": version["version_stages"], + } + ) - response = json.dumps({ - 'ARN': secret['secret_id'], - 'Name': secret['name'], - 'NextToken': '', - 'Versions': version_list, - }) + response = json.dumps( + { + "ARN": secret["secret_id"], + "Name": secret["name"], + "NextToken": "", + "Versions": version_list, + } + ) return response @@ -314,35 +381,39 @@ class SecretsManagerBackend(BaseBackend): for secret in self.secrets.values(): versions_to_stages = {} - for version_id, version in secret['versions'].items(): - versions_to_stages[version_id] = version['version_stages'] + for version_id, version in secret["versions"].items(): + versions_to_stages[version_id] = version["version_stages"] - secret_list.append({ - "ARN": secret_arn(self.region, secret['secret_id']), - "DeletedDate": secret.get('deleted_date', None), - "Description": "", - "KmsKeyId": "", - "LastAccessedDate": None, - "LastChangedDate": None, - "LastRotatedDate": None, - "Name": secret['name'], - "RotationEnabled": secret['rotation_enabled'], - "RotationLambdaARN": secret['rotation_lambda_arn'], - "RotationRules": { - "AutomaticallyAfterDays": secret['auto_rotate_after_days'] - }, - "SecretVersionsToStages": versions_to_stages, - "Tags": secret['tags'] - }) + secret_list.append( + { + "ARN": secret_arn(self.region, secret["secret_id"]), + "DeletedDate": secret.get("deleted_date", None), + "Description": "", + "KmsKeyId": "", + "LastAccessedDate": None, + "LastChangedDate": None, + "LastRotatedDate": None, + "Name": secret["name"], + "RotationEnabled": secret["rotation_enabled"], + "RotationLambdaARN": secret["rotation_lambda_arn"], + "RotationRules": { + "AutomaticallyAfterDays": secret["auto_rotate_after_days"] + }, + "SecretVersionsToStages": versions_to_stages, + "Tags": secret["tags"], + } + ) return secret_list, None - def delete_secret(self, secret_id, recovery_window_in_days, force_delete_without_recovery): + def delete_secret( + self, secret_id, recovery_window_in_days, force_delete_without_recovery + ): if not self._is_valid_identifier(secret_id): - raise ResourceNotFoundException + raise SecretNotFoundException() - if 'deleted_date' in self.secrets[secret_id]: + if "deleted_date" in self.secrets[secret_id]: raise InvalidRequestException( "An error occurred (InvalidRequestException) when calling the DeleteSecret operation: You tried to \ perform the operation on a secret that's currently marked deleted." @@ -354,7 +425,9 @@ class SecretsManagerBackend(BaseBackend): use ForceDeleteWithoutRecovery in conjunction with RecoveryWindowInDays." ) - if recovery_window_in_days and (recovery_window_in_days < 7 or recovery_window_in_days > 30): + if recovery_window_in_days and ( + recovery_window_in_days < 7 or recovery_window_in_days > 30 + ): raise InvalidParameterException( "An error occurred (InvalidParameterException) when calling the DeleteSecret operation: The \ RecoveryWindowInDays value must be between 7 and 30 days (inclusive)." @@ -366,34 +439,59 @@ class SecretsManagerBackend(BaseBackend): secret = self.secrets.pop(secret_id, None) else: deletion_date += datetime.timedelta(days=recovery_window_in_days or 30) - self.secrets[secret_id]['deleted_date'] = self._unix_time_secs(deletion_date) + self.secrets[secret_id]["deleted_date"] = self._unix_time_secs( + deletion_date + ) secret = self.secrets.get(secret_id, None) if not secret: - raise ResourceNotFoundException + raise SecretNotFoundException() - arn = secret_arn(self.region, secret['secret_id']) - name = secret['name'] + arn = secret_arn(self.region, secret["secret_id"]) + name = secret["name"] return arn, name, self._unix_time_secs(deletion_date) def restore_secret(self, secret_id): if not self._is_valid_identifier(secret_id): - raise ResourceNotFoundException + raise SecretNotFoundException() - self.secrets[secret_id].pop('deleted_date', None) + self.secrets[secret_id].pop("deleted_date", None) secret = self.secrets[secret_id] - arn = secret_arn(self.region, secret['secret_id']) - name = secret['name'] + arn = secret_arn(self.region, secret["secret_id"]) + name = secret["name"] return arn, name + @staticmethod + def get_resource_policy(secret_id): + resource_policy = { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Principal": { + "AWS": [ + "arn:aws:iam::111122223333:root", + "arn:aws:iam::444455556666:root", + ] + }, + "Action": ["secretsmanager:GetSecretValue"], + "Resource": "*", + }, + } + return json.dumps( + { + "ARN": secret_id, + "Name": secret_id, + "ResourcePolicy": json.dumps(resource_policy), + } + ) -available_regions = ( - boto3.session.Session().get_available_regions("secretsmanager") -) -secretsmanager_backends = {region: SecretsManagerBackend(region_name=region) - for region in available_regions} + +available_regions = boto3.session.Session().get_available_regions("secretsmanager") +secretsmanager_backends = { + region: SecretsManagerBackend(region_name=region) for region in available_regions +} diff --git a/moto/secretsmanager/responses.py b/moto/secretsmanager/responses.py index 090688351..28af7b91d 100644 --- a/moto/secretsmanager/responses.py +++ b/moto/secretsmanager/responses.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse +from moto.secretsmanager.exceptions import InvalidRequestException from .models import secretsmanager_backends @@ -8,38 +9,37 @@ import json class SecretsManagerResponse(BaseResponse): - def get_secret_value(self): - secret_id = self._get_param('SecretId') - version_id = self._get_param('VersionId') - version_stage = self._get_param('VersionStage') + secret_id = self._get_param("SecretId") + version_id = self._get_param("VersionId") + version_stage = self._get_param("VersionStage") return secretsmanager_backends[self.region].get_secret_value( - secret_id=secret_id, - version_id=version_id, - version_stage=version_stage) + secret_id=secret_id, version_id=version_id, version_stage=version_stage + ) def create_secret(self): - name = self._get_param('Name') - secret_string = self._get_param('SecretString') - secret_binary = self._get_param('SecretBinary') - tags = self._get_param('Tags', if_none=[]) + name = self._get_param("Name") + secret_string = self._get_param("SecretString") + secret_binary = self._get_param("SecretBinary") + tags = self._get_param("Tags", if_none=[]) return secretsmanager_backends[self.region].create_secret( name=name, secret_string=secret_string, secret_binary=secret_binary, - tags=tags + tags=tags, ) def get_random_password(self): - password_length = self._get_param('PasswordLength', if_none=32) - exclude_characters = self._get_param('ExcludeCharacters', if_none='') - exclude_numbers = self._get_param('ExcludeNumbers', if_none=False) - exclude_punctuation = self._get_param('ExcludePunctuation', if_none=False) - exclude_uppercase = self._get_param('ExcludeUppercase', if_none=False) - exclude_lowercase = self._get_param('ExcludeLowercase', if_none=False) - include_space = self._get_param('IncludeSpace', if_none=False) + password_length = self._get_param("PasswordLength", if_none=32) + exclude_characters = self._get_param("ExcludeCharacters", if_none="") + exclude_numbers = self._get_param("ExcludeNumbers", if_none=False) + exclude_punctuation = self._get_param("ExcludePunctuation", if_none=False) + exclude_uppercase = self._get_param("ExcludeUppercase", if_none=False) + exclude_lowercase = self._get_param("ExcludeLowercase", if_none=False) + include_space = self._get_param("IncludeSpace", if_none=False) require_each_included_type = self._get_param( - 'RequireEachIncludedType', if_none=True) + "RequireEachIncludedType", if_none=True + ) return secretsmanager_backends[self.region].get_random_password( password_length=password_length, exclude_characters=exclude_characters, @@ -48,39 +48,43 @@ class SecretsManagerResponse(BaseResponse): exclude_uppercase=exclude_uppercase, exclude_lowercase=exclude_lowercase, include_space=include_space, - require_each_included_type=require_each_included_type + require_each_included_type=require_each_included_type, ) def describe_secret(self): - secret_id = self._get_param('SecretId') - return secretsmanager_backends[self.region].describe_secret( - secret_id=secret_id - ) + secret_id = self._get_param("SecretId") + return secretsmanager_backends[self.region].describe_secret(secret_id=secret_id) def rotate_secret(self): - client_request_token = self._get_param('ClientRequestToken') - rotation_lambda_arn = self._get_param('RotationLambdaARN') - rotation_rules = self._get_param('RotationRules') - secret_id = self._get_param('SecretId') + client_request_token = self._get_param("ClientRequestToken") + rotation_lambda_arn = self._get_param("RotationLambdaARN") + rotation_rules = self._get_param("RotationRules") + secret_id = self._get_param("SecretId") return secretsmanager_backends[self.region].rotate_secret( secret_id=secret_id, client_request_token=client_request_token, rotation_lambda_arn=rotation_lambda_arn, - rotation_rules=rotation_rules + rotation_rules=rotation_rules, ) def put_secret_value(self): - secret_id = self._get_param('SecretId', if_none='') - secret_string = self._get_param('SecretString', if_none='') - version_stages = self._get_param('VersionStages', if_none=['AWSCURRENT']) + secret_id = self._get_param("SecretId", if_none="") + secret_string = self._get_param("SecretString") + secret_binary = self._get_param("SecretBinary") + if not secret_binary and not secret_string: + raise InvalidRequestException( + "You must provide either SecretString or SecretBinary." + ) + version_stages = self._get_param("VersionStages", if_none=["AWSCURRENT"]) return secretsmanager_backends[self.region].put_secret_value( secret_id=secret_id, + secret_binary=secret_binary, secret_string=secret_string, version_stages=version_stages, ) def list_secret_version_ids(self): - secret_id = self._get_param('SecretId', if_none='') + secret_id = self._get_param("SecretId", if_none="") return secretsmanager_backends[self.region].list_secret_version_ids( secret_id=secret_id ) @@ -89,8 +93,7 @@ class SecretsManagerResponse(BaseResponse): max_results = self._get_int_param("MaxResults") next_token = self._get_param("NextToken") secret_list, next_token = secretsmanager_backends[self.region].list_secrets( - max_results=max_results, - next_token=next_token, + max_results=max_results, next_token=next_token ) return json.dumps(dict(SecretList=secret_list, NextToken=next_token)) @@ -108,6 +111,12 @@ class SecretsManagerResponse(BaseResponse): def restore_secret(self): secret_id = self._get_param("SecretId") arn, name = secretsmanager_backends[self.region].restore_secret( - secret_id=secret_id, + secret_id=secret_id ) return json.dumps(dict(ARN=arn, Name=name)) + + def get_resource_policy(self): + secret_id = self._get_param("SecretId") + return secretsmanager_backends[self.region].get_resource_policy( + secret_id=secret_id + ) diff --git a/moto/secretsmanager/urls.py b/moto/secretsmanager/urls.py index 9e39e7263..57cbac0e4 100644 --- a/moto/secretsmanager/urls.py +++ b/moto/secretsmanager/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import SecretsManagerResponse -url_bases = [ - "https?://secretsmanager.(.+).amazonaws.com", -] +url_bases = ["https?://secretsmanager.(.+).amazonaws.com"] -url_paths = { - '{0}/$': SecretsManagerResponse.dispatch, -} +url_paths = {"{0}/$": SecretsManagerResponse.dispatch} diff --git a/moto/secretsmanager/utils.py b/moto/secretsmanager/utils.py index 231fea296..73275ee05 100644 --- a/moto/secretsmanager/utils.py +++ b/moto/secretsmanager/utils.py @@ -6,55 +6,83 @@ import six import re -def random_password(password_length, exclude_characters, exclude_numbers, - exclude_punctuation, exclude_uppercase, exclude_lowercase, - include_space, require_each_included_type): +def random_password( + password_length, + exclude_characters, + exclude_numbers, + exclude_punctuation, + exclude_uppercase, + exclude_lowercase, + include_space, + require_each_included_type, +): - password = '' - required_characters = '' + password = "" + required_characters = "" if not exclude_lowercase and not exclude_uppercase: password += string.ascii_letters - required_characters += random.choice(_exclude_characters( - string.ascii_lowercase, exclude_characters)) - required_characters += random.choice(_exclude_characters( - string.ascii_uppercase, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.ascii_lowercase, exclude_characters) + ) + required_characters += random.choice( + _exclude_characters(string.ascii_uppercase, exclude_characters) + ) elif not exclude_lowercase: password += string.ascii_lowercase - required_characters += random.choice(_exclude_characters( - string.ascii_lowercase, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.ascii_lowercase, exclude_characters) + ) elif not exclude_uppercase: password += string.ascii_uppercase - required_characters += random.choice(_exclude_characters( - string.ascii_uppercase, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.ascii_uppercase, exclude_characters) + ) if not exclude_numbers: password += string.digits - required_characters += random.choice(_exclude_characters( - string.digits, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.digits, exclude_characters) + ) if not exclude_punctuation: password += string.punctuation - required_characters += random.choice(_exclude_characters( - string.punctuation, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.punctuation, exclude_characters) + ) if include_space: password += " " required_characters += " " - password = ''.join( - six.text_type(random.choice(password)) - for x in range(password_length)) + password = "".join( + six.text_type(random.choice(password)) for x in range(password_length) + ) if require_each_included_type: password = _add_password_require_each_included_type( - password, required_characters) + password, required_characters + ) password = _exclude_characters(password, exclude_characters) return password def secret_arn(region, secret_id): - id_string = ''.join(random.choice(string.ascii_letters) for _ in range(5)) + id_string = "".join(random.choice(string.ascii_letters) for _ in range(5)) return "arn:aws:secretsmanager:{0}:1234567890:secret:{1}-{2}".format( - region, secret_id, id_string) + region, secret_id, id_string + ) + + +def get_secret_name_from_arn(secret_id): + # can fetch by both arn and by name + # but we are storing via name + # so we need to change the arn to name + # if it starts with arn then the secret id is arn + if secret_id.startswith("arn:aws:secretsmanager:"): + # split the arn by colon + # then get the last value which is the name appended with a random string + # then remove the random string + secret_id = "-".join(secret_id.split(":")[-1].split("-")[:-1]) + return secret_id def _exclude_characters(password, exclude_characters): @@ -62,12 +90,12 @@ def _exclude_characters(password, exclude_characters): if c in string.punctuation: # Escape punctuation regex usage c = "\{0}".format(c) - password = re.sub(c, '', str(password)) + password = re.sub(c, "", str(password)) return password def _add_password_require_each_included_type(password, required_characters): - password_with_required_char = password[:-len(required_characters)] + password_with_required_char = password[: -len(required_characters)] password_with_required_char += required_characters return password_with_required_char diff --git a/moto/server.py b/moto/server.py index 89be47093..92fe6f229 100644 --- a/moto/server.py +++ b/moto/server.py @@ -21,13 +21,13 @@ from moto.core.utils import convert_flask_to_httpretty_response HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"] -DEFAULT_SERVICE_REGION = ('s3', 'us-east-1') +DEFAULT_SERVICE_REGION = ("s3", "us-east-1") # Map of unsigned calls to service-region as per AWS API docs # https://docs.aws.amazon.com/cognito/latest/developerguide/resource-permissions.html#amazon-cognito-signed-versus-unsigned-apis UNSIGNED_REQUESTS = { - 'AWSCognitoIdentityService': ('cognito-identity', 'us-east-1'), - 'AWSCognitoIdentityProviderService': ('cognito-idp', 'us-east-1'), + "AWSCognitoIdentityService": ("cognito-identity", "us-east-1"), + "AWSCognitoIdentityProviderService": ("cognito-idp", "us-east-1"), } @@ -44,7 +44,7 @@ class DomainDispatcherApplication(object): self.service = service def get_backend_for_host(self, host): - if host == 'moto_api': + if host == "moto_api": return host if self.service: @@ -55,11 +55,11 @@ class DomainDispatcherApplication(object): for backend_name, backend in BACKENDS.items(): for url_base in list(backend.values())[0].url_bases: - if re.match(url_base, 'http://%s' % host): + if re.match(url_base, "http://%s" % host): return backend_name def infer_service_region_host(self, environ): - auth = environ.get('HTTP_AUTHORIZATION') + auth = environ.get("HTTP_AUTHORIZATION") if auth: # Signed request # Parse auth header to find service assuming a SigV4 request @@ -76,43 +76,46 @@ class DomainDispatcherApplication(object): service, region = DEFAULT_SERVICE_REGION else: # Unsigned request - target = environ.get('HTTP_X_AMZ_TARGET') + target = environ.get("HTTP_X_AMZ_TARGET") if target: - service, _ = target.split('.', 1) + service, _ = target.split(".", 1) service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION) else: # S3 is the last resort when the target is also unknown service, region = DEFAULT_SERVICE_REGION - if service == 'dynamodb': - if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'): - host = 'dynamodbstreams' + if service == "dynamodb": + if environ["HTTP_X_AMZ_TARGET"].startswith("DynamoDBStreams"): + host = "dynamodbstreams" else: - dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0] + dynamo_api_version = ( + environ["HTTP_X_AMZ_TARGET"].split("_")[1].split(".")[0] + ) # If Newer API version, use dynamodb2 if dynamo_api_version > "20111205": host = "dynamodb2" else: host = "{service}.{region}.amazonaws.com".format( - service=service, region=region) + service=service, region=region + ) return host def get_application(self, environ): - path_info = environ.get('PATH_INFO', '') + path_info = environ.get("PATH_INFO", "") # The URL path might contain non-ASCII text, for instance unicode S3 bucket names if six.PY2 and isinstance(path_info, str): path_info = six.u(path_info) if six.PY3 and isinstance(path_info, six.binary_type): - path_info = path_info.decode('utf-8') + path_info = path_info.decode("utf-8") if path_info.startswith("/moto-api") or path_info == "/favicon.ico": host = "moto_api" elif path_info.startswith("/latest/meta-data/"): host = "instance_metadata" else: - host = environ['HTTP_HOST'].split(':')[0] + host = environ["HTTP_HOST"].split(":")[0] with self.lock: backend = self.get_backend_for_host(host) @@ -141,15 +144,18 @@ class RegexConverter(BaseConverter): class AWSTestHelper(FlaskClient): - def action_data(self, action_name, **kwargs): """ Method calls resource with action_name and returns data of response. """ opts = {"Action": action_name} opts.update(kwargs) - res = self.get("/?{0}".format(urlencode(opts)), - headers={"Host": "{0}.us-east-1.amazonaws.com".format(self.application.service)}) + res = self.get( + "/?{0}".format(urlencode(opts)), + headers={ + "Host": "{0}.us-east-1.amazonaws.com".format(self.application.service) + }, + ) return res.data.decode("utf-8") def action_json(self, action_name, **kwargs): @@ -171,19 +177,20 @@ def create_backend_app(service): # Reset view functions to reset the app backend_app.view_functions = {} backend_app.url_map = Map() - backend_app.url_map.converters['regex'] = RegexConverter + backend_app.url_map.converters["regex"] = RegexConverter backend = list(BACKENDS[service].values())[0] for url_path, handler in backend.flask_paths.items(): - if handler.__name__ == 'dispatch': - endpoint = '{0}.dispatch'.format(handler.__self__.__name__) + view_func = convert_flask_to_httpretty_response(handler) + if handler.__name__ == "dispatch": + endpoint = "{0}.dispatch".format(handler.__self__.__name__) else: - endpoint = None + endpoint = view_func.__name__ original_endpoint = endpoint index = 2 while endpoint in backend_app.view_functions: # HACK: Sometimes we map the same view to multiple url_paths. Flask - # requries us to have different names. + # requires us to have different names. endpoint = original_endpoint + str(index) index += 1 @@ -191,7 +198,7 @@ def create_backend_app(service): url_path, endpoint=endpoint, methods=HTTP_METHODS, - view_func=convert_flask_to_httpretty_response(handler), + view_func=view_func, strict_slashes=False, ) @@ -206,54 +213,57 @@ def main(argv=sys.argv[1:]): parser.add_argument( "service", type=str, - nargs='?', # http://stackoverflow.com/a/4480202/731592 - default=None) - parser.add_argument( - '-H', '--host', type=str, - help='Which host to bind', - default='127.0.0.1') - parser.add_argument( - '-p', '--port', type=int, - help='Port number to use for connection', - default=5000) - parser.add_argument( - '-r', '--reload', - action='store_true', - help='Reload server on a file change', - default=False + nargs="?", # http://stackoverflow.com/a/4480202/731592 + default=None, ) parser.add_argument( - '-s', '--ssl', - action='store_true', - help='Enable SSL encrypted connection with auto-generated certificate (use https://... URL)', - default=False + "-H", "--host", type=str, help="Which host to bind", default="127.0.0.1" ) parser.add_argument( - '-c', '--ssl-cert', type=str, - help='Path to SSL certificate', - default=None) + "-p", "--port", type=int, help="Port number to use for connection", default=5000 + ) parser.add_argument( - '-k', '--ssl-key', type=str, - help='Path to SSL private key', - default=None) + "-r", + "--reload", + action="store_true", + help="Reload server on a file change", + default=False, + ) + parser.add_argument( + "-s", + "--ssl", + action="store_true", + help="Enable SSL encrypted connection with auto-generated certificate (use https://... URL)", + default=False, + ) + parser.add_argument( + "-c", "--ssl-cert", type=str, help="Path to SSL certificate", default=None + ) + parser.add_argument( + "-k", "--ssl-key", type=str, help="Path to SSL private key", default=None + ) args = parser.parse_args(argv) # Wrap the main application - main_app = DomainDispatcherApplication( - create_backend_app, service=args.service) + main_app = DomainDispatcherApplication(create_backend_app, service=args.service) main_app.debug = True ssl_context = None if args.ssl_key and args.ssl_cert: ssl_context = (args.ssl_cert, args.ssl_key) elif args.ssl: - ssl_context = 'adhoc' + ssl_context = "adhoc" - run_simple(args.host, args.port, main_app, - threaded=True, use_reloader=args.reload, - ssl_context=ssl_context) + run_simple( + args.host, + args.port, + main_app, + threaded=True, + use_reloader=args.reload, + ssl_context=ssl_context, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/moto/ses/exceptions.py b/moto/ses/exceptions.py index f888af9f6..a905039e2 100644 --- a/moto/ses/exceptions.py +++ b/moto/ses/exceptions.py @@ -6,5 +6,4 @@ class MessageRejectedError(RESTError): code = 400 def __init__(self, message): - super(MessageRejectedError, self).__init__( - "MessageRejected", message) + super(MessageRejectedError, self).__init__("MessageRejected", message) diff --git a/moto/ses/feedback.py b/moto/ses/feedback.py index 2d32f9ce0..b0fa293e7 100644 --- a/moto/ses/feedback.py +++ b/moto/ses/feedback.py @@ -1,3 +1,5 @@ +from moto.core import ACCOUNT_ID + """ SES Feedback messages Extracted from https://docs.aws.amazon.com/ses/latest/DeveloperGuide/notification-contents.html @@ -10,33 +12,21 @@ COMMON_MAIL = { "source": "sender@example.com", "sourceArn": "arn:aws:ses:us-west-2:888888888888:identity/example.com", "sourceIp": "127.0.3.0", - "sendingAccountId": "123456789012", - "destination": [ - "recipient@example.com" - ], + "sendingAccountId": ACCOUNT_ID, + "destination": ["recipient@example.com"], "headersTruncated": False, "headers": [ - { - "name": "From", - "value": "\"Sender Name\" " - }, - { - "name": "To", - "value": "\"Recipient Name\" " - } + {"name": "From", "value": '"Sender Name" '}, + {"name": "To", "value": '"Recipient Name" '}, ], "commonHeaders": { - "from": [ - "Sender Name " - ], + "from": ["Sender Name "], "date": "Mon, 08 Oct 2018 14:05:45 +0000", - "to": [ - "Recipient Name " - ], + "to": ["Recipient Name "], "messageId": " custom-message-ID", - "subject": "Message sent using Amazon SES" - } - } + "subject": "Message sent using Amazon SES", + }, + }, } BOUNCE = { "bounceType": "Permanent", @@ -46,30 +36,26 @@ BOUNCE = { "status": "5.0.0", "action": "failed", "diagnosticCode": "smtp; 550 user unknown", - "emailAddress": "recipient1@example.com" + "emailAddress": "recipient1@example.com", }, { "status": "4.0.0", "action": "delayed", - "emailAddress": "recipient2@example.com" - } + "emailAddress": "recipient2@example.com", + }, ], "reportingMTA": "example.com", "timestamp": "2012-05-25T14:59:38.605Z", "feedbackId": "000001378603176d-5a4b5ad9-6f30-4198-a8c3-b1eb0c270a1d-000000", - "remoteMtaIp": "127.0.2.0" + "remoteMtaIp": "127.0.2.0", } COMPLAINT = { "userAgent": "AnyCompany Feedback Loop (V0.01)", - "complainedRecipients": [ - { - "emailAddress": "recipient1@example.com" - } - ], + "complainedRecipients": [{"emailAddress": "recipient1@example.com"}], "complaintFeedbackType": "abuse", "arrivalDate": "2009-12-03T04:24:21.000-05:00", "timestamp": "2012-05-25T14:59:38.623Z", - "feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000" + "feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000", } DELIVERY = { "timestamp": "2014-05-28T22:41:01.184Z", @@ -77,5 +63,5 @@ DELIVERY = { "recipients": ["success@simulator.amazonses.com"], "smtpResponse": "250 ok: Message 64111812 accepted", "reportingMTA": "a8-70.smtp-out.amazonses.com", - "remoteMtaIp": "127.0.2.0" + "remoteMtaIp": "127.0.2.0", } diff --git a/moto/ses/models.py b/moto/ses/models.py index 0544ac278..eacdd8458 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -40,7 +40,6 @@ class SESFeedback(BaseModel): class Message(BaseModel): - def __init__(self, message_id, source, subject, body, destinations): self.id = message_id self.source = source @@ -49,8 +48,16 @@ class Message(BaseModel): self.destinations = destinations -class RawMessage(BaseModel): +class TemplateMessage(BaseModel): + def __init__(self, message_id, source, template, template_data, destinations): + self.id = message_id + self.source = source + self.template = template + self.template_data = template_data + self.destinations = destinations + +class RawMessage(BaseModel): def __init__(self, message_id, source, destinations, raw_data): self.id = message_id self.source = source @@ -59,7 +66,6 @@ class RawMessage(BaseModel): class SESQuota(BaseModel): - def __init__(self, sent): self.sent = sent @@ -69,7 +75,6 @@ class SESQuota(BaseModel): class SESBackend(BaseBackend): - def __init__(self): self.addresses = [] self.email_addresses = [] @@ -82,7 +87,7 @@ class SESBackend(BaseBackend): _, address = parseaddr(source) if address in self.addresses: return True - user, host = address.split('@', 1) + user, host = address.split("@", 1) return host in self.domains def verify_email_identity(self, address): @@ -101,7 +106,7 @@ class SESBackend(BaseBackend): return self.email_addresses def delete_identity(self, identity): - if '@' in identity: + if "@" in identity: self.addresses.remove(identity) else: self.domains.remove(identity) @@ -109,11 +114,9 @@ class SESBackend(BaseBackend): def send_email(self, source, subject, body, destinations, region): recipient_count = sum(map(len, destinations.values())) if recipient_count > RECIPIENT_LIMIT: - raise MessageRejectedError('Too many recipients.') + raise MessageRejectedError("Too many recipients.") if not self._is_verified_address(source): - raise MessageRejectedError( - "Email address not verified %s" % source - ) + raise MessageRejectedError("Email address not verified %s" % source) self.__process_sns_feedback__(source, destinations, region) @@ -123,10 +126,33 @@ class SESBackend(BaseBackend): self.sent_message_count += recipient_count return message + def send_templated_email( + self, source, template, template_data, destinations, region + ): + recipient_count = sum(map(len, destinations.values())) + if recipient_count > RECIPIENT_LIMIT: + raise MessageRejectedError("Too many recipients.") + if not self._is_verified_address(source): + raise MessageRejectedError("Email address not verified %s" % source) + + self.__process_sns_feedback__(source, destinations, region) + + message_id = get_random_message_id() + message = TemplateMessage( + message_id, source, template, template_data, destinations + ) + self.sent_messages.append(message) + self.sent_message_count += recipient_count + return message + def __type_of_message__(self, destinations): - """Checks the destination for any special address that could indicate delivery, complaint or bounce - like in SES simualtor""" - alladdress = destinations.get("ToAddresses", []) + destinations.get("CcAddresses", []) + destinations.get("BccAddresses", []) + """Checks the destination for any special address that could indicate delivery, + complaint or bounce like in SES simulator""" + alladdress = ( + destinations.get("ToAddresses", []) + + destinations.get("CcAddresses", []) + + destinations.get("BccAddresses", []) + ) for addr in alladdress: if SESFeedback.SUCCESS_ADDR in addr: return SESFeedback.DELIVERY @@ -159,30 +185,29 @@ class SESBackend(BaseBackend): _, source_email_address = parseaddr(source) if source_email_address not in self.addresses: raise MessageRejectedError( - "Did not have authority to send from email %s" % source_email_address + "Did not have authority to send from email %s" + % source_email_address ) recipient_count = len(destinations) message = email.message_from_string(raw_data) if source is None: - if message['from'] is None: - raise MessageRejectedError( - "Source not specified" - ) + if message["from"] is None: + raise MessageRejectedError("Source not specified") - _, source_email_address = parseaddr(message['from']) + _, source_email_address = parseaddr(message["from"]) if source_email_address not in self.addresses: raise MessageRejectedError( - "Did not have authority to send from email %s" % source_email_address + "Did not have authority to send from email %s" + % source_email_address ) - for header in 'TO', 'CC', 'BCC': + for header in "TO", "CC", "BCC": recipient_count += sum( - d.strip() and 1 or 0 - for d in message.get(header, '').split(',') + d.strip() and 1 or 0 for d in message.get(header, "").split(",") ) if recipient_count > RECIPIENT_LIMIT: - raise MessageRejectedError('Too many recipients.') + raise MessageRejectedError("Too many recipients.") self.__process_sns_feedback__(source, destinations, region) diff --git a/moto/ses/responses.py b/moto/ses/responses.py index d2dda55f1..1034aeb0d 100644 --- a/moto/ses/responses.py +++ b/moto/ses/responses.py @@ -8,15 +8,14 @@ from .models import ses_backend class EmailResponse(BaseResponse): - def verify_email_identity(self): - address = self.querystring.get('EmailAddress')[0] + address = self.querystring.get("EmailAddress")[0] ses_backend.verify_email_identity(address) template = self.response_template(VERIFY_EMAIL_IDENTITY) return template.render() def verify_email_address(self): - address = self.querystring.get('EmailAddress')[0] + address = self.querystring.get("EmailAddress")[0] ses_backend.verify_email_address(address) template = self.response_template(VERIFY_EMAIL_ADDRESS) return template.render() @@ -32,67 +31,88 @@ class EmailResponse(BaseResponse): return template.render(email_addresses=email_addresses) def verify_domain_dkim(self): - domain = self.querystring.get('Domain')[0] + domain = self.querystring.get("Domain")[0] ses_backend.verify_domain(domain) template = self.response_template(VERIFY_DOMAIN_DKIM_RESPONSE) return template.render() def verify_domain_identity(self): - domain = self.querystring.get('Domain')[0] + domain = self.querystring.get("Domain")[0] ses_backend.verify_domain(domain) template = self.response_template(VERIFY_DOMAIN_IDENTITY_RESPONSE) return template.render() def delete_identity(self): - domain = self.querystring.get('Identity')[0] + domain = self.querystring.get("Identity")[0] ses_backend.delete_identity(domain) template = self.response_template(DELETE_IDENTITY_RESPONSE) return template.render() def send_email(self): - bodydatakey = 'Message.Body.Text.Data' - if 'Message.Body.Html.Data' in self.querystring: - bodydatakey = 'Message.Body.Html.Data' + bodydatakey = "Message.Body.Text.Data" + if "Message.Body.Html.Data" in self.querystring: + bodydatakey = "Message.Body.Html.Data" body = self.querystring.get(bodydatakey)[0] - source = self.querystring.get('Source')[0] - subject = self.querystring.get('Message.Subject.Data')[0] - destinations = { - 'ToAddresses': [], - 'CcAddresses': [], - 'BccAddresses': [], - } + source = self.querystring.get("Source")[0] + subject = self.querystring.get("Message.Subject.Data")[0] + destinations = {"ToAddresses": [], "CcAddresses": [], "BccAddresses": []} for dest_type in destinations: # consume up to 51 to allow exception for i in six.moves.range(1, 52): - field = 'Destination.%s.member.%s' % (dest_type, i) + field = "Destination.%s.member.%s" % (dest_type, i) address = self.querystring.get(field) if address is None: break destinations[dest_type].append(address[0]) - message = ses_backend.send_email(source, subject, body, destinations, self.region) + message = ses_backend.send_email( + source, subject, body, destinations, self.region + ) template = self.response_template(SEND_EMAIL_RESPONSE) return template.render(message=message) - def send_raw_email(self): - source = self.querystring.get('Source') - if source is not None: - source, = source + def send_templated_email(self): + source = self.querystring.get("Source")[0] + template = self.querystring.get("Template") + template_data = self.querystring.get("TemplateData") - raw_data = self.querystring.get('RawMessage.Data')[0] + destinations = {"ToAddresses": [], "CcAddresses": [], "BccAddresses": []} + for dest_type in destinations: + # consume up to 51 to allow exception + for i in six.moves.range(1, 52): + field = "Destination.%s.member.%s" % (dest_type, i) + address = self.querystring.get(field) + if address is None: + break + destinations[dest_type].append(address[0]) + + message = ses_backend.send_templated_email( + source, template, template_data, destinations, self.region + ) + template = self.response_template(SEND_TEMPLATED_EMAIL_RESPONSE) + return template.render(message=message) + + def send_raw_email(self): + source = self.querystring.get("Source") + if source is not None: + (source,) = source + + raw_data = self.querystring.get("RawMessage.Data")[0] raw_data = base64.b64decode(raw_data) if six.PY3: - raw_data = raw_data.decode('utf-8') + raw_data = raw_data.decode("utf-8") destinations = [] # consume up to 51 to allow exception for i in six.moves.range(1, 52): - field = 'Destinations.member.%s' % i + field = "Destinations.member.%s" % i address = self.querystring.get(field) if address is None: break destinations.append(address[0]) - message = ses_backend.send_raw_email(source, destinations, raw_data, self.region) + message = ses_backend.send_raw_email( + source, destinations, raw_data, self.region + ) template = self.response_template(SEND_RAW_EMAIL_RESPONSE) return template.render(message=message) @@ -193,6 +213,15 @@ SEND_EMAIL_RESPONSE = """ + + {{ message.id }} + + + d5964849-c866-11e0-9beb-01a62d68c57f + +""" + SEND_RAW_EMAIL_RESPONSE = """ {{ message.id }} diff --git a/moto/ses/urls.py b/moto/ses/urls.py index adfb4c6e4..5c26d2152 100644 --- a/moto/ses/urls.py +++ b/moto/ses/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import EmailResponse -url_bases = [ - "https?://email.(.+).amazonaws.com", - "https?://ses.(.+).amazonaws.com", -] +url_bases = ["https?://email.(.+).amazonaws.com", "https?://ses.(.+).amazonaws.com"] -url_paths = { - '{0}/$': EmailResponse.dispatch, -} +url_paths = {"{0}/$": EmailResponse.dispatch} diff --git a/moto/ses/utils.py b/moto/ses/utils.py index c674892d1..6d9151cea 100644 --- a/moto/ses/utils.py +++ b/moto/ses/utils.py @@ -4,16 +4,16 @@ import string def random_hex(length): - return ''.join(random.choice(string.ascii_lowercase) for x in range(length)) + return "".join(random.choice(string.ascii_lowercase) for x in range(length)) def get_random_message_id(): return "{0}-{1}-{2}-{3}-{4}-{5}-{6}".format( - random_hex(16), - random_hex(8), - random_hex(4), - random_hex(4), - random_hex(4), - random_hex(12), - random_hex(6), + random_hex(16), + random_hex(8), + random_hex(4), + random_hex(4), + random_hex(4), + random_hex(12), + random_hex(6), ) diff --git a/moto/settings.py b/moto/settings.py index 12402dc80..707c61397 100644 --- a/moto/settings.py +++ b/moto/settings.py @@ -1,4 +1,6 @@ import os -TEST_SERVER_MODE = os.environ.get('TEST_SERVER_MODE', '0').lower() == 'true' -INITIAL_NO_AUTH_ACTION_COUNT = float(os.environ.get('INITIAL_NO_AUTH_ACTION_COUNT', float('inf'))) +TEST_SERVER_MODE = os.environ.get("TEST_SERVER_MODE", "0").lower() == "true" +INITIAL_NO_AUTH_ACTION_COUNT = float( + os.environ.get("INITIAL_NO_AUTH_ACTION_COUNT", float("inf")) +) diff --git a/moto/sns/__init__.py b/moto/sns/__init__.py index bd36cb23d..896735b43 100644 --- a/moto/sns/__init__.py +++ b/moto/sns/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import sns_backends from ..core.models import base_decorator, deprecated_base_decorator -sns_backend = sns_backends['us-east-1'] +sns_backend = sns_backends["us-east-1"] mock_sns = base_decorator(sns_backends) mock_sns_deprecated = deprecated_base_decorator(sns_backends) diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 0e7a0bdcf..187865220 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -6,37 +6,58 @@ class SNSNotFoundError(RESTError): code = 404 def __init__(self, message): - super(SNSNotFoundError, self).__init__( - "NotFound", message) + super(SNSNotFoundError, self).__init__("NotFound", message) + + +class ResourceNotFoundError(RESTError): + code = 404 + + def __init__(self): + super(ResourceNotFoundError, self).__init__( + "ResourceNotFound", "Resource does not exist" + ) class DuplicateSnsEndpointError(RESTError): code = 400 def __init__(self, message): - super(DuplicateSnsEndpointError, self).__init__( - "DuplicateEndpoint", message) + super(DuplicateSnsEndpointError, self).__init__("DuplicateEndpoint", message) class SnsEndpointDisabled(RESTError): code = 400 def __init__(self, message): - super(SnsEndpointDisabled, self).__init__( - "EndpointDisabled", message) + super(SnsEndpointDisabled, self).__init__("EndpointDisabled", message) class SNSInvalidParameter(RESTError): code = 400 def __init__(self, message): - super(SNSInvalidParameter, self).__init__( - "InvalidParameter", message) + super(SNSInvalidParameter, self).__init__("InvalidParameter", message) class InvalidParameterValue(RESTError): code = 400 def __init__(self, message): - super(InvalidParameterValue, self).__init__( - "InvalidParameterValue", message) + super(InvalidParameterValue, self).__init__("InvalidParameterValue", message) + + +class TagLimitExceededError(RESTError): + code = 400 + + def __init__(self): + super(TagLimitExceededError, self).__init__( + "TagLimitExceeded", + "Could not complete request: tag quota of per resource exceeded", + ) + + +class InternalError(RESTError): + code = 500 + + def __init__(self, message): + super(InternalError, self).__init__("InternalFailure", message) diff --git a/moto/sns/models.py b/moto/sns/models.py index f1293eb0f..cdc50f640 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -12,49 +12,66 @@ from boto3 import Session from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel -from moto.core.utils import iso_8601_datetime_with_milliseconds, camelcase_to_underscores +from moto.core.utils import ( + iso_8601_datetime_with_milliseconds, + camelcase_to_underscores, +) from moto.sqs import sqs_backends from moto.awslambda import lambda_backends from .exceptions import ( - SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter, - InvalidParameterValue + SNSNotFoundError, + DuplicateSnsEndpointError, + SnsEndpointDisabled, + SNSInvalidParameter, + InvalidParameterValue, + InternalError, + ResourceNotFoundError, + TagLimitExceededError, ) -from .utils import make_arn_for_topic, make_arn_for_subscription +from .utils import make_arn_for_topic, make_arn_for_subscription, is_e164 + +from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID -DEFAULT_ACCOUNT_ID = 123456789012 DEFAULT_PAGE_SIZE = 100 MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB class Topic(BaseModel): - def __init__(self, name, sns_backend): self.name = name self.sns_backend = sns_backend self.account_id = DEFAULT_ACCOUNT_ID self.display_name = "" - self.policy = json.dumps(DEFAULT_TOPIC_POLICY) self.delivery_policy = "" self.effective_delivery_policy = json.dumps(DEFAULT_EFFECTIVE_DELIVERY_POLICY) - self.arn = make_arn_for_topic( - self.account_id, name, sns_backend.region_name) + self.arn = make_arn_for_topic(self.account_id, name, sns_backend.region_name) self.subscriptions_pending = 0 self.subscriptions_confimed = 0 self.subscriptions_deleted = 0 + self._policy_json = self._create_default_topic_policy( + sns_backend.region_name, self.account_id, name + ) + self._tags = {} + def publish(self, message, subject=None, message_attributes=None): message_id = six.text_type(uuid.uuid4()) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) for subscription in subscriptions: - subscription.publish(message, message_id, subject=subject, - message_attributes=message_attributes) + subscription.publish( + message, + message_id, + subject=subject, + message_attributes=message_attributes, + ) return message_id def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'TopicName': + + if attribute_name == "TopicName": return self.name raise UnformattedGetAttTemplateException() @@ -62,22 +79,56 @@ class Topic(BaseModel): def physical_resource_id(self): return self.arn - @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - sns_backend = sns_backends[region_name] - properties = cloudformation_json['Properties'] + @property + def policy(self): + return json.dumps(self._policy_json) - topic = sns_backend.create_topic( - properties.get("TopicName") - ) + @policy.setter + def policy(self, policy): + self._policy_json = json.loads(policy) + + @classmethod + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + sns_backend = sns_backends[region_name] + properties = cloudformation_json["Properties"] + + topic = sns_backend.create_topic(properties.get("TopicName")) for subscription in properties.get("Subscription", []): - sns_backend.subscribe(topic.arn, subscription[ - 'Endpoint'], subscription['Protocol']) + sns_backend.subscribe( + topic.arn, subscription["Endpoint"], subscription["Protocol"] + ) return topic + def _create_default_topic_policy(self, region_name, account_id, name): + return { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": make_arn_for_topic(self.account_id, name, region_name), + "Condition": {"StringEquals": {"AWS:SourceOwner": str(account_id)}}, + } + ], + } + class Subscription(BaseModel): - def __init__(self, topic, endpoint, protocol): self.topic = topic self.endpoint = endpoint @@ -87,39 +138,54 @@ class Subscription(BaseModel): self._filter_policy = None # filter policy as a dict, not json. self.confirmed = False - def publish(self, message, message_id, subject=None, - message_attributes=None): + def publish(self, message, message_id, subject=None, message_attributes=None): if not self._matches_filter_policy(message_attributes): return - if self.protocol == 'sqs': + if self.protocol == "sqs": queue_name = self.endpoint.split(":")[-1] region = self.endpoint.split(":")[3] - if self.attributes.get('RawMessageDelivery') != 'true': - enveloped_message = json.dumps(self.get_post_data(message, message_id, subject, message_attributes=message_attributes), sort_keys=True, indent=2, separators=(',', ': ')) + if self.attributes.get("RawMessageDelivery") != "true": + enveloped_message = json.dumps( + self.get_post_data( + message, + message_id, + subject, + message_attributes=message_attributes, + ), + sort_keys=True, + indent=2, + separators=(",", ": "), + ) else: enveloped_message = message sqs_backends[region].send_message(queue_name, enveloped_message) - elif self.protocol in ['http', 'https']: + elif self.protocol in ["http", "https"]: post_data = self.get_post_data(message, message_id, subject) - requests.post(self.endpoint, json=post_data, headers={'Content-Type': 'text/plain; charset=UTF-8'}) - elif self.protocol == 'lambda': + requests.post( + self.endpoint, + json=post_data, + headers={"Content-Type": "text/plain; charset=UTF-8"}, + ) + elif self.protocol == "lambda": # TODO: support bad function name # http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html arr = self.endpoint.split(":") region = arr[3] qualifier = None if len(arr) == 7: - assert arr[5] == 'function' + assert arr[5] == "function" function_name = arr[-1] elif len(arr) == 8: - assert arr[5] == 'function' + assert arr[5] == "function" qualifier = arr[-1] function_name = arr[-2] else: assert False - lambda_backends[region].send_sns_message(function_name, message, subject=subject, qualifier=qualifier) + lambda_backends[region].send_sns_message( + function_name, message, subject=subject, qualifier=qualifier + ) def _matches_filter_policy(self, message_attributes): # TODO: support Anything-but matching, prefix matching and @@ -131,31 +197,72 @@ class Subscription(BaseModel): message_attributes = {} def _field_match(field, rules, message_attributes): - if field not in message_attributes: - return False for rule in rules: + # TODO: boolean value matching is not supported, SNS behavior unknown if isinstance(rule, six.string_types): - # only string value matching is supported - if message_attributes[field]['Value'] == rule: + if field not in message_attributes: + return False + if message_attributes[field]["Value"] == rule: return True + try: + json_data = json.loads(message_attributes[field]["Value"]) + if rule in json_data: + return True + except (ValueError, TypeError): + pass + if isinstance(rule, (six.integer_types, float)): + if field not in message_attributes: + return False + if message_attributes[field]["Type"] == "Number": + attribute_values = [message_attributes[field]["Value"]] + elif message_attributes[field]["Type"] == "String.Array": + try: + attribute_values = json.loads( + message_attributes[field]["Value"] + ) + if not isinstance(attribute_values, list): + attribute_values = [attribute_values] + except (ValueError, TypeError): + return False + else: + return False + + for attribute_values in attribute_values: + # Even the official documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6 + # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + if int(attribute_values * 1000000) == int(rule * 1000000): + return True + if isinstance(rule, dict): + keyword = list(rule.keys())[0] + attributes = list(rule.values())[0] + if keyword == "exists": + if attributes and field in message_attributes: + return True + elif not attributes and field not in message_attributes: + return True return False - return all(_field_match(field, rules, message_attributes) - for field, rules in six.iteritems(self._filter_policy)) + return all( + _field_match(field, rules, message_attributes) + for field, rules in six.iteritems(self._filter_policy) + ) - def get_post_data( - self, message, message_id, subject, message_attributes=None): + def get_post_data(self, message, message_id, subject, message_attributes=None): post_data = { "Type": "Notification", "MessageId": message_id, "TopicArn": self.topic.arn, "Subject": subject or "my subject", "Message": message, - "Timestamp": iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()), + "Timestamp": iso_8601_datetime_with_milliseconds( + datetime.datetime.utcnow() + ), "SignatureVersion": "1", "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=", "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem", - "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55" + "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:{}:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55".format( + DEFAULT_ACCOUNT_ID + ), } if message_attributes: post_data["MessageAttributes"] = message_attributes @@ -163,7 +270,6 @@ class Subscription(BaseModel): class PlatformApplication(BaseModel): - def __init__(self, region, name, platform, attributes): self.region = region self.name = name @@ -172,15 +278,15 @@ class PlatformApplication(BaseModel): @property def arn(self): - return "arn:aws:sns:{region}:123456789012:app/{platform}/{name}".format( + return "arn:aws:sns:{region}:{AccountId}:app/{platform}/{name}".format( region=self.region, platform=self.platform, name=self.name, + AccountId=DEFAULT_ACCOUNT_ID, ) class PlatformEndpoint(BaseModel): - def __init__(self, region, application, custom_user_data, token, attributes): self.region = region self.application = application @@ -194,19 +300,20 @@ class PlatformEndpoint(BaseModel): def __fixup_attributes(self): # When AWS returns the attributes dict, it always contains these two elements, so we need to # automatically ensure they exist as well. - if 'Token' not in self.attributes: - self.attributes['Token'] = self.token - if 'Enabled' not in self.attributes: - self.attributes['Enabled'] = 'True' + if "Token" not in self.attributes: + self.attributes["Token"] = self.token + if "Enabled" not in self.attributes: + self.attributes["Enabled"] = "True" @property def enabled(self): - return json.loads(self.attributes.get('Enabled', 'true').lower()) + return json.loads(self.attributes.get("Enabled", "true").lower()) @property def arn(self): - return "arn:aws:sns:{region}:123456789012:endpoint/{platform}/{name}/{id}".format( + return "arn:aws:sns:{region}:{AccountId}:endpoint/{platform}/{name}/{id}".format( region=self.region, + AccountId=DEFAULT_ACCOUNT_ID, platform=self.application.platform, name=self.application.name, id=self.id, @@ -223,7 +330,6 @@ class PlatformEndpoint(BaseModel): class SNSBackend(BaseBackend): - def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() @@ -232,8 +338,16 @@ class SNSBackend(BaseBackend): self.platform_endpoints = {} self.region_name = region_name self.sms_attributes = {} - self.opt_out_numbers = ['+447420500600', '+447420505401', '+447632960543', '+447632960028', '+447700900149', '+447700900550', '+447700900545', '+447700900907'] - self.permissions = {} + self.opt_out_numbers = [ + "+447420500600", + "+447420505401", + "+447632960543", + "+447632960028", + "+447700900149", + "+447700900550", + "+447700900545", + "+447700900907", + ] def reset(self): region_name = self.region_name @@ -243,14 +357,22 @@ class SNSBackend(BaseBackend): def update_sms_attributes(self, attrs): self.sms_attributes.update(attrs) - def create_topic(self, name, attributes=None): - fails_constraints = not re.match(r'^[a-zA-Z0-9_-]{1,256}$', name) + def create_topic(self, name, attributes=None, tags=None): + fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}$", name) if fails_constraints: - raise InvalidParameterValue("Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long.") + raise InvalidParameterValue( + "Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long." + ) candidate_topic = Topic(name, self) if attributes: for attribute in attributes: - setattr(candidate_topic, camelcase_to_underscores(attribute), attributes[attribute]) + setattr( + candidate_topic, + camelcase_to_underscores(attribute), + attributes[attribute], + ) + if tags: + candidate_topic._tags = tags if candidate_topic.arn in self.topics: return self.topics[candidate_topic.arn] else: @@ -261,8 +383,7 @@ class SNSBackend(BaseBackend): if next_token is None or not next_token: next_token = 0 next_token = int(next_token) - values = list(values_map.values())[ - next_token: next_token + DEFAULT_PAGE_SIZE] + values = list(values_map.values())[next_token : next_token + DEFAULT_PAGE_SIZE] if len(values) == DEFAULT_PAGE_SIZE: next_token = next_token + DEFAULT_PAGE_SIZE else: @@ -290,27 +411,57 @@ class SNSBackend(BaseBackend): def get_topic_from_phone_number(self, number): for subscription in self.subscriptions.values(): - if subscription.protocol == 'sms' and subscription.endpoint == number: + if subscription.protocol == "sms" and subscription.endpoint == number: return subscription.topic.arn - raise SNSNotFoundError('Could not find valid subscription') + raise SNSNotFoundError("Could not find valid subscription") def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) setattr(topic, attribute_name, attribute_value) def subscribe(self, topic_arn, endpoint, protocol): + if protocol == "sms": + if re.search(r"[./-]{2,}", endpoint) or re.search( + r"(^[./-]|[./-]$)", endpoint + ): + raise SNSInvalidParameter("Invalid SMS endpoint: {}".format(endpoint)) + + reduced_endpoint = re.sub(r"[./-]", "", endpoint) + + if not is_e164(reduced_endpoint): + raise SNSInvalidParameter("Invalid SMS endpoint: {}".format(endpoint)) + # AWS doesn't create duplicates old_subscription = self._find_subscription(topic_arn, endpoint, protocol) if old_subscription: return old_subscription topic = self.get_topic(topic_arn) subscription = Subscription(topic, endpoint, protocol) + attributes = { + "PendingConfirmation": "false", + "ConfirmationWasAuthenticated": "true", + "Endpoint": endpoint, + "TopicArn": topic_arn, + "Protocol": protocol, + "SubscriptionArn": subscription.arn, + "Owner": DEFAULT_ACCOUNT_ID, + "RawMessageDelivery": "false", + } + + if protocol in ["http", "https"]: + attributes["EffectiveDeliveryPolicy"] = topic.effective_delivery_policy + + subscription.attributes = attributes self.subscriptions[subscription.arn] = subscription return subscription def _find_subscription(self, topic_arn, endpoint, protocol): for subscription in self.subscriptions.values(): - if subscription.topic.arn == topic_arn and subscription.endpoint == endpoint and subscription.protocol == protocol: + if ( + subscription.topic.arn == topic_arn + and subscription.endpoint == endpoint + and subscription.protocol == protocol + ): return subscription return None @@ -321,7 +472,8 @@ class SNSBackend(BaseBackend): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict( - [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)]) + [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)] + ) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) @@ -329,15 +481,18 @@ class SNSBackend(BaseBackend): def publish(self, arn, message, subject=None, message_attributes=None): if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503 - raise ValueError('Subject must be less than 100 characters') + raise ValueError("Subject must be less than 100 characters") if len(message) > MAXIMUM_MESSAGE_LENGTH: - raise InvalidParameterValue("An error occurred (InvalidParameter) when calling the Publish operation: Invalid parameter: Message too long") + raise InvalidParameterValue( + "An error occurred (InvalidParameter) when calling the Publish operation: Invalid parameter: Message too long" + ) try: topic = self.get_topic(arn) - message_id = topic.publish(message, subject=subject, - message_attributes=message_attributes) + message_id = topic.publish( + message, subject=subject, message_attributes=message_attributes + ) except SNSNotFoundError: endpoint = self.get_endpoint(arn) message_id = endpoint.publish(message) @@ -352,8 +507,7 @@ class SNSBackend(BaseBackend): try: return self.applications[arn] except KeyError: - raise SNSNotFoundError( - "Application with arn {0} not found".format(arn)) + raise SNSNotFoundError("Application with arn {0} not found".format(arn)) def set_application_attributes(self, arn, attributes): application = self.get_application(arn) @@ -366,18 +520,23 @@ class SNSBackend(BaseBackend): def delete_platform_application(self, platform_arn): self.applications.pop(platform_arn) - def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): - if any(token == endpoint.token for endpoint in self.platform_endpoints.values()): + def create_platform_endpoint( + self, region, application, custom_user_data, token, attributes + ): + if any( + token == endpoint.token for endpoint in self.platform_endpoints.values() + ): raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) platform_endpoint = PlatformEndpoint( - region, application, custom_user_data, token, attributes) + region, application, custom_user_data, token, attributes + ) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint def list_endpoints_by_platform_application(self, application_arn): return [ - endpoint for endpoint - in self.platform_endpoints.values() + endpoint + for endpoint in self.platform_endpoints.values() if endpoint.application.arn == application_arn ] @@ -385,8 +544,7 @@ class SNSBackend(BaseBackend): try: return self.platform_endpoints[arn] except KeyError: - raise SNSNotFoundError( - "Endpoint with arn {0} not found".format(arn)) + raise SNSNotFoundError("Endpoint with arn {0} not found".format(arn)) def set_endpoint_attributes(self, arn, attributes): endpoint = self.get_endpoint(arn) @@ -397,8 +555,7 @@ class SNSBackend(BaseBackend): try: del self.platform_endpoints[arn] except KeyError: - raise SNSNotFoundError( - "Endpoint with arn {0} not found".format(arn)) + raise SNSNotFoundError("Endpoint with arn {0} not found".format(arn)) def get_subscription_attributes(self, arn): _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] @@ -409,8 +566,8 @@ class SNSBackend(BaseBackend): return subscription.attributes def set_subscription_attributes(self, arn, name, value): - if name not in ['RawMessageDelivery', 'DeliveryPolicy', 'FilterPolicy']: - raise SNSInvalidParameter('AttributeName') + if name not in ["RawMessageDelivery", "DeliveryPolicy", "FilterPolicy"]: + raise SNSInvalidParameter("AttributeName") # TODO: should do validation _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] @@ -420,55 +577,162 @@ class SNSBackend(BaseBackend): subscription.attributes[name] = value - if name == 'FilterPolicy': - subscription._filter_policy = json.loads(value) + if name == "FilterPolicy": + filter_policy = json.loads(value) + self._validate_filter_policy(filter_policy) + subscription._filter_policy = filter_policy + + def _validate_filter_policy(self, value): + # TODO: extend validation checks + combinations = 1 + for rules in six.itervalues(value): + combinations *= len(rules) + # Even the official documentation states the total combination of values must not exceed 100, in reality it is 150 + # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + if combinations > 150: + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: Filter policy is too complex" + ) + + for field, rules in six.iteritems(value): + for rule in rules: + if rule is None: + continue + if isinstance(rule, six.string_types): + continue + if isinstance(rule, bool): + continue + if isinstance(rule, (six.integer_types, float)): + if rule <= -1000000000 or rule >= 1000000000: + raise InternalError("Unknown") + continue + if isinstance(rule, dict): + keyword = list(rule.keys())[0] + attributes = list(rule.values())[0] + if keyword == "anything-but": + continue + elif keyword == "exists": + if not isinstance(attributes, bool): + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: exists match pattern must be either true or false." + ) + continue + elif keyword == "numeric": + continue + elif keyword == "prefix": + continue + else: + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: Unrecognized match type {type}".format( + type=keyword + ) + ) + + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" + ) + + def add_permission(self, topic_arn, label, aws_account_ids, action_names): + if topic_arn not in self.topics: + raise SNSNotFoundError("Topic does not exist") + + policy = self.topics[topic_arn]._policy_json + statement = next( + ( + statement + for statement in policy["Statement"] + if statement["Sid"] == label + ), + None, + ) + + if statement: + raise SNSInvalidParameter("Statement already exists") + + if any(action_name not in VALID_POLICY_ACTIONS for action_name in action_names): + raise SNSInvalidParameter("Policy statement action out of service scope!") + + principals = [ + "arn:aws:iam::{}:root".format(account_id) for account_id in aws_account_ids + ] + actions = ["SNS:{}".format(action_name) for action_name in action_names] + + statement = { + "Sid": label, + "Effect": "Allow", + "Principal": {"AWS": principals[0] if len(principals) == 1 else principals}, + "Action": actions[0] if len(actions) == 1 else actions, + "Resource": topic_arn, + } + + self.topics[topic_arn]._policy_json["Statement"].append(statement) + + def remove_permission(self, topic_arn, label): + if topic_arn not in self.topics: + raise SNSNotFoundError("Topic does not exist") + + statements = self.topics[topic_arn]._policy_json["Statement"] + statements = [ + statement for statement in statements if statement["Sid"] != label + ] + + self.topics[topic_arn]._policy_json["Statement"] = statements + + def list_tags_for_resource(self, resource_arn): + if resource_arn not in self.topics: + raise ResourceNotFoundError + + return self.topics[resource_arn]._tags + + def tag_resource(self, resource_arn, tags): + if resource_arn not in self.topics: + raise ResourceNotFoundError + + updated_tags = self.topics[resource_arn]._tags.copy() + updated_tags.update(tags) + + if len(updated_tags) > 50: + raise TagLimitExceededError + + self.topics[resource_arn]._tags = updated_tags + + def untag_resource(self, resource_arn, tag_keys): + if resource_arn not in self.topics: + raise ResourceNotFoundError + + for key in tag_keys: + self.topics[resource_arn]._tags.pop(key, None) sns_backends = {} -for region in Session().get_available_regions('sns'): +for region in Session().get_available_regions("sns"): sns_backends[region] = SNSBackend(region) -DEFAULT_TOPIC_POLICY = { - "Version": "2008-10-17", - "Id": "us-east-1/698519295917/test__default_policy_ID", - "Statement": [{ - "Effect": "Allow", - "Sid": "us-east-1/698519295917/test__default_statement_ID", - "Principal": { - "AWS": "*" - }, - "Action": [ - "SNS:GetTopicAttributes", - "SNS:SetTopicAttributes", - "SNS:AddPermission", - "SNS:RemovePermission", - "SNS:DeleteTopic", - "SNS:Subscribe", - "SNS:ListSubscriptionsByTopic", - "SNS:Publish", - "SNS:Receive", - ], - "Resource": "arn:aws:sns:us-east-1:698519295917:test", - "Condition": { - "StringLike": { - "AWS:SourceArn": "arn:aws:*:*:698519295917:*" - } - } - }] +DEFAULT_EFFECTIVE_DELIVERY_POLICY = { + "defaultHealthyRetryPolicy": { + "numNoDelayRetries": 0, + "numMinDelayRetries": 0, + "minDelayTarget": 20, + "maxDelayTarget": 20, + "numMaxDelayRetries": 0, + "numRetries": 3, + "backoffFunction": "linear", + }, + "sicklyRetryPolicy": None, + "throttlePolicy": None, + "guaranteed": False, } -DEFAULT_EFFECTIVE_DELIVERY_POLICY = { - 'http': { - 'disableSubscriptionOverrides': False, - 'defaultHealthyRetryPolicy': { - 'numNoDelayRetries': 0, - 'numMinDelayRetries': 0, - 'minDelayTarget': 20, - 'maxDelayTarget': 20, - 'numMaxDelayRetries': 0, - 'numRetries': 3, - 'backoffFunction': 'linear' - } - } -} + +VALID_POLICY_ACTIONS = [ + "GetTopicAttributes", + "SetTopicAttributes", + "AddPermission", + "RemovePermission", + "DeleteTopic", + "Subscribe", + "ListSubscriptionsByTopic", + "Publish", + "Receive", +] diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 440115429..c2eb3e7c3 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -11,548 +11,612 @@ from .utils import is_e164 class SNSResponse(BaseResponse): - SMS_ATTR_REGEX = re.compile(r'^attributes\.entry\.(?P\d+)\.(?Pkey|value)$') - OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r'^\+?\d+$') + SMS_ATTR_REGEX = re.compile( + r"^attributes\.entry\.(?P\d+)\.(?Pkey|value)$" + ) + OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r"^\+?\d+$") @property def backend(self): return sns_backends[self.region] - def _error(self, code, message, sender='Sender'): + def _error(self, code, message, sender="Sender"): template = self.response_template(ERROR_RESPONSE) return template.render(code=code, message=message, sender=sender) def _get_attributes(self): - attributes = self._get_list_prefix('Attributes.entry') - return dict( - (attribute['key'], attribute['value']) - for attribute - in attributes - ) + attributes = self._get_list_prefix("Attributes.entry") + return dict((attribute["key"], attribute["value"]) for attribute in attributes) - def _parse_message_attributes(self, prefix='', value_namespace='Value.'): + def _get_tags(self): + tags = self._get_list_prefix("Tags.member") + return {tag["key"]: tag["value"] for tag in tags} + + def _parse_message_attributes(self, prefix="", value_namespace="Value."): message_attributes = self._get_object_map( - 'MessageAttributes.entry', - name='Name', - value='Value' + "MessageAttributes.entry", name="Name", value="Value" ) # SNS converts some key names before forwarding messages # DataType -> Type, StringValue -> Value, BinaryValue -> Value transformed_message_attributes = {} for name, value in message_attributes.items(): # validation - data_type = value['DataType'] + data_type = value["DataType"] if not data_type: raise InvalidParameterValue( "The message attribute '{0}' must contain non-empty " - "message attribute value.".format(name)) + "message attribute value.".format(name) + ) - data_type_parts = data_type.split('.') - if (len(data_type_parts) > 2 or - data_type_parts[0] not in ['String', 'Binary', 'Number']): + data_type_parts = data_type.split(".") + if len(data_type_parts) > 2 or data_type_parts[0] not in [ + "String", + "Binary", + "Number", + ]: raise InvalidParameterValue( "The message attribute '{0}' has an invalid message " "attribute type, the set of supported type prefixes is " - "Binary, Number, and String.".format(name)) + "Binary, Number, and String.".format(name) + ) transform_value = None - if 'StringValue' in value: - transform_value = value['StringValue'] - elif 'BinaryValue' in value: - transform_value = value['BinaryValue'] - if not transform_value: + if "StringValue" in value: + if data_type == "Number": + try: + transform_value = float(value["StringValue"]) + except ValueError: + raise InvalidParameterValue( + "An error occurred (ParameterValueInvalid) " + "when calling the Publish operation: " + "Could not cast message attribute '{0}' value to number.".format( + name + ) + ) + else: + transform_value = value["StringValue"] + elif "BinaryValue" in value: + transform_value = value["BinaryValue"] + if transform_value == "": raise InvalidParameterValue( "The message attribute '{0}' must contain non-empty " "message attribute value for message attribute " - "type '{1}'.".format(name, data_type[0])) + "type '{1}'.".format(name, data_type[0]) + ) # transformation transformed_message_attributes[name] = { - 'Type': data_type, 'Value': transform_value + "Type": data_type, + "Value": transform_value, } return transformed_message_attributes def create_topic(self): - name = self._get_param('Name') + name = self._get_param("Name") attributes = self._get_attributes() - topic = self.backend.create_topic(name, attributes) + tags = self._get_tags() + topic = self.backend.create_topic(name, attributes, tags) if self.request_json: - return json.dumps({ - 'CreateTopicResponse': { - 'CreateTopicResult': { - 'TopicArn': topic.arn, - }, - 'ResponseMetadata': { - 'RequestId': 'a8dec8b3-33a4-11df-8963-01868b7c937a', + return json.dumps( + { + "CreateTopicResponse": { + "CreateTopicResult": {"TopicArn": topic.arn}, + "ResponseMetadata": { + "RequestId": "a8dec8b3-33a4-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(CREATE_TOPIC_TEMPLATE) return template.render(topic=topic) def list_topics(self): - next_token = self._get_param('NextToken') + next_token = self._get_param("NextToken") topics, next_token = self.backend.list_topics(next_token=next_token) if self.request_json: - return json.dumps({ - 'ListTopicsResponse': { - 'ListTopicsResult': { - 'Topics': [{'TopicArn': topic.arn} for topic in topics], - 'NextToken': next_token, - } - }, - 'ResponseMetadata': { - 'RequestId': 'a8dec8b3-33a4-11df-8963-01868b7c937a', + return json.dumps( + { + "ListTopicsResponse": { + "ListTopicsResult": { + "Topics": [{"TopicArn": topic.arn} for topic in topics], + "NextToken": next_token, + } + }, + "ResponseMetadata": { + "RequestId": "a8dec8b3-33a4-11df-8963-01868b7c937a" + }, } - }) + ) template = self.response_template(LIST_TOPICS_TEMPLATE) return template.render(topics=topics, next_token=next_token) def delete_topic(self): - topic_arn = self._get_param('TopicArn') + topic_arn = self._get_param("TopicArn") self.backend.delete_topic(topic_arn) if self.request_json: - return json.dumps({ - 'DeleteTopicResponse': { - 'ResponseMetadata': { - 'RequestId': 'a8dec8b3-33a4-11df-8963-01868b7c937a', + return json.dumps( + { + "DeleteTopicResponse": { + "ResponseMetadata": { + "RequestId": "a8dec8b3-33a4-11df-8963-01868b7c937a" + } } } - }) + ) template = self.response_template(DELETE_TOPIC_TEMPLATE) return template.render() def get_topic_attributes(self): - topic_arn = self._get_param('TopicArn') + topic_arn = self._get_param("TopicArn") topic = self.backend.get_topic(topic_arn) if self.request_json: - return json.dumps({ - "GetTopicAttributesResponse": { - "GetTopicAttributesResult": { - "Attributes": { - "Owner": topic.account_id, - "Policy": topic.policy, - "TopicArn": topic.arn, - "DisplayName": topic.display_name, - "SubscriptionsPending": topic.subscriptions_pending, - "SubscriptionsConfirmed": topic.subscriptions_confimed, - "SubscriptionsDeleted": topic.subscriptions_deleted, - "DeliveryPolicy": topic.delivery_policy, - "EffectiveDeliveryPolicy": topic.effective_delivery_policy, - } - }, - "ResponseMetadata": { - "RequestId": "057f074c-33a7-11df-9540-99d0768312d3" + return json.dumps( + { + "GetTopicAttributesResponse": { + "GetTopicAttributesResult": { + "Attributes": { + "Owner": topic.account_id, + "Policy": topic.policy, + "TopicArn": topic.arn, + "DisplayName": topic.display_name, + "SubscriptionsPending": topic.subscriptions_pending, + "SubscriptionsConfirmed": topic.subscriptions_confimed, + "SubscriptionsDeleted": topic.subscriptions_deleted, + "DeliveryPolicy": topic.delivery_policy, + "EffectiveDeliveryPolicy": topic.effective_delivery_policy, + } + }, + "ResponseMetadata": { + "RequestId": "057f074c-33a7-11df-9540-99d0768312d3" + }, } } - }) + ) template = self.response_template(GET_TOPIC_ATTRIBUTES_TEMPLATE) return template.render(topic=topic) def set_topic_attributes(self): - topic_arn = self._get_param('TopicArn') - attribute_name = self._get_param('AttributeName') + topic_arn = self._get_param("TopicArn") + attribute_name = self._get_param("AttributeName") attribute_name = camelcase_to_underscores(attribute_name) - attribute_value = self._get_param('AttributeValue') - self.backend.set_topic_attribute( - topic_arn, attribute_name, attribute_value) + attribute_value = self._get_param("AttributeValue") + self.backend.set_topic_attribute(topic_arn, attribute_name, attribute_value) if self.request_json: - return json.dumps({ - "SetTopicAttributesResponse": { - "ResponseMetadata": { - "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + return json.dumps( + { + "SetTopicAttributesResponse": { + "ResponseMetadata": { + "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + } } } - }) + ) template = self.response_template(SET_TOPIC_ATTRIBUTES_TEMPLATE) return template.render() def subscribe(self): - topic_arn = self._get_param('TopicArn') - endpoint = self._get_param('Endpoint') - protocol = self._get_param('Protocol') + topic_arn = self._get_param("TopicArn") + endpoint = self._get_param("Endpoint") + protocol = self._get_param("Protocol") attributes = self._get_attributes() - if protocol == 'sms' and not is_e164(endpoint): - return self._error( - 'InvalidParameter', - 'Phone number does not meet the E164 format' - ), dict(status=400) - subscription = self.backend.subscribe(topic_arn, endpoint, protocol) if attributes is not None: for attr_name, attr_value in attributes.items(): - self.backend.set_subscription_attributes(subscription.arn, attr_name, attr_value) + self.backend.set_subscription_attributes( + subscription.arn, attr_name, attr_value + ) if self.request_json: - return json.dumps({ - "SubscribeResponse": { - "SubscribeResult": { - "SubscriptionArn": subscription.arn, - }, - "ResponseMetadata": { - "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + return json.dumps( + { + "SubscribeResponse": { + "SubscribeResult": {"SubscriptionArn": subscription.arn}, + "ResponseMetadata": { + "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + }, } } - }) + ) template = self.response_template(SUBSCRIBE_TEMPLATE) return template.render(subscription=subscription) def unsubscribe(self): - subscription_arn = self._get_param('SubscriptionArn') + subscription_arn = self._get_param("SubscriptionArn") self.backend.unsubscribe(subscription_arn) if self.request_json: - return json.dumps({ - "UnsubscribeResponse": { - "ResponseMetadata": { - "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + return json.dumps( + { + "UnsubscribeResponse": { + "ResponseMetadata": { + "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + } } } - }) + ) template = self.response_template(UNSUBSCRIBE_TEMPLATE) return template.render() def list_subscriptions(self): - next_token = self._get_param('NextToken') + next_token = self._get_param("NextToken") subscriptions, next_token = self.backend.list_subscriptions( - next_token=next_token) + next_token=next_token + ) if self.request_json: - return json.dumps({ - "ListSubscriptionsResponse": { - "ListSubscriptionsResult": { - "Subscriptions": [{ - "TopicArn": subscription.topic.arn, - "Protocol": subscription.protocol, - "SubscriptionArn": subscription.arn, - "Owner": subscription.topic.account_id, - "Endpoint": subscription.endpoint, - } for subscription in subscriptions], - 'NextToken': next_token, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "ListSubscriptionsResponse": { + "ListSubscriptionsResult": { + "Subscriptions": [ + { + "TopicArn": subscription.topic.arn, + "Protocol": subscription.protocol, + "SubscriptionArn": subscription.arn, + "Owner": subscription.topic.account_id, + "Endpoint": subscription.endpoint, + } + for subscription in subscriptions + ], + "NextToken": next_token, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(LIST_SUBSCRIPTIONS_TEMPLATE) - return template.render(subscriptions=subscriptions, - next_token=next_token) + return template.render(subscriptions=subscriptions, next_token=next_token) def list_subscriptions_by_topic(self): - topic_arn = self._get_param('TopicArn') - next_token = self._get_param('NextToken') + topic_arn = self._get_param("TopicArn") + next_token = self._get_param("NextToken") subscriptions, next_token = self.backend.list_subscriptions( - topic_arn, next_token=next_token) + topic_arn, next_token=next_token + ) if self.request_json: - return json.dumps({ - "ListSubscriptionsByTopicResponse": { - "ListSubscriptionsByTopicResult": { - "Subscriptions": [{ - "TopicArn": subscription.topic.arn, - "Protocol": subscription.protocol, - "SubscriptionArn": subscription.arn, - "Owner": subscription.topic.account_id, - "Endpoint": subscription.endpoint, - } for subscription in subscriptions], - 'NextToken': next_token, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "ListSubscriptionsByTopicResponse": { + "ListSubscriptionsByTopicResult": { + "Subscriptions": [ + { + "TopicArn": subscription.topic.arn, + "Protocol": subscription.protocol, + "SubscriptionArn": subscription.arn, + "Owner": subscription.topic.account_id, + "Endpoint": subscription.endpoint, + } + for subscription in subscriptions + ], + "NextToken": next_token, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE) - return template.render(subscriptions=subscriptions, - next_token=next_token) + return template.render(subscriptions=subscriptions, next_token=next_token) def publish(self): - target_arn = self._get_param('TargetArn') - topic_arn = self._get_param('TopicArn') - phone_number = self._get_param('PhoneNumber') - subject = self._get_param('Subject') + target_arn = self._get_param("TargetArn") + topic_arn = self._get_param("TopicArn") + phone_number = self._get_param("PhoneNumber") + subject = self._get_param("Subject") message_attributes = self._parse_message_attributes() if phone_number is not None: # Check phone is correct syntax (e164) if not is_e164(phone_number): - return self._error( - 'InvalidParameter', - 'Phone number does not meet the E164 format' - ), dict(status=400) + return ( + self._error( + "InvalidParameter", "Phone number does not meet the E164 format" + ), + dict(status=400), + ) # Look up topic arn by phone number try: arn = self.backend.get_topic_from_phone_number(phone_number) except SNSNotFoundError: - return self._error( - 'ParameterValueInvalid', - 'Could not find topic associated with phone number' - ), dict(status=400) + return ( + self._error( + "ParameterValueInvalid", + "Could not find topic associated with phone number", + ), + dict(status=400), + ) elif target_arn is not None: arn = target_arn else: arn = topic_arn - message = self._get_param('Message') + message = self._get_param("Message") try: message_id = self.backend.publish( - arn, message, subject=subject, - message_attributes=message_attributes) + arn, message, subject=subject, message_attributes=message_attributes + ) except ValueError as err: - error_response = self._error('InvalidParameter', str(err)) + error_response = self._error("InvalidParameter", str(err)) return error_response, dict(status=400) if self.request_json: - return json.dumps({ - "PublishResponse": { - "PublishResult": { - "MessageId": message_id, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "PublishResponse": { + "PublishResult": {"MessageId": message_id}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(PUBLISH_TEMPLATE) return template.render(message_id=message_id) def create_platform_application(self): - name = self._get_param('Name') - platform = self._get_param('Platform') + name = self._get_param("Name") + platform = self._get_param("Platform") attributes = self._get_attributes() platform_application = self.backend.create_platform_application( - self.region, name, platform, attributes) + self.region, name, platform, attributes + ) if self.request_json: - return json.dumps({ - "CreatePlatformApplicationResponse": { - "CreatePlatformApplicationResult": { - "PlatformApplicationArn": platform_application.arn, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937b", + return json.dumps( + { + "CreatePlatformApplicationResponse": { + "CreatePlatformApplicationResult": { + "PlatformApplicationArn": platform_application.arn + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937b" + }, } } - }) + ) template = self.response_template(CREATE_PLATFORM_APPLICATION_TEMPLATE) return template.render(platform_application=platform_application) def get_platform_application_attributes(self): - arn = self._get_param('PlatformApplicationArn') + arn = self._get_param("PlatformApplicationArn") application = self.backend.get_application(arn) if self.request_json: - return json.dumps({ - "GetPlatformApplicationAttributesResponse": { - "GetPlatformApplicationAttributesResult": { - "Attributes": application.attributes, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937f", + return json.dumps( + { + "GetPlatformApplicationAttributesResponse": { + "GetPlatformApplicationAttributesResult": { + "Attributes": application.attributes + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937f" + }, } } - }) + ) - template = self.response_template( - GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) + template = self.response_template(GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) return template.render(application=application) def set_platform_application_attributes(self): - arn = self._get_param('PlatformApplicationArn') + arn = self._get_param("PlatformApplicationArn") attributes = self._get_attributes() self.backend.set_application_attributes(arn, attributes) if self.request_json: - return json.dumps({ - "SetPlatformApplicationAttributesResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-12df-8963-01868b7c937f", + return json.dumps( + { + "SetPlatformApplicationAttributesResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-12df-8963-01868b7c937f" + } } } - }) + ) - template = self.response_template( - SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) + template = self.response_template(SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) return template.render() def list_platform_applications(self): applications = self.backend.list_platform_applications() if self.request_json: - return json.dumps({ - "ListPlatformApplicationsResponse": { - "ListPlatformApplicationsResult": { - "PlatformApplications": [{ - "PlatformApplicationArn": application.arn, - "attributes": application.attributes, - } for application in applications], - "NextToken": None - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937c", + return json.dumps( + { + "ListPlatformApplicationsResponse": { + "ListPlatformApplicationsResult": { + "PlatformApplications": [ + { + "PlatformApplicationArn": application.arn, + "attributes": application.attributes, + } + for application in applications + ], + "NextToken": None, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937c" + }, } } - }) + ) template = self.response_template(LIST_PLATFORM_APPLICATIONS_TEMPLATE) return template.render(applications=applications) def delete_platform_application(self): - platform_arn = self._get_param('PlatformApplicationArn') + platform_arn = self._get_param("PlatformApplicationArn") self.backend.delete_platform_application(platform_arn) if self.request_json: - return json.dumps({ - "DeletePlatformApplicationResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937e", + return json.dumps( + { + "DeletePlatformApplicationResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937e" + } } } - }) + ) template = self.response_template(DELETE_PLATFORM_APPLICATION_TEMPLATE) return template.render() def create_platform_endpoint(self): - application_arn = self._get_param('PlatformApplicationArn') + application_arn = self._get_param("PlatformApplicationArn") application = self.backend.get_application(application_arn) - custom_user_data = self._get_param('CustomUserData') - token = self._get_param('Token') + custom_user_data = self._get_param("CustomUserData") + token = self._get_param("Token") attributes = self._get_attributes() platform_endpoint = self.backend.create_platform_endpoint( - self.region, application, custom_user_data, token, attributes) + self.region, application, custom_user_data, token, attributes + ) if self.request_json: - return json.dumps({ - "CreatePlatformEndpointResponse": { - "CreatePlatformEndpointResult": { - "EndpointArn": platform_endpoint.arn, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3779-11df-8963-01868b7c937b", + return json.dumps( + { + "CreatePlatformEndpointResponse": { + "CreatePlatformEndpointResult": { + "EndpointArn": platform_endpoint.arn + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3779-11df-8963-01868b7c937b" + }, } } - }) + ) template = self.response_template(CREATE_PLATFORM_ENDPOINT_TEMPLATE) return template.render(platform_endpoint=platform_endpoint) def list_endpoints_by_platform_application(self): - application_arn = self._get_param('PlatformApplicationArn') - endpoints = self.backend.list_endpoints_by_platform_application( - application_arn) + application_arn = self._get_param("PlatformApplicationArn") + endpoints = self.backend.list_endpoints_by_platform_application(application_arn) if self.request_json: - return json.dumps({ - "ListEndpointsByPlatformApplicationResponse": { - "ListEndpointsByPlatformApplicationResult": { - "Endpoints": [ - { - "Attributes": endpoint.attributes, - "EndpointArn": endpoint.arn, - } for endpoint in endpoints - ], - "NextToken": None - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "ListEndpointsByPlatformApplicationResponse": { + "ListEndpointsByPlatformApplicationResult": { + "Endpoints": [ + { + "Attributes": endpoint.attributes, + "EndpointArn": endpoint.arn, + } + for endpoint in endpoints + ], + "NextToken": None, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template( - LIST_ENDPOINTS_BY_PLATFORM_APPLICATION_TEMPLATE) + LIST_ENDPOINTS_BY_PLATFORM_APPLICATION_TEMPLATE + ) return template.render(endpoints=endpoints) def get_endpoint_attributes(self): - arn = self._get_param('EndpointArn') + arn = self._get_param("EndpointArn") endpoint = self.backend.get_endpoint(arn) if self.request_json: - return json.dumps({ - "GetEndpointAttributesResponse": { - "GetEndpointAttributesResult": { - "Attributes": endpoint.attributes, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937f", + return json.dumps( + { + "GetEndpointAttributesResponse": { + "GetEndpointAttributesResult": { + "Attributes": endpoint.attributes + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937f" + }, } } - }) + ) template = self.response_template(GET_ENDPOINT_ATTRIBUTES_TEMPLATE) return template.render(endpoint=endpoint) def set_endpoint_attributes(self): - arn = self._get_param('EndpointArn') + arn = self._get_param("EndpointArn") attributes = self._get_attributes() self.backend.set_endpoint_attributes(arn, attributes) if self.request_json: - return json.dumps({ - "SetEndpointAttributesResponse": { - "ResponseMetadata": { - "RequestId": "384bc68d-3775-12df-8963-01868b7c937f", + return json.dumps( + { + "SetEndpointAttributesResponse": { + "ResponseMetadata": { + "RequestId": "384bc68d-3775-12df-8963-01868b7c937f" + } } } - }) + ) template = self.response_template(SET_ENDPOINT_ATTRIBUTES_TEMPLATE) return template.render() def delete_endpoint(self): - arn = self._get_param('EndpointArn') + arn = self._get_param("EndpointArn") self.backend.delete_endpoint(arn) if self.request_json: - return json.dumps({ - "DeleteEndpointResponse": { - "ResponseMetadata": { - "RequestId": "384bc68d-3775-12df-8963-01868b7c937f", + return json.dumps( + { + "DeleteEndpointResponse": { + "ResponseMetadata": { + "RequestId": "384bc68d-3775-12df-8963-01868b7c937f" + } } } - }) + ) template = self.response_template(DELETE_ENDPOINT_TEMPLATE) return template.render() def get_subscription_attributes(self): - arn = self._get_param('SubscriptionArn') + arn = self._get_param("SubscriptionArn") attributes = self.backend.get_subscription_attributes(arn) template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) return template.render(attributes=attributes) def set_subscription_attributes(self): - arn = self._get_param('SubscriptionArn') - attr_name = self._get_param('AttributeName') - attr_value = self._get_param('AttributeValue') + arn = self._get_param("SubscriptionArn") + attr_name = self._get_param("AttributeName") + attr_value = self._get_param("AttributeValue") self.backend.set_subscription_attributes(arn, attr_name, attr_value) template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) return template.render() @@ -566,7 +630,7 @@ class SNSResponse(BaseResponse): for key, value in self.querystring.items(): match = self.SMS_ATTR_REGEX.match(key) if match is not None: - temp_dict[match.group('index')][match.group('type')] = value[0] + temp_dict[match.group("index")][match.group("type")] = value[0] # 1: {key:X, value:Y} # to @@ -574,8 +638,8 @@ class SNSResponse(BaseResponse): # All of this, just to take into account when people provide invalid stuff. result = {} for item in temp_dict.values(): - if 'key' in item and 'value' in item: - result[item['key']] = item['value'] + if "key" in item and "value" in item: + result[item["key"]] = item["value"] self.backend.update_sms_attributes(result) @@ -585,11 +649,13 @@ class SNSResponse(BaseResponse): def get_sms_attributes(self): filter_list = set() for key, value in self.querystring.items(): - if key.startswith('attributes.member.1'): + if key.startswith("attributes.member.1"): filter_list.add(value[0]) if len(filter_list) > 0: - result = {k: v for k, v in self.backend.sms_attributes.items() if k in filter_list} + result = { + k: v for k, v in self.backend.sms_attributes.items() if k in filter_list + } else: result = self.backend.sms_attributes @@ -597,24 +663,24 @@ class SNSResponse(BaseResponse): return template.render(attributes=result) def check_if_phone_number_is_opted_out(self): - number = self._get_param('phoneNumber') + number = self._get_param("phoneNumber") if self.OPT_OUT_PHONE_NUMBER_REGEX.match(number) is None: error_response = self._error( - code='InvalidParameter', - message='Invalid parameter: PhoneNumber Reason: input incorrectly formatted' + code="InvalidParameter", + message="Invalid parameter: PhoneNumber Reason: input incorrectly formatted", ) return error_response, dict(status=400) # There should be a nicer way to set if a nubmer has opted out template = self.response_template(CHECK_IF_OPTED_OUT_TEMPLATE) - return template.render(opt_out=str(number.endswith('99')).lower()) + return template.render(opt_out=str(number.endswith("99")).lower()) def list_phone_numbers_opted_out(self): template = self.response_template(LIST_OPTOUT_TEMPLATE) return template.render(opt_outs=self.backend.opt_out_numbers) def opt_in_phone_number(self): - number = self._get_param('phoneNumber') + number = self._get_param("phoneNumber") try: self.backend.opt_out_numbers.remove(number) @@ -625,43 +691,30 @@ class SNSResponse(BaseResponse): return template.render() def add_permission(self): - arn = self._get_param('TopicArn') - label = self._get_param('Label') - accounts = self._get_multi_param('AWSAccountId.member.') - action = self._get_multi_param('ActionName.member.') + topic_arn = self._get_param("TopicArn") + label = self._get_param("Label") + aws_account_ids = self._get_multi_param("AWSAccountId.member.") + action_names = self._get_multi_param("ActionName.member.") - if arn not in self.backend.topics: - error_response = self._error('NotFound', 'Topic does not exist') - return error_response, dict(status=404) - - key = (arn, label) - self.backend.permissions[key] = {'accounts': accounts, 'action': action} + self.backend.add_permission(topic_arn, label, aws_account_ids, action_names) template = self.response_template(ADD_PERMISSION_TEMPLATE) return template.render() def remove_permission(self): - arn = self._get_param('TopicArn') - label = self._get_param('Label') + topic_arn = self._get_param("TopicArn") + label = self._get_param("Label") - if arn not in self.backend.topics: - error_response = self._error('NotFound', 'Topic does not exist') - return error_response, dict(status=404) - - try: - key = (arn, label) - del self.backend.permissions[key] - except KeyError: - pass + self.backend.remove_permission(topic_arn, label) template = self.response_template(DEL_PERMISSION_TEMPLATE) return template.render() def confirm_subscription(self): - arn = self._get_param('TopicArn') + arn = self._get_param("TopicArn") if arn not in self.backend.topics: - error_response = self._error('NotFound', 'Topic does not exist') + error_response = self._error("NotFound", "Topic does not exist") return error_response, dict(status=404) # Once Tokens are stored by the `subscribe` endpoint and distributed @@ -680,7 +733,33 @@ class SNSResponse(BaseResponse): # return error_response, dict(status=400) template = self.response_template(CONFIRM_SUBSCRIPTION_TEMPLATE) - return template.render(sub_arn='{0}:68762e72-e9b1-410a-8b3b-903da69ee1d5'.format(arn)) + return template.render( + sub_arn="{0}:68762e72-e9b1-410a-8b3b-903da69ee1d5".format(arn) + ) + + def list_tags_for_resource(self): + arn = self._get_param("ResourceArn") + + result = self.backend.list_tags_for_resource(arn) + + template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE) + return template.render(tags=result) + + def tag_resource(self): + arn = self._get_param("ResourceArn") + tags = self._get_tags() + + self.backend.tag_resource(arn, tags) + + return self.response_template(TAG_RESOURCE_TEMPLATE).render() + + def untag_resource(self): + arn = self._get_param("ResourceArn") + tag_keys = self._get_multi_param("TagKeys.member") + + self.backend.untag_resource(arn, tag_keys) + + return self.response_template(UNTAG_RESOURCE_TEMPLATE).render() CREATE_TOPIC_TEMPLATE = """ @@ -1063,3 +1142,33 @@ CONFIRM_SUBSCRIPTION_TEMPLATE = """ + + + {% for name, value in tags.items() %} + + {{ name }} + {{ value }} + + {% endfor %} + + + + 97fa763f-861b-5223-a946-20251f2a42e2 + +""" + +TAG_RESOURCE_TEMPLATE = """ + + + fd4ab1da-692f-50a7-95ad-e7c665877d98 + +""" + +UNTAG_RESOURCE_TEMPLATE = """ + + + 14eb7b1a-4cbd-5a56-80db-2d06412df769 + +""" diff --git a/moto/sns/urls.py b/moto/sns/urls.py index 518531c55..8c38bb12c 100644 --- a/moto/sns/urls.py +++ b/moto/sns/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import SNSResponse -url_bases = [ - "https?://sns.(.+).amazonaws.com", -] +url_bases = ["https?://sns.(.+).amazonaws.com"] -url_paths = { - '{0}/$': SNSResponse.dispatch, -} +url_paths = {"{0}/$": SNSResponse.dispatch} diff --git a/moto/sns/utils.py b/moto/sns/utils.py index 7793b0f6d..a46b84ac2 100644 --- a/moto/sns/utils.py +++ b/moto/sns/utils.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import re import uuid -E164_REGEX = re.compile(r'^\+?[1-9]\d{1,14}$') +E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$") def make_arn_for_topic(account_id, name, region_name): diff --git a/moto/sqs/__init__.py b/moto/sqs/__init__.py index 46c83133f..b2617b4e4 100644 --- a/moto/sqs/__init__.py +++ b/moto/sqs/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import sqs_backends from ..core.models import base_decorator, deprecated_base_decorator -sqs_backend = sqs_backends['us-east-1'] +sqs_backend = sqs_backends["us-east-1"] mock_sqs = base_decorator(sqs_backends) mock_sqs_deprecated = deprecated_base_decorator(sqs_backends) diff --git a/moto/sqs/exceptions.py b/moto/sqs/exceptions.py index 5f1cc46b2..01123d777 100644 --- a/moto/sqs/exceptions.py +++ b/moto/sqs/exceptions.py @@ -7,9 +7,13 @@ class MessageNotInflight(Exception): status_code = 400 -class ReceiptHandleIsInvalid(Exception): - description = "The receipt handle provided is not valid." - status_code = 400 +class ReceiptHandleIsInvalid(RESTError): + code = 400 + + def __init__(self): + super(ReceiptHandleIsInvalid, self).__init__( + "ReceiptHandleIsInvalid", "The input receipt handle is invalid." + ) class MessageAttributesInvalid(Exception): @@ -19,14 +23,79 @@ class MessageAttributesInvalid(Exception): self.description = description -class QueueDoesNotExist(Exception): - status_code = 404 - description = "The specified queue does not exist for this wsdl version." +class QueueDoesNotExist(RESTError): + code = 404 + + def __init__(self): + super(QueueDoesNotExist, self).__init__( + "QueueDoesNotExist", + "The specified queue does not exist for this wsdl version.", + ) class QueueAlreadyExists(RESTError): code = 400 def __init__(self, message): - super(QueueAlreadyExists, self).__init__( - "QueueAlreadyExists", message) + super(QueueAlreadyExists, self).__init__("QueueAlreadyExists", message) + + +class EmptyBatchRequest(RESTError): + code = 400 + + def __init__(self): + super(EmptyBatchRequest, self).__init__( + "EmptyBatchRequest", + "There should be at least one SendMessageBatchRequestEntry in the request.", + ) + + +class InvalidBatchEntryId(RESTError): + code = 400 + + def __init__(self): + super(InvalidBatchEntryId, self).__init__( + "InvalidBatchEntryId", + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", + ) + + +class BatchRequestTooLong(RESTError): + code = 400 + + def __init__(self, length): + super(BatchRequestTooLong, self).__init__( + "BatchRequestTooLong", + "Batch requests cannot be longer than 262144 bytes. " + "You have sent {} bytes.".format(length), + ) + + +class BatchEntryIdsNotDistinct(RESTError): + code = 400 + + def __init__(self, entry_id): + super(BatchEntryIdsNotDistinct, self).__init__( + "BatchEntryIdsNotDistinct", "Id {} repeated.".format(entry_id) + ) + + +class TooManyEntriesInBatchRequest(RESTError): + code = 400 + + def __init__(self, number): + super(TooManyEntriesInBatchRequest, self).__init__( + "TooManyEntriesInBatchRequest", + "Maximum number of entries per request are 10. " + "You have sent {}.".format(number), + ) + + +class InvalidAttributeName(RESTError): + code = 400 + + def __init__(self, attribute_name): + super(InvalidAttributeName, self).__init__( + "InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name) + ) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index e774e261c..4e6282f56 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -12,7 +12,12 @@ import boto.sqs from moto.core.exceptions import RESTError from moto.core import BaseBackend, BaseModel -from moto.core.utils import camelcase_to_underscores, get_random_message_id, unix_time, unix_time_millis +from moto.core.utils import ( + camelcase_to_underscores, + get_random_message_id, + unix_time, + unix_time_millis, +) from .utils import generate_receipt_handle from .exceptions import ( MessageAttributesInvalid, @@ -20,16 +25,23 @@ from .exceptions import ( QueueDoesNotExist, QueueAlreadyExists, ReceiptHandleIsInvalid, + InvalidBatchEntryId, + BatchRequestTooLong, + BatchEntryIdsNotDistinct, + TooManyEntriesInBatchRequest, + InvalidAttributeName, ) -DEFAULT_ACCOUNT_ID = 123456789012 +from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID + DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU" -TRANSPORT_TYPE_ENCODINGS = {'String': b'\x01', 'Binary': b'\x02', 'Number': b'\x01'} +MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB + +TRANSPORT_TYPE_ENCODINGS = {"String": b"\x01", "Binary": b"\x02", "Number": b"\x01"} class Message(BaseModel): - def __init__(self, message_id, body): self.id = message_id self._body = body @@ -47,7 +59,7 @@ class Message(BaseModel): @property def body_md5(self): md5 = hashlib.md5() - md5.update(self._body.encode('utf-8')) + md5.update(self._body.encode("utf-8")) return md5.hexdigest() @property @@ -61,17 +73,19 @@ class Message(BaseModel): Not yet implemented: List types (https://github.com/aws/aws-sdk-java/blob/7844c64cf248aed889811bf2e871ad6b276a89ca/aws-java-sdk-sqs/src/main/java/com/amazonaws/services/sqs/MessageMD5ChecksumHandler.java#L58k) """ + def utf8(str): if isinstance(str, six.string_types): - return str.encode('utf-8') + return str.encode("utf-8") return str + md5 = hashlib.md5() - struct_format = "!I".encode('ascii') # ensure it's a bytestring + struct_format = "!I".encode("ascii") # ensure it's a bytestring for name in sorted(self.message_attributes.keys()): attr = self.message_attributes[name] - data_type = attr['data_type'] + data_type = attr["data_type"] - encoded = utf8('') + encoded = utf8("") # Each part of each attribute is encoded right after it's # own length is packed into a 4-byte integer # 'timestamp' -> b'\x00\x00\x00\t' @@ -81,18 +95,22 @@ class Message(BaseModel): encoded += struct.pack(struct_format, len(data_type)) + utf8(data_type) encoded += TRANSPORT_TYPE_ENCODINGS[data_type] - if data_type == 'String' or data_type == 'Number': - value = attr['string_value'] - elif data_type == 'Binary': - print(data_type, attr['binary_value'], type(attr['binary_value'])) - value = base64.b64decode(attr['binary_value']) + if data_type == "String" or data_type == "Number": + value = attr["string_value"] + elif data_type == "Binary": + print(data_type, attr["binary_value"], type(attr["binary_value"])) + value = base64.b64decode(attr["binary_value"]) else: - print("Moto hasn't implemented MD5 hashing for {} attributes".format(data_type)) + print( + "Moto hasn't implemented MD5 hashing for {} attributes".format( + data_type + ) + ) # The following should be enough of a clue to users that # they are not, in fact, looking at a correct MD5 while # also following the character and length constraints of # MD5 so as not to break client softwre - return('deadbeefdeadbeefdeadbeefdeadbeef') + return "deadbeefdeadbeefdeadbeefdeadbeef" encoded += struct.pack(struct_format, len(utf8(value))) + utf8(value) @@ -155,24 +173,30 @@ class Message(BaseModel): class Queue(BaseModel): - base_attributes = ['ApproximateNumberOfMessages', - 'ApproximateNumberOfMessagesDelayed', - 'ApproximateNumberOfMessagesNotVisible', - 'CreatedTimestamp', - 'DelaySeconds', - 'LastModifiedTimestamp', - 'MaximumMessageSize', - 'MessageRetentionPeriod', - 'QueueArn', - 'ReceiveMessageWaitTimeSeconds', - 'VisibilityTimeout'] - fifo_attributes = ['FifoQueue', - 'ContentBasedDeduplication'] - kms_attributes = ['KmsDataKeyReusePeriodSeconds', - 'KmsMasterKeyId'] - ALLOWED_PERMISSIONS = ('*', 'ChangeMessageVisibility', 'DeleteMessage', - 'GetQueueAttributes', 'GetQueueUrl', - 'ReceiveMessage', 'SendMessage') + BASE_ATTRIBUTES = [ + "ApproximateNumberOfMessages", + "ApproximateNumberOfMessagesDelayed", + "ApproximateNumberOfMessagesNotVisible", + "CreatedTimestamp", + "DelaySeconds", + "LastModifiedTimestamp", + "MaximumMessageSize", + "MessageRetentionPeriod", + "QueueArn", + "ReceiveMessageWaitTimeSeconds", + "VisibilityTimeout", + ] + FIFO_ATTRIBUTES = ["FifoQueue", "ContentBasedDeduplication"] + KMS_ATTRIBUTES = ["KmsDataKeyReusePeriodSeconds", "KmsMasterKeyId"] + ALLOWED_PERMISSIONS = ( + "*", + "ChangeMessageVisibility", + "DeleteMessage", + "GetQueueAttributes", + "GetQueueUrl", + "ReceiveMessage", + "SendMessage", + ) def __init__(self, name, region, **kwargs): self.name = name @@ -185,33 +209,36 @@ class Queue(BaseModel): now = unix_time() self.created_timestamp = now - self.queue_arn = 'arn:aws:sqs:{0}:123456789012:{1}'.format(self.region, - self.name) + self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format( + self.region, DEFAULT_ACCOUNT_ID, self.name + ) self.dead_letter_queue = None self.lambda_event_source_mappings = {} # default settings for a non fifo queue defaults = { - 'ContentBasedDeduplication': 'false', - 'DelaySeconds': 0, - 'FifoQueue': 'false', - 'KmsDataKeyReusePeriodSeconds': 300, # five minutes - 'KmsMasterKeyId': None, - 'MaximumMessageSize': int(64 << 10), - 'MessageRetentionPeriod': 86400 * 4, # four days - 'Policy': None, - 'ReceiveMessageWaitTimeSeconds': 0, - 'RedrivePolicy': None, - 'VisibilityTimeout': 30, + "ContentBasedDeduplication": "false", + "DelaySeconds": 0, + "FifoQueue": "false", + "KmsDataKeyReusePeriodSeconds": 300, # five minutes + "KmsMasterKeyId": None, + "MaximumMessageSize": int(64 << 10), + "MessageRetentionPeriod": 86400 * 4, # four days + "Policy": None, + "ReceiveMessageWaitTimeSeconds": 0, + "RedrivePolicy": None, + "VisibilityTimeout": 30, } defaults.update(kwargs) self._set_attributes(defaults, now) # Check some conditions - if self.fifo_queue and not self.name.endswith('.fifo'): - raise MessageAttributesInvalid('Queue name must end in .fifo for FIFO queues') + if self.fifo_queue and not self.name.endswith(".fifo"): + raise MessageAttributesInvalid( + "Queue name must end in .fifo for FIFO queues" + ) @property def pending_messages(self): @@ -219,18 +246,25 @@ class Queue(BaseModel): @property def pending_message_groups(self): - return set(message.group_id - for message in self._pending_messages - if message.group_id is not None) + return set( + message.group_id + for message in self._pending_messages + if message.group_id is not None + ) def _set_attributes(self, attributes, now=None): if not now: now = unix_time() - integer_fields = ('DelaySeconds', 'KmsDataKeyreusePeriodSeconds', - 'MaximumMessageSize', 'MessageRetentionPeriod', - 'ReceiveMessageWaitTime', 'VisibilityTimeout') - bool_fields = ('ContentBasedDeduplication', 'FifoQueue') + integer_fields = ( + "DelaySeconds", + "KmsDataKeyreusePeriodSeconds", + "MaximumMessageSize", + "MessageRetentionPeriod", + "ReceiveMessageWaitTime", + "VisibilityTimeout", + ) + bool_fields = ("ContentBasedDeduplication", "FifoQueue") for key, value in six.iteritems(attributes): if key in integer_fields: @@ -238,13 +272,13 @@ class Queue(BaseModel): if key in bool_fields: value = value == "true" - if key == 'RedrivePolicy' and value is not None: + if key == "RedrivePolicy" and value is not None: continue setattr(self, camelcase_to_underscores(key), value) - if attributes.get('RedrivePolicy', None): - self._setup_dlq(attributes['RedrivePolicy']) + if attributes.get("RedrivePolicy", None): + self._setup_dlq(attributes["RedrivePolicy"]) self.last_modified_timestamp = now @@ -254,56 +288,86 @@ class Queue(BaseModel): try: self.redrive_policy = json.loads(policy) except ValueError: - raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json') + raise RESTError( + "InvalidParameterValue", + "Redrive policy is not a dict or valid json", + ) elif isinstance(policy, dict): self.redrive_policy = policy else: - raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json') + raise RESTError( + "InvalidParameterValue", "Redrive policy is not a dict or valid json" + ) - if 'deadLetterTargetArn' not in self.redrive_policy: - raise RESTError('InvalidParameterValue', 'Redrive policy does not contain deadLetterTargetArn') - if 'maxReceiveCount' not in self.redrive_policy: - raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount') + if "deadLetterTargetArn" not in self.redrive_policy: + raise RESTError( + "InvalidParameterValue", + "Redrive policy does not contain deadLetterTargetArn", + ) + if "maxReceiveCount" not in self.redrive_policy: + raise RESTError( + "InvalidParameterValue", + "Redrive policy does not contain maxReceiveCount", + ) + + # 'maxReceiveCount' is stored as int + self.redrive_policy["maxReceiveCount"] = int( + self.redrive_policy["maxReceiveCount"] + ) for queue in sqs_backends[self.region].queues.values(): - if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']: + if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]: self.dead_letter_queue = queue if self.fifo_queue and not queue.fifo_queue: - raise RESTError('InvalidParameterCombination', 'Fifo queues cannot use non fifo dead letter queues') + raise RESTError( + "InvalidParameterCombination", + "Fifo queues cannot use non fifo dead letter queues", + ) break else: - raise RESTError('AWS.SimpleQueueService.NonExistentQueue', 'Could not find DLQ for {0}'.format(self.redrive_policy['deadLetterTargetArn'])) + raise RESTError( + "AWS.SimpleQueueService.NonExistentQueue", + "Could not find DLQ for {0}".format( + self.redrive_policy["deadLetterTargetArn"] + ), + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] sqs_backend = sqs_backends[region_name] return sqs_backend.create_queue( - name=properties['QueueName'], - region=region_name, - **properties + name=properties["QueueName"], region=region_name, **properties ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - queue_name = properties['QueueName'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + queue_name = properties["QueueName"] sqs_backend = sqs_backends[region_name] queue = sqs_backend.get_queue(queue_name) - if 'VisibilityTimeout' in properties: - queue.visibility_timeout = int(properties['VisibilityTimeout']) + if "VisibilityTimeout" in properties: + queue.visibility_timeout = int(properties["VisibilityTimeout"]) - if 'ReceiveMessageWaitTimeSeconds' in properties: - queue.receive_message_wait_time_seconds = int(properties['ReceiveMessageWaitTimeSeconds']) + if "ReceiveMessageWaitTimeSeconds" in properties: + queue.receive_message_wait_time_seconds = int( + properties["ReceiveMessageWaitTimeSeconds"] + ) return queue @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - queue_name = properties['QueueName'] + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + queue_name = properties["QueueName"] sqs_backend = sqs_backends[region_name] sqs_backend.delete_queue(queue_name) @@ -327,25 +391,25 @@ class Queue(BaseModel): def attributes(self): result = {} - for attribute in self.base_attributes: + for attribute in self.BASE_ATTRIBUTES: attr = getattr(self, camelcase_to_underscores(attribute)) result[attribute] = attr if self.fifo_queue: - for attribute in self.fifo_attributes: + for attribute in self.FIFO_ATTRIBUTES: attr = getattr(self, camelcase_to_underscores(attribute)) result[attribute] = attr if self.kms_master_key_id: - for attribute in self.kms_attributes: + for attribute in self.KMS_ATTRIBUTES: attr = getattr(self, camelcase_to_underscores(attribute)) result[attribute] = attr if self.policy: - result['Policy'] = self.policy + result["Policy"] = self.policy if self.redrive_policy: - result['RedrivePolicy'] = json.dumps(self.redrive_policy) + result["RedrivePolicy"] = json.dumps(self.redrive_policy) for key in result: if isinstance(result[key], bool): @@ -354,15 +418,22 @@ class Queue(BaseModel): return result def url(self, request_url): - return "{0}://{1}/123456789012/{2}".format(request_url.scheme, request_url.netloc, self.name) + return "{0}://{1}/{2}/{3}".format( + request_url.scheme, request_url.netloc, DEFAULT_ACCOUNT_ID, self.name + ) @property def messages(self): - return [message for message in self._messages if message.visible and not message.delayed] + return [ + message + for message in self._messages + if message.visible and not message.delayed + ] def add_message(self, message): self._messages.append(message) from moto.awslambda import lambda_backends + for arn, esm in self.lambda_event_source_mappings.items(): backend = sqs_backends[self.region] @@ -380,27 +451,28 @@ class Queue(BaseModel): ) result = lambda_backends[self.region].send_sqs_batch( - arn, - messages, - self.queue_arn, + arn, messages, self.queue_arn ) if result: [backend.delete_message(self.name, m.receipt_handle) for m in messages] else: - [backend.change_message_visibility(self.name, m.receipt_handle, 0) for m in messages] + [ + backend.change_message_visibility(self.name, m.receipt_handle, 0) + for m in messages + ] def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.queue_arn - elif attribute_name == 'QueueName': + elif attribute_name == "QueueName": return self.name raise UnformattedGetAttTemplateException() class SQSBackend(BaseBackend): - def __init__(self, region_name): self.region_name = region_name self.queues = {} @@ -412,11 +484,11 @@ class SQSBackend(BaseBackend): self.__dict__ = {} self.__init__(region_name) - def create_queue(self, name, **kwargs): + def create_queue(self, name, tags=None, **kwargs): queue = self.queues.get(name) if queue: try: - kwargs.pop('region') + kwargs.pop("region") except KeyError: pass @@ -424,26 +496,44 @@ class SQSBackend(BaseBackend): queue_attributes = queue.attributes new_queue_attributes = new_queue.attributes + static_attributes = ( + "DelaySeconds", + "MaximumMessageSize", + "MessageRetentionPeriod", + "Policy", + "QueueArn", + "ReceiveMessageWaitTimeSeconds", + "RedrivePolicy", + "VisibilityTimeout", + "KmsMasterKeyId", + "KmsDataKeyReusePeriodSeconds", + "FifoQueue", + "ContentBasedDeduplication", + ) - for key in ['CreatedTimestamp', 'LastModifiedTimestamp']: - queue_attributes.pop(key) - new_queue_attributes.pop(key) - - if queue_attributes != new_queue_attributes: - raise QueueAlreadyExists("The specified queue already exists.") + for key in static_attributes: + if queue_attributes.get(key) != new_queue_attributes.get(key): + raise QueueAlreadyExists("The specified queue already exists.") else: try: - kwargs.pop('region') + kwargs.pop("region") except KeyError: pass queue = Queue(name, region=self.region_name, **kwargs) self.queues[name] = queue + + if tags: + queue.tags = tags + return queue + def get_queue_url(self, queue_name): + return self.get_queue(queue_name) + def list_queues(self, queue_name_prefix): - re_str = '.*' + re_str = ".*" if queue_name_prefix: - re_str = '^{0}.*'.format(queue_name_prefix) + re_str = "^{0}.*".format(queue_name_prefix) prefix_re = re.compile(re_str) qs = [] for name, q in self.queues.items(): @@ -462,12 +552,49 @@ class SQSBackend(BaseBackend): return self.queues.pop(queue_name) return False + def get_queue_attributes(self, queue_name, attribute_names): + queue = self.get_queue(queue_name) + + if not len(attribute_names): + attribute_names.append("All") + + valid_names = ( + ["All"] + + queue.BASE_ATTRIBUTES + + queue.FIFO_ATTRIBUTES + + queue.KMS_ATTRIBUTES + ) + invalid_name = next( + (name for name in attribute_names if name not in valid_names), None + ) + + if invalid_name or invalid_name == "": + raise InvalidAttributeName(invalid_name) + + attributes = {} + + if "All" in attribute_names: + attributes = queue.attributes + else: + for name in (name for name in attribute_names if name in queue.attributes): + attributes[name] = queue.attributes.get(name) + + return attributes + def set_queue_attributes(self, queue_name, attributes): queue = self.get_queue(queue_name) queue._set_attributes(attributes) return queue - def send_message(self, queue_name, message_body, message_attributes=None, delay_seconds=None, deduplication_id=None, group_id=None): + def send_message( + self, + queue_name, + message_body, + message_attributes=None, + delay_seconds=None, + deduplication_id=None, + group_id=None, + ): queue = self.get_queue(queue_name) @@ -488,15 +615,66 @@ class SQSBackend(BaseBackend): if message_attributes: message.message_attributes = message_attributes - message.mark_sent( - delay_seconds=delay_seconds - ) + message.mark_sent(delay_seconds=delay_seconds) queue.add_message(message) return message - def receive_messages(self, queue_name, count, wait_seconds_timeout, visibility_timeout): + def send_message_batch(self, queue_name, entries): + self.get_queue(queue_name) + + if any( + not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values() + ): + raise InvalidBatchEntryId() + + body_length = next( + ( + len(entry["MessageBody"]) + for entry in entries.values() + if len(entry["MessageBody"]) > MAXIMUM_MESSAGE_LENGTH + ), + False, + ) + if body_length: + raise BatchRequestTooLong(body_length) + + duplicate_id = self._get_first_duplicate_id( + [entry["Id"] for entry in entries.values()] + ) + if duplicate_id: + raise BatchEntryIdsNotDistinct(duplicate_id) + + if len(entries) > 10: + raise TooManyEntriesInBatchRequest(len(entries)) + + messages = [] + for index, entry in entries.items(): + # Loop through looking for messages + message = self.send_message( + queue_name, + entry["MessageBody"], + message_attributes=entry["MessageAttributes"], + delay_seconds=entry["DelaySeconds"], + ) + message.user_id = entry["Id"] + + messages.append(message) + + return messages + + def _get_first_duplicate_id(self, ids): + unique_ids = set() + for id in ids: + if id in unique_ids: + return id + unique_ids.add(id) + return None + + def receive_messages( + self, queue_name, count, wait_seconds_timeout, visibility_timeout + ): """ Attempt to retrieve visible messages from a queue. @@ -542,13 +720,15 @@ class SQSBackend(BaseBackend): queue.pending_messages.add(message) - if queue.dead_letter_queue is not None and message.approximate_receive_count >= queue.redrive_policy['maxReceiveCount']: + if ( + queue.dead_letter_queue is not None + and message.approximate_receive_count + >= queue.redrive_policy["maxReceiveCount"] + ): messages_to_dlq.append(message) continue - message.mark_received( - visibility_timeout=visibility_timeout - ) + message.mark_received(visibility_timeout=visibility_timeout) result.append(message) if len(result) >= count: break @@ -564,6 +744,7 @@ class SQSBackend(BaseBackend): break import time + time.sleep(0.01) continue @@ -573,9 +754,15 @@ class SQSBackend(BaseBackend): def delete_message(self, queue_name, receipt_handle): queue = self.get_queue(queue_name) + + if not any( + message.receipt_handle == receipt_handle for message in queue._messages + ): + raise ReceiptHandleIsInvalid() + new_messages = [] for message in queue._messages: - # Only delete message if it is not visible and the reciept_handle + # Only delete message if it is not visible and the receipt_handle # matches. if message.receipt_handle == receipt_handle: queue.pending_messages.remove(message) @@ -615,12 +802,12 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) if actions is None or len(actions) == 0: - raise RESTError('InvalidParameterValue', 'Need at least one Action') + raise RESTError("InvalidParameterValue", "Need at least one Action") if account_ids is None or len(account_ids) == 0: - raise RESTError('InvalidParameterValue', 'Need at least one Account ID') + raise RESTError("InvalidParameterValue", "Need at least one Account ID") if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]): - raise RESTError('InvalidParameterValue', 'Invalid permissions') + raise RESTError("InvalidParameterValue", "Invalid permissions") queue.permissions[label] = (account_ids, actions) @@ -628,22 +815,46 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) if label not in queue.permissions: - raise RESTError('InvalidParameterValue', 'Permission doesnt exist for the given label') + raise RESTError( + "InvalidParameterValue", "Permission doesnt exist for the given label" + ) del queue.permissions[label] def tag_queue(self, queue_name, tags): queue = self.get_queue(queue_name) + + if not len(tags): + raise RESTError( + "MissingParameter", "The request must contain the parameter Tags." + ) + + if len(tags) > 50: + raise RESTError( + "InvalidParameterValue", + "Too many tags added for queue {}.".format(queue_name), + ) + queue.tags.update(tags) def untag_queue(self, queue_name, tag_keys): queue = self.get_queue(queue_name) + + if not len(tag_keys): + raise RESTError( + "InvalidParameterValue", + "Tag keys must be between 1 and 128 characters in length.", + ) + for key in tag_keys: try: del queue.tags[key] except KeyError: pass + def list_queue_tags(self, queue_name): + return self.get_queue(queue_name) + sqs_backends = {} for region in boto.sqs.regions(): diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index 5ddaf8849..8acea0799 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -1,18 +1,20 @@ from __future__ import unicode_literals import re -from six.moves.urllib.parse import urlparse from moto.core.responses import BaseResponse from moto.core.utils import amz_crc32, amzn_request_id -from .utils import parse_message_attributes -from .models import sqs_backends +from six.moves.urllib.parse import urlparse + from .exceptions import ( + EmptyBatchRequest, + InvalidAttributeName, MessageAttributesInvalid, MessageNotInflight, - QueueDoesNotExist, ReceiptHandleIsInvalid, ) +from .models import sqs_backends +from .utils import parse_message_attributes MAXIMUM_VISIBILTY_TIMEOUT = 43200 MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB @@ -21,7 +23,7 @@ DEFAULT_RECEIVED_MESSAGES = 1 class SQSResponse(BaseResponse): - region_regex = re.compile(r'://(.+?)\.queue\.amazonaws\.com') + region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com") @property def sqs_backend(self): @@ -29,13 +31,21 @@ class SQSResponse(BaseResponse): @property def attribute(self): - if not hasattr(self, '_attribute'): - self._attribute = self._get_map_prefix('Attribute', key_end='.Name', value_end='.Value') + if not hasattr(self, "_attribute"): + self._attribute = self._get_map_prefix( + "Attribute", key_end=".Name", value_end=".Value" + ) return self._attribute + @property + def tags(self): + if not hasattr(self, "_tags"): + self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") + return self._tags + def _get_queue_name(self): try: - queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1] + queue_name = self.querystring.get("QueueUrl")[0].split("/")[-1] except TypeError: # Fallback to reading from the URL queue_name = self.path.split("/")[-1] @@ -73,39 +83,34 @@ class SQSResponse(BaseResponse): queue_name = self._get_param("QueueName") try: - queue = self.sqs_backend.create_queue(queue_name, **self.attribute) + queue = self.sqs_backend.create_queue( + queue_name, self.tags, **self.attribute + ) except MessageAttributesInvalid as e: - return self._error('InvalidParameterValue', e.description) + return self._error("InvalidParameterValue", e.description) template = self.response_template(CREATE_QUEUE_RESPONSE) - return template.render(queue=queue, request_url=request_url) + return template.render(queue_url=queue.url(request_url)) def get_queue_url(self): request_url = urlparse(self.uri) queue_name = self._get_param("QueueName") - try: - queue = self.sqs_backend.get_queue(queue_name) - except QueueDoesNotExist as e: - return self._error('AWS.SimpleQueueService.NonExistentQueue', - e.description) + queue = self.sqs_backend.get_queue_url(queue_name) - if queue: - template = self.response_template(GET_QUEUE_URL_RESPONSE) - return template.render(queue=queue, request_url=request_url) - else: - return "", dict(status=404) + template = self.response_template(GET_QUEUE_URL_RESPONSE) + return template.render(queue_url=queue.url(request_url)) def list_queues(self): request_url = urlparse(self.uri) - queue_name_prefix = self._get_param('QueueNamePrefix') + queue_name_prefix = self._get_param("QueueNamePrefix") queues = self.sqs_backend.list_queues(queue_name_prefix) template = self.response_template(LIST_QUEUES_RESPONSE) return template.render(queues=queues, request_url=request_url) def change_message_visibility(self): queue_name = self._get_queue_name() - receipt_handle = self._get_param('ReceiptHandle') + receipt_handle = self._get_param("ReceiptHandle") try: visibility_timeout = self._get_validated_visibility_timeout() @@ -116,67 +121,80 @@ class SQSResponse(BaseResponse): self.sqs_backend.change_message_visibility( queue_name=queue_name, receipt_handle=receipt_handle, - visibility_timeout=visibility_timeout + visibility_timeout=visibility_timeout, + ) + except MessageNotInflight as e: + return ( + "Invalid request: {0}".format(e.description), + dict(status=e.status_code), ) - except (ReceiptHandleIsInvalid, MessageNotInflight) as e: - return "Invalid request: {0}".format(e.description), dict(status=e.status_code) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE) return template.render() def change_message_visibility_batch(self): queue_name = self._get_queue_name() - entries = self._get_list_prefix('ChangeMessageVisibilityBatchRequestEntry') + entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry") success = [] error = [] for entry in entries: try: - visibility_timeout = self._get_validated_visibility_timeout(entry['visibility_timeout']) + visibility_timeout = self._get_validated_visibility_timeout( + entry["visibility_timeout"] + ) except ValueError: - error.append({ - 'Id': entry['id'], - 'SenderFault': 'true', - 'Code': 'InvalidParameterValue', - 'Message': 'Visibility timeout invalid' - }) + error.append( + { + "Id": entry["id"], + "SenderFault": "true", + "Code": "InvalidParameterValue", + "Message": "Visibility timeout invalid", + } + ) continue try: self.sqs_backend.change_message_visibility( queue_name=queue_name, - receipt_handle=entry['receipt_handle'], - visibility_timeout=visibility_timeout + receipt_handle=entry["receipt_handle"], + visibility_timeout=visibility_timeout, ) - success.append(entry['id']) + success.append(entry["id"]) except ReceiptHandleIsInvalid as e: - error.append({ - 'Id': entry['id'], - 'SenderFault': 'true', - 'Code': 'ReceiptHandleIsInvalid', - 'Message': e.description - }) + error.append( + { + "Id": entry["id"], + "SenderFault": "true", + "Code": "ReceiptHandleIsInvalid", + "Message": e.description, + } + ) except MessageNotInflight as e: - error.append({ - 'Id': entry['id'], - 'SenderFault': 'false', - 'Code': 'AWS.SimpleQueueService.MessageNotInflight', - 'Message': e.description - }) + error.append( + { + "Id": entry["id"], + "SenderFault": "false", + "Code": "AWS.SimpleQueueService.MessageNotInflight", + "Message": e.description, + } + ) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE) return template.render(success=success, errors=error) def get_queue_attributes(self): queue_name = self._get_queue_name() - try: - queue = self.sqs_backend.get_queue(queue_name) - except QueueDoesNotExist as e: - return self._error('AWS.SimpleQueueService.NonExistentQueue', - e.description) + + if self.querystring.get("AttributeNames"): + raise InvalidAttributeName("") + + attribute_names = self._get_multi_param("AttributeName") + + attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) template = self.response_template(GET_QUEUE_ATTRIBUTES_RESPONSE) - return template.render(queue=queue) + return template.render(attributes=attributes) def set_queue_attributes(self): # TODO validate self.get_param('QueueUrl') @@ -190,14 +208,17 @@ class SQSResponse(BaseResponse): queue_name = self._get_queue_name() queue = self.sqs_backend.delete_queue(queue_name) if not queue: - return "A queue with name {0} does not exist".format(queue_name), dict(status=404) + return ( + "A queue with name {0} does not exist".format(queue_name), + dict(status=404), + ) template = self.response_template(DELETE_QUEUE_RESPONSE) return template.render(queue=queue) def send_message(self): - message = self._get_param('MessageBody') - delay_seconds = int(self._get_param('DelaySeconds', 0)) + message = self._get_param("MessageBody") + delay_seconds = int(self._get_param("DelaySeconds", 0)) message_group_id = self._get_param("MessageGroupId") message_dedupe_id = self._get_param("MessageDeduplicationId") @@ -217,7 +238,7 @@ class SQSResponse(BaseResponse): message_attributes=message_attributes, delay_seconds=delay_seconds, deduplication_id=message_dedupe_id, - group_id=message_group_id + group_id=message_group_id, ) template = self.response_template(SEND_MESSAGE_RESPONSE) return template.render(message=message, message_attributes=message_attributes) @@ -236,33 +257,35 @@ class SQSResponse(BaseResponse): queue_name = self._get_queue_name() - messages = [] - for index in range(1, 11): - # Loop through looking for messages - message_key = 'SendMessageBatchRequestEntry.{0}.MessageBody'.format( - index) - message_body = self.querystring.get(message_key) - if not message_body: - # Found all messages - break + self.sqs_backend.get_queue(queue_name) - message_user_id_key = 'SendMessageBatchRequestEntry.{0}.Id'.format( - index) - message_user_id = self.querystring.get(message_user_id_key)[0] - delay_key = 'SendMessageBatchRequestEntry.{0}.DelaySeconds'.format( - index) - delay_seconds = self.querystring.get(delay_key, [None])[0] - message = self.sqs_backend.send_message( - queue_name, message_body[0], delay_seconds=delay_seconds) - message.user_id = message_user_id + if self.querystring.get("Entries"): + raise EmptyBatchRequest() - message_attributes = parse_message_attributes( - self.querystring, base='SendMessageBatchRequestEntry.{0}.'.format(index)) - if type(message_attributes) == tuple: - return message_attributes[0], message_attributes[1] - message.message_attributes = message_attributes + entries = {} + for key, value in self.querystring.items(): + match = re.match(r"^SendMessageBatchRequestEntry\.(\d+)\.Id", key) + if match: + index = match.group(1) - messages.append(message) + message_attributes = parse_message_attributes( + self.querystring, + base="SendMessageBatchRequestEntry.{}.".format(index), + ) + + entries[index] = { + "Id": value[0], + "MessageBody": self.querystring.get( + "SendMessageBatchRequestEntry.{}.MessageBody".format(index) + )[0], + "DelaySeconds": self.querystring.get( + "SendMessageBatchRequestEntry.{}.DelaySeconds".format(index), + [None], + )[0], + "MessageAttributes": message_attributes, + } + + messages = self.sqs_backend.send_message_batch(queue_name, entries) template = self.response_template(SEND_MESSAGE_BATCH_RESPONSE) return template.render(messages=messages) @@ -289,8 +312,9 @@ class SQSResponse(BaseResponse): message_ids = [] for index in range(1, 11): # Loop through looking for messages - receipt_key = 'DeleteMessageBatchRequestEntry.{0}.ReceiptHandle'.format( - index) + receipt_key = "DeleteMessageBatchRequestEntry.{0}.ReceiptHandle".format( + index + ) receipt_handle = self.querystring.get(receipt_key) if not receipt_handle: # Found all messages @@ -298,8 +322,7 @@ class SQSResponse(BaseResponse): self.sqs_backend.delete_message(queue_name, receipt_handle[0]) - message_user_id_key = 'DeleteMessageBatchRequestEntry.{0}.Id'.format( - index) + message_user_id_key = "DeleteMessageBatchRequestEntry.{0}.Id".format(index) message_user_id = self.querystring.get(message_user_id_key)[0] message_ids.append(message_user_id) @@ -315,10 +338,7 @@ class SQSResponse(BaseResponse): def receive_message(self): queue_name = self._get_queue_name() - try: - queue = self.sqs_backend.get_queue(queue_name) - except QueueDoesNotExist as e: - return self._error('QueueDoesNotExist', e.description) + queue = self.sqs_backend.get_queue(queue_name) try: message_count = int(self.querystring.get("MaxNumberOfMessages")[0]) @@ -331,7 +351,8 @@ class SQSResponse(BaseResponse): "An error occurred (InvalidParameterValue) when calling " "the ReceiveMessage operation: Value %s for parameter " "MaxNumberOfMessages is invalid. Reason: must be between " - "1 and 10, if provided." % message_count) + "1 and 10, if provided." % message_count, + ) try: wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) @@ -344,7 +365,8 @@ class SQSResponse(BaseResponse): "An error occurred (InvalidParameterValue) when calling " "the ReceiveMessage operation: Value %s for parameter " "WaitTimeSeconds is invalid. Reason: must be <= 0 and " - ">= 20 if provided." % wait_time) + ">= 20 if provided." % wait_time, + ) try: visibility_timeout = self._get_validated_visibility_timeout() @@ -354,7 +376,8 @@ class SQSResponse(BaseResponse): return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) messages = self.sqs_backend.receive_messages( - queue_name, message_count, wait_time, visibility_timeout) + queue_name, message_count, wait_time, visibility_timeout + ) template = self.response_template(RECEIVE_MESSAGE_RESPONSE) return template.render(messages=messages) @@ -369,9 +392,9 @@ class SQSResponse(BaseResponse): def add_permission(self): queue_name = self._get_queue_name() - actions = self._get_multi_param('ActionName') - account_ids = self._get_multi_param('AWSAccountId') - label = self._get_param('Label') + actions = self._get_multi_param("ActionName") + account_ids = self._get_multi_param("AWSAccountId") + label = self._get_param("Label") self.sqs_backend.add_permission(queue_name, actions, account_ids, label) @@ -380,7 +403,7 @@ class SQSResponse(BaseResponse): def remove_permission(self): queue_name = self._get_queue_name() - label = self._get_param('Label') + label = self._get_param("Label") self.sqs_backend.remove_permission(queue_name, label) @@ -389,7 +412,7 @@ class SQSResponse(BaseResponse): def tag_queue(self): queue_name = self._get_queue_name() - tags = self._get_map_prefix('Tag', key_end='.Key', value_end='.Value') + tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") self.sqs_backend.tag_queue(queue_name, tags) @@ -398,7 +421,7 @@ class SQSResponse(BaseResponse): def untag_queue(self): queue_name = self._get_queue_name() - tag_keys = self._get_multi_param('TagKey') + tag_keys = self._get_multi_param("TagKey") self.sqs_backend.untag_queue(queue_name, tag_keys) @@ -408,7 +431,7 @@ class SQSResponse(BaseResponse): def list_queue_tags(self): queue_name = self._get_queue_name() - queue = self.sqs_backend.get_queue(queue_name) + queue = self.sqs_backend.list_queue_tags(queue_name) template = self.response_template(LIST_QUEUE_TAGS_RESPONSE) return template.render(tags=queue.tags) @@ -416,8 +439,7 @@ class SQSResponse(BaseResponse): CREATE_QUEUE_RESPONSE = """ - {{ queue.url(request_url) }} - {{ queue.visibility_timeout }} + {{ queue_url }} @@ -426,7 +448,7 @@ CREATE_QUEUE_RESPONSE = """ GET_QUEUE_URL_RESPONSE = """ - {{ queue.url(request_url) }} + {{ queue_url }} @@ -452,7 +474,7 @@ DELETE_QUEUE_RESPONSE = """ GET_QUEUE_ATTRIBUTES_RESPONSE = """ - {% for key, value in queue.attributes.items() %} + {% for key, value in attributes.items() %} {{ key }} {{ value }} @@ -677,7 +699,8 @@ ERROR_TOO_LONG_RESPONSE = """ diff --git a/moto/sqs/urls.py b/moto/sqs/urls.py index 9ec014a80..3acf8591a 100644 --- a/moto/sqs/urls.py +++ b/moto/sqs/urls.py @@ -1,13 +1,11 @@ from __future__ import unicode_literals from .responses import SQSResponse -url_bases = [ - "https?://(.*?)(queue|sqs)(.*?).amazonaws.com" -] +url_bases = ["https?://(.*?)(queue|sqs)(.*?).amazonaws.com"] dispatch = SQSResponse().dispatch url_paths = { - '{0}/$': dispatch, - '{0}/(?P\d+)/(?P[a-zA-Z0-9\-_\.]+)': dispatch, + "{0}/$": dispatch, + "{0}/(?P\d+)/(?P[a-zA-Z0-9\-_\.]+)": dispatch, } diff --git a/moto/sqs/utils.py b/moto/sqs/utils.py index 78be5f629..f3b8bbfe8 100644 --- a/moto/sqs/utils.py +++ b/moto/sqs/utils.py @@ -8,46 +8,62 @@ from .exceptions import MessageAttributesInvalid def generate_receipt_handle(): # http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ImportantIdentifiers.html#ImportantIdentifiers-receipt-handles length = 185 - return ''.join(random.choice(string.ascii_lowercase) for x in range(length)) + return "".join(random.choice(string.ascii_lowercase) for x in range(length)) -def parse_message_attributes(querystring, base='', value_namespace='Value.'): +def parse_message_attributes(querystring, base="", value_namespace="Value."): message_attributes = {} index = 1 while True: # Loop through looking for message attributes - name_key = base + 'MessageAttribute.{0}.Name'.format(index) + name_key = base + "MessageAttribute.{0}.Name".format(index) name = querystring.get(name_key) if not name: # Found all attributes break - data_type_key = base + \ - 'MessageAttribute.{0}.{1}DataType'.format(index, value_namespace) + data_type_key = base + "MessageAttribute.{0}.{1}DataType".format( + index, value_namespace + ) data_type = querystring.get(data_type_key) if not data_type: raise MessageAttributesInvalid( - "The message attribute '{0}' must contain non-empty message attribute value.".format(name[0])) + "The message attribute '{0}' must contain non-empty message attribute value.".format( + name[0] + ) + ) - data_type_parts = data_type[0].split('.') - if len(data_type_parts) > 2 or data_type_parts[0] not in ['String', 'Binary', 'Number']: + data_type_parts = data_type[0].split(".") + if len(data_type_parts) > 2 or data_type_parts[0] not in [ + "String", + "Binary", + "Number", + ]: raise MessageAttributesInvalid( - "The message attribute '{0}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String.".format(name[0])) + "The message attribute '{0}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String.".format( + name[0] + ) + ) - type_prefix = 'String' - if data_type_parts[0] == 'Binary': - type_prefix = 'Binary' + type_prefix = "String" + if data_type_parts[0] == "Binary": + type_prefix = "Binary" - value_key = base + \ - 'MessageAttribute.{0}.{1}{2}Value'.format( - index, value_namespace, type_prefix) + value_key = base + "MessageAttribute.{0}.{1}{2}Value".format( + index, value_namespace, type_prefix + ) value = querystring.get(value_key) if not value: raise MessageAttributesInvalid( - "The message attribute '{0}' must contain non-empty message attribute value for message attribute type '{1}'.".format(name[0], data_type[0])) + "The message attribute '{0}' must contain non-empty message attribute value for message attribute type '{1}'.".format( + name[0], data_type[0] + ) + ) - message_attributes[name[0]] = {'data_type': data_type[ - 0], type_prefix.lower() + '_value': value[0]} + message_attributes[name[0]] = { + "data_type": data_type[0], + type_prefix.lower() + "_value": value[0], + } index += 1 diff --git a/moto/ssm/__init__.py b/moto/ssm/__init__.py index c42f3b780..18112544a 100644 --- a/moto/ssm/__init__.py +++ b/moto/ssm/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import ssm_backends from ..core.models import base_decorator -ssm_backend = ssm_backends['us-east-1'] +ssm_backend = ssm_backends["us-east-1"] mock_ssm = base_decorator(ssm_backends) diff --git a/moto/ssm/exceptions.py b/moto/ssm/exceptions.py new file mode 100644 index 000000000..3458fe7d3 --- /dev/null +++ b/moto/ssm/exceptions.py @@ -0,0 +1,30 @@ +from __future__ import unicode_literals +from moto.core.exceptions import JsonRESTError + + +class InvalidFilterKey(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidFilterKey, self).__init__("InvalidFilterKey", message) + + +class InvalidFilterOption(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidFilterOption, self).__init__("InvalidFilterOption", message) + + +class InvalidFilterValue(JsonRESTError): + code = 400 + + def __init__(self, message): + super(InvalidFilterValue, self).__init__("InvalidFilterValue", message) + + +class ValidationException(JsonRESTError): + code = 400 + + def __init__(self, message): + super(ValidationException, self).__init__("ValidationException", message) diff --git a/moto/ssm/models.py b/moto/ssm/models.py index 2f316a3ac..60c47f021 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import re from collections import defaultdict from moto.core import BaseBackend, BaseModel @@ -12,10 +13,27 @@ import time import uuid import itertools +from .utils import parameter_arn +from .exceptions import ( + ValidationException, + InvalidFilterValue, + InvalidFilterOption, + InvalidFilterKey, +) + class Parameter(BaseModel): - def __init__(self, name, value, type, description, allowed_pattern, keyid, - last_modified_date, version): + def __init__( + self, + name, + value, + type, + description, + allowed_pattern, + keyid, + last_modified_date, + version, + ): self.name = name self.type = type self.description = description @@ -24,45 +42,52 @@ class Parameter(BaseModel): self.last_modified_date = last_modified_date self.version = version - if self.type == 'SecureString': + if self.type == "SecureString": + if not self.keyid: + self.keyid = "alias/aws/ssm" + self.value = self.encrypt(value) else: self.value = value def encrypt(self, value): - return 'kms:{}:'.format(self.keyid or 'default') + value + return "kms:{}:".format(self.keyid) + value def decrypt(self, value): - if self.type != 'SecureString': + if self.type != "SecureString": return value - prefix = 'kms:{}:'.format(self.keyid or 'default') + prefix = "kms:{}:".format(self.keyid or "default") if value.startswith(prefix): - return value[len(prefix):] + return value[len(prefix) :] - def response_object(self, decrypt=False): + def response_object(self, decrypt=False, region=None): r = { - 'Name': self.name, - 'Type': self.type, - 'Value': self.decrypt(self.value) if decrypt else self.value, - 'Version': self.version, + "Name": self.name, + "Type": self.type, + "Value": self.decrypt(self.value) if decrypt else self.value, + "Version": self.version, + "LastModifiedDate": round(self.last_modified_date, 3), } + if region: + r["ARN"] = parameter_arn(region, self.name) + return r def describe_response_object(self, decrypt=False): r = self.response_object(decrypt) - r['LastModifiedDate'] = int(self.last_modified_date) - r['LastModifiedUser'] = 'N/A' + r["LastModifiedDate"] = round(self.last_modified_date, 3) + r["LastModifiedUser"] = "N/A" if self.description: - r['Description'] = self.description + r["Description"] = self.description if self.keyid: - r['KeyId'] = self.keyid + r["KeyId"] = self.keyid if self.allowed_pattern: - r['AllowedPattern'] = self.allowed_pattern + r["AllowedPattern"] = self.allowed_pattern return r @@ -71,11 +96,23 @@ MAX_TIMEOUT_SECONDS = 3600 class Command(BaseModel): - def __init__(self, comment='', document_name='', timeout_seconds=MAX_TIMEOUT_SECONDS, - instance_ids=None, max_concurrency='', max_errors='', - notification_config=None, output_s3_bucket_name='', - output_s3_key_prefix='', output_s3_region='', parameters=None, - service_role_arn='', targets=None, backend_region='us-east-1'): + def __init__( + self, + comment="", + document_name="", + timeout_seconds=MAX_TIMEOUT_SECONDS, + instance_ids=None, + max_concurrency="", + max_errors="", + notification_config=None, + output_s3_bucket_name="", + output_s3_key_prefix="", + output_s3_region="", + parameters=None, + service_role_arn="", + targets=None, + backend_region="us-east-1", + ): if instance_ids is None: instance_ids = [] @@ -93,12 +130,14 @@ class Command(BaseModel): self.completed_count = len(instance_ids) self.target_count = len(instance_ids) self.command_id = str(uuid.uuid4()) - self.status = 'Success' - self.status_details = 'Details placeholder' + self.status = "Success" + self.status_details = "Details placeholder" self.requested_date_time = datetime.datetime.now() self.requested_date_time_iso = self.requested_date_time.isoformat() - expires_after = self.requested_date_time + datetime.timedelta(0, timeout_seconds) + expires_after = self.requested_date_time + datetime.timedelta( + 0, timeout_seconds + ) self.expires_after = expires_after.isoformat() self.comment = comment @@ -116,9 +155,11 @@ class Command(BaseModel): self.backend_region = backend_region # Get instance ids from a cloud formation stack target. - stack_instance_ids = [self.get_instance_ids_by_stack_ids(target['Values']) for - target in self.targets if - target['Key'] == 'tag:aws:cloudformation:stack-name'] + stack_instance_ids = [ + self.get_instance_ids_by_stack_ids(target["Values"]) + for target in self.targets + if target["Key"] == "tag:aws:cloudformation:stack-name" + ] self.instance_ids += list(itertools.chain.from_iterable(stack_instance_ids)) @@ -126,7 +167,8 @@ class Command(BaseModel): self.invocations = [] for instance_id in self.instance_ids: self.invocations.append( - self.invocation_response(instance_id, "aws:runShellScript")) + self.invocation_response(instance_id, "aws:runShellScript") + ) def get_instance_ids_by_stack_ids(self, stack_ids): instance_ids = [] @@ -134,34 +176,36 @@ class Command(BaseModel): for stack_id in stack_ids: stack_resources = cloudformation_backend.list_stack_resources(stack_id) instance_resources = [ - instance.id for instance in stack_resources - if instance.type == "AWS::EC2::Instance"] + instance.id + for instance in stack_resources + if instance.type == "AWS::EC2::Instance" + ] instance_ids.extend(instance_resources) return instance_ids def response_object(self): r = { - 'CommandId': self.command_id, - 'Comment': self.comment, - 'CompletedCount': self.completed_count, - 'DocumentName': self.document_name, - 'ErrorCount': self.error_count, - 'ExpiresAfter': self.expires_after, - 'InstanceIds': self.instance_ids, - 'MaxConcurrency': self.max_concurrency, - 'MaxErrors': self.max_errors, - 'NotificationConfig': self.notification_config, - 'OutputS3Region': self.output_s3_region, - 'OutputS3BucketName': self.output_s3_bucket_name, - 'OutputS3KeyPrefix': self.output_s3_key_prefix, - 'Parameters': self.parameters, - 'RequestedDateTime': self.requested_date_time_iso, - 'ServiceRole': self.service_role_arn, - 'Status': self.status, - 'StatusDetails': self.status_details, - 'TargetCount': self.target_count, - 'Targets': self.targets, + "CommandId": self.command_id, + "Comment": self.comment, + "CompletedCount": self.completed_count, + "DocumentName": self.document_name, + "ErrorCount": self.error_count, + "ExpiresAfter": self.expires_after, + "InstanceIds": self.instance_ids, + "MaxConcurrency": self.max_concurrency, + "MaxErrors": self.max_errors, + "NotificationConfig": self.notification_config, + "OutputS3Region": self.output_s3_region, + "OutputS3BucketName": self.output_s3_bucket_name, + "OutputS3KeyPrefix": self.output_s3_key_prefix, + "Parameters": self.parameters, + "RequestedDateTime": self.requested_date_time_iso, + "ServiceRole": self.service_role_arn, + "Status": self.status, + "StatusDetails": self.status_details, + "TargetCount": self.target_count, + "Targets": self.targets, } return r @@ -175,48 +219,58 @@ class Command(BaseModel): end_time = self.requested_date_time + elapsed_time_delta r = { - 'CommandId': self.command_id, - 'InstanceId': instance_id, - 'Comment': self.comment, - 'DocumentName': self.document_name, - 'PluginName': plugin_name, - 'ResponseCode': 0, - 'ExecutionStartDateTime': self.requested_date_time_iso, - 'ExecutionElapsedTime': elapsed_time_iso, - 'ExecutionEndDateTime': end_time.isoformat(), - 'Status': 'Success', - 'StatusDetails': 'Success', - 'StandardOutputContent': '', - 'StandardOutputUrl': '', - 'StandardErrorContent': '', + "CommandId": self.command_id, + "InstanceId": instance_id, + "Comment": self.comment, + "DocumentName": self.document_name, + "PluginName": plugin_name, + "ResponseCode": 0, + "ExecutionStartDateTime": self.requested_date_time_iso, + "ExecutionElapsedTime": elapsed_time_iso, + "ExecutionEndDateTime": end_time.isoformat(), + "Status": "Success", + "StatusDetails": "Success", + "StandardOutputContent": "", + "StandardOutputUrl": "", + "StandardErrorContent": "", } return r def get_invocation(self, instance_id, plugin_name): invocation = next( - (invocation for invocation in self.invocations - if invocation['InstanceId'] == instance_id), None) + ( + invocation + for invocation in self.invocations + if invocation["InstanceId"] == instance_id + ), + None, + ) if invocation is None: raise RESTError( - 'InvocationDoesNotExist', - 'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation') + "InvocationDoesNotExist", + "An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation", + ) - if plugin_name is not None and invocation['PluginName'] != plugin_name: - raise RESTError( - 'InvocationDoesNotExist', - 'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation') + if plugin_name is not None and invocation["PluginName"] != plugin_name: + raise RESTError( + "InvocationDoesNotExist", + "An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation", + ) return invocation class SimpleSystemManagerBackend(BaseBackend): - def __init__(self): - self._parameters = {} + # each value is a list of all of the versions for a parameter + # to get the current value, grab the last item of the list + self._parameters = defaultdict(list) + self._resource_tags = defaultdict(lambda: defaultdict(dict)) self._commands = [] + self._errors = [] # figure out what region we're in for region, backend in ssm_backends.items(): @@ -239,6 +293,215 @@ class SimpleSystemManagerBackend(BaseBackend): pass return result + def describe_parameters(self, filters, parameter_filters): + if filters and parameter_filters: + raise ValidationException( + "You can use either Filters or ParameterFilters in a single request." + ) + + self._validate_parameter_filters(parameter_filters, by_path=False) + + result = [] + for param_name in self._parameters: + ssm_parameter = self.get_parameter(param_name, False) + if not self._match_filters(ssm_parameter, parameter_filters): + continue + + if filters: + for filter in filters: + if filter["Key"] == "Name": + k = ssm_parameter.name + for v in filter["Values"]: + if k.startswith(v): + result.append(ssm_parameter) + break + elif filter["Key"] == "Type": + k = ssm_parameter.type + for v in filter["Values"]: + if k == v: + result.append(ssm_parameter) + break + elif filter["Key"] == "KeyId": + k = ssm_parameter.keyid + if k: + for v in filter["Values"]: + if k == v: + result.append(ssm_parameter) + break + continue + + result.append(ssm_parameter) + + return result + + def _validate_parameter_filters(self, parameter_filters, by_path): + for index, filter_obj in enumerate(parameter_filters or []): + key = filter_obj["Key"] + values = filter_obj.get("Values", []) + + if key == "Path": + option = filter_obj.get("Option", "OneLevel") + else: + option = filter_obj.get("Option", "Equals") + + if not re.match(r"^tag:.+|Name|Type|KeyId|Path|Label|Tier$", key): + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.key".format( + index=(index + 1) + ), + value=key, + constraint="Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier", + ) + ) + + if len(key) > 132: + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.key".format( + index=(index + 1) + ), + value=key, + constraint="Member must have length less than or equal to 132", + ) + ) + + if len(option) > 10: + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.option".format( + index=(index + 1) + ), + value="over 10 chars", + constraint="Member must have length less than or equal to 10", + ) + ) + + if len(values) > 50: + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.values".format( + index=(index + 1) + ), + value=values, + constraint="Member must have length less than or equal to 50", + ) + ) + + if any(len(value) > 1024 for value in values): + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.values".format( + index=(index + 1) + ), + value=values, + constraint="[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]", + ) + ) + + self._raise_errors() + + filter_keys = [] + for filter_obj in parameter_filters or []: + key = filter_obj["Key"] + values = filter_obj.get("Values") + + if key == "Path": + option = filter_obj.get("Option", "OneLevel") + else: + option = filter_obj.get("Option", "Equals") + + if not by_path and key == "Label": + raise InvalidFilterKey( + "The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier]." + ) + + if not values: + raise InvalidFilterValue( + "The following filter values are missing : null for filter key Name." + ) + + if key in filter_keys: + raise InvalidFilterKey( + "The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter." + ) + + if key == "Path": + if option not in ["Recursive", "OneLevel"]: + raise InvalidFilterOption( + "The following filter option is not valid: {option}. Valid options include: [Recursive, OneLevel].".format( + option=option + ) + ) + if any(value.lower().startswith(("/aws", "/ssm")) for value in values): + raise ValidationException( + 'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' + "When using global parameters, please specify within a global namespace." + ) + for value in values: + if value.lower().startswith(("/aws", "/ssm")): + raise ValidationException( + 'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' + "When using global parameters, please specify within a global namespace." + ) + if ( + "//" in value + or not value.startswith("/") + or not re.match("^[a-zA-Z0-9_.-/]*$", value) + ): + raise ValidationException( + 'The parameter doesn\'t meet the parameter name requirements. The parameter name must begin with a forward slash "/". ' + 'It can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' + "It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). " + 'Special characters are not allowed. All sub-paths, if specified, must use the forward slash symbol "/". ' + "Valid example: /get/parameters2-/by1./path0_." + ) + + if key == "Tier": + for value in values: + if value not in ["Standard", "Advanced", "Intelligent-Tiering"]: + raise InvalidFilterOption( + "The following filter value is not valid: {value}. Valid values include: [Standard, Advanced, Intelligent-Tiering].".format( + value=value + ) + ) + + if key == "Type": + for value in values: + if value not in ["String", "StringList", "SecureString"]: + raise InvalidFilterOption( + "The following filter value is not valid: {value}. Valid values include: [String, StringList, SecureString].".format( + value=value + ) + ) + + if key != "Path" and option not in ["Equals", "BeginsWith"]: + raise InvalidFilterOption( + "The following filter option is not valid: {option}. Valid options include: [BeginsWith, Equals].".format( + option=option + ) + ) + + filter_keys.append(key) + + def _format_error(self, key, value, constraint): + return 'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}'.format( + constraint=constraint, key=key, value=value + ) + + def _raise_errors(self): + if self._errors: + count = len(self._errors) + plural = "s" if len(self._errors) > 1 else "" + errors = "; ".join(self._errors) + self._errors = [] # reset collected errors + + raise ValidationException( + "{count} validation error{plural} detected: {errors}".format( + count=count, plural=plural, errors=errors + ) + ) + def get_all_parameters(self): result = [] for k, _ in self._parameters.items(): @@ -249,68 +512,138 @@ class SimpleSystemManagerBackend(BaseBackend): result = [] for name in names: if name in self._parameters: - result.append(self._parameters[name]) + result.append(self.get_parameter(name, with_decryption)) return result - def get_parameters_by_path(self, path, with_decryption, recursive, filters=None): + def get_parameters_by_path( + self, + path, + with_decryption, + recursive, + filters=None, + next_token=None, + max_results=10, + ): """Implement the get-parameters-by-path-API in the backend.""" result = [] # path could be with or without a trailing /. we handle this # difference here. - path = path.rstrip('/') + '/' - for param in self._parameters: - if path != '/' and not param.startswith(path): + path = path.rstrip("/") + "/" + for param_name in self._parameters: + if path != "/" and not param_name.startswith(path): continue - if '/' in param[len(path) + 1:] and not recursive: + if "/" in param_name[len(path) + 1 :] and not recursive: continue - if not self._match_filters(self._parameters[param], filters): + if not self._match_filters( + self.get_parameter(param_name, with_decryption), filters + ): continue - result.append(self._parameters[param]) + result.append(self.get_parameter(param_name, with_decryption)) - return result + return self._get_values_nexttoken(result, max_results, next_token) - @staticmethod - def _match_filters(parameter, filters=None): + def _get_values_nexttoken(self, values_list, max_results, next_token=None): + if next_token is None: + next_token = 0 + next_token = int(next_token) + max_results = int(max_results) + values = values_list[next_token : next_token + max_results] + if len(values) == max_results: + next_token = str(next_token + max_results) + else: + next_token = None + return values, next_token + + def get_parameter_history(self, name, with_decryption): + if name in self._parameters: + return self._parameters[name] + return None + + def _match_filters(self, parameter, filters=None): """Return True if the given parameter matches all the filters""" - for filter_obj in (filters or []): - key = filter_obj['Key'] - option = filter_obj.get('Option', 'Equals') - values = filter_obj.get('Values', []) + for filter_obj in filters or []: + key = filter_obj["Key"] + values = filter_obj.get("Values", []) + + if key == "Path": + option = filter_obj.get("Option", "OneLevel") + else: + option = filter_obj.get("Option", "Equals") what = None - if key == 'Type': - what = parameter.type - elif key == 'KeyId': + if key == "KeyId": what = parameter.keyid + elif key == "Name": + what = "/" + parameter.name.lstrip("/") + values = ["/" + value.lstrip("/") for value in values] + elif key == "Path": + what = "/" + parameter.name.lstrip("/") + values = ["/" + value.strip("/") for value in values] + elif key == "Type": + what = parameter.type - if option == 'Equals'\ - and not any(what == value for value in values): + if what is None: return False - elif option == 'BeginsWith'\ - and not any(what.startswith(value) for value in values): + elif option == "BeginsWith" and not any( + what.startswith(value) for value in values + ): return False + elif option == "Equals" and not any(what == value for value in values): + return False + elif option == "OneLevel": + if any(value == "/" and len(what.split("/")) == 2 for value in values): + continue + elif any( + value != "/" + and what.startswith(value + "/") + and len(what.split("/")) - 1 == len(value.split("/")) + for value in values + ): + continue + else: + return False + elif option == "Recursive": + if any(value == "/" for value in values): + continue + elif any(what.startswith(value + "/") for value in values): + continue + else: + return False # True if no false match (or no filters at all) return True def get_parameter(self, name, with_decryption): if name in self._parameters: - return self._parameters[name] + return self._parameters[name][-1] return None - def put_parameter(self, name, description, value, type, allowed_pattern, - keyid, overwrite): - previous_parameter = self._parameters.get(name) - version = 1 - - if previous_parameter: + def put_parameter( + self, name, description, value, type, allowed_pattern, keyid, overwrite + ): + previous_parameter_versions = self._parameters[name] + if len(previous_parameter_versions) == 0: + previous_parameter = None + version = 1 + else: + previous_parameter = previous_parameter_versions[-1] version = previous_parameter.version + 1 if not overwrite: return last_modified_date = time.time() - self._parameters[name] = Parameter(name, value, type, description, - allowed_pattern, keyid, last_modified_date, version) + self._parameters[name].append( + Parameter( + name, + value, + type, + description, + allowed_pattern, + keyid, + last_modified_date, + version, + ) + ) return version def add_tags_to_resource(self, resource_type, resource_id, tags): @@ -328,29 +661,31 @@ class SimpleSystemManagerBackend(BaseBackend): def send_command(self, **kwargs): command = Command( - comment=kwargs.get('Comment', ''), - document_name=kwargs.get('DocumentName'), - timeout_seconds=kwargs.get('TimeoutSeconds', 3600), - instance_ids=kwargs.get('InstanceIds', []), - max_concurrency=kwargs.get('MaxConcurrency', '50'), - max_errors=kwargs.get('MaxErrors', '0'), - notification_config=kwargs.get('NotificationConfig', { - 'NotificationArn': 'string', - 'NotificationEvents': ['Success'], - 'NotificationType': 'Command' - }), - output_s3_bucket_name=kwargs.get('OutputS3BucketName', ''), - output_s3_key_prefix=kwargs.get('OutputS3KeyPrefix', ''), - output_s3_region=kwargs.get('OutputS3Region', ''), - parameters=kwargs.get('Parameters', {}), - service_role_arn=kwargs.get('ServiceRoleArn', ''), - targets=kwargs.get('Targets', []), - backend_region=self._region) + comment=kwargs.get("Comment", ""), + document_name=kwargs.get("DocumentName"), + timeout_seconds=kwargs.get("TimeoutSeconds", 3600), + instance_ids=kwargs.get("InstanceIds", []), + max_concurrency=kwargs.get("MaxConcurrency", "50"), + max_errors=kwargs.get("MaxErrors", "0"), + notification_config=kwargs.get( + "NotificationConfig", + { + "NotificationArn": "string", + "NotificationEvents": ["Success"], + "NotificationType": "Command", + }, + ), + output_s3_bucket_name=kwargs.get("OutputS3BucketName", ""), + output_s3_key_prefix=kwargs.get("OutputS3KeyPrefix", ""), + output_s3_region=kwargs.get("OutputS3Region", ""), + parameters=kwargs.get("Parameters", {}), + service_role_arn=kwargs.get("ServiceRoleArn", ""), + targets=kwargs.get("Targets", []), + backend_region=self._region, + ) self._commands.append(command) - return { - 'Command': command.response_object() - } + return {"Command": command.response_object()} def list_commands(self, **kwargs): """ @@ -358,39 +693,38 @@ class SimpleSystemManagerBackend(BaseBackend): """ commands = self._commands - command_id = kwargs.get('CommandId', None) + command_id = kwargs.get("CommandId", None) if command_id: commands = [self.get_command_by_id(command_id)] - instance_id = kwargs.get('InstanceId', None) + instance_id = kwargs.get("InstanceId", None) if instance_id: commands = self.get_commands_by_instance_id(instance_id) - return { - 'Commands': [command.response_object() for command in commands] - } + return {"Commands": [command.response_object() for command in commands]} def get_command_by_id(self, id): command = next( - (command for command in self._commands if command.command_id == id), None) + (command for command in self._commands if command.command_id == id), None + ) if command is None: - raise RESTError('InvalidCommandId', 'Invalid command id.') + raise RESTError("InvalidCommandId", "Invalid command id.") return command def get_commands_by_instance_id(self, instance_id): return [ - command for command in self._commands - if instance_id in command.instance_ids] + command for command in self._commands if instance_id in command.instance_ids + ] def get_command_invocation(self, **kwargs): """ https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_GetCommandInvocation.html """ - command_id = kwargs.get('CommandId') - instance_id = kwargs.get('InstanceId') - plugin_name = kwargs.get('PluginName', None) + command_id = kwargs.get("CommandId") + instance_id = kwargs.get("InstanceId") + plugin_name = kwargs.get("PluginName", None) command = self.get_command_by_id(command_id) return command.get_invocation(instance_id, plugin_name) diff --git a/moto/ssm/responses.py b/moto/ssm/responses.py index c47d4127a..1b13780a8 100644 --- a/moto/ssm/responses.py +++ b/moto/ssm/responses.py @@ -6,7 +6,6 @@ from .models import ssm_backends class SimpleSystemManagerResponse(BaseResponse): - @property def ssm_backend(self): return ssm_backends[self.region] @@ -22,195 +21,180 @@ class SimpleSystemManagerResponse(BaseResponse): return self.request_params.get(param, default) def delete_parameter(self): - name = self._get_param('Name') + name = self._get_param("Name") self.ssm_backend.delete_parameter(name) return json.dumps({}) def delete_parameters(self): - names = self._get_param('Names') + names = self._get_param("Names") result = self.ssm_backend.delete_parameters(names) - response = { - 'DeletedParameters': [], - 'InvalidParameters': [] - } + response = {"DeletedParameters": [], "InvalidParameters": []} for name in names: if name in result: - response['DeletedParameters'].append(name) + response["DeletedParameters"].append(name) else: - response['InvalidParameters'].append(name) + response["InvalidParameters"].append(name) return json.dumps(response) def get_parameter(self): - name = self._get_param('Name') - with_decryption = self._get_param('WithDecryption') + name = self._get_param("Name") + with_decryption = self._get_param("WithDecryption") result = self.ssm_backend.get_parameter(name, with_decryption) if result is None: error = { - '__type': 'ParameterNotFound', - 'message': 'Parameter {0} not found.'.format(name) + "__type": "ParameterNotFound", + "message": "Parameter {0} not found.".format(name), } return json.dumps(error), dict(status=400) - response = { - 'Parameter': result.response_object(with_decryption) - } + response = {"Parameter": result.response_object(with_decryption, self.region)} return json.dumps(response) def get_parameters(self): - names = self._get_param('Names') - with_decryption = self._get_param('WithDecryption') + names = self._get_param("Names") + with_decryption = self._get_param("WithDecryption") result = self.ssm_backend.get_parameters(names, with_decryption) - response = { - 'Parameters': [], - 'InvalidParameters': [], - } + response = {"Parameters": [], "InvalidParameters": []} for parameter in result: - param_data = parameter.response_object(with_decryption) - response['Parameters'].append(param_data) + param_data = parameter.response_object(with_decryption, self.region) + response["Parameters"].append(param_data) param_names = [param.name for param in result] for name in names: if name not in param_names: - response['InvalidParameters'].append(name) + response["InvalidParameters"].append(name) return json.dumps(response) def get_parameters_by_path(self): - path = self._get_param('Path') - with_decryption = self._get_param('WithDecryption') - recursive = self._get_param('Recursive', False) - filters = self._get_param('ParameterFilters') + path = self._get_param("Path") + with_decryption = self._get_param("WithDecryption") + recursive = self._get_param("Recursive", False) + filters = self._get_param("ParameterFilters") + token = self._get_param("NextToken") + max_results = self._get_param("MaxResults", 10) - result = self.ssm_backend.get_parameters_by_path( - path, with_decryption, recursive, filters + result, next_token = self.ssm_backend.get_parameters_by_path( + path, + with_decryption, + recursive, + filters, + next_token=token, + max_results=max_results, ) - response = { - 'Parameters': [], - } + response = {"Parameters": [], "NextToken": next_token} for parameter in result: - param_data = parameter.response_object(with_decryption) - response['Parameters'].append(param_data) + param_data = parameter.response_object(with_decryption, self.region) + response["Parameters"].append(param_data) return json.dumps(response) def describe_parameters(self): page_size = 10 - filters = self._get_param('Filters') - token = self._get_param('NextToken') - if hasattr(token, 'strip'): + filters = self._get_param("Filters") + parameter_filters = self._get_param("ParameterFilters") + token = self._get_param("NextToken") + if hasattr(token, "strip"): token = token.strip() if not token: - token = '0' + token = "0" token = int(token) - result = self.ssm_backend.get_all_parameters() - response = { - 'Parameters': [], - } + result = self.ssm_backend.describe_parameters(filters, parameter_filters) + + response = {"Parameters": []} end = token + page_size for parameter in result[token:]: - param_data = parameter.describe_response_object(False) - add = False - - if filters: - for filter in filters: - if filter['Key'] == 'Name': - k = param_data['Name'] - for v in filter['Values']: - if k.startswith(v): - add = True - break - elif filter['Key'] == 'Type': - k = param_data['Type'] - for v in filter['Values']: - if k == v: - add = True - break - elif filter['Key'] == 'KeyId': - k = param_data.get('KeyId') - if k: - for v in filter['Values']: - if k == v: - add = True - break - else: - add = True - - if add: - response['Parameters'].append(param_data) + response["Parameters"].append(parameter.describe_response_object(False)) token = token + 1 - if len(response['Parameters']) == page_size: - response['NextToken'] = str(end) + if len(response["Parameters"]) == page_size: + response["NextToken"] = str(end) break return json.dumps(response) def put_parameter(self): - name = self._get_param('Name') - description = self._get_param('Description') - value = self._get_param('Value') - type_ = self._get_param('Type') - allowed_pattern = self._get_param('AllowedPattern') - keyid = self._get_param('KeyId') - overwrite = self._get_param('Overwrite', False) + name = self._get_param("Name") + description = self._get_param("Description") + value = self._get_param("Value") + type_ = self._get_param("Type") + allowed_pattern = self._get_param("AllowedPattern") + keyid = self._get_param("KeyId") + overwrite = self._get_param("Overwrite", False) result = self.ssm_backend.put_parameter( - name, description, value, type_, allowed_pattern, keyid, overwrite) + name, description, value, type_, allowed_pattern, keyid, overwrite + ) if result is None: error = { - '__type': 'ParameterAlreadyExists', - 'message': 'Parameter {0} already exists.'.format(name) + "__type": "ParameterAlreadyExists", + "message": "Parameter {0} already exists.".format(name), } return json.dumps(error), dict(status=400) - response = {'Version': result} + response = {"Version": result} + return json.dumps(response) + + def get_parameter_history(self): + name = self._get_param("Name") + with_decryption = self._get_param("WithDecryption") + + result = self.ssm_backend.get_parameter_history(name, with_decryption) + + if result is None: + error = { + "__type": "ParameterNotFound", + "message": "Parameter {0} not found.".format(name), + } + return json.dumps(error), dict(status=400) + + response = {"Parameters": []} + for parameter_version in result: + param_data = parameter_version.describe_response_object( + decrypt=with_decryption + ) + response["Parameters"].append(param_data) + return json.dumps(response) def add_tags_to_resource(self): - resource_id = self._get_param('ResourceId') - resource_type = self._get_param('ResourceType') - tags = {t['Key']: t['Value'] for t in self._get_param('Tags')} - self.ssm_backend.add_tags_to_resource( - resource_id, resource_type, tags) + resource_id = self._get_param("ResourceId") + resource_type = self._get_param("ResourceType") + tags = {t["Key"]: t["Value"] for t in self._get_param("Tags")} + self.ssm_backend.add_tags_to_resource(resource_id, resource_type, tags) return json.dumps({}) def remove_tags_from_resource(self): - resource_id = self._get_param('ResourceId') - resource_type = self._get_param('ResourceType') - keys = self._get_param('TagKeys') - self.ssm_backend.remove_tags_from_resource( - resource_id, resource_type, keys) + resource_id = self._get_param("ResourceId") + resource_type = self._get_param("ResourceType") + keys = self._get_param("TagKeys") + self.ssm_backend.remove_tags_from_resource(resource_id, resource_type, keys) return json.dumps({}) def list_tags_for_resource(self): - resource_id = self._get_param('ResourceId') - resource_type = self._get_param('ResourceType') - tags = self.ssm_backend.list_tags_for_resource( - resource_id, resource_type) - tag_list = [{'Key': k, 'Value': v} for (k, v) in tags.items()] - response = {'TagList': tag_list} + resource_id = self._get_param("ResourceId") + resource_type = self._get_param("ResourceType") + tags = self.ssm_backend.list_tags_for_resource(resource_id, resource_type) + tag_list = [{"Key": k, "Value": v} for (k, v) in tags.items()] + response = {"TagList": tag_list} return json.dumps(response) def send_command(self): - return json.dumps( - self.ssm_backend.send_command(**self.request_params) - ) + return json.dumps(self.ssm_backend.send_command(**self.request_params)) def list_commands(self): - return json.dumps( - self.ssm_backend.list_commands(**self.request_params) - ) + return json.dumps(self.ssm_backend.list_commands(**self.request_params)) def get_command_invocation(self): return json.dumps( diff --git a/moto/ssm/urls.py b/moto/ssm/urls.py index 9ac327325..bd6706dfa 100644 --- a/moto/ssm/urls.py +++ b/moto/ssm/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import SimpleSystemManagerResponse -url_bases = [ - "https?://ssm.(.+).amazonaws.com", - "https?://ssm.(.+).amazonaws.com.cn", -] +url_bases = ["https?://ssm.(.+).amazonaws.com", "https?://ssm.(.+).amazonaws.com.cn"] -url_paths = { - '{0}/$': SimpleSystemManagerResponse.dispatch, -} +url_paths = {"{0}/$": SimpleSystemManagerResponse.dispatch} diff --git a/moto/ssm/utils.py b/moto/ssm/utils.py new file mode 100644 index 000000000..3f3762d1c --- /dev/null +++ b/moto/ssm/utils.py @@ -0,0 +1,9 @@ +ACCOUNT_ID = "1234567890" + + +def parameter_arn(region, parameter_name): + if parameter_name[0] == "/": + parameter_name = parameter_name[1:] + return "arn:aws:ssm:{0}:{1}:parameter/{2}".format( + region, ACCOUNT_ID, parameter_name + ) diff --git a/moto/stepfunctions/__init__.py b/moto/stepfunctions/__init__.py new file mode 100644 index 000000000..6dd50c9dc --- /dev/null +++ b/moto/stepfunctions/__init__.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals +from .models import stepfunction_backends +from ..core.models import base_decorator + +stepfunction_backend = stepfunction_backends["us-east-1"] +mock_stepfunctions = base_decorator(stepfunction_backends) diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py new file mode 100644 index 000000000..704e4ea83 --- /dev/null +++ b/moto/stepfunctions/exceptions.py @@ -0,0 +1,38 @@ +from __future__ import unicode_literals +import json + + +class AWSError(Exception): + TYPE = None + STATUS = 400 + + def __init__(self, message, type=None, status=None): + self.message = message + self.type = type if type is not None else self.TYPE + self.status = status if status is not None else self.STATUS + + def response(self): + return ( + json.dumps({"__type": self.type, "message": self.message}), + dict(status=self.status), + ) + + +class ExecutionDoesNotExist(AWSError): + TYPE = "ExecutionDoesNotExist" + STATUS = 400 + + +class InvalidArn(AWSError): + TYPE = "InvalidArn" + STATUS = 400 + + +class InvalidName(AWSError): + TYPE = "InvalidName" + STATUS = 400 + + +class StateMachineDoesNotExist(AWSError): + TYPE = "StateMachineDoesNotExist" + STATUS = 400 diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py new file mode 100644 index 000000000..665f3b777 --- /dev/null +++ b/moto/stepfunctions/models.py @@ -0,0 +1,286 @@ +import boto +import re +from datetime import datetime +from moto.core import BaseBackend +from moto.core.utils import iso_8601_datetime_without_milliseconds +from moto.sts.models import ACCOUNT_ID +from uuid import uuid4 +from .exceptions import ( + ExecutionDoesNotExist, + InvalidArn, + InvalidName, + StateMachineDoesNotExist, +) + + +class StateMachine: + def __init__(self, arn, name, definition, roleArn, tags=None): + self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.arn = arn + self.name = name + self.definition = definition + self.roleArn = roleArn + self.tags = tags + + +class Execution: + def __init__( + self, + region_name, + account_id, + state_machine_name, + execution_name, + state_machine_arn, + ): + execution_arn = "arn:aws:states:{}:{}:execution:{}:{}" + execution_arn = execution_arn.format( + region_name, account_id, state_machine_name, execution_name + ) + self.execution_arn = execution_arn + self.name = execution_name + self.start_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.state_machine_arn = state_machine_arn + self.status = "RUNNING" + self.stop_date = None + + def stop(self): + self.status = "SUCCEEDED" + self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now()) + + +class StepFunctionBackend(BaseBackend): + + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.create_state_machine + # A name must not contain: + # whitespace + # brackets < > { } [ ] + # wildcard characters ? * + # special characters " # % \ ^ | ~ ` $ & , ; : / + invalid_chars_for_name = [ + " ", + "{", + "}", + "[", + "]", + "<", + ">", + "?", + "*", + '"', + "#", + "%", + "\\", + "^", + "|", + "~", + "`", + "$", + "&", + ",", + ";", + ":", + "/", + ] + # control characters (U+0000-001F , U+007F-009F ) + invalid_unicodes_for_name = [ + u"\u0000", + u"\u0001", + u"\u0002", + u"\u0003", + u"\u0004", + u"\u0005", + u"\u0006", + u"\u0007", + u"\u0008", + u"\u0009", + u"\u000A", + u"\u000B", + u"\u000C", + u"\u000D", + u"\u000E", + u"\u000F", + u"\u0010", + u"\u0011", + u"\u0012", + u"\u0013", + u"\u0014", + u"\u0015", + u"\u0016", + u"\u0017", + u"\u0018", + u"\u0019", + u"\u001A", + u"\u001B", + u"\u001C", + u"\u001D", + u"\u001E", + u"\u001F", + u"\u007F", + u"\u0080", + u"\u0081", + u"\u0082", + u"\u0083", + u"\u0084", + u"\u0085", + u"\u0086", + u"\u0087", + u"\u0088", + u"\u0089", + u"\u008A", + u"\u008B", + u"\u008C", + u"\u008D", + u"\u008E", + u"\u008F", + u"\u0090", + u"\u0091", + u"\u0092", + u"\u0093", + u"\u0094", + u"\u0095", + u"\u0096", + u"\u0097", + u"\u0098", + u"\u0099", + u"\u009A", + u"\u009B", + u"\u009C", + u"\u009D", + u"\u009E", + u"\u009F", + ] + accepted_role_arn_format = re.compile( + "arn:aws:iam::(?P[0-9]{12}):role/.+" + ) + accepted_mchn_arn_format = re.compile( + "arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):stateMachine:.+" + ) + accepted_exec_arn_format = re.compile( + "arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):execution:.+" + ) + + def __init__(self, region_name): + self.state_machines = [] + self.executions = [] + self.region_name = region_name + self._account_id = None + + def create_state_machine(self, name, definition, roleArn, tags=None): + self._validate_name(name) + self._validate_role_arn(roleArn) + arn = ( + "arn:aws:states:" + + self.region_name + + ":" + + str(self._get_account_id()) + + ":stateMachine:" + + name + ) + try: + return self.describe_state_machine(arn) + except StateMachineDoesNotExist: + state_machine = StateMachine(arn, name, definition, roleArn, tags) + self.state_machines.append(state_machine) + return state_machine + + def list_state_machines(self): + return self.state_machines + + def describe_state_machine(self, arn): + self._validate_machine_arn(arn) + sm = next((x for x in self.state_machines if x.arn == arn), None) + if not sm: + raise StateMachineDoesNotExist( + "State Machine Does Not Exist: '" + arn + "'" + ) + return sm + + def delete_state_machine(self, arn): + self._validate_machine_arn(arn) + sm = next((x for x in self.state_machines if x.arn == arn), None) + if sm: + self.state_machines.remove(sm) + + def start_execution(self, state_machine_arn, name=None): + state_machine_name = self.describe_state_machine(state_machine_arn).name + execution = Execution( + region_name=self.region_name, + account_id=self._get_account_id(), + state_machine_name=state_machine_name, + execution_name=name or str(uuid4()), + state_machine_arn=state_machine_arn, + ) + self.executions.append(execution) + return execution + + def stop_execution(self, execution_arn): + execution = next( + (x for x in self.executions if x.execution_arn == execution_arn), None + ) + if not execution: + raise ExecutionDoesNotExist( + "Execution Does Not Exist: '" + execution_arn + "'" + ) + execution.stop() + return execution + + def list_executions(self, state_machine_arn): + return [ + execution + for execution in self.executions + if execution.state_machine_arn == state_machine_arn + ] + + def describe_execution(self, arn): + self._validate_execution_arn(arn) + exctn = next((x for x in self.executions if x.execution_arn == arn), None) + if not exctn: + raise ExecutionDoesNotExist("Execution Does Not Exist: '" + arn + "'") + return exctn + + def reset(self): + region_name = self.region_name + self.__dict__ = {} + self.__init__(region_name) + + def _validate_name(self, name): + if any(invalid_char in name for invalid_char in self.invalid_chars_for_name): + raise InvalidName("Invalid Name: '" + name + "'") + + if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name): + raise InvalidName("Invalid Name: '" + name + "'") + + def _validate_role_arn(self, role_arn): + self._validate_arn( + arn=role_arn, + regex=self.accepted_role_arn_format, + invalid_msg="Invalid Role Arn: '" + role_arn + "'", + ) + + def _validate_machine_arn(self, machine_arn): + self._validate_arn( + arn=machine_arn, + regex=self.accepted_mchn_arn_format, + invalid_msg="Invalid State Machine Arn: '" + machine_arn + "'", + ) + + def _validate_execution_arn(self, execution_arn): + self._validate_arn( + arn=execution_arn, + regex=self.accepted_exec_arn_format, + invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", + ) + + def _validate_arn(self, arn, regex, invalid_msg): + match = regex.match(arn) + if not arn or not match: + raise InvalidArn(invalid_msg) + + def _get_account_id(self): + return ACCOUNT_ID + + +stepfunction_backends = { + _region.name: StepFunctionBackend(_region.name) + for _region in boto.awslambda.regions() +} diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py new file mode 100644 index 000000000..689961d5a --- /dev/null +++ b/moto/stepfunctions/responses.py @@ -0,0 +1,154 @@ +from __future__ import unicode_literals + +import json + +from moto.core.responses import BaseResponse +from moto.core.utils import amzn_request_id +from .exceptions import AWSError +from .models import stepfunction_backends + + +class StepFunctionResponse(BaseResponse): + @property + def stepfunction_backend(self): + return stepfunction_backends[self.region] + + @amzn_request_id + def create_state_machine(self): + name = self._get_param("name") + definition = self._get_param("definition") + roleArn = self._get_param("roleArn") + tags = self._get_param("tags") + try: + state_machine = self.stepfunction_backend.create_state_machine( + name=name, definition=definition, roleArn=roleArn, tags=tags + ) + response = { + "creationDate": state_machine.creation_date, + "stateMachineArn": state_machine.arn, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def list_state_machines(self): + list_all = self.stepfunction_backend.list_state_machines() + list_all = sorted( + [ + { + "creationDate": sm.creation_date, + "name": sm.name, + "stateMachineArn": sm.arn, + } + for sm in list_all + ], + key=lambda x: x["name"], + ) + response = {"stateMachines": list_all} + return 200, {}, json.dumps(response) + + @amzn_request_id + def describe_state_machine(self): + arn = self._get_param("stateMachineArn") + return self._describe_state_machine(arn) + + @amzn_request_id + def _describe_state_machine(self, state_machine_arn): + try: + state_machine = self.stepfunction_backend.describe_state_machine( + state_machine_arn + ) + response = { + "creationDate": state_machine.creation_date, + "stateMachineArn": state_machine.arn, + "definition": state_machine.definition, + "name": state_machine.name, + "roleArn": state_machine.roleArn, + "status": "ACTIVE", + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def delete_state_machine(self): + arn = self._get_param("stateMachineArn") + try: + self.stepfunction_backend.delete_state_machine(arn) + return 200, {}, json.dumps("{}") + except AWSError as err: + return err.response() + + @amzn_request_id + def list_tags_for_resource(self): + arn = self._get_param("resourceArn") + try: + state_machine = self.stepfunction_backend.describe_state_machine(arn) + tags = state_machine.tags or [] + except AWSError: + tags = [] + response = {"tags": tags} + return 200, {}, json.dumps(response) + + @amzn_request_id + def start_execution(self): + arn = self._get_param("stateMachineArn") + name = self._get_param("name") + execution = self.stepfunction_backend.start_execution(arn, name) + response = { + "executionArn": execution.execution_arn, + "startDate": execution.start_date, + } + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_executions(self): + arn = self._get_param("stateMachineArn") + state_machine = self.stepfunction_backend.describe_state_machine(arn) + executions = self.stepfunction_backend.list_executions(arn) + executions = [ + { + "executionArn": execution.execution_arn, + "name": execution.name, + "startDate": execution.start_date, + "stateMachineArn": state_machine.arn, + "status": execution.status, + } + for execution in executions + ] + return 200, {}, json.dumps({"executions": executions}) + + @amzn_request_id + def describe_execution(self): + arn = self._get_param("executionArn") + try: + execution = self.stepfunction_backend.describe_execution(arn) + response = { + "executionArn": arn, + "input": "{}", + "name": execution.name, + "startDate": execution.start_date, + "stateMachineArn": execution.state_machine_arn, + "status": execution.status, + "stopDate": execution.stop_date, + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_state_machine_for_execution(self): + arn = self._get_param("executionArn") + try: + execution = self.stepfunction_backend.describe_execution(arn) + return self._describe_state_machine(execution.state_machine_arn) + except AWSError as err: + return err.response() + + @amzn_request_id + def stop_execution(self): + arn = self._get_param("executionArn") + execution = self.stepfunction_backend.stop_execution(arn) + response = {"stopDate": execution.stop_date} + return 200, {}, json.dumps(response) diff --git a/moto/stepfunctions/urls.py b/moto/stepfunctions/urls.py new file mode 100644 index 000000000..46dfd4e24 --- /dev/null +++ b/moto/stepfunctions/urls.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals +from .responses import StepFunctionResponse + +url_bases = ["https?://states.(.+).amazonaws.com"] + +url_paths = {"{0}/$": StepFunctionResponse.dispatch} diff --git a/moto/sts/exceptions.py b/moto/sts/exceptions.py index bddb56e3f..1acda9288 100644 --- a/moto/sts/exceptions.py +++ b/moto/sts/exceptions.py @@ -7,9 +7,5 @@ class STSClientError(RESTError): class STSValidationError(STSClientError): - def __init__(self, *args, **kwargs): - super(STSValidationError, self).__init__( - "ValidationError", - *args, **kwargs - ) + super(STSValidationError, self).__init__("ValidationError", *args, **kwargs) diff --git a/moto/sts/models.py b/moto/sts/models.py index c2ff7a8d3..12824b2ed 100644 --- a/moto/sts/models.py +++ b/moto/sts/models.py @@ -2,12 +2,16 @@ from __future__ import unicode_literals import datetime from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds -from moto.iam.models import ACCOUNT_ID -from moto.sts.utils import random_access_key_id, random_secret_access_key, random_session_token, random_assumed_role_id +from moto.core import ACCOUNT_ID +from moto.sts.utils import ( + random_access_key_id, + random_secret_access_key, + random_session_token, + random_assumed_role_id, +) class Token(BaseModel): - def __init__(self, duration, name=None, policy=None): now = datetime.datetime.utcnow() self.expiration = now + datetime.timedelta(seconds=duration) @@ -20,7 +24,6 @@ class Token(BaseModel): class AssumedRole(BaseModel): - def __init__(self, role_session_name, role_arn, policy, duration, external_id): self.session_name = role_session_name self.role_arn = role_arn @@ -46,12 +49,11 @@ class AssumedRole(BaseModel): return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( account_id=ACCOUNT_ID, role_name=self.role_arn.split("/")[-1], - session_name=self.session_name + session_name=self.session_name, ) class STSBackend(BaseBackend): - def __init__(self): self.assumed_roles = [] diff --git a/moto/sts/responses.py b/moto/sts/responses.py index ebdc4321c..f36799b03 100644 --- a/moto/sts/responses.py +++ b/moto/sts/responses.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse -from moto.iam.models import ACCOUNT_ID +from moto.core import ACCOUNT_ID from moto.iam import iam_backend from .exceptions import STSValidationError from .models import sts_backend @@ -10,38 +10,38 @@ MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048 class TokenResponse(BaseResponse): - def get_session_token(self): - duration = int(self.querystring.get('DurationSeconds', [43200])[0]) + duration = int(self.querystring.get("DurationSeconds", [43200])[0]) token = sts_backend.get_session_token(duration=duration) template = self.response_template(GET_SESSION_TOKEN_RESPONSE) return template.render(token=token) def get_federation_token(self): - duration = int(self.querystring.get('DurationSeconds', [43200])[0]) - policy = self.querystring.get('Policy', [None])[0] + duration = int(self.querystring.get("DurationSeconds", [43200])[0]) + policy = self.querystring.get("Policy", [None])[0] if policy is not None and len(policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH: raise STSValidationError( "1 validation error detected: Value " - "'{\"Version\": \"2012-10-17\", \"Statement\": [...]}' " + '\'{"Version": "2012-10-17", "Statement": [...]}\' ' "at 'policy' failed to satisfy constraint: Member must have length less than or " " equal to %s" % MAX_FEDERATION_TOKEN_POLICY_LENGTH ) - name = self.querystring.get('Name')[0] + name = self.querystring.get("Name")[0] token = sts_backend.get_federation_token( - duration=duration, name=name, policy=policy) + duration=duration, name=name, policy=policy + ) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) return template.render(token=token, account_id=ACCOUNT_ID) def assume_role(self): - role_session_name = self.querystring.get('RoleSessionName')[0] - role_arn = self.querystring.get('RoleArn')[0] + role_session_name = self.querystring.get("RoleSessionName")[0] + role_arn = self.querystring.get("RoleArn")[0] - policy = self.querystring.get('Policy', [None])[0] - duration = int(self.querystring.get('DurationSeconds', [3600])[0]) - external_id = self.querystring.get('ExternalId', [None])[0] + policy = self.querystring.get("Policy", [None])[0] + duration = int(self.querystring.get("DurationSeconds", [3600])[0]) + external_id = self.querystring.get("ExternalId", [None])[0] role = sts_backend.assume_role( role_session_name=role_session_name, @@ -54,12 +54,12 @@ class TokenResponse(BaseResponse): return template.render(role=role) def assume_role_with_web_identity(self): - role_session_name = self.querystring.get('RoleSessionName')[0] - role_arn = self.querystring.get('RoleArn')[0] + role_session_name = self.querystring.get("RoleSessionName")[0] + role_arn = self.querystring.get("RoleArn")[0] - policy = self.querystring.get('Policy', [None])[0] - duration = int(self.querystring.get('DurationSeconds', [3600])[0]) - external_id = self.querystring.get('ExternalId', [None])[0] + policy = self.querystring.get("Policy", [None])[0] + duration = int(self.querystring.get("DurationSeconds", [3600])[0]) + external_id = self.querystring.get("ExternalId", [None])[0] role = sts_backend.assume_role_with_web_identity( role_session_name=role_session_name, @@ -128,8 +128,7 @@ GET_FEDERATION_TOKEN_RESPONSE = """ +ASSUME_ROLE_RESPONSE = """ {{ role.session_token }} diff --git a/moto/sts/urls.py b/moto/sts/urls.py index 2078e0b2c..e110f39df 100644 --- a/moto/sts/urls.py +++ b/moto/sts/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import TokenResponse -url_bases = [ - "https?://sts(.*).amazonaws.com" -] +url_bases = ["https?://sts(.*).amazonaws.com"] -url_paths = { - '{0}/$': TokenResponse.dispatch, -} +url_paths = {"{0}/$": TokenResponse.dispatch} diff --git a/moto/sts/utils.py b/moto/sts/utils.py index 50767729f..1e8a13569 100644 --- a/moto/sts/utils.py +++ b/moto/sts/utils.py @@ -19,17 +19,20 @@ def random_secret_access_key(): def random_session_token(): - return SESSION_TOKEN_PREFIX + base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX):].decode() + return ( + SESSION_TOKEN_PREFIX + + base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX) :].decode() + ) def random_assumed_role_id(): - return ACCOUNT_SPECIFIC_ASSUMED_ROLE_ID_PREFIX + _random_uppercase_or_digit_sequence(9) + return ( + ACCOUNT_SPECIFIC_ASSUMED_ROLE_ID_PREFIX + _random_uppercase_or_digit_sequence(9) + ) def _random_uppercase_or_digit_sequence(length): - return ''.join( - six.text_type( - random.choice( - string.ascii_uppercase + string.digits - )) for _ in range(length) + return "".join( + six.text_type(random.choice(string.ascii_uppercase + string.digits)) + for _ in range(length) ) diff --git a/moto/swf/__init__.py b/moto/swf/__init__.py index 0d626690a..2a500458e 100644 --- a/moto/swf/__init__.py +++ b/moto/swf/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import swf_backends from ..core.models import base_decorator, deprecated_base_decorator -swf_backend = swf_backends['us-east-1'] +swf_backend = swf_backends["us-east-1"] mock_swf = base_decorator(swf_backends) mock_swf_deprecated = deprecated_base_decorator(swf_backends) diff --git a/moto/swf/constants.py b/moto/swf/constants.py index b9f680d39..80e384d3c 100644 --- a/moto/swf/constants.py +++ b/moto/swf/constants.py @@ -3,9 +3,7 @@ # See http://docs.aws.amazon.com/amazonswf/latest/apireference/API_RespondDecisionTaskCompleted.html # and subsequent docs for each decision type. DECISIONS_FIELDS = { - "cancelTimerDecisionAttributes": { - "timerId": {"type": "string", "required": True} - }, + "cancelTimerDecisionAttributes": {"timerId": {"type": "string", "required": True}}, "cancelWorkflowExecutionDecisionAttributes": { "details": {"type": "string", "required": False} }, @@ -21,15 +19,15 @@ DECISIONS_FIELDS = { "taskList": {"type": "TaskList", "required": False}, "taskPriority": {"type": "string", "required": False}, "taskStartToCloseTimeout": {"type": "string", "required": False}, - "workflowTypeVersion": {"type": "string", "required": False} + "workflowTypeVersion": {"type": "string", "required": False}, }, "failWorkflowExecutionDecisionAttributes": { "details": {"type": "string", "required": False}, - "reason": {"type": "string", "required": False} + "reason": {"type": "string", "required": False}, }, "recordMarkerDecisionAttributes": { "details": {"type": "string", "required": False}, - "markerName": {"type": "string", "required": True} + "markerName": {"type": "string", "required": True}, }, "requestCancelActivityTaskDecisionAttributes": { "activityId": {"type": "string", "required": True} @@ -37,7 +35,7 @@ DECISIONS_FIELDS = { "requestCancelExternalWorkflowExecutionDecisionAttributes": { "control": {"type": "string", "required": False}, "runId": {"type": "string", "required": False}, - "workflowId": {"type": "string", "required": True} + "workflowId": {"type": "string", "required": True}, }, "scheduleActivityTaskDecisionAttributes": { "activityId": {"type": "string", "required": True}, @@ -49,20 +47,20 @@ DECISIONS_FIELDS = { "scheduleToStartTimeout": {"type": "string", "required": False}, "startToCloseTimeout": {"type": "string", "required": False}, "taskList": {"type": "TaskList", "required": False}, - "taskPriority": {"type": "string", "required": False} + "taskPriority": {"type": "string", "required": False}, }, "scheduleLambdaFunctionDecisionAttributes": { "id": {"type": "string", "required": True}, "input": {"type": "string", "required": False}, "name": {"type": "string", "required": True}, - "startToCloseTimeout": {"type": "string", "required": False} + "startToCloseTimeout": {"type": "string", "required": False}, }, "signalExternalWorkflowExecutionDecisionAttributes": { "control": {"type": "string", "required": False}, "input": {"type": "string", "required": False}, "runId": {"type": "string", "required": False}, "signalName": {"type": "string", "required": True}, - "workflowId": {"type": "string", "required": True} + "workflowId": {"type": "string", "required": True}, }, "startChildWorkflowExecutionDecisionAttributes": { "childPolicy": {"type": "string", "required": False}, @@ -75,11 +73,11 @@ DECISIONS_FIELDS = { "taskPriority": {"type": "string", "required": False}, "taskStartToCloseTimeout": {"type": "string", "required": False}, "workflowId": {"type": "string", "required": True}, - "workflowType": {"type": "WorkflowType", "required": True} + "workflowType": {"type": "WorkflowType", "required": True}, }, "startTimerDecisionAttributes": { "control": {"type": "string", "required": False}, "startToFireTimeout": {"type": "string", "required": True}, - "timerId": {"type": "string", "required": True} - } + "timerId": {"type": "string", "required": True}, + }, } diff --git a/moto/swf/exceptions.py b/moto/swf/exceptions.py index 232b1f237..def30b313 100644 --- a/moto/swf/exceptions.py +++ b/moto/swf/exceptions.py @@ -8,71 +8,59 @@ class SWFClientError(JsonRESTError): class SWFUnknownResourceFault(SWFClientError): - def __init__(self, resource_type, resource_name=None): if resource_name: message = "Unknown {0}: {1}".format(resource_type, resource_name) else: message = "Unknown {0}".format(resource_type) super(SWFUnknownResourceFault, self).__init__( - "com.amazonaws.swf.base.model#UnknownResourceFault", - message, + "com.amazonaws.swf.base.model#UnknownResourceFault", message ) class SWFDomainAlreadyExistsFault(SWFClientError): - def __init__(self, domain_name): super(SWFDomainAlreadyExistsFault, self).__init__( - "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", - domain_name, + "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", domain_name ) class SWFDomainDeprecatedFault(SWFClientError): - def __init__(self, domain_name): super(SWFDomainDeprecatedFault, self).__init__( - "com.amazonaws.swf.base.model#DomainDeprecatedFault", - domain_name, + "com.amazonaws.swf.base.model#DomainDeprecatedFault", domain_name ) class SWFSerializationException(SWFClientError): - def __init__(self, value): message = "class java.lang.Foo can not be converted to an String " - message += " (not a real SWF exception ; happened on: {0})".format( - value) + message += " (not a real SWF exception ; happened on: {0})".format(value) __type = "com.amazonaws.swf.base.model#SerializationException" - super(SWFSerializationException, self).__init__( - __type, - message, - ) + super(SWFSerializationException, self).__init__(__type, message) class SWFTypeAlreadyExistsFault(SWFClientError): - def __init__(self, _type): super(SWFTypeAlreadyExistsFault, self).__init__( "com.amazonaws.swf.base.model#TypeAlreadyExistsFault", "{0}=[name={1}, version={2}]".format( - _type.__class__.__name__, _type.name, _type.version), + _type.__class__.__name__, _type.name, _type.version + ), ) class SWFTypeDeprecatedFault(SWFClientError): - def __init__(self, _type): super(SWFTypeDeprecatedFault, self).__init__( "com.amazonaws.swf.base.model#TypeDeprecatedFault", "{0}=[name={1}, version={2}]".format( - _type.__class__.__name__, _type.name, _type.version), + _type.__class__.__name__, _type.name, _type.version + ), ) class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError): - def __init__(self): super(SWFWorkflowExecutionAlreadyStartedFault, self).__init__( "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", @@ -81,7 +69,6 @@ class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError): class SWFDefaultUndefinedFault(SWFClientError): - def __init__(self, key): # TODO: move that into moto.core.utils maybe? words = key.split("_") @@ -89,22 +76,18 @@ class SWFDefaultUndefinedFault(SWFClientError): for word in words: key_camel_case += word.capitalize() super(SWFDefaultUndefinedFault, self).__init__( - "com.amazonaws.swf.base.model#DefaultUndefinedFault", - key_camel_case, + "com.amazonaws.swf.base.model#DefaultUndefinedFault", key_camel_case ) class SWFValidationException(SWFClientError): - def __init__(self, message): super(SWFValidationException, self).__init__( - "com.amazon.coral.validate#ValidationException", - message, + "com.amazon.coral.validate#ValidationException", message ) class SWFDecisionValidationException(SWFClientError): - def __init__(self, problems): # messages messages = [] @@ -122,8 +105,7 @@ class SWFDecisionValidationException(SWFClientError): ) else: raise ValueError( - "Unhandled decision constraint type: {0}".format(pb[ - "type"]) + "Unhandled decision constraint type: {0}".format(pb["type"]) ) # prefix count = len(problems) @@ -138,6 +120,5 @@ class SWFDecisionValidationException(SWFClientError): class SWFWorkflowExecutionClosedError(Exception): - def __str__(self): return repr("Cannot change this object because the WorkflowExecution is closed") diff --git a/moto/swf/models/__init__.py b/moto/swf/models/__init__.py index a8bc57f40..50cc29bb3 100644 --- a/moto/swf/models/__init__.py +++ b/moto/swf/models/__init__.py @@ -12,25 +12,21 @@ from ..exceptions import ( SWFTypeDeprecatedFault, SWFValidationException, ) -from .activity_task import ActivityTask # flake8: noqa -from .activity_type import ActivityType # flake8: noqa -from .decision_task import DecisionTask # flake8: noqa -from .domain import Domain # flake8: noqa -from .generic_type import GenericType # flake8: noqa -from .history_event import HistoryEvent # flake8: noqa -from .timeout import Timeout # flake8: noqa -from .workflow_type import WorkflowType # flake8: noqa -from .workflow_execution import WorkflowExecution # flake8: noqa +from .activity_task import ActivityTask # noqa +from .activity_type import ActivityType # noqa +from .decision_task import DecisionTask # noqa +from .domain import Domain # noqa +from .generic_type import GenericType # noqa +from .history_event import HistoryEvent # noqa +from .timeout import Timeout # noqa +from .workflow_type import WorkflowType # noqa +from .workflow_execution import WorkflowExecution # noqa from time import sleep -KNOWN_SWF_TYPES = { - "activity": ActivityType, - "workflow": WorkflowType, -} +KNOWN_SWF_TYPES = {"activity": ActivityType, "workflow": WorkflowType} class SWFBackend(BaseBackend): - def __init__(self, region_name): self.region_name = region_name self.domains = [] @@ -55,46 +51,53 @@ class SWFBackend(BaseBackend): wfe._process_timeouts() def list_domains(self, status, reverse_order=None): - domains = [domain for domain in self.domains - if domain.status == status] + domains = [domain for domain in self.domains if domain.status == status] domains = sorted(domains, key=lambda domain: domain.name) if reverse_order: domains = reversed(domains) return domains - def list_open_workflow_executions(self, domain_name, maximum_page_size, - tag_filter, reverse_order, **kwargs): + def list_open_workflow_executions( + self, domain_name, maximum_page_size, tag_filter, reverse_order, **kwargs + ): self._process_timeouts() domain = self._get_domain(domain_name) if domain.status == "DEPRECATED": raise SWFDomainDeprecatedFault(domain_name) open_wfes = [ - wfe for wfe in domain.workflow_executions - if wfe.execution_status == 'OPEN' + wfe for wfe in domain.workflow_executions if wfe.execution_status == "OPEN" ] if tag_filter: for open_wfe in open_wfes: - if tag_filter['tag'] not in open_wfe.tag_list: + if tag_filter["tag"] not in open_wfe.tag_list: open_wfes.remove(open_wfe) if reverse_order: open_wfes = reversed(open_wfes) return open_wfes[0:maximum_page_size] - def list_closed_workflow_executions(self, domain_name, close_time_filter, - tag_filter, close_status_filter, maximum_page_size, reverse_order, - **kwargs): + def list_closed_workflow_executions( + self, + domain_name, + close_time_filter, + tag_filter, + close_status_filter, + maximum_page_size, + reverse_order, + **kwargs + ): self._process_timeouts() domain = self._get_domain(domain_name) if domain.status == "DEPRECATED": raise SWFDomainDeprecatedFault(domain_name) closed_wfes = [ - wfe for wfe in domain.workflow_executions - if wfe.execution_status == 'CLOSED' + wfe + for wfe in domain.workflow_executions + if wfe.execution_status == "CLOSED" ] if tag_filter: for closed_wfe in closed_wfes: - if tag_filter['tag'] not in closed_wfe.tag_list: + if tag_filter["tag"] not in closed_wfe.tag_list: closed_wfes.remove(closed_wfe) if close_status_filter: for closed_wfe in closed_wfes: @@ -104,12 +107,12 @@ class SWFBackend(BaseBackend): closed_wfes = reversed(closed_wfes) return closed_wfes[0:maximum_page_size] - def register_domain(self, name, workflow_execution_retention_period_in_days, - description=None): + def register_domain( + self, name, workflow_execution_retention_period_in_days, description=None + ): if self._get_domain(name, ignore_empty=True): raise SWFDomainAlreadyExistsFault(name) - domain = Domain(name, workflow_execution_retention_period_in_days, - description) + domain = Domain(name, workflow_execution_retention_period_in_days, description) self.domains.append(domain) def deprecate_domain(self, name): @@ -149,15 +152,23 @@ class SWFBackend(BaseBackend): domain = self._get_domain(domain_name) return domain.get_type(kind, name, version) - def start_workflow_execution(self, domain_name, workflow_id, - workflow_name, workflow_version, - tag_list=None, input=None, **kwargs): + def start_workflow_execution( + self, + domain_name, + workflow_id, + workflow_name, + workflow_version, + tag_list=None, + input=None, + **kwargs + ): domain = self._get_domain(domain_name) wf_type = domain.get_type("workflow", workflow_name, workflow_version) if wf_type.status == "DEPRECATED": raise SWFTypeDeprecatedFault(wf_type) - wfe = WorkflowExecution(domain, wf_type, workflow_id, - tag_list=tag_list, input=input, **kwargs) + wfe = WorkflowExecution( + domain, wf_type, workflow_id, tag_list=tag_list, input=input, **kwargs + ) domain.add_workflow_execution(wfe) wfe.start() @@ -213,9 +224,9 @@ class SWFBackend(BaseBackend): count += wfe.open_counts["openDecisionTasks"] return count - def respond_decision_task_completed(self, task_token, - decisions=None, - execution_context=None): + def respond_decision_task_completed( + self, task_token, decisions=None, execution_context=None + ): # process timeouts on all objects self._process_timeouts() # let's find decision task @@ -244,14 +255,15 @@ class SWFBackend(BaseBackend): "execution", "WorkflowExecution=[workflowId={0}, runId={1}]".format( wfe.workflow_id, wfe.run_id - ) + ), ) # decision task found, but already completed if decision_task.state != "STARTED": if decision_task.state == "COMPLETED": raise SWFUnknownResourceFault( "decision task, scheduledEventId = {0}".format( - decision_task.scheduled_event_id) + decision_task.scheduled_event_id + ) ) else: raise ValueError( @@ -263,9 +275,11 @@ class SWFBackend(BaseBackend): # everything's good if decision_task: wfe = decision_task.workflow_execution - wfe.complete_decision_task(decision_task.task_token, - decisions=decisions, - execution_context=execution_context) + wfe.complete_decision_task( + decision_task.task_token, + decisions=decisions, + execution_context=execution_context, + ) def poll_for_activity_task(self, domain_name, task_list, identity=None): # process timeouts on all objects @@ -308,8 +322,7 @@ class SWFBackend(BaseBackend): count = 0 for _task_list, tasks in domain.activity_task_lists.items(): if _task_list == task_list: - pending = [t for t in tasks if t.state in [ - "SCHEDULED", "STARTED"]] + pending = [t for t in tasks if t.state in ["SCHEDULED", "STARTED"]] count += len(pending) return count @@ -333,14 +346,15 @@ class SWFBackend(BaseBackend): "execution", "WorkflowExecution=[workflowId={0}, runId={1}]".format( wfe.workflow_id, wfe.run_id - ) + ), ) # activity task found, but already completed if activity_task.state != "STARTED": if activity_task.state == "COMPLETED": raise SWFUnknownResourceFault( "activity, scheduledEventId = {0}".format( - activity_task.scheduled_event_id) + activity_task.scheduled_event_id + ) ) else: raise ValueError( @@ -364,18 +378,24 @@ class SWFBackend(BaseBackend): self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) wfe = activity_task.workflow_execution - wfe.fail_activity_task(activity_task.task_token, - reason=reason, details=details) + wfe.fail_activity_task(activity_task.task_token, reason=reason, details=details) - def terminate_workflow_execution(self, domain_name, workflow_id, child_policy=None, - details=None, reason=None, run_id=None): + def terminate_workflow_execution( + self, + domain_name, + workflow_id, + child_policy=None, + details=None, + reason=None, + run_id=None, + ): # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) wfe = domain.get_workflow_execution( - workflow_id, run_id=run_id, raise_if_closed=True) - wfe.terminate(child_policy=child_policy, - details=details, reason=reason) + workflow_id, run_id=run_id, raise_if_closed=True + ) + wfe.terminate(child_policy=child_policy, details=details, reason=reason) def record_activity_task_heartbeat(self, task_token, details=None): # process timeouts on all objects @@ -385,12 +405,15 @@ class SWFBackend(BaseBackend): if details: activity_task.details = details - def signal_workflow_execution(self, domain_name, signal_name, workflow_id, input=None, run_id=None): + def signal_workflow_execution( + self, domain_name, signal_name, workflow_id, input=None, run_id=None + ): # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) wfe = domain.get_workflow_execution( - workflow_id, run_id=run_id, raise_if_closed=True) + workflow_id, run_id=run_id, raise_if_closed=True + ) wfe.signal(signal_name, input) diff --git a/moto/swf/models/activity_task.py b/moto/swf/models/activity_task.py index 0c1f283ca..93e300ae5 100644 --- a/moto/swf/models/activity_task.py +++ b/moto/swf/models/activity_task.py @@ -10,9 +10,15 @@ from .timeout import Timeout class ActivityTask(BaseModel): - - def __init__(self, activity_id, activity_type, scheduled_event_id, - workflow_execution, timeouts, input=None): + def __init__( + self, + activity_id, + activity_type, + scheduled_event_id, + workflow_execution, + timeouts, + input=None, + ): self.activity_id = activity_id self.activity_type = activity_type self.details = None @@ -68,8 +74,9 @@ class ActivityTask(BaseModel): if not self.open or not self.workflow_execution.open: return None # TODO: handle the "NONE" case - heartbeat_timeout_at = (self.last_heartbeat_timestamp + - int(self.timeouts["heartbeatTimeout"])) + heartbeat_timeout_at = self.last_heartbeat_timestamp + int( + self.timeouts["heartbeatTimeout"] + ) _timeout = Timeout(self, heartbeat_timeout_at, "HEARTBEAT") if _timeout.reached: return _timeout diff --git a/moto/swf/models/activity_type.py b/moto/swf/models/activity_type.py index eb1bbfa68..95a83ca7a 100644 --- a/moto/swf/models/activity_type.py +++ b/moto/swf/models/activity_type.py @@ -2,7 +2,6 @@ from .generic_type import GenericType class ActivityType(GenericType): - @property def _configuration_keys(self): return [ diff --git a/moto/swf/models/decision_task.py b/moto/swf/models/decision_task.py index 9255dd6f2..c8c9824a2 100644 --- a/moto/swf/models/decision_task.py +++ b/moto/swf/models/decision_task.py @@ -10,7 +10,6 @@ from .timeout import Timeout class DecisionTask(BaseModel): - def __init__(self, workflow_execution, scheduled_event_id): self.workflow_execution = workflow_execution self.workflow_type = workflow_execution.workflow_type @@ -19,7 +18,9 @@ class DecisionTask(BaseModel): self.previous_started_event_id = 0 self.started_event_id = None self.started_timestamp = None - self.start_to_close_timeout = self.workflow_execution.task_start_to_close_timeout + self.start_to_close_timeout = ( + self.workflow_execution.task_start_to_close_timeout + ) self.state = "SCHEDULED" # this is *not* necessarily coherent with workflow execution history, # but that shouldn't be a problem for tests @@ -37,9 +38,7 @@ class DecisionTask(BaseModel): def to_full_dict(self, reverse_order=False): events = self.workflow_execution.events(reverse_order=reverse_order) hsh = { - "events": [ - evt.to_dict() for evt in events - ], + "events": [evt.to_dict() for evt in events], "taskToken": self.task_token, "previousStartedEventId": self.previous_started_event_id, "workflowExecution": self.workflow_execution.to_short_dict(), @@ -62,8 +61,7 @@ class DecisionTask(BaseModel): if not self.started or not self.workflow_execution.open: return None # TODO: handle the "NONE" case - start_to_close_at = self.started_timestamp + \ - int(self.start_to_close_timeout) + start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout) _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") if _timeout.reached: return _timeout diff --git a/moto/swf/models/domain.py b/moto/swf/models/domain.py index 0aa62f4f0..54347b22b 100644 --- a/moto/swf/models/domain.py +++ b/moto/swf/models/domain.py @@ -9,16 +9,12 @@ from ..exceptions import ( class Domain(BaseModel): - def __init__(self, name, retention, description=None): self.name = name self.retention = retention self.description = description self.status = "REGISTERED" - self.types = { - "activity": defaultdict(dict), - "workflow": defaultdict(dict), - } + self.types = {"activity": defaultdict(dict), "workflow": defaultdict(dict)} # Workflow executions have an id, which unicity is guaranteed # at domain level (not super clear in the docs, but I checked # that against SWF API) ; hence the storage method as a dict @@ -32,10 +28,7 @@ class Domain(BaseModel): return "Domain(name: %(name)s, status: %(status)s)" % self.__dict__ def to_short_dict(self): - hsh = { - "name": self.name, - "status": self.status, - } + hsh = {"name": self.name, "status": self.status} if self.description: hsh["description"] = self.description return hsh @@ -43,9 +36,7 @@ class Domain(BaseModel): def to_full_dict(self): return { "domainInfo": self.to_short_dict(), - "configuration": { - "workflowExecutionRetentionPeriodInDays": self.retention, - } + "configuration": {"workflowExecutionRetentionPeriodInDays": self.retention}, } def get_type(self, kind, name, version, ignore_empty=False): @@ -57,7 +48,7 @@ class Domain(BaseModel): "type", "{0}Type=[name={1}, version={2}]".format( kind.capitalize(), name, version - ) + ), ) def add_type(self, _type): @@ -77,15 +68,22 @@ class Domain(BaseModel): raise SWFWorkflowExecutionAlreadyStartedFault() self.workflow_executions.append(workflow_execution) - def get_workflow_execution(self, workflow_id, run_id=None, - raise_if_none=True, raise_if_closed=False): + def get_workflow_execution( + self, workflow_id, run_id=None, raise_if_none=True, raise_if_closed=False + ): # query if run_id: - _all = [w for w in self.workflow_executions - if w.workflow_id == workflow_id and w.run_id == run_id] + _all = [ + w + for w in self.workflow_executions + if w.workflow_id == workflow_id and w.run_id == run_id + ] else: - _all = [w for w in self.workflow_executions - if w.workflow_id == workflow_id and w.open] + _all = [ + w + for w in self.workflow_executions + if w.workflow_id == workflow_id and w.open + ] # reduce wfe = _all[0] if _all else None # raise if closed / none @@ -93,8 +91,12 @@ class Domain(BaseModel): wfe = None if not wfe and raise_if_none: if run_id: - args = ["execution", "WorkflowExecution=[workflowId={0}, runId={1}]".format( - workflow_id, run_id)] + args = [ + "execution", + "WorkflowExecution=[workflowId={0}, runId={1}]".format( + workflow_id, run_id + ), + ] else: args = ["execution, workflowId = {0}".format(workflow_id)] raise SWFUnknownResourceFault(*args) diff --git a/moto/swf/models/generic_type.py b/moto/swf/models/generic_type.py index a56220ed6..8ae6ebc08 100644 --- a/moto/swf/models/generic_type.py +++ b/moto/swf/models/generic_type.py @@ -5,7 +5,6 @@ from moto.core.utils import camelcase_to_underscores class GenericType(BaseModel): - def __init__(self, name, version, **kwargs): self.name = name self.version = version @@ -24,7 +23,9 @@ class GenericType(BaseModel): def __repr__(self): cls = self.__class__.__name__ - attrs = "name: %(name)s, version: %(version)s, status: %(status)s" % self.__dict__ + attrs = ( + "name: %(name)s, version: %(version)s, status: %(status)s" % self.__dict__ + ) return "{0}({1})".format(cls, attrs) @property @@ -36,10 +37,7 @@ class GenericType(BaseModel): raise NotImplementedError() def to_short_dict(self): - return { - "name": self.name, - "version": self.version, - } + return {"name": self.name, "version": self.version} def to_medium_dict(self): hsh = { @@ -54,10 +52,7 @@ class GenericType(BaseModel): return hsh def to_full_dict(self): - hsh = { - "typeInfo": self.to_medium_dict(), - "configuration": {} - } + hsh = {"typeInfo": self.to_medium_dict(), "configuration": {}} if self.task_list: hsh["configuration"]["defaultTaskList"] = {"name": self.task_list} for key in self._configuration_keys: diff --git a/moto/swf/models/history_event.py b/moto/swf/models/history_event.py index e7ddfd924..f259ea94e 100644 --- a/moto/swf/models/history_event.py +++ b/moto/swf/models/history_event.py @@ -25,17 +25,17 @@ SUPPORTED_HISTORY_EVENT_TYPES = ( "ActivityTaskTimedOut", "DecisionTaskTimedOut", "WorkflowExecutionTimedOut", - "WorkflowExecutionSignaled" + "WorkflowExecutionSignaled", ) class HistoryEvent(BaseModel): - def __init__(self, event_id, event_type, event_timestamp=None, **kwargs): if event_type not in SUPPORTED_HISTORY_EVENT_TYPES: raise NotImplementedError( "HistoryEvent does not implement attributes for type '{0}'".format( - event_type) + event_type + ) ) self.event_id = event_id self.event_type = event_type @@ -61,7 +61,7 @@ class HistoryEvent(BaseModel): "eventId": self.event_id, "eventType": self.event_type, "eventTimestamp": self.event_timestamp, - self._attributes_key(): self.event_attributes + self._attributes_key(): self.event_attributes, } def _attributes_key(self): diff --git a/moto/swf/models/timeout.py b/moto/swf/models/timeout.py index f26c8a38b..bc576bb64 100644 --- a/moto/swf/models/timeout.py +++ b/moto/swf/models/timeout.py @@ -3,7 +3,6 @@ from moto.core.utils import unix_time class Timeout(BaseModel): - def __init__(self, obj, timestamp, kind): self.obj = obj self.timestamp = timestamp diff --git a/moto/swf/models/workflow_execution.py b/moto/swf/models/workflow_execution.py index 3d01f9192..4d91b1f6f 100644 --- a/moto/swf/models/workflow_execution.py +++ b/moto/swf/models/workflow_execution.py @@ -4,9 +4,7 @@ import uuid from moto.core import BaseModel from moto.core.utils import camelcase_to_underscores, unix_time -from ..constants import ( - DECISIONS_FIELDS, -) +from ..constants import DECISIONS_FIELDS from ..exceptions import ( SWFDefaultUndefinedFault, SWFValidationException, @@ -38,7 +36,7 @@ class WorkflowExecution(BaseModel): "FailWorkflowExecution", "RequestCancelActivityTask", "StartChildWorkflowExecution", - "CancelWorkflowExecution" + "CancelWorkflowExecution", ] def __init__(self, domain, workflow_type, workflow_id, **kwargs): @@ -66,11 +64,10 @@ class WorkflowExecution(BaseModel): # param is set, # SWF will raise DefaultUndefinedFault errors in the # same order as the few lines that follow) self._set_from_kwargs_or_workflow_type( - kwargs, "execution_start_to_close_timeout") - self._set_from_kwargs_or_workflow_type( - kwargs, "task_list", "task_list") - self._set_from_kwargs_or_workflow_type( - kwargs, "task_start_to_close_timeout") + kwargs, "execution_start_to_close_timeout" + ) + self._set_from_kwargs_or_workflow_type(kwargs, "task_list", "task_list") + self._set_from_kwargs_or_workflow_type(kwargs, "task_start_to_close_timeout") self._set_from_kwargs_or_workflow_type(kwargs, "child_policy") self.input = kwargs.get("input") # counters @@ -89,7 +86,9 @@ class WorkflowExecution(BaseModel): def __repr__(self): return "WorkflowExecution(run_id: {0})".format(self.run_id) - def _set_from_kwargs_or_workflow_type(self, kwargs, local_key, workflow_type_key=None): + def _set_from_kwargs_or_workflow_type( + self, kwargs, local_key, workflow_type_key=None + ): if workflow_type_key is None: workflow_type_key = "default_" + local_key value = kwargs.get(local_key) @@ -109,10 +108,7 @@ class WorkflowExecution(BaseModel): ] def to_short_dict(self): - return { - "workflowId": self.workflow_id, - "runId": self.run_id - } + return {"workflowId": self.workflow_id, "runId": self.run_id} def to_medium_dict(self): hsh = { @@ -129,9 +125,7 @@ class WorkflowExecution(BaseModel): def to_full_dict(self): hsh = { "executionInfo": self.to_medium_dict(), - "executionConfiguration": { - "taskList": {"name": self.task_list} - } + "executionConfiguration": {"taskList": {"name": self.task_list}}, } # configuration for key in self._configuration_keys: @@ -152,23 +146,20 @@ class WorkflowExecution(BaseModel): def to_list_dict(self): hsh = { - 'execution': { - 'workflowId': self.workflow_id, - 'runId': self.run_id, - }, - 'workflowType': self.workflow_type.to_short_dict(), - 'startTimestamp': self.start_timestamp, - 'executionStatus': self.execution_status, - 'cancelRequested': self.cancel_requested, + "execution": {"workflowId": self.workflow_id, "runId": self.run_id}, + "workflowType": self.workflow_type.to_short_dict(), + "startTimestamp": self.start_timestamp, + "executionStatus": self.execution_status, + "cancelRequested": self.cancel_requested, } if self.tag_list: - hsh['tagList'] = self.tag_list + hsh["tagList"] = self.tag_list if self.parent: - hsh['parent'] = self.parent + hsh["parent"] = self.parent if self.close_status: - hsh['closeStatus'] = self.close_status + hsh["closeStatus"] = self.close_status if self.close_timestamp: - hsh['closeTimestamp'] = self.close_timestamp + hsh["closeTimestamp"] = self.close_timestamp return hsh def _process_timeouts(self): @@ -206,10 +197,7 @@ class WorkflowExecution(BaseModel): # now find the first timeout to process first_timeout = None if timeout_candidates: - first_timeout = min( - timeout_candidates, - key=lambda t: t.timestamp - ) + first_timeout = min(timeout_candidates, key=lambda t: t.timestamp) if first_timeout: should_schedule_decision_next = False @@ -258,7 +246,7 @@ class WorkflowExecution(BaseModel): task_list=self.task_list, task_start_to_close_timeout=self.task_start_to_close_timeout, workflow_type=self.workflow_type, - input=self.input + input=self.input, ) self.schedule_decision_task() @@ -269,8 +257,7 @@ class WorkflowExecution(BaseModel): task_list=self.task_list, ) self.domain.add_to_decision_task_list( - self.task_list, - DecisionTask(self, evt.event_id), + self.task_list, DecisionTask(self, evt.event_id) ) self.open_counts["openDecisionTasks"] += 1 @@ -285,32 +272,30 @@ class WorkflowExecution(BaseModel): @property def decision_tasks(self): - return [t for t in self.domain.decision_tasks - if t.workflow_execution == self] + return [t for t in self.domain.decision_tasks if t.workflow_execution == self] @property def activity_tasks(self): - return [t for t in self.domain.activity_tasks - if t.workflow_execution == self] + return [t for t in self.domain.activity_tasks if t.workflow_execution == self] def _find_decision_task(self, task_token): for dt in self.decision_tasks: if dt.task_token == task_token: return dt - raise ValueError( - "No decision task with token: {0}".format(task_token) - ) + raise ValueError("No decision task with token: {0}".format(task_token)) def start_decision_task(self, task_token, identity=None): dt = self._find_decision_task(task_token) evt = self._add_event( "DecisionTaskStarted", scheduled_event_id=dt.scheduled_event_id, - identity=identity + identity=identity, ) dt.start(evt.event_id) - def complete_decision_task(self, task_token, decisions=None, execution_context=None): + def complete_decision_task( + self, task_token, decisions=None, execution_context=None + ): # 'decisions' can be None per boto.swf defaults, so replace it with something iterable if not decisions: decisions = [] @@ -336,12 +321,14 @@ class WorkflowExecution(BaseModel): constraints = DECISIONS_FIELDS.get(kind, {}) for key, constraint in constraints.items(): if constraint["required"] and not value.get(key): - problems.append({ - "type": "null_value", - "where": "decisions.{0}.member.{1}.{2}".format( - decision_id, kind, key - ) - }) + problems.append( + { + "type": "null_value", + "where": "decisions.{0}.member.{1}.{2}".format( + decision_id, kind, key + ), + } + ) return problems def validate_decisions(self, decisions): @@ -362,9 +349,7 @@ class WorkflowExecution(BaseModel): "CancelWorkflowExecution", ] if dcs["decisionType"] in close_decision_types: - raise SWFValidationException( - "Close must be last decision in list" - ) + raise SWFValidationException("Close must be last decision in list") decision_number = 0 for dcs in decisions: @@ -372,24 +357,29 @@ class WorkflowExecution(BaseModel): # check decision types mandatory attributes # NB: the real SWF service seems to check attributes even for attributes list # that are not in line with the decisionType, so we do the same - attrs_to_check = [ - d for d in dcs.keys() if d.endswith("DecisionAttributes")] + attrs_to_check = [d for d in dcs.keys() if d.endswith("DecisionAttributes")] if dcs["decisionType"] in self.KNOWN_DECISION_TYPES: decision_type = dcs["decisionType"] decision_attr = "{0}DecisionAttributes".format( - decapitalize(decision_type)) + decapitalize(decision_type) + ) attrs_to_check.append(decision_attr) for attr in attrs_to_check: problems += self._check_decision_attributes( - attr, dcs.get(attr, {}), decision_number) + attr, dcs.get(attr, {}), decision_number + ) # check decision type is correct if dcs["decisionType"] not in self.KNOWN_DECISION_TYPES: - problems.append({ - "type": "bad_decision_type", - "value": dcs["decisionType"], - "where": "decisions.{0}.member.decisionType".format(decision_number), - "possible_values": ", ".join(self.KNOWN_DECISION_TYPES), - }) + problems.append( + { + "type": "bad_decision_type", + "value": dcs["decisionType"], + "where": "decisions.{0}.member.decisionType".format( + decision_number + ), + "possible_values": ", ".join(self.KNOWN_DECISION_TYPES), + } + ) # raise if any problem if any(problems): @@ -403,14 +393,12 @@ class WorkflowExecution(BaseModel): # handle each decision separately, in order for decision in decisions: decision_type = decision["decisionType"] - attributes_key = "{0}DecisionAttributes".format( - decapitalize(decision_type)) + attributes_key = "{0}DecisionAttributes".format(decapitalize(decision_type)) attributes = decision.get(attributes_key, {}) if decision_type == "CompleteWorkflowExecution": self.complete(event_id, attributes.get("result")) elif decision_type == "FailWorkflowExecution": - self.fail(event_id, attributes.get( - "details"), attributes.get("reason")) + self.fail(event_id, attributes.get("details"), attributes.get("reason")) elif decision_type == "ScheduleActivityTask": self.schedule_activity_task(event_id, attributes) else: @@ -425,7 +413,8 @@ class WorkflowExecution(BaseModel): # TODO: implement Decision type: StartChildWorkflowExecution # TODO: implement Decision type: StartTimer raise NotImplementedError( - "Cannot handle decision: {0}".format(decision_type)) + "Cannot handle decision: {0}".format(decision_type) + ) # finally decrement counter if and only if everything went well self.open_counts["openDecisionTasks"] -= 1 @@ -441,7 +430,7 @@ class WorkflowExecution(BaseModel): ) def fail(self, event_id, details=None, reason=None): - # TODO: implement lenght constraints on details/reason + # TODO: implement length constraints on details/reason self.execution_status = "CLOSED" self.close_status = "FAILED" self.close_timestamp = unix_time() @@ -475,18 +464,21 @@ class WorkflowExecution(BaseModel): ignore_empty=True, ) if not activity_type: - fake_type = ActivityType(attributes["activityType"]["name"], - attributes["activityType"]["version"]) - fail_schedule_activity_task(fake_type, - "ACTIVITY_TYPE_DOES_NOT_EXIST") + fake_type = ActivityType( + attributes["activityType"]["name"], + attributes["activityType"]["version"], + ) + fail_schedule_activity_task(fake_type, "ACTIVITY_TYPE_DOES_NOT_EXIST") return if activity_type.status == "DEPRECATED": - fail_schedule_activity_task(activity_type, - "ACTIVITY_TYPE_DEPRECATED") + fail_schedule_activity_task(activity_type, "ACTIVITY_TYPE_DEPRECATED") return - if any(at for at in self.activity_tasks if at.activity_id == attributes["activityId"]): - fail_schedule_activity_task(activity_type, - "ACTIVITY_ID_ALREADY_IN_USE") + if any( + at + for at in self.activity_tasks + if at.activity_id == attributes["activityId"] + ): + fail_schedule_activity_task(activity_type, "ACTIVITY_ID_ALREADY_IN_USE") return # find task list or default task list, else fail @@ -494,20 +486,25 @@ class WorkflowExecution(BaseModel): if not task_list and activity_type.task_list: task_list = activity_type.task_list if not task_list: - fail_schedule_activity_task(activity_type, - "DEFAULT_TASK_LIST_UNDEFINED") + fail_schedule_activity_task(activity_type, "DEFAULT_TASK_LIST_UNDEFINED") return # find timeouts or default timeout, else fail timeouts = {} - for _type in ["scheduleToStartTimeout", "scheduleToCloseTimeout", "startToCloseTimeout", "heartbeatTimeout"]: + for _type in [ + "scheduleToStartTimeout", + "scheduleToCloseTimeout", + "startToCloseTimeout", + "heartbeatTimeout", + ]: default_key = "default_task_" + camelcase_to_underscores(_type) default_value = getattr(activity_type, default_key) timeouts[_type] = attributes.get(_type, default_value) if not timeouts[_type]: error_key = default_key.replace("default_task_", "default_") - fail_schedule_activity_task(activity_type, - "{0}_UNDEFINED".format(error_key.upper())) + fail_schedule_activity_task( + activity_type, "{0}_UNDEFINED".format(error_key.upper()) + ) return # Only add event and increment counters now that nothing went wrong @@ -541,16 +538,14 @@ class WorkflowExecution(BaseModel): for task in self.activity_tasks: if task.task_token == task_token: return task - raise ValueError( - "No activity task with token: {0}".format(task_token) - ) + raise ValueError("No activity task with token: {0}".format(task_token)) def start_activity_task(self, task_token, identity=None): task = self._find_activity_task(task_token) evt = self._add_event( "ActivityTaskStarted", scheduled_event_id=task.scheduled_event_id, - identity=identity + identity=identity, ) task.start(evt.event_id) @@ -601,17 +596,16 @@ class WorkflowExecution(BaseModel): def signal(self, signal_name, input): self._add_event( - "WorkflowExecutionSignaled", - signal_name=signal_name, - input=input, + "WorkflowExecutionSignaled", signal_name=signal_name, input=input ) self.schedule_decision_task() def first_timeout(self): if not self.open or not self.start_timestamp: return None - start_to_close_at = self.start_timestamp + \ - int(self.execution_start_to_close_timeout) + start_to_close_at = self.start_timestamp + int( + self.execution_start_to_close_timeout + ) _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") if _timeout.reached: return _timeout diff --git a/moto/swf/models/workflow_type.py b/moto/swf/models/workflow_type.py index 18d18d415..ddb2475b2 100644 --- a/moto/swf/models/workflow_type.py +++ b/moto/swf/models/workflow_type.py @@ -2,7 +2,6 @@ from .generic_type import GenericType class WorkflowType(GenericType): - @property def _configuration_keys(self): return [ diff --git a/moto/swf/responses.py b/moto/swf/responses.py index 6f002d3d4..98b736cda 100644 --- a/moto/swf/responses.py +++ b/moto/swf/responses.py @@ -8,7 +8,6 @@ from .models import swf_backends class SWFResponse(BaseResponse): - @property def swf_backend(self): return swf_backends[self.region] @@ -51,11 +50,12 @@ class SWFResponse(BaseResponse): return keys = kwargs.keys() if len(keys) == 2: - message = 'Cannot specify both a {0} and a {1}'.format(keys[0], - keys[1]) + message = "Cannot specify both a {0} and a {1}".format(keys[0], keys[1]) else: - message = 'Cannot specify more than one exclusive filters in the' \ - ' same query: {0}'.format(keys) + message = ( + "Cannot specify more than one exclusive filters in the" + " same query: {0}".format(keys) + ) raise SWFValidationException(message) def _list_types(self, kind): @@ -65,10 +65,9 @@ class SWFResponse(BaseResponse): self._check_string(domain_name) self._check_string(status) types = self.swf_backend.list_types( - kind, domain_name, status, reverse_order=reverse_order) - return json.dumps({ - "typeInfos": [_type.to_medium_dict() for _type in types] - }) + kind, domain_name, status, reverse_order=reverse_order + ) + return json.dumps({"typeInfos": [_type.to_medium_dict() for _type in types]}) def _describe_type(self, kind): domain = self._params["domain"] @@ -98,50 +97,51 @@ class SWFResponse(BaseResponse): status = self._params["registrationStatus"] self._check_string(status) reverse_order = self._params.get("reverseOrder", None) - domains = self.swf_backend.list_domains( - status, reverse_order=reverse_order) - return json.dumps({ - "domainInfos": [domain.to_short_dict() for domain in domains] - }) + domains = self.swf_backend.list_domains(status, reverse_order=reverse_order) + return json.dumps( + {"domainInfos": [domain.to_short_dict() for domain in domains]} + ) def list_closed_workflow_executions(self): - domain = self._params['domain'] - start_time_filter = self._params.get('startTimeFilter', None) - close_time_filter = self._params.get('closeTimeFilter', None) - execution_filter = self._params.get('executionFilter', None) - workflow_id = execution_filter[ - 'workflowId'] if execution_filter else None - maximum_page_size = self._params.get('maximumPageSize', 1000) - reverse_order = self._params.get('reverseOrder', None) - tag_filter = self._params.get('tagFilter', None) - type_filter = self._params.get('typeFilter', None) - close_status_filter = self._params.get('closeStatusFilter', None) + domain = self._params["domain"] + start_time_filter = self._params.get("startTimeFilter", None) + close_time_filter = self._params.get("closeTimeFilter", None) + execution_filter = self._params.get("executionFilter", None) + workflow_id = execution_filter["workflowId"] if execution_filter else None + maximum_page_size = self._params.get("maximumPageSize", 1000) + reverse_order = self._params.get("reverseOrder", None) + tag_filter = self._params.get("tagFilter", None) + type_filter = self._params.get("typeFilter", None) + close_status_filter = self._params.get("closeStatusFilter", None) self._check_string(domain) self._check_none_or_string(workflow_id) - self._check_exclusivity(executionFilter=execution_filter, - typeFilter=type_filter, - tagFilter=tag_filter, - closeStatusFilter=close_status_filter) - self._check_exclusivity(startTimeFilter=start_time_filter, - closeTimeFilter=close_time_filter) + self._check_exclusivity( + executionFilter=execution_filter, + typeFilter=type_filter, + tagFilter=tag_filter, + closeStatusFilter=close_status_filter, + ) + self._check_exclusivity( + startTimeFilter=start_time_filter, closeTimeFilter=close_time_filter + ) if start_time_filter is None and close_time_filter is None: - raise SWFValidationException('Must specify time filter') + raise SWFValidationException("Must specify time filter") if start_time_filter: - self._check_float_or_int(start_time_filter['oldestDate']) - if 'latestDate' in start_time_filter: - self._check_float_or_int(start_time_filter['latestDate']) + self._check_float_or_int(start_time_filter["oldestDate"]) + if "latestDate" in start_time_filter: + self._check_float_or_int(start_time_filter["latestDate"]) if close_time_filter: - self._check_float_or_int(close_time_filter['oldestDate']) - if 'latestDate' in close_time_filter: - self._check_float_or_int(close_time_filter['latestDate']) + self._check_float_or_int(close_time_filter["oldestDate"]) + if "latestDate" in close_time_filter: + self._check_float_or_int(close_time_filter["latestDate"]) if tag_filter: - self._check_string(tag_filter['tag']) + self._check_string(tag_filter["tag"]) if type_filter: - self._check_string(type_filter['name']) - self._check_string(type_filter['version']) + self._check_string(type_filter["name"]) + self._check_string(type_filter["version"]) if close_status_filter: - self._check_string(close_status_filter['status']) + self._check_string(close_status_filter["status"]) self._check_int(maximum_page_size) workflow_executions = self.swf_backend.list_closed_workflow_executions( @@ -154,37 +154,38 @@ class SWFResponse(BaseResponse): maximum_page_size=maximum_page_size, reverse_order=reverse_order, workflow_id=workflow_id, - close_status_filter=close_status_filter + close_status_filter=close_status_filter, ) - return json.dumps({ - 'executionInfos': [wfe.to_list_dict() for wfe in workflow_executions] - }) + return json.dumps( + {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} + ) def list_open_workflow_executions(self): - domain = self._params['domain'] - start_time_filter = self._params['startTimeFilter'] - execution_filter = self._params.get('executionFilter', None) - workflow_id = execution_filter[ - 'workflowId'] if execution_filter else None - maximum_page_size = self._params.get('maximumPageSize', 1000) - reverse_order = self._params.get('reverseOrder', None) - tag_filter = self._params.get('tagFilter', None) - type_filter = self._params.get('typeFilter', None) + domain = self._params["domain"] + start_time_filter = self._params["startTimeFilter"] + execution_filter = self._params.get("executionFilter", None) + workflow_id = execution_filter["workflowId"] if execution_filter else None + maximum_page_size = self._params.get("maximumPageSize", 1000) + reverse_order = self._params.get("reverseOrder", None) + tag_filter = self._params.get("tagFilter", None) + type_filter = self._params.get("typeFilter", None) self._check_string(domain) self._check_none_or_string(workflow_id) - self._check_exclusivity(executionFilter=execution_filter, - typeFilter=type_filter, - tagFilter=tag_filter) - self._check_float_or_int(start_time_filter['oldestDate']) - if 'latestDate' in start_time_filter: - self._check_float_or_int(start_time_filter['latestDate']) + self._check_exclusivity( + executionFilter=execution_filter, + typeFilter=type_filter, + tagFilter=tag_filter, + ) + self._check_float_or_int(start_time_filter["oldestDate"]) + if "latestDate" in start_time_filter: + self._check_float_or_int(start_time_filter["latestDate"]) if tag_filter: - self._check_string(tag_filter['tag']) + self._check_string(tag_filter["tag"]) if type_filter: - self._check_string(type_filter['name']) - self._check_string(type_filter['version']) + self._check_string(type_filter["name"]) + self._check_string(type_filter["version"]) self._check_int(maximum_page_size) workflow_executions = self.swf_backend.list_open_workflow_executions( @@ -195,12 +196,12 @@ class SWFResponse(BaseResponse): type_filter=type_filter, maximum_page_size=maximum_page_size, reverse_order=reverse_order, - workflow_id=workflow_id + workflow_id=workflow_id, ) - return json.dumps({ - 'executionInfos': [wfe.to_list_dict() for wfe in workflow_executions] - }) + return json.dumps( + {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} + ) def register_domain(self): name = self._params["name"] @@ -209,8 +210,7 @@ class SWFResponse(BaseResponse): self._check_string(retention) self._check_string(name) self._check_none_or_string(description) - self.swf_backend.register_domain(name, retention, - description=description) + self.swf_backend.register_domain(name, retention, description=description) return "" def deprecate_domain(self): @@ -238,14 +238,16 @@ class SWFResponse(BaseResponse): task_list = default_task_list.get("name") else: task_list = None - default_task_heartbeat_timeout = self._params.get( - "defaultTaskHeartbeatTimeout") + default_task_heartbeat_timeout = self._params.get("defaultTaskHeartbeatTimeout") default_task_schedule_to_close_timeout = self._params.get( - "defaultTaskScheduleToCloseTimeout") + "defaultTaskScheduleToCloseTimeout" + ) default_task_schedule_to_start_timeout = self._params.get( - "defaultTaskScheduleToStartTimeout") + "defaultTaskScheduleToStartTimeout" + ) default_task_start_to_close_timeout = self._params.get( - "defaultTaskStartToCloseTimeout") + "defaultTaskStartToCloseTimeout" + ) description = self._params.get("description") self._check_string(domain) @@ -260,7 +262,11 @@ class SWFResponse(BaseResponse): # TODO: add defaultTaskPriority when boto gets to support it self.swf_backend.register_type( - "activity", domain, name, version, task_list=task_list, + "activity", + domain, + name, + version, + task_list=task_list, default_task_heartbeat_timeout=default_task_heartbeat_timeout, default_task_schedule_to_close_timeout=default_task_schedule_to_close_timeout, default_task_schedule_to_start_timeout=default_task_schedule_to_start_timeout, @@ -289,9 +295,11 @@ class SWFResponse(BaseResponse): task_list = None default_child_policy = self._params.get("defaultChildPolicy") default_task_start_to_close_timeout = self._params.get( - "defaultTaskStartToCloseTimeout") + "defaultTaskStartToCloseTimeout" + ) default_execution_start_to_close_timeout = self._params.get( - "defaultExecutionStartToCloseTimeout") + "defaultExecutionStartToCloseTimeout" + ) description = self._params.get("description") self._check_string(domain) @@ -306,7 +314,11 @@ class SWFResponse(BaseResponse): # TODO: add defaultTaskPriority when boto gets to support it # TODO: add defaultLambdaRole when boto gets to support it self.swf_backend.register_type( - "workflow", domain, name, version, task_list=task_list, + "workflow", + domain, + name, + version, + task_list=task_list, default_child_policy=default_child_policy, default_task_start_to_close_timeout=default_task_start_to_close_timeout, default_execution_start_to_close_timeout=default_execution_start_to_close_timeout, @@ -333,11 +345,11 @@ class SWFResponse(BaseResponse): task_list = None child_policy = self._params.get("childPolicy") execution_start_to_close_timeout = self._params.get( - "executionStartToCloseTimeout") + "executionStartToCloseTimeout" + ) input_ = self._params.get("input") tag_list = self._params.get("tagList") - task_start_to_close_timeout = self._params.get( - "taskStartToCloseTimeout") + task_start_to_close_timeout = self._params.get("taskStartToCloseTimeout") self._check_string(domain) self._check_string(workflow_id) @@ -351,16 +363,19 @@ class SWFResponse(BaseResponse): self._check_none_or_string(task_start_to_close_timeout) wfe = self.swf_backend.start_workflow_execution( - domain, workflow_id, workflow_name, workflow_version, - task_list=task_list, child_policy=child_policy, + domain, + workflow_id, + workflow_name, + workflow_version, + task_list=task_list, + child_policy=child_policy, execution_start_to_close_timeout=execution_start_to_close_timeout, - input=input_, tag_list=tag_list, - task_start_to_close_timeout=task_start_to_close_timeout + input=input_, + tag_list=tag_list, + task_start_to_close_timeout=task_start_to_close_timeout, ) - return json.dumps({ - "runId": wfe.run_id - }) + return json.dumps({"runId": wfe.run_id}) def describe_workflow_execution(self): domain_name = self._params["domain"] @@ -373,7 +388,8 @@ class SWFResponse(BaseResponse): self._check_string(workflow_id) wfe = self.swf_backend.describe_workflow_execution( - domain_name, run_id, workflow_id) + domain_name, run_id, workflow_id + ) return json.dumps(wfe.to_full_dict()) def get_workflow_execution_history(self): @@ -383,11 +399,10 @@ class SWFResponse(BaseResponse): workflow_id = _workflow_execution["workflowId"] reverse_order = self._params.get("reverseOrder", None) wfe = self.swf_backend.describe_workflow_execution( - domain_name, run_id, workflow_id) + domain_name, run_id, workflow_id + ) events = wfe.events(reverse_order=reverse_order) - return json.dumps({ - "events": [evt.to_dict() for evt in events] - }) + return json.dumps({"events": [evt.to_dict() for evt in events]}) def poll_for_decision_task(self): domain_name = self._params["domain"] @@ -402,9 +417,7 @@ class SWFResponse(BaseResponse): domain_name, task_list, identity=identity ) if decision: - return json.dumps( - decision.to_full_dict(reverse_order=reverse_order) - ) + return json.dumps(decision.to_full_dict(reverse_order=reverse_order)) else: return json.dumps({"previousStartedEventId": 0, "startedEventId": 0}) @@ -413,8 +426,7 @@ class SWFResponse(BaseResponse): task_list = self._params["taskList"]["name"] self._check_string(domain_name) self._check_string(task_list) - count = self.swf_backend.count_pending_decision_tasks( - domain_name, task_list) + count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) def respond_decision_task_completed(self): @@ -439,9 +451,7 @@ class SWFResponse(BaseResponse): domain_name, task_list, identity=identity ) if activity_task: - return json.dumps( - activity_task.to_full_dict() - ) + return json.dumps(activity_task.to_full_dict()) else: return json.dumps({"startedEventId": 0}) @@ -450,8 +460,7 @@ class SWFResponse(BaseResponse): task_list = self._params["taskList"]["name"] self._check_string(domain_name) self._check_string(task_list) - count = self.swf_backend.count_pending_activity_tasks( - domain_name, task_list) + count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) def respond_activity_task_completed(self): @@ -459,9 +468,7 @@ class SWFResponse(BaseResponse): result = self._params.get("result") self._check_string(task_token) self._check_none_or_string(result) - self.swf_backend.respond_activity_task_completed( - task_token, result=result - ) + self.swf_backend.respond_activity_task_completed(task_token, result=result) return "" def respond_activity_task_failed(self): @@ -492,8 +499,12 @@ class SWFResponse(BaseResponse): self._check_none_or_string(reason) self._check_none_or_string(run_id) self.swf_backend.terminate_workflow_execution( - domain_name, workflow_id, child_policy=child_policy, - details=details, reason=reason, run_id=run_id + domain_name, + workflow_id, + child_policy=child_policy, + details=details, + reason=reason, + run_id=run_id, ) return "" @@ -502,9 +513,7 @@ class SWFResponse(BaseResponse): details = self._params.get("details") self._check_string(task_token) self._check_none_or_string(details) - self.swf_backend.record_activity_task_heartbeat( - task_token, details=details - ) + self.swf_backend.record_activity_task_heartbeat(task_token, details=details) # TODO: make it dynamic when we implement activity tasks cancellation return json.dumps({"cancelRequested": False}) @@ -522,5 +531,6 @@ class SWFResponse(BaseResponse): self._check_none_or_string(run_id) self.swf_backend.signal_workflow_execution( - domain_name, signal_name, workflow_id, _input, run_id) + domain_name, signal_name, workflow_id, _input, run_id + ) return "" diff --git a/moto/swf/urls.py b/moto/swf/urls.py index 582c874fc..cafc39ad3 100644 --- a/moto/swf/urls.py +++ b/moto/swf/urls.py @@ -1,9 +1,5 @@ from .responses import SWFResponse -url_bases = [ - "https?://swf.(.+).amazonaws.com", -] +url_bases = ["https?://swf.(.+).amazonaws.com"] -url_paths = { - '{0}/$': SWFResponse.dispatch, -} +url_paths = {"{0}/$": SWFResponse.dispatch} diff --git a/moto/swf/utils.py b/moto/swf/utils.py index de628ce50..1b85f4ca9 100644 --- a/moto/swf/utils.py +++ b/moto/swf/utils.py @@ -1,3 +1,2 @@ - def decapitalize(key): return key[0].lower() + key[1:] diff --git a/moto/xray/__init__.py b/moto/xray/__init__.py index 41f00af58..c6c612250 100644 --- a/moto/xray/__init__.py +++ b/moto/xray/__init__.py @@ -3,5 +3,5 @@ from .models import xray_backends from ..core.models import base_decorator from .mock_client import mock_xray_client, XRaySegment # noqa -xray_backend = xray_backends['us-east-1'] +xray_backend = xray_backends["us-east-1"] mock_xray = base_decorator(xray_backends) diff --git a/moto/xray/exceptions.py b/moto/xray/exceptions.py index 24f700178..8b5c87e36 100644 --- a/moto/xray/exceptions.py +++ b/moto/xray/exceptions.py @@ -11,11 +11,14 @@ class AWSError(Exception): self.status = status if status is not None else self.STATUS def response(self): - return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) + return ( + json.dumps({"__type": self.code, "message": self.message}), + dict(status=self.status), + ) class InvalidRequestException(AWSError): - CODE = 'InvalidRequestException' + CODE = "InvalidRequestException" class BadSegmentException(Exception): @@ -25,15 +28,15 @@ class BadSegmentException(Exception): self.message = message def __repr__(self): - return ''.format('-'.join([self.id, self.code, self.message])) + return "".format("-".join([self.id, self.code, self.message])) def to_dict(self): result = {} if self.id is not None: - result['Id'] = self.id + result["Id"] = self.id if self.code is not None: - result['ErrorCode'] = self.code + result["ErrorCode"] = self.code if self.message is not None: - result['Message'] = self.message + result["Message"] = self.message return result diff --git a/moto/xray/mock_client.py b/moto/xray/mock_client.py index 135796054..9e042c594 100644 --- a/moto/xray/mock_client.py +++ b/moto/xray/mock_client.py @@ -10,8 +10,11 @@ class MockEmitter(UDPEmitter): """ Replaces the code that sends UDP to local X-Ray daemon """ - def __init__(self, daemon_address='127.0.0.1:2000'): - address = os.getenv('AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE', daemon_address) + + def __init__(self, daemon_address="127.0.0.1:2000"): + address = os.getenv( + "AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE", daemon_address + ) self._ip, self._port = self._parse_address(address) def _xray_backend(self, region): @@ -26,7 +29,7 @@ class MockEmitter(UDPEmitter): pass def _send_data(self, data): - raise RuntimeError('Should not be running this') + raise RuntimeError("Should not be running this") def mock_xray_client(f): @@ -39,12 +42,13 @@ def mock_xray_client(f): We also patch the Emitter by subclassing the UDPEmitter class replacing its methods and pushing that itno the recorder instance. """ + @wraps(f) def _wrapped(*args, **kwargs): print("Starting X-Ray Patch") - old_xray_context_var = os.environ.get('AWS_XRAY_CONTEXT_MISSING') - os.environ['AWS_XRAY_CONTEXT_MISSING'] = 'LOG_ERROR' + old_xray_context_var = os.environ.get("AWS_XRAY_CONTEXT_MISSING") + os.environ["AWS_XRAY_CONTEXT_MISSING"] = "LOG_ERROR" old_xray_context = aws_xray_sdk.core.xray_recorder._context old_xray_emitter = aws_xray_sdk.core.xray_recorder._emitter aws_xray_sdk.core.xray_recorder._context = AWSContext() @@ -55,9 +59,9 @@ def mock_xray_client(f): finally: if old_xray_context_var is None: - del os.environ['AWS_XRAY_CONTEXT_MISSING'] + del os.environ["AWS_XRAY_CONTEXT_MISSING"] else: - os.environ['AWS_XRAY_CONTEXT_MISSING'] = old_xray_context_var + os.environ["AWS_XRAY_CONTEXT_MISSING"] = old_xray_context_var aws_xray_sdk.core.xray_recorder._emitter = old_xray_emitter aws_xray_sdk.core.xray_recorder._context = old_xray_context @@ -74,8 +78,11 @@ class XRaySegment(object): During testing we're going to have to control the start and end of a segment via context managers. """ + def __enter__(self): - aws_xray_sdk.core.xray_recorder.begin_segment(name='moto_mock', traceid=None, parent_id=None, sampling=1) + aws_xray_sdk.core.xray_recorder.begin_segment( + name="moto_mock", traceid=None, parent_id=None, sampling=1 + ) return self diff --git a/moto/xray/models.py b/moto/xray/models.py index b2d418232..33a271f9b 100644 --- a/moto/xray/models.py +++ b/moto/xray/models.py @@ -18,18 +18,36 @@ class TelemetryRecords(BaseModel): @classmethod def from_json(cls, json): - instance_id = json.get('EC2InstanceId', None) - hostname = json.get('Hostname') - resource_arn = json.get('ResourceARN') - telemetry_records = json['TelemetryRecords'] + instance_id = json.get("EC2InstanceId", None) + hostname = json.get("Hostname") + resource_arn = json.get("ResourceARN") + telemetry_records = json["TelemetryRecords"] return cls(instance_id, hostname, resource_arn, telemetry_records) # https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html class TraceSegment(BaseModel): - def __init__(self, name, segment_id, trace_id, start_time, raw, end_time=None, in_progress=False, service=None, user=None, - origin=None, parent_id=None, http=None, aws=None, metadata=None, annotations=None, subsegments=None, **kwargs): + def __init__( + self, + name, + segment_id, + trace_id, + start_time, + raw, + end_time=None, + in_progress=False, + service=None, + user=None, + origin=None, + parent_id=None, + http=None, + aws=None, + metadata=None, + annotations=None, + subsegments=None, + **kwargs + ): self.name = name self.id = segment_id self.trace_id = trace_id @@ -61,14 +79,16 @@ class TraceSegment(BaseModel): @property def trace_version(self): if self._trace_version is None: - self._trace_version = int(self.trace_id.split('-', 1)[0]) + self._trace_version = int(self.trace_id.split("-", 1)[0]) return self._trace_version @property def request_start_date(self): if self._original_request_start_time is None: - start_time = int(self.trace_id.split('-')[1], 16) - self._original_request_start_time = datetime.datetime.fromtimestamp(start_time) + start_time = int(self.trace_id.split("-")[1], 16) + self._original_request_start_time = datetime.datetime.fromtimestamp( + start_time + ) return self._original_request_start_time @property @@ -86,19 +106,27 @@ class TraceSegment(BaseModel): @classmethod def from_dict(cls, data, raw): # Check manditory args - if 'id' not in data: - raise BadSegmentException(code='MissingParam', message='Missing segment ID') - seg_id = data['id'] - data['segment_id'] = seg_id # Just adding this key for future convenience + if "id" not in data: + raise BadSegmentException(code="MissingParam", message="Missing segment ID") + seg_id = data["id"] + data["segment_id"] = seg_id # Just adding this key for future convenience - for arg in ('name', 'trace_id', 'start_time'): + for arg in ("name", "trace_id", "start_time"): if arg not in data: - raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing segment ID') + raise BadSegmentException( + seg_id=seg_id, code="MissingParam", message="Missing segment ID" + ) - if 'end_time' not in data and 'in_progress' not in data: - raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time or in_progress') - if 'end_time' not in data and data['in_progress'] == 'false': - raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time') + if "end_time" not in data and "in_progress" not in data: + raise BadSegmentException( + seg_id=seg_id, + code="MissingParam", + message="Missing end_time or in_progress", + ) + if "end_time" not in data and data["in_progress"] == "false": + raise BadSegmentException( + seg_id=seg_id, code="MissingParam", message="Missing end_time" + ) return cls(raw=raw, **data) @@ -110,65 +138,79 @@ class SegmentCollection(object): @staticmethod def _new_trace_item(): return { - 'start_date': datetime.datetime(1970, 1, 1), - 'end_date': datetime.datetime(1970, 1, 1), - 'finished': False, - 'trace_id': None, - 'segments': [] + "start_date": datetime.datetime(1970, 1, 1), + "end_date": datetime.datetime(1970, 1, 1), + "finished": False, + "trace_id": None, + "segments": [], } def put_segment(self, segment): # insert into a sorted list - bisect.insort_left(self._traces[segment.trace_id]['segments'], segment) + bisect.insort_left(self._traces[segment.trace_id]["segments"], segment) # Get the last segment (takes into account incorrect ordering) # and if its the last one, mark trace as complete - if self._traces[segment.trace_id]['segments'][-1].end_time is not None: - self._traces[segment.trace_id]['finished'] = True + if self._traces[segment.trace_id]["segments"][-1].end_time is not None: + self._traces[segment.trace_id]["finished"] = True - start_time = self._traces[segment.trace_id]['segments'][0].start_date - end_time = self._traces[segment.trace_id]['segments'][-1].end_date - self._traces[segment.trace_id]['start_date'] = start_time - self._traces[segment.trace_id]['end_date'] = end_time - self._traces[segment.trace_id]['trace_id'] = segment.trace_id + start_time = self._traces[segment.trace_id]["segments"][0].start_date + end_time = self._traces[segment.trace_id]["segments"][-1].end_date + self._traces[segment.trace_id]["start_date"] = start_time + self._traces[segment.trace_id]["end_date"] = end_time + self._traces[segment.trace_id]["trace_id"] = segment.trace_id # Todo consolidate trace segments into a trace. # not enough working knowledge of xray to do this def summary(self, start_time, end_time, filter_expression=None, sampling=False): # This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax if filter_expression is not None: - raise AWSError('Not implemented yet - moto', code='InternalFailure', status=500) + raise AWSError( + "Not implemented yet - moto", code="InternalFailure", status=500 + ) summaries = [] for tid, trace in self._traces.items(): - if trace['finished'] and start_time < trace['start_date'] and trace['end_date'] < end_time: - duration = int((trace['end_date'] - trace['start_date']).total_seconds()) + if ( + trace["finished"] + and start_time < trace["start_date"] + and trace["end_date"] < end_time + ): + duration = int( + (trace["end_date"] - trace["start_date"]).total_seconds() + ) # this stuff is mostly guesses, refer to TODO above - has_error = any(['error' in seg.misc for seg in trace['segments']]) - has_fault = any(['fault' in seg.misc for seg in trace['segments']]) - has_throttle = any(['throttle' in seg.misc for seg in trace['segments']]) + has_error = any(["error" in seg.misc for seg in trace["segments"]]) + has_fault = any(["fault" in seg.misc for seg in trace["segments"]]) + has_throttle = any( + ["throttle" in seg.misc for seg in trace["segments"]] + ) # Apparently all of these options are optional summary_part = { - 'Annotations': {}, # Not implemented yet - 'Duration': duration, - 'HasError': has_error, - 'HasFault': has_fault, - 'HasThrottle': has_throttle, - 'Http': {}, # Not implemented yet - 'Id': tid, - 'IsParital': False, # needs lots more work to work on partials - 'ResponseTime': 1, # definitely 1ms resposnetime - 'ServiceIds': [], # Not implemented yet - 'Users': {} # Not implemented yet + "Annotations": {}, # Not implemented yet + "Duration": duration, + "HasError": has_error, + "HasFault": has_fault, + "HasThrottle": has_throttle, + "Http": {}, # Not implemented yet + "Id": tid, + "IsParital": False, # needs lots more work to work on partials + "ResponseTime": 1, # definitely 1ms resposnetime + "ServiceIds": [], # Not implemented yet + "Users": {}, # Not implemented yet } summaries.append(summary_part) result = { - "ApproximateTime": int((datetime.datetime.now() - datetime.datetime(1970, 1, 1)).total_seconds()), + "ApproximateTime": int( + ( + datetime.datetime.now() - datetime.datetime(1970, 1, 1) + ).total_seconds() + ), "TracesProcessedCount": len(summaries), - "TraceSummaries": summaries + "TraceSummaries": summaries, } return result @@ -189,59 +231,57 @@ class SegmentCollection(object): class XRayBackend(BaseBackend): - def __init__(self): self._telemetry_records = [] self._segment_collection = SegmentCollection() def add_telemetry_records(self, json): - self._telemetry_records.append( - TelemetryRecords.from_json(json) - ) + self._telemetry_records.append(TelemetryRecords.from_json(json)) def process_segment(self, doc): try: data = json.loads(doc) except ValueError: - raise BadSegmentException(code='JSONFormatError', message='Bad JSON data') + raise BadSegmentException(code="JSONFormatError", message="Bad JSON data") try: # Get Segment Object segment = TraceSegment.from_dict(data, raw=doc) except ValueError: - raise BadSegmentException(code='JSONFormatError', message='Bad JSON data') + raise BadSegmentException(code="JSONFormatError", message="Bad JSON data") try: # Store Segment Object self._segment_collection.put_segment(segment) except Exception as err: - raise BadSegmentException(seg_id=segment.id, code='InternalFailure', message=str(err)) + raise BadSegmentException( + seg_id=segment.id, code="InternalFailure", message=str(err) + ) def get_trace_summary(self, start_time, end_time, filter_expression, summaries): - return self._segment_collection.summary(start_time, end_time, filter_expression, summaries) + return self._segment_collection.summary( + start_time, end_time, filter_expression, summaries + ) def get_trace_ids(self, trace_ids, next_token): traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids) - result = { - 'Traces': [], - 'UnprocessedTraceIds': unprocessed_ids - - } + result = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids} for trace in traces: segments = [] - for segment in trace['segments']: - segments.append({ - 'Id': segment.id, - 'Document': segment.raw - }) + for segment in trace["segments"]: + segments.append({"Id": segment.id, "Document": segment.raw}) - result['Traces'].append({ - 'Duration': int((trace['end_date'] - trace['start_date']).total_seconds()), - 'Id': trace['trace_id'], - 'Segments': segments - }) + result["Traces"].append( + { + "Duration": int( + (trace["end_date"] - trace["start_date"]).total_seconds() + ), + "Id": trace["trace_id"], + "Segments": segments, + } + ) return result diff --git a/moto/xray/responses.py b/moto/xray/responses.py index 328a266bf..118f2de2f 100644 --- a/moto/xray/responses.py +++ b/moto/xray/responses.py @@ -10,9 +10,8 @@ from .exceptions import AWSError, BadSegmentException class XRayResponse(BaseResponse): - def _error(self, code, message): - return json.dumps({'__type': code, 'message': message}), dict(status=400) + return json.dumps({"__type": code, "message": message}), dict(status=400) @property def xray_backend(self): @@ -32,7 +31,7 @@ class XRayResponse(BaseResponse): # Amazon is just calling urls like /TelemetryRecords etc... # This uses the value after / as the camalcase action, which then # gets converted in call_action to find the following methods - return urlsplit(self.uri).path.lstrip('/') + return urlsplit(self.uri).path.lstrip("/") # PutTelemetryRecords def telemetry_records(self): @@ -41,15 +40,18 @@ class XRayResponse(BaseResponse): except AWSError as err: return err.response() - return '' + return "" # PutTraceSegments def trace_segments(self): - docs = self._get_param('TraceSegmentDocuments') + docs = self._get_param("TraceSegmentDocuments") if docs is None: - msg = 'Parameter TraceSegmentDocuments is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter TraceSegmentDocuments is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) # Raises an exception that contains info about a bad segment, # the object also has a to_dict() method @@ -60,91 +62,120 @@ class XRayResponse(BaseResponse): except BadSegmentException as bad_seg: bad_segments.append(bad_seg) except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) - result = {'UnprocessedTraceSegments': [x.to_dict() for x in bad_segments]} + result = {"UnprocessedTraceSegments": [x.to_dict() for x in bad_segments]} return json.dumps(result) # GetTraceSummaries def trace_summaries(self): - start_time = self._get_param('StartTime') - end_time = self._get_param('EndTime') + start_time = self._get_param("StartTime") + end_time = self._get_param("EndTime") if start_time is None: - msg = 'Parameter StartTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter StartTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) if end_time is None: - msg = 'Parameter EndTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter EndTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) - filter_expression = self._get_param('FilterExpression') - sampling = self._get_param('Sampling', 'false') == 'true' + filter_expression = self._get_param("FilterExpression") + sampling = self._get_param("Sampling", "false") == "true" try: start_time = datetime.datetime.fromtimestamp(int(start_time)) end_time = datetime.datetime.fromtimestamp(int(end_time)) except ValueError: - msg = 'start_time and end_time are not integers' - return json.dumps({'__type': 'InvalidParameterValue', 'message': msg}), dict(status=400) + msg = "start_time and end_time are not integers" + return ( + json.dumps({"__type": "InvalidParameterValue", "message": msg}), + dict(status=400), + ) except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) try: - result = self.xray_backend.get_trace_summary(start_time, end_time, filter_expression, sampling) + result = self.xray_backend.get_trace_summary( + start_time, end_time, filter_expression, sampling + ) except AWSError as err: return err.response() except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) return json.dumps(result) # BatchGetTraces def traces(self): - trace_ids = self._get_param('TraceIds') - next_token = self._get_param('NextToken') # not implemented yet + trace_ids = self._get_param("TraceIds") + next_token = self._get_param("NextToken") # not implemented yet if trace_ids is None: - msg = 'Parameter TraceIds is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter TraceIds is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: result = self.xray_backend.get_trace_ids(trace_ids, next_token) except AWSError as err: return err.response() except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) return json.dumps(result) # GetServiceGraph - just a dummy response for now def service_graph(self): - start_time = self._get_param('StartTime') - end_time = self._get_param('EndTime') + start_time = self._get_param("StartTime") + end_time = self._get_param("EndTime") # next_token = self._get_param('NextToken') # not implemented yet if start_time is None: - msg = 'Parameter StartTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter StartTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) if end_time is None: - msg = 'Parameter EndTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter EndTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) - result = { - 'StartTime': start_time, - 'EndTime': end_time, - 'Services': [] - } + result = {"StartTime": start_time, "EndTime": end_time, "Services": []} return json.dumps(result) # GetTraceGraph - just a dummy response for now def trace_graph(self): - trace_ids = self._get_param('TraceIds') + trace_ids = self._get_param("TraceIds") # next_token = self._get_param('NextToken') # not implemented yet if trace_ids is None: - msg = 'Parameter TraceIds is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter TraceIds is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) - result = { - 'Services': [] - } + result = {"Services": []} return json.dumps(result) diff --git a/moto/xray/urls.py b/moto/xray/urls.py index b0f13a980..4a3d4b253 100644 --- a/moto/xray/urls.py +++ b/moto/xray/urls.py @@ -1,15 +1,13 @@ from __future__ import unicode_literals from .responses import XRayResponse -url_bases = [ - "https?://xray.(.+).amazonaws.com", -] +url_bases = ["https?://xray.(.+).amazonaws.com"] url_paths = { - '{0}/TelemetryRecords$': XRayResponse.dispatch, - '{0}/TraceSegments$': XRayResponse.dispatch, - '{0}/Traces$': XRayResponse.dispatch, - '{0}/ServiceGraph$': XRayResponse.dispatch, - '{0}/TraceGraph$': XRayResponse.dispatch, - '{0}/TraceSummaries$': XRayResponse.dispatch, + "{0}/TelemetryRecords$": XRayResponse.dispatch, + "{0}/TraceSegments$": XRayResponse.dispatch, + "{0}/Traces$": XRayResponse.dispatch, + "{0}/ServiceGraph$": XRayResponse.dispatch, + "{0}/TraceGraph$": XRayResponse.dispatch, + "{0}/TraceSummaries$": XRayResponse.dispatch, } diff --git a/requirements-dev.txt b/requirements-dev.txt index f87ab3db6..c5f055a26 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,15 +1,18 @@ -r requirements.txt mock nose +black; python_version >= '3.6' +regex==2019.11.1; python_version >= '3.6' # Needed for black sure==1.4.11 -coverage -flake8==3.5.0 +coverage==4.5.4 +flake8==3.7.8 freezegun flask boto>=2.45.0 boto3>=1.4.4 botocore>=1.12.13 six>=1.9 +parameterized>=0.7.0 prompt-toolkit==1.0.14 click==6.7 inflection==0.3.1 diff --git a/scripts/implementation_coverage.py b/scripts/implementation_coverage.py index 0e1816088..4552ec18e 100755 --- a/scripts/implementation_coverage.py +++ b/scripts/implementation_coverage.py @@ -7,16 +7,18 @@ import boto3 script_dir = os.path.dirname(os.path.abspath(__file__)) +alternative_service_names = {'lambda': 'awslambda'} def get_moto_implementation(service_name): - service_name_standardized = service_name.replace("-", "") if "-" in service_name else service_name - if not hasattr(moto, service_name_standardized): + service_name = service_name.replace("-", "") if "-" in service_name else service_name + alt_service_name = alternative_service_names[service_name] if service_name in alternative_service_names else service_name + if not hasattr(moto, alt_service_name): return None - module = getattr(moto, service_name_standardized) + module = getattr(moto, alt_service_name) if module is None: return None - mock = getattr(module, "mock_{}".format(service_name_standardized)) + mock = getattr(module, "mock_{}".format(service_name)) if mock is None: return None backends = list(mock().backends.values()) @@ -71,16 +73,16 @@ def print_implementation_coverage(coverage): def write_implementation_coverage_to_file(coverage): + implementation_coverage_file = "{}/../IMPLEMENTATION_COVERAGE.md".format(script_dir) + # rewrite the implementation coverage file with updated values # try deleting the implementation coverage file try: - os.remove("../IMPLEMENTATION_COVERAGE.md") + os.remove(implementation_coverage_file) except OSError: pass - implementation_coverage_file = "{}/../IMPLEMENTATION_COVERAGE.md".format(script_dir) - # rewrite the implementation coverage file with updated values print("Writing to {}".format(implementation_coverage_file)) - with open(implementation_coverage_file, "a+") as file: + with open(implementation_coverage_file, "w+") as file: for service_name in sorted(coverage): implemented = coverage.get(service_name)['implemented'] not_implemented = coverage.get(service_name)['not_implemented'] diff --git a/scripts/scaffold.py b/scripts/scaffold.py index 6c83eeb50..be154f103 100755 --- a/scripts/scaffold.py +++ b/scripts/scaffold.py @@ -119,7 +119,7 @@ def append_mock_to_init_py(service): filtered_lines = [_ for _ in lines if re.match('^from.*mock.*$', _)] last_import_line_index = lines.index(filtered_lines[-1]) - new_line = 'from .{} import mock_{} # flake8: noqa'.format(get_escaped_service(service), get_escaped_service(service)) + new_line = 'from .{} import mock_{} # noqa'.format(get_escaped_service(service), get_escaped_service(service)) lines.insert(last_import_line_index + 1, new_line) body = '\n'.join(lines) + '\n' diff --git a/setup.py b/setup.py index 77ebdfcf7..97a6341ff 100755 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ install_requires = [ "werkzeug", "PyYAML>=5.1", "pytz", - "python-dateutil<3.0.0,>=2.1", + "python-dateutil<2.8.1,>=2.1", "python-jose<4.0.0", "mock", "docker>=2.5.1", @@ -94,4 +94,7 @@ setup( "License :: OSI Approved :: Apache Software License", "Topic :: Software Development :: Testing", ], -) \ No newline at end of file + project_urls={ + "Documentation": "http://docs.getmoto.org/en/latest/", + }, +) diff --git a/tests/__init__.py b/tests/__init__.py index bf582e0b3..05b1d476b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import logging + # Disable extra logging for tests -logging.getLogger('boto').setLevel(logging.CRITICAL) -logging.getLogger('boto3').setLevel(logging.CRITICAL) -logging.getLogger('botocore').setLevel(logging.CRITICAL) -logging.getLogger('nose').setLevel(logging.CRITICAL) +logging.getLogger("boto").setLevel(logging.CRITICAL) +logging.getLogger("boto3").setLevel(logging.CRITICAL) +logging.getLogger("botocore").setLevel(logging.CRITICAL) +logging.getLogger("nose").setLevel(logging.CRITICAL) diff --git a/tests/backport_assert_raises.py b/tests/backport_assert_raises.py index 9b20edf9d..bfed51308 100644 --- a/tests/backport_assert_raises.py +++ b/tests/backport_assert_raises.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + """ Patch courtesy of: https://marmida.com/blog/index.php/2012/08/08/monkey-patching-assert_raises/ @@ -19,7 +20,6 @@ try: except TypeError: # this version of assert_raises doesn't support the 1-arg version class AssertRaisesContext(object): - def __init__(self, expected): self.expected = expected diff --git a/tests/helpers.py b/tests/helpers.py index 50615b094..ffe27103d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -29,7 +29,6 @@ class requires_boto_gte(object): class disable_on_py3(object): - def __call__(self, test): if not six.PY3: return test diff --git a/tests/test_acm/resources/star_moto_com.pem b/tests/test_acm/resources/star_moto_com.pem index 6d599d53e..646972e95 100644 --- a/tests/test_acm/resources/star_moto_com.pem +++ b/tests/test_acm/resources/star_moto_com.pem @@ -1,8 +1,8 @@ -----BEGIN CERTIFICATE----- -MIIEUDCCAjgCCQDfXZHMio+6oDANBgkqhkiG9w0BAQ0FADBjMQswCQYDVQQGEwJH +MIIEUDCCAjgCCQDfXZHMio+6oDANBgkqhkiG9w0BAQsFADBjMQswCQYDVQQGEwJH QjESMBAGA1UECAwJQmVya3NoaXJlMQ8wDQYDVQQHDAZTbG91Z2gxEzARBgNVBAoM -Ck1vdG9TZXJ2ZXIxCzAJBgNVBAsMAlFBMQ0wCwYDVQQDDARNb3RvMB4XDTE3MDky -MTIxMjQ1MFoXDTI3MDkxOTIxMjQ1MFowcTELMAkGA1UEBhMCR0IxEjAQBgNVBAgM +Ck1vdG9TZXJ2ZXIxCzAJBgNVBAsMAlFBMQ0wCwYDVQQDDARNb3RvMB4XDTE5MTAy +MTEzMjczMVoXDTQ5MTIzMTEzMjczNFowcTELMAkGA1UEBhMCR0IxEjAQBgNVBAgM CUJlcmtzaGlyZTEPMA0GA1UEBwwGU2xvdWdoMRMwEQYDVQQKDApNb3RvU2VydmVy MRMwEQYDVQQLDApPcGVyYXRpb25zMRMwEQYDVQQDDAoqLm1vdG8uY29tMIIBIjAN BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzC/oBkzwiIBEceSC/tSD7hkqs8AW @@ -11,16 +11,16 @@ niDXbMgAQE9oxUxtkFESxiNa+EbAMLBFtBkPRvc3iKXh/cfLo7yP8VdqEIDmJCB/ vpjJvf6HnrNJ7keQR+oGJNf7jVaCgOVdJ4lt7+98YDVde7jLx1DN+QbvViJQl60n K3bmfuLiiw8154Eyi9DOcJE8AB+W7KpPdrmbPisR1EiqY0i0L62ZixN0rPi5hHF+ ozwURL1axcmLjlhIFi8YhBCNcY6ThE7jrqgLIq1n6d8ezRxjDKmqfH1spQIDAQAB -MA0GCSqGSIb3DQEBDQUAA4ICAQCgl/EfjE0Jh3cqQgoOlaFq6L1iJVgy5sYKCC4r -OU4dHgifZ6/grqCJesGiS1Vh4L8XklN++C2aSL73lVtxXoCSopP8Yj0rOGeA6b+7 -Fetm4ZQYF61QtahC0L2fkvKXR+uz1I85ndSoMJPT8lbm7sYJuL81Si32NOo6kC6y -4eKzV4KznxdAf6XaQMKtMIyXO3PWTrjm5ayzS6UsmnBvULGDCaAQznFlVFdGNSHx -CaENICR0CBcB+vbL7FPC683a4afceM+aMcMVElWG5q8fxtgbL/aPhzfonhDGWOM4 -Rdg8x+yDdi7swxmWlcW5wlP8LpLxN/S3GR9j9IyelxUGmb20yTph3i1K6RM/Fm2W -PI8xdneA6qycUAJo93NfaCuNK7yBfK3uDLqmWlGh3xCG+I1JETLRbxYBWiqeVTb3 -qjHMrsgqTqjcaCiKR/5H2eVkdcr8mLxrV5niyBItDl1xGxj4LF8hDLormhaCjiBb -N1cMq5saj/BpoIanlqOWby6uRMYlZvuhwKQGPVWgfuRWKFzGbMWyPCxATbiU89Wb -IykNkT1zTCE/eZwH12T4A7jrBiWq8WNfIST0Z7MReE6Oz+M9Pxx7DyDzSb2Y1RmU -xNYd8CavZLCfns00xZSo+10deMoKVS9GgxSHcS4ELaVaBQwu35emiMJSLcK7iNGE -I4WVSA== +MA0GCSqGSIb3DQEBCwUAA4ICAQAOwvJjY1cLIBVGCDPkkxH4xCP6+QRdm7bqF7X5 +DNZ70YcJ27GldrEPmKX8C1RvkC4oCsaytl8Hlw3ZcS1GvwBxTVlnYIE6nLPPi1ix +LvYYgoq+Mjk/2XPCnU/6cqJhb5INskg9s0o15jv27cUIgWVMnj+d5lvSiy1HhdYM +wvuQzXELjhe/rHw1/BFGaBV2vd7einUQwla50UZLcsj6FwWSIsv7EB4GaY/G0XqC +Mai2PltBgBPFqsZo27uBeVfxqMZtwAQlr4iWwWZm1haDy6D4GFCSR8E/gtlyhiN4 +MOk1cmr9PSOMB3CWqKjkx7lPMOQT/f+gxlCnupNHsHcZGvQV4mCPiU+lLwp+8z/s +bupQwRvu1SwSUD2rIsVeUuSP3hbMcfhiZA50lenQNApimgrThdPUoFXi07FUdL+F +1QCk6cvA48KzGRo+bPSfZQusj51k/2+hl4sHHZdWg6mGAIY9InMKmPDE4VzM8hro +fr2fJLqKQ4h+xKbEYnvPEPttUdJbvUgr9TKKVw+m3lmW9SktzE5KtvWvN6daTj9Z +oHDJkOyko3uyTzk+HwWDC/pQ2cC+iF1MjIHi72U9ibObSODg/d9cMH3XJTnZ9W3+ +He9iuH4dJpKnVjnJ5NKt7IOrPHID77160hpwF1dim22ZRp508eYapRzgawPMpCcd +a6YipQ== -----END CERTIFICATE----- diff --git a/tests/test_acm/test_acm.py b/tests/test_acm/test_acm.py index cdd8682e1..b38cd1843 100644 --- a/tests/test_acm/test_acm.py +++ b/tests/test_acm/test_acm.py @@ -9,112 +9,109 @@ import uuid from botocore.exceptions import ClientError from moto import mock_acm +from moto.core import ACCOUNT_ID -RESOURCE_FOLDER = os.path.join(os.path.dirname(__file__), 'resources') -_GET_RESOURCE = lambda x: open(os.path.join(RESOURCE_FOLDER, x), 'rb').read() -CA_CRT = _GET_RESOURCE('ca.pem') -CA_KEY = _GET_RESOURCE('ca.key') -SERVER_CRT = _GET_RESOURCE('star_moto_com.pem') -SERVER_COMMON_NAME = '*.moto.com' -SERVER_CRT_BAD = _GET_RESOURCE('star_moto_com-bad.pem') -SERVER_KEY = _GET_RESOURCE('star_moto_com.key') -BAD_ARN = 'arn:aws:acm:us-east-2:123456789012:certificate/_0000000-0000-0000-0000-000000000000' +RESOURCE_FOLDER = os.path.join(os.path.dirname(__file__), "resources") +_GET_RESOURCE = lambda x: open(os.path.join(RESOURCE_FOLDER, x), "rb").read() +CA_CRT = _GET_RESOURCE("ca.pem") +CA_KEY = _GET_RESOURCE("ca.key") +SERVER_CRT = _GET_RESOURCE("star_moto_com.pem") +SERVER_COMMON_NAME = "*.moto.com" +SERVER_CRT_BAD = _GET_RESOURCE("star_moto_com-bad.pem") +SERVER_KEY = _GET_RESOURCE("star_moto_com.key") +BAD_ARN = "arn:aws:acm:us-east-2:{}:certificate/_0000000-0000-0000-0000-000000000000".format( + ACCOUNT_ID +) def _import_cert(client): response = client.import_certificate( - Certificate=SERVER_CRT, - PrivateKey=SERVER_KEY, - CertificateChain=CA_CRT + Certificate=SERVER_CRT, PrivateKey=SERVER_KEY, CertificateChain=CA_CRT ) - return response['CertificateArn'] + return response["CertificateArn"] # Also tests GetCertificate @mock_acm def test_import_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") resp = client.import_certificate( - Certificate=SERVER_CRT, - PrivateKey=SERVER_KEY, - CertificateChain=CA_CRT + Certificate=SERVER_CRT, PrivateKey=SERVER_KEY, CertificateChain=CA_CRT ) - resp = client.get_certificate(CertificateArn=resp['CertificateArn']) + resp = client.get_certificate(CertificateArn=resp["CertificateArn"]) - resp['Certificate'].should.equal(SERVER_CRT.decode()) - resp.should.contain('CertificateChain') + resp["Certificate"].should.equal(SERVER_CRT.decode()) + resp.should.contain("CertificateChain") @mock_acm def test_import_bad_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: - client.import_certificate( - Certificate=SERVER_CRT_BAD, - PrivateKey=SERVER_KEY, - ) + client.import_certificate(Certificate=SERVER_CRT_BAD, PrivateKey=SERVER_KEY) except ClientError as err: - err.response['Error']['Code'].should.equal('ValidationException') + err.response["Error"]["Code"].should.equal("ValidationException") else: - raise RuntimeError('Should of raised ValidationException') + raise RuntimeError("Should of raised ValidationException") @mock_acm def test_list_certificates(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) resp = client.list_certificates() - len(resp['CertificateSummaryList']).should.equal(1) + len(resp["CertificateSummaryList"]).should.equal(1) - resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(arn) - resp['CertificateSummaryList'][0]['DomainName'].should.equal(SERVER_COMMON_NAME) + resp["CertificateSummaryList"][0]["CertificateArn"].should.equal(arn) + resp["CertificateSummaryList"][0]["DomainName"].should.equal(SERVER_COMMON_NAME) @mock_acm def test_list_certificates_by_status(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") issued_arn = _import_cert(client) - pending_arn = client.request_certificate(DomainName='google.com')['CertificateArn'] + pending_arn = client.request_certificate(DomainName="google.com")["CertificateArn"] resp = client.list_certificates() - len(resp['CertificateSummaryList']).should.equal(2) - resp = client.list_certificates(CertificateStatuses=['EXPIRED', 'INACTIVE']) - len(resp['CertificateSummaryList']).should.equal(0) - resp = client.list_certificates(CertificateStatuses=['PENDING_VALIDATION']) - len(resp['CertificateSummaryList']).should.equal(1) - resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(pending_arn) + len(resp["CertificateSummaryList"]).should.equal(2) + resp = client.list_certificates(CertificateStatuses=["EXPIRED", "INACTIVE"]) + len(resp["CertificateSummaryList"]).should.equal(0) + resp = client.list_certificates(CertificateStatuses=["PENDING_VALIDATION"]) + len(resp["CertificateSummaryList"]).should.equal(1) + resp["CertificateSummaryList"][0]["CertificateArn"].should.equal(pending_arn) - resp = client.list_certificates(CertificateStatuses=['ISSUED']) - len(resp['CertificateSummaryList']).should.equal(1) - resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(issued_arn) - resp = client.list_certificates(CertificateStatuses=['ISSUED', 'PENDING_VALIDATION']) - len(resp['CertificateSummaryList']).should.equal(2) - arns = {cert['CertificateArn'] for cert in resp['CertificateSummaryList']} + resp = client.list_certificates(CertificateStatuses=["ISSUED"]) + len(resp["CertificateSummaryList"]).should.equal(1) + resp["CertificateSummaryList"][0]["CertificateArn"].should.equal(issued_arn) + resp = client.list_certificates( + CertificateStatuses=["ISSUED", "PENDING_VALIDATION"] + ) + len(resp["CertificateSummaryList"]).should.equal(2) + arns = {cert["CertificateArn"] for cert in resp["CertificateSummaryList"]} arns.should.contain(issued_arn) arns.should.contain(pending_arn) - @mock_acm def test_get_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.get_certificate(CertificateArn=BAD_ARN) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") # Also tests deleting invalid certificate @mock_acm def test_delete_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) # If it does not raise an error and the next call does, all is fine @@ -123,222 +120,209 @@ def test_delete_certificate(): try: client.delete_certificate(CertificateArn=arn) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_describe_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) resp = client.describe_certificate(CertificateArn=arn) - resp['Certificate']['CertificateArn'].should.equal(arn) - resp['Certificate']['DomainName'].should.equal(SERVER_COMMON_NAME) - resp['Certificate']['Issuer'].should.equal('Moto') - resp['Certificate']['KeyAlgorithm'].should.equal('RSA_2048') - resp['Certificate']['Status'].should.equal('ISSUED') - resp['Certificate']['Type'].should.equal('IMPORTED') + resp["Certificate"]["CertificateArn"].should.equal(arn) + resp["Certificate"]["DomainName"].should.equal(SERVER_COMMON_NAME) + resp["Certificate"]["Issuer"].should.equal("Moto") + resp["Certificate"]["KeyAlgorithm"].should.equal("RSA_2048") + resp["Certificate"]["Status"].should.equal("ISSUED") + resp["Certificate"]["Type"].should.equal("IMPORTED") @mock_acm def test_describe_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.describe_certificate(CertificateArn=BAD_ARN) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") # Also tests ListTagsForCertificate @mock_acm def test_add_tags_to_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) client.add_tags_to_certificate( - CertificateArn=arn, - Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - ] + CertificateArn=arn, Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2"}] ) resp = client.list_tags_for_certificate(CertificateArn=arn) - tags = {item['Key']: item.get('Value', '__NONE__') for item in resp['Tags']} + tags = {item["Key"]: item.get("Value", "__NONE__") for item in resp["Tags"]} - tags.should.contain('key1') - tags.should.contain('key2') - tags['key1'].should.equal('value1') + tags.should.contain("key1") + tags.should.contain("key2") + tags["key1"].should.equal("value1") # This way, it ensures that we can detect if None is passed back when it shouldnt, # as we store keys without values with a value of None, but it shouldnt be passed back - tags['key2'].should.equal('__NONE__') + tags["key2"].should.equal("__NONE__") @mock_acm def test_add_tags_to_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.add_tags_to_certificate( CertificateArn=BAD_ARN, - Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - ] + Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2"}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_list_tags_for_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.list_tags_for_certificate(CertificateArn=BAD_ARN) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_remove_tags_from_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) client.add_tags_to_certificate( CertificateArn=arn, Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - {'Key': 'key3', 'Value': 'value3'}, - {'Key': 'key4', 'Value': 'value4'}, - ] + {"Key": "key1", "Value": "value1"}, + {"Key": "key2"}, + {"Key": "key3", "Value": "value3"}, + {"Key": "key4", "Value": "value4"}, + ], ) client.remove_tags_from_certificate( CertificateArn=arn, Tags=[ - {'Key': 'key1', 'Value': 'value2'}, # Should not remove as doesnt match - {'Key': 'key2'}, # Single key removal - {'Key': 'key3', 'Value': 'value3'}, # Exact match removal - {'Key': 'key4'} # Partial match removal - ] + {"Key": "key1", "Value": "value2"}, # Should not remove as doesnt match + {"Key": "key2"}, # Single key removal + {"Key": "key3", "Value": "value3"}, # Exact match removal + {"Key": "key4"}, # Partial match removal + ], ) resp = client.list_tags_for_certificate(CertificateArn=arn) - tags = {item['Key']: item.get('Value', '__NONE__') for item in resp['Tags']} + tags = {item["Key"]: item.get("Value", "__NONE__") for item in resp["Tags"]} - for key in ('key2', 'key3', 'key4'): + for key in ("key2", "key3", "key4"): tags.should_not.contain(key) - tags.should.contain('key1') + tags.should.contain("key1") @mock_acm def test_remove_tags_from_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.remove_tags_from_certificate( CertificateArn=BAD_ARN, - Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - ] + Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2"}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_resend_validation_email(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) client.resend_validation_email( - CertificateArn=arn, - Domain='*.moto.com', - ValidationDomain='NOTUSEDYET' + CertificateArn=arn, Domain="*.moto.com", ValidationDomain="NOTUSEDYET" ) # Returns nothing, boto would raise Exceptions otherwise @mock_acm def test_resend_validation_email_invalid(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) try: client.resend_validation_email( CertificateArn=arn, - Domain='no-match.moto.com', - ValidationDomain='NOTUSEDYET' + Domain="no-match.moto.com", + ValidationDomain="NOTUSEDYET", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidDomainValidationOptionsException') + err.response["Error"]["Code"].should.equal( + "InvalidDomainValidationOptionsException" + ) else: - raise RuntimeError('Should of raised InvalidDomainValidationOptionsException') + raise RuntimeError("Should of raised InvalidDomainValidationOptionsException") try: client.resend_validation_email( CertificateArn=BAD_ARN, - Domain='no-match.moto.com', - ValidationDomain='NOTUSEDYET' + Domain="no-match.moto.com", + ValidationDomain="NOTUSEDYET", ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_request_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") token = str(uuid.uuid4()) resp = client.request_certificate( - DomainName='google.com', + DomainName="google.com", IdempotencyToken=token, - SubjectAlternativeNames=['google.com', 'www.google.com', 'mail.google.com'], + SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"], ) - resp.should.contain('CertificateArn') - arn = resp['CertificateArn'] + resp.should.contain("CertificateArn") + arn = resp["CertificateArn"] arn.should.match(r"arn:aws:acm:eu-central-1:\d{12}:certificate/") resp = client.request_certificate( - DomainName='google.com', + DomainName="google.com", IdempotencyToken=token, - SubjectAlternativeNames=['google.com', 'www.google.com', 'mail.google.com'], + SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"], ) - resp['CertificateArn'].should.equal(arn) + resp["CertificateArn"].should.equal(arn) @mock_acm def test_request_certificate_no_san(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") - resp = client.request_certificate( - DomainName='google.com' - ) - resp.should.contain('CertificateArn') + resp = client.request_certificate(DomainName="google.com") + resp.should.contain("CertificateArn") + + resp2 = client.describe_certificate(CertificateArn=resp["CertificateArn"]) + resp2.should.contain("Certificate") - resp2 = client.describe_certificate( - CertificateArn=resp['CertificateArn'] - ) - resp2.should.contain('Certificate') # # Also tests the SAN code # # requires Pull: https://github.com/spulec/freezegun/pull/210 diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index 0a33f2f9f..59c0c07f6 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -9,649 +9,537 @@ from botocore.exceptions import ClientError import responses from moto import mock_apigateway, settings +from moto.core import ACCOUNT_ID +from nose.tools import assert_raises @freeze_time("2015-01-01") @mock_apigateway def test_create_and_get_rest_api(): - client = boto3.client('apigateway', region_name='us-west-2') + client = boto3.client("apigateway", region_name="us-west-2") - response = client.create_rest_api( - name='my_api', - description='this is my api', + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + + response = client.get_rest_api(restApiId=api_id) + + response.pop("ResponseMetadata") + response.pop("createdDate") + response.should.equal( + {"id": api_id, "name": "my_api", "description": "this is my api"} ) - api_id = response['id'] - - response = client.get_rest_api( - restApiId=api_id - ) - - response.pop('ResponseMetadata') - response.pop('createdDate') - response.should.equal({ - 'id': api_id, - 'name': 'my_api', - 'description': 'this is my api', - }) @mock_apigateway def test_list_and_delete_apis(): - client = boto3.client('apigateway', region_name='us-west-2') + client = boto3.client("apigateway", region_name="us-west-2") - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] - client.create_rest_api( - name='my_api2', - description='this is my api2', - ) + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + client.create_rest_api(name="my_api2", description="this is my api2") response = client.get_rest_apis() - len(response['items']).should.equal(2) + len(response["items"]).should.equal(2) - client.delete_rest_api( - restApiId=api_id - ) + client.delete_rest_api(restApiId=api_id) response = client.get_rest_apis() - len(response['items']).should.equal(1) + len(response["items"]).should.equal(1) + + +@mock_apigateway +def test_create_resource__validate_name(): + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + + resources = client.get_resources(restApiId=api_id) + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] + + invalid_names = ["/users", "users/", "users/{user_id}", "us{er"] + valid_names = ["users", "{user_id}", "user_09", "good-dog"] + # All invalid names should throw an exception + for name in invalid_names: + with assert_raises(ClientError) as ex: + client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "Resource's path part only allow a-zA-Z0-9._- and curly braces at the beginning and the end." + ) + # All valid names should go through + for name in valid_names: + client.create_resource(restApiId=api_id, parentId=root_id, pathPart=name) @mock_apigateway def test_create_resource(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] - root_resource = client.get_resource( - restApiId=api_id, - resourceId=root_id, - ) + root_resource = client.get_resource(restApiId=api_id, resourceId=root_id) # this is hard to match against, so remove it - root_resource['ResponseMetadata'].pop('HTTPHeaders', None) - root_resource['ResponseMetadata'].pop('RetryAttempts', None) - root_resource.should.equal({ - 'path': '/', - 'id': root_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'resourceMethods': { - 'GET': {} + root_resource["ResponseMetadata"].pop("HTTPHeaders", None) + root_resource["ResponseMetadata"].pop("RetryAttempts", None) + root_resource.should.equal( + { + "path": "/", + "id": root_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "resourceMethods": {"GET": {}}, } - }) - - response = client.create_resource( - restApiId=api_id, - parentId=root_id, - pathPart='/users', ) - resources = client.get_resources(restApiId=api_id)['items'] + client.create_resource(restApiId=api_id, parentId=root_id, pathPart="users") + + resources = client.get_resources(restApiId=api_id)["items"] len(resources).should.equal(2) - non_root_resource = [ - resource for resource in resources if resource['path'] != '/'][0] + non_root_resource = [resource for resource in resources if resource["path"] != "/"][ + 0 + ] - response = client.delete_resource( - restApiId=api_id, - resourceId=non_root_resource['id'] - ) + client.delete_resource(restApiId=api_id, resourceId=non_root_resource["id"]) - len(client.get_resources(restApiId=api_id)['items']).should.equal(1) + len(client.get_resources(restApiId=api_id)["items"]).should.equal(1) @mock_apigateway def test_child_resource(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] response = client.create_resource( - restApiId=api_id, - parentId=root_id, - pathPart='users', + restApiId=api_id, parentId=root_id, pathPart="users" ) - users_id = response['id'] + users_id = response["id"] response = client.create_resource( - restApiId=api_id, - parentId=users_id, - pathPart='tags', + restApiId=api_id, parentId=users_id, pathPart="tags" ) - tags_id = response['id'] + tags_id = response["id"] - child_resource = client.get_resource( - restApiId=api_id, - resourceId=tags_id, - ) + child_resource = client.get_resource(restApiId=api_id, resourceId=tags_id) # this is hard to match against, so remove it - child_resource['ResponseMetadata'].pop('HTTPHeaders', None) - child_resource['ResponseMetadata'].pop('RetryAttempts', None) - child_resource.should.equal({ - 'path': '/users/tags', - 'pathPart': 'tags', - 'parentId': users_id, - 'id': tags_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'resourceMethods': {'GET': {}}, - }) + child_resource["ResponseMetadata"].pop("HTTPHeaders", None) + child_resource["ResponseMetadata"].pop("RetryAttempts", None) + child_resource.should.equal( + { + "path": "/users/tags", + "pathPart": "tags", + "parentId": users_id, + "id": tags_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "resourceMethods": {"GET": {}}, + } + ) @mock_apigateway def test_create_method(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' - ) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'httpMethod': 'GET', - 'authorizationType': 'none', - 'ResponseMetadata': {'HTTPStatusCode': 200} - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "httpMethod": "GET", + "authorizationType": "none", + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + ) @mock_apigateway def test_create_method_response(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' - ) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") response = client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'statusCode': '200' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + {"ResponseMetadata": {"HTTPStatusCode": 200}, "statusCode": "200"} + ) response = client.get_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'statusCode': '200' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + {"ResponseMetadata": {"HTTPStatusCode": 200}, "statusCode": "200"} + ) response = client.delete_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({'ResponseMetadata': {'HTTPStatusCode': 200}}) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal({"ResponseMetadata": {"HTTPStatusCode": 200}}) @mock_apigateway def test_integrations(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) response = client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - type='HTTP', - uri='http://httpbin.org/robots.txt', + httpMethod="GET", + type="HTTP", + uri="http://httpbin.org/robots.txt", + integrationHttpMethod="POST", ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'httpMethod': 'GET', - 'integrationResponses': { - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'statusCode': 200 - } - }, - 'type': 'HTTP', - 'uri': 'http://httpbin.org/robots.txt' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "httpMethod": "GET", + "integrationResponses": { + "200": { + "responseTemplates": {"application/json": None}, + "statusCode": 200, + } + }, + "type": "HTTP", + "uri": "http://httpbin.org/robots.txt", + } + ) response = client.get_integration( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' + restApiId=api_id, resourceId=root_id, httpMethod="GET" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'httpMethod': 'GET', - 'integrationResponses': { - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'statusCode': 200 - } - }, - 'type': 'HTTP', - 'uri': 'http://httpbin.org/robots.txt' - }) - - response = client.get_resource( - restApiId=api_id, - resourceId=root_id, + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "httpMethod": "GET", + "integrationResponses": { + "200": { + "responseTemplates": {"application/json": None}, + "statusCode": 200, + } + }, + "type": "HTTP", + "uri": "http://httpbin.org/robots.txt", + } ) + + response = client.get_resource(restApiId=api_id, resourceId=root_id) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response['resourceMethods']['GET']['methodIntegration'].should.equal({ - 'httpMethod': 'GET', - 'integrationResponses': { - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'statusCode': 200 - } - }, - 'type': 'HTTP', - 'uri': 'http://httpbin.org/robots.txt' - }) - - client.delete_integration( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response["resourceMethods"]["GET"]["methodIntegration"].should.equal( + { + "httpMethod": "GET", + "integrationResponses": { + "200": { + "responseTemplates": {"application/json": None}, + "statusCode": 200, + } + }, + "type": "HTTP", + "uri": "http://httpbin.org/robots.txt", + } ) - response = client.get_resource( - restApiId=api_id, - resourceId=root_id, - ) - response['resourceMethods']['GET'].shouldnt.contain("methodIntegration") + client.delete_integration(restApiId=api_id, resourceId=root_id, httpMethod="GET") + + response = client.get_resource(restApiId=api_id, resourceId=root_id) + response["resourceMethods"]["GET"].shouldnt.contain("methodIntegration") # Create a new integration with a requestTemplates config client.put_method( restApiId=api_id, resourceId=root_id, - httpMethod='POST', - authorizationType='none', + httpMethod="POST", + authorizationType="none", ) templates = { # example based on # http://docs.aws.amazon.com/apigateway/latest/developerguide/api-as-kinesis-proxy-export-swagger-with-extensions.html - 'application/json': "{\n \"StreamName\": \"$input.params('stream-name')\",\n \"Records\": []\n}" + "application/json": '{\n "StreamName": "$input.params(\'stream-name\')",\n "Records": []\n}' } - test_uri = 'http://example.com/foobar.txt' + test_uri = "http://example.com/foobar.txt" response = client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='POST', - type='HTTP', + httpMethod="POST", + type="HTTP", uri=test_uri, - requestTemplates=templates + requestTemplates=templates, + integrationHttpMethod="POST", ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response['ResponseMetadata'].should.equal({'HTTPStatusCode': 200}) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].should.equal({"HTTPStatusCode": 200}) response = client.get_integration( - restApiId=api_id, - resourceId=root_id, - httpMethod='POST' + restApiId=api_id, resourceId=root_id, httpMethod="POST" ) - response['uri'].should.equal(test_uri) - response['requestTemplates'].should.equal(templates) + response["uri"].should.equal(test_uri) + response["requestTemplates"].should.equal(templates) @mock_apigateway def test_integration_response(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) - response = client.put_integration( + client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - type='HTTP', - uri='http://httpbin.org/robots.txt', + httpMethod="GET", + type="HTTP", + uri="http://httpbin.org/robots.txt", + integrationHttpMethod="POST", ) response = client.put_integration_response( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - statusCode='200', - selectionPattern='foobar', + httpMethod="GET", + statusCode="200", + selectionPattern="foobar", + responseTemplates={}, ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'statusCode': '200', - 'selectionPattern': 'foobar', - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'responseTemplates': { - 'application/json': None + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "statusCode": "200", + "selectionPattern": "foobar", + "ResponseMetadata": {"HTTPStatusCode": 200}, + "responseTemplates": {"application/json": None}, } - }) + ) response = client.get_integration_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'statusCode': '200', - 'selectionPattern': 'foobar', - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'responseTemplates': { - 'application/json': None + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "statusCode": "200", + "selectionPattern": "foobar", + "ResponseMetadata": {"HTTPStatusCode": 200}, + "responseTemplates": {"application/json": None}, } - }) + ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - ) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response['methodIntegration']['integrationResponses'].should.equal({ - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'selectionPattern': 'foobar', - 'statusCode': '200' + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response["methodIntegration"]["integrationResponses"].should.equal( + { + "200": { + "responseTemplates": {"application/json": None}, + "selectionPattern": "foobar", + "statusCode": "200", + } } - }) + ) response = client.delete_integration_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - ) - response['methodIntegration']['integrationResponses'].should.equal({}) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") + response["methodIntegration"]["integrationResponses"].should.equal({}) @mock_apigateway def test_update_stage_configuration(): - client = boto3.client('apigateway', region_name='us-west-2') - stage_name = 'staging' - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + create_method_integration(client, api_id) response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, - description="1.0.1" + restApiId=api_id, stageName=stage_name, description="1.0.1" ) - deployment_id = response['id'] + deployment_id = response["id"] - response = client.get_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id) # createdDate is hard to match against, remove it - response.pop('createdDate', None) + response.pop("createdDate", None) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'id': deployment_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '1.0.1' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "id": deployment_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "1.0.1", + } + ) response = client.create_deployment( + restApiId=api_id, stageName=stage_name, description="1.0.2" + ) + deployment_id2 = response["id"] + + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + stage["stageName"].should.equal(stage_name) + stage["deploymentId"].should.equal(deployment_id2) + stage.shouldnt.have.key("cacheClusterSize") + + client.update_stage( restApiId=api_id, stageName=stage_name, - description="1.0.2" + patchOperations=[ + {"op": "replace", "path": "/cacheClusterEnabled", "value": "True"} + ], ) - deployment_id2 = response['id'] - - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - stage['stageName'].should.equal(stage_name) - stage['deploymentId'].should.equal(deployment_id2) - stage.shouldnt.have.key('cacheClusterSize') - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "replace", - "path": "/cacheClusterEnabled", - "value": "True" - } - ]) - - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - - stage.should.have.key('cacheClusterSize').which.should.equal("0.5") - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "replace", - "path": "/cacheClusterSize", - "value": "1.6" - } - ]) - - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - - stage.should.have.key('cacheClusterSize').which.should.equal("1.6") - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "replace", - "path": "/deploymentId", - "value": deployment_id - }, - { - "op": "replace", - "path": "/variables/environment", - "value": "dev" - }, - { - "op": "replace", - "path": "/variables/region", - "value": "eu-west-1" - }, - { - "op": "replace", - "path": "/*/*/caching/dataEncrypted", - "value": "True" - }, - { - "op": "replace", - "path": "/cacheClusterEnabled", - "value": "True" - }, - { - "op": "replace", - "path": "/description", - "value": "stage description update" - }, - { - "op": "replace", - "path": "/cacheClusterSize", - "value": "1.6" - } - ]) - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "remove", - "path": "/variables/region", - "value": "eu-west-1" - } - ]) stage = client.get_stage(restApiId=api_id, stageName=stage_name) - stage['description'].should.match('stage description update') - stage['cacheClusterSize'].should.equal("1.6") - stage['variables']['environment'].should.match('dev') - stage['variables'].should_not.have.key('region') - stage['cacheClusterEnabled'].should.be.true - stage['deploymentId'].should.match(deployment_id) - stage['methodSettings'].should.have.key('*/*') - stage['methodSettings'][ - '*/*'].should.have.key('cacheDataEncrypted').which.should.be.true + stage.should.have.key("cacheClusterSize").which.should.equal("0.5") + + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "replace", "path": "/cacheClusterSize", "value": "1.6"} + ], + ) + + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + + stage.should.have.key("cacheClusterSize").which.should.equal("1.6") + + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "replace", "path": "/deploymentId", "value": deployment_id}, + {"op": "replace", "path": "/variables/environment", "value": "dev"}, + {"op": "replace", "path": "/variables/region", "value": "eu-west-1"}, + {"op": "replace", "path": "/*/*/caching/dataEncrypted", "value": "True"}, + {"op": "replace", "path": "/cacheClusterEnabled", "value": "True"}, + { + "op": "replace", + "path": "/description", + "value": "stage description update", + }, + {"op": "replace", "path": "/cacheClusterSize", "value": "1.6"}, + ], + ) + + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "remove", "path": "/variables/region", "value": "eu-west-1"} + ], + ) + + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + + stage["description"].should.match("stage description update") + stage["cacheClusterSize"].should.equal("1.6") + stage["variables"]["environment"].should.match("dev") + stage["variables"].should_not.have.key("region") + stage["cacheClusterEnabled"].should.be.true + stage["deploymentId"].should.match(deployment_id) + stage["methodSettings"].should.have.key("*/*") + stage["methodSettings"]["*/*"].should.have.key( + "cacheDataEncrypted" + ).which.should.be.true try: - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "add", - "path": "/notasetting", - "value": "eu-west-1" - } - ]) + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "add", "path": "/notasetting", "value": "eu-west-1"} + ], + ) assert False.should.be.ok # Fail, should not be here except Exception: assert True.should.be.ok @@ -659,478 +547,830 @@ def test_update_stage_configuration(): @mock_apigateway def test_non_existent_stage(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] - client.get_stage.when.called_with( - restApiId=api_id, stageName='xxx').should.throw(ClientError) + client.get_stage.when.called_with(restApiId=api_id, stageName="xxx").should.throw( + ClientError + ) @mock_apigateway def test_create_stage(): - client = boto3.client('apigateway', region_name='us-west-2') - stage_name = 'staging' - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] - response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, - ) - deployment_id = response['id'] + create_method_integration(client, api_id) + response = client.create_deployment(restApiId=api_id, stageName=stage_name) + deployment_id = response["id"] - response = client.get_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id) # createdDate is hard to match against, remove it - response.pop('createdDate', None) + response.pop("createdDate", None) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'id': deployment_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '' - }) - - response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "id": deployment_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + } ) - deployment_id2 = response['id'] + response = client.create_deployment(restApiId=api_id, stageName=stage_name) - response = client.get_deployments( - restApiId=api_id, - ) + deployment_id2 = response["id"] + + response = client.get_deployments(restApiId=api_id) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) - response['items'][0].pop('createdDate') - response['items'][1].pop('createdDate') - response['items'][0]['id'].should.match( - r"{0}|{1}".format(deployment_id2, deployment_id)) - response['items'][1]['id'].should.match( - r"{0}|{1}".format(deployment_id2, deployment_id)) + response["items"][0].pop("createdDate") + response["items"][1].pop("createdDate") + response["items"][0]["id"].should.match( + r"{0}|{1}".format(deployment_id2, deployment_id) + ) + response["items"][1]["id"].should.match( + r"{0}|{1}".format(deployment_id2, deployment_id) + ) - new_stage_name = 'current' + new_stage_name = "current" response = client.create_stage( - restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2) - - # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - - response.should.equal({ - 'stageName': new_stage_name, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '', - 'cacheClusterEnabled': False - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name - ) - stage['stageName'].should.equal(new_stage_name) - stage['deploymentId'].should.equal(deployment_id2) - - new_stage_name_with_vars = 'stage_with_vars' - response = client.create_stage(restApiId=api_id, stageName=new_stage_name_with_vars, deploymentId=deployment_id2, variables={ - "env": "dev" - }) - - # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - - response.should.equal({ - 'stageName': new_stage_name_with_vars, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {"env": "dev"}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '', - 'cacheClusterEnabled': False - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name_with_vars - ) - stage['stageName'].should.equal(new_stage_name_with_vars) - stage['deploymentId'].should.equal(deployment_id2) - stage['variables'].should.have.key('env').which.should.match("dev") - - new_stage_name = 'stage_with_vars_and_cache_settings' - response = client.create_stage(restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2, variables={ - "env": "dev" - }, cacheClusterEnabled=True, description="hello moto") - - # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - - response.should.equal({ - 'stageName': new_stage_name, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {"env": "dev"}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': 'hello moto', - 'cacheClusterEnabled': True, - 'cacheClusterSize': "0.5" - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name + restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2 ) - stage['cacheClusterSize'].should.equal("0.5") + # this is hard to match against, so remove it + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) - new_stage_name = 'stage_with_vars_and_cache_settings_and_size' - response = client.create_stage(restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2, variables={ - "env": "dev" - }, cacheClusterEnabled=True, cacheClusterSize="1.6", description="hello moto") + response.should.equal( + { + "stageName": new_stage_name, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + "cacheClusterEnabled": False, + } + ) + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name) + stage["stageName"].should.equal(new_stage_name) + stage["deploymentId"].should.equal(deployment_id2) + + new_stage_name_with_vars = "stage_with_vars" + response = client.create_stage( + restApiId=api_id, + stageName=new_stage_name_with_vars, + deploymentId=deployment_id2, + variables={"env": "dev"}, + ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) - response.should.equal({ - 'stageName': new_stage_name, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {"env": "dev"}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': 'hello moto', - 'cacheClusterEnabled': True, - 'cacheClusterSize': "1.6" - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name + response.should.equal( + { + "stageName": new_stage_name_with_vars, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {"env": "dev"}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + "cacheClusterEnabled": False, + } + ) + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name_with_vars) + stage["stageName"].should.equal(new_stage_name_with_vars) + stage["deploymentId"].should.equal(deployment_id2) + stage["variables"].should.have.key("env").which.should.match("dev") + + new_stage_name = "stage_with_vars_and_cache_settings" + response = client.create_stage( + restApiId=api_id, + stageName=new_stage_name, + deploymentId=deployment_id2, + variables={"env": "dev"}, + cacheClusterEnabled=True, + description="hello moto", + ) + + # this is hard to match against, so remove it + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + + response.should.equal( + { + "stageName": new_stage_name, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {"env": "dev"}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "hello moto", + "cacheClusterEnabled": True, + "cacheClusterSize": "0.5", + } + ) + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name) + + stage["cacheClusterSize"].should.equal("0.5") + + new_stage_name = "stage_with_vars_and_cache_settings_and_size" + response = client.create_stage( + restApiId=api_id, + stageName=new_stage_name, + deploymentId=deployment_id2, + variables={"env": "dev"}, + cacheClusterEnabled=True, + cacheClusterSize="1.6", + description="hello moto", + ) + + # this is hard to match against, so remove it + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + + response.should.equal( + { + "stageName": new_stage_name, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {"env": "dev"}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "hello moto", + "cacheClusterEnabled": True, + "cacheClusterSize": "1.6", + } + ) + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name) + stage["stageName"].should.equal(new_stage_name) + stage["deploymentId"].should.equal(deployment_id2) + stage["variables"].should.have.key("env").which.should.match("dev") + stage["cacheClusterSize"].should.equal("1.6") + + +@mock_apigateway +def test_create_deployment_requires_REST_methods(): + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + + with assert_raises(ClientError) as ex: + client.create_deployment(restApiId=api_id, stageName=stage_name)["id"] + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "The REST API doesn't contain any methods" + ) + + +@mock_apigateway +def test_create_deployment_requires_REST_method_integrations(): + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + resources = client.get_resources(restApiId=api_id) + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] + + client.put_method( + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="NONE" + ) + + with assert_raises(ClientError) as ex: + client.create_deployment(restApiId=api_id, stageName=stage_name)["id"] + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "No integration defined for method" + ) + + +@mock_apigateway +def test_create_simple_deployment_with_get_method(): + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + create_method_integration(client, api_id) + deployment = client.create_deployment(restApiId=api_id, stageName=stage_name) + assert "id" in deployment + + +@mock_apigateway +def test_create_simple_deployment_with_post_method(): + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + create_method_integration(client, api_id, httpMethod="POST") + deployment = client.create_deployment(restApiId=api_id, stageName=stage_name) + assert "id" in deployment + + +@mock_apigateway +# https://github.com/aws/aws-sdk-js/issues/2588 +def test_put_integration_response_requires_responseTemplate(): + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + resources = client.get_resources(restApiId=api_id) + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] + + client.put_method( + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="NONE" + ) + client.put_method_response( + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" + ) + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type="HTTP", + uri="http://httpbin.org/robots.txt", + integrationHttpMethod="POST", + ) + + with assert_raises(ClientError) as ex: + client.put_integration_response( + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" + ) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal("Invalid request input") + # Works fine if responseTemplate is defined + client.put_integration_response( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + statusCode="200", + responseTemplates={}, + ) + + +@mock_apigateway +def test_put_integration_validation(): + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + resources = client.get_resources(restApiId=api_id) + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] + + client.put_method( + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="NONE" + ) + client.put_method_response( + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" + ) + + http_types = ["HTTP", "HTTP_PROXY"] + aws_types = ["AWS", "AWS_PROXY"] + types_requiring_integration_method = http_types + aws_types + types_not_requiring_integration_method = ["MOCK"] + + for type in types_requiring_integration_method: + # Ensure that integrations of these types fail if no integrationHttpMethod is provided + with assert_raises(ClientError) as ex: + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="http://httpbin.org/robots.txt", + ) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "Enumeration value for HttpMethod must be non-empty" + ) + for type in types_not_requiring_integration_method: + # Ensure that integrations of these types do not need the integrationHttpMethod + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="http://httpbin.org/robots.txt", + ) + for type in http_types: + # Ensure that it works fine when providing the integrationHttpMethod-argument + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="http://httpbin.org/robots.txt", + integrationHttpMethod="POST", + ) + for type in ["AWS"]: + # Ensure that it works fine when providing the integrationHttpMethod + credentials + client.put_integration( + restApiId=api_id, + resourceId=root_id, + credentials="arn:aws:iam::{}:role/service-role/testfunction-role-oe783psq".format( + ACCOUNT_ID + ), + httpMethod="GET", + type=type, + uri="arn:aws:apigateway:us-west-2:s3:path/b/k", + integrationHttpMethod="POST", + ) + for type in aws_types: + # Ensure that credentials are not required when URI points to a Lambda stream + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="arn:aws:apigateway:eu-west-1:lambda:path/2015-03-31/functions/arn:aws:lambda:eu-west-1:012345678901:function:MyLambda/invocations", + integrationHttpMethod="POST", + ) + for type in ["AWS_PROXY"]: + # Ensure that aws_proxy does not support S3 + with assert_raises(ClientError) as ex: + client.put_integration( + restApiId=api_id, + resourceId=root_id, + credentials="arn:aws:iam::{}:role/service-role/testfunction-role-oe783psq".format( + ACCOUNT_ID + ), + httpMethod="GET", + type=type, + uri="arn:aws:apigateway:us-west-2:s3:path/b/k", + integrationHttpMethod="POST", + ) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "Integrations of type 'AWS_PROXY' currently only supports Lambda function and Firehose stream invocations." + ) + for type in aws_types: + # Ensure that the Role ARN is for the current account + with assert_raises(ClientError) as ex: + client.put_integration( + restApiId=api_id, + resourceId=root_id, + credentials="arn:aws:iam::000000000000:role/service-role/testrole", + httpMethod="GET", + type=type, + uri="arn:aws:apigateway:us-west-2:s3:path/b/k", + integrationHttpMethod="POST", + ) + ex.exception.response["Error"]["Code"].should.equal("AccessDeniedException") + ex.exception.response["Error"]["Message"].should.equal( + "Cross-account pass role is not allowed." + ) + for type in ["AWS"]: + # Ensure that the Role ARN is specified for aws integrations + with assert_raises(ClientError) as ex: + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="arn:aws:apigateway:us-west-2:s3:path/b/k", + integrationHttpMethod="POST", + ) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "Role ARN must be specified for AWS integrations" + ) + for type in http_types: + # Ensure that the URI is valid HTTP + with assert_raises(ClientError) as ex: + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="non-valid-http", + integrationHttpMethod="POST", + ) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "Invalid HTTP endpoint specified for URI" + ) + for type in aws_types: + # Ensure that the URI is an ARN + with assert_raises(ClientError) as ex: + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="non-valid-arn", + integrationHttpMethod="POST", + ) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "Invalid ARN specified in the request" + ) + for type in aws_types: + # Ensure that the URI is a valid ARN + with assert_raises(ClientError) as ex: + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod="GET", + type=type, + uri="arn:aws:iam::0000000000:role/service-role/asdf", + integrationHttpMethod="POST", + ) + ex.exception.response["Error"]["Code"].should.equal("BadRequestException") + ex.exception.response["Error"]["Message"].should.equal( + "AWS ARN for integration must contain path or action" + ) + + +@mock_apigateway +def test_delete_stage(): + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + create_method_integration(client, api_id) + deployment_id1 = client.create_deployment(restApiId=api_id, stageName=stage_name)[ + "id" + ] + deployment_id2 = client.create_deployment(restApiId=api_id, stageName=stage_name)[ + "id" + ] + + new_stage_name = "current" + client.create_stage( + restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id1 + ) + + new_stage_name_with_vars = "stage_with_vars" + client.create_stage( + restApiId=api_id, + stageName=new_stage_name_with_vars, + deploymentId=deployment_id2, + variables={"env": "dev"}, + ) + stages = client.get_stages(restApiId=api_id)["item"] + sorted([stage["stageName"] for stage in stages]).should.equal( + sorted([new_stage_name, new_stage_name_with_vars, stage_name]) + ) + # delete stage + response = client.delete_stage(restApiId=api_id, stageName=new_stage_name_with_vars) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(202) + # verify other stage still exists + stages = client.get_stages(restApiId=api_id)["item"] + sorted([stage["stageName"] for stage in stages]).should.equal( + sorted([new_stage_name, stage_name]) ) - stage['stageName'].should.equal(new_stage_name) - stage['deploymentId'].should.equal(deployment_id2) - stage['variables'].should.have.key('env').which.should.match("dev") - stage['cacheClusterSize'].should.equal("1.6") @mock_apigateway def test_deployment(): - client = boto3.client('apigateway', region_name='us-west-2') - stage_name = 'staging' - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + create_method_integration(client, api_id) - response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, - ) - deployment_id = response['id'] + response = client.create_deployment(restApiId=api_id, stageName=stage_name) + deployment_id = response["id"] - response = client.get_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id) # createdDate is hard to match against, remove it - response.pop('createdDate', None) + response.pop("createdDate", None) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'id': deployment_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '' - }) - - response = client.get_deployments( - restApiId=api_id, + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "id": deployment_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + } ) - response['items'][0].pop('createdDate') - response['items'].should.equal([ - {'id': deployment_id, 'description': ''} - ]) + response = client.get_deployments(restApiId=api_id) - response = client.delete_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response["items"][0].pop("createdDate") + response["items"].should.equal([{"id": deployment_id, "description": ""}]) - response = client.get_deployments( - restApiId=api_id, - ) - len(response['items']).should.equal(0) + client.delete_deployment(restApiId=api_id, deploymentId=deployment_id) + + response = client.get_deployments(restApiId=api_id) + len(response["items"]).should.equal(0) # test deployment stages - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - stage['stageName'].should.equal(stage_name) - stage['deploymentId'].should.equal(deployment_id) + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + stage["stageName"].should.equal(stage_name) + stage["deploymentId"].should.equal(deployment_id) - stage = client.update_stage( + client.update_stage( restApiId=api_id, stageName=stage_name, patchOperations=[ - { - 'op': 'replace', - 'path': '/description', - 'value': '_new_description_' - }, - ] + {"op": "replace", "path": "/description", "value": "_new_description_"} + ], ) - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - stage['stageName'].should.equal(stage_name) - stage['deploymentId'].should.equal(deployment_id) - stage['description'].should.equal('_new_description_') + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + stage["stageName"].should.equal(stage_name) + stage["deploymentId"].should.equal(deployment_id) + stage["description"].should.equal("_new_description_") @mock_apigateway def test_http_proxying_integration(): responses.add( - responses.GET, "http://httpbin.org/robots.txt", body='a fake response' + responses.GET, "http://httpbin.org/robots.txt", body="a fake response" ) - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) response = client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - type='HTTP', - uri='http://httpbin.org/robots.txt', + httpMethod="GET", + type="HTTP", + uri="http://httpbin.org/robots.txt", + integrationHttpMethod="POST", ) - stage_name = 'staging' - client.create_deployment( - restApiId=api_id, - stageName=stage_name, - ) + stage_name = "staging" + client.create_deployment(restApiId=api_id, stageName=stage_name) deploy_url = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}".format( - api_id=api_id, region_name=region_name, stage_name=stage_name) + api_id=api_id, region_name=region_name, stage_name=stage_name + ) if not settings.TEST_SERVER_MODE: requests.get(deploy_url).content.should.equal(b"a fake response") @mock_apigateway -def test_api_keys(): - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) +def test_create_api_key(): + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) + + apikey_value = "12345" + apikey_name = "TESTKEY1" + payload = {"value": apikey_value, "name": apikey_name} + + client.create_api_key(**payload) + response = client.get_api_keys() - len(response['items']).should.equal(0) + len(response["items"]).should.equal(1) - apikey_value = '12345' - apikey_name = 'TESTKEY1' - payload = {'value': apikey_value, 'name': apikey_name} - response = client.create_api_key(**payload) - apikey = client.get_api_key(apiKey=response['id']) - apikey['name'].should.equal(apikey_name) - apikey['value'].should.equal(apikey_value) + client.create_api_key.when.called_with(**payload).should.throw(ClientError) - apikey_name = 'TESTKEY2' - payload = {'name': apikey_name } - response = client.create_api_key(**payload) - apikey_id = response['id'] - apikey = client.get_api_key(apiKey=apikey_id) - apikey['name'].should.equal(apikey_name) - len(apikey['value']).should.equal(40) - apikey_name = 'TESTKEY3' - payload = {'name': apikey_name } +@mock_apigateway +def test_api_keys(): + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) + response = client.get_api_keys() + len(response["items"]).should.equal(0) + + apikey_value = "12345" + apikey_name = "TESTKEY1" + payload = { + "value": apikey_value, + "name": apikey_name, + "tags": {"tag1": "test_tag1", "tag2": "1"}, + } response = client.create_api_key(**payload) - apikey_id = response['id'] + apikey_id = response["id"] + apikey = client.get_api_key(apiKey=response["id"]) + apikey["name"].should.equal(apikey_name) + apikey["value"].should.equal(apikey_value) + apikey["tags"]["tag1"].should.equal("test_tag1") + apikey["tags"]["tag2"].should.equal("1") patch_operations = [ - {'op': 'replace', 'path': '/name', 'value': 'TESTKEY3_CHANGE'}, - {'op': 'replace', 'path': '/customerId', 'value': '12345'}, - {'op': 'replace', 'path': '/description', 'value': 'APIKEY UPDATE TEST'}, - {'op': 'replace', 'path': '/enabled', 'value': 'false'}, + {"op": "replace", "path": "/name", "value": "TESTKEY3_CHANGE"}, + {"op": "replace", "path": "/customerId", "value": "12345"}, + {"op": "replace", "path": "/description", "value": "APIKEY UPDATE TEST"}, + {"op": "replace", "path": "/enabled", "value": "false"}, ] response = client.update_api_key(apiKey=apikey_id, patchOperations=patch_operations) - response['name'].should.equal('TESTKEY3_CHANGE') - response['customerId'].should.equal('12345') - response['description'].should.equal('APIKEY UPDATE TEST') - response['enabled'].should.equal(False) + response["name"].should.equal("TESTKEY3_CHANGE") + response["customerId"].should.equal("12345") + response["description"].should.equal("APIKEY UPDATE TEST") + response["enabled"].should.equal(False) + + updated_api_key = client.get_api_key(apiKey=apikey_id) + updated_api_key["name"].should.equal("TESTKEY3_CHANGE") + updated_api_key["customerId"].should.equal("12345") + updated_api_key["description"].should.equal("APIKEY UPDATE TEST") + updated_api_key["enabled"].should.equal(False) response = client.get_api_keys() - len(response['items']).should.equal(3) + len(response["items"]).should.equal(1) + + payload = {"name": apikey_name} + client.create_api_key(**payload) + + response = client.get_api_keys() + len(response["items"]).should.equal(2) client.delete_api_key(apiKey=apikey_id) response = client.get_api_keys() - len(response['items']).should.equal(2) + len(response["items"]).should.equal(1) + @mock_apigateway def test_usage_plans(): - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) response = client.get_usage_plans() - len(response['items']).should.equal(0) + len(response["items"]).should.equal(0) - usage_plan_name = 'TEST-PLAN' - payload = {'name': usage_plan_name} + usage_plan_name = "TEST-PLAN" + payload = {"name": usage_plan_name} response = client.create_usage_plan(**payload) - usage_plan = client.get_usage_plan(usagePlanId=response['id']) - usage_plan['name'].should.equal(usage_plan_name) - usage_plan['apiStages'].should.equal([]) + usage_plan = client.get_usage_plan(usagePlanId=response["id"]) + usage_plan["name"].should.equal(usage_plan_name) + usage_plan["apiStages"].should.equal([]) - usage_plan_name = 'TEST-PLAN-2' - usage_plan_description = 'Description' - usage_plan_quota = {'limit': 10, 'period': 'DAY', 'offset': 0} - usage_plan_throttle = {'rateLimit': 2, 'burstLimit': 1} - usage_plan_api_stages = [{'apiId': 'foo', 'stage': 'bar'}] - payload = {'name': usage_plan_name, 'description': usage_plan_description, 'quota': usage_plan_quota, 'throttle': usage_plan_throttle, 'apiStages': usage_plan_api_stages} + payload = { + "name": "TEST-PLAN-2", + "description": "Description", + "quota": {"limit": 10, "period": "DAY", "offset": 0}, + "throttle": {"rateLimit": 2, "burstLimit": 1}, + "apiStages": [{"apiId": "foo", "stage": "bar"}], + "tags": {"tag_key": "tag_value"}, + } response = client.create_usage_plan(**payload) - usage_plan_id = response['id'] + usage_plan_id = response["id"] usage_plan = client.get_usage_plan(usagePlanId=usage_plan_id) - usage_plan['name'].should.equal(usage_plan_name) - usage_plan['description'].should.equal(usage_plan_description) - usage_plan['apiStages'].should.equal(usage_plan_api_stages) - usage_plan['throttle'].should.equal(usage_plan_throttle) - usage_plan['quota'].should.equal(usage_plan_quota) + + # The payload should remain unchanged + for key, value in payload.items(): + usage_plan.should.have.key(key).which.should.equal(value) + + # Status code should be 200 + usage_plan["ResponseMetadata"].should.have.key("HTTPStatusCode").which.should.equal( + 200 + ) + + # An Id should've been generated + usage_plan.should.have.key("id").which.should_not.be.none response = client.get_usage_plans() - len(response['items']).should.equal(2) + len(response["items"]).should.equal(2) client.delete_usage_plan(usagePlanId=usage_plan_id) response = client.get_usage_plans() - len(response['items']).should.equal(1) + len(response["items"]).should.equal(1) + @mock_apigateway def test_usage_plan_keys(): - region_name = 'us-west-2' - usage_plan_id = 'test_usage_plan_id' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) usage_plan_id = "test" # Create an API key so we can use it - key_name = 'test-api-key' + key_name = "test-api-key" response = client.create_api_key(name=key_name) key_id = response["id"] key_value = response["value"] # Get current plan keys (expect none) response = client.get_usage_plan_keys(usagePlanId=usage_plan_id) - len(response['items']).should.equal(0) + len(response["items"]).should.equal(0) # Create usage plan key - key_type = 'API_KEY' - payload = {'usagePlanId': usage_plan_id, 'keyId': key_id, 'keyType': key_type } + key_type = "API_KEY" + payload = {"usagePlanId": usage_plan_id, "keyId": key_id, "keyType": key_type} response = client.create_usage_plan_key(**payload) usage_plan_key_id = response["id"] # Get current plan keys (expect 1) response = client.get_usage_plan_keys(usagePlanId=usage_plan_id) - len(response['items']).should.equal(1) + len(response["items"]).should.equal(1) # Get a single usage plan key and check it matches the created one - usage_plan_key = client.get_usage_plan_key(usagePlanId=usage_plan_id, keyId=usage_plan_key_id) - usage_plan_key['name'].should.equal(key_name) - usage_plan_key['id'].should.equal(key_id) - usage_plan_key['type'].should.equal(key_type) - usage_plan_key['value'].should.equal(key_value) + usage_plan_key = client.get_usage_plan_key( + usagePlanId=usage_plan_id, keyId=usage_plan_key_id + ) + usage_plan_key["name"].should.equal(key_name) + usage_plan_key["id"].should.equal(key_id) + usage_plan_key["type"].should.equal(key_type) + usage_plan_key["value"].should.equal(key_value) # Delete usage plan key client.delete_usage_plan_key(usagePlanId=usage_plan_id, keyId=key_id) # Get current plan keys (expect none) response = client.get_usage_plan_keys(usagePlanId=usage_plan_id) - len(response['items']).should.equal(0) + len(response["items"]).should.equal(0) + @mock_apigateway def test_create_usage_plan_key_non_existent_api_key(): - region_name = 'us-west-2' - usage_plan_id = 'test_usage_plan_id' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) usage_plan_id = "test" # Attempt to create a usage plan key for a API key that doesn't exists - payload = {'usagePlanId': usage_plan_id, 'keyId': 'non-existent', 'keyType': 'API_KEY' } + payload = { + "usagePlanId": usage_plan_id, + "keyId": "non-existent", + "keyType": "API_KEY", + } client.create_usage_plan_key.when.called_with(**payload).should.throw(ClientError) @mock_apigateway def test_get_usage_plans_using_key_id(): - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) # Create 2 Usage Plans # one will be attached to an API Key, the other will remain unattached - attached_plan = client.create_usage_plan(name='Attached') - unattached_plan = client.create_usage_plan(name='Unattached') + attached_plan = client.create_usage_plan(name="Attached") + unattached_plan = client.create_usage_plan(name="Unattached") # Create an API key # to attach to the usage plan - key_name = 'test-api-key' + key_name = "test-api-key" response = client.create_api_key(name=key_name) key_id = response["id"] # Create a Usage Plan Key # Attached the Usage Plan and API Key - key_type = 'API_KEY' - payload = {'usagePlanId': attached_plan['id'], 'keyId': key_id, 'keyType': key_type} + key_type = "API_KEY" + payload = {"usagePlanId": attached_plan["id"], "keyId": key_id, "keyType": key_type} response = client.create_usage_plan_key(**payload) # All usage plans should be returned when keyId is not included all_plans = client.get_usage_plans() - len(all_plans['items']).should.equal(2) + len(all_plans["items"]).should.equal(2) # Only the usage plan attached to the given api key are included only_plans_with_key = client.get_usage_plans(keyId=key_id) - len(only_plans_with_key['items']).should.equal(1) - only_plans_with_key['items'][0]['name'].should.equal(attached_plan['name']) - only_plans_with_key['items'][0]['id'].should.equal(attached_plan['id']) + len(only_plans_with_key["items"]).should.equal(1) + only_plans_with_key["items"][0]["name"].should.equal(attached_plan["name"]) + only_plans_with_key["items"][0]["id"].should.equal(attached_plan["id"]) + + +def create_method_integration(client, api_id, httpMethod="GET"): + resources = client.get_resources(restApiId=api_id) + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] + client.put_method( + restApiId=api_id, + resourceId=root_id, + httpMethod=httpMethod, + authorizationType="NONE", + ) + client.put_method_response( + restApiId=api_id, resourceId=root_id, httpMethod=httpMethod, statusCode="200" + ) + client.put_integration( + restApiId=api_id, + resourceId=root_id, + httpMethod=httpMethod, + type="HTTP", + uri="http://httpbin.org/robots.txt", + integrationHttpMethod="POST", + ) + client.put_integration_response( + restApiId=api_id, + resourceId=root_id, + httpMethod=httpMethod, + statusCode="200", + responseTemplates={}, + ) diff --git a/tests/test_apigateway/test_server.py b/tests/test_apigateway/test_server.py index 953d942cc..08b20cc61 100644 --- a/tests/test_apigateway/test_server.py +++ b/tests/test_apigateway/test_server.py @@ -4,88 +4,100 @@ import json import moto.server as server -''' +""" Test the different server responses -''' +""" def test_list_apis(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() - res = test_client.get('/restapis') + res = test_client.get("/restapis") res.data.should.equal(b'{"item": []}') + def test_usage_plans_apis(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() # List usage plans (expect empty) - res = test_client.get('/usageplans') + res = test_client.get("/usageplans") json.loads(res.data)["item"].should.have.length_of(0) # Create usage plan - res = test_client.post('/usageplans', data=json.dumps({'name': 'test'})) + res = test_client.post("/usageplans", data=json.dumps({"name": "test"})) created_plan = json.loads(res.data) - created_plan['name'].should.equal('test') + created_plan["name"].should.equal("test") # List usage plans (expect 1 plan) - res = test_client.get('/usageplans') + res = test_client.get("/usageplans") json.loads(res.data)["item"].should.have.length_of(1) # Get single usage plan - res = test_client.get('/usageplans/{0}'.format(created_plan["id"])) + res = test_client.get("/usageplans/{0}".format(created_plan["id"])) fetched_plan = json.loads(res.data) fetched_plan.should.equal(created_plan) # Delete usage plan - res = test_client.delete('/usageplans/{0}'.format(created_plan["id"])) - res.data.should.equal(b'{}') + res = test_client.delete("/usageplans/{0}".format(created_plan["id"])) + res.data.should.equal(b"{}") # List usage plans (expect empty again) - res = test_client.get('/usageplans') + res = test_client.get("/usageplans") json.loads(res.data)["item"].should.have.length_of(0) + def test_usage_plans_keys(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() - usage_plan_id = 'test_plan_id' + usage_plan_id = "test_plan_id" # Create API key to be used in tests - res = test_client.post('/apikeys', data=json.dumps({'name': 'test'})) + res = test_client.post("/apikeys", data=json.dumps({"name": "test"})) created_api_key = json.loads(res.data) # List usage plans keys (expect empty) - res = test_client.get('/usageplans/{0}/keys'.format(usage_plan_id)) + res = test_client.get("/usageplans/{0}/keys".format(usage_plan_id)) json.loads(res.data)["item"].should.have.length_of(0) # Create usage plan key - res = test_client.post('/usageplans/{0}/keys'.format(usage_plan_id), data=json.dumps({'keyId': created_api_key["id"], 'keyType': 'API_KEY'})) + res = test_client.post( + "/usageplans/{0}/keys".format(usage_plan_id), + data=json.dumps({"keyId": created_api_key["id"], "keyType": "API_KEY"}), + ) created_usage_plan_key = json.loads(res.data) # List usage plans keys (expect 1 key) - res = test_client.get('/usageplans/{0}/keys'.format(usage_plan_id)) + res = test_client.get("/usageplans/{0}/keys".format(usage_plan_id)) json.loads(res.data)["item"].should.have.length_of(1) # Get single usage plan key - res = test_client.get('/usageplans/{0}/keys/{1}'.format(usage_plan_id, created_api_key["id"])) + res = test_client.get( + "/usageplans/{0}/keys/{1}".format(usage_plan_id, created_api_key["id"]) + ) fetched_plan_key = json.loads(res.data) fetched_plan_key.should.equal(created_usage_plan_key) # Delete usage plan key - res = test_client.delete('/usageplans/{0}/keys/{1}'.format(usage_plan_id, created_api_key["id"])) - res.data.should.equal(b'{}') + res = test_client.delete( + "/usageplans/{0}/keys/{1}".format(usage_plan_id, created_api_key["id"]) + ) + res.data.should.equal(b"{}") # List usage plans keys (expect to be empty again) - res = test_client.get('/usageplans/{0}/keys'.format(usage_plan_id)) + res = test_client.get("/usageplans/{0}/keys".format(usage_plan_id)) json.loads(res.data)["item"].should.have.length_of(0) + def test_create_usage_plans_key_non_existent_api_key(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() - usage_plan_id = 'test_plan_id' + usage_plan_id = "test_plan_id" # Create usage plan key with non-existent api key - res = test_client.post('/usageplans/{0}/keys'.format(usage_plan_id), data=json.dumps({'keyId': 'non-existent', 'keyType': 'API_KEY'})) + res = test_client.post( + "/usageplans/{0}/keys".format(usage_plan_id), + data=json.dumps({"keyId": "non-existent", "keyType": "API_KEY"}), + ) res.status_code.should.equal(404) - diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py new file mode 100644 index 000000000..d36653910 --- /dev/null +++ b/tests/test_athena/test_athena.py @@ -0,0 +1,59 @@ +from __future__ import unicode_literals + +import datetime + +from botocore.exceptions import ClientError +import boto3 +import sure # noqa + +from moto import mock_athena + + +@mock_athena +def test_create_work_group(): + client = boto3.client("athena", region_name="us-east-1") + + response = client.create_work_group( + Name="athena_workgroup", + Description="Test work group", + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://bucket-name/prefix/", + "EncryptionConfiguration": { + "EncryptionOption": "SSE_KMS", + "KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1", + }, + } + }, + Tags=[], + ) + + try: + # The second time should throw an error + response = client.create_work_group( + Name="athena_workgroup", + Description="duplicate", + Configuration={ + "ResultConfiguration": { + "OutputLocation": "s3://bucket-name/prefix/", + "EncryptionConfiguration": { + "EncryptionOption": "SSE_KMS", + "KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1", + }, + } + }, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidRequestException") + err.response["Error"]["Message"].should.equal("WorkGroup already exists") + else: + raise RuntimeError("Should have raised ResourceNotFoundException") + + # Then test the work group appears in the work group list + response = client.list_work_groups() + + response["WorkGroups"].should.have.length_of(1) + work_group = response["WorkGroups"][0] + work_group["Name"].should.equal("athena_workgroup") + work_group["Description"].should.equal("Test work group") + work_group["State"].should.equal("ENABLED") diff --git a/tests/test_autoscaling/test_autoscaling.py b/tests/test_autoscaling/test_autoscaling.py index 2df7bf30f..c46bc7219 100644 --- a/tests/test_autoscaling/test_autoscaling.py +++ b/tests/test_autoscaling/test_autoscaling.py @@ -10,31 +10,39 @@ import sure # noqa from botocore.exceptions import ClientError from nose.tools import assert_raises -from moto import mock_autoscaling, mock_ec2_deprecated, mock_elb_deprecated, mock_elb, mock_autoscaling_deprecated, mock_ec2 +from moto import ( + mock_autoscaling, + mock_ec2_deprecated, + mock_elb_deprecated, + mock_elb, + mock_autoscaling_deprecated, + mock_ec2, +) from tests.helpers import requires_boto_gte -from utils import setup_networking, setup_networking_deprecated, setup_instance_with_networking +from utils import ( + setup_networking, + setup_networking_deprecated, + setup_instance_with_networking, +) @mock_autoscaling_deprecated @mock_elb_deprecated def test_create_autoscaling_group(): mocked_networking = setup_networking_deprecated() - elb_conn = boto.ec2.elb.connect_to_region('us-east-1') - elb_conn.create_load_balancer( - 'test_lb', zones=[], listeners=[(80, 8080, 'http')]) + elb_conn = boto.ec2.elb.connect_to_region("us-east-1") + elb_conn.create_load_balancer("test_lb", zones=[], listeners=[(80, 8080, "http")]) - conn = boto.ec2.autoscale.connect_to_region('us-east-1') + conn = boto.ec2.autoscale.connect_to_region("us-east-1") config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a', 'us-east-1b'], + name="tester_group", + availability_zones=["us-east-1a", "us-east-1b"], default_cooldown=60, desired_capacity=2, health_check_period=100, @@ -45,45 +53,44 @@ def test_create_autoscaling_group(): load_balancers=["test_lb"], placement_group="test_placement", vpc_zone_identifier="{subnet1},{subnet2}".format( - subnet1=mocked_networking['subnet1'], - subnet2=mocked_networking['subnet2'], + subnet1=mocked_networking["subnet1"], subnet2=mocked_networking["subnet2"] ), termination_policies=["OldestInstance", "NewestInstance"], - tags=[Tag( - resource_id='tester_group', - key='test_key', - value='test_value', - propagate_at_launch=True - ) + tags=[ + Tag( + resource_id="tester_group", + key="test_key", + value="test_value", + propagate_at_launch=True, + ) ], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] - group.name.should.equal('tester_group') - set(group.availability_zones).should.equal( - set(['us-east-1a', 'us-east-1b'])) + group.name.should.equal("tester_group") + set(group.availability_zones).should.equal(set(["us-east-1a", "us-east-1b"])) group.desired_capacity.should.equal(2) group.max_size.should.equal(2) group.min_size.should.equal(2) group.instances.should.have.length_of(2) - group.vpc_zone_identifier.should.equal("{subnet1},{subnet2}".format( - subnet1=mocked_networking['subnet1'], - subnet2=mocked_networking['subnet2'], - )) - group.launch_config_name.should.equal('tester') + group.vpc_zone_identifier.should.equal( + "{subnet1},{subnet2}".format( + subnet1=mocked_networking["subnet1"], subnet2=mocked_networking["subnet2"] + ) + ) + group.launch_config_name.should.equal("tester") group.default_cooldown.should.equal(60) group.health_check_period.should.equal(100) group.health_check_type.should.equal("EC2") list(group.load_balancers).should.equal(["test_lb"]) group.placement_group.should.equal("test_placement") - list(group.termination_policies).should.equal( - ["OldestInstance", "NewestInstance"]) + list(group.termination_policies).should.equal(["OldestInstance", "NewestInstance"]) len(list(group.tags)).should.equal(1) tag = list(group.tags)[0] - tag.resource_id.should.equal('tester_group') - tag.key.should.equal('test_key') - tag.value.should.equal('test_value') + tag.resource_id.should.equal("tester_group") + tag.key.should.equal("test_key") + tag.value.should.equal("test_value") tag.propagate_at_launch.should.equal(True) @@ -95,31 +102,29 @@ def test_create_autoscaling_groups_defaults(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] - group.name.should.equal('tester_group') + group.name.should.equal("tester_group") group.max_size.should.equal(2) group.min_size.should.equal(2) - group.launch_config_name.should.equal('tester') + group.launch_config_name.should.equal("tester") # Defaults - list(group.availability_zones).should.equal(['us-east-1a']) # subnet1 + list(group.availability_zones).should.equal(["us-east-1a"]) # subnet1 group.desired_capacity.should.equal(2) - group.vpc_zone_identifier.should.equal(mocked_networking['subnet1']) + group.vpc_zone_identifier.should.equal(mocked_networking["subnet1"]) group.default_cooldown.should.equal(300) group.health_check_period.should.equal(300) group.health_check_type.should.equal("EC2") @@ -132,55 +137,61 @@ def test_create_autoscaling_groups_defaults(): @mock_autoscaling def test_list_many_autoscaling_groups(): mocked_networking = setup_networking() - conn = boto3.client('autoscaling', region_name='us-east-1') - conn.create_launch_configuration(LaunchConfigurationName='TestLC') + conn = boto3.client("autoscaling", region_name="us-east-1") + conn.create_launch_configuration(LaunchConfigurationName="TestLC") for i in range(51): - conn.create_auto_scaling_group(AutoScalingGroupName='TestGroup%d' % i, - MinSize=1, - MaxSize=2, - LaunchConfigurationName='TestLC', - VPCZoneIdentifier=mocked_networking['subnet1']) + conn.create_auto_scaling_group( + AutoScalingGroupName="TestGroup%d" % i, + MinSize=1, + MaxSize=2, + LaunchConfigurationName="TestLC", + VPCZoneIdentifier=mocked_networking["subnet1"], + ) response = conn.describe_auto_scaling_groups() groups = response["AutoScalingGroups"] marker = response["NextToken"] groups.should.have.length_of(50) - marker.should.equal(groups[-1]['AutoScalingGroupName']) + marker.should.equal(groups[-1]["AutoScalingGroupName"]) response2 = conn.describe_auto_scaling_groups(NextToken=marker) groups.extend(response2["AutoScalingGroups"]) groups.should.have.length_of(51) - assert 'NextToken' not in response2.keys() + assert "NextToken" not in response2.keys() @mock_autoscaling @mock_ec2 def test_list_many_autoscaling_groups(): mocked_networking = setup_networking() - conn = boto3.client('autoscaling', region_name='us-east-1') - conn.create_launch_configuration(LaunchConfigurationName='TestLC') + conn = boto3.client("autoscaling", region_name="us-east-1") + conn.create_launch_configuration(LaunchConfigurationName="TestLC") - conn.create_auto_scaling_group(AutoScalingGroupName='TestGroup1', - MinSize=1, - MaxSize=2, - LaunchConfigurationName='TestLC', - Tags=[{ - "ResourceId": 'TestGroup1', - "ResourceType": "auto-scaling-group", - "PropagateAtLaunch": True, - "Key": 'TestTagKey1', - "Value": 'TestTagValue1' - }], - VPCZoneIdentifier=mocked_networking['subnet1']) + conn.create_auto_scaling_group( + AutoScalingGroupName="TestGroup1", + MinSize=1, + MaxSize=2, + LaunchConfigurationName="TestLC", + Tags=[ + { + "ResourceId": "TestGroup1", + "ResourceType": "auto-scaling-group", + "PropagateAtLaunch": True, + "Key": "TestTagKey1", + "Value": "TestTagValue1", + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], + ) - ec2 = boto3.client('ec2', region_name='us-east-1') + ec2 = boto3.client("ec2", region_name="us-east-1") instances = ec2.describe_instances() - tags = instances['Reservations'][0]['Instances'][0]['Tags'] - tags.should.contain({u'Value': 'TestTagValue1', u'Key': 'TestTagKey1'}) - tags.should.contain({u'Value': 'TestGroup1', u'Key': 'aws:autoscaling:groupName'}) + tags = instances["Reservations"][0]["Instances"][0]["Tags"] + tags.should.contain({"Value": "TestTagValue1", "Key": "TestTagKey1"}) + tags.should.contain({"Value": "TestGroup1", "Key": "aws:autoscaling:groupName"}) @mock_autoscaling_deprecated @@ -188,27 +199,26 @@ def test_autoscaling_group_describe_filter(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) - group.name = 'tester_group2' + group.name = "tester_group2" conn.create_auto_scaling_group(group) - group.name = 'tester_group3' + group.name = "tester_group3" conn.create_auto_scaling_group(group) - conn.get_all_groups( - names=['tester_group', 'tester_group2']).should.have.length_of(2) + conn.get_all_groups(names=["tester_group", "tester_group2"]).should.have.length_of( + 2 + ) conn.get_all_groups().should.have.length_of(3) @@ -217,33 +227,31 @@ def test_autoscaling_update(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] - group.availability_zones.should.equal(['us-east-1a']) - group.vpc_zone_identifier.should.equal(mocked_networking['subnet1']) + group.availability_zones.should.equal(["us-east-1a"]) + group.vpc_zone_identifier.should.equal(mocked_networking["subnet1"]) - group.availability_zones = ['us-east-1b'] - group.vpc_zone_identifier = mocked_networking['subnet2'] + group.availability_zones = ["us-east-1b"] + group.vpc_zone_identifier = mocked_networking["subnet2"] group.update() group = conn.get_all_groups()[0] - group.availability_zones.should.equal(['us-east-1b']) - group.vpc_zone_identifier.should.equal(mocked_networking['subnet2']) + group.availability_zones.should.equal(["us-east-1b"]) + group.vpc_zone_identifier.should.equal(mocked_networking["subnet2"]) @mock_autoscaling_deprecated @@ -251,40 +259,45 @@ def test_autoscaling_tags_update(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - tags=[Tag( - resource_id='tester_group', - key='test_key', - value='test_value', - propagate_at_launch=True - )], - vpc_zone_identifier=mocked_networking['subnet1'], + tags=[ + Tag( + resource_id="tester_group", + key="test_key", + value="test_value", + propagate_at_launch=True, + ) + ], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) - conn.create_or_update_tags(tags=[Tag( - resource_id='tester_group', - key='test_key', - value='new_test_value', - propagate_at_launch=True - ), Tag( - resource_id='tester_group', - key='test_key2', - value='test_value2', - propagate_at_launch=True - )]) + conn.create_or_update_tags( + tags=[ + Tag( + resource_id="tester_group", + key="test_key", + value="new_test_value", + propagate_at_launch=True, + ), + Tag( + resource_id="tester_group", + key="test_key2", + value="test_value2", + propagate_at_launch=True, + ), + ] + ) group = conn.get_all_groups()[0] group.tags.should.have.length_of(2) @@ -294,24 +307,22 @@ def test_autoscaling_group_delete(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) conn.get_all_groups().should.have.length_of(1) - conn.delete_auto_scaling_group('tester_group') + conn.delete_auto_scaling_group("tester_group") conn.get_all_groups().should.have.length_of(0) @@ -319,30 +330,28 @@ def test_autoscaling_group_delete(): @mock_autoscaling_deprecated def test_autoscaling_group_describe_instances(): mocked_networking = setup_networking_deprecated() - conn = boto.ec2.autoscale.connect_to_region('us-east-1') + conn = boto.ec2.autoscale.connect_to_region("us-east-1") config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) instances = list(conn.get_all_autoscaling_instances()) instances.should.have.length_of(2) - instances[0].launch_config_name.should.equal('tester') - instances[0].health_status.should.equal('Healthy') + instances[0].launch_config_name.should.equal("tester") + instances[0].health_status.should.equal("Healthy") autoscale_instance_ids = [instance.instance_id for instance in instances] - ec2_conn = boto.ec2.connect_to_region('us-east-1') + ec2_conn = boto.ec2.connect_to_region("us-east-1") reservations = ec2_conn.get_all_instances() instances = reservations[0].instances instances.should.have.length_of(2) @@ -357,20 +366,18 @@ def test_set_desired_capacity_up(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) @@ -393,20 +400,18 @@ def test_set_desired_capacity_down(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) @@ -429,20 +434,18 @@ def test_set_desired_capacity_the_same(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) @@ -464,26 +467,24 @@ def test_set_desired_capacity_the_same(): def test_autoscaling_group_with_elb(): mocked_networking = setup_networking_deprecated() elb_conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = elb_conn.create_load_balancer('my-lb', zones, ports) - instances_health = elb_conn.describe_instance_health('my-lb') + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = elb_conn.create_load_balancer("my-lb", zones, ports) + instances_health = elb_conn.describe_instance_health("my-lb") instances_health.should.be.empty conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, load_balancers=["my-lb"], - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] @@ -491,8 +492,7 @@ def test_autoscaling_group_with_elb(): group.desired_capacity.should.equal(2) elb.instances.should.have.length_of(2) - autoscale_instance_ids = set( - instance.instance_id for instance in group.instances) + autoscale_instance_ids = set(instance.instance_id for instance in group.instances) elb_instace_ids = set(instance.id for instance in elb.instances) autoscale_instance_ids.should.equal(elb_instace_ids) @@ -502,20 +502,19 @@ def test_autoscaling_group_with_elb(): group.desired_capacity.should.equal(3) elb.instances.should.have.length_of(3) - autoscale_instance_ids = set( - instance.instance_id for instance in group.instances) + autoscale_instance_ids = set(instance.instance_id for instance in group.instances) elb_instace_ids = set(instance.id for instance in elb.instances) autoscale_instance_ids.should.equal(elb_instace_ids) - conn.delete_auto_scaling_group('tester_group') + conn.delete_auto_scaling_group("tester_group") conn.get_all_groups().should.have.length_of(0) elb = elb_conn.get_all_load_balancers()[0] elb.instances.should.have.length_of(0) -''' +""" Boto3 -''' +""" @mock_autoscaling @@ -524,77 +523,74 @@ def test_describe_load_balancers(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', - LoadBalancerNames=['my-lb'], + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", + LoadBalancerNames=["my-lb"], MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_load_balancers(AutoScalingGroupName='test_asg') - assert response['ResponseMetadata']['RequestId'] - list(response['LoadBalancers']).should.have.length_of(1) - response['LoadBalancers'][0]['LoadBalancerName'].should.equal('my-lb') + response = client.describe_load_balancers(AutoScalingGroupName="test_asg") + assert response["ResponseMetadata"]["RequestId"] + list(response["LoadBalancers"]).should.have.length_of(1) + response["LoadBalancers"][0]["LoadBalancerName"].should.equal("my-lb") + @mock_autoscaling @mock_elb def test_create_elb_and_autoscaling_group_no_relationship(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - ELB_NAME = 'my-elb' + ELB_NAME = "my-elb" - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( LoadBalancerName=ELB_NAME, - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) # autoscaling group and elb should have no relationship - response = client.describe_load_balancers( - AutoScalingGroupName='test_asg' - ) - list(response['LoadBalancers']).should.have.length_of(0) - response = elb_client.describe_load_balancers( - LoadBalancerNames=[ELB_NAME] - ) - list(response['LoadBalancerDescriptions'][0]['Instances']).should.have.length_of(0) + response = client.describe_load_balancers(AutoScalingGroupName="test_asg") + list(response["LoadBalancers"]).should.have.length_of(0) + response = elb_client.describe_load_balancers(LoadBalancerNames=[ELB_NAME]) + list(response["LoadBalancerDescriptions"][0]["Instances"]).should.have.length_of(0) @mock_autoscaling @@ -603,47 +599,46 @@ def test_attach_load_balancer(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) response = client.attach_load_balancers( - AutoScalingGroupName='test_asg', - LoadBalancerNames=['my-lb']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - - response = elb_client.describe_load_balancers( - LoadBalancerNames=['my-lb'] + AutoScalingGroupName="test_asg", LoadBalancerNames=["my-lb"] ) - list(response['LoadBalancerDescriptions'][0]['Instances']).should.have.length_of(INSTANCE_COUNT) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] + response = elb_client.describe_load_balancers(LoadBalancerNames=["my-lb"]) + list(response["LoadBalancerDescriptions"][0]["Instances"]).should.have.length_of( + INSTANCE_COUNT ) - list(response['AutoScalingGroups'][0]['LoadBalancerNames']).should.have.length_of(1) + + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + list(response["AutoScalingGroups"][0]["LoadBalancerNames"]).should.have.length_of(1) @mock_autoscaling @@ -652,740 +647,736 @@ def test_detach_load_balancer(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', - LoadBalancerNames=['my-lb'], + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", + LoadBalancerNames=["my-lb"], MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) response = client.detach_load_balancers( - AutoScalingGroupName='test_asg', - LoadBalancerNames=['my-lb']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - - response = elb_client.describe_load_balancers( - LoadBalancerNames=['my-lb'] + AutoScalingGroupName="test_asg", LoadBalancerNames=["my-lb"] ) - list(response['LoadBalancerDescriptions'][0]['Instances']).should.have.length_of(0) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_load_balancers(AutoScalingGroupName='test_asg') - list(response['LoadBalancers']).should.have.length_of(0) + response = elb_client.describe_load_balancers(LoadBalancerNames=["my-lb"]) + list(response["LoadBalancerDescriptions"][0]["Instances"]).should.have.length_of(0) + + response = client.describe_load_balancers(AutoScalingGroupName="test_asg") + list(response["LoadBalancers"]).should.have.length_of(0) @mock_autoscaling def test_create_autoscaling_group_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) response = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, Tags=[ - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }, - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'not-propogated-tag-key', - 'Value': 'not-propogate-tag-value', - 'PropagateAtLaunch': False - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + }, + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "not-propogated-tag-key", + "Value": "not-propogate-tag-value", + "PropagateAtLaunch": False, + }, + ], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @mock_autoscaling def test_create_autoscaling_group_from_instance(): - autoscaling_group_name = 'test_asg' - image_id = 'ami-0cc293023f983ed53' - instance_type = 't2.micro' + autoscaling_group_name = "test_asg" + image_id = "ami-0cc293023f983ed53" + instance_type = "t2.micro" - mocked_instance_with_networking = setup_instance_with_networking(image_id, instance_type) - client = boto3.client('autoscaling', region_name='us-east-1') + mocked_instance_with_networking = setup_instance_with_networking( + image_id, instance_type + ) + client = boto3.client("autoscaling", region_name="us-east-1") response = client.create_auto_scaling_group( AutoScalingGroupName=autoscaling_group_name, - InstanceId=mocked_instance_with_networking['instance'], + InstanceId=mocked_instance_with_networking["instance"], MinSize=1, MaxSize=3, DesiredCapacity=2, Tags=[ - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }, - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'not-propogated-tag-key', - 'Value': 'not-propogate-tag-value', - 'PropagateAtLaunch': False - }], - VPCZoneIdentifier=mocked_instance_with_networking['subnet1'], + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + }, + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "not-propogated-tag-key", + "Value": "not-propogate-tag-value", + "PropagateAtLaunch": False, + }, + ], + VPCZoneIdentifier=mocked_instance_with_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) describe_launch_configurations_response = client.describe_launch_configurations() - describe_launch_configurations_response['LaunchConfigurations'].should.have.length_of(1) - launch_configuration_from_instance = describe_launch_configurations_response['LaunchConfigurations'][0] - launch_configuration_from_instance['LaunchConfigurationName'].should.equal('test_asg') - launch_configuration_from_instance['ImageId'].should.equal(image_id) - launch_configuration_from_instance['InstanceType'].should.equal(instance_type) + describe_launch_configurations_response[ + "LaunchConfigurations" + ].should.have.length_of(1) + launch_configuration_from_instance = describe_launch_configurations_response[ + "LaunchConfigurations" + ][0] + launch_configuration_from_instance["LaunchConfigurationName"].should.equal( + "test_asg" + ) + launch_configuration_from_instance["ImageId"].should.equal(image_id) + launch_configuration_from_instance["InstanceType"].should.equal(instance_type) @mock_autoscaling def test_create_autoscaling_group_from_invalid_instance_id(): - invalid_instance_id = 'invalid_instance' + invalid_instance_id = "invalid_instance" mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") with assert_raises(ClientError) as ex: client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceId=invalid_instance_id, MinSize=9, MaxSize=15, DesiredCapacity=12, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Code'].should.equal('ValidationError') - ex.exception.response['Error']['Message'].should.equal('Instance [{0}] is invalid.'.format(invalid_instance_id)) + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Code"].should.equal("ValidationError") + ex.exception.response["Error"]["Message"].should.equal( + "Instance [{0}] is invalid.".format(invalid_instance_id) + ) @mock_autoscaling def test_describe_autoscaling_groups_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] - ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - group = response['AutoScalingGroups'][0] - group['AutoScalingGroupName'].should.equal('test_asg') - group['AvailabilityZones'].should.equal(['us-east-1a']) - group['VPCZoneIdentifier'].should.equal(mocked_networking['subnet1']) - group['NewInstancesProtectedFromScaleIn'].should.equal(True) - for instance in group['Instances']: - instance['AvailabilityZone'].should.equal('us-east-1a') - instance['ProtectedFromScaleIn'].should.equal(True) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + group = response["AutoScalingGroups"][0] + group["AutoScalingGroupName"].should.equal("test_asg") + group["AvailabilityZones"].should.equal(["us-east-1a"]) + group["VPCZoneIdentifier"].should.equal(mocked_networking["subnet1"]) + group["NewInstancesProtectedFromScaleIn"].should.equal(True) + for instance in group["Instances"]: + instance["AvailabilityZone"].should.equal("us-east-1a") + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling def test_describe_autoscaling_instances_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] - ) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) instance_ids = [ - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ] response = client.describe_auto_scaling_instances(InstanceIds=instance_ids) - for instance in response['AutoScalingInstances']: - instance['AutoScalingGroupName'].should.equal('test_asg') - instance['AvailabilityZone'].should.equal('us-east-1a') - instance['ProtectedFromScaleIn'].should.equal(True) + for instance in response["AutoScalingInstances"]: + instance["AutoScalingGroupName"].should.equal("test_asg") + instance["AvailabilityZone"].should.equal("us-east-1a") + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling def test_update_autoscaling_group_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) _ = client.update_auto_scaling_group( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", MinSize=1, VPCZoneIdentifier="{subnet1},{subnet2}".format( - subnet1=mocked_networking['subnet1'], - subnet2=mocked_networking['subnet2'], + subnet1=mocked_networking["subnet1"], subnet2=mocked_networking["subnet2"] ), NewInstancesProtectedFromScaleIn=False, ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] - ) - group = response['AutoScalingGroups'][0] - group['MinSize'].should.equal(1) - set(group['AvailabilityZones']).should.equal({'us-east-1a', 'us-east-1b'}) - group['NewInstancesProtectedFromScaleIn'].should.equal(False) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["MinSize"].should.equal(1) + set(group["AvailabilityZones"]).should.equal({"us-east-1a", "us-east-1b"}) + group["NewInstancesProtectedFromScaleIn"].should.equal(False) @mock_autoscaling def test_update_autoscaling_group_min_size_desired_capacity_change(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=20, DesiredCapacity=3, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - client.update_auto_scaling_group( - AutoScalingGroupName='test_asg', - MinSize=5, - ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg']) - group = response['AutoScalingGroups'][0] - group['DesiredCapacity'].should.equal(5) - group['MinSize'].should.equal(5) - group['Instances'].should.have.length_of(5) + client.update_auto_scaling_group(AutoScalingGroupName="test_asg", MinSize=5) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["DesiredCapacity"].should.equal(5) + group["MinSize"].should.equal(5) + group["Instances"].should.have.length_of(5) @mock_autoscaling def test_update_autoscaling_group_max_size_desired_capacity_change(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=20, DesiredCapacity=10, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - client.update_auto_scaling_group( - AutoScalingGroupName='test_asg', - MaxSize=5, - ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg']) - group = response['AutoScalingGroups'][0] - group['DesiredCapacity'].should.equal(5) - group['MaxSize'].should.equal(5) - group['Instances'].should.have.length_of(5) + client.update_auto_scaling_group(AutoScalingGroupName="test_asg", MaxSize=5) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["DesiredCapacity"].should.equal(5) + group["MaxSize"].should.equal(5) + group["Instances"].should.have.length_of(5) @mock_autoscaling def test_autoscaling_taqs_update_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - client.create_or_update_tags(Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'updated_test_value', - "PropagateAtLaunch": True - }, { - "ResourceId": 'test_asg', - "Key": 'test_key2', - "Value": 'test_value2', - "PropagateAtLaunch": False - }]) - - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] + client.create_or_update_tags( + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "updated_test_value", + "PropagateAtLaunch": True, + }, + { + "ResourceId": "test_asg", + "Key": "test_key2", + "Value": "test_value2", + "PropagateAtLaunch": False, + }, + ] ) - response['AutoScalingGroups'][0]['Tags'].should.have.length_of(2) + + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + response["AutoScalingGroups"][0]["Tags"].should.have.length_of(2) @mock_autoscaling def test_autoscaling_describe_policies_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) client.put_scaling_policy( - AutoScalingGroupName='test_asg', - PolicyName='test_policy_down', - PolicyType='SimpleScaling', - AdjustmentType='PercentChangeInCapacity', + AutoScalingGroupName="test_asg", + PolicyName="test_policy_down", + PolicyType="SimpleScaling", + AdjustmentType="PercentChangeInCapacity", ScalingAdjustment=-10, Cooldown=60, - MinAdjustmentMagnitude=1) + MinAdjustmentMagnitude=1, + ) client.put_scaling_policy( - AutoScalingGroupName='test_asg', - PolicyName='test_policy_up', - PolicyType='SimpleScaling', - AdjustmentType='PercentChangeInCapacity', + AutoScalingGroupName="test_asg", + PolicyName="test_policy_up", + PolicyType="SimpleScaling", + AdjustmentType="PercentChangeInCapacity", ScalingAdjustment=10, Cooldown=60, - MinAdjustmentMagnitude=1) + MinAdjustmentMagnitude=1, + ) response = client.describe_policies() - response['ScalingPolicies'].should.have.length_of(2) + response["ScalingPolicies"].should.have.length_of(2) - response = client.describe_policies(AutoScalingGroupName='test_asg') - response['ScalingPolicies'].should.have.length_of(2) + response = client.describe_policies(AutoScalingGroupName="test_asg") + response["ScalingPolicies"].should.have.length_of(2) - response = client.describe_policies(PolicyTypes=['StepScaling']) - response['ScalingPolicies'].should.have.length_of(0) + response = client.describe_policies(PolicyTypes=["StepScaling"]) + response["ScalingPolicies"].should.have.length_of(0) response = client.describe_policies( - AutoScalingGroupName='test_asg', - PolicyNames=['test_policy_down'], - PolicyTypes=['SimpleScaling'] + AutoScalingGroupName="test_asg", + PolicyNames=["test_policy_down"], + PolicyTypes=["SimpleScaling"], ) - response['ScalingPolicies'].should.have.length_of(1) - response['ScalingPolicies'][0][ - 'PolicyName'].should.equal('test_policy_down') + response["ScalingPolicies"].should.have.length_of(1) + response["ScalingPolicies"][0]["PolicyName"].should.equal("test_policy_down") + @mock_autoscaling @mock_ec2 def test_detach_one_instance_decrement(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=2, DesiredCapacity=2, - Tags=[{ - 'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - instance_to_detach = response['AutoScalingGroups'][0]['Instances'][0]['InstanceId'] - instance_to_keep = response['AutoScalingGroups'][0]['Instances'][1]['InstanceId'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instance_to_detach = response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"] + instance_to_keep = response["AutoScalingGroups"][0]["Instances"][1]["InstanceId"] - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) response = client.detach_instances( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=[instance_to_detach], - ShouldDecrementDesiredCapacity=True + ShouldDecrementDesiredCapacity=True, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - response['AutoScalingGroups'][0]['Instances'].should.have.length_of(1) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + response["AutoScalingGroups"][0]["Instances"].should.have.length_of(1) # test to ensure tag has been removed response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(1) # test to ensure tag is present on other instance response = ec2_client.describe_instances(InstanceIds=[instance_to_keep]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(2) + @mock_autoscaling @mock_ec2 def test_detach_one_instance(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=2, DesiredCapacity=2, - Tags=[{ - 'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - instance_to_detach = response['AutoScalingGroups'][0]['Instances'][0]['InstanceId'] - instance_to_keep = response['AutoScalingGroups'][0]['Instances'][1]['InstanceId'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instance_to_detach = response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"] + instance_to_keep = response["AutoScalingGroups"][0]["Instances"][1]["InstanceId"] - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) response = client.detach_instances( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=[instance_to_detach], - ShouldDecrementDesiredCapacity=False + ShouldDecrementDesiredCapacity=False, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) # test to ensure instance was replaced - response['AutoScalingGroups'][0]['Instances'].should.have.length_of(2) + response["AutoScalingGroups"][0]["Instances"].should.have.length_of(2) response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(1) response = ec2_client.describe_instances(InstanceIds=[instance_to_keep]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(2) + @mock_autoscaling @mock_ec2 def test_attach_one_instance(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=4, DesiredCapacity=2, - Tags=[{ - 'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - ec2 = boto3.resource('ec2', 'us-east-1') - instances_to_add = [x.id for x in ec2.create_instances(ImageId='', MinCount=1, MaxCount=1)] + ec2 = boto3.resource("ec2", "us-east-1") + instances_to_add = [ + x.id for x in ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + ] response = client.attach_instances( - AutoScalingGroupName='test_asg', - InstanceIds=instances_to_add + AutoScalingGroupName="test_asg", InstanceIds=instances_to_add ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - instances = response['AutoScalingGroups'][0]['Instances'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instances = response["AutoScalingGroups"][0]["Instances"] instances.should.have.length_of(3) for instance in instances: - instance['ProtectedFromScaleIn'].should.equal(True) + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling @mock_ec2 def test_describe_instance_health(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=4, DesiredCapacity=2, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + + instance1 = response["AutoScalingGroups"][0]["Instances"][0] + instance1["HealthStatus"].should.equal("Healthy") - instance1 = response['AutoScalingGroups'][0]['Instances'][0] - instance1['HealthStatus'].should.equal('Healthy') @mock_autoscaling @mock_ec2 def test_set_instance_health(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=4, DesiredCapacity=2, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + + instance1 = response["AutoScalingGroups"][0]["Instances"][0] + instance1["HealthStatus"].should.equal("Healthy") + + client.set_instance_health( + InstanceId=instance1["InstanceId"], HealthStatus="Unhealthy" ) - instance1 = response['AutoScalingGroups'][0]['Instances'][0] - instance1['HealthStatus'].should.equal('Healthy') + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) - client.set_instance_health(InstanceId=instance1['InstanceId'], HealthStatus='Unhealthy') + instance1 = response["AutoScalingGroups"][0]["Instances"][0] + instance1["HealthStatus"].should.equal("Unhealthy") - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - - instance1 = response['AutoScalingGroups'][0]['Instances'][0] - instance1['HealthStatus'].should.equal('Unhealthy') @mock_autoscaling def test_suspend_processes(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') - client.create_launch_configuration( - LaunchConfigurationName='lc', - ) + client = boto3.client("autoscaling", region_name="us-east-1") + client.create_launch_configuration(LaunchConfigurationName="lc") client.create_auto_scaling_group( - LaunchConfigurationName='lc', - AutoScalingGroupName='test-asg', + LaunchConfigurationName="lc", + AutoScalingGroupName="test-asg", MinSize=1, MaxSize=1, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) # When we suspend the 'Launch' process on the ASG client client.suspend_processes( - AutoScalingGroupName='test-asg', - ScalingProcesses=['Launch'] + AutoScalingGroupName="test-asg", ScalingProcesses=["Launch"] ) - res = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test-asg'] - ) + res = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test-asg"]) # The 'Launch' process should, in fact, be suspended launch_suspended = False - for proc in res['AutoScalingGroups'][0]['SuspendedProcesses']: - if proc.get('ProcessName') == 'Launch': + for proc in res["AutoScalingGroups"][0]["SuspendedProcesses"]: + if proc.get("ProcessName") == "Launch": launch_suspended = True assert launch_suspended is True + @mock_autoscaling def test_set_instance_protection(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) instance_ids = [ - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ] protected = instance_ids[:3] _ = client.set_instance_protection( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=protected, ProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) - for instance in response['AutoScalingGroups'][0]['Instances']: - instance['ProtectedFromScaleIn'].should.equal( - instance['InstanceId'] in protected + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + for instance in response["AutoScalingGroups"][0]["Instances"]: + instance["ProtectedFromScaleIn"].should.equal( + instance["InstanceId"] in protected ) @mock_autoscaling def test_set_desired_capacity_up_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - _ = client.set_desired_capacity( - AutoScalingGroupName='test_asg', - DesiredCapacity=10, - ) + _ = client.set_desired_capacity(AutoScalingGroupName="test_asg", DesiredCapacity=10) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) - instances = response['AutoScalingGroups'][0]['Instances'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instances = response["AutoScalingGroups"][0]["Instances"] instances.should.have.length_of(10) for instance in instances: - instance['ProtectedFromScaleIn'].should.equal(True) + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling def test_set_desired_capacity_down_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) instance_ids = [ - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ] unprotected, protected = instance_ids[:2], instance_ids[2:] _ = client.set_instance_protection( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=unprotected, ProtectedFromScaleIn=False, ) - _ = client.set_desired_capacity( - AutoScalingGroupName='test_asg', - DesiredCapacity=1, - ) + _ = client.set_desired_capacity(AutoScalingGroupName="test_asg", DesiredCapacity=1) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) - group = response['AutoScalingGroups'][0] - group['DesiredCapacity'].should.equal(1) - instance_ids = {instance['InstanceId'] for instance in group['Instances']} + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["DesiredCapacity"].should.equal(1) + instance_ids = {instance["InstanceId"] for instance in group["Instances"]} set(protected).should.equal(instance_ids) set(unprotected).should_not.be.within(instance_ids) # only unprotected killed @@ -1394,30 +1385,30 @@ def test_set_desired_capacity_down_boto3(): @mock_ec2 def test_terminate_instance_in_autoscaling_group(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=1, MaxSize=20, - VPCZoneIdentifier=mocked_networking['subnet1'], - NewInstancesProtectedFromScaleIn=False + VPCZoneIdentifier=mocked_networking["subnet1"], + NewInstancesProtectedFromScaleIn=False, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) original_instance_id = next( - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ) - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") ec2_client.terminate_instances(InstanceIds=[original_instance_id]) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) replaced_instance_id = next( - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ) replaced_instance_id.should_not.equal(original_instance_id) diff --git a/tests/test_autoscaling/test_elbv2.py b/tests/test_autoscaling/test_elbv2.py index a142fd133..a3d3dba9f 100644 --- a/tests/test_autoscaling/test_elbv2.py +++ b/tests/test_autoscaling/test_elbv2.py @@ -2,127 +2,134 @@ from __future__ import unicode_literals import boto3 import sure # noqa -from moto import mock_autoscaling, mock_ec2, mock_elbv2 +from moto import mock_autoscaling, mock_ec2, mock_elbv2 from utils import setup_networking + @mock_elbv2 @mock_autoscaling def test_attach_detach_target_groups(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - client = boto3.client('autoscaling', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") response = elbv2_client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, - VpcId=mocked_networking['vpc'], - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + VpcId=mocked_networking["vpc"], + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group_arn = response['TargetGroups'][0]['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group_arn = response["TargetGroups"][0]["TargetGroupArn"] client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration') + LaunchConfigurationName="test_launch_configuration" + ) # create asg, attach to target group on create client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, TargetGroupARNs=[target_group_arn], - VPCZoneIdentifier=mocked_networking['subnet1']) + VPCZoneIdentifier=mocked_networking["subnet1"], + ) # create asg without attaching to target group client.create_auto_scaling_group( - AutoScalingGroupName='test_asg2', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg2", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - VPCZoneIdentifier=mocked_networking['subnet2']) + VPCZoneIdentifier=mocked_networking["subnet2"], + ) response = client.describe_load_balancer_target_groups( - AutoScalingGroupName='test_asg') - list(response['LoadBalancerTargetGroups']).should.have.length_of(1) + AutoScalingGroupName="test_asg" + ) + list(response["LoadBalancerTargetGroups"]).should.have.length_of(1) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT) client.attach_load_balancer_target_groups( - AutoScalingGroupName='test_asg2', - TargetGroupARNs=[target_group_arn]) + AutoScalingGroupName="test_asg2", TargetGroupARNs=[target_group_arn] + ) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT * 2) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT * 2) response = client.detach_load_balancer_target_groups( - AutoScalingGroupName='test_asg2', - TargetGroupARNs=[target_group_arn]) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT) + AutoScalingGroupName="test_asg2", TargetGroupARNs=[target_group_arn] + ) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT) + @mock_elbv2 @mock_autoscaling def test_detach_all_target_groups(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - client = boto3.client('autoscaling', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") response = elbv2_client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, - VpcId=mocked_networking['vpc'], - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + VpcId=mocked_networking["vpc"], + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group_arn = response['TargetGroups'][0]['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group_arn = response["TargetGroups"][0]["TargetGroupArn"] client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration') + LaunchConfigurationName="test_launch_configuration" + ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, TargetGroupARNs=[target_group_arn], - VPCZoneIdentifier=mocked_networking['subnet1']) + VPCZoneIdentifier=mocked_networking["subnet1"], + ) response = client.describe_load_balancer_target_groups( - AutoScalingGroupName='test_asg') - list(response['LoadBalancerTargetGroups']).should.have.length_of(1) + AutoScalingGroupName="test_asg" + ) + list(response["LoadBalancerTargetGroups"]).should.have.length_of(1) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT) response = client.detach_load_balancer_target_groups( - AutoScalingGroupName='test_asg', - TargetGroupARNs=[target_group_arn]) + AutoScalingGroupName="test_asg", TargetGroupARNs=[target_group_arn] + ) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(0) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(0) response = client.describe_load_balancer_target_groups( - AutoScalingGroupName='test_asg') - list(response['LoadBalancerTargetGroups']).should.have.length_of(0) + AutoScalingGroupName="test_asg" + ) + list(response["LoadBalancerTargetGroups"]).should.have.length_of(0) diff --git a/tests/test_autoscaling/test_launch_configurations.py b/tests/test_autoscaling/test_launch_configurations.py index 931fc8a7e..ab2743f54 100644 --- a/tests/test_autoscaling/test_launch_configurations.py +++ b/tests/test_autoscaling/test_launch_configurations.py @@ -8,6 +8,7 @@ import sure # noqa from moto import mock_autoscaling_deprecated from moto import mock_autoscaling +from moto.core import ACCOUNT_ID from tests.helpers import requires_boto_gte @@ -15,29 +16,31 @@ from tests.helpers import requires_boto_gte def test_create_launch_configuration(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t1.micro', - key_name='the_keys', + name="tester", + image_id="ami-abcd1234", + instance_type="t1.micro", + key_name="the_keys", security_groups=["default", "default2"], user_data=b"This is some user_data", instance_monitoring=True, - instance_profile_name='arn:aws:iam::123456789012:instance-profile/testing', + instance_profile_name="arn:aws:iam::{}:instance-profile/testing".format( + ACCOUNT_ID + ), spot_price=0.1, ) conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] - launch_config.name.should.equal('tester') - launch_config.image_id.should.equal('ami-abcd1234') - launch_config.instance_type.should.equal('t1.micro') - launch_config.key_name.should.equal('the_keys') - set(launch_config.security_groups).should.equal( - set(['default', 'default2'])) + launch_config.name.should.equal("tester") + launch_config.image_id.should.equal("ami-abcd1234") + launch_config.instance_type.should.equal("t1.micro") + launch_config.key_name.should.equal("the_keys") + set(launch_config.security_groups).should.equal(set(["default", "default2"])) launch_config.user_data.should.equal(b"This is some user_data") - launch_config.instance_monitoring.enabled.should.equal('true') + launch_config.instance_monitoring.enabled.should.equal("true") launch_config.instance_profile_name.should.equal( - 'arn:aws:iam::123456789012:instance-profile/testing') + "arn:aws:iam::{}:instance-profile/testing".format(ACCOUNT_ID) + ) launch_config.spot_price.should.equal(0.1) @@ -47,64 +50,67 @@ def test_create_launch_configuration_with_block_device_mappings(): block_device_mapping = BlockDeviceMapping() ephemeral_drive = BlockDeviceType() - ephemeral_drive.ephemeral_name = 'ephemeral0' - block_device_mapping['/dev/xvdb'] = ephemeral_drive + ephemeral_drive.ephemeral_name = "ephemeral0" + block_device_mapping["/dev/xvdb"] = ephemeral_drive snapshot_drive = BlockDeviceType() snapshot_drive.snapshot_id = "snap-1234abcd" snapshot_drive.volume_type = "standard" - block_device_mapping['/dev/xvdp'] = snapshot_drive + block_device_mapping["/dev/xvdp"] = snapshot_drive ebs_drive = BlockDeviceType() ebs_drive.volume_type = "io1" ebs_drive.size = 100 ebs_drive.iops = 1000 ebs_drive.delete_on_termination = False - block_device_mapping['/dev/xvdh'] = ebs_drive + block_device_mapping["/dev/xvdh"] = ebs_drive conn = boto.connect_autoscale(use_block_device_types=True) config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', - key_name='the_keys', + name="tester", + image_id="ami-abcd1234", + instance_type="m1.small", + key_name="the_keys", security_groups=["default", "default2"], user_data=b"This is some user_data", instance_monitoring=True, - instance_profile_name='arn:aws:iam::123456789012:instance-profile/testing', + instance_profile_name="arn:aws:iam::{}:instance-profile/testing".format( + ACCOUNT_ID + ), spot_price=0.1, - block_device_mappings=[block_device_mapping] + block_device_mappings=[block_device_mapping], ) conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] - launch_config.name.should.equal('tester') - launch_config.image_id.should.equal('ami-abcd1234') - launch_config.instance_type.should.equal('m1.small') - launch_config.key_name.should.equal('the_keys') - set(launch_config.security_groups).should.equal( - set(['default', 'default2'])) + launch_config.name.should.equal("tester") + launch_config.image_id.should.equal("ami-abcd1234") + launch_config.instance_type.should.equal("m1.small") + launch_config.key_name.should.equal("the_keys") + set(launch_config.security_groups).should.equal(set(["default", "default2"])) launch_config.user_data.should.equal(b"This is some user_data") - launch_config.instance_monitoring.enabled.should.equal('true') + launch_config.instance_monitoring.enabled.should.equal("true") launch_config.instance_profile_name.should.equal( - 'arn:aws:iam::123456789012:instance-profile/testing') + "arn:aws:iam::{}:instance-profile/testing".format(ACCOUNT_ID) + ) launch_config.spot_price.should.equal(0.1) len(launch_config.block_device_mappings).should.equal(3) returned_mapping = launch_config.block_device_mappings set(returned_mapping.keys()).should.equal( - set(['/dev/xvdb', '/dev/xvdp', '/dev/xvdh'])) + set(["/dev/xvdb", "/dev/xvdp", "/dev/xvdh"]) + ) - returned_mapping['/dev/xvdh'].iops.should.equal(1000) - returned_mapping['/dev/xvdh'].size.should.equal(100) - returned_mapping['/dev/xvdh'].volume_type.should.equal("io1") - returned_mapping['/dev/xvdh'].delete_on_termination.should.be.false + returned_mapping["/dev/xvdh"].iops.should.equal(1000) + returned_mapping["/dev/xvdh"].size.should.equal(100) + returned_mapping["/dev/xvdh"].volume_type.should.equal("io1") + returned_mapping["/dev/xvdh"].delete_on_termination.should.be.false - returned_mapping['/dev/xvdp'].snapshot_id.should.equal("snap-1234abcd") - returned_mapping['/dev/xvdp'].volume_type.should.equal("standard") + returned_mapping["/dev/xvdp"].snapshot_id.should.equal("snap-1234abcd") + returned_mapping["/dev/xvdp"].volume_type.should.equal("standard") - returned_mapping['/dev/xvdb'].ephemeral_name.should.equal('ephemeral0') + returned_mapping["/dev/xvdb"].ephemeral_name.should.equal("ephemeral0") @requires_boto_gte("2.12") @@ -112,9 +118,7 @@ def test_create_launch_configuration_with_block_device_mappings(): def test_create_launch_configuration_for_2_12(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - ebs_optimized=True, + name="tester", image_id="ami-abcd1234", ebs_optimized=True ) conn.create_launch_configuration(config) @@ -127,9 +131,7 @@ def test_create_launch_configuration_for_2_12(): def test_create_launch_configuration_using_ip_association(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - associate_public_ip_address=True, + name="tester", image_id="ami-abcd1234", associate_public_ip_address=True ) conn.create_launch_configuration(config) @@ -141,10 +143,7 @@ def test_create_launch_configuration_using_ip_association(): @mock_autoscaling_deprecated def test_create_launch_configuration_using_ip_association_should_default_to_false(): conn = boto.connect_autoscale() - config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - ) + config = LaunchConfiguration(name="tester", image_id="ami-abcd1234") conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] @@ -157,22 +156,20 @@ def test_create_launch_configuration_defaults(): are assigned for the other attributes """ conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] - launch_config.name.should.equal('tester') - launch_config.image_id.should.equal('ami-abcd1234') - launch_config.instance_type.should.equal('m1.small') + launch_config.name.should.equal("tester") + launch_config.image_id.should.equal("ami-abcd1234") + launch_config.instance_type.should.equal("m1.small") # Defaults - launch_config.key_name.should.equal('') + launch_config.key_name.should.equal("") list(launch_config.security_groups).should.equal([]) launch_config.user_data.should.equal(b"") - launch_config.instance_monitoring.enabled.should.equal('false') + launch_config.instance_monitoring.enabled.should.equal("false") launch_config.instance_profile_name.should.equal(None) launch_config.spot_price.should.equal(None) @@ -181,10 +178,7 @@ def test_create_launch_configuration_defaults(): @mock_autoscaling_deprecated def test_create_launch_configuration_defaults_for_2_12(): conn = boto.connect_autoscale() - config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - ) + config = LaunchConfiguration(name="tester", image_id="ami-abcd1234") conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] @@ -195,51 +189,48 @@ def test_create_launch_configuration_defaults_for_2_12(): def test_launch_configuration_describe_filter(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) - config.name = 'tester2' + config.name = "tester2" conn.create_launch_configuration(config) - config.name = 'tester3' + config.name = "tester3" conn.create_launch_configuration(config) conn.get_all_launch_configurations( - names=['tester', 'tester2']).should.have.length_of(2) + names=["tester", "tester2"] + ).should.have.length_of(2) conn.get_all_launch_configurations().should.have.length_of(3) @mock_autoscaling def test_launch_configuration_describe_paginated(): - conn = boto3.client('autoscaling', region_name='us-east-1') + conn = boto3.client("autoscaling", region_name="us-east-1") for i in range(51): - conn.create_launch_configuration(LaunchConfigurationName='TestLC%d' % i) + conn.create_launch_configuration(LaunchConfigurationName="TestLC%d" % i) response = conn.describe_launch_configurations() lcs = response["LaunchConfigurations"] marker = response["NextToken"] lcs.should.have.length_of(50) - marker.should.equal(lcs[-1]['LaunchConfigurationName']) + marker.should.equal(lcs[-1]["LaunchConfigurationName"]) response2 = conn.describe_launch_configurations(NextToken=marker) lcs.extend(response2["LaunchConfigurations"]) lcs.should.have.length_of(51) - assert 'NextToken' not in response2.keys() + assert "NextToken" not in response2.keys() @mock_autoscaling_deprecated def test_launch_configuration_delete(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) conn.get_all_launch_configurations().should.have.length_of(1) - conn.delete_launch_configuration('tester') + conn.delete_launch_configuration("tester") conn.get_all_launch_configurations().should.have.length_of(0) diff --git a/tests/test_autoscaling/test_policies.py b/tests/test_autoscaling/test_policies.py index e6b01163f..f44938eea 100644 --- a/tests/test_autoscaling/test_policies.py +++ b/tests/test_autoscaling/test_policies.py @@ -14,18 +14,16 @@ def setup_autoscale_group(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) return group @@ -36,18 +34,18 @@ def test_create_policy(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, cooldown=60, ) conn.create_scaling_policy(policy) policy = conn.get_all_policies()[0] - policy.name.should.equal('ScaleUp') - policy.adjustment_type.should.equal('ExactCapacity') - policy.as_name.should.equal('tester_group') + policy.name.should.equal("ScaleUp") + policy.adjustment_type.should.equal("ExactCapacity") + policy.as_name.should.equal("tester_group") policy.scaling_adjustment.should.equal(3) policy.cooldown.should.equal(60) @@ -57,15 +55,15 @@ def test_create_policy_default_values(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) policy = conn.get_all_policies()[0] - policy.name.should.equal('ScaleUp') + policy.name.should.equal("ScaleUp") # Defaults policy.cooldown.should.equal(300) @@ -76,9 +74,9 @@ def test_update_policy(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) @@ -88,9 +86,9 @@ def test_update_policy(): # Now update it by creating another with the same name policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=2, ) conn.create_scaling_policy(policy) @@ -103,16 +101,16 @@ def test_delete_policy(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) conn.get_all_policies().should.have.length_of(1) - conn.delete_policy('ScaleUp') + conn.delete_policy("ScaleUp") conn.get_all_policies().should.have.length_of(0) @@ -121,9 +119,9 @@ def test_execute_policy_exact_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) @@ -139,9 +137,9 @@ def test_execute_policy_positive_change_in_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ChangeInCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ChangeInCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) @@ -157,9 +155,9 @@ def test_execute_policy_percent_change_in_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='PercentChangeInCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="PercentChangeInCapacity", + as_name="tester_group", scaling_adjustment=50, ) conn.create_scaling_policy(policy) @@ -178,9 +176,9 @@ def test_execute_policy_small_percent_change_in_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='PercentChangeInCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="PercentChangeInCapacity", + as_name="tester_group", scaling_adjustment=1, ) conn.create_scaling_policy(policy) diff --git a/tests/test_autoscaling/test_server.py b/tests/test_autoscaling/test_server.py index 2025694cd..17263af44 100644 --- a/tests/test_autoscaling/test_server.py +++ b/tests/test_autoscaling/test_server.py @@ -3,16 +3,16 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_describe_autoscaling_groups(): backend = server.create_backend_app("autoscaling") test_client = backend.test_client() - res = test_client.get('/?Action=DescribeLaunchConfigurations') + res = test_client.get("/?Action=DescribeLaunchConfigurations") - res.data.should.contain(b'') + res.data.should.contain(b"") diff --git a/tests/test_autoscaling/utils.py b/tests/test_autoscaling/utils.py index dc38aba3d..8827d2693 100644 --- a/tests/test_autoscaling/utils.py +++ b/tests/test_autoscaling/utils.py @@ -6,43 +6,36 @@ from moto import mock_ec2, mock_ec2_deprecated @mock_ec2 def setup_networking(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.11.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.11.0.0/16") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='10.11.1.0/24', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="10.11.1.0/24", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='10.11.2.0/24', - AvailabilityZone='us-east-1b') - return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id} + VpcId=vpc.id, CidrBlock="10.11.2.0/24", AvailabilityZone="us-east-1b" + ) + return {"vpc": vpc.id, "subnet1": subnet1.id, "subnet2": subnet2.id} + @mock_ec2_deprecated def setup_networking_deprecated(): - conn = boto_vpc.connect_to_region('us-east-1') + conn = boto_vpc.connect_to_region("us-east-1") vpc = conn.create_vpc("10.11.0.0/16") - subnet1 = conn.create_subnet( - vpc.id, - "10.11.1.0/24", - availability_zone='us-east-1a') - subnet2 = conn.create_subnet( - vpc.id, - "10.11.2.0/24", - availability_zone='us-east-1b') - return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id} + subnet1 = conn.create_subnet(vpc.id, "10.11.1.0/24", availability_zone="us-east-1a") + subnet2 = conn.create_subnet(vpc.id, "10.11.2.0/24", availability_zone="us-east-1b") + return {"vpc": vpc.id, "subnet1": subnet1.id, "subnet2": subnet2.id} @mock_ec2 def setup_instance_with_networking(image_id, instance_type): mock_data = setup_networking() - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") instances = ec2.create_instances( ImageId=image_id, InstanceType=instance_type, MaxCount=1, MinCount=1, - SubnetId=mock_data['subnet1'] + SubnetId=mock_data["subnet1"], ) - mock_data['instance'] = instances[0].id + mock_data["instance"] = instances[0].id return mock_data diff --git a/tests/test_awslambda/test_lambda.py b/tests/test_awslambda/test_lambda.py index 9467b0803..6fd97e325 100644 --- a/tests/test_awslambda/test_lambda.py +++ b/tests/test_awslambda/test_lambda.py @@ -12,18 +12,29 @@ import zipfile import sure # noqa from freezegun import freeze_time -from moto import mock_lambda, mock_s3, mock_ec2, mock_sns, mock_logs, settings, mock_sqs +from moto import ( + mock_dynamodb2, + mock_lambda, + mock_iam, + mock_s3, + mock_ec2, + mock_sns, + mock_logs, + settings, + mock_sqs, +) +from moto.sts.models import ACCOUNT_ID from nose.tools import assert_raises from botocore.exceptions import ClientError -_lambda_region = 'us-west-2' +_lambda_region = "us-west-2" boto3.setup_default_session(region_name=_lambda_region) def _process_lambda(func_str): zip_output = io.BytesIO() - zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED) - zip_file.writestr('lambda_function.py', func_str) + zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) + zip_file.writestr("lambda_function.py", func_str) zip_file.close() zip_output.seek(0) return zip_output.read() @@ -49,7 +60,11 @@ def lambda_handler(event, context): print('get volume details for %s\\nVolume - %s state=%s, size=%s' % (volume_id, volume_id, vol.state, vol.size)) return event -""".format(base_url="motoserver:5000" if settings.TEST_SERVER_MODE else "ec2.us-west-2.amazonaws.com") +""".format( + base_url="motoserver:5000" + if settings.TEST_SERVER_MODE + else "ec2.us-west-2.amazonaws.com" + ) return _process_lambda(func_str) @@ -61,6 +76,7 @@ def lambda_handler(event, context): """ return _process_lambda(pfunc) + def get_test_zip_file4(): pfunc = """ def lambda_handler(event, context): @@ -71,113 +87,120 @@ def lambda_handler(event, context): @mock_lambda def test_list_functions(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") result = conn.list_functions() - result['Functions'].should.have.length_of(0) + result["Functions"].should.have.length_of(0) @mock_lambda def test_invoke_requestresponse_function(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file1(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file1()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - in_data = {'msg': 'So long and thanks for all the fish'} - success_result = conn.invoke(FunctionName='testFunction', InvocationType='RequestResponse', - Payload=json.dumps(in_data)) + in_data = {"msg": "So long and thanks for all the fish"} + success_result = conn.invoke( + FunctionName="testFunction", + InvocationType="RequestResponse", + Payload=json.dumps(in_data), + ) success_result["StatusCode"].should.equal(202) result_obj = json.loads( - base64.b64decode(success_result["LogResult"]).decode('utf-8')) + base64.b64decode(success_result["LogResult"]).decode("utf-8") + ) result_obj.should.equal(in_data) - payload = success_result["Payload"].read().decode('utf-8') + payload = success_result["Payload"].read().decode("utf-8") json.loads(payload).should.equal(in_data) @mock_lambda def test_invoke_event_function(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file1(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file1()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) conn.invoke.when.called_with( - FunctionName='notAFunction', - InvocationType='Event', - Payload='{}' + FunctionName="notAFunction", InvocationType="Event", Payload="{}" ).should.throw(botocore.client.ClientError) - in_data = {'msg': 'So long and thanks for all the fish'} + in_data = {"msg": "So long and thanks for all the fish"} success_result = conn.invoke( - FunctionName='testFunction', InvocationType='Event', Payload=json.dumps(in_data)) + FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data) + ) success_result["StatusCode"].should.equal(202) - json.loads(success_result['Payload'].read().decode( - 'utf-8')).should.equal({}) + json.loads(success_result["Payload"].read().decode("utf-8")).should.equal({}) if settings.TEST_SERVER_MODE: + @mock_ec2 @mock_lambda def test_invoke_function_get_ec2_volume(): conn = boto3.resource("ec2", "us-west-2") - vol = conn.create_volume(Size=99, AvailabilityZone='us-west-2') + vol = conn.create_volume(Size=99, AvailabilityZone="us-west-2") vol = conn.Volume(vol.id) - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file2(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python3.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file2()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - in_data = {'volume_id': vol.id} - result = conn.invoke(FunctionName='testFunction', - InvocationType='RequestResponse', Payload=json.dumps(in_data)) + in_data = {"volume_id": vol.id} + result = conn.invoke( + FunctionName="testFunction", + InvocationType="RequestResponse", + Payload=json.dumps(in_data), + ) result["StatusCode"].should.equal(202) - msg = 'get volume details for %s\nVolume - %s state=%s, size=%s\n%s' % ( - vol.id, vol.id, vol.state, vol.size, json.dumps(in_data)) + msg = "get volume details for %s\nVolume - %s state=%s, size=%s\n%s" % ( + vol.id, + vol.id, + vol.state, + vol.size, + json.dumps(in_data).replace( + " ", "" + ), # Makes the tests pass as the result is missing the whitespace + ) - log_result = base64.b64decode(result["LogResult"]).decode('utf-8') + log_result = base64.b64decode(result["LogResult"]).decode("utf-8") - # fix for running under travis (TODO: investigate why it has an extra newline) - log_result = log_result.replace('\n\n', '\n') + # The Docker lambda invocation will return an additional '\n', so need to replace it: + log_result = log_result.replace("\n\n", "\n") log_result.should.equal(msg) - payload = result['Payload'].read().decode('utf-8') + payload = result["Payload"].read().decode("utf-8") - # fix for running under travis (TODO: investigate why it has an extra newline) - payload = payload.replace('\n\n', '\n') + # The Docker lambda invocation will return an additional '\n', so need to replace it: + payload = payload.replace("\n\n", "\n") payload.should.equal(msg) @@ -191,39 +214,42 @@ def test_invoke_function_from_sns(): sns_conn.create_topic(Name="some-topic") topics_json = sns_conn.list_topics() topics = topics_json["Topics"] - topic_arn = topics[0]['TopicArn'] + topic_arn = topics[0]["TopicArn"] - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - sns_conn.subscribe(TopicArn=topic_arn, Protocol="lambda", Endpoint=result['FunctionArn']) + sns_conn.subscribe( + TopicArn=topic_arn, Protocol="lambda", Endpoint=result["FunctionArn"] + ) result = sns_conn.publish(TopicArn=topic_arn, Message=json.dumps({})) start = time.time() while (time.time() - start) < 30: - result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction') - log_streams = result.get('logStreams') + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") if not log_streams: time.sleep(1) continue assert len(log_streams) == 1 - result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName']) - for event in result.get('events'): - if event['message'] == 'get_test_zip_file3 success': + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if event["message"] == "get_test_zip_file3 success": return time.sleep(1) @@ -233,342 +259,396 @@ def test_invoke_function_from_sns(): @mock_lambda def test_create_based_on_s3_with_missing_bucket(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function.when.called_with( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'this-bucket-does-not-exist', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "this-bucket-does-not-exist", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, - VpcConfig={ - "SecurityGroupIds": ["sg-123abc"], - "SubnetIds": ["subnet-123abc"], - }, + VpcConfig={"SecurityGroupIds": ["sg-123abc"], "SubnetIds": ["subnet-123abc"]}, ).should.throw(botocore.client.ClientError) @mock_lambda @mock_s3 -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_create_function_from_aws_bucket(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, - VpcConfig={ - "SecurityGroupIds": ["sg-123abc"], - "SubnetIds": ["subnet-123abc"], - }, + VpcConfig={"SecurityGroupIds": ["sg-123abc"], "SubnetIds": ["subnet-123abc"]}, ) # this is hard to match against, so remove it - result['ResponseMetadata'].pop('HTTPHeaders', None) + result["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - result['ResponseMetadata'].pop('RetryAttempts', None) - result.pop('LastModified') - result.should.equal({ - 'FunctionName': 'testFunction', - 'FunctionArn': 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), - 'Runtime': 'python2.7', - 'Role': 'test-iam-role', - 'Handler': 'lambda_function.lambda_handler', - "CodeSha256": hashlib.sha256(zip_content).hexdigest(), - "CodeSize": len(zip_content), - 'Description': 'test lambda function', - 'Timeout': 3, - 'MemorySize': 128, - 'Version': '1', - 'VpcConfig': { - "SecurityGroupIds": ["sg-123abc"], - "SubnetIds": ["subnet-123abc"], - "VpcId": "vpc-123abc" - }, - 'ResponseMetadata': {'HTTPStatusCode': 201}, - }) + result["ResponseMetadata"].pop("RetryAttempts", None) + result.pop("LastModified") + result.should.equal( + { + "FunctionName": "testFunction", + "FunctionArn": "arn:aws:lambda:{}:{}:function:testFunction".format( + _lambda_region, ACCOUNT_ID + ), + "Runtime": "python2.7", + "Role": result["Role"], + "Handler": "lambda_function.lambda_handler", + "CodeSha256": hashlib.sha256(zip_content).hexdigest(), + "CodeSize": len(zip_content), + "Description": "test lambda function", + "Timeout": 3, + "MemorySize": 128, + "Version": "1", + "VpcConfig": { + "SecurityGroupIds": ["sg-123abc"], + "SubnetIds": ["subnet-123abc"], + "VpcId": "vpc-123abc", + }, + "ResponseMetadata": {"HTTPStatusCode": 201}, + } + ) @mock_lambda -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_create_function_from_zipfile(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) # this is hard to match against, so remove it - result['ResponseMetadata'].pop('HTTPHeaders', None) + result["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - result['ResponseMetadata'].pop('RetryAttempts', None) - result.pop('LastModified') + result["ResponseMetadata"].pop("RetryAttempts", None) + result.pop("LastModified") - result.should.equal({ - 'FunctionName': 'testFunction', - 'FunctionArn': 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), - 'Runtime': 'python2.7', - 'Role': 'test-iam-role', - 'Handler': 'lambda_function.lambda_handler', - 'CodeSize': len(zip_content), - 'Description': 'test lambda function', - 'Timeout': 3, - 'MemorySize': 128, - 'CodeSha256': hashlib.sha256(zip_content).hexdigest(), - 'Version': '1', - 'VpcConfig': { - "SecurityGroupIds": [], - "SubnetIds": [], - }, - - 'ResponseMetadata': {'HTTPStatusCode': 201}, - }) + result.should.equal( + { + "FunctionName": "testFunction", + "FunctionArn": "arn:aws:lambda:{}:{}:function:testFunction".format( + _lambda_region, ACCOUNT_ID + ), + "Runtime": "python2.7", + "Role": result["Role"], + "Handler": "lambda_function.lambda_handler", + "CodeSize": len(zip_content), + "Description": "test lambda function", + "Timeout": 3, + "MemorySize": 128, + "CodeSha256": hashlib.sha256(zip_content).hexdigest(), + "Version": "1", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + "ResponseMetadata": {"HTTPStatusCode": 201}, + } + ) @mock_lambda @mock_s3 -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_get_function(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file1() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", + Timeout=3, + MemorySize=128, + Publish=True, + Environment={"Variables": {"test_variable": "test_value"}}, + ) + + result = conn.get_function(FunctionName="testFunction") + # this is hard to match against, so remove it + result["ResponseMetadata"].pop("HTTPHeaders", None) + # Botocore inserts retry attempts not seen in Python27 + result["ResponseMetadata"].pop("RetryAttempts", None) + result["Configuration"].pop("LastModified") + + result["Code"]["Location"].should.equal( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip".format(_lambda_region) + ) + result["Code"]["RepositoryType"].should.equal("S3") + + result["Configuration"]["CodeSha256"].should.equal( + hashlib.sha256(zip_content).hexdigest() + ) + result["Configuration"]["CodeSize"].should.equal(len(zip_content)) + result["Configuration"]["Description"].should.equal("test lambda function") + result["Configuration"].should.contain("FunctionArn") + result["Configuration"]["FunctionName"].should.equal("testFunction") + result["Configuration"]["Handler"].should.equal("lambda_function.lambda_handler") + result["Configuration"]["MemorySize"].should.equal(128) + result["Configuration"]["Role"].should.equal(get_role_name()) + result["Configuration"]["Runtime"].should.equal("python2.7") + result["Configuration"]["Timeout"].should.equal(3) + result["Configuration"]["Version"].should.equal("$LATEST") + result["Configuration"].should.contain("VpcConfig") + result["Configuration"].should.contain("Environment") + result["Configuration"]["Environment"].should.contain("Variables") + result["Configuration"]["Environment"]["Variables"].should.equal( + {"test_variable": "test_value"} + ) + + # Test get function with + result = conn.get_function(FunctionName="testFunction", Qualifier="$LATEST") + result["Configuration"]["Version"].should.equal("$LATEST") + result["Configuration"]["FunctionArn"].should.equal( + "arn:aws:lambda:us-west-2:{}:function:testFunction:$LATEST".format(ACCOUNT_ID) + ) + + # Test get function when can't find function name + with assert_raises(ClientError): + conn.get_function(FunctionName="junk", Qualifier="$LATEST") + + +@mock_lambda +@mock_s3 +def test_get_function_by_arn(): + bucket_name = "test-bucket" + s3_conn = boto3.client("s3", "us-east-1") + s3_conn.create_bucket(Bucket=bucket_name) + + zip_content = get_test_zip_file2() + s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-east-1") + + fnc = conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": bucket_name, "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - result = conn.get_function(FunctionName='testFunction') - # this is hard to match against, so remove it - result['ResponseMetadata'].pop('HTTPHeaders', None) - # Botocore inserts retry attempts not seen in Python27 - result['ResponseMetadata'].pop('RetryAttempts', None) - result['Configuration'].pop('LastModified') - - result['Code']['Location'].should.equal('s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip'.format(_lambda_region)) - result['Code']['RepositoryType'].should.equal('S3') - - result['Configuration']['CodeSha256'].should.equal(hashlib.sha256(zip_content).hexdigest()) - result['Configuration']['CodeSize'].should.equal(len(zip_content)) - result['Configuration']['Description'].should.equal('test lambda function') - result['Configuration'].should.contain('FunctionArn') - result['Configuration']['FunctionName'].should.equal('testFunction') - result['Configuration']['Handler'].should.equal('lambda_function.lambda_handler') - result['Configuration']['MemorySize'].should.equal(128) - result['Configuration']['Role'].should.equal('test-iam-role') - result['Configuration']['Runtime'].should.equal('python2.7') - result['Configuration']['Timeout'].should.equal(3) - result['Configuration']['Version'].should.equal('$LATEST') - result['Configuration'].should.contain('VpcConfig') - - # Test get function with - result = conn.get_function(FunctionName='testFunction', Qualifier='$LATEST') - result['Configuration']['Version'].should.equal('$LATEST') - result['Configuration']['FunctionArn'].should.equal('arn:aws:lambda:us-west-2:123456789012:function:testFunction:$LATEST') - - - # Test get function when can't find function name - with assert_raises(ClientError): - conn.get_function(FunctionName='junk', Qualifier='$LATEST') - + result = conn.get_function(FunctionName=fnc["FunctionArn"]) + result["Configuration"]["FunctionName"].should.equal("testFunction") @mock_lambda @mock_s3 def test_delete_function(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - success_result = conn.delete_function(FunctionName='testFunction') + success_result = conn.delete_function(FunctionName="testFunction") # this is hard to match against, so remove it - success_result['ResponseMetadata'].pop('HTTPHeaders', None) + success_result["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - success_result['ResponseMetadata'].pop('RetryAttempts', None) + success_result["ResponseMetadata"].pop("RetryAttempts", None) - success_result.should.equal({'ResponseMetadata': {'HTTPStatusCode': 204}}) + success_result.should.equal({"ResponseMetadata": {"HTTPStatusCode": 204}}) + function_list = conn.list_functions() + function_list["Functions"].should.have.length_of(0) + + +@mock_lambda +@mock_s3 +def test_delete_function_by_arn(): + bucket_name = "test-bucket" + s3_conn = boto3.client("s3", "us-east-1") + s3_conn.create_bucket(Bucket=bucket_name) + + zip_content = get_test_zip_file2() + s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-east-1") + + fnc = conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": bucket_name, "S3Key": "test.zip"}, + Description="test lambda function", + Timeout=3, + MemorySize=128, + Publish=True, + ) + + conn.delete_function(FunctionName=fnc["FunctionArn"]) + function_list = conn.list_functions() + function_list["Functions"].should.have.length_of(0) + + +@mock_lambda +def test_delete_unknown_function(): + conn = boto3.client("lambda", "us-west-2") conn.delete_function.when.called_with( - FunctionName='testFunctionThatDoesntExist').should.throw(botocore.client.ClientError) + FunctionName="testFunctionThatDoesntExist" + ).should.throw(botocore.client.ClientError) @mock_lambda @mock_s3 def test_publish(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=False, ) function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(1) - latest_arn = function_list['Functions'][0]['FunctionArn'] + function_list["Functions"].should.have.length_of(1) + latest_arn = function_list["Functions"][0]["FunctionArn"] - res = conn.publish_version(FunctionName='testFunction') - assert res['ResponseMetadata']['HTTPStatusCode'] == 201 + res = conn.publish_version(FunctionName="testFunction") + assert res["ResponseMetadata"]["HTTPStatusCode"] == 201 function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(2) + function_list["Functions"].should.have.length_of(2) # #SetComprehension ;-) - published_arn = list({f['FunctionArn'] for f in function_list['Functions']} - {latest_arn})[0] - published_arn.should.contain('testFunction:1') + published_arn = list( + {f["FunctionArn"] for f in function_list["Functions"]} - {latest_arn} + )[0] + published_arn.should.contain("testFunction:1") - conn.delete_function(FunctionName='testFunction', Qualifier='1') + conn.delete_function(FunctionName="testFunction", Qualifier="1") function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(1) - function_list['Functions'][0]['FunctionArn'].should.contain('testFunction') + function_list["Functions"].should.have.length_of(1) + function_list["Functions"][0]["FunctionArn"].should.contain("testFunction") @mock_lambda @mock_s3 -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_list_create_list_get_delete_list(): """ test `list -> create -> list -> get -> delete -> list` integration """ - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") - conn.list_functions()['Functions'].should.have.length_of(0) + conn.list_functions()["Functions"].should.have.length_of(0) conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) expected_function_result = { "Code": { - "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip".format(_lambda_region), - "RepositoryType": "S3" + "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip".format( + _lambda_region + ), + "RepositoryType": "S3", }, "Configuration": { "CodeSha256": hashlib.sha256(zip_content).hexdigest(), "CodeSize": len(zip_content), "Description": "test lambda function", - "FunctionArn": 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), + "FunctionArn": "arn:aws:lambda:{}:{}:function:testFunction".format( + _lambda_region, ACCOUNT_ID + ), "FunctionName": "testFunction", "Handler": "lambda_function.lambda_handler", "MemorySize": 128, - "Role": "test-iam-role", + "Role": get_role_name(), "Runtime": "python2.7", "Timeout": 3, - "Version": '$LATEST', - "VpcConfig": { - "SecurityGroupIds": [], - "SubnetIds": [], - } + "Version": "$LATEST", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, }, - 'ResponseMetadata': {'HTTPStatusCode': 200}, + "ResponseMetadata": {"HTTPStatusCode": 200}, } - func = conn.list_functions()['Functions'][0] - func.pop('LastModified') - func.should.equal(expected_function_result['Configuration']) + func = conn.list_functions()["Functions"][0] + func.pop("LastModified") + func.should.equal(expected_function_result["Configuration"]) - func = conn.get_function(FunctionName='testFunction') + func = conn.get_function(FunctionName="testFunction") # this is hard to match against, so remove it - func['ResponseMetadata'].pop('HTTPHeaders', None) + func["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - func['ResponseMetadata'].pop('RetryAttempts', None) - func['Configuration'].pop('LastModified') + func["ResponseMetadata"].pop("RetryAttempts", None) + func["Configuration"].pop("LastModified") func.should.equal(expected_function_result) - conn.delete_function(FunctionName='testFunction') + conn.delete_function(FunctionName="testFunction") - conn.list_functions()['Functions'].should.have.length_of(0) + conn.list_functions()["Functions"].should.have.length_of(0) @mock_lambda @@ -578,34 +658,30 @@ def lambda_handler(event, context): raise Exception('failsauce') """ zip_output = io.BytesIO() - zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED) - zip_file.writestr('lambda_function.py', lambda_fx) + zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) + zip_file.writestr("lambda_function.py", lambda_fx) zip_file.close() zip_output.seek(0) - client = boto3.client('lambda', region_name='us-east-1') + client = boto3.client("lambda", region_name="us-east-1") client.create_function( - FunctionName='test-lambda-fx', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Description='test lambda function', + FunctionName="test-lambda-fx", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, - Code={ - 'ZipFile': zip_output.read() - }, + Code={"ZipFile": zip_output.read()}, ) result = client.invoke( - FunctionName='test-lambda-fx', - InvocationType='RequestResponse', - LogType='Tail' + FunctionName="test-lambda-fx", InvocationType="RequestResponse", LogType="Tail" ) - assert 'FunctionError' in result - assert result['FunctionError'] == 'Handled' + assert "FunctionError" in result + assert result["FunctionError"] == "Handled" @mock_lambda @@ -614,65 +690,56 @@ def test_tags(): """ test list_tags -> tag_resource -> list_tags -> tag_resource -> list_tags -> untag_resource -> list_tags integration """ - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") function = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) # List tags when there are none - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict()) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal(dict()) # List tags when there is one - conn.tag_resource( - Resource=function['FunctionArn'], - Tags=dict(spam='eggs') - )['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict(spam='eggs')) + conn.tag_resource(Resource=function["FunctionArn"], Tags=dict(spam="eggs"))[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(200) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal( + dict(spam="eggs") + ) # List tags when another has been added - conn.tag_resource( - Resource=function['FunctionArn'], - Tags=dict(foo='bar') - )['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict(spam='eggs', foo='bar')) + conn.tag_resource(Resource=function["FunctionArn"], Tags=dict(foo="bar"))[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(200) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal( + dict(spam="eggs", foo="bar") + ) # Untag resource - conn.untag_resource( - Resource=function['FunctionArn'], - TagKeys=['spam', 'trolls'] - )['ResponseMetadata']['HTTPStatusCode'].should.equal(204) - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict(foo='bar')) + conn.untag_resource(Resource=function["FunctionArn"], TagKeys=["spam", "trolls"])[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(204) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal( + dict(foo="bar") + ) # Untag a tag that does not exist (no error and no change) - conn.untag_resource( - Resource=function['FunctionArn'], - TagKeys=['spam'] - )['ResponseMetadata']['HTTPStatusCode'].should.equal(204) + conn.untag_resource(Resource=function["FunctionArn"], TagKeys=["spam"])[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(204) @mock_lambda @@ -680,347 +747,391 @@ def test_tags_not_found(): """ Test list_tags and tag_resource when the lambda with the given arn does not exist """ - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.list_tags.when.called_with( - Resource='arn:aws:lambda:123456789012:function:not-found' + Resource="arn:aws:lambda:{}:function:not-found".format(ACCOUNT_ID) ).should.throw(botocore.client.ClientError) conn.tag_resource.when.called_with( - Resource='arn:aws:lambda:123456789012:function:not-found', - Tags=dict(spam='eggs') + Resource="arn:aws:lambda:{}:function:not-found".format(ACCOUNT_ID), + Tags=dict(spam="eggs"), ).should.throw(botocore.client.ClientError) conn.untag_resource.when.called_with( - Resource='arn:aws:lambda:123456789012:function:not-found', - TagKeys=['spam'] + Resource="arn:aws:lambda:{}:function:not-found".format(ACCOUNT_ID), + TagKeys=["spam"], ).should.throw(botocore.client.ClientError) @mock_lambda def test_invoke_async_function(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={'ZipFile': get_test_zip_file1()}, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file1()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) success_result = conn.invoke_async( - FunctionName='testFunction', - InvokeArgs=json.dumps({'test': 'event'}) - ) + FunctionName="testFunction", InvokeArgs=json.dumps({"test": "event"}) + ) - success_result['Status'].should.equal(202) + success_result["Status"].should.equal(202) @mock_lambda -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_get_function_created_with_zipfile(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - response = conn.get_function( - FunctionName='testFunction' - ) - response['Configuration'].pop('LastModified') + response = conn.get_function(FunctionName="testFunction") + response["Configuration"].pop("LastModified") - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - assert len(response['Code']) == 2 - assert response['Code']['RepositoryType'] == 'S3' - assert response['Code']['Location'].startswith('s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com'.format(_lambda_region)) - response['Configuration'].should.equal( + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + assert len(response["Code"]) == 2 + assert response["Code"]["RepositoryType"] == "S3" + assert response["Code"]["Location"].startswith( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com".format(_lambda_region) + ) + response["Configuration"].should.equal( { "CodeSha256": hashlib.sha256(zip_content).hexdigest(), "CodeSize": len(zip_content), "Description": "test lambda function", - "FunctionArn": 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), + "FunctionArn": "arn:aws:lambda:{}:{}:function:testFunction".format( + _lambda_region, ACCOUNT_ID + ), "FunctionName": "testFunction", "Handler": "lambda_function.handler", "MemorySize": 128, - "Role": "test-iam-role", + "Role": get_role_name(), "Runtime": "python2.7", "Timeout": 3, - "Version": '$LATEST', - "VpcConfig": { - "SecurityGroupIds": [], - "SubnetIds": [], - } - }, + "Version": "$LATEST", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + } ) @mock_lambda -def add_function_permission(): - conn = boto3.client('lambda', 'us-west-2') +def test_add_function_permission(): + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() - result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role=(get_role_name()), + Handler="lambda_function.handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.add_permission( - FunctionName='testFunction', - StatementId='1', + FunctionName="testFunction", + StatementId="1", Action="lambda:InvokeFunction", - Principal='432143214321', + Principal="432143214321", SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld", - SourceAccount='123412341234', - EventSourceToken='blah', - Qualifier='2' + SourceAccount="123412341234", + EventSourceToken="blah", + Qualifier="2", ) - assert 'Statement' in response - res = json.loads(response['Statement']) - assert res['Action'] == "lambda:InvokeFunction" + assert "Statement" in response + res = json.loads(response["Statement"]) + assert res["Action"] == "lambda:InvokeFunction" @mock_lambda -def get_function_policy(): - conn = boto3.client('lambda', 'us-west-2') +def test_get_function_policy(): + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() - result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.add_permission( - FunctionName='testFunction', - StatementId='1', + FunctionName="testFunction", + StatementId="1", Action="lambda:InvokeFunction", - Principal='432143214321', + Principal="432143214321", SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld", - SourceAccount='123412341234', - EventSourceToken='blah', - Qualifier='2' + SourceAccount="123412341234", + EventSourceToken="blah", + Qualifier="2", ) - response = conn.get_policy( - FunctionName='testFunction' - ) + response = conn.get_policy(FunctionName="testFunction") - assert 'Policy' in response - assert isinstance(response['Policy'], str) - res = json.loads(response['Policy']) - assert res['Statement'][0]['Action'] == 'lambda:InvokeFunction' + assert "Policy" in response + res = json.loads(response["Policy"]) + assert res["Statement"][0]["Action"] == "lambda:InvokeFunction" @mock_lambda @mock_s3 def test_list_versions_by_function(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='arn:aws:iam::123456789012:role/test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - res = conn.publish_version(FunctionName='testFunction') - assert res['ResponseMetadata']['HTTPStatusCode'] == 201 - versions = conn.list_versions_by_function(FunctionName='testFunction') - assert len(versions['Versions']) == 3 - assert versions['Versions'][0]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction:$LATEST' - assert versions['Versions'][1]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction:1' - assert versions['Versions'][2]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction:2' + res = conn.publish_version(FunctionName="testFunction") + assert res["ResponseMetadata"]["HTTPStatusCode"] == 201 + versions = conn.list_versions_by_function(FunctionName="testFunction") + assert len(versions["Versions"]) == 3 + assert versions["Versions"][0][ + "FunctionArn" + ] == "arn:aws:lambda:us-west-2:{}:function:testFunction:$LATEST".format(ACCOUNT_ID) + assert versions["Versions"][1][ + "FunctionArn" + ] == "arn:aws:lambda:us-west-2:{}:function:testFunction:1".format(ACCOUNT_ID) + assert versions["Versions"][2][ + "FunctionArn" + ] == "arn:aws:lambda:us-west-2:{}:function:testFunction:2".format(ACCOUNT_ID) conn.create_function( - FunctionName='testFunction_2', - Runtime='python2.7', - Role='arn:aws:iam::123456789012:role/test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction_2", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=False, ) - versions = conn.list_versions_by_function(FunctionName='testFunction_2') - assert len(versions['Versions']) == 1 - assert versions['Versions'][0]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction_2:$LATEST' + versions = conn.list_versions_by_function(FunctionName="testFunction_2") + assert len(versions["Versions"]) == 1 + assert versions["Versions"][0][ + "FunctionArn" + ] == "arn:aws:lambda:us-west-2:{}:function:testFunction_2:$LATEST".format( + ACCOUNT_ID + ) @mock_lambda @mock_s3 def test_create_function_with_already_exists(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - assert response['FunctionName'] == 'testFunction' + assert response["FunctionName"] == "testFunction" @mock_lambda @mock_s3 def test_list_versions_by_function_for_nonexistent_function(): - conn = boto3.client('lambda', 'us-west-2') - versions = conn.list_versions_by_function(FunctionName='testFunction') + conn = boto3.client("lambda", "us-west-2") + versions = conn.list_versions_by_function(FunctionName="testFunction") - assert len(versions['Versions']) == 0 + assert len(versions["Versions"]) == 0 @mock_logs @mock_lambda @mock_sqs def test_create_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda", region_name="us-east-1") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - assert response['EventSourceArn'] == queue.attributes['QueueArn'] - assert response['FunctionArn'] == func['FunctionArn'] - assert response['State'] == 'Enabled' + assert response["EventSourceArn"] == queue.attributes["QueueArn"] + assert response["FunctionArn"] == func["FunctionArn"] + assert response["State"] == "Enabled" @mock_logs @mock_lambda @mock_sqs def test_invoke_function_from_sqs(): - logs_conn = boto3.client("logs") - sqs = boto3.resource('sqs') + logs_conn = boto3.client("logs", region_name="us-east-1") + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda", region_name="us-east-1") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - assert response['EventSourceArn'] == queue.attributes['QueueArn'] - assert response['State'] == 'Enabled' + assert response["EventSourceArn"] == queue.attributes["QueueArn"] + assert response["State"] == "Enabled" - sqs_client = boto3.client('sqs') - sqs_client.send_message(QueueUrl=queue.url, MessageBody='test') + sqs_client = boto3.client("sqs", region_name="us-east-1") + sqs_client.send_message(QueueUrl=queue.url, MessageBody="test") start = time.time() while (time.time() - start) < 30: - result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction') - log_streams = result.get('logStreams') + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") if not log_streams: time.sleep(1) continue assert len(log_streams) == 1 - result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName']) - for event in result.get('events'): - if event['message'] == 'get_test_zip_file3 success': + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if event["message"] == "get_test_zip_file3 success": + return + time.sleep(1) + + assert False, "Test Failed" + + +@mock_logs +@mock_lambda +@mock_dynamodb2 +def test_invoke_function_from_dynamodb(): + logs_conn = boto3.client("logs", region_name="us-east-1") + dynamodb = boto3.client("dynamodb", region_name="us-east-1") + table_name = "table_with_stream" + table = dynamodb.create_table( + TableName=table_name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + StreamSpecification={ + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", + }, + ) + + conn = boto3.client("lambda", region_name="us-east-1") + func = conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function executed after a DynamoDB table is updated", + Timeout=3, + MemorySize=128, + Publish=True, + ) + + response = conn.create_event_source_mapping( + EventSourceArn=table["TableDescription"]["LatestStreamArn"], + FunctionName=func["FunctionArn"], + ) + + assert response["EventSourceArn"] == table["TableDescription"]["LatestStreamArn"] + assert response["State"] == "Enabled" + + dynamodb.put_item(TableName=table_name, Item={"id": {"S": "item 1"}}) + start = time.time() + while (time.time() - start) < 30: + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") + if not log_streams: + time.sleep(1) + continue + + assert len(log_streams) == 1 + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if event["message"] == "get_test_zip_file3 success": return time.sleep(1) @@ -1031,62 +1142,56 @@ def test_invoke_function_from_sqs(): @mock_lambda @mock_sqs def test_invoke_function_from_sqs_exception(): - logs_conn = boto3.client("logs") - sqs = boto3.resource('sqs') + logs_conn = boto3.client("logs", region_name="us-east-1") + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda", region_name="us-east-1") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file4(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file4()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - assert response['EventSourceArn'] == queue.attributes['QueueArn'] - assert response['State'] == 'Enabled' + assert response["EventSourceArn"] == queue.attributes["QueueArn"] + assert response["State"] == "Enabled" entries = [] for i in range(3): - body = { - "uuid": str(uuid.uuid4()), - "test": "test_{}".format(i), - } - entry = { - 'Id': str(i), - 'MessageBody': json.dumps(body) - } + body = {"uuid": str(uuid.uuid4()), "test": "test_{}".format(i)} + entry = {"Id": str(i), "MessageBody": json.dumps(body)} entries.append(entry) queue.send_messages(Entries=entries) start = time.time() while (time.time() - start) < 30: - result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction') - log_streams = result.get('logStreams') + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") if not log_streams: time.sleep(1) continue assert len(log_streams) >= 1 - result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName']) - for event in result.get('events'): - if 'I failed!' in event['message']: + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if "I failed!" in event["message"]: messages = queue.receive_messages(MaxNumberOfMessages=10) # Verify messages are still visible and unprocessed - assert len(messages) is 3 + assert len(messages) == 3 return time.sleep(1) @@ -1097,151 +1202,358 @@ def test_invoke_function_from_sqs_exception(): @mock_lambda @mock_sqs def test_list_event_source_mappings(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda", region_name="us-east-1") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - mappings = conn.list_event_source_mappings(EventSourceArn='123') - assert len(mappings['EventSourceMappings']) == 0 + mappings = conn.list_event_source_mappings(EventSourceArn="123") + assert len(mappings["EventSourceMappings"]) == 0 - mappings = conn.list_event_source_mappings(EventSourceArn=queue.attributes['QueueArn']) - assert len(mappings['EventSourceMappings']) == 1 - assert mappings['EventSourceMappings'][0]['UUID'] == response['UUID'] - assert mappings['EventSourceMappings'][0]['FunctionArn'] == func['FunctionArn'] + mappings = conn.list_event_source_mappings( + EventSourceArn=queue.attributes["QueueArn"] + ) + assert len(mappings["EventSourceMappings"]) == 1 + assert mappings["EventSourceMappings"][0]["UUID"] == response["UUID"] + assert mappings["EventSourceMappings"][0]["FunctionArn"] == func["FunctionArn"] @mock_lambda @mock_sqs def test_get_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda", region_name="us-east-1") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - mapping = conn.get_event_source_mapping(UUID=response['UUID']) - assert mapping['UUID'] == response['UUID'] - assert mapping['FunctionArn'] == func['FunctionArn'] + mapping = conn.get_event_source_mapping(UUID=response["UUID"]) + assert mapping["UUID"] == response["UUID"] + assert mapping["FunctionArn"] == func["FunctionArn"] - conn.get_event_source_mapping.when.called_with(UUID='1')\ - .should.throw(botocore.client.ClientError) + conn.get_event_source_mapping.when.called_with(UUID="1").should.throw( + botocore.client.ClientError + ) @mock_lambda @mock_sqs def test_update_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda", region_name="us-east-1") func1 = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) func2 = conn.create_function( - FunctionName='testFunction2', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction2", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func1['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func1["FunctionArn"] ) - assert response['FunctionArn'] == func1['FunctionArn'] - assert response['BatchSize'] == 10 - assert response['State'] == 'Enabled' + assert response["FunctionArn"] == func1["FunctionArn"] + assert response["BatchSize"] == 10 + assert response["State"] == "Enabled" mapping = conn.update_event_source_mapping( - UUID=response['UUID'], - Enabled=False, - BatchSize=15, - FunctionName='testFunction2' - + UUID=response["UUID"], Enabled=False, BatchSize=15, FunctionName="testFunction2" ) - assert mapping['UUID'] == response['UUID'] - assert mapping['FunctionArn'] == func2['FunctionArn'] - assert mapping['State'] == 'Disabled' + assert mapping["UUID"] == response["UUID"] + assert mapping["FunctionArn"] == func2["FunctionArn"] + assert mapping["State"] == "Disabled" @mock_lambda @mock_sqs def test_delete_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda", region_name="us-east-1") func1 = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func1['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func1["FunctionArn"] ) - assert response['FunctionArn'] == func1['FunctionArn'] - assert response['BatchSize'] == 10 - assert response['State'] == 'Enabled' + assert response["FunctionArn"] == func1["FunctionArn"] + assert response["BatchSize"] == 10 + assert response["State"] == "Enabled" - response = conn.delete_event_source_mapping(UUID=response['UUID']) + response = conn.delete_event_source_mapping(UUID=response["UUID"]) - assert response['State'] == 'Deleting' - conn.get_event_source_mapping.when.called_with(UUID=response['UUID'])\ - .should.throw(botocore.client.ClientError) + assert response["State"] == "Deleting" + conn.get_event_source_mapping.when.called_with(UUID=response["UUID"]).should.throw( + botocore.client.ClientError + ) + + +@mock_lambda +@mock_s3 +def test_update_configuration(): + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") + + zip_content = get_test_zip_file2() + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") + + fxn = conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", + Timeout=3, + MemorySize=128, + Publish=True, + Environment={"Variables": {"test_old_environment": "test_old_value"}}, + ) + + assert fxn["Description"] == "test lambda function" + assert fxn["Handler"] == "lambda_function.lambda_handler" + assert fxn["MemorySize"] == 128 + assert fxn["Runtime"] == "python2.7" + assert fxn["Timeout"] == 3 + + updated_config = conn.update_function_configuration( + FunctionName="testFunction", + Description="updated test lambda function", + Handler="lambda_function.new_lambda_handler", + Runtime="python3.6", + Timeout=7, + Environment={"Variables": {"test_environment": "test_value"}}, + ) + + assert updated_config["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert updated_config["Description"] == "updated test lambda function" + assert updated_config["Handler"] == "lambda_function.new_lambda_handler" + assert updated_config["MemorySize"] == 128 + assert updated_config["Runtime"] == "python3.6" + assert updated_config["Timeout"] == 7 + assert updated_config["Environment"]["Variables"] == { + "test_environment": "test_value" + } + + +@mock_lambda +def test_update_function_zip(): + conn = boto3.client("lambda", "us-west-2") + + zip_content_one = get_test_zip_file1() + + fxn = conn.create_function( + FunctionName="testFunctionZip", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"ZipFile": zip_content_one}, + Description="test lambda function", + Timeout=3, + MemorySize=128, + Publish=True, + ) + + zip_content_two = get_test_zip_file2() + + fxn_updated = conn.update_function_code( + FunctionName="testFunctionZip", ZipFile=zip_content_two, Publish=True + ) + + response = conn.get_function(FunctionName="testFunctionZip", Qualifier="2") + response["Configuration"].pop("LastModified") + + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + assert len(response["Code"]) == 2 + assert response["Code"]["RepositoryType"] == "S3" + assert response["Code"]["Location"].startswith( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com".format(_lambda_region) + ) + response["Configuration"].should.equal( + { + "CodeSha256": hashlib.sha256(zip_content_two).hexdigest(), + "CodeSize": len(zip_content_two), + "Description": "test lambda function", + "FunctionArn": "arn:aws:lambda:{}:{}:function:testFunctionZip:2".format( + _lambda_region, ACCOUNT_ID + ), + "FunctionName": "testFunctionZip", + "Handler": "lambda_function.lambda_handler", + "MemorySize": 128, + "Role": fxn["Role"], + "Runtime": "python2.7", + "Timeout": 3, + "Version": "2", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + } + ) + + +@mock_lambda +@mock_s3 +def test_update_function_s3(): + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") + + zip_content = get_test_zip_file1() + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + + conn = boto3.client("lambda", "us-west-2") + + fxn = conn.create_function( + FunctionName="testFunctionS3", + Runtime="python2.7", + Role=get_role_name(), + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", + Timeout=3, + MemorySize=128, + Publish=True, + ) + + zip_content_two = get_test_zip_file2() + s3_conn.put_object(Bucket="test-bucket", Key="test2.zip", Body=zip_content_two) + + fxn_updated = conn.update_function_code( + FunctionName="testFunctionS3", + S3Bucket="test-bucket", + S3Key="test2.zip", + Publish=True, + ) + + response = conn.get_function(FunctionName="testFunctionS3", Qualifier="2") + response["Configuration"].pop("LastModified") + + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + assert len(response["Code"]) == 2 + assert response["Code"]["RepositoryType"] == "S3" + assert response["Code"]["Location"].startswith( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com".format(_lambda_region) + ) + response["Configuration"].should.equal( + { + "CodeSha256": hashlib.sha256(zip_content_two).hexdigest(), + "CodeSize": len(zip_content_two), + "Description": "test lambda function", + "FunctionArn": "arn:aws:lambda:{}:{}:function:testFunctionS3:2".format( + _lambda_region, ACCOUNT_ID + ), + "FunctionName": "testFunctionS3", + "Handler": "lambda_function.lambda_handler", + "MemorySize": 128, + "Role": fxn["Role"], + "Runtime": "python2.7", + "Timeout": 3, + "Version": "2", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + } + ) + + +@mock_lambda +def test_create_function_with_invalid_arn(): + err = create_invalid_lambda("test-iam-role") + err.exception.response["Error"]["Message"].should.equal( + "1 validation error detected: Value 'test-iam-role' at 'role' failed to satisfy constraint: Member must satisfy regular expression pattern: arn:(aws[a-zA-Z-]*)?:iam::(\d{12}):role/?[a-zA-Z_0-9+=,.@\-_/]+" + ) + + +@mock_lambda +def test_create_function_with_arn_from_different_account(): + err = create_invalid_lambda("arn:aws:iam::000000000000:role/example_role") + err.exception.response["Error"]["Message"].should.equal( + "Cross-account pass role is not allowed." + ) + + +@mock_lambda +def test_create_function_with_unknown_arn(): + err = create_invalid_lambda( + "arn:aws:iam::" + str(ACCOUNT_ID) + ":role/service-role/unknown_role" + ) + err.exception.response["Error"]["Message"].should.equal( + "The role defined for the function cannot be assumed by Lambda." + ) + + +def create_invalid_lambda(role): + conn = boto3.client("lambda", "us-west-2") + zip_content = get_test_zip_file1() + with assert_raises(ClientError) as err: + conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role=role, + Handler="lambda_function.handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", + Timeout=3, + MemorySize=128, + Publish=True, + ) + return err + + +def get_role_name(): + with mock_iam(): + iam = boto3.client("iam", region_name="us-west-2") + try: + return iam.get_role(RoleName="my-role")["Role"]["Arn"] + except ClientError: + return iam.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="some policy", + Path="/my-path/", + )["Role"]["Arn"] diff --git a/tests/test_awslambda/test_lambda_cloudformation.py b/tests/test_awslambda/test_lambda_cloudformation.py new file mode 100644 index 000000000..a5d4d23fd --- /dev/null +++ b/tests/test_awslambda/test_lambda_cloudformation.py @@ -0,0 +1,138 @@ +import boto3 +import io +import sure # noqa +import zipfile +from botocore.exceptions import ClientError +from moto import mock_cloudformation, mock_iam, mock_lambda, mock_s3 +from nose.tools import assert_raises +from string import Template +from uuid import uuid4 + + +def _process_lambda(func_str): + zip_output = io.BytesIO() + zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) + zip_file.writestr("lambda_function.py", func_str) + zip_file.close() + zip_output.seek(0) + return zip_output.read() + + +def get_zip_file(): + pfunc = """ +def lambda_handler1(event, context): + return event +def lambda_handler2(event, context): + return event +""" + return _process_lambda(pfunc) + + +template = Template( + """{ + "AWSTemplateFormatVersion": "2010-09-09", + "Resources": { + "LF3ABOV": { + "Type": "AWS::Lambda::Function", + "Properties": { + "Handler": "$handler", + "Role": "$role_arn", + "Runtime": "$runtime", + "Code": { + "S3Bucket": "$bucket_name", + "S3Key": "$key" + }, + } + } + } +}""" +) + + +@mock_cloudformation +@mock_lambda +@mock_s3 +def test_lambda_can_be_updated_by_cloudformation(): + s3 = boto3.client("s3", "us-east-1") + cf = boto3.client("cloudformation", region_name="us-east-1") + lmbda = boto3.client("lambda", region_name="us-east-1") + body2, stack = create_stack(cf, s3) + created_fn_name = get_created_function_name(cf, stack) + # Verify function has been created + created_fn = lmbda.get_function(FunctionName=created_fn_name) + created_fn["Configuration"]["Handler"].should.equal( + "lambda_function.lambda_handler1" + ) + created_fn["Configuration"]["Runtime"].should.equal("python3.7") + created_fn["Code"]["Location"].should.match("/test1.zip") + # Update CF stack + cf.update_stack(StackName="teststack", TemplateBody=body2) + updated_fn_name = get_created_function_name(cf, stack) + # Verify function has been updated + updated_fn = lmbda.get_function(FunctionName=updated_fn_name) + updated_fn["Configuration"]["FunctionArn"].should.equal( + created_fn["Configuration"]["FunctionArn"] + ) + updated_fn["Configuration"]["Handler"].should.equal( + "lambda_function.lambda_handler2" + ) + updated_fn["Configuration"]["Runtime"].should.equal("python3.8") + updated_fn["Code"]["Location"].should.match("/test2.zip") + + +@mock_cloudformation +@mock_lambda +@mock_s3 +def test_lambda_can_be_deleted_by_cloudformation(): + s3 = boto3.client("s3", "us-east-1") + cf = boto3.client("cloudformation", region_name="us-east-1") + lmbda = boto3.client("lambda", region_name="us-east-1") + _, stack = create_stack(cf, s3) + created_fn_name = get_created_function_name(cf, stack) + # Delete Stack + cf.delete_stack(StackName=stack["StackId"]) + # Verify function was deleted + with assert_raises(ClientError) as e: + lmbda.get_function(FunctionName=created_fn_name) + e.exception.response["Error"]["Code"].should.equal("404") + + +def create_stack(cf, s3): + bucket_name = str(uuid4()) + s3.create_bucket(Bucket=bucket_name) + s3.put_object(Bucket=bucket_name, Key="test1.zip", Body=get_zip_file()) + s3.put_object(Bucket=bucket_name, Key="test2.zip", Body=get_zip_file()) + body1 = get_template(bucket_name, "1", "python3.7") + body2 = get_template(bucket_name, "2", "python3.8") + stack = cf.create_stack(StackName="teststack", TemplateBody=body1) + return body2, stack + + +def get_created_function_name(cf, stack): + res = cf.list_stack_resources(StackName=stack["StackId"]) + return res["StackResourceSummaries"][0]["PhysicalResourceId"] + + +def get_template(bucket_name, version, runtime): + key = "test" + version + ".zip" + handler = "lambda_function.lambda_handler" + version + return template.substitute( + bucket_name=bucket_name, + key=key, + handler=handler, + role_arn=get_role_arn(), + runtime=runtime, + ) + + +def get_role_arn(): + with mock_iam(): + iam = boto3.client("iam", region_name="us-west-2") + try: + return iam.get_role(RoleName="my-role")["Role"]["Arn"] + except ClientError: + return iam.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="some policy", + Path="/my-path/", + )["Role"]["Arn"] diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index 89a8d4d0e..141d6b343 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -17,17 +17,21 @@ def expected_failure(test): test(*args, **kwargs) except Exception as err: raise nose.SkipTest + return inner -DEFAULT_REGION = 'eu-central-1' + +DEFAULT_REGION = "eu-central-1" def _get_clients(): - return boto3.client('ec2', region_name=DEFAULT_REGION), \ - boto3.client('iam', region_name=DEFAULT_REGION), \ - boto3.client('ecs', region_name=DEFAULT_REGION), \ - boto3.client('logs', region_name=DEFAULT_REGION), \ - boto3.client('batch', region_name=DEFAULT_REGION) + return ( + boto3.client("ec2", region_name=DEFAULT_REGION), + boto3.client("iam", region_name=DEFAULT_REGION), + boto3.client("ecs", region_name=DEFAULT_REGION), + boto3.client("logs", region_name=DEFAULT_REGION), + boto3.client("batch", region_name=DEFAULT_REGION), + ) def _setup(ec2_client, iam_client): @@ -36,26 +40,25 @@ def _setup(ec2_client, iam_client): :return: VPC ID, Subnet ID, Security group ID, IAM Role ARN :rtype: tuple """ - resp = ec2_client.create_vpc(CidrBlock='172.30.0.0/24') - vpc_id = resp['Vpc']['VpcId'] + resp = ec2_client.create_vpc(CidrBlock="172.30.0.0/24") + vpc_id = resp["Vpc"]["VpcId"] resp = ec2_client.create_subnet( - AvailabilityZone='eu-central-1a', - CidrBlock='172.30.0.0/25', - VpcId=vpc_id + AvailabilityZone="eu-central-1a", CidrBlock="172.30.0.0/25", VpcId=vpc_id ) - subnet_id = resp['Subnet']['SubnetId'] + subnet_id = resp["Subnet"]["SubnetId"] resp = ec2_client.create_security_group( - Description='test_sg_desc', - GroupName='test_sg', - VpcId=vpc_id + Description="test_sg_desc", GroupName="test_sg", VpcId=vpc_id ) - sg_id = resp['GroupId'] + sg_id = resp["GroupId"] resp = iam_client.create_role( - RoleName='TestRole', - AssumeRolePolicyDocument='some_policy' + RoleName="TestRole", AssumeRolePolicyDocument="some_policy" + ) + iam_arn = resp["Role"]["Arn"] + iam_client.create_instance_profile(InstanceProfileName="TestRole") + iam_client.add_role_to_instance_profile( + InstanceProfileName="TestRole", RoleName="TestRole" ) - iam_arn = resp['Role']['Arn'] return vpc_id, subnet_id, sg_id, iam_arn @@ -69,49 +72,40 @@ def test_create_managed_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='MANAGED', - state='ENABLED', + type="MANAGED", + state="ENABLED", computeResources={ - 'type': 'EC2', - 'minvCpus': 5, - 'maxvCpus': 10, - 'desiredvCpus': 5, - 'instanceTypes': [ - 't2.small', - 't2.medium' - ], - 'imageId': 'some_image_id', - 'subnets': [ - subnet_id, - ], - 'securityGroupIds': [ - sg_id, - ], - 'ec2KeyPair': 'string', - 'instanceRole': iam_arn, - 'tags': { - 'string': 'string' - }, - 'bidPercentage': 123, - 'spotIamFleetRole': 'string' + "type": "EC2", + "minvCpus": 5, + "maxvCpus": 10, + "desiredvCpus": 5, + "instanceTypes": ["t2.small", "t2.medium"], + "imageId": "some_image_id", + "subnets": [subnet_id], + "securityGroupIds": [sg_id], + "ec2KeyPair": "string", + "instanceRole": iam_arn.replace("role", "instance-profile"), + "tags": {"string": "string"}, + "bidPercentage": 123, + "spotIamFleetRole": "string", }, - serviceRole=iam_arn + serviceRole=iam_arn, ) - resp.should.contain('computeEnvironmentArn') - resp['computeEnvironmentName'].should.equal(compute_name) + resp.should.contain("computeEnvironmentArn") + resp["computeEnvironmentName"].should.equal(compute_name) # Given a t2.medium is 2 vcpu and t2.small is 1, therefore 2 mediums and 1 small should be created resp = ec2_client.describe_instances() - resp.should.contain('Reservations') - len(resp['Reservations']).should.equal(3) + resp.should.contain("Reservations") + len(resp["Reservations"]).should.equal(3) # Should have created 1 ECS cluster resp = ecs_client.list_clusters() - resp.should.contain('clusterArns') - len(resp['clusterArns']).should.equal(1) + resp.should.contain("clusterArns") + len(resp["clusterArns"]).should.equal(1) @mock_ec2 @@ -122,25 +116,26 @@ def test_create_unmanaged_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - resp.should.contain('computeEnvironmentArn') - resp['computeEnvironmentName'].should.equal(compute_name) + resp.should.contain("computeEnvironmentArn") + resp["computeEnvironmentName"].should.equal(compute_name) # Its unmanaged so no instances should be created resp = ec2_client.describe_instances() - resp.should.contain('Reservations') - len(resp['Reservations']).should.equal(0) + resp.should.contain("Reservations") + len(resp["Reservations"]).should.equal(0) # Should have created 1 ECS cluster resp = ecs_client.list_clusters() - resp.should.contain('clusterArns') - len(resp['clusterArns']).should.equal(1) + resp.should.contain("clusterArns") + len(resp["clusterArns"]).should.equal(1) + # TODO create 1000s of tests to test complex option combinations of create environment @@ -153,23 +148,21 @@ def test_describe_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(1) - resp['computeEnvironments'][0]['computeEnvironmentName'].should.equal(compute_name) + len(resp["computeEnvironments"]).should.equal(1) + resp["computeEnvironments"][0]["computeEnvironmentName"].should.equal(compute_name) # Test filtering - resp = batch_client.describe_compute_environments( - computeEnvironments=['test1'] - ) - len(resp['computeEnvironments']).should.equal(0) + resp = batch_client.describe_compute_environments(computeEnvironments=["test1"]) + len(resp["computeEnvironments"]).should.equal(0) @mock_ec2 @@ -180,23 +173,21 @@ def test_delete_unmanaged_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - batch_client.delete_compute_environment( - computeEnvironment=compute_name, - ) + batch_client.delete_compute_environment(computeEnvironment=compute_name) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(0) + len(resp["computeEnvironments"]).should.equal(0) resp = ecs_client.list_clusters() - len(resp.get('clusterArns', [])).should.equal(0) + len(resp.get("clusterArns", [])).should.equal(0) @mock_ec2 @@ -207,53 +198,42 @@ def test_delete_managed_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='MANAGED', - state='ENABLED', + type="MANAGED", + state="ENABLED", computeResources={ - 'type': 'EC2', - 'minvCpus': 5, - 'maxvCpus': 10, - 'desiredvCpus': 5, - 'instanceTypes': [ - 't2.small', - 't2.medium' - ], - 'imageId': 'some_image_id', - 'subnets': [ - subnet_id, - ], - 'securityGroupIds': [ - sg_id, - ], - 'ec2KeyPair': 'string', - 'instanceRole': iam_arn, - 'tags': { - 'string': 'string' - }, - 'bidPercentage': 123, - 'spotIamFleetRole': 'string' + "type": "EC2", + "minvCpus": 5, + "maxvCpus": 10, + "desiredvCpus": 5, + "instanceTypes": ["t2.small", "t2.medium"], + "imageId": "some_image_id", + "subnets": [subnet_id], + "securityGroupIds": [sg_id], + "ec2KeyPair": "string", + "instanceRole": iam_arn.replace("role", "instance-profile"), + "tags": {"string": "string"}, + "bidPercentage": 123, + "spotIamFleetRole": "string", }, - serviceRole=iam_arn + serviceRole=iam_arn, ) - batch_client.delete_compute_environment( - computeEnvironment=compute_name, - ) + batch_client.delete_compute_environment(computeEnvironment=compute_name) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(0) + len(resp["computeEnvironments"]).should.equal(0) resp = ec2_client.describe_instances() - resp.should.contain('Reservations') - len(resp['Reservations']).should.equal(3) - for reservation in resp['Reservations']: - reservation['Instances'][0]['State']['Name'].should.equal('terminated') + resp.should.contain("Reservations") + len(resp["Reservations"]).should.equal(3) + for reservation in resp["Reservations"]: + reservation["Instances"][0]["State"]["Name"].should.equal("terminated") resp = ecs_client.list_clusters() - len(resp.get('clusterArns', [])).should.equal(0) + len(resp.get("clusterArns", [])).should.equal(0) @mock_ec2 @@ -264,22 +244,21 @@ def test_update_unmanaged_compute_environment_state(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) batch_client.update_compute_environment( - computeEnvironment=compute_name, - state='DISABLED' + computeEnvironment=compute_name, state="DISABLED" ) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(1) - resp['computeEnvironments'][0]['state'].should.equal('DISABLED') + len(resp["computeEnvironments"]).should.equal(1) + resp["computeEnvironments"][0]["state"].should.equal("DISABLED") @mock_ec2 @@ -290,87 +269,70 @@ def test_create_job_queue(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - resp.should.contain('jobQueueArn') - resp.should.contain('jobQueueName') - queue_arn = resp['jobQueueArn'] + resp.should.contain("jobQueueArn") + resp.should.contain("jobQueueName") + queue_arn = resp["jobQueueArn"] resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(1) - resp['jobQueues'][0]['jobQueueArn'].should.equal(queue_arn) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(1) + resp["jobQueues"][0]["jobQueueArn"].should.equal(queue_arn) - resp = batch_client.describe_job_queues(jobQueues=['test_invalid_queue']) - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(0) + resp = batch_client.describe_job_queues(jobQueues=["test_invalid_queue"]) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(0) # Create job queue which already exists try: resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') - + err.response["Error"]["Code"].should.equal("ClientException") # Create job queue with incorrect state try: resp = batch_client.create_job_queue( - jobQueueName='test_job_queue2', - state='JUNK', + jobQueueName="test_job_queue2", + state="JUNK", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") # Create job queue with no compute env try: resp = batch_client.create_job_queue( - jobQueueName='test_job_queue3', - state='JUNK', + jobQueueName="test_job_queue3", + state="JUNK", priority=123, - computeEnvironmentOrder=[ - - ] + computeEnvironmentOrder=[], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") + @mock_ec2 @mock_ecs @@ -380,29 +342,26 @@ def test_job_queue_bad_arn(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] try: batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn + 'LALALA' - }, - ] + {"order": 123, "computeEnvironment": arn + "LALALA"} + ], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") @mock_ec2 @@ -413,48 +372,36 @@ def test_update_job_queue(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] - batch_client.update_job_queue( - jobQueue=queue_arn, - priority=5 - ) + batch_client.update_job_queue(jobQueue=queue_arn, priority=5) resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(1) - resp['jobQueues'][0]['priority'].should.equal(5) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(1) + resp["jobQueues"][0]["priority"].should.equal(5) - batch_client.update_job_queue( - jobQueue='test_job_queue', - priority=5 - ) + batch_client.update_job_queue(jobQueue="test_job_queue", priority=5) resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(1) - resp['jobQueues'][0]['priority'].should.equal(5) - + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(1) + resp["jobQueues"][0]["priority"].should.equal(5) @mock_ec2 @@ -465,35 +412,28 @@ def test_update_job_queue(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] - batch_client.delete_job_queue( - jobQueue=queue_arn - ) + batch_client.delete_job_queue(jobQueue=queue_arn) resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(0) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(0) @mock_ec2 @@ -505,21 +445,23 @@ def test_register_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - resp.should.contain('jobDefinitionArn') - resp.should.contain('jobDefinitionName') - resp.should.contain('revision') + resp.should.contain("jobDefinitionArn") + resp.should.contain("jobDefinitionName") + resp.should.contain("revision") - assert resp['jobDefinitionArn'].endswith('{0}:{1}'.format(resp['jobDefinitionName'], resp['revision'])) + assert resp["jobDefinitionArn"].endswith( + "{0}:{1}".format(resp["jobDefinitionName"], resp["revision"]) + ) @mock_ec2 @@ -532,36 +474,69 @@ def test_reregister_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp1 = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - resp1.should.contain('jobDefinitionArn') - resp1.should.contain('jobDefinitionName') - resp1.should.contain('revision') + resp1.should.contain("jobDefinitionArn") + resp1.should.contain("jobDefinitionName") + resp1.should.contain("revision") - assert resp1['jobDefinitionArn'].endswith('{0}:{1}'.format(resp1['jobDefinitionName'], resp1['revision'])) - resp1['revision'].should.equal(1) + assert resp1["jobDefinitionArn"].endswith( + "{0}:{1}".format(resp1["jobDefinitionName"], resp1["revision"]) + ) + resp1["revision"].should.equal(1) resp2 = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 68, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 68, + "command": ["sleep", "10"], + }, ) - resp2['revision'].should.equal(2) + resp2["revision"].should.equal(2) - resp2['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp2["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) + + resp3 = batch_client.register_job_definition( + jobDefinitionName="sleep10", + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 42, + "command": ["sleep", "10"], + }, + ) + resp3["revision"].should.equal(3) + + resp3["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) + resp3["jobDefinitionArn"].should_not.equal(resp2["jobDefinitionArn"]) + + resp4 = batch_client.register_job_definition( + jobDefinitionName="sleep10", + type="container", + containerProperties={ + "image": "busybox", + "vcpus": 1, + "memory": 41, + "command": ["sleep", "10"], + }, + ) + resp4["revision"].should.equal(4) + + resp4["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) + resp4["jobDefinitionArn"].should_not.equal(resp2["jobDefinitionArn"]) + resp4["jobDefinitionArn"].should_not.equal(resp3["jobDefinitionArn"]) @mock_ec2 @@ -573,20 +548,20 @@ def test_delete_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - batch_client.deregister_job_definition(jobDefinition=resp['jobDefinitionArn']) + batch_client.deregister_job_definition(jobDefinition=resp["jobDefinitionArn"]) resp = batch_client.describe_job_definitions() - len(resp['jobDefinitions']).should.equal(0) + len(resp["jobDefinitions"]).should.equal(0) @mock_ec2 @@ -598,48 +573,47 @@ def test_describe_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 64, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 64, + "command": ["sleep", "10"], + }, ) batch_client.register_job_definition( - jobDefinitionName='test1', - type='container', + jobDefinitionName="test1", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 64, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 64, + "command": ["sleep", "10"], + }, ) - resp = batch_client.describe_job_definitions( - jobDefinitionName='sleep10' - ) - len(resp['jobDefinitions']).should.equal(2) + resp = batch_client.describe_job_definitions(jobDefinitionName="sleep10") + len(resp["jobDefinitions"]).should.equal(2) resp = batch_client.describe_job_definitions() - len(resp['jobDefinitions']).should.equal(3) + len(resp["jobDefinitions"]).should.equal(3) - resp = batch_client.describe_job_definitions( - jobDefinitions=['sleep10', 'test1'] - ) - len(resp['jobDefinitions']).should.equal(3) + resp = batch_client.describe_job_definitions(jobDefinitions=["sleep10", "test1"]) + len(resp["jobDefinitions"]).should.equal(3) + + for job_definition in resp["jobDefinitions"]: + job_definition["status"].should.equal("ACTIVE") @mock_logs @@ -651,77 +625,71 @@ def test_submit_job_by_name(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] - job_definition_name = 'sleep10' + job_definition_name = "sleep10" batch_client.register_job_definition( jobDefinitionName=job_definition_name, - type='container', + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) batch_client.register_job_definition( jobDefinitionName=job_definition_name, - type='container', + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 256, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 256, + "command": ["sleep", "10"], + }, ) resp = batch_client.register_job_definition( jobDefinitionName=job_definition_name, - type='container', + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 512, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 512, + "command": ["sleep", "10"], + }, ) - job_definition_arn = resp['jobDefinitionArn'] + job_definition_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_definition_name + jobName="test1", jobQueue=queue_arn, jobDefinition=job_definition_name ) - job_id = resp['jobId'] + job_id = resp["jobId"] resp_jobs = batch_client.describe_jobs(jobs=[job_id]) # batch_client.terminate_job(jobId=job_id) - len(resp_jobs['jobs']).should.equal(1) - resp_jobs['jobs'][0]['jobId'].should.equal(job_id) - resp_jobs['jobs'][0]['jobQueue'].should.equal(queue_arn) - resp_jobs['jobs'][0]['jobDefinition'].should.equal(job_definition_arn) + len(resp_jobs["jobs"]).should.equal(1) + resp_jobs["jobs"][0]["jobId"].should.equal(job_id) + resp_jobs["jobs"][0]["jobQueue"].should.equal(queue_arn) + resp_jobs["jobs"][0]["jobDefinition"].should.equal(job_definition_arn) + # SLOW TESTS @expected_failure @@ -734,67 +702,68 @@ def test_submit_job(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - job_def_arn = resp['jobDefinitionArn'] + job_def_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id = resp['jobId'] + job_id = resp["jobId"] future = datetime.datetime.now() + datetime.timedelta(seconds=30) while datetime.datetime.now() < future: resp = batch_client.describe_jobs(jobs=[job_id]) - print("{0}:{1} {2}".format(resp['jobs'][0]['jobName'], resp['jobs'][0]['jobId'], resp['jobs'][0]['status'])) + print( + "{0}:{1} {2}".format( + resp["jobs"][0]["jobName"], + resp["jobs"][0]["jobId"], + resp["jobs"][0]["status"], + ) + ) - if resp['jobs'][0]['status'] == 'FAILED': - raise RuntimeError('Batch job failed') - if resp['jobs'][0]['status'] == 'SUCCEEDED': + if resp["jobs"][0]["status"] == "FAILED": + raise RuntimeError("Batch job failed") + if resp["jobs"][0]["status"] == "SUCCEEDED": break time.sleep(0.5) else: - raise RuntimeError('Batch job timed out') + raise RuntimeError("Batch job timed out") - resp = logs_client.describe_log_streams(logGroupName='/aws/batch/job') - len(resp['logStreams']).should.equal(1) - ls_name = resp['logStreams'][0]['logStreamName'] + resp = logs_client.describe_log_streams(logGroupName="/aws/batch/job") + len(resp["logStreams"]).should.equal(1) + ls_name = resp["logStreams"][0]["logStreamName"] - resp = logs_client.get_log_events(logGroupName='/aws/batch/job', logStreamName=ls_name) - len(resp['events']).should.be.greater_than(5) + resp = logs_client.get_log_events( + logGroupName="/aws/batch/job", logStreamName=ls_name + ) + len(resp["events"]).should.be.greater_than(5) @expected_failure @@ -807,82 +776,71 @@ def test_list_jobs(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - job_def_arn = resp['jobDefinitionArn'] + job_def_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id1 = resp['jobId'] + job_id1 = resp["jobId"] resp = batch_client.submit_job( - jobName='test2', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test2", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id2 = resp['jobId'] + job_id2 = resp["jobId"] future = datetime.datetime.now() + datetime.timedelta(seconds=30) resp_finished_jobs = batch_client.list_jobs( - jobQueue=queue_arn, - jobStatus='SUCCEEDED' + jobQueue=queue_arn, jobStatus="SUCCEEDED" ) # Wait only as long as it takes to run the jobs while datetime.datetime.now() < future: resp = batch_client.describe_jobs(jobs=[job_id1, job_id2]) - any_failed_jobs = any([job['status'] == 'FAILED' for job in resp['jobs']]) - succeeded_jobs = all([job['status'] == 'SUCCEEDED' for job in resp['jobs']]) + any_failed_jobs = any([job["status"] == "FAILED" for job in resp["jobs"]]) + succeeded_jobs = all([job["status"] == "SUCCEEDED" for job in resp["jobs"]]) if any_failed_jobs: - raise RuntimeError('A Batch job failed') + raise RuntimeError("A Batch job failed") if succeeded_jobs: break time.sleep(0.5) else: - raise RuntimeError('Batch jobs timed out') + raise RuntimeError("Batch jobs timed out") resp_finished_jobs2 = batch_client.list_jobs( - jobQueue=queue_arn, - jobStatus='SUCCEEDED' + jobQueue=queue_arn, jobStatus="SUCCEEDED" ) - len(resp_finished_jobs['jobSummaryList']).should.equal(0) - len(resp_finished_jobs2['jobSummaryList']).should.equal(2) + len(resp_finished_jobs["jobSummaryList"]).should.equal(0) + len(resp_finished_jobs2["jobSummaryList"]).should.equal(2) @expected_failure @@ -895,55 +853,47 @@ def test_terminate_job(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - job_def_arn = resp['jobDefinitionArn'] + job_def_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id = resp['jobId'] + job_id = resp["jobId"] time.sleep(2) - batch_client.terminate_job(jobId=job_id, reason='test_terminate') + batch_client.terminate_job(jobId=job_id, reason="test_terminate") time.sleep(1) resp = batch_client.describe_jobs(jobs=[job_id]) - resp['jobs'][0]['jobName'].should.equal('test1') - resp['jobs'][0]['status'].should.equal('FAILED') - resp['jobs'][0]['statusReason'].should.equal('test_terminate') - + resp["jobs"][0]["jobName"].should.equal("test1") + resp["jobs"][0]["status"].should.equal("FAILED") + resp["jobs"][0]["statusReason"].should.equal("test_terminate") diff --git a/tests/test_batch/test_cloudformation.py b/tests/test_batch/test_cloudformation.py index 1e37aa3a6..a6baedb38 100644 --- a/tests/test_batch/test_cloudformation.py +++ b/tests/test_batch/test_cloudformation.py @@ -5,20 +5,29 @@ import datetime import boto3 from botocore.exceptions import ClientError import sure # noqa -from moto import mock_batch, mock_iam, mock_ec2, mock_ecs, mock_logs, mock_cloudformation +from moto import ( + mock_batch, + mock_iam, + mock_ec2, + mock_ecs, + mock_logs, + mock_cloudformation, +) import functools import nose import json -DEFAULT_REGION = 'eu-central-1' +DEFAULT_REGION = "eu-central-1" def _get_clients(): - return boto3.client('ec2', region_name=DEFAULT_REGION), \ - boto3.client('iam', region_name=DEFAULT_REGION), \ - boto3.client('ecs', region_name=DEFAULT_REGION), \ - boto3.client('logs', region_name=DEFAULT_REGION), \ - boto3.client('batch', region_name=DEFAULT_REGION) + return ( + boto3.client("ec2", region_name=DEFAULT_REGION), + boto3.client("iam", region_name=DEFAULT_REGION), + boto3.client("ecs", region_name=DEFAULT_REGION), + boto3.client("logs", region_name=DEFAULT_REGION), + boto3.client("batch", region_name=DEFAULT_REGION), + ) def _setup(ec2_client, iam_client): @@ -27,26 +36,25 @@ def _setup(ec2_client, iam_client): :return: VPC ID, Subnet ID, Security group ID, IAM Role ARN :rtype: tuple """ - resp = ec2_client.create_vpc(CidrBlock='172.30.0.0/24') - vpc_id = resp['Vpc']['VpcId'] + resp = ec2_client.create_vpc(CidrBlock="172.30.0.0/24") + vpc_id = resp["Vpc"]["VpcId"] resp = ec2_client.create_subnet( - AvailabilityZone='eu-central-1a', - CidrBlock='172.30.0.0/25', - VpcId=vpc_id + AvailabilityZone="eu-central-1a", CidrBlock="172.30.0.0/25", VpcId=vpc_id ) - subnet_id = resp['Subnet']['SubnetId'] + subnet_id = resp["Subnet"]["SubnetId"] resp = ec2_client.create_security_group( - Description='test_sg_desc', - GroupName='test_sg', - VpcId=vpc_id + Description="test_sg_desc", GroupName="test_sg", VpcId=vpc_id ) - sg_id = resp['GroupId'] + sg_id = resp["GroupId"] resp = iam_client.create_role( - RoleName='TestRole', - AssumeRolePolicyDocument='some_policy' + RoleName="TestRole", AssumeRolePolicyDocument="some_policy" + ) + iam_arn = resp["Role"]["Arn"] + iam_client.create_instance_profile(InstanceProfileName="TestRole") + iam_client.add_role_to_instance_profile( + InstanceProfileName="TestRole", RoleName="TestRole" ) - iam_arn = resp['Role']['Arn'] return vpc_id, subnet_id, sg_id, iam_arn @@ -61,7 +69,7 @@ def test_create_env_cf(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) create_environment_template = { - 'Resources': { + "Resources": { "ComputeEnvironment": { "Type": "AWS::Batch::ComputeEnvironment", "Properties": { @@ -71,32 +79,35 @@ def test_create_env_cf(): "MinvCpus": 0, "DesiredvCpus": 0, "MaxvCpus": 64, - "InstanceTypes": [ - "optimal" - ], + "InstanceTypes": ["optimal"], "Subnets": [subnet_id], "SecurityGroupIds": [sg_id], - "InstanceRole": iam_arn + "InstanceRole": iam_arn.replace("role", "instance-profile"), }, - "ServiceRole": iam_arn - } + "ServiceRole": iam_arn, + }, } } } cf_json = json.dumps(create_environment_template) - cf_conn = boto3.client('cloudformation', DEFAULT_REGION) - stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=cf_json, - )['StackId'] + cf_conn = boto3.client("cloudformation", DEFAULT_REGION) + stack_id = cf_conn.create_stack(StackName="test_stack", TemplateBody=cf_json)[ + "StackId" + ] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - stack_resources['StackResourceSummaries'][0]['ResourceStatus'].should.equal('CREATE_COMPLETE') + stack_resources["StackResourceSummaries"][0]["ResourceStatus"].should.equal( + "CREATE_COMPLETE" + ) # Spot checks on the ARN - stack_resources['StackResourceSummaries'][0]['PhysicalResourceId'].startswith('arn:aws:batch:') - stack_resources['StackResourceSummaries'][0]['PhysicalResourceId'].should.contain('test_stack') + stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"].startswith( + "arn:aws:batch:" + ) + stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"].should.contain( + "test_stack" + ) @mock_cloudformation() @@ -109,7 +120,7 @@ def test_create_job_queue_cf(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) create_environment_template = { - 'Resources': { + "Resources": { "ComputeEnvironment": { "Type": "AWS::Batch::ComputeEnvironment", "Properties": { @@ -119,17 +130,14 @@ def test_create_job_queue_cf(): "MinvCpus": 0, "DesiredvCpus": 0, "MaxvCpus": 64, - "InstanceTypes": [ - "optimal" - ], + "InstanceTypes": ["optimal"], "Subnets": [subnet_id], "SecurityGroupIds": [sg_id], - "InstanceRole": iam_arn + "InstanceRole": iam_arn.replace("role", "instance-profile"), }, - "ServiceRole": iam_arn - } + "ServiceRole": iam_arn, + }, }, - "JobQueue": { "Type": "AWS::Batch::JobQueue", "Properties": { @@ -137,31 +145,35 @@ def test_create_job_queue_cf(): "ComputeEnvironmentOrder": [ { "Order": 1, - "ComputeEnvironment": {"Ref": "ComputeEnvironment"} + "ComputeEnvironment": {"Ref": "ComputeEnvironment"}, } - ] - } + ], + }, }, } } cf_json = json.dumps(create_environment_template) - cf_conn = boto3.client('cloudformation', DEFAULT_REGION) - stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=cf_json, - )['StackId'] + cf_conn = boto3.client("cloudformation", DEFAULT_REGION) + stack_id = cf_conn.create_stack(StackName="test_stack", TemplateBody=cf_json)[ + "StackId" + ] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - len(stack_resources['StackResourceSummaries']).should.equal(2) + len(stack_resources["StackResourceSummaries"]).should.equal(2) - job_queue_resource = list(filter(lambda item: item['ResourceType'] == 'AWS::Batch::JobQueue', stack_resources['StackResourceSummaries']))[0] + job_queue_resource = list( + filter( + lambda item: item["ResourceType"] == "AWS::Batch::JobQueue", + stack_resources["StackResourceSummaries"], + ) + )[0] - job_queue_resource['ResourceStatus'].should.equal('CREATE_COMPLETE') + job_queue_resource["ResourceStatus"].should.equal("CREATE_COMPLETE") # Spot checks on the ARN - job_queue_resource['PhysicalResourceId'].startswith('arn:aws:batch:') - job_queue_resource['PhysicalResourceId'].should.contain('test_stack') - job_queue_resource['PhysicalResourceId'].should.contain('job-queue/') + job_queue_resource["PhysicalResourceId"].startswith("arn:aws:batch:") + job_queue_resource["PhysicalResourceId"].should.contain("test_stack") + job_queue_resource["PhysicalResourceId"].should.contain("job-queue/") @mock_cloudformation() @@ -174,7 +186,7 @@ def test_create_job_def_cf(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) create_environment_template = { - 'Resources': { + "Resources": { "ComputeEnvironment": { "Type": "AWS::Batch::ComputeEnvironment", "Properties": { @@ -184,17 +196,14 @@ def test_create_job_def_cf(): "MinvCpus": 0, "DesiredvCpus": 0, "MaxvCpus": 64, - "InstanceTypes": [ - "optimal" - ], + "InstanceTypes": ["optimal"], "Subnets": [subnet_id], "SecurityGroupIds": [sg_id], - "InstanceRole": iam_arn + "InstanceRole": iam_arn.replace("role", "instance-profile"), }, - "ServiceRole": iam_arn - } + "ServiceRole": iam_arn, + }, }, - "JobQueue": { "Type": "AWS::Batch::JobQueue", "Properties": { @@ -202,46 +211,54 @@ def test_create_job_def_cf(): "ComputeEnvironmentOrder": [ { "Order": 1, - "ComputeEnvironment": {"Ref": "ComputeEnvironment"} + "ComputeEnvironment": {"Ref": "ComputeEnvironment"}, } - ] - } + ], + }, }, - "JobDefinition": { "Type": "AWS::Batch::JobDefinition", "Properties": { "Type": "container", "ContainerProperties": { "Image": { - "Fn::Join": ["", ["137112412989.dkr.ecr.", {"Ref": "AWS::Region"}, ".amazonaws.com/amazonlinux:latest"]] + "Fn::Join": [ + "", + [ + "137112412989.dkr.ecr.", + {"Ref": "AWS::Region"}, + ".amazonaws.com/amazonlinux:latest", + ], + ] }, "Vcpus": 2, "Memory": 2000, - "Command": ["echo", "Hello world"] + "Command": ["echo", "Hello world"], }, - "RetryStrategy": { - "Attempts": 1 - } - } + "RetryStrategy": {"Attempts": 1}, + }, }, } } cf_json = json.dumps(create_environment_template) - cf_conn = boto3.client('cloudformation', DEFAULT_REGION) - stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=cf_json, - )['StackId'] + cf_conn = boto3.client("cloudformation", DEFAULT_REGION) + stack_id = cf_conn.create_stack(StackName="test_stack", TemplateBody=cf_json)[ + "StackId" + ] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - len(stack_resources['StackResourceSummaries']).should.equal(3) + len(stack_resources["StackResourceSummaries"]).should.equal(3) - job_def_resource = list(filter(lambda item: item['ResourceType'] == 'AWS::Batch::JobDefinition', stack_resources['StackResourceSummaries']))[0] + job_def_resource = list( + filter( + lambda item: item["ResourceType"] == "AWS::Batch::JobDefinition", + stack_resources["StackResourceSummaries"], + ) + )[0] - job_def_resource['ResourceStatus'].should.equal('CREATE_COMPLETE') + job_def_resource["ResourceStatus"].should.equal("CREATE_COMPLETE") # Spot checks on the ARN - job_def_resource['PhysicalResourceId'].startswith('arn:aws:batch:') - job_def_resource['PhysicalResourceId'].should.contain('test_stack-JobDef') - job_def_resource['PhysicalResourceId'].should.contain('job-definition/') + job_def_resource["PhysicalResourceId"].startswith("arn:aws:batch:") + job_def_resource["PhysicalResourceId"].should.contain("test_stack-JobDef") + job_def_resource["PhysicalResourceId"].should.contain("job-definition/") diff --git a/tests/test_batch/test_server.py b/tests/test_batch/test_server.py index 4a74260a8..91b5f0c47 100644 --- a/tests/test_batch/test_server.py +++ b/tests/test_batch/test_server.py @@ -5,9 +5,9 @@ import sure # noqa import moto.server as server from moto import mock_batch -''' +""" Test the different server responses -''' +""" @mock_batch @@ -15,5 +15,5 @@ def test_batch_list(): backend = server.create_backend_app("batch") test_client = backend.test_client() - res = test_client.get('/v1/describecomputeenvironments') + res = test_client.get("/v1/describecomputeenvironments") res.status_code.should.equal(200) diff --git a/tests/test_cloudformation/fixtures/ec2_classic_eip.py b/tests/test_cloudformation/fixtures/ec2_classic_eip.py index 626e90ada..fd7758300 100644 --- a/tests/test_cloudformation/fixtures/ec2_classic_eip.py +++ b/tests/test_cloudformation/fixtures/ec2_classic_eip.py @@ -1,9 +1,3 @@ from __future__ import unicode_literals -template = { - "Resources": { - "EC2EIP": { - "Type": "AWS::EC2::EIP" - } - } -} +template = {"Resources": {"EC2EIP": {"Type": "AWS::EC2::EIP"}}} diff --git a/tests/test_cloudformation/fixtures/fn_join.py b/tests/test_cloudformation/fixtures/fn_join.py index 79b62d01e..ac73e3cd2 100644 --- a/tests/test_cloudformation/fixtures/fn_join.py +++ b/tests/test_cloudformation/fixtures/fn_join.py @@ -1,23 +1,11 @@ from __future__ import unicode_literals template = { - "Resources": { - "EC2EIP": { - "Type": "AWS::EC2::EIP" - } - }, + "Resources": {"EC2EIP": {"Type": "AWS::EC2::EIP"}}, "Outputs": { "EIP": { "Description": "EIP for joining", - "Value": { - "Fn::Join": [ - ":", - [ - "test eip", - {"Ref": "EC2EIP"} - ] - ] - } + "Value": {"Fn::Join": [":", ["test eip", {"Ref": "EC2EIP"}]]}, } - } + }, } diff --git a/tests/test_cloudformation/fixtures/kms_key.py b/tests/test_cloudformation/fixtures/kms_key.py index 366dbfcf5..af6a535d1 100644 --- a/tests/test_cloudformation/fixtures/kms_key.py +++ b/tests/test_cloudformation/fixtures/kms_key.py @@ -2,38 +2,45 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template to create a KMS Key. The Fn::GetAtt is used to retrieve the ARN", - - "Resources" : { - "myKey" : { - "Type" : "AWS::KMS::Key", - "Properties" : { + "Resources": { + "myKey": { + "Type": "AWS::KMS::Key", + "Properties": { "Description": "Sample KmsKey", "EnableKeyRotation": False, "Enabled": True, - "KeyPolicy" : { + "KeyPolicy": { "Version": "2012-10-17", "Id": "key-default-1", "Statement": [ { - "Sid": "Enable IAM User Permissions", - "Effect": "Allow", - "Principal": { - "AWS": { "Fn::Join" : ["" , ["arn:aws:iam::", {"Ref" : "AWS::AccountId"} ,":root" ]] } - }, - "Action": "kms:*", - "Resource": "*" + "Sid": "Enable IAM User Permissions", + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::Join": [ + "", + [ + "arn:aws:iam::", + {"Ref": "AWS::AccountId"}, + ":root", + ], + ] + } + }, + "Action": "kms:*", + "Resource": "*", } - ] - } - } + ], + }, + }, } }, - "Outputs" : { - "KeyArn" : { + "Outputs": { + "KeyArn": { "Description": "Generated Key Arn", - "Value" : { "Fn::GetAtt" : [ "myKey", "Arn" ] } + "Value": {"Fn::GetAtt": ["myKey", "Arn"]}, } - } -} \ No newline at end of file + }, +} diff --git a/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py b/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py index 6f379daa6..d58516d3d 100644 --- a/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py +++ b/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py @@ -2,9 +2,7 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template RDS_MySQL_With_Read_Replica: Sample template showing how to create a highly-available, RDS DBInstance with a read replica. **WARNING** This template creates an Amazon Relational Database Service database instance and Amazon CloudWatch alarms. You will be billed for the AWS resources used if you create a stack from this template.", - "Parameters": { "DBName": { "Default": "MyDatabase", @@ -13,13 +11,9 @@ template = { "MinLength": "1", "MaxLength": "64", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - - "DBInstanceIdentifier": { - "Type": "String" - }, - + "DBInstanceIdentifier": {"Type": "String"}, "DBUser": { "NoEcho": "true", "Description": "The database admin account username", @@ -27,9 +21,8 @@ template = { "MinLength": "1", "MaxLength": "16", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - "DBPassword": { "NoEcho": "true", "Description": "The database admin account password", @@ -37,112 +30,121 @@ template = { "MinLength": "1", "MaxLength": "41", "AllowedPattern": "[a-zA-Z0-9]+", - "ConstraintDescription": "must contain only alphanumeric characters." + "ConstraintDescription": "must contain only alphanumeric characters.", }, - "DBAllocatedStorage": { "Default": "5", "Description": "The size of the database (Gb)", "Type": "Number", "MinValue": "5", "MaxValue": "1024", - "ConstraintDescription": "must be between 5 and 1024Gb." + "ConstraintDescription": "must be between 5 and 1024Gb.", }, - "DBInstanceClass": { "Description": "The database instance type", "Type": "String", "Default": "db.m1.small", - "AllowedValues": ["db.t1.micro", "db.m1.small", "db.m1.medium", "db.m1.large", "db.m1.xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.m3.medium", "db.m3.large", "db.m3.xlarge", "db.m3.2xlarge", "db.r3.large", "db.r3.xlarge", "db.r3.2xlarge", "db.r3.4xlarge", "db.r3.8xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.cr1.8xlarge"], - "ConstraintDescription": "must select a valid database instance type." + "AllowedValues": [ + "db.t1.micro", + "db.m1.small", + "db.m1.medium", + "db.m1.large", + "db.m1.xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.m3.medium", + "db.m3.large", + "db.m3.xlarge", + "db.m3.2xlarge", + "db.r3.large", + "db.r3.xlarge", + "db.r3.2xlarge", + "db.r3.4xlarge", + "db.r3.8xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.cr1.8xlarge", + ], + "ConstraintDescription": "must select a valid database instance type.", }, - "EC2SecurityGroup": { "Description": "The EC2 security group that contains instances that need access to the database", "Default": "default", "Type": "String", "AllowedPattern": "[a-zA-Z0-9\\-]+", - "ConstraintDescription": "must be a valid security group name." + "ConstraintDescription": "must be a valid security group name.", }, - "MultiAZ": { "Description": "Multi-AZ master database", "Type": "String", "Default": "false", "AllowedValues": ["true", "false"], - "ConstraintDescription": "must be true or false." - } + "ConstraintDescription": "must be true or false.", + }, }, - "Conditions": { - "Is-EC2-VPC": {"Fn::Or": [{"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, - {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}]}, - "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]} + "Is-EC2-VPC": { + "Fn::Or": [ + {"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, + {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}, + ] + }, + "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]}, }, - "Resources": { "DBParameterGroup": { "Type": "AWS::RDS::DBParameterGroup", "Properties": { "Description": "DB Parameter Goup", "Family": "MySQL5.1", - "Parameters": { - "BACKLOG_QUEUE_LIMIT": "2048" - } - } + "Parameters": {"BACKLOG_QUEUE_LIMIT": "2048"}, + }, }, - "DBEC2SecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Condition": "Is-EC2-VPC", "Properties": { "GroupDescription": "Open database for access", - "SecurityGroupIngress": [{ - "IpProtocol": "tcp", - "FromPort": "3306", - "ToPort": "3306", - "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"} - }] - } + "SecurityGroupIngress": [ + { + "IpProtocol": "tcp", + "FromPort": "3306", + "ToPort": "3306", + "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"}, + } + ], + }, }, - "DBSecurityGroup": { "Type": "AWS::RDS::DBSecurityGroup", "Condition": "Is-EC2-Classic", "Properties": { - "DBSecurityGroupIngress": [{ - "EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"} - }], - "GroupDescription": "database access" - } + "DBSecurityGroupIngress": [ + {"EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"}} + ], + "GroupDescription": "database access", + }, }, - - "my_vpc": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - } - }, - + "my_vpc": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, "EC2Subnet": { "Type": "AWS::EC2::Subnet", "Condition": "Is-EC2-VPC", "Properties": { "AvailabilityZone": "eu-central-1a", "CidrBlock": "10.0.1.0/24", - "VpcId": {"Ref": "my_vpc"} - } + "VpcId": {"Ref": "my_vpc"}, + }, }, - "DBSubnet": { "Type": "AWS::RDS::DBSubnetGroup", "Condition": "Is-EC2-VPC", "Properties": { "DBSubnetGroupDescription": "my db subnet group", "SubnetIds": [{"Ref": "EC2Subnet"}], - } + }, }, - "MasterDB": { "Type": "AWS::RDS::DBInstance", "Properties": { @@ -151,54 +153,79 @@ template = { "AllocatedStorage": {"Ref": "DBAllocatedStorage"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, "Engine": "MySQL", - "DBSubnetGroupName": {"Fn::If": ["Is-EC2-VPC", {"Ref": "DBSubnet"}, {"Ref": "AWS::NoValue"}]}, + "DBSubnetGroupName": { + "Fn::If": [ + "Is-EC2-VPC", + {"Ref": "DBSubnet"}, + {"Ref": "AWS::NoValue"}, + ] + }, "MasterUsername": {"Ref": "DBUser"}, "MasterUserPassword": {"Ref": "DBPassword"}, "MultiAZ": {"Ref": "MultiAZ"}, "Tags": [{"Key": "Name", "Value": "Master Database"}], - "VPCSecurityGroups": {"Fn::If": ["Is-EC2-VPC", [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], {"Ref": "AWS::NoValue"}]}, - "DBSecurityGroups": {"Fn::If": ["Is-EC2-Classic", [{"Ref": "DBSecurityGroup"}], {"Ref": "AWS::NoValue"}]} + "VPCSecurityGroups": { + "Fn::If": [ + "Is-EC2-VPC", + [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], + {"Ref": "AWS::NoValue"}, + ] + }, + "DBSecurityGroups": { + "Fn::If": [ + "Is-EC2-Classic", + [{"Ref": "DBSecurityGroup"}], + {"Ref": "AWS::NoValue"}, + ] + }, }, - "DeletionPolicy": "Snapshot" + "DeletionPolicy": "Snapshot", }, - "ReplicaDB": { "Type": "AWS::RDS::DBInstance", "Properties": { "SourceDBInstanceIdentifier": {"Ref": "MasterDB"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, - "Tags": [{"Key": "Name", "Value": "Read Replica Database"}] - } - } + "Tags": [{"Key": "Name", "Value": "Read Replica Database"}], + }, + }, }, - "Outputs": { "EC2Platform": { "Description": "Platform in which this stack is deployed", - "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]} + "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]}, }, - "MasterJDBCConnectionString": { "Description": "JDBC connection string for the master database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, }, "ReplicaJDBCConnectionString": { "Description": "JDBC connection string for the replica database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} - } - } + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py b/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py index 2fbfb4cad..30f2210fc 100644 --- a/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py +++ b/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py @@ -2,9 +2,7 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template RDS_MySQL_With_Read_Replica: Sample template showing how to create a highly-available, RDS DBInstance with a read replica. **WARNING** This template creates an Amazon Relational Database Service database instance and Amazon CloudWatch alarms. You will be billed for the AWS resources used if you create a stack from this template.", - "Parameters": { "DBName": { "Default": "MyDatabase", @@ -13,13 +11,9 @@ template = { "MinLength": "1", "MaxLength": "64", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - - "DBInstanceIdentifier": { - "Type": "String" - }, - + "DBInstanceIdentifier": {"Type": "String"}, "DBUser": { "NoEcho": "true", "Description": "The database admin account username", @@ -27,9 +21,8 @@ template = { "MinLength": "1", "MaxLength": "16", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - "DBPassword": { "NoEcho": "true", "Description": "The database admin account password", @@ -37,101 +30,113 @@ template = { "MinLength": "1", "MaxLength": "41", "AllowedPattern": "[a-zA-Z0-9]+", - "ConstraintDescription": "must contain only alphanumeric characters." + "ConstraintDescription": "must contain only alphanumeric characters.", }, - "DBAllocatedStorage": { "Default": "5", "Description": "The size of the database (Gb)", "Type": "Number", "MinValue": "5", "MaxValue": "1024", - "ConstraintDescription": "must be between 5 and 1024Gb." + "ConstraintDescription": "must be between 5 and 1024Gb.", }, - "DBInstanceClass": { "Description": "The database instance type", "Type": "String", "Default": "db.m1.small", - "AllowedValues": ["db.t1.micro", "db.m1.small", "db.m1.medium", "db.m1.large", "db.m1.xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.m3.medium", "db.m3.large", "db.m3.xlarge", "db.m3.2xlarge", "db.r3.large", "db.r3.xlarge", "db.r3.2xlarge", "db.r3.4xlarge", "db.r3.8xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.cr1.8xlarge"], - "ConstraintDescription": "must select a valid database instance type." + "AllowedValues": [ + "db.t1.micro", + "db.m1.small", + "db.m1.medium", + "db.m1.large", + "db.m1.xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.m3.medium", + "db.m3.large", + "db.m3.xlarge", + "db.m3.2xlarge", + "db.r3.large", + "db.r3.xlarge", + "db.r3.2xlarge", + "db.r3.4xlarge", + "db.r3.8xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.cr1.8xlarge", + ], + "ConstraintDescription": "must select a valid database instance type.", }, - "EC2SecurityGroup": { "Description": "The EC2 security group that contains instances that need access to the database", "Default": "default", "Type": "String", "AllowedPattern": "[a-zA-Z0-9\\-]+", - "ConstraintDescription": "must be a valid security group name." + "ConstraintDescription": "must be a valid security group name.", }, - "MultiAZ": { "Description": "Multi-AZ master database", "Type": "String", "Default": "false", "AllowedValues": ["true", "false"], - "ConstraintDescription": "must be true or false." - } + "ConstraintDescription": "must be true or false.", + }, }, - "Conditions": { - "Is-EC2-VPC": {"Fn::Or": [{"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, - {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}]}, - "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]} + "Is-EC2-VPC": { + "Fn::Or": [ + {"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, + {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}, + ] + }, + "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]}, }, - "Resources": { "DBEC2SecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Condition": "Is-EC2-VPC", "Properties": { "GroupDescription": "Open database for access", - "SecurityGroupIngress": [{ - "IpProtocol": "tcp", - "FromPort": "3306", - "ToPort": "3306", - "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"} - }] - } + "SecurityGroupIngress": [ + { + "IpProtocol": "tcp", + "FromPort": "3306", + "ToPort": "3306", + "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"}, + } + ], + }, }, - "DBSecurityGroup": { "Type": "AWS::RDS::DBSecurityGroup", "Condition": "Is-EC2-Classic", "Properties": { - "DBSecurityGroupIngress": [{ - "EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"} - }], - "GroupDescription": "database access" - } + "DBSecurityGroupIngress": [ + {"EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"}} + ], + "GroupDescription": "database access", + }, }, - - "my_vpc": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - } - }, - + "my_vpc": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, "EC2Subnet": { "Type": "AWS::EC2::Subnet", "Condition": "Is-EC2-VPC", "Properties": { "AvailabilityZone": "eu-central-1a", "CidrBlock": "10.0.1.0/24", - "VpcId": {"Ref": "my_vpc"} - } + "VpcId": {"Ref": "my_vpc"}, + }, }, - "DBSubnet": { "Type": "AWS::RDS::DBSubnetGroup", "Condition": "Is-EC2-VPC", "Properties": { "DBSubnetGroupDescription": "my db subnet group", "SubnetIds": [{"Ref": "EC2Subnet"}], - } + }, }, - "MasterDB": { "Type": "AWS::RDS::DBInstance", "Properties": { @@ -140,54 +145,79 @@ template = { "AllocatedStorage": {"Ref": "DBAllocatedStorage"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, "Engine": "MySQL", - "DBSubnetGroupName": {"Fn::If": ["Is-EC2-VPC", {"Ref": "DBSubnet"}, {"Ref": "AWS::NoValue"}]}, + "DBSubnetGroupName": { + "Fn::If": [ + "Is-EC2-VPC", + {"Ref": "DBSubnet"}, + {"Ref": "AWS::NoValue"}, + ] + }, "MasterUsername": {"Ref": "DBUser"}, "MasterUserPassword": {"Ref": "DBPassword"}, "MultiAZ": {"Ref": "MultiAZ"}, "Tags": [{"Key": "Name", "Value": "Master Database"}], - "VPCSecurityGroups": {"Fn::If": ["Is-EC2-VPC", [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], {"Ref": "AWS::NoValue"}]}, - "DBSecurityGroups": {"Fn::If": ["Is-EC2-Classic", [{"Ref": "DBSecurityGroup"}], {"Ref": "AWS::NoValue"}]} + "VPCSecurityGroups": { + "Fn::If": [ + "Is-EC2-VPC", + [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], + {"Ref": "AWS::NoValue"}, + ] + }, + "DBSecurityGroups": { + "Fn::If": [ + "Is-EC2-Classic", + [{"Ref": "DBSecurityGroup"}], + {"Ref": "AWS::NoValue"}, + ] + }, }, - "DeletionPolicy": "Snapshot" + "DeletionPolicy": "Snapshot", }, - "ReplicaDB": { "Type": "AWS::RDS::DBInstance", "Properties": { "SourceDBInstanceIdentifier": {"Ref": "MasterDB"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, - "Tags": [{"Key": "Name", "Value": "Read Replica Database"}] - } - } + "Tags": [{"Key": "Name", "Value": "Read Replica Database"}], + }, + }, }, - "Outputs": { "EC2Platform": { "Description": "Platform in which this stack is deployed", - "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]} + "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]}, }, - "MasterJDBCConnectionString": { "Description": "JDBC connection string for the master database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, }, "ReplicaJDBCConnectionString": { "Description": "JDBC connection string for the replica database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} - } - } + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/redshift.py b/tests/test_cloudformation/fixtures/redshift.py index 317e213bc..6da5c30db 100644 --- a/tests/test_cloudformation/fixtures/redshift.py +++ b/tests/test_cloudformation/fixtures/redshift.py @@ -7,35 +7,35 @@ template = { "Description": "The name of the first database to be created when the cluster is created", "Type": "String", "Default": "dev", - "AllowedPattern": "([a-z]|[0-9])+" + "AllowedPattern": "([a-z]|[0-9])+", }, "ClusterType": { "Description": "The type of cluster", "Type": "String", "Default": "single-node", - "AllowedValues": ["single-node", "multi-node"] + "AllowedValues": ["single-node", "multi-node"], }, "NumberOfNodes": { "Description": "The number of compute nodes in the cluster. For multi-node clusters, the NumberOfNodes parameter must be greater than 1", "Type": "Number", - "Default": "1" + "Default": "1", }, "NodeType": { "Description": "The type of node to be provisioned", "Type": "String", "Default": "dw1.xlarge", - "AllowedValues": ["dw1.xlarge", "dw1.8xlarge", "dw2.large", "dw2.8xlarge"] + "AllowedValues": ["dw1.xlarge", "dw1.8xlarge", "dw2.large", "dw2.8xlarge"], }, "MasterUsername": { "Description": "The user name that is associated with the master user account for the cluster that is being created", "Type": "String", "Default": "defaultuser", - "AllowedPattern": "([a-z])([a-z]|[0-9])*" + "AllowedPattern": "([a-z])([a-z]|[0-9])*", }, - "MasterUserPassword": { + "MasterUserPassword": { "Description": "The password that is associated with the master user account for the cluster that is being created.", "Type": "String", - "NoEcho": "true" + "NoEcho": "true", }, "InboundTraffic": { "Description": "Allow inbound traffic to the cluster from this CIDR range.", @@ -44,18 +44,16 @@ template = { "MaxLength": "18", "Default": "0.0.0.0/0", "AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})", - "ConstraintDescription": "must be a valid CIDR range of the form x.x.x.x/x." + "ConstraintDescription": "must be a valid CIDR range of the form x.x.x.x/x.", }, "PortNumber": { "Description": "The port number on which the cluster accepts incoming connections.", "Type": "Number", - "Default": "5439" - } + "Default": "5439", + }, }, "Conditions": { - "IsMultiNodeCluster": { - "Fn::Equals": [{"Ref": "ClusterType"}, "multi-node"] - } + "IsMultiNodeCluster": {"Fn::Equals": [{"Ref": "ClusterType"}, "multi-node"]} }, "Resources": { "RedshiftCluster": { @@ -63,7 +61,13 @@ template = { "DependsOn": "AttachGateway", "Properties": { "ClusterType": {"Ref": "ClusterType"}, - "NumberOfNodes": {"Fn::If": ["IsMultiNodeCluster", {"Ref": "NumberOfNodes"}, {"Ref": "AWS::NoValue"}]}, + "NumberOfNodes": { + "Fn::If": [ + "IsMultiNodeCluster", + {"Ref": "NumberOfNodes"}, + {"Ref": "AWS::NoValue"}, + ] + }, "NodeType": {"Ref": "NodeType"}, "DBName": {"Ref": "DatabaseName"}, "MasterUsername": {"Ref": "MasterUsername"}, @@ -72,116 +76,106 @@ template = { "VpcSecurityGroupIds": [{"Ref": "SecurityGroup"}], "ClusterSubnetGroupName": {"Ref": "RedshiftClusterSubnetGroup"}, "PubliclyAccessible": "true", - "Port": {"Ref": "PortNumber"} - } + "Port": {"Ref": "PortNumber"}, + }, }, "RedshiftClusterParameterGroup": { "Type": "AWS::Redshift::ClusterParameterGroup", "Properties": { "Description": "Cluster parameter group", "ParameterGroupFamily": "redshift-1.0", - "Parameters": [{ - "ParameterName": "enable_user_activity_logging", - "ParameterValue": "true" - }] - } + "Parameters": [ + { + "ParameterName": "enable_user_activity_logging", + "ParameterValue": "true", + } + ], + }, }, "RedshiftClusterSubnetGroup": { "Type": "AWS::Redshift::ClusterSubnetGroup", "Properties": { "Description": "Cluster subnet group", - "SubnetIds": [{"Ref": "PublicSubnet"}] - } - }, - "VPC": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16" - } + "SubnetIds": [{"Ref": "PublicSubnet"}], + }, }, + "VPC": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, "PublicSubnet": { "Type": "AWS::EC2::Subnet", - "Properties": { - "CidrBlock": "10.0.0.0/24", - "VpcId": {"Ref": "VPC"} - } + "Properties": {"CidrBlock": "10.0.0.0/24", "VpcId": {"Ref": "VPC"}}, }, "SecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "Security group", - "SecurityGroupIngress": [{ - "CidrIp": {"Ref": "InboundTraffic"}, - "FromPort": {"Ref": "PortNumber"}, - "ToPort": {"Ref": "PortNumber"}, - "IpProtocol": "tcp" - }], - "VpcId": {"Ref": "VPC"} - } - }, - "myInternetGateway": { - "Type": "AWS::EC2::InternetGateway" + "SecurityGroupIngress": [ + { + "CidrIp": {"Ref": "InboundTraffic"}, + "FromPort": {"Ref": "PortNumber"}, + "ToPort": {"Ref": "PortNumber"}, + "IpProtocol": "tcp", + } + ], + "VpcId": {"Ref": "VPC"}, + }, }, + "myInternetGateway": {"Type": "AWS::EC2::InternetGateway"}, "AttachGateway": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { "VpcId": {"Ref": "VPC"}, - "InternetGatewayId": {"Ref": "myInternetGateway"} - } + "InternetGatewayId": {"Ref": "myInternetGateway"}, + }, }, "PublicRouteTable": { "Type": "AWS::EC2::RouteTable", - "Properties": { - "VpcId": { - "Ref": "VPC" - } - } + "Properties": {"VpcId": {"Ref": "VPC"}}, }, "PublicRoute": { "Type": "AWS::EC2::Route", "DependsOn": "AttachGateway", "Properties": { - "RouteTableId": { - "Ref": "PublicRouteTable" - }, + "RouteTableId": {"Ref": "PublicRouteTable"}, "DestinationCidrBlock": "0.0.0.0/0", - "GatewayId": { - "Ref": "myInternetGateway" - } - } + "GatewayId": {"Ref": "myInternetGateway"}, + }, }, "PublicSubnetRouteTableAssociation": { "Type": "AWS::EC2::SubnetRouteTableAssociation", "Properties": { - "SubnetId": { - "Ref": "PublicSubnet" - }, - "RouteTableId": { - "Ref": "PublicRouteTable" - } - } - } + "SubnetId": {"Ref": "PublicSubnet"}, + "RouteTableId": {"Ref": "PublicRouteTable"}, + }, + }, }, "Outputs": { "ClusterEndpoint": { "Description": "Cluster endpoint", - "Value": {"Fn::Join": [":", [{"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Address"]}, {"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Port"]}]]} + "Value": { + "Fn::Join": [ + ":", + [ + {"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Address"]}, + {"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Port"]}, + ], + ] + }, }, "ClusterName": { "Description": "Name of cluster", - "Value": {"Ref": "RedshiftCluster"} + "Value": {"Ref": "RedshiftCluster"}, }, "ParameterGroupName": { "Description": "Name of parameter group", - "Value": {"Ref": "RedshiftClusterParameterGroup"} + "Value": {"Ref": "RedshiftClusterParameterGroup"}, }, "RedshiftClusterSubnetGroupName": { "Description": "Name of cluster subnet group", - "Value": {"Ref": "RedshiftClusterSubnetGroup"} + "Value": {"Ref": "RedshiftClusterSubnetGroup"}, }, "RedshiftClusterSecurityGroupName": { "Description": "Name of cluster security group", - "Value": {"Ref": "SecurityGroup"} - } - } + "Value": {"Ref": "SecurityGroup"}, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py b/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py index 43a11104b..3f5735bba 100644 --- a/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py +++ b/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py @@ -1,47 +1,38 @@ from __future__ import unicode_literals template = { - "Parameters": { - "R53ZoneName": { - "Type": "String", - "Default": "my_zone" - } - }, - + "Parameters": {"R53ZoneName": {"Type": "String", "Default": "my_zone"}}, "Resources": { "Ec2Instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "PrivateIpAddress": "10.0.0.25", - } + "Properties": {"ImageId": "ami-1234abcd", "PrivateIpAddress": "10.0.0.25"}, }, - "HostedZone": { "Type": "AWS::Route53::HostedZone", - "Properties": { - "Name": {"Ref": "R53ZoneName"} - } + "Properties": {"Name": {"Ref": "R53ZoneName"}}, }, - "myDNSRecord": { "Type": "AWS::Route53::RecordSet", "Properties": { "HostedZoneId": {"Ref": "HostedZone"}, "Comment": "DNS name for my instance.", "Name": { - "Fn::Join": ["", [ - {"Ref": "Ec2Instance"}, ".", - {"Ref": "AWS::Region"}, ".", - {"Ref": "R53ZoneName"}, "." - ]] + "Fn::Join": [ + "", + [ + {"Ref": "Ec2Instance"}, + ".", + {"Ref": "AWS::Region"}, + ".", + {"Ref": "R53ZoneName"}, + ".", + ], + ] }, "Type": "A", "TTL": "900", - "ResourceRecords": [ - {"Fn::GetAtt": ["Ec2Instance", "PrivateIp"]} - ] - } - } + "ResourceRecords": [{"Fn::GetAtt": ["Ec2Instance", "PrivateIp"]}], + }, + }, }, } diff --git a/tests/test_cloudformation/fixtures/route53_health_check.py b/tests/test_cloudformation/fixtures/route53_health_check.py index 420cd38ba..876caf299 100644 --- a/tests/test_cloudformation/fixtures/route53_health_check.py +++ b/tests/test_cloudformation/fixtures/route53_health_check.py @@ -4,11 +4,8 @@ template = { "Resources": { "HostedZone": { "Type": "AWS::Route53::HostedZone", - "Properties": { - "Name": "my_zone" - } + "Properties": {"Name": "my_zone"}, }, - "my_health_check": { "Type": "AWS::Route53::HealthCheck", "Properties": { @@ -20,9 +17,8 @@ template = { "ResourcePath": "/", "Type": "HTTP", } - } + }, }, - "myDNSRecord": { "Type": "AWS::Route53::RecordSet", "Properties": { @@ -33,7 +29,7 @@ template = { "TTL": "900", "ResourceRecords": ["my.example.com"], "HealthCheckId": {"Ref": "my_health_check"}, - } - } - }, + }, + }, + } } diff --git a/tests/test_cloudformation/fixtures/route53_roundrobin.py b/tests/test_cloudformation/fixtures/route53_roundrobin.py index 199e3e088..9c9f8a6f9 100644 --- a/tests/test_cloudformation/fixtures/route53_roundrobin.py +++ b/tests/test_cloudformation/fixtures/route53_roundrobin.py @@ -2,53 +2,71 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template Route53_RoundRobin: Sample template showing how to use weighted round robin (WRR) DNS entried via Amazon Route 53. This contrived sample uses weighted CNAME records to illustrate that the weighting influences the return records. It assumes that you already have a Hosted Zone registered with Amazon Route 53. **WARNING** This template creates one or more AWS resources. You will be billed for the AWS resources used if you create a stack from this template.", - - "Parameters": { - "R53ZoneName": { - "Type": "String", - "Default": "my_zone" - } - }, - + "Parameters": {"R53ZoneName": {"Type": "String", "Default": "my_zone"}}, "Resources": { - "MyZone": { "Type": "AWS::Route53::HostedZone", - "Properties": { - "Name": {"Ref": "R53ZoneName"} - } + "Properties": {"Name": {"Ref": "R53ZoneName"}}, }, - "MyDNSRecord": { "Type": "AWS::Route53::RecordSetGroup", "Properties": { "HostedZoneId": {"Ref": "MyZone"}, "Comment": "Contrived example to redirect to aws.amazon.com 75% of the time and www.amazon.com 25% of the time.", - "RecordSets": [{ - "SetIdentifier": {"Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "AWS"]]}, - "Name": {"Fn::Join": ["", [{"Ref": "AWS::StackName"}, ".", {"Ref": "AWS::Region"}, ".", {"Ref": "R53ZoneName"}, "."]]}, - "Type": "CNAME", - "TTL": "900", - "ResourceRecords": ["aws.amazon.com"], - "Weight": "3" - }, { - "SetIdentifier": {"Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "Amazon"]]}, - "Name": {"Fn::Join": ["", [{"Ref": "AWS::StackName"}, ".", {"Ref": "AWS::Region"}, ".", {"Ref": "R53ZoneName"}, "."]]}, - "Type": "CNAME", - "TTL": "900", - "ResourceRecords": ["www.amazon.com"], - "Weight": "1" - }] - } - } + "RecordSets": [ + { + "SetIdentifier": { + "Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "AWS"]] + }, + "Name": { + "Fn::Join": [ + "", + [ + {"Ref": "AWS::StackName"}, + ".", + {"Ref": "AWS::Region"}, + ".", + {"Ref": "R53ZoneName"}, + ".", + ], + ] + }, + "Type": "CNAME", + "TTL": "900", + "ResourceRecords": ["aws.amazon.com"], + "Weight": "3", + }, + { + "SetIdentifier": { + "Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "Amazon"]] + }, + "Name": { + "Fn::Join": [ + "", + [ + {"Ref": "AWS::StackName"}, + ".", + {"Ref": "AWS::Region"}, + ".", + {"Ref": "R53ZoneName"}, + ".", + ], + ] + }, + "Type": "CNAME", + "TTL": "900", + "ResourceRecords": ["www.amazon.com"], + "Weight": "1", + }, + ], + }, + }, }, - "Outputs": { "DomainName": { "Description": "Fully qualified domain name", - "Value": {"Ref": "MyDNSRecord"} + "Value": {"Ref": "MyDNSRecord"}, } - } + }, } diff --git a/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py b/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py index 189cc36cd..8226b5ad3 100644 --- a/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py +++ b/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py @@ -10,7 +10,7 @@ template = { "MinLength": "9", "AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})", "MaxLength": "18", - "Type": "String" + "Type": "String", }, "KeyName": { "Type": "String", @@ -18,7 +18,7 @@ template = { "MinLength": "1", "AllowedPattern": "[\\x20-\\x7E]*", "MaxLength": "255", - "ConstraintDescription": "can contain only ASCII characters." + "ConstraintDescription": "can contain only ASCII characters.", }, "InstanceType": { "Default": "m1.small", @@ -40,8 +40,8 @@ template = { "c1.xlarge", "cc1.4xlarge", "cc2.8xlarge", - "cg1.4xlarge" - ] + "cg1.4xlarge", + ], }, "VolumeSize": { "Description": "WebServer EC2 instance type", @@ -49,8 +49,8 @@ template = { "Type": "Number", "MaxValue": "1024", "MinValue": "5", - "ConstraintDescription": "must be between 5 and 1024 Gb." - } + "ConstraintDescription": "must be between 5 and 1024 Gb.", + }, }, "AWSTemplateFormatVersion": "2010-09-09", "Outputs": { @@ -59,17 +59,9 @@ template = { "Value": { "Fn::Join": [ "", - [ - "http://", - { - "Fn::GetAtt": [ - "WebServer", - "PublicDnsName" - ] - } - ] + ["http://", {"Fn::GetAtt": ["WebServer", "PublicDnsName"]}], ] - } + }, } }, "Resources": { @@ -81,19 +73,17 @@ template = { "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", - "FromPort": "80" + "FromPort": "80", }, { "ToPort": "22", "IpProtocol": "tcp", - "CidrIp": { - "Ref": "SSHLocation" - }, - "FromPort": "22" - } + "CidrIp": {"Ref": "SSHLocation"}, + "FromPort": "22", + }, ], - "GroupDescription": "Enable SSH access and HTTP access on the inbound port" - } + "GroupDescription": "Enable SSH access and HTTP access on the inbound port", + }, }, "WebServer": { "Type": "AWS::EC2::Instance", @@ -108,23 +98,17 @@ template = { "# Helper function\n", "function error_exit\n", "{\n", - " /opt/aws/bin/cfn-signal -e 1 -r \"$1\" '", - { - "Ref": "WaitHandle" - }, + ' /opt/aws/bin/cfn-signal -e 1 -r "$1" \'', + {"Ref": "WaitHandle"}, "'\n", " exit 1\n", "}\n", "# Install Rails packages\n", "/opt/aws/bin/cfn-init -s ", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, " -r WebServer ", " --region ", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, " || error_exit 'Failed to run cfn-init'\n", "# Wait for the EBS volume to show up\n", "while [ ! -e /dev/sdh ]; do echo Waiting for EBS volume to attach; sleep 5; done\n", @@ -137,56 +121,38 @@ template = { "git init\n", "gollum --port 80 --host 0.0.0.0 &\n", "# If all is well so signal success\n", - "/opt/aws/bin/cfn-signal -e $? -r \"Rails application setup complete\" '", - { - "Ref": "WaitHandle" - }, - "'\n" - ] + '/opt/aws/bin/cfn-signal -e $? -r "Rails application setup complete" \'', + {"Ref": "WaitHandle"}, + "'\n", + ], ] } }, - "KeyName": { - "Ref": "KeyName" - }, - "SecurityGroups": [ - { - "Ref": "WebServerSecurityGroup" - } - ], - "InstanceType": { - "Ref": "InstanceType" - }, + "KeyName": {"Ref": "KeyName"}, + "SecurityGroups": [{"Ref": "WebServerSecurityGroup"}], + "InstanceType": {"Ref": "InstanceType"}, "ImageId": { "Fn::FindInMap": [ "AWSRegionArch2AMI", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, { "Fn::FindInMap": [ "AWSInstanceType2Arch", - { - "Ref": "InstanceType" - }, - "Arch" + {"Ref": "InstanceType"}, + "Arch", ] - } + }, ] - } + }, }, "Metadata": { "AWS::CloudFormation::Init": { "config": { "packages": { "rubygems": { - "nokogiri": [ - "1.5.10" - ], + "nokogiri": ["1.5.10"], "rdiscount": [], - "gollum": [ - "1.1.1" - ] + "gollum": ["1.1.1"], }, "yum": { "libxslt-devel": [], @@ -196,150 +162,99 @@ template = { "ruby-devel": [], "ruby-rdoc": [], "make": [], - "libxml2-devel": [] - } + "libxml2-devel": [], + }, } } } - } + }, }, "DataVolume": { "Type": "AWS::EC2::Volume", "Properties": { - "Tags": [ - { - "Value": "Gollum Data Volume", - "Key": "Usage" - } - ], - "AvailabilityZone": { - "Fn::GetAtt": [ - "WebServer", - "AvailabilityZone" - ] - }, + "Tags": [{"Value": "Gollum Data Volume", "Key": "Usage"}], + "AvailabilityZone": {"Fn::GetAtt": ["WebServer", "AvailabilityZone"]}, "Size": "100", - } + }, }, "MountPoint": { "Type": "AWS::EC2::VolumeAttachment", "Properties": { - "InstanceId": { - "Ref": "WebServer" - }, + "InstanceId": {"Ref": "WebServer"}, "Device": "/dev/sdh", - "VolumeId": { - "Ref": "DataVolume" - } - } + "VolumeId": {"Ref": "DataVolume"}, + }, }, "WaitCondition": { "DependsOn": "MountPoint", "Type": "AWS::CloudFormation::WaitCondition", - "Properties": { - "Handle": { - "Ref": "WaitHandle" - }, - "Timeout": "300" - }, + "Properties": {"Handle": {"Ref": "WaitHandle"}, "Timeout": "300"}, "Metadata": { "Comment1": "Note that the WaitCondition is dependent on the volume mount point allowing the volume to be created and attached to the EC2 instance", - "Comment2": "The instance bootstrap script waits for the volume to be attached to the instance prior to installing Gollum and signalling completion" - } + "Comment2": "The instance bootstrap script waits for the volume to be attached to the instance prior to installing Gollum and signalling completion", + }, }, - "WaitHandle": { - "Type": "AWS::CloudFormation::WaitConditionHandle" - } + "WaitHandle": {"Type": "AWS::CloudFormation::WaitConditionHandle"}, }, "Mappings": { "AWSInstanceType2Arch": { - "m3.2xlarge": { - "Arch": "64" - }, - "m2.2xlarge": { - "Arch": "64" - }, - "m1.small": { - "Arch": "64" - }, - "c1.medium": { - "Arch": "64" - }, - "cg1.4xlarge": { - "Arch": "64HVM" - }, - "m2.xlarge": { - "Arch": "64" - }, - "t1.micro": { - "Arch": "64" - }, - "cc1.4xlarge": { - "Arch": "64HVM" - }, - "m1.medium": { - "Arch": "64" - }, - "cc2.8xlarge": { - "Arch": "64HVM" - }, - "m1.large": { - "Arch": "64" - }, - "m1.xlarge": { - "Arch": "64" - }, - "m2.4xlarge": { - "Arch": "64" - }, - "c1.xlarge": { - "Arch": "64" - }, - "m3.xlarge": { - "Arch": "64" - } + "m3.2xlarge": {"Arch": "64"}, + "m2.2xlarge": {"Arch": "64"}, + "m1.small": {"Arch": "64"}, + "c1.medium": {"Arch": "64"}, + "cg1.4xlarge": {"Arch": "64HVM"}, + "m2.xlarge": {"Arch": "64"}, + "t1.micro": {"Arch": "64"}, + "cc1.4xlarge": {"Arch": "64HVM"}, + "m1.medium": {"Arch": "64"}, + "cc2.8xlarge": {"Arch": "64HVM"}, + "m1.large": {"Arch": "64"}, + "m1.xlarge": {"Arch": "64"}, + "m2.4xlarge": {"Arch": "64"}, + "c1.xlarge": {"Arch": "64"}, + "m3.xlarge": {"Arch": "64"}, }, "AWSRegionArch2AMI": { "ap-southeast-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-b4b0cae6", - "64": "ami-beb0caec" + "64": "ami-beb0caec", }, "ap-southeast-2": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-b3990e89", - "64": "ami-bd990e87" + "64": "ami-bd990e87", }, "us-west-2": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-38fe7308", - "64": "ami-30fe7300" + "64": "ami-30fe7300", }, "us-east-1": { "64HVM": "ami-0da96764", "32": "ami-31814f58", - "64": "ami-1b814f72" + "64": "ami-1b814f72", }, "ap-northeast-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-0644f007", - "64": "ami-0a44f00b" + "64": "ami-0a44f00b", }, "us-west-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-11d68a54", - "64": "ami-1bd68a5e" + "64": "ami-1bd68a5e", }, "eu-west-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-973b06e3", - "64": "ami-953b06e1" + "64": "ami-953b06e1", }, "sa-east-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-3e3be423", - "64": "ami-3c3be421" - } - } - } + "64": "ami-3c3be421", + }, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/vpc_eip.py b/tests/test_cloudformation/fixtures/vpc_eip.py index 2d6872f64..154d4c2d4 100644 --- a/tests/test_cloudformation/fixtures/vpc_eip.py +++ b/tests/test_cloudformation/fixtures/vpc_eip.py @@ -1,12 +1,5 @@ from __future__ import unicode_literals template = { - "Resources": { - "VPCEIP": { - "Type": "AWS::EC2::EIP", - "Properties": { - "Domain": "vpc" - } - } - } + "Resources": {"VPCEIP": {"Type": "AWS::EC2::EIP", "Properties": {"Domain": "vpc"}}} } diff --git a/tests/test_cloudformation/fixtures/vpc_eni.py b/tests/test_cloudformation/fixtures/vpc_eni.py index 3f8eb2d03..fc2d7d61b 100644 --- a/tests/test_cloudformation/fixtures/vpc_eni.py +++ b/tests/test_cloudformation/fixtures/vpc_eni.py @@ -6,33 +6,26 @@ template = { "Resources": { "ENI": { "Type": "AWS::EC2::NetworkInterface", - "Properties": { - "SubnetId": {"Ref": "Subnet"} - } + "Properties": {"SubnetId": {"Ref": "Subnet"}}, }, "Subnet": { "Type": "AWS::EC2::Subnet", "Properties": { "AvailabilityZone": "us-east-1a", "VpcId": {"Ref": "VPC"}, - "CidrBlock": "10.0.0.0/24" - } + "CidrBlock": "10.0.0.0/24", + }, }, - "VPC": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16" - } - } + "VPC": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, }, "Outputs": { "NinjaENI": { "Description": "Elastic IP mapping to Auto-Scaling Group", - "Value": {"Ref": "ENI"} + "Value": {"Ref": "ENI"}, }, "ENIIpAddress": { "Description": "ENI's Private IP address", - "Value": {"Fn::GetAtt": ["ENI", "PrimaryPrivateIpAddress"]} - } - } + "Value": {"Fn::GetAtt": ["ENI", "PrimaryPrivateIpAddress"]}, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py b/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py index 39f02462e..546f68cb4 100644 --- a/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py +++ b/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py @@ -10,7 +10,7 @@ template = { "MinLength": "9", "AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})", "MaxLength": "18", - "Type": "String" + "Type": "String", }, "KeyName": { "Type": "String", @@ -18,7 +18,7 @@ template = { "MinLength": "1", "AllowedPattern": "[\\x20-\\x7E]*", "MaxLength": "255", - "ConstraintDescription": "can contain only ASCII characters." + "ConstraintDescription": "can contain only ASCII characters.", }, "InstanceType": { "Default": "m1.small", @@ -40,9 +40,9 @@ template = { "c1.xlarge", "cc1.4xlarge", "cc2.8xlarge", - "cg1.4xlarge" - ] - } + "cg1.4xlarge", + ], + }, }, "AWSTemplateFormatVersion": "2010-09-09", "Outputs": { @@ -51,116 +51,61 @@ template = { "Value": { "Fn::Join": [ "", - [ - "http://", - { - "Fn::GetAtt": [ - "WebServerInstance", - "PublicIp" - ] - } - ] + ["http://", {"Fn::GetAtt": ["WebServerInstance", "PublicIp"]}], ] - } + }, } }, "Resources": { "Subnet": { "Type": "AWS::EC2::Subnet", "Properties": { - "VpcId": { - "Ref": "VPC" - }, + "VpcId": {"Ref": "VPC"}, "CidrBlock": "10.0.0.0/24", - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } - }, - "WebServerWaitHandle": { - "Type": "AWS::CloudFormation::WaitConditionHandle" + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], + }, }, + "WebServerWaitHandle": {"Type": "AWS::CloudFormation::WaitConditionHandle"}, "Route": { "Type": "AWS::EC2::Route", "Properties": { - "GatewayId": { - "Ref": "InternetGateway" - }, + "GatewayId": {"Ref": "InternetGateway"}, "DestinationCidrBlock": "0.0.0.0/0", - "RouteTableId": { - "Ref": "RouteTable" - } + "RouteTableId": {"Ref": "RouteTable"}, }, - "DependsOn": "AttachGateway" + "DependsOn": "AttachGateway", }, "SubnetRouteTableAssociation": { "Type": "AWS::EC2::SubnetRouteTableAssociation", "Properties": { - "SubnetId": { - "Ref": "Subnet" - }, - "RouteTableId": { - "Ref": "RouteTable" - } - } + "SubnetId": {"Ref": "Subnet"}, + "RouteTableId": {"Ref": "RouteTable"}, + }, }, "InternetGateway": { "Type": "AWS::EC2::InternetGateway", "Properties": { - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}] + }, }, "RouteTable": { "Type": "AWS::EC2::RouteTable", "Properties": { - "VpcId": { - "Ref": "VPC" - }, - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } + "VpcId": {"Ref": "VPC"}, + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], + }, }, "WebServerWaitCondition": { "Type": "AWS::CloudFormation::WaitCondition", - "Properties": { - "Handle": { - "Ref": "WebServerWaitHandle" - }, - "Timeout": "300" - }, - "DependsOn": "WebServerInstance" + "Properties": {"Handle": {"Ref": "WebServerWaitHandle"}, "Timeout": "300"}, + "DependsOn": "WebServerInstance", }, "VPC": { "Type": "AWS::EC2::VPC", "Properties": { "CidrBlock": "10.0.0.0/16", - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], + }, }, "InstanceSecurityGroup": { "Type": "AWS::EC2::SecurityGroup", @@ -169,23 +114,19 @@ template = { { "ToPort": "22", "IpProtocol": "tcp", - "CidrIp": { - "Ref": "SSHLocation" - }, - "FromPort": "22" + "CidrIp": {"Ref": "SSHLocation"}, + "FromPort": "22", }, { "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", - "FromPort": "80" - } + "FromPort": "80", + }, ], - "VpcId": { - "Ref": "VPC" - }, - "GroupDescription": "Enable SSH access via port 22" - } + "VpcId": {"Ref": "VPC"}, + "GroupDescription": "Enable SSH access via port 22", + }, }, "WebServerInstance": { "Type": "AWS::EC2::Instance", @@ -200,71 +141,39 @@ template = { "# Helper function\n", "function error_exit\n", "{\n", - " /opt/aws/bin/cfn-signal -e 1 -r \"$1\" '", - { - "Ref": "WebServerWaitHandle" - }, + ' /opt/aws/bin/cfn-signal -e 1 -r "$1" \'', + {"Ref": "WebServerWaitHandle"}, "'\n", " exit 1\n", "}\n", "# Install the simple web page\n", "/opt/aws/bin/cfn-init -s ", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, " -r WebServerInstance ", " --region ", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, " || error_exit 'Failed to run cfn-init'\n", "# Start up the cfn-hup daemon to listen for changes to the Web Server metadata\n", "/opt/aws/bin/cfn-hup || error_exit 'Failed to start cfn-hup'\n", "# All done so signal success\n", - "/opt/aws/bin/cfn-signal -e 0 -r \"WebServer setup complete\" '", - { - "Ref": "WebServerWaitHandle" - }, - "'\n" - ] + '/opt/aws/bin/cfn-signal -e 0 -r "WebServer setup complete" \'', + {"Ref": "WebServerWaitHandle"}, + "'\n", + ], ] } }, "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - }, - { - "Value": "Bar", - "Key": "Foo" - } + {"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}, + {"Value": "Bar", "Key": "Foo"}, ], - "SecurityGroupIds": [ - { - "Ref": "InstanceSecurityGroup" - } - ], - "KeyName": { - "Ref": "KeyName" - }, - "SubnetId": { - "Ref": "Subnet" - }, + "SecurityGroupIds": [{"Ref": "InstanceSecurityGroup"}], + "KeyName": {"Ref": "KeyName"}, + "SubnetId": {"Ref": "Subnet"}, "ImageId": { - "Fn::FindInMap": [ - "RegionMap", - { - "Ref": "AWS::Region" - }, - "AMI" - ] + "Fn::FindInMap": ["RegionMap", {"Ref": "AWS::Region"}, "AMI"] }, - "InstanceType": { - "Ref": "InstanceType" - } + "InstanceType": {"Ref": "InstanceType"}, }, "Metadata": { "Comment": "Install a simple PHP application", @@ -278,21 +187,17 @@ template = { [ "[main]\n", "stack=", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, "\n", "region=", - { - "Ref": "AWS::Region" - }, - "\n" - ] + {"Ref": "AWS::Region"}, + "\n", + ], ] }, "owner": "root", "group": "root", - "mode": "000400" + "mode": "000400", }, "/etc/cfn/hooks.d/cfn-auto-reloader.conf": { "content": { @@ -303,17 +208,13 @@ template = { "triggers=post.update\n", "path=Resources.WebServerInstance.Metadata.AWS::CloudFormation::Init\n", "action=/opt/aws/bin/cfn-init -s ", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, " -r WebServerInstance ", " --region ", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, "\n", - "runas=root\n" - ] + "runas=root\n", + ], ] } }, @@ -324,85 +225,52 @@ template = { [ "AWS CloudFormation sample PHP application';\n", - "?>\n" - ] + "?>\n", + ], ] }, "owner": "apache", "group": "apache", - "mode": "000644" - } + "mode": "000644", + }, }, "services": { "sysvinit": { - "httpd": { - "ensureRunning": "true", - "enabled": "true" - }, + "httpd": {"ensureRunning": "true", "enabled": "true"}, "sendmail": { "ensureRunning": "false", - "enabled": "false" - } + "enabled": "false", + }, } }, - "packages": { - "yum": { - "httpd": [], - "php": [] - } - } + "packages": {"yum": {"httpd": [], "php": []}}, } - } - } + }, + }, }, "IPAddress": { "Type": "AWS::EC2::EIP", - "Properties": { - "InstanceId": { - "Ref": "WebServerInstance" - }, - "Domain": "vpc" - }, - "DependsOn": "AttachGateway" + "Properties": {"InstanceId": {"Ref": "WebServerInstance"}, "Domain": "vpc"}, + "DependsOn": "AttachGateway", }, "AttachGateway": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { - "VpcId": { - "Ref": "VPC" - }, - "InternetGatewayId": { - "Ref": "InternetGateway" - } - } - } + "VpcId": {"Ref": "VPC"}, + "InternetGatewayId": {"Ref": "InternetGateway"}, + }, + }, }, "Mappings": { "RegionMap": { - "ap-southeast-1": { - "AMI": "ami-74dda626" - }, - "ap-southeast-2": { - "AMI": "ami-b3990e89" - }, - "us-west-2": { - "AMI": "ami-16fd7026" - }, - "us-east-1": { - "AMI": "ami-7f418316" - }, - "ap-northeast-1": { - "AMI": "ami-dcfa4edd" - }, - "us-west-1": { - "AMI": "ami-951945d0" - }, - "eu-west-1": { - "AMI": "ami-24506250" - }, - "sa-east-1": { - "AMI": "ami-3e3be423" - } + "ap-southeast-1": {"AMI": "ami-74dda626"}, + "ap-southeast-2": {"AMI": "ami-b3990e89"}, + "us-west-2": {"AMI": "ami-16fd7026"}, + "us-east-1": {"AMI": "ami-7f418316"}, + "ap-northeast-1": {"AMI": "ami-dcfa4edd"}, + "us-west-1": {"AMI": "ami-951945d0"}, + "eu-west-1": {"AMI": "ami-24506250"}, + "sa-east-1": {"AMI": "ami-3e3be423"}, } - } + }, } diff --git a/tests/test_cloudformation/test_cloudformation_stack_crud.py b/tests/test_cloudformation/test_cloudformation_stack_crud.py index 27424bf8c..75f705ea7 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_crud.py +++ b/tests/test_cloudformation/test_cloudformation_stack_crud.py @@ -4,16 +4,24 @@ import os import json import boto +import boto.iam import boto.s3 import boto.s3.key import boto.cloudformation from boto.exception import BotoServerError import sure # noqa + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises # noqa from nose.tools import assert_raises +from moto.core import ACCOUNT_ID -from moto import mock_cloudformation_deprecated, mock_s3_deprecated, mock_route53_deprecated +from moto import ( + mock_cloudformation_deprecated, + mock_s3_deprecated, + mock_route53_deprecated, + mock_iam_deprecated, +) from moto.cloudformation import cloudformation_backends dummy_template = { @@ -33,12 +41,7 @@ dummy_template3 = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 3", "Resources": { - "VPC": { - "Properties": { - "CidrBlock": "192.168.0.0/16", - }, - "Type": "AWS::EC2::VPC" - } + "VPC": {"Properties": {"CidrBlock": "192.168.0.0/16"}, "Type": "AWS::EC2::VPC"} }, } @@ -50,24 +53,22 @@ dummy_template_json3 = json.dumps(dummy_template3) @mock_cloudformation_deprecated def test_create_stack(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks()[0] - stack.stack_name.should.equal('test_stack') - stack.get_template().should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + stack.stack_name.should.equal("test_stack") + stack.get_template().should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - - }) + ) @mock_cloudformation_deprecated @@ -77,44 +78,34 @@ def test_create_stack_hosted_zone_by_id(): dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 1", - "Parameters": { - }, + "Parameters": {}, "Resources": { "Bar": { - "Type" : "AWS::Route53::HostedZone", - "Properties" : { - "Name" : "foo.bar.baz", - } - }, + "Type": "AWS::Route53::HostedZone", + "Properties": {"Name": "foo.bar.baz"}, + } }, } dummy_template2 = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 2", - "Parameters": { - "ZoneId": { "Type": "String" } - }, + "Parameters": {"ZoneId": {"Type": "String"}}, "Resources": { "Foo": { - "Properties": { - "HostedZoneId": {"Ref": "ZoneId"}, - "RecordSets": [] - }, - "Type": "AWS::Route53::RecordSetGroup" + "Properties": {"HostedZoneId": {"Ref": "ZoneId"}, "RecordSets": []}, + "Type": "AWS::Route53::RecordSetGroup", } }, } conn.create_stack( - "test_stack", - template_body=json.dumps(dummy_template), - parameters={}.items() + "test_stack", template_body=json.dumps(dummy_template), parameters={}.items() ) r53_conn = boto.connect_route53() zone_id = r53_conn.get_zones()[0].id conn.create_stack( "test_stack", template_body=json.dumps(dummy_template2), - parameters={"ZoneId": zone_id}.items() + parameters={"ZoneId": zone_id}.items(), ) stack = conn.describe_stacks()[0] @@ -139,62 +130,57 @@ def test_create_stack_with_notification_arn(): conn.create_stack( "test_stack_with_notifications", template_body=dummy_template_json, - notification_arns='arn:aws:sns:us-east-1:123456789012:fake-queue' + notification_arns="arn:aws:sns:us-east-1:{}:fake-queue".format(ACCOUNT_ID), ) stack = conn.describe_stacks()[0] [n.value for n in stack.notification_arns].should.contain( - 'arn:aws:sns:us-east-1:123456789012:fake-queue') + "arn:aws:sns:us-east-1:{}:fake-queue".format(ACCOUNT_ID) + ) @mock_cloudformation_deprecated @mock_s3_deprecated def test_create_stack_from_s3_url(): - s3_conn = boto.s3.connect_to_region('us-west-1') + s3_conn = boto.s3.connect_to_region("us-west-1") bucket = s3_conn.create_bucket("foobar") key = boto.s3.key.Key(bucket) key.key = "template-key" key.set_contents_from_string(dummy_template_json) key_url = key.generate_url(expires_in=0, query_auth=False) - conn = boto.cloudformation.connect_to_region('us-west-1') - conn.create_stack('new-stack', template_url=key_url) + conn = boto.cloudformation.connect_to_region("us-west-1") + conn.create_stack("new-stack", template_url=key_url) stack = conn.describe_stacks()[0] - stack.stack_name.should.equal('new-stack') + stack.stack_name.should.equal("new-stack") stack.get_template().should.equal( { - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' - } + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } - - }) + } + ) @mock_cloudformation_deprecated def test_describe_stack_by_name(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks("test_stack")[0] - stack.stack_name.should.equal('test_stack') + stack.stack_name.should.equal("test_stack") @mock_cloudformation_deprecated def test_describe_stack_by_stack_id(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks("test_stack")[0] stack_by_id = conn.describe_stacks(stack.stack_id)[0] @@ -205,10 +191,7 @@ def test_describe_stack_by_stack_id(): @mock_cloudformation_deprecated def test_describe_deleted_stack(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks("test_stack")[0] stack_id = stack.stack_id @@ -222,36 +205,28 @@ def test_describe_deleted_stack(): @mock_cloudformation_deprecated def test_get_template_by_name(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) template = conn.get_template("test_stack") - template.should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + template.should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - - }) + ) @mock_cloudformation_deprecated def test_list_stacks(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) - conn.create_stack( - "test_stack2", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) + conn.create_stack("test_stack2", template_body=dummy_template_json) stacks = conn.list_stacks() stacks.should.have.length_of(2) @@ -261,10 +236,7 @@ def test_list_stacks(): @mock_cloudformation_deprecated def test_delete_stack_by_name(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) conn.describe_stacks().should.have.length_of(1) conn.delete_stack("test_stack") @@ -274,10 +246,7 @@ def test_delete_stack_by_name(): @mock_cloudformation_deprecated def test_delete_stack_by_id(): conn = boto.connect_cloudformation() - stack_id = conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + stack_id = conn.create_stack("test_stack", template_body=dummy_template_json) conn.describe_stacks().should.have.length_of(1) conn.delete_stack(stack_id) @@ -291,10 +260,7 @@ def test_delete_stack_by_id(): @mock_cloudformation_deprecated def test_delete_stack_with_resource_missing_delete_attr(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json3, - ) + conn.create_stack("test_stack", template_body=dummy_template_json3) conn.describe_stacks().should.have.length_of(1) conn.delete_stack("test_stack") @@ -318,19 +284,22 @@ def test_cloudformation_params(): "APPNAME": { "Default": "app-name", "Description": "The name of the app", - "Type": "String" + "Type": "String", } - } + }, } dummy_template_json = json.dumps(dummy_template) cfn = boto.connect_cloudformation() - cfn.create_stack('test_stack1', template_body=dummy_template_json, parameters=[ - ('APPNAME', 'testing123')]) - stack = cfn.describe_stacks('test_stack1')[0] + cfn.create_stack( + "test_stack1", + template_body=dummy_template_json, + parameters=[("APPNAME", "testing123")], + ) + stack = cfn.describe_stacks("test_stack1")[0] stack.parameters.should.have.length_of(1) param = stack.parameters[0] - param.key.should.equal('APPNAME') - param.value.should.equal('testing123') + param.key.should.equal("APPNAME") + param.value.should.equal("testing123") @mock_cloudformation_deprecated @@ -339,52 +308,34 @@ def test_cloudformation_params_conditions_and_resources_are_distinct(): "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 1", "Conditions": { - "FooEnabled": { - "Fn::Equals": [ - { - "Ref": "FooEnabled" - }, - "true" - ] - }, + "FooEnabled": {"Fn::Equals": [{"Ref": "FooEnabled"}, "true"]}, "FooDisabled": { - "Fn::Not": [ - { - "Fn::Equals": [ - { - "Ref": "FooEnabled" - }, - "true" - ] - } - ] - } + "Fn::Not": [{"Fn::Equals": [{"Ref": "FooEnabled"}, "true"]}] + }, }, "Parameters": { - "FooEnabled": { - "Type": "String", - "AllowedValues": [ - "true", - "false" - ] - } + "FooEnabled": {"Type": "String", "AllowedValues": ["true", "false"]} }, "Resources": { "Bar": { - "Properties": { - "CidrBlock": "192.168.0.0/16", - }, + "Properties": {"CidrBlock": "192.168.0.0/16"}, "Condition": "FooDisabled", - "Type": "AWS::EC2::VPC" + "Type": "AWS::EC2::VPC", } - } + }, } dummy_template_json = json.dumps(dummy_template) cfn = boto.connect_cloudformation() - cfn.create_stack('test_stack1', template_body=dummy_template_json, parameters=[('FooEnabled', 'true')]) - stack = cfn.describe_stacks('test_stack1')[0] + cfn.create_stack( + "test_stack1", + template_body=dummy_template_json, + parameters=[("FooEnabled", "true")], + ) + stack = cfn.describe_stacks("test_stack1")[0] resources = stack.list_resources() - assert not [resource for resource in resources if resource.logical_resource_id == 'Bar'] + assert not [ + resource for resource in resources if resource.logical_resource_id == "Bar" + ] @mock_cloudformation_deprecated @@ -403,48 +354,46 @@ def test_stack_tags(): @mock_cloudformation_deprecated def test_update_stack(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) conn.update_stack("test_stack", dummy_template_json2) stack = conn.describe_stacks()[0] stack.stack_status.should.equal("UPDATE_COMPLETE") - stack.get_template().should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json2, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + stack.get_template().should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json2, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - }) + ) @mock_cloudformation_deprecated def test_update_stack_with_previous_template(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) conn.update_stack("test_stack", use_previous_template=True) stack = conn.describe_stacks()[0] stack.stack_status.should.equal("UPDATE_COMPLETE") - stack.get_template().should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + stack.get_template().should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - }) + ) @mock_cloudformation_deprecated @@ -454,29 +403,23 @@ def test_update_stack_with_parameters(): "Description": "Stack", "Resources": { "VPC": { - "Properties": { - "CidrBlock": {"Ref": "Bar"} - }, - "Type": "AWS::EC2::VPC" + "Properties": {"CidrBlock": {"Ref": "Bar"}}, + "Type": "AWS::EC2::VPC", } }, - "Parameters": { - "Bar": { - "Type": "String" - } - } + "Parameters": {"Bar": {"Type": "String"}}, } dummy_template_json = json.dumps(dummy_template) conn = boto.connect_cloudformation() conn.create_stack( "test_stack", template_body=dummy_template_json, - parameters=[("Bar", "192.168.0.0/16")] + parameters=[("Bar", "192.168.0.0/16")], ) conn.update_stack( "test_stack", template_body=dummy_template_json, - parameters=[("Bar", "192.168.0.1/16")] + parameters=[("Bar", "192.168.0.1/16")], ) stack = conn.describe_stacks()[0] @@ -487,14 +430,10 @@ def test_update_stack_with_parameters(): def test_update_stack_replace_tags(): conn = boto.connect_cloudformation() conn.create_stack( - "test_stack", - template_body=dummy_template_json, - tags={"foo": "bar"}, + "test_stack", template_body=dummy_template_json, tags={"foo": "bar"} ) conn.update_stack( - "test_stack", - template_body=dummy_template_json, - tags={"foo": "baz"}, + "test_stack", template_body=dummy_template_json, tags={"foo": "baz"} ) stack = conn.describe_stacks()[0] @@ -506,28 +445,26 @@ def test_update_stack_replace_tags(): @mock_cloudformation_deprecated def test_update_stack_when_rolled_back(): conn = boto.connect_cloudformation() - stack_id = conn.create_stack( - "test_stack", template_body=dummy_template_json) + stack_id = conn.create_stack("test_stack", template_body=dummy_template_json) cloudformation_backends[conn.region.name].stacks[ - stack_id].status = 'ROLLBACK_COMPLETE' + stack_id + ].status = "ROLLBACK_COMPLETE" with assert_raises(BotoServerError) as err: conn.update_stack("test_stack", dummy_template_json) ex = err.exception - ex.body.should.match( - r'is in ROLLBACK_COMPLETE state and can not be updated') - ex.error_code.should.equal('ValidationError') - ex.reason.should.equal('Bad Request') + ex.body.should.match(r"is in ROLLBACK_COMPLETE state and can not be updated") + ex.error_code.should.equal("ValidationError") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_cloudformation_deprecated def test_describe_stack_events_shows_create_update_and_delete(): conn = boto.connect_cloudformation() - stack_id = conn.create_stack( - "test_stack", template_body=dummy_template_json) + stack_id = conn.create_stack("test_stack", template_body=dummy_template_json) conn.update_stack(stack_id, template_body=dummy_template_json2) conn.delete_stack(stack_id) @@ -538,14 +475,16 @@ def test_describe_stack_events_shows_create_update_and_delete(): # testing ordering of stack events without assuming resource events will not exist # the AWS API returns events in reverse chronological order - stack_events_to_look_for = iter([ - ("DELETE_COMPLETE", None), - ("DELETE_IN_PROGRESS", "User Initiated"), - ("UPDATE_COMPLETE", None), - ("UPDATE_IN_PROGRESS", "User Initiated"), - ("CREATE_COMPLETE", None), - ("CREATE_IN_PROGRESS", "User Initiated"), - ]) + stack_events_to_look_for = iter( + [ + ("DELETE_COMPLETE", None), + ("DELETE_IN_PROGRESS", "User Initiated"), + ("UPDATE_COMPLETE", None), + ("UPDATE_IN_PROGRESS", "User Initiated"), + ("CREATE_COMPLETE", None), + ("CREATE_IN_PROGRESS", "User Initiated"), + ] + ) try: for event in events: event.stack_id.should.equal(stack_id) @@ -556,12 +495,10 @@ def test_describe_stack_events_shows_create_update_and_delete(): event.logical_resource_id.should.equal("test_stack") event.physical_resource_id.should.equal(stack_id) - status_to_look_for, reason_to_look_for = next( - stack_events_to_look_for) + status_to_look_for, reason_to_look_for = next(stack_events_to_look_for) event.resource_status.should.equal(status_to_look_for) if reason_to_look_for is not None: - event.resource_status_reason.should.equal( - reason_to_look_for) + event.resource_status_reason.should.equal(reason_to_look_for) except StopIteration: assert False, "Too many stack events" @@ -574,74 +511,60 @@ def test_create_stack_lambda_and_dynamodb(): dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack Lambda Test 1", - "Parameters": { - }, + "Parameters": {}, "Resources": { "func1": { - "Type" : "AWS::Lambda::Function", - "Properties" : { - "Code": { - "S3Bucket": "bucket_123", - "S3Key": "key_123" - }, + "Type": "AWS::Lambda::Function", + "Properties": { + "Code": {"S3Bucket": "bucket_123", "S3Key": "key_123"}, "FunctionName": "func1", "Handler": "handler.handler", - "Role": "role1", + "Role": get_role_name(), "Runtime": "python2.7", "Description": "descr", "MemorySize": 12345, - } + }, }, "func1version": { "Type": "AWS::Lambda::Version", - "Properties": { - "FunctionName": { - "Ref": "func1" - } - } + "Properties": {"FunctionName": {"Ref": "func1"}}, }, "tab1": { - "Type" : "AWS::DynamoDB::Table", - "Properties" : { + "Type": "AWS::DynamoDB::Table", + "Properties": { "TableName": "tab1", - "KeySchema": [{ - "AttributeName": "attr1", - "KeyType": "HASH" - }], - "AttributeDefinitions": [{ - "AttributeName": "attr1", - "AttributeType": "string" - }], + "KeySchema": [{"AttributeName": "attr1", "KeyType": "HASH"}], + "AttributeDefinitions": [ + {"AttributeName": "attr1", "AttributeType": "string"} + ], "ProvisionedThroughput": { "ReadCapacityUnits": 10, - "WriteCapacityUnits": 10 - } - } + "WriteCapacityUnits": 10, + }, + }, }, "func1mapping": { "Type": "AWS::Lambda::EventSourceMapping", "Properties": { - "FunctionName": { - "Ref": "func1" - }, + "FunctionName": {"Ref": "func1"}, "EventSourceArn": "arn:aws:dynamodb:region:XXXXXX:table/tab1/stream/2000T00:00:00.000", "StartingPosition": "0", "BatchSize": 100, - "Enabled": True - } - } + "Enabled": True, + }, + }, }, } - validate_s3_before = os.environ.get('VALIDATE_LAMBDA_S3', '') + validate_s3_before = os.environ.get("VALIDATE_LAMBDA_S3", "") try: - os.environ['VALIDATE_LAMBDA_S3'] = 'false' + os.environ["VALIDATE_LAMBDA_S3"] = "false" conn.create_stack( "test_stack_lambda_1", template_body=json.dumps(dummy_template), - parameters={}.items() + parameters={}.items(), ) finally: - os.environ['VALIDATE_LAMBDA_S3'] = validate_s3_before + os.environ["VALIDATE_LAMBDA_S3"] = validate_s3_before stack = conn.describe_stacks()[0] resources = stack.list_resources() @@ -657,20 +580,26 @@ def test_create_stack_kinesis(): "Parameters": {}, "Resources": { "stream1": { - "Type" : "AWS::Kinesis::Stream", - "Properties" : { - "Name": "stream1", - "ShardCount": 2 - } + "Type": "AWS::Kinesis::Stream", + "Properties": {"Name": "stream1", "ShardCount": 2}, } - } + }, } conn.create_stack( "test_stack_kinesis_1", template_body=json.dumps(dummy_template), - parameters={}.items() + parameters={}.items(), ) stack = conn.describe_stacks()[0] resources = stack.list_resources() assert len(resources) == 1 + + +def get_role_name(): + with mock_iam_deprecated(): + iam = boto.connect_iam() + role = iam.create_role("my-role")["create_role_response"]["create_role_result"][ + "role" + ]["arn"] + return role diff --git a/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py b/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py index d05bc1b53..40fb2d669 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py +++ b/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py @@ -6,10 +6,12 @@ from collections import OrderedDict import boto3 from botocore.exceptions import ClientError import sure # noqa + # Ensure 'assert_raises' context manager support for Python 2.6 from nose.tools import assert_raises from moto import mock_cloudformation, mock_s3, mock_sqs, mock_ec2 +from moto.core import ACCOUNT_ID dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -22,18 +24,12 @@ dummy_template = { "KeyName": "dummy", "InstanceType": "t2.micro", "Tags": [ - { - "Key": "Description", - "Value": "Test tag" - }, - { - "Key": "Name", - "Value": "Name tag for tests" - } - ] - } + {"Key": "Description", "Value": "Test tag"}, + {"Key": "Name", "Value": "Name tag for tests"}, + ], + }, } - } + }, } dummy_template_yaml = """--- @@ -100,17 +96,15 @@ dummy_update_template = { "KeyName": { "Description": "Name of an existing EC2 KeyPair", "Type": "AWS::EC2::KeyPair::KeyName", - "ConstraintDescription": "must be the name of an existing EC2 KeyPair." + "ConstraintDescription": "must be the name of an existing EC2 KeyPair.", } }, "Resources": { "Instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-08111162" - } + "Properties": {"ImageId": "ami-08111162"}, } - } + }, } dummy_output_template = { @@ -119,20 +113,16 @@ dummy_output_template = { "Resources": { "Instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-08111162" - } + "Properties": {"ImageId": "ami-08111162"}, } }, "Outputs": { "StackVPC": { "Description": "The ID of the VPC", "Value": "VPCID", - "Export": { - "Name": "My VPC ID" - } + "Export": {"Name": "My VPC ID"}, } - } + }, } dummy_import_template = { @@ -141,11 +131,11 @@ dummy_import_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::ImportValue": 'My VPC ID'}, + "QueueName": {"Fn::ImportValue": "My VPC ID"}, "VisibilityTimeout": 60, - } + }, } - } + }, } dummy_redrive_template = { @@ -158,23 +148,16 @@ dummy_redrive_template = { "FifoQueue": True, "ContentBasedDeduplication": False, "RedrivePolicy": { - "deadLetterTargetArn": { - "Fn::GetAtt": [ - "DeadLetterQueue", - "Arn" - ] - }, - "maxReceiveCount": 5 - } - } + "deadLetterTargetArn": {"Fn::GetAtt": ["DeadLetterQueue", "Arn"]}, + "maxReceiveCount": 5, + }, + }, }, "DeadLetterQueue": { "Type": "AWS::SQS::Queue", - "Properties": { - "FifoQueue": True - } + "Properties": {"FifoQueue": True}, }, - } + }, } dummy_template_json = json.dumps(dummy_template) @@ -186,43 +169,48 @@ dummy_redrive_template_json = json.dumps(dummy_redrive_template) @mock_cloudformation def test_boto3_describe_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-2"], ) usw2_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-2', + StackInstanceAccount=ACCOUNT_ID, + StackInstanceRegion="us-west-2", ) use1_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-east-1', + StackInstanceAccount=ACCOUNT_ID, + StackInstanceRegion="us-east-1", ) - usw2_instance['StackInstance'].should.have.key('Region').which.should.equal('us-west-2') - usw2_instance['StackInstance'].should.have.key('Account').which.should.equal('123456789012') - use1_instance['StackInstance'].should.have.key('Region').which.should.equal('us-east-1') - use1_instance['StackInstance'].should.have.key('Account').which.should.equal('123456789012') + usw2_instance["StackInstance"].should.have.key("Region").which.should.equal( + "us-west-2" + ) + usw2_instance["StackInstance"].should.have.key("Account").which.should.equal( + ACCOUNT_ID + ) + use1_instance["StackInstance"].should.have.key("Region").which.should.equal( + "us-east-1" + ) + use1_instance["StackInstance"].should.have.key("Account").which.should.equal( + ACCOUNT_ID + ) @mock_cloudformation def test_boto3_list_stacksets_length(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_set( - StackSetName="test_stack_set2", - TemplateBody=dummy_template_yaml, + StackSetName="test_stack_set2", TemplateBody=dummy_template_yaml ) stacksets = cf_conn.list_stack_sets() stacksets.should.have.length_of(2) @@ -230,106 +218,100 @@ def test_boto3_list_stacksets_length(): @mock_cloudformation def test_boto3_list_stacksets_contents(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) stacksets = cf_conn.list_stack_sets() - stacksets['Summaries'][0].should.have.key('StackSetName').which.should.equal('test_stack_set') - stacksets['Summaries'][0].should.have.key('Status').which.should.equal('ACTIVE') + stacksets["Summaries"][0].should.have.key("StackSetName").which.should.equal( + "test_stack_set" + ) + stacksets["Summaries"][0].should.have.key("Status").which.should.equal("ACTIVE") @mock_cloudformation def test_boto3_stop_stack_set_operation(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) - operation_id = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set")['Summaries'][-1]['OperationId'] + operation_id = cf_conn.list_stack_set_operations(StackSetName="test_stack_set")[ + "Summaries" + ][-1]["OperationId"] cf_conn.stop_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id + StackSetName="test_stack_set", OperationId=operation_id ) - list_operation = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set" - ) - list_operation['Summaries'][-1]['Status'].should.equal('STOPPED') + list_operation = cf_conn.list_stack_set_operations(StackSetName="test_stack_set") + list_operation["Summaries"][-1]["Status"].should.equal("STOPPED") @mock_cloudformation def test_boto3_describe_stack_set_operation(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) - operation_id = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set")['Summaries'][-1]['OperationId'] + operation_id = cf_conn.list_stack_set_operations(StackSetName="test_stack_set")[ + "Summaries" + ][-1]["OperationId"] cf_conn.stop_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id + StackSetName="test_stack_set", OperationId=operation_id ) response = cf_conn.describe_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id, + StackSetName="test_stack_set", OperationId=operation_id ) - response['StackSetOperation']['Status'].should.equal('STOPPED') - response['StackSetOperation']['Action'].should.equal('CREATE') + response["StackSetOperation"]["Status"].should.equal("STOPPED") + response["StackSetOperation"]["Action"].should.equal("CREATE") @mock_cloudformation def test_boto3_list_stack_set_operation_results(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) - operation_id = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set")['Summaries'][-1]['OperationId'] + operation_id = cf_conn.list_stack_set_operations(StackSetName="test_stack_set")[ + "Summaries" + ][-1]["OperationId"] cf_conn.stop_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id + StackSetName="test_stack_set", OperationId=operation_id ) response = cf_conn.list_stack_set_operation_results( - StackSetName="test_stack_set", - OperationId=operation_id, + StackSetName="test_stack_set", OperationId=operation_id ) - response['Summaries'].should.have.length_of(3) - response['Summaries'][0].should.have.key('Account').which.should.equal('123456789012') - response['Summaries'][1].should.have.key('Status').which.should.equal('STOPPED') + response["Summaries"].should.have.length_of(3) + response["Summaries"][0].should.have.key("Account").which.should.equal(ACCOUNT_ID) + response["Summaries"][1].should.have.key("Status").which.should.equal("STOPPED") @mock_cloudformation def test_boto3_update_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") param = [ - {'ParameterKey': 'SomeParam', 'ParameterValue': 'StackSetValue'}, - {'ParameterKey': 'AnotherParam', 'ParameterValue': 'StackSetValue2'}, + {"ParameterKey": "SomeParam", "ParameterValue": "StackSetValue"}, + {"ParameterKey": "AnotherParam", "ParameterValue": "StackSetValue2"}, ] param_overrides = [ - {'ParameterKey': 'SomeParam', 'ParameterValue': 'OverrideValue'}, - {'ParameterKey': 'AnotherParam', 'ParameterValue': 'OverrideValue2'} + {"ParameterKey": "SomeParam", "ParameterValue": "OverrideValue"}, + {"ParameterKey": "AnotherParam", "ParameterValue": "OverrideValue2"}, ] cf_conn.create_stack_set( StackSetName="test_stack_set", @@ -338,97 +320,117 @@ def test_boto3_update_stack_instances(): ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) cf_conn.update_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-west-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-west-1", "us-west-2"], ParameterOverrides=param_overrides, ) usw2_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-2', + StackInstanceAccount=ACCOUNT_ID, + StackInstanceRegion="us-west-2", ) usw1_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-1', + StackInstanceAccount=ACCOUNT_ID, + StackInstanceRegion="us-west-1", ) use1_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-east-1', + StackInstanceAccount=ACCOUNT_ID, + StackInstanceRegion="us-east-1", ) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterKey" + ].should.equal(param_overrides[0]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterValue" + ].should.equal(param_overrides[0]["ParameterValue"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterKey" + ].should.equal(param_overrides[1]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterValue" + ].should.equal(param_overrides[1]["ParameterValue"]) - usw1_instance['StackInstance']['ParameterOverrides'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - usw1_instance['StackInstance']['ParameterOverrides'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - usw1_instance['StackInstance']['ParameterOverrides'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) - usw1_instance['StackInstance']['ParameterOverrides'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) + usw1_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterKey" + ].should.equal(param_overrides[0]["ParameterKey"]) + usw1_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterValue" + ].should.equal(param_overrides[0]["ParameterValue"]) + usw1_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterKey" + ].should.equal(param_overrides[1]["ParameterKey"]) + usw1_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterValue" + ].should.equal(param_overrides[1]["ParameterValue"]) - use1_instance['StackInstance']['ParameterOverrides'].should.be.empty + use1_instance["StackInstance"]["ParameterOverrides"].should.be.empty @mock_cloudformation def test_boto3_delete_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-2"], ) cf_conn.delete_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1"], RetainStacks=False, ) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'].should.have.length_of(1) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'][0]['Region'].should.equal( - 'us-west-2') + cf_conn.list_stack_instances(StackSetName="test_stack_set")[ + "Summaries" + ].should.have.length_of(1) + cf_conn.list_stack_instances(StackSetName="test_stack_set")["Summaries"][0][ + "Region" + ].should.equal("us-west-2") @mock_cloudformation def test_boto3_create_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-2"], ) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'].should.have.length_of(2) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'][0]['Account'].should.equal( - '123456789012') + cf_conn.list_stack_instances(StackSetName="test_stack_set")[ + "Summaries" + ].should.have.length_of(2) + cf_conn.list_stack_instances(StackSetName="test_stack_set")["Summaries"][0][ + "Account" + ].should.equal(ACCOUNT_ID) @mock_cloudformation def test_boto3_create_stack_instances_with_param_overrides(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") param = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'StackSetValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'StackSetValue2'}, + {"ParameterKey": "TagDescription", "ParameterValue": "StackSetValue"}, + {"ParameterKey": "TagName", "ParameterValue": "StackSetValue2"}, ] param_overrides = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'OverrideValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'OverrideValue2'} + {"ParameterKey": "TagDescription", "ParameterValue": "OverrideValue"}, + {"ParameterKey": "TagName", "ParameterValue": "OverrideValue2"}, ] cf_conn.create_stack_set( StackSetName="test_stack_set", @@ -437,32 +439,40 @@ def test_boto3_create_stack_instances_with_param_overrides(): ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-2"], ParameterOverrides=param_overrides, ) usw2_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-2', + StackInstanceAccount=ACCOUNT_ID, + StackInstanceRegion="us-west-2", ) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterKey" + ].should.equal(param_overrides[0]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterKey" + ].should.equal(param_overrides[1]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterValue" + ].should.equal(param_overrides[0]["ParameterValue"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterValue" + ].should.equal(param_overrides[1]["ParameterValue"]) @mock_cloudformation def test_update_stack_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") param = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'StackSetValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'StackSetValue2'}, + {"ParameterKey": "TagDescription", "ParameterValue": "StackSetValue"}, + {"ParameterKey": "TagName", "ParameterValue": "StackSetValue2"}, ] param_overrides = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'OverrideValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'OverrideValue2'} + {"ParameterKey": "TagDescription", "ParameterValue": "OverrideValue"}, + {"ParameterKey": "TagName", "ParameterValue": "OverrideValue2"}, ] cf_conn.create_stack_set( StackSetName="test_stack_set", @@ -470,203 +480,196 @@ def test_update_stack_set(): Parameters=param, ) cf_conn.update_stack_set( - StackSetName='test_stack_set', + StackSetName="test_stack_set", TemplateBody=dummy_template_yaml_with_ref, Parameters=param_overrides, ) - stackset = cf_conn.describe_stack_set(StackSetName='test_stack_set') + stackset = cf_conn.describe_stack_set(StackSetName="test_stack_set") - stackset['StackSet']['Parameters'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - stackset['StackSet']['Parameters'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) - stackset['StackSet']['Parameters'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - stackset['StackSet']['Parameters'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) + stackset["StackSet"]["Parameters"][0]["ParameterValue"].should.equal( + param_overrides[0]["ParameterValue"] + ) + stackset["StackSet"]["Parameters"][1]["ParameterValue"].should.equal( + param_overrides[1]["ParameterValue"] + ) + stackset["StackSet"]["Parameters"][0]["ParameterKey"].should.equal( + param_overrides[0]["ParameterKey"] + ) + stackset["StackSet"]["Parameters"][1]["ParameterKey"].should.equal( + param_overrides[1]["ParameterKey"] + ) @mock_cloudformation def test_boto3_list_stack_set_operations(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-2"], ) cf_conn.update_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=[ACCOUNT_ID], + Regions=["us-east-1", "us-west-2"], ) list_operation = cf_conn.list_stack_set_operations(StackSetName="test_stack_set") - list_operation['Summaries'].should.have.length_of(2) - list_operation['Summaries'][-1]['Action'].should.equal('UPDATE') + list_operation["Summaries"].should.have.length_of(2) + list_operation["Summaries"][-1]["Action"].should.equal("UPDATE") @mock_cloudformation def test_boto3_delete_stack_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) - cf_conn.delete_stack_set(StackSetName='test_stack_set') + cf_conn.delete_stack_set(StackSetName="test_stack_set") - cf_conn.describe_stack_set(StackSetName="test_stack_set")['StackSet']['Status'].should.equal( - 'DELETED') + cf_conn.describe_stack_set(StackSetName="test_stack_set")["StackSet"][ + "Status" + ].should.equal("DELETED") @mock_cloudformation def test_boto3_create_stack_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) - cf_conn.describe_stack_set(StackSetName="test_stack_set")['StackSet']['TemplateBody'].should.equal( - dummy_template_json) + cf_conn.describe_stack_set(StackSetName="test_stack_set")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_json) @mock_cloudformation def test_boto3_create_stack_set_with_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_yaml, + StackSetName="test_stack_set", TemplateBody=dummy_template_yaml ) - cf_conn.describe_stack_set(StackSetName="test_stack_set")['StackSet']['TemplateBody'].should.equal( - dummy_template_yaml) + cf_conn.describe_stack_set(StackSetName="test_stack_set")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_yaml) @mock_cloudformation @mock_s3 def test_create_stack_set_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") bucket = s3_conn.create_bucket(Bucket="foobar") - key = s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_template_json) + key = s3_conn.Object("foobar", "template-key").put(Body=dummy_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn = boto3.client('cloudformation', region_name='us-west-1') - cf_conn.create_stack_set( - StackSetName='stack_from_url', - TemplateURL=key_url, - ) - cf_conn.describe_stack_set(StackSetName="stack_from_url")['StackSet']['TemplateBody'].should.equal( - dummy_template_json) + cf_conn = boto3.client("cloudformation", region_name="us-west-1") + cf_conn.create_stack_set(StackSetName="stack_from_url", TemplateURL=key_url) + cf_conn.describe_stack_set(StackSetName="stack_from_url")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_json) @mock_cloudformation def test_boto3_create_stack_set_with_ref_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") params = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'desc_ref'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'name_ref'}, + {"ParameterKey": "TagDescription", "ParameterValue": "desc_ref"}, + {"ParameterKey": "TagName", "ParameterValue": "name_ref"}, ] cf_conn.create_stack_set( StackSetName="test_stack", TemplateBody=dummy_template_yaml_with_ref, - Parameters=params + Parameters=params, ) - cf_conn.describe_stack_set(StackSetName="test_stack")['StackSet']['TemplateBody'].should.equal( - dummy_template_yaml_with_ref) + cf_conn.describe_stack_set(StackSetName="test_stack")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_yaml_with_ref) @mock_cloudformation def test_boto3_describe_stack_set_params(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") params = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'desc_ref'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'name_ref'}, + {"ParameterKey": "TagDescription", "ParameterValue": "desc_ref"}, + {"ParameterKey": "TagName", "ParameterValue": "name_ref"}, ] cf_conn.create_stack_set( StackSetName="test_stack", TemplateBody=dummy_template_yaml_with_ref, - Parameters=params + Parameters=params, ) - cf_conn.describe_stack_set(StackSetName="test_stack")['StackSet']['Parameters'].should.equal( - params) + cf_conn.describe_stack_set(StackSetName="test_stack")["StackSet"][ + "Parameters" + ].should.equal(params) @mock_cloudformation def test_boto3_create_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - json.loads(dummy_template_json, object_pairs_hook=OrderedDict)) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + json.loads(dummy_template_json, object_pairs_hook=OrderedDict) + ) @mock_cloudformation def test_boto3_create_stack_with_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_yaml, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_yaml) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - dummy_template_yaml) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + dummy_template_yaml + ) @mock_cloudformation def test_boto3_create_stack_with_short_form_func_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_yaml_with_short_form_func, + StackName="test_stack", TemplateBody=dummy_template_yaml_with_short_form_func ) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - dummy_template_yaml_with_short_form_func) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + dummy_template_yaml_with_short_form_func + ) @mock_cloudformation def test_boto3_create_stack_with_ref_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") params = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'desc_ref'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'name_ref'}, + {"ParameterKey": "TagDescription", "ParameterValue": "desc_ref"}, + {"ParameterKey": "TagName", "ParameterValue": "name_ref"}, ] cf_conn.create_stack( StackName="test_stack", TemplateBody=dummy_template_yaml_with_ref, - Parameters=params + Parameters=params, ) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - dummy_template_yaml_with_ref) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + dummy_template_yaml_with_ref + ) @mock_cloudformation def test_creating_stacks_across_regions(): - west1_cf = boto3.resource('cloudformation', region_name='us-west-1') - west2_cf = boto3.resource('cloudformation', region_name='us-west-2') - west1_cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) - west2_cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + west1_cf = boto3.resource("cloudformation", region_name="us-west-1") + west2_cf = boto3.resource("cloudformation", region_name="us-west-2") + west1_cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) + west2_cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) list(west1_cf.stacks.all()).should.have.length_of(1) list(west2_cf.stacks.all()).should.have.length_of(1) @@ -674,289 +677,266 @@ def test_creating_stacks_across_regions(): @mock_cloudformation def test_create_stack_with_notification_arn(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") cf.create_stack( StackName="test_stack_with_notifications", TemplateBody=dummy_template_json, - NotificationARNs=['arn:aws:sns:us-east-1:123456789012:fake-queue'], + NotificationARNs=["arn:aws:sns:us-east-1:{}:fake-queue".format(ACCOUNT_ID)], ) stack = list(cf.stacks.all())[0] stack.notification_arns.should.contain( - 'arn:aws:sns:us-east-1:123456789012:fake-queue') + "arn:aws:sns:us-east-1:{}:fake-queue".format(ACCOUNT_ID) + ) @mock_cloudformation def test_create_stack_with_role_arn(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") cf.create_stack( StackName="test_stack_with_notifications", TemplateBody=dummy_template_json, - RoleARN='arn:aws:iam::123456789012:role/moto', + RoleARN="arn:aws:iam::{}:role/moto".format(ACCOUNT_ID), ) stack = list(cf.stacks.all())[0] - stack.role_arn.should.equal('arn:aws:iam::123456789012:role/moto') + stack.role_arn.should.equal("arn:aws:iam::{}:role/moto".format(ACCOUNT_ID)) @mock_cloudformation @mock_s3 def test_create_stack_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") bucket = s3_conn.create_bucket(Bucket="foobar") - key = s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_template_json) + key = s3_conn.Object("foobar", "template-key").put(Body=dummy_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn = boto3.client('cloudformation', region_name='us-west-1') - cf_conn.create_stack( - StackName='stack_from_url', - TemplateURL=key_url, + cf_conn = boto3.client("cloudformation", region_name="us-west-1") + cf_conn.create_stack(StackName="stack_from_url", TemplateURL=key_url) + cf_conn.get_template(StackName="stack_from_url")["TemplateBody"].should.equal( + json.loads(dummy_template_json, object_pairs_hook=OrderedDict) ) - cf_conn.get_template(StackName="stack_from_url")['TemplateBody'].should.equal( - json.loads(dummy_template_json, object_pairs_hook=OrderedDict)) @mock_cloudformation def test_update_stack_with_previous_value(): - name = 'update_stack_with_previous_value' - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + name = "update_stack_with_previous_value" + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( - StackName=name, TemplateBody=dummy_template_yaml_with_ref, + StackName=name, + TemplateBody=dummy_template_yaml_with_ref, Parameters=[ - {'ParameterKey': 'TagName', 'ParameterValue': 'foo'}, - {'ParameterKey': 'TagDescription', 'ParameterValue': 'bar'}, - ] + {"ParameterKey": "TagName", "ParameterValue": "foo"}, + {"ParameterKey": "TagDescription", "ParameterValue": "bar"}, + ], ) cf_conn.update_stack( - StackName=name, UsePreviousTemplate=True, + StackName=name, + UsePreviousTemplate=True, Parameters=[ - {'ParameterKey': 'TagName', 'UsePreviousValue': True}, - {'ParameterKey': 'TagDescription', 'ParameterValue': 'not bar'}, - ] + {"ParameterKey": "TagName", "UsePreviousValue": True}, + {"ParameterKey": "TagDescription", "ParameterValue": "not bar"}, + ], ) - stack = cf_conn.describe_stacks(StackName=name)['Stacks'][0] - tag_name = [x['ParameterValue'] for x in stack['Parameters'] - if x['ParameterKey'] == 'TagName'][0] - tag_desc = [x['ParameterValue'] for x in stack['Parameters'] - if x['ParameterKey'] == 'TagDescription'][0] - assert tag_name == 'foo' - assert tag_desc == 'not bar' + stack = cf_conn.describe_stacks(StackName=name)["Stacks"][0] + tag_name = [ + x["ParameterValue"] + for x in stack["Parameters"] + if x["ParameterKey"] == "TagName" + ][0] + tag_desc = [ + x["ParameterValue"] + for x in stack["Parameters"] + if x["ParameterKey"] == "TagDescription" + ][0] + assert tag_name == "foo" + assert tag_desc == "not bar" @mock_cloudformation @mock_s3 @mock_ec2 def test_update_stack_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( StackName="update_stack_from_url", TemplateBody=dummy_template_json, - Tags=[{'Key': 'foo', 'Value': 'bar'}], + Tags=[{"Key": "foo", "Value": "bar"}], ) s3_conn.create_bucket(Bucket="foobar") - s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_update_template_json) + s3_conn.Object("foobar", "template-key").put(Body=dummy_update_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn.update_stack( - StackName="update_stack_from_url", - TemplateURL=key_url, - ) + cf_conn.update_stack(StackName="update_stack_from_url", TemplateURL=key_url) - cf_conn.get_template(StackName="update_stack_from_url")[ 'TemplateBody'].should.equal( - json.loads(dummy_update_template_json, object_pairs_hook=OrderedDict)) + cf_conn.get_template(StackName="update_stack_from_url")[ + "TemplateBody" + ].should.equal( + json.loads(dummy_update_template_json, object_pairs_hook=OrderedDict) + ) @mock_cloudformation @mock_s3 def test_create_change_set_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") bucket = s3_conn.create_bucket(Bucket="foobar") - key = s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_template_json) + key = s3_conn.Object("foobar", "template-key").put(Body=dummy_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn = boto3.client('cloudformation', region_name='us-west-1') + cf_conn = boto3.client("cloudformation", region_name="us-west-1") response = cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateURL=key_url, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', - Tags=[ - {'Key': 'tag-key', 'Value': 'tag-value'} - ], + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", + Tags=[{"Key": "tag-key", "Value": "tag-value"}], + ) + assert ( + "arn:aws:cloudformation:us-west-1:123456789:changeSet/NewChangeSet/" + in response["Id"] + ) + assert ( + "arn:aws:cloudformation:us-east-1:123456789:stack/NewStack" + in response["StackId"] ) - assert 'arn:aws:cloudformation:us-west-1:123456789:changeSet/NewChangeSet/' in response['Id'] - assert 'arn:aws:cloudformation:us-east-1:123456789:stack/NewStack' in response['StackId'] @mock_cloudformation def test_describe_change_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) stack = cf_conn.describe_change_set(ChangeSetName="NewChangeSet") - stack['ChangeSetName'].should.equal('NewChangeSet') - stack['StackName'].should.equal('NewStack') + stack["ChangeSetName"].should.equal("NewChangeSet") + stack["StackName"].should.equal("NewStack") cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_update_template_json, - ChangeSetName='NewChangeSet2', - ChangeSetType='UPDATE', + ChangeSetName="NewChangeSet2", + ChangeSetType="UPDATE", ) stack = cf_conn.describe_change_set(ChangeSetName="NewChangeSet2") - stack['ChangeSetName'].should.equal('NewChangeSet2') - stack['StackName'].should.equal('NewStack') - stack['Changes'].should.have.length_of(2) + stack["ChangeSetName"].should.equal("NewChangeSet2") + stack["StackName"].should.equal("NewStack") + stack["Changes"].should.have.length_of(2) @mock_cloudformation def test_execute_change_set_w_arn(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") change_set = cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) - cf_conn.execute_change_set(ChangeSetName=change_set['Id']) + cf_conn.execute_change_set(ChangeSetName=change_set["Id"]) @mock_cloudformation def test_execute_change_set_w_name(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") change_set = cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) - cf_conn.execute_change_set(ChangeSetName='NewChangeSet', StackName='NewStack') + cf_conn.execute_change_set(ChangeSetName="NewChangeSet", StackName="NewStack") @mock_cloudformation def test_describe_stack_pagination(): - conn = boto3.client('cloudformation', region_name='us-east-1') + conn = boto3.client("cloudformation", region_name="us-east-1") for i in range(100): - conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) resp = conn.describe_stacks() - stacks = resp['Stacks'] + stacks = resp["Stacks"] stacks.should.have.length_of(50) - next_token = resp['NextToken'] + next_token = resp["NextToken"] next_token.should_not.be.none resp2 = conn.describe_stacks(NextToken=next_token) - stacks.extend(resp2['Stacks']) + stacks.extend(resp2["Stacks"]) stacks.should.have.length_of(100) - assert 'NextToken' not in resp2.keys() + assert "NextToken" not in resp2.keys() @mock_cloudformation def test_describe_stack_resources(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] - response = cf_conn.describe_stack_resources(StackName=stack['StackName']) - resource = response['StackResources'][0] - resource['LogicalResourceId'].should.equal('EC2Instance1') - resource['ResourceStatus'].should.equal('CREATE_COMPLETE') - resource['ResourceType'].should.equal('AWS::EC2::Instance') - resource['StackId'].should.equal(stack['StackId']) + response = cf_conn.describe_stack_resources(StackName=stack["StackName"]) + resource = response["StackResources"][0] + resource["LogicalResourceId"].should.equal("EC2Instance1") + resource["ResourceStatus"].should.equal("CREATE_COMPLETE") + resource["ResourceType"].should.equal("AWS::EC2::Instance") + resource["StackId"].should.equal(stack["StackId"]) @mock_cloudformation def test_describe_stack_by_name(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack['StackName'].should.equal('test_stack') + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack["StackName"].should.equal("test_stack") @mock_cloudformation def test_describe_stack_by_stack_id(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack_by_id = cf_conn.describe_stacks(StackName=stack['StackId'])['Stacks'][ - 0] + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack_by_id = cf_conn.describe_stacks(StackName=stack["StackId"])["Stacks"][0] - stack_by_id['StackId'].should.equal(stack['StackId']) - stack_by_id['StackName'].should.equal("test_stack") + stack_by_id["StackId"].should.equal(stack["StackId"]) + stack_by_id["StackName"].should.equal("test_stack") @mock_cloudformation def test_list_change_sets(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_change_set( - StackName='NewStack2', + StackName="NewStack2", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet2', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet2", + ChangeSetType="CREATE", ) - change_set = cf_conn.list_change_sets(StackName='NewStack2')['Summaries'][0] - change_set['StackName'].should.equal('NewStack2') - change_set['ChangeSetName'].should.equal('NewChangeSet2') + change_set = cf_conn.list_change_sets(StackName="NewStack2")["Summaries"][0] + change_set["StackName"].should.equal("NewStack2") + change_set["ChangeSetName"].should.equal("NewChangeSet2") @mock_cloudformation def test_list_stacks(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) - cf.create_stack( - StackName="test_stack2", - TemplateBody=dummy_template_json, - ) + cf = boto3.resource("cloudformation", region_name="us-east-1") + cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) + cf.create_stack(StackName="test_stack2", TemplateBody=dummy_template_json) stacks = list(cf.stacks.all()) stacks.should.have.length_of(2) @@ -967,11 +947,8 @@ def test_list_stacks(): @mock_cloudformation def test_delete_stack_from_resource(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf = boto3.resource("cloudformation", region_name="us-east-1") + stack = cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) list(cf.stacks.all()).should.have.length_of(1) stack.delete() @@ -981,95 +958,84 @@ def test_delete_stack_from_resource(): @mock_cloudformation @mock_ec2 def test_delete_change_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) - cf_conn.list_change_sets(StackName='NewStack')['Summaries'].should.have.length_of(1) - cf_conn.delete_change_set(ChangeSetName='NewChangeSet', StackName='NewStack') - cf_conn.list_change_sets(StackName='NewStack')['Summaries'].should.have.length_of(0) + cf_conn.list_change_sets(StackName="NewStack")["Summaries"].should.have.length_of(1) + cf_conn.delete_change_set(ChangeSetName="NewChangeSet", StackName="NewStack") + cf_conn.list_change_sets(StackName="NewStack")["Summaries"].should.have.length_of(0) @mock_cloudformation @mock_ec2 def test_delete_stack_by_name(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - cf_conn.describe_stacks()['Stacks'].should.have.length_of(1) + cf_conn.describe_stacks()["Stacks"].should.have.length_of(1) cf_conn.delete_stack(StackName="test_stack") - cf_conn.describe_stacks()['Stacks'].should.have.length_of(0) + cf_conn.describe_stacks()["Stacks"].should.have.length_of(0) @mock_cloudformation def test_delete_stack(): - cf = boto3.client('cloudformation', region_name='us-east-1') - cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf = boto3.client("cloudformation", region_name="us-east-1") + cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - cf.delete_stack( - StackName="test_stack", - ) + cf.delete_stack(StackName="test_stack") stacks = cf.list_stacks() - assert stacks['StackSummaries'][0]['StackStatus'] == 'DELETE_COMPLETE' + assert stacks["StackSummaries"][0]["StackStatus"] == "DELETE_COMPLETE" @mock_cloudformation def test_describe_deleted_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack_id = stack['StackId'] - cf_conn.delete_stack(StackName=stack['StackId']) - stack_by_id = cf_conn.describe_stacks(StackName=stack_id)['Stacks'][0] - stack_by_id['StackId'].should.equal(stack['StackId']) - stack_by_id['StackName'].should.equal("test_stack") - stack_by_id['StackStatus'].should.equal("DELETE_COMPLETE") + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack_id = stack["StackId"] + cf_conn.delete_stack(StackName=stack["StackId"]) + stack_by_id = cf_conn.describe_stacks(StackName=stack_id)["Stacks"][0] + stack_by_id["StackId"].should.equal(stack["StackId"]) + stack_by_id["StackName"].should.equal("test_stack") + stack_by_id["StackStatus"].should.equal("DELETE_COMPLETE") @mock_cloudformation @mock_ec2 def test_describe_updated_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( StackName="test_stack", TemplateBody=dummy_template_json, - Tags=[{'Key': 'foo', 'Value': 'bar'}], + Tags=[{"Key": "foo", "Value": "bar"}], ) cf_conn.update_stack( StackName="test_stack", - RoleARN='arn:aws:iam::123456789012:role/moto', + RoleARN="arn:aws:iam::{}:role/moto".format(ACCOUNT_ID), TemplateBody=dummy_update_template_json, - Tags=[{'Key': 'foo', 'Value': 'baz'}], + Tags=[{"Key": "foo", "Value": "baz"}], ) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack_id = stack['StackId'] - stack_by_id = cf_conn.describe_stacks(StackName=stack_id)['Stacks'][0] - stack_by_id['StackId'].should.equal(stack['StackId']) - stack_by_id['StackName'].should.equal("test_stack") - stack_by_id['StackStatus'].should.equal("UPDATE_COMPLETE") - stack_by_id['RoleARN'].should.equal('arn:aws:iam::123456789012:role/moto') - stack_by_id['Tags'].should.equal([{'Key': 'foo', 'Value': 'baz'}]) + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack_id = stack["StackId"] + stack_by_id = cf_conn.describe_stacks(StackName=stack_id)["Stacks"][0] + stack_by_id["StackId"].should.equal(stack["StackId"]) + stack_by_id["StackName"].should.equal("test_stack") + stack_by_id["StackStatus"].should.equal("UPDATE_COMPLETE") + stack_by_id["RoleARN"].should.equal("arn:aws:iam::{}:role/moto".format(ACCOUNT_ID)) + stack_by_id["Tags"].should.equal([{"Key": "foo", "Value": "baz"}]) @mock_cloudformation def test_bad_describe_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") with assert_raises(ClientError): cf_conn.describe_stacks(StackName="non_existent_stack") @@ -1084,61 +1050,46 @@ def test_cloudformation_params(): "APPNAME": { "Default": "app-name", "Description": "The name of the app", - "Type": "String" + "Type": "String", } - } + }, } dummy_template_with_params_json = json.dumps(dummy_template_with_params) - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName='test_stack', + StackName="test_stack", TemplateBody=dummy_template_with_params_json, - Parameters=[{ - "ParameterKey": "APPNAME", - "ParameterValue": "testing123", - }], + Parameters=[{"ParameterKey": "APPNAME", "ParameterValue": "testing123"}], ) stack.parameters.should.have.length_of(1) param = stack.parameters[0] - param['ParameterKey'].should.equal('APPNAME') - param['ParameterValue'].should.equal('testing123') + param["ParameterKey"].should.equal("APPNAME") + param["ParameterValue"].should.equal("testing123") @mock_cloudformation def test_stack_tags(): - tags = [ - { - "Key": "foo", - "Value": "bar" - }, - { - "Key": "baz", - "Value": "bleh" - } - ] - cf = boto3.resource('cloudformation', region_name='us-east-1') + tags = [{"Key": "foo", "Value": "bar"}, {"Key": "baz", "Value": "bleh"}] + cf = boto3.resource("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - Tags=tags, + StackName="test_stack", TemplateBody=dummy_template_json, Tags=tags ) observed_tag_items = set( - item for items in [tag.items() for tag in stack.tags] for item in items) + item for items in [tag.items() for tag in stack.tags] for item in items + ) expected_tag_items = set( - item for items in [tag.items() for tag in tags] for item in items) + item for items in [tag.items() for tag in tags] for item in items + ) observed_tag_items.should.equal(expected_tag_items) @mock_cloudformation @mock_ec2 def test_stack_events(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf = boto3.resource("cloudformation", region_name="us-east-1") + stack = cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) stack.update(TemplateBody=dummy_update_template_json) stack = cf.Stack(stack.stack_id) stack.delete() @@ -1150,14 +1101,16 @@ def test_stack_events(): # testing ordering of stack events without assuming resource events will not exist # the AWS API returns events in reverse chronological order - stack_events_to_look_for = iter([ - ("DELETE_COMPLETE", None), - ("DELETE_IN_PROGRESS", "User Initiated"), - ("UPDATE_COMPLETE", None), - ("UPDATE_IN_PROGRESS", "User Initiated"), - ("CREATE_COMPLETE", None), - ("CREATE_IN_PROGRESS", "User Initiated"), - ]) + stack_events_to_look_for = iter( + [ + ("DELETE_COMPLETE", None), + ("DELETE_IN_PROGRESS", "User Initiated"), + ("UPDATE_COMPLETE", None), + ("UPDATE_IN_PROGRESS", "User Initiated"), + ("CREATE_COMPLETE", None), + ("CREATE_IN_PROGRESS", "User Initiated"), + ] + ) try: for event in events: event.stack_id.should.equal(stack.stack_id) @@ -1168,12 +1121,10 @@ def test_stack_events(): event.logical_resource_id.should.equal("test_stack") event.physical_resource_id.should.equal(stack.stack_id) - status_to_look_for, reason_to_look_for = next( - stack_events_to_look_for) + status_to_look_for, reason_to_look_for = next(stack_events_to_look_for) event.resource_status.should.equal(status_to_look_for) if reason_to_look_for is not None: - event.resource_status_reason.should.equal( - reason_to_look_for) + event.resource_status_reason.should.equal(reason_to_look_for) except StopIteration: assert False, "Too many stack events" @@ -1182,90 +1133,81 @@ def test_stack_events(): @mock_cloudformation def test_list_exports(): - cf_client = boto3.client('cloudformation', region_name='us-east-1') - cf_resource = boto3.resource('cloudformation', region_name='us-east-1') + cf_client = boto3.client("cloudformation", region_name="us-east-1") + cf_resource = boto3.resource("cloudformation", region_name="us-east-1") stack = cf_resource.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, + StackName="test_stack", TemplateBody=dummy_output_template_json ) - output_value = 'VPCID' - exports = cf_client.list_exports()['Exports'] + output_value = "VPCID" + exports = cf_client.list_exports()["Exports"] stack.outputs.should.have.length_of(1) - stack.outputs[0]['OutputValue'].should.equal(output_value) + stack.outputs[0]["OutputValue"].should.equal(output_value) exports.should.have.length_of(1) - exports[0]['ExportingStackId'].should.equal(stack.stack_id) - exports[0]['Name'].should.equal('My VPC ID') - exports[0]['Value'].should.equal(output_value) + exports[0]["ExportingStackId"].should.equal(stack.stack_id) + exports[0]["Name"].should.equal("My VPC ID") + exports[0]["Value"].should.equal(output_value) @mock_cloudformation def test_list_exports_with_token(): - cf = boto3.client('cloudformation', region_name='us-east-1') + cf = boto3.client("cloudformation", region_name="us-east-1") for i in range(101): # Add index to ensure name is unique - dummy_output_template['Outputs']['StackVPC']['Export']['Name'] += str(i) + dummy_output_template["Outputs"]["StackVPC"]["Export"]["Name"] += str(i) cf.create_stack( - StackName="test_stack", - TemplateBody=json.dumps(dummy_output_template), + StackName="test_stack", TemplateBody=json.dumps(dummy_output_template) ) exports = cf.list_exports() - exports['Exports'].should.have.length_of(100) - exports.get('NextToken').should_not.be.none + exports["Exports"].should.have.length_of(100) + exports.get("NextToken").should_not.be.none - more_exports = cf.list_exports(NextToken=exports['NextToken']) - more_exports['Exports'].should.have.length_of(1) - more_exports.get('NextToken').should.be.none + more_exports = cf.list_exports(NextToken=exports["NextToken"]) + more_exports["Exports"].should.have.length_of(1) + more_exports.get("NextToken").should.be.none @mock_cloudformation def test_delete_stack_with_export(): - cf = boto3.client('cloudformation', region_name='us-east-1') + cf = boto3.client("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, + StackName="test_stack", TemplateBody=dummy_output_template_json ) - stack_id = stack['StackId'] - exports = cf.list_exports()['Exports'] + stack_id = stack["StackId"] + exports = cf.list_exports()["Exports"] exports.should.have.length_of(1) cf.delete_stack(StackName=stack_id) - cf.list_exports()['Exports'].should.have.length_of(0) + cf.list_exports()["Exports"].should.have.length_of(0) @mock_cloudformation def test_export_names_must_be_unique(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") first_stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, + StackName="test_stack", TemplateBody=dummy_output_template_json ) with assert_raises(ClientError): - cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, - ) + cf.create_stack(StackName="test_stack", TemplateBody=dummy_output_template_json) @mock_sqs @mock_cloudformation def test_stack_with_imports(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - ec2_resource = boto3.resource('sqs', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") + ec2_resource = boto3.resource("sqs", region_name="us-east-1") output_stack = cf.create_stack( - StackName="test_stack1", - TemplateBody=dummy_output_template_json, + StackName="test_stack1", TemplateBody=dummy_output_template_json ) import_stack = cf.create_stack( - StackName="test_stack2", - TemplateBody=dummy_import_template_json + StackName="test_stack2", TemplateBody=dummy_import_template_json ) output_stack.outputs.should.have.length_of(1) - output = output_stack.outputs[0]['OutputValue'] + output = output_stack.outputs[0]["OutputValue"] queue = ec2_resource.get_queue_by_name(QueueName=output) queue.should_not.be.none @@ -1273,14 +1215,11 @@ def test_stack_with_imports(): @mock_sqs @mock_cloudformation def test_non_json_redrive_policy(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName="test_stack1", - TemplateBody=dummy_redrive_template_json + StackName="test_stack1", TemplateBody=dummy_redrive_template_json ) - stack.Resource('MainQueue').resource_status\ - .should.equal("CREATE_COMPLETE") - stack.Resource('DeadLetterQueue').resource_status\ - .should.equal("CREATE_COMPLETE") + stack.Resource("MainQueue").resource_status.should.equal("CREATE_COMPLETE") + stack.Resource("DeadLetterQueue").resource_status.should.equal("CREATE_COMPLETE") diff --git a/tests/test_cloudformation/test_cloudformation_stack_integration.py b/tests/test_cloudformation/test_cloudformation_stack_integration.py index 42ddd2351..e296ef2ed 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_integration.py +++ b/tests/test_cloudformation/test_cloudformation_stack_integration.py @@ -41,7 +41,9 @@ from moto import ( mock_sns_deprecated, mock_sqs, mock_sqs_deprecated, - mock_elbv2) + mock_elbv2, +) +from moto.core import ACCOUNT_ID from moto.dynamodb2.models import Table from .fixtures import ( @@ -65,26 +67,19 @@ def test_stack_sqs_integration(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) stack = conn.describe_stacks()[0] queue = stack.describe_resources()[0] - queue.resource_type.should.equal('AWS::SQS::Queue') + queue.resource_type.should.equal("AWS::SQS::Queue") queue.logical_resource_id.should.equal("QueueGroup") queue.physical_resource_id.should.equal("my-queue") @@ -95,27 +90,20 @@ def test_stack_list_resources(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) resources = conn.list_stack_resources("test_stack") assert len(resources) == 1 queue = resources[0] - queue.resource_type.should.equal('AWS::SQS::Queue') + queue.resource_type.should.equal("AWS::SQS::Queue") queue.logical_resource_id.should.equal("QueueGroup") queue.physical_resource_id.should.equal("my-queue") @@ -127,38 +115,32 @@ def test_update_stack(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) sqs_conn = boto.sqs.connect_to_region("us-west-1") queues = sqs_conn.get_all_queues() queues.should.have.length_of(1) - queues[0].get_attributes('VisibilityTimeout')[ - 'VisibilityTimeout'].should.equal('60') + queues[0].get_attributes("VisibilityTimeout")["VisibilityTimeout"].should.equal( + "60" + ) - sqs_template['Resources']['QueueGroup'][ - 'Properties']['VisibilityTimeout'] = 100 + sqs_template["Resources"]["QueueGroup"]["Properties"]["VisibilityTimeout"] = 100 sqs_template_json = json.dumps(sqs_template) conn.update_stack("test_stack", sqs_template_json) queues = sqs_conn.get_all_queues() queues.should.have.length_of(1) - queues[0].get_attributes('VisibilityTimeout')[ - 'VisibilityTimeout'].should.equal('100') + queues[0].get_attributes("VisibilityTimeout")["VisibilityTimeout"].should.equal( + "100" + ) @mock_cloudformation_deprecated() @@ -168,28 +150,21 @@ def test_update_stack_and_remove_resource(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) sqs_conn = boto.sqs.connect_to_region("us-west-1") queues = sqs_conn.get_all_queues() queues.should.have.length_of(1) - sqs_template['Resources'].pop('QueueGroup') + sqs_template["Resources"].pop("QueueGroup") sqs_template_json = json.dumps(sqs_template) conn.update_stack("test_stack", sqs_template_json) @@ -200,17 +175,11 @@ def test_update_stack_and_remove_resource(): @mock_cloudformation_deprecated() @mock_sqs_deprecated() def test_update_stack_and_add_resource(): - sqs_template = { - "AWSTemplateFormatVersion": "2010-09-09", - "Resources": {}, - } + sqs_template = {"AWSTemplateFormatVersion": "2010-09-09", "Resources": {}} sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) sqs_conn = boto.sqs.connect_to_region("us-west-1") queues = sqs_conn.get_all_queues() @@ -220,13 +189,9 @@ def test_update_stack_and_add_resource(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) @@ -244,20 +209,14 @@ def test_stack_ec2_integration(): "Resources": { "WebServerGroup": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } - }, + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, + } }, } ec2_template_json = json.dumps(ec2_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "ec2_stack", - template_body=ec2_template_json, - ) + conn.create_stack("ec2_stack", template_body=ec2_template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") reservation = ec2_conn.get_all_instances()[0] @@ -265,7 +224,7 @@ def test_stack_ec2_integration(): stack = conn.describe_stacks()[0] instance = stack.describe_resources()[0] - instance.resource_type.should.equal('AWS::EC2::Instance') + instance.resource_type.should.equal("AWS::EC2::Instance") instance.logical_resource_id.should.contain("WebServerGroup") instance.physical_resource_id.should.equal(ec2_instance.id) @@ -282,7 +241,7 @@ def test_stack_elb_integration_with_attached_ec2_instances(): "Properties": { "Instances": [{"Ref": "Ec2Instance1"}], "LoadBalancerName": "test-elb", - "AvailabilityZones": ['us-east-1'], + "AvailabilityZones": ["us-east-1"], "Listeners": [ { "InstancePort": "80", @@ -290,24 +249,18 @@ def test_stack_elb_integration_with_attached_ec2_instances(): "Protocol": "HTTP", } ], - } + }, }, "Ec2Instance1": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, }, } elb_template_json = json.dumps(elb_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.create_stack("elb_stack", template_body=elb_template_json) elb_conn = boto.ec2.elb.connect_to_region("us-west-1") load_balancer = elb_conn.get_all_load_balancers()[0] @@ -317,7 +270,7 @@ def test_stack_elb_integration_with_attached_ec2_instances(): ec2_instance = reservation.instances[0] load_balancer.instances[0].id.should.equal(ec2_instance.id) - list(load_balancer.availability_zones).should.equal(['us-east-1']) + list(load_balancer.availability_zones).should.equal(["us-east-1"]) @mock_elb_deprecated() @@ -330,7 +283,7 @@ def test_stack_elb_integration_with_health_check(): "Type": "AWS::ElasticLoadBalancing::LoadBalancer", "Properties": { "LoadBalancerName": "test-elb", - "AvailabilityZones": ['us-west-1'], + "AvailabilityZones": ["us-west-1"], "HealthCheck": { "HealthyThreshold": "3", "Interval": "5", @@ -345,17 +298,14 @@ def test_stack_elb_integration_with_health_check(): "Protocol": "HTTP", } ], - } - }, + }, + } }, } elb_template_json = json.dumps(elb_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.create_stack("elb_stack", template_body=elb_template_json) elb_conn = boto.ec2.elb.connect_to_region("us-west-1") load_balancer = elb_conn.get_all_load_balancers()[0] @@ -378,7 +328,7 @@ def test_stack_elb_integration_with_update(): "Type": "AWS::ElasticLoadBalancing::LoadBalancer", "Properties": { "LoadBalancerName": "test-elb", - "AvailabilityZones": ['us-west-1a'], + "AvailabilityZones": ["us-west-1a"], "Listeners": [ { "InstancePort": "80", @@ -387,31 +337,26 @@ def test_stack_elb_integration_with_update(): } ], "Policies": {"Ref": "AWS::NoValue"}, - } - }, + }, + } }, } elb_template_json = json.dumps(elb_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.create_stack("elb_stack", template_body=elb_template_json) elb_conn = boto.ec2.elb.connect_to_region("us-west-1") load_balancer = elb_conn.get_all_load_balancers()[0] - load_balancer.availability_zones[0].should.equal('us-west-1a') + load_balancer.availability_zones[0].should.equal("us-west-1a") - elb_template['Resources']['MyELB']['Properties'][ - 'AvailabilityZones'] = ['us-west-1b'] + elb_template["Resources"]["MyELB"]["Properties"]["AvailabilityZones"] = [ + "us-west-1b" + ] elb_template_json = json.dumps(elb_template) - conn.update_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.update_stack("elb_stack", template_body=elb_template_json) load_balancer = elb_conn.get_all_load_balancers()[0] - load_balancer.availability_zones[0].should.equal('us-west-1b') + load_balancer.availability_zones[0].should.equal("us-west-1b") @mock_ec2_deprecated() @@ -434,23 +379,24 @@ def test_redshift_stack(): ("MasterUserPassword", "mypass"), ("InboundTraffic", "10.0.0.1/16"), ("PortNumber", 5439), - ] + ], ) redshift_conn = boto.redshift.connect_to_region("us-west-2") cluster_res = redshift_conn.describe_clusters() - clusters = cluster_res['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'] + clusters = cluster_res["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ] clusters.should.have.length_of(1) cluster = clusters[0] - cluster['DBName'].should.equal("mydb") - cluster['NumberOfNodes'].should.equal(2) - cluster['NodeType'].should.equal("dw1.xlarge") - cluster['MasterUsername'].should.equal("myuser") - cluster['Port'].should.equal(5439) - cluster['VpcSecurityGroups'].should.have.length_of(1) - security_group_id = cluster['VpcSecurityGroups'][0]['VpcSecurityGroupId'] + cluster["DBName"].should.equal("mydb") + cluster["NumberOfNodes"].should.equal(2) + cluster["NodeType"].should.equal("dw1.xlarge") + cluster["MasterUsername"].should.equal("myuser") + cluster["Port"].should.equal(5439) + cluster["VpcSecurityGroups"].should.have.length_of(1) + security_group_id = cluster["VpcSecurityGroups"][0]["VpcSecurityGroupId"] groups = vpc_conn.get_all_security_groups(group_ids=[security_group_id]) groups.should.have.length_of(1) @@ -467,40 +413,36 @@ def test_stack_security_groups(): "Resources": { "my-security-group": { "Type": "AWS::EC2::SecurityGroup", - "Properties": { - "GroupDescription": "My other group", - }, + "Properties": {"GroupDescription": "My other group"}, }, "Ec2Instance2": { "Type": "AWS::EC2::Instance", "Properties": { "SecurityGroups": [{"Ref": "InstanceSecurityGroup"}], "ImageId": "ami-1234abcd", - } + }, }, "InstanceSecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "My security group", - "Tags": [ + "Tags": [{"Key": "bar", "Value": "baz"}], + "SecurityGroupIngress": [ { - "Key": "bar", - "Value": "baz" - } + "IpProtocol": "tcp", + "FromPort": "22", + "ToPort": "22", + "CidrIp": "123.123.123.123/32", + }, + { + "IpProtocol": "tcp", + "FromPort": "80", + "ToPort": "8000", + "SourceSecurityGroupId": {"Ref": "my-security-group"}, + }, ], - "SecurityGroupIngress": [{ - "IpProtocol": "tcp", - "FromPort": "22", - "ToPort": "22", - "CidrIp": "123.123.123.123/32", - }, { - "IpProtocol": "tcp", - "FromPort": "80", - "ToPort": "8000", - "SourceSecurityGroupId": {"Ref": "my-security-group"}, - }] - } - } + }, + }, }, } security_group_template_json = json.dumps(security_group_template) @@ -509,31 +451,33 @@ def test_stack_security_groups(): conn.create_stack( "security_group_stack", template_body=security_group_template_json, - tags={"foo": "bar"} + tags={"foo": "bar"}, ) ec2_conn = boto.ec2.connect_to_region("us-west-1") instance_group = ec2_conn.get_all_security_groups( - filters={'description': ['My security group']})[0] + filters={"description": ["My security group"]} + )[0] other_group = ec2_conn.get_all_security_groups( - filters={'description': ['My other group']})[0] + filters={"description": ["My other group"]} + )[0] reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] ec2_instance.groups[0].id.should.equal(instance_group.id) instance_group.description.should.equal("My security group") - instance_group.tags.should.have.key('foo').which.should.equal('bar') - instance_group.tags.should.have.key('bar').which.should.equal('baz') + instance_group.tags.should.have.key("foo").which.should.equal("bar") + instance_group.tags.should.have.key("bar").which.should.equal("baz") rule1, rule2 = instance_group.rules int(rule1.to_port).should.equal(22) int(rule1.from_port).should.equal(22) rule1.grants[0].cidr_ip.should.equal("123.123.123.123/32") - rule1.ip_protocol.should.equal('tcp') + rule1.ip_protocol.should.equal("tcp") int(rule2.to_port).should.equal(8000) int(rule2.from_port).should.equal(80) - rule2.ip_protocol.should.equal('tcp') + rule2.ip_protocol.should.equal("tcp") rule2.grants[0].group_id.should.equal(other_group.id) @@ -544,12 +488,11 @@ def test_stack_security_groups(): def test_autoscaling_group_with_elb(): web_setup_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Resources": { "my-as-group": { "Type": "AWS::AutoScaling::AutoScalingGroup", "Properties": { - "AvailabilityZones": ['us-east1'], + "AvailabilityZones": ["us-east1"], "LaunchConfigurationName": {"Ref": "my-launch-config"}, "MinSize": "2", "MaxSize": "2", @@ -557,34 +500,33 @@ def test_autoscaling_group_with_elb(): "LoadBalancerNames": [{"Ref": "my-elb"}], "Tags": [ { - "Key": "propagated-test-tag", "Value": "propagated-test-tag-value", - "PropagateAtLaunch": True}, + "Key": "propagated-test-tag", + "Value": "propagated-test-tag-value", + "PropagateAtLaunch": True, + }, { "Key": "not-propagated-test-tag", "Value": "not-propagated-test-tag-value", - "PropagateAtLaunch": False - } - ] + "PropagateAtLaunch": False, + }, + ], }, }, - "my-launch-config": { "Type": "AWS::AutoScaling::LaunchConfiguration", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, - "my-elb": { "Type": "AWS::ElasticLoadBalancing::LoadBalancer", "Properties": { - "AvailabilityZones": ['us-east1'], - "Listeners": [{ - "LoadBalancerPort": "80", - "InstancePort": "80", - "Protocol": "HTTP", - }], + "AvailabilityZones": ["us-east1"], + "Listeners": [ + { + "LoadBalancerPort": "80", + "InstancePort": "80", + "Protocol": "HTTP", + } + ], "LoadBalancerName": "my-elb", "HealthCheck": { "Target": "HTTP:80", @@ -595,21 +537,18 @@ def test_autoscaling_group_with_elb(): }, }, }, - } + }, } web_setup_template_json = json.dumps(web_setup_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "web_stack", - template_body=web_setup_template_json, - ) + conn.create_stack("web_stack", template_body=web_setup_template_json) autoscale_conn = boto.ec2.autoscale.connect_to_region("us-west-1") autoscale_group = autoscale_conn.get_all_groups()[0] autoscale_group.launch_config_name.should.contain("my-launch-config") - autoscale_group.load_balancers[0].should.equal('my-elb') + autoscale_group.load_balancers[0].should.equal("my-elb") # Confirm the Launch config was actually created autoscale_conn.get_all_launch_configurations().should.have.length_of(1) @@ -620,29 +559,36 @@ def test_autoscaling_group_with_elb(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() - as_group_resource = [resource for resource in resources if resource.resource_type == - 'AWS::AutoScaling::AutoScalingGroup'][0] + as_group_resource = [ + resource + for resource in resources + if resource.resource_type == "AWS::AutoScaling::AutoScalingGroup" + ][0] as_group_resource.physical_resource_id.should.contain("my-as-group") launch_config_resource = [ - resource for resource in resources if - resource.resource_type == 'AWS::AutoScaling::LaunchConfiguration'][0] - launch_config_resource.physical_resource_id.should.contain( - "my-launch-config") + resource + for resource in resources + if resource.resource_type == "AWS::AutoScaling::LaunchConfiguration" + ][0] + launch_config_resource.physical_resource_id.should.contain("my-launch-config") - elb_resource = [resource for resource in resources if resource.resource_type == - 'AWS::ElasticLoadBalancing::LoadBalancer'][0] + elb_resource = [ + resource + for resource in resources + if resource.resource_type == "AWS::ElasticLoadBalancing::LoadBalancer" + ][0] elb_resource.physical_resource_id.should.contain("my-elb") # confirm the instances were created with the right tags - ec2_conn = boto.ec2.connect_to_region('us-west-1') + ec2_conn = boto.ec2.connect_to_region("us-west-1") reservations = ec2_conn.get_all_reservations() len(reservations).should.equal(1) reservation = reservations[0] len(reservation.instances).should.equal(2) for instance in reservation.instances: - instance.tags['propagated-test-tag'].should.equal('propagated-test-tag-value') - instance.tags.keys().should_not.contain('not-propagated-test-tag') + instance.tags["propagated-test-tag"].should.equal("propagated-test-tag-value") + instance.tags.keys().should_not.contain("not-propagated-test-tag") @mock_autoscaling_deprecated() @@ -655,30 +601,23 @@ def test_autoscaling_group_update(): "my-as-group": { "Type": "AWS::AutoScaling::AutoScalingGroup", "Properties": { - "AvailabilityZones": ['us-west-1'], + "AvailabilityZones": ["us-west-1"], "LaunchConfigurationName": {"Ref": "my-launch-config"}, "MinSize": "2", "MaxSize": "2", - "DesiredCapacity": "2" + "DesiredCapacity": "2", }, }, - "my-launch-config": { "Type": "AWS::AutoScaling::LaunchConfiguration", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, }, } asg_template_json = json.dumps(asg_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "asg_stack", - template_body=asg_template_json, - ) + conn.create_stack("asg_stack", template_body=asg_template_json) autoscale_conn = boto.ec2.autoscale.connect_to_region("us-west-1") asg = autoscale_conn.get_all_groups()[0] @@ -686,37 +625,38 @@ def test_autoscaling_group_update(): asg.max_size.should.equal(2) asg.desired_capacity.should.equal(2) - asg_template['Resources']['my-as-group']['Properties']['MaxSize'] = 3 - asg_template['Resources']['my-as-group']['Properties']['Tags'] = [ + asg_template["Resources"]["my-as-group"]["Properties"]["MaxSize"] = 3 + asg_template["Resources"]["my-as-group"]["Properties"]["Tags"] = [ { - "Key": "propagated-test-tag", "Value": "propagated-test-tag-value", - "PropagateAtLaunch": True}, + "Key": "propagated-test-tag", + "Value": "propagated-test-tag-value", + "PropagateAtLaunch": True, + }, { "Key": "not-propagated-test-tag", "Value": "not-propagated-test-tag-value", - "PropagateAtLaunch": False - } + "PropagateAtLaunch": False, + }, ] asg_template_json = json.dumps(asg_template) - conn.update_stack( - "asg_stack", - template_body=asg_template_json, - ) + conn.update_stack("asg_stack", template_body=asg_template_json) asg = autoscale_conn.get_all_groups()[0] asg.min_size.should.equal(2) asg.max_size.should.equal(3) asg.desired_capacity.should.equal(2) # confirm the instances were created with the right tags - ec2_conn = boto.ec2.connect_to_region('us-west-1') + ec2_conn = boto.ec2.connect_to_region("us-west-1") reservations = ec2_conn.get_all_reservations() running_instance_count = 0 for res in reservations: for instance in res.instances: - if instance.state == 'running': + if instance.state == "running": running_instance_count += 1 - instance.tags['propagated-test-tag'].should.equal('propagated-test-tag-value') - instance.tags.keys().should_not.contain('not-propagated-test-tag') + instance.tags["propagated-test-tag"].should.equal( + "propagated-test-tag-value" + ) + instance.tags.keys().should_not.contain("not-propagated-test-tag") running_instance_count.should.equal(2) @@ -726,20 +666,18 @@ def test_vpc_single_instance_in_subnet(): template_json = json.dumps(vpc_single_instance_in_subnet.template) conn = boto.cloudformation.connect_to_region("us-west-1") conn.create_stack( - "test_stack", - template_body=template_json, - parameters=[("KeyName", "my_key")], + "test_stack", template_body=template_json, parameters=[("KeyName", "my_key")] ) vpc_conn = boto.vpc.connect_to_region("us-west-1") - vpc = vpc_conn.get_all_vpcs(filters={'cidrBlock': '10.0.0.0/16'})[0] + vpc = vpc_conn.get_all_vpcs(filters={"cidrBlock": "10.0.0.0/16"})[0] vpc.cidr_block.should.equal("10.0.0.0/16") # Add this once we implement the endpoint # vpc_conn.get_all_internet_gateways().should.have.length_of(1) - subnet = vpc_conn.get_all_subnets(filters={'vpcId': vpc.id})[0] + subnet = vpc_conn.get_all_subnets(filters={"vpcId": vpc.id})[0] subnet.vpc_id.should.equal(vpc.id) ec2_conn = boto.ec2.connect_to_region("us-west-1") @@ -748,28 +686,32 @@ def test_vpc_single_instance_in_subnet(): instance.tags["Foo"].should.equal("Bar") # Check that the EIP is attached the the EC2 instance eip = ec2_conn.get_all_addresses()[0] - eip.domain.should.equal('vpc') + eip.domain.should.equal("vpc") eip.instance_id.should.equal(instance.id) - security_group = ec2_conn.get_all_security_groups( - filters={'vpc_id': [vpc.id]})[0] + security_group = ec2_conn.get_all_security_groups(filters={"vpc_id": [vpc.id]})[0] security_group.vpc_id.should.equal(vpc.id) stack = conn.describe_stacks()[0] - vpc.tags.should.have.key('Application').which.should.equal(stack.stack_id) + vpc.tags.should.have.key("Application").which.should.equal(stack.stack_id) resources = stack.describe_resources() vpc_resource = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::VPC'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::VPC" + ][0] vpc_resource.physical_resource_id.should.equal(vpc.id) subnet_resource = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::Subnet'][0] + resource + for resource in resources + if resource.resource_type == "AWS::EC2::Subnet" + ][0] subnet_resource.physical_resource_id.should.equal(subnet.id) eip_resource = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::EIP'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::EIP" + ][0] eip_resource.physical_resource_id.should.equal(eip.public_ip) @@ -779,39 +721,45 @@ def test_vpc_single_instance_in_subnet(): def test_rds_db_parameter_groups(): ec2_conn = boto3.client("ec2", region_name="us-west-1") ec2_conn.create_security_group( - GroupName='application', Description='Our Application Group') + GroupName="application", Description="Our Application Group" + ) template_json = json.dumps(rds_mysql_with_db_parameter_group.template) - cf_conn = boto3.client('cloudformation', 'us-west-1') + cf_conn = boto3.client("cloudformation", "us-west-1") cf_conn.create_stack( StackName="test_stack", TemplateBody=template_json, - Parameters=[{'ParameterKey': key, 'ParameterValue': value} for - key, value in [ - ("DBInstanceIdentifier", "master_db"), - ("DBName", "my_db"), - ("DBUser", "my_user"), - ("DBPassword", "my_password"), - ("DBAllocatedStorage", "20"), - ("DBInstanceClass", "db.m1.medium"), - ("EC2SecurityGroup", "application"), - ("MultiAZ", "true"), - ] - ], + Parameters=[ + {"ParameterKey": key, "ParameterValue": value} + for key, value in [ + ("DBInstanceIdentifier", "master_db"), + ("DBName", "my_db"), + ("DBUser", "my_user"), + ("DBPassword", "my_password"), + ("DBAllocatedStorage", "20"), + ("DBInstanceClass", "db.m1.medium"), + ("EC2SecurityGroup", "application"), + ("MultiAZ", "true"), + ] + ], ) - rds_conn = boto3.client('rds', region_name="us-west-1") + rds_conn = boto3.client("rds", region_name="us-west-1") db_parameter_groups = rds_conn.describe_db_parameter_groups() - len(db_parameter_groups['DBParameterGroups']).should.equal(1) - db_parameter_group_name = db_parameter_groups[ - 'DBParameterGroups'][0]['DBParameterGroupName'] + len(db_parameter_groups["DBParameterGroups"]).should.equal(1) + db_parameter_group_name = db_parameter_groups["DBParameterGroups"][0][ + "DBParameterGroupName" + ] found_cloudformation_set_parameter = False - for db_parameter in rds_conn.describe_db_parameters(DBParameterGroupName=db_parameter_group_name)[ - 'Parameters']: - if db_parameter['ParameterName'] == 'BACKLOG_QUEUE_LIMIT' and db_parameter[ - 'ParameterValue'] == '2048': + for db_parameter in rds_conn.describe_db_parameters( + DBParameterGroupName=db_parameter_group_name + )["Parameters"]: + if ( + db_parameter["ParameterName"] == "BACKLOG_QUEUE_LIMIT" + and db_parameter["ParameterValue"] == "2048" + ): found_cloudformation_set_parameter = True found_cloudformation_set_parameter.should.equal(True) @@ -822,7 +770,7 @@ def test_rds_db_parameter_groups(): @mock_rds_deprecated() def test_rds_mysql_with_read_replica(): ec2_conn = boto.ec2.connect_to_region("us-west-1") - ec2_conn.create_security_group('application', 'Our Application Group') + ec2_conn.create_security_group("application", "Our Application Group") template_json = json.dumps(rds_mysql_with_read_replica.template) conn = boto.cloudformation.connect_to_region("us-west-1") @@ -893,43 +841,33 @@ def test_rds_mysql_with_read_replica_in_vpc(): def test_iam_roles(): iam_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Resources": { - "my-launch-config": { "Properties": { "IamInstanceProfile": {"Ref": "my-instance-profile-with-path"}, "ImageId": "ami-1234abcd", }, - "Type": "AWS::AutoScaling::LaunchConfiguration" + "Type": "AWS::AutoScaling::LaunchConfiguration", }, "my-instance-profile-with-path": { "Properties": { "Path": "my-path", "Roles": [{"Ref": "my-role-with-path"}], }, - "Type": "AWS::IAM::InstanceProfile" + "Type": "AWS::IAM::InstanceProfile", }, "my-instance-profile-no-path": { - "Properties": { - "Roles": [{"Ref": "my-role-no-path"}], - }, - "Type": "AWS::IAM::InstanceProfile" + "Properties": {"Roles": [{"Ref": "my-role-no-path"}]}, + "Type": "AWS::IAM::InstanceProfile", }, "my-role-with-path": { "Properties": { "AssumeRolePolicyDocument": { "Statement": [ { - "Action": [ - "sts:AssumeRole" - ], + "Action": ["sts:AssumeRole"], "Effect": "Allow", - "Principal": { - "Service": [ - "ec2.amazonaws.com" - ] - } + "Principal": {"Service": ["ec2.amazonaws.com"]}, } ] }, @@ -942,102 +880,90 @@ def test_iam_roles(): "Action": [ "ec2:CreateTags", "ec2:DescribeInstances", - "ec2:DescribeTags" + "ec2:DescribeTags", ], "Effect": "Allow", - "Resource": [ - "*" - ] + "Resource": ["*"], } ], - "Version": "2012-10-17" + "Version": "2012-10-17", }, - "PolicyName": "EC2_Tags" + "PolicyName": "EC2_Tags", }, { "PolicyDocument": { "Statement": [ { - "Action": [ - "sqs:*" - ], + "Action": ["sqs:*"], "Effect": "Allow", - "Resource": [ - "*" - ] + "Resource": ["*"], } ], - "Version": "2012-10-17" + "Version": "2012-10-17", }, - "PolicyName": "SQS" + "PolicyName": "SQS", }, - ] + ], }, - "Type": "AWS::IAM::Role" + "Type": "AWS::IAM::Role", }, "my-role-no-path": { "Properties": { "AssumeRolePolicyDocument": { "Statement": [ { - "Action": [ - "sts:AssumeRole" - ], + "Action": ["sts:AssumeRole"], "Effect": "Allow", - "Principal": { - "Service": [ - "ec2.amazonaws.com" - ] - } + "Principal": {"Service": ["ec2.amazonaws.com"]}, } ] - }, + } }, - "Type": "AWS::IAM::Role" - } - } + "Type": "AWS::IAM::Role", + }, + }, } iam_template_json = json.dumps(iam_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=iam_template_json, - ) + conn.create_stack("test_stack", template_body=iam_template_json) iam_conn = boto.iam.connect_to_region("us-west-1") - role_results = iam_conn.list_roles()['list_roles_response'][ - 'list_roles_result']['roles'] + role_results = iam_conn.list_roles()["list_roles_response"]["list_roles_result"][ + "roles" + ] role_name_to_id = {} for role_result in role_results: role = iam_conn.get_role(role_result.role_name) role.role_name.should.contain("my-role") - if 'with-path' in role.role_name: - role_name_to_id['with-path'] = role.role_id + if "with-path" in role.role_name: + role_name_to_id["with-path"] = role.role_id role.path.should.equal("my-path") else: - role_name_to_id['no-path'] = role.role_id - role.role_name.should.contain('no-path') - role.path.should.equal('/') + role_name_to_id["no-path"] = role.role_id + role.role_name.should.contain("no-path") + role.path.should.equal("/") instance_profile_responses = iam_conn.list_instance_profiles()[ - 'list_instance_profiles_response']['list_instance_profiles_result']['instance_profiles'] + "list_instance_profiles_response" + ]["list_instance_profiles_result"]["instance_profiles"] instance_profile_responses.should.have.length_of(2) instance_profile_names = [] for instance_profile_response in instance_profile_responses: - instance_profile = iam_conn.get_instance_profile(instance_profile_response.instance_profile_name) + instance_profile = iam_conn.get_instance_profile( + instance_profile_response.instance_profile_name + ) instance_profile_names.append(instance_profile.instance_profile_name) - instance_profile.instance_profile_name.should.contain( - "my-instance-profile") + instance_profile.instance_profile_name.should.contain("my-instance-profile") if "with-path" in instance_profile.instance_profile_name: instance_profile.path.should.equal("my-path") - instance_profile.role_id.should.equal(role_name_to_id['with-path']) + instance_profile.role_id.should.equal(role_name_to_id["with-path"]) else: - instance_profile.instance_profile_name.should.contain('no-path') - instance_profile.role_id.should.equal(role_name_to_id['no-path']) - instance_profile.path.should.equal('/') + instance_profile.instance_profile_name.should.contain("no-path") + instance_profile.role_id.should.equal(role_name_to_id["no-path"]) + instance_profile.path.should.equal("/") autoscale_conn = boto.ec2.autoscale.connect_to_region("us-west-1") launch_config = autoscale_conn.get_all_launch_configurations()[0] @@ -1046,12 +972,20 @@ def test_iam_roles(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() instance_profile_resources = [ - resource for resource in resources if resource.resource_type == 'AWS::IAM::InstanceProfile'] - {ip.physical_resource_id for ip in instance_profile_resources}.should.equal(set(instance_profile_names)) + resource + for resource in resources + if resource.resource_type == "AWS::IAM::InstanceProfile" + ] + {ip.physical_resource_id for ip in instance_profile_resources}.should.equal( + set(instance_profile_names) + ) role_resources = [ - resource for resource in resources if resource.resource_type == 'AWS::IAM::Role'] - {r.physical_resource_id for r in role_resources}.should.equal(set(role_name_to_id.values())) + resource for resource in resources if resource.resource_type == "AWS::IAM::Role" + ] + {r.physical_resource_id for r in role_resources}.should.equal( + set(role_name_to_id.values()) + ) @mock_ec2_deprecated() @@ -1060,9 +994,7 @@ def test_single_instance_with_ebs_volume(): template_json = json.dumps(single_instance_with_ebs_volume.template) conn = boto.cloudformation.connect_to_region("us-west-1") conn.create_stack( - "test_stack", - template_body=template_json, - parameters=[("KeyName", "key_name")] + "test_stack", template_body=template_json, parameters=[("KeyName", "key_name")] ) ec2_conn = boto.ec2.connect_to_region("us-west-1") @@ -1071,15 +1003,19 @@ def test_single_instance_with_ebs_volume(): volumes = ec2_conn.get_all_volumes() # Grab the mounted drive - volume = [ - volume for volume in volumes if volume.attach_data.device == '/dev/sdh'][0] - volume.volume_state().should.equal('in-use') + volume = [volume for volume in volumes if volume.attach_data.device == "/dev/sdh"][ + 0 + ] + volume.volume_state().should.equal("in-use") volume.attach_data.instance_id.should.equal(ec2_instance.id) stack = conn.describe_stacks()[0] resources = stack.describe_resources() ebs_volumes = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::Volume'] + resource + for resource in resources + if resource.resource_type == "AWS::EC2::Volume" + ] ebs_volumes[0].physical_resource_id.should.equal(volume.id) @@ -1088,8 +1024,7 @@ def test_create_template_without_required_param(): template_json = json.dumps(single_instance_with_ebs_volume.template) conn = boto.cloudformation.connect_to_region("us-west-1") conn.create_stack.when.called_with( - "test_stack", - template_body=template_json, + "test_stack", template_body=template_json ).should.throw(BotoServerError) @@ -1105,7 +1040,8 @@ def test_classic_eip(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() cfn_eip = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::EIP'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::EIP" + ][0] cfn_eip.physical_resource_id.should.equal(eip.public_ip) @@ -1121,7 +1057,8 @@ def test_vpc_eip(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() cfn_eip = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::EIP'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::EIP" + ][0] cfn_eip.physical_resource_id.should.equal(eip.public_ip) @@ -1136,7 +1073,7 @@ def test_fn_join(): stack = conn.describe_stacks()[0] fn_join_output = stack.outputs[0] - fn_join_output.value.should.equal('test eip:{0}'.format(eip.public_ip)) + fn_join_output.value.should.equal("test eip:{0}".format(eip.public_ip)) @mock_cloudformation_deprecated() @@ -1145,23 +1082,15 @@ def test_conditional_resources(): sqs_template = { "AWSTemplateFormatVersion": "2010-09-09", "Parameters": { - "EnvType": { - "Description": "Environment type.", - "Type": "String", - } - }, - "Conditions": { - "CreateQueue": {"Fn::Equals": [{"Ref": "EnvType"}, "prod"]} + "EnvType": {"Description": "Environment type.", "Type": "String"} }, + "Conditions": {"CreateQueue": {"Fn::Equals": [{"Ref": "EnvType"}, "prod"]}}, "Resources": { "QueueGroup": { "Condition": "CreateQueue", "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) @@ -1190,43 +1119,30 @@ def test_conditional_resources(): def test_conditional_if_handling(): dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Conditions": { - "EnvEqualsPrd": { - "Fn::Equals": [ - { - "Ref": "ENV" - }, - "prd" - ] - } - }, + "Conditions": {"EnvEqualsPrd": {"Fn::Equals": [{"Ref": "ENV"}, "prd"]}}, "Parameters": { "ENV": { "Default": "dev", "Description": "Deployment environment for the stack (dev/prd)", - "Type": "String" - }, + "Type": "String", + } }, "Description": "Stack 1", "Resources": { "App1": { "Properties": { "ImageId": { - "Fn::If": [ - "EnvEqualsPrd", - "ami-00000000", - "ami-ffffffff" - ] - }, + "Fn::If": ["EnvEqualsPrd", "ami-00000000", "ami-ffffffff"] + } }, - "Type": "AWS::EC2::Instance" - }, - } + "Type": "AWS::EC2::Instance", + } + }, } dummy_template_json = json.dumps(dummy_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack('test_stack1', template_body=dummy_template_json) + conn.create_stack("test_stack1", template_body=dummy_template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] @@ -1235,7 +1151,8 @@ def test_conditional_if_handling(): conn = boto.cloudformation.connect_to_region("us-west-2") conn.create_stack( - 'test_stack1', template_body=dummy_template_json, parameters=[("ENV", "prd")]) + "test_stack1", template_body=dummy_template_json, parameters=[("ENV", "prd")] + ) ec2_conn = boto.ec2.connect_to_region("us-west-2") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] @@ -1253,7 +1170,7 @@ def test_cloudformation_mapping(): "us-west-1": {"32": "ami-c9c7978c", "64": "ami-cfc7978a"}, "eu-west-1": {"32": "ami-37c2f643", "64": "ami-31c2f645"}, "ap-southeast-1": {"32": "ami-66f28c34", "64": "ami-60f28c32"}, - "ap-northeast-1": {"32": "ami-9c03a89d", "64": "ami-a003a8a1"} + "ap-northeast-1": {"32": "ami-9c03a89d", "64": "ami-a003a8a1"}, } }, "Resources": { @@ -1263,24 +1180,24 @@ def test_cloudformation_mapping(): "ImageId": { "Fn::FindInMap": ["RegionMap", {"Ref": "AWS::Region"}, "32"] }, - "InstanceType": "m1.small" + "InstanceType": "m1.small", }, "Type": "AWS::EC2::Instance", - }, + } }, } dummy_template_json = json.dumps(dummy_template) conn = boto.cloudformation.connect_to_region("us-east-1") - conn.create_stack('test_stack1', template_body=dummy_template_json) + conn.create_stack("test_stack1", template_body=dummy_template_json) ec2_conn = boto.ec2.connect_to_region("us-east-1") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] ec2_instance.image_id.should.equal("ami-6411e20d") conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack('test_stack1', template_body=dummy_template_json) + conn.create_stack("test_stack1", template_body=dummy_template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] @@ -1294,42 +1211,39 @@ def test_route53_roundrobin(): template_json = json.dumps(route53_roundrobin.template) conn = boto.cloudformation.connect_to_region("us-west-1") - stack = conn.create_stack( - "test_stack", - template_body=template_json, - ) + stack = conn.create_stack("test_stack", template_body=template_json) - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) rrsets.hosted_zone_id.should.equal(zone_id) rrsets.should.have.length_of(2) record_set1 = rrsets[0] - record_set1.name.should.equal('test_stack.us-west-1.my_zone.') + record_set1.name.should.equal("test_stack.us-west-1.my_zone.") record_set1.identifier.should.equal("test_stack AWS") - record_set1.type.should.equal('CNAME') - record_set1.ttl.should.equal('900') - record_set1.weight.should.equal('3') + record_set1.type.should.equal("CNAME") + record_set1.ttl.should.equal("900") + record_set1.weight.should.equal("3") record_set1.resource_records[0].should.equal("aws.amazon.com") record_set2 = rrsets[1] - record_set2.name.should.equal('test_stack.us-west-1.my_zone.') + record_set2.name.should.equal("test_stack.us-west-1.my_zone.") record_set2.identifier.should.equal("test_stack Amazon") - record_set2.type.should.equal('CNAME') - record_set2.ttl.should.equal('900') - record_set2.weight.should.equal('1') + record_set2.type.should.equal("CNAME") + record_set2.ttl.should.equal("900") + record_set2.weight.should.equal("1") record_set2.resource_records[0].should.equal("www.amazon.com") stack = conn.describe_stacks()[0] output = stack.outputs[0] - output.key.should.equal('DomainName') - output.value.should.equal( - 'arn:aws:route53:::hostedzone/{0}'.format(zone_id)) + output.key.should.equal("DomainName") + output.value.should.equal("arn:aws:route53:::hostedzone/{0}".format(zone_id)) @mock_cloudformation_deprecated() @@ -1341,28 +1255,26 @@ def test_route53_ec2_instance_with_public_ip(): template_json = json.dumps(route53_ec2_instance_with_public_ip.template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=template_json, - ) + conn.create_stack("test_stack", template_body=template_json) instance_id = ec2_conn.get_all_reservations()[0].instances[0].id - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) rrsets.should.have.length_of(1) record_set1 = rrsets[0] - record_set1.name.should.equal('{0}.us-west-1.my_zone.'.format(instance_id)) + record_set1.name.should.equal("{0}.us-west-1.my_zone.".format(instance_id)) record_set1.identifier.should.equal(None) - record_set1.type.should.equal('A') - record_set1.ttl.should.equal('900') + record_set1.type.should.equal("A") + record_set1.ttl.should.equal("900") record_set1.weight.should.equal(None) record_set1.resource_records[0].should.equal("10.0.0.25") @@ -1374,17 +1286,15 @@ def test_route53_associate_health_check(): template_json = json.dumps(route53_health_check.template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=template_json, - ) + conn.create_stack("test_stack", template_body=template_json) - checks = route53_conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = route53_conn.get_list_health_checks()["ListHealthChecksResponse"][ + "HealthChecks" + ] list(checks).should.have.length_of(1) check = checks[0] - health_check_id = check['Id'] - config = check['HealthCheckConfig'] + health_check_id = check["Id"] + config = check["HealthCheckConfig"] config["FailureThreshold"].should.equal("3") config["IPAddress"].should.equal("10.0.0.4") config["Port"].should.equal("80") @@ -1392,11 +1302,12 @@ def test_route53_associate_health_check(): config["ResourcePath"].should.equal("/") config["Type"].should.equal("HTTP") - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) @@ -1413,16 +1324,14 @@ def test_route53_with_update(): template_json = json.dumps(route53_health_check.template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) @@ -1431,19 +1340,18 @@ def test_route53_with_update(): record_set = rrsets[0] record_set.resource_records.should.equal(["my.example.com"]) - route53_health_check.template['Resources']['myDNSRecord'][ - 'Properties']['ResourceRecords'] = ["my_other.example.com"] + route53_health_check.template["Resources"]["myDNSRecord"]["Properties"][ + "ResourceRecords" + ] = ["my_other.example.com"] template_json = json.dumps(route53_health_check.template) - cf_conn.update_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.update_stack("test_stack", template_body=template_json) - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) @@ -1463,37 +1371,32 @@ def test_sns_topic(): "Type": "AWS::SNS::Topic", "Properties": { "Subscription": [ - {"Endpoint": "https://example.com", "Protocol": "https"}, + {"Endpoint": "https://example.com", "Protocol": "https"} ], "TopicName": "my_topics", - } + }, } }, "Outputs": { - "topic_name": { - "Value": {"Fn::GetAtt": ["MySNSTopic", "TopicName"]} - }, - "topic_arn": { - "Value": {"Ref": "MySNSTopic"} - }, - } + "topic_name": {"Value": {"Fn::GetAtt": ["MySNSTopic", "TopicName"]}}, + "topic_arn": {"Value": {"Ref": "MySNSTopic"}}, + }, } template_json = json.dumps(dummy_template) conn = boto.cloudformation.connect_to_region("us-west-1") - stack = conn.create_stack( - "test_stack", - template_body=template_json, - ) + stack = conn.create_stack("test_stack", template_body=template_json) sns_conn = boto.sns.connect_to_region("us-west-1") - topics = sns_conn.get_all_topics()["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"] + topics = sns_conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"][ + "Topics" + ] topics.should.have.length_of(1) - topic_arn = topics[0]['TopicArn'] + topic_arn = topics[0]["TopicArn"] topic_arn.should.contain("my_topics") subscriptions = sns_conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] + "ListSubscriptionsResult" + ]["Subscriptions"] subscriptions.should.have.length_of(1) subscription = subscriptions[0] subscription["TopicArn"].should.equal(topic_arn) @@ -1502,9 +1405,9 @@ def test_sns_topic(): subscription["Endpoint"].should.equal("https://example.com") stack = conn.describe_stacks()[0] - topic_name_output = [x for x in stack.outputs if x.key == 'topic_name'][0] + topic_name_output = [x for x in stack.outputs if x.key == "topic_name"][0] topic_name_output.value.should.equal("my_topics") - topic_arn_output = [x for x in stack.outputs if x.key == 'topic_arn'][0] + topic_arn_output = [x for x in stack.outputs if x.key == "topic_arn"][0] topic_arn_output.value.should.equal(topic_arn) @@ -1514,44 +1417,33 @@ def test_vpc_gateway_attachment_creation_should_attach_itself_to_vpc(): template = { "AWSTemplateFormatVersion": "2010-09-09", "Resources": { - "internetgateway": { - "Type": "AWS::EC2::InternetGateway" - }, + "internetgateway": {"Type": "AWS::EC2::InternetGateway"}, "testvpc": { "Type": "AWS::EC2::VPC", "Properties": { "CidrBlock": "10.0.0.0/16", "EnableDnsHostnames": "true", "EnableDnsSupport": "true", - "InstanceTenancy": "default" + "InstanceTenancy": "default", }, }, "vpcgatewayattachment": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { - "InternetGatewayId": { - "Ref": "internetgateway" - }, - "VpcId": { - "Ref": "testvpc" - } + "InternetGatewayId": {"Ref": "internetgateway"}, + "VpcId": {"Ref": "testvpc"}, }, }, - } + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) vpc_conn = boto.vpc.connect_to_region("us-west-1") - vpc = vpc_conn.get_all_vpcs(filters={'cidrBlock': '10.0.0.0/16'})[0] - igws = vpc_conn.get_all_internet_gateways( - filters={'attachment.vpc-id': vpc.id} - ) + vpc = vpc_conn.get_all_vpcs(filters={"cidrBlock": "10.0.0.0/16"})[0] + igws = vpc_conn.get_all_internet_gateways(filters={"attachment.vpc-id": vpc.id}) igws.should.have.length_of(1) @@ -1567,20 +1459,14 @@ def test_vpc_peering_creation(): "Resources": { "vpcpeeringconnection": { "Type": "AWS::EC2::VPCPeeringConnection", - "Properties": { - "PeerVpcId": peer_vpc.id, - "VpcId": vpc_source.id, - } - }, - } + "Properties": {"PeerVpcId": peer_vpc.id, "VpcId": vpc_source.id}, + } + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) peering_connections = vpc_conn.get_all_vpc_peering_connections() peering_connections.should.have.length_of(1) @@ -1596,24 +1482,14 @@ def test_multiple_security_group_ingress_separate_from_security_group_by_id(): "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "test security group", - "Tags": [ - { - "Key": "sg-name", - "Value": "sg1" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg1"}], }, }, "test-security-group2": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "test security group", - "Tags": [ - { - "Key": "sg-name", - "Value": "sg2" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg2"}], }, }, "test-sg-ingress": { @@ -1624,39 +1500,36 @@ def test_multiple_security_group_ingress_separate_from_security_group_by_id(): "FromPort": "80", "ToPort": "8080", "SourceSecurityGroupId": {"Ref": "test-security-group2"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") - security_group1 = ec2_conn.get_all_security_groups( - filters={"tag:sg-name": "sg1"})[0] - security_group2 = ec2_conn.get_all_security_groups( - filters={"tag:sg-name": "sg2"})[0] + security_group1 = ec2_conn.get_all_security_groups(filters={"tag:sg-name": "sg1"})[ + 0 + ] + security_group2 = ec2_conn.get_all_security_groups(filters={"tag:sg-name": "sg2"})[ + 0 + ] security_group1.rules.should.have.length_of(1) security_group1.rules[0].grants.should.have.length_of(1) - security_group1.rules[0].grants[ - 0].group_id.should.equal(security_group2.id) - security_group1.rules[0].ip_protocol.should.equal('tcp') - security_group1.rules[0].from_port.should.equal('80') - security_group1.rules[0].to_port.should.equal('8080') + security_group1.rules[0].grants[0].group_id.should.equal(security_group2.id) + security_group1.rules[0].ip_protocol.should.equal("tcp") + security_group1.rules[0].from_port.should.equal("80") + security_group1.rules[0].to_port.should.equal("8080") @mock_cloudformation_deprecated @mock_ec2_deprecated def test_security_group_ingress_separate_from_security_group_by_id(): ec2_conn = boto.ec2.connect_to_region("us-west-1") - ec2_conn.create_security_group( - "test-security-group1", "test security group") + ec2_conn.create_security_group("test-security-group1", "test security group") template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -1665,12 +1538,7 @@ def test_security_group_ingress_separate_from_security_group_by_id(): "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "test security group", - "Tags": [ - { - "Key": "sg-name", - "Value": "sg2" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg2"}], }, }, "test-sg-ingress": { @@ -1681,29 +1549,27 @@ def test_security_group_ingress_separate_from_security_group_by_id(): "FromPort": "80", "ToPort": "8080", "SourceSecurityGroupId": {"Ref": "test-security-group2"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) security_group1 = ec2_conn.get_all_security_groups( - groupnames=["test-security-group1"])[0] - security_group2 = ec2_conn.get_all_security_groups( - filters={"tag:sg-name": "sg2"})[0] + groupnames=["test-security-group1"] + )[0] + security_group2 = ec2_conn.get_all_security_groups(filters={"tag:sg-name": "sg2"})[ + 0 + ] security_group1.rules.should.have.length_of(1) security_group1.rules[0].grants.should.have.length_of(1) - security_group1.rules[0].grants[ - 0].group_id.should.equal(security_group2.id) - security_group1.rules[0].ip_protocol.should.equal('tcp') - security_group1.rules[0].from_port.should.equal('80') - security_group1.rules[0].to_port.should.equal('8080') + security_group1.rules[0].grants[0].group_id.should.equal(security_group2.id) + security_group1.rules[0].ip_protocol.should.equal("tcp") + security_group1.rules[0].from_port.should.equal("80") + security_group1.rules[0].to_port.should.equal("8080") @mock_cloudformation_deprecated @@ -1720,12 +1586,7 @@ def test_security_group_ingress_separate_from_security_group_by_id_using_vpc(): "Properties": { "GroupDescription": "test security group", "VpcId": vpc.id, - "Tags": [ - { - "Key": "sg-name", - "Value": "sg1" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg1"}], }, }, "test-security-group2": { @@ -1733,12 +1594,7 @@ def test_security_group_ingress_separate_from_security_group_by_id_using_vpc(): "Properties": { "GroupDescription": "test security group", "VpcId": vpc.id, - "Tags": [ - { - "Key": "sg-name", - "Value": "sg2" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg2"}], }, }, "test-sg-ingress": { @@ -1750,29 +1606,27 @@ def test_security_group_ingress_separate_from_security_group_by_id_using_vpc(): "FromPort": "80", "ToPort": "8080", "SourceSecurityGroupId": {"Ref": "test-security-group2"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) - security_group1 = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg1"})[0] - security_group2 = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg2"})[0] + cf_conn.create_stack("test_stack", template_body=template_json) + security_group1 = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg1"})[ + 0 + ] + security_group2 = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg2"})[ + 0 + ] security_group1.rules.should.have.length_of(1) security_group1.rules[0].grants.should.have.length_of(1) - security_group1.rules[0].grants[ - 0].group_id.should.equal(security_group2.id) - security_group1.rules[0].ip_protocol.should.equal('tcp') - security_group1.rules[0].from_port.should.equal('80') - security_group1.rules[0].to_port.should.equal('8080') + security_group1.rules[0].grants[0].group_id.should.equal(security_group2.id) + security_group1.rules[0].ip_protocol.should.equal("tcp") + security_group1.rules[0].from_port.should.equal("80") + security_group1.rules[0].to_port.should.equal("8080") @mock_cloudformation_deprecated @@ -1789,44 +1643,30 @@ def test_security_group_with_update(): "Properties": { "GroupDescription": "test security group", "VpcId": vpc1.id, - "Tags": [ - { - "Key": "sg-name", - "Value": "sg" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg"}], }, - }, - } + } + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) - security_group = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg"})[0] + cf_conn.create_stack("test_stack", template_body=template_json) + security_group = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg"})[0] security_group.vpc_id.should.equal(vpc1.id) vpc2 = vpc_conn.create_vpc("10.1.0.0/16") - template['Resources'][ - 'test-security-group']['Properties']['VpcId'] = vpc2.id + template["Resources"]["test-security-group"]["Properties"]["VpcId"] = vpc2.id template_json = json.dumps(template) - cf_conn.update_stack( - "test_stack", - template_body=template_json, - ) - security_group = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg"})[0] + cf_conn.update_stack("test_stack", template_body=template_json) + security_group = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg"})[0] security_group.vpc_id.should.equal(vpc2.id) @mock_cloudformation_deprecated @mock_ec2_deprecated def test_subnets_should_be_created_with_availability_zone(): - vpc_conn = boto.vpc.connect_to_region('us-west-1') + vpc_conn = boto.vpc.connect_to_region("us-west-1") vpc = vpc_conn.create_vpc("10.0.0.0/16") subnet_template = { @@ -1838,18 +1678,15 @@ def test_subnets_should_be_created_with_availability_zone(): "VpcId": vpc.id, "CidrBlock": "10.0.0.0/24", "AvailabilityZone": "us-west-1b", - } + }, } - } + }, } cf_conn = boto.cloudformation.connect_to_region("us-west-1") template_json = json.dumps(subnet_template) - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) - subnet = vpc_conn.get_all_subnets(filters={'cidrBlock': '10.0.0.0/24'})[0] - subnet.availability_zone.should.equal('us-west-1b') + cf_conn.create_stack("test_stack", template_body=template_json) + subnet = vpc_conn.get_all_subnets(filters={"cidrBlock": "10.0.0.0/24"})[0] + subnet.availability_zone.should.equal("us-west-1b") @mock_cloudformation_deprecated @@ -1867,71 +1704,53 @@ def test_datapipeline(): "Fields": [ { "Key": "failureAndRerunMode", - "StringValue": "CASCADE" - }, - { - "Key": "scheduleType", - "StringValue": "cron" - }, - { - "Key": "schedule", - "RefValue": "DefaultSchedule" + "StringValue": "CASCADE", }, + {"Key": "scheduleType", "StringValue": "cron"}, + {"Key": "schedule", "RefValue": "DefaultSchedule"}, { "Key": "pipelineLogUri", - "StringValue": "s3://bucket/logs" - }, - { - "Key": "type", - "StringValue": "Default" + "StringValue": "s3://bucket/logs", }, + {"Key": "type", "StringValue": "Default"}, ], "Id": "Default", - "Name": "Default" + "Name": "Default", }, { "Fields": [ { "Key": "startDateTime", - "StringValue": "1970-01-01T01:00:00" + "StringValue": "1970-01-01T01:00:00", }, - { - "Key": "period", - "StringValue": "1 Day" - }, - { - "Key": "type", - "StringValue": "Schedule" - } + {"Key": "period", "StringValue": "1 Day"}, + {"Key": "type", "StringValue": "Schedule"}, ], "Id": "DefaultSchedule", - "Name": "RunOnce" - } + "Name": "RunOnce", + }, ], - "PipelineTags": [] + "PipelineTags": [], }, - "Type": "AWS::DataPipeline::Pipeline" + "Type": "AWS::DataPipeline::Pipeline", } - } + }, } cf_conn = boto.cloudformation.connect_to_region("us-east-1") template_json = json.dumps(dp_template) - stack_id = cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + stack_id = cf_conn.create_stack("test_stack", template_body=template_json) - dp_conn = boto.datapipeline.connect_to_region('us-east-1') + dp_conn = boto.datapipeline.connect_to_region("us-east-1") data_pipelines = dp_conn.list_pipelines() - data_pipelines['pipelineIdList'].should.have.length_of(1) - data_pipelines['pipelineIdList'][0][ - 'name'].should.equal('testDataPipeline') + data_pipelines["pipelineIdList"].should.have.length_of(1) + data_pipelines["pipelineIdList"][0]["name"].should.equal("testDataPipeline") stack_resources = cf_conn.list_stack_resources(stack_id) stack_resources.should.have.length_of(1) stack_resources[0].physical_resource_id.should.equal( - data_pipelines['pipelineIdList'][0]['id']) + data_pipelines["pipelineIdList"][0]["id"] + ) @mock_cloudformation @@ -1955,47 +1774,55 @@ def lambda_handler(event, context): "Handler": "lambda_function.handler", "Description": "Test function", "MemorySize": 128, - "Role": "test-role", + "Role": {"Fn::GetAtt": ["MyRole", "Arn"]}, "Runtime": "python2.7", - "Environment": { - "Variables": { - "TEST_ENV_KEY": "test-env-val", - } - }, - } - } - } + "Environment": {"Variables": {"TEST_ENV_KEY": "test-env-val"}}, + }, + }, + "MyRole": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Action": ["sts:AssumeRole"], + "Effect": "Allow", + "Principal": {"Service": ["ec2.amazonaws.com"]}, + } + ] + } + }, + }, + }, } template_json = json.dumps(template) - cf_conn = boto3.client('cloudformation', 'us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cf_conn = boto3.client("cloudformation", "us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=template_json) - conn = boto3.client('lambda', 'us-east-1') + conn = boto3.client("lambda", "us-east-1") result = conn.list_functions() - result['Functions'].should.have.length_of(1) - result['Functions'][0]['Description'].should.equal('Test function') - result['Functions'][0]['Handler'].should.equal('lambda_function.handler') - result['Functions'][0]['MemorySize'].should.equal(128) - result['Functions'][0]['Role'].should.equal('test-role') - result['Functions'][0]['Runtime'].should.equal('python2.7') - result['Functions'][0]['Environment'].should.equal({ - "Variables": {"TEST_ENV_KEY": "test-env-val"} - }) + result["Functions"].should.have.length_of(1) + result["Functions"][0]["Description"].should.equal("Test function") + result["Functions"][0]["Handler"].should.equal("lambda_function.handler") + result["Functions"][0]["MemorySize"].should.equal(128) + result["Functions"][0]["Runtime"].should.equal("python2.7") + result["Functions"][0]["Environment"].should.equal( + {"Variables": {"TEST_ENV_KEY": "test-env-val"}} + ) @mock_cloudformation @mock_ec2 def test_nat_gateway(): - ec2_conn = boto3.client('ec2', 'us-east-1') - vpc_id = ec2_conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc']['VpcId'] - subnet_id = ec2_conn.create_subnet( - CidrBlock='10.0.1.0/24', VpcId=vpc_id)['Subnet']['SubnetId'] - route_table_id = ec2_conn.create_route_table( - VpcId=vpc_id)['RouteTable']['RouteTableId'] + ec2_conn = boto3.client("ec2", "us-east-1") + vpc_id = ec2_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]["VpcId"] + subnet_id = ec2_conn.create_subnet(CidrBlock="10.0.1.0/24", VpcId=vpc_id)["Subnet"][ + "SubnetId" + ] + route_table_id = ec2_conn.create_route_table(VpcId=vpc_id)["RouteTable"][ + "RouteTableId" + ] template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -2005,109 +1832,95 @@ def test_nat_gateway(): "Type": "AWS::EC2::NatGateway", "Properties": { "AllocationId": {"Fn::GetAtt": ["EIP", "AllocationId"]}, - "SubnetId": subnet_id - } - }, - "EIP": { - "Type": "AWS::EC2::EIP", - "Properties": { - "Domain": "vpc" - } + "SubnetId": subnet_id, + }, }, + "EIP": {"Type": "AWS::EC2::EIP", "Properties": {"Domain": "vpc"}}, "Route": { "Type": "AWS::EC2::Route", "Properties": { "RouteTableId": route_table_id, "DestinationCidrBlock": "0.0.0.0/0", - "NatGatewayId": {"Ref": "NAT"} - } - }, - "internetgateway": { - "Type": "AWS::EC2::InternetGateway" + "NatGatewayId": {"Ref": "NAT"}, + }, }, + "internetgateway": {"Type": "AWS::EC2::InternetGateway"}, "vpcgatewayattachment": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { - "InternetGatewayId": { - "Ref": "internetgateway" - }, + "InternetGatewayId": {"Ref": "internetgateway"}, "VpcId": vpc_id, }, - } - } + }, + }, } - cf_conn = boto3.client('cloudformation', 'us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=json.dumps(template), - ) + cf_conn = boto3.client("cloudformation", "us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=json.dumps(template)) result = ec2_conn.describe_nat_gateways() - result['NatGateways'].should.have.length_of(1) - result['NatGateways'][0]['VpcId'].should.equal(vpc_id) - result['NatGateways'][0]['SubnetId'].should.equal(subnet_id) - result['NatGateways'][0]['State'].should.equal('available') + result["NatGateways"].should.have.length_of(1) + result["NatGateways"][0]["VpcId"].should.equal(vpc_id) + result["NatGateways"][0]["SubnetId"].should.equal(subnet_id) + result["NatGateways"][0]["State"].should.equal("available") @mock_cloudformation() @mock_kms() def test_stack_kms(): kms_key_template = { - 'Resources': { - 'kmskey': { - 'Properties': { - 'Description': 'A kms key', - 'EnableKeyRotation': True, - 'Enabled': True, - 'KeyPolicy': 'a policy', + "Resources": { + "kmskey": { + "Properties": { + "Description": "A kms key", + "EnableKeyRotation": True, + "Enabled": True, + "KeyPolicy": "a policy", }, - 'Type': 'AWS::KMS::Key' + "Type": "AWS::KMS::Key", } } } kms_key_template_json = json.dumps(kms_key_template) - cf_conn = boto3.client('cloudformation', 'us-east-1') - cf_conn.create_stack( - StackName='test_stack', - TemplateBody=kms_key_template_json, - ) + cf_conn = boto3.client("cloudformation", "us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=kms_key_template_json) - kms_conn = boto3.client('kms', 'us-east-1') - keys = kms_conn.list_keys()['Keys'] + kms_conn = boto3.client("kms", "us-east-1") + keys = kms_conn.list_keys()["Keys"] len(keys).should.equal(1) - result = kms_conn.describe_key(KeyId=keys[0]['KeyId']) + result = kms_conn.describe_key(KeyId=keys[0]["KeyId"]) - result['KeyMetadata']['Enabled'].should.equal(True) - result['KeyMetadata']['KeyUsage'].should.equal('ENCRYPT_DECRYPT') + result["KeyMetadata"]["Enabled"].should.equal(True) + result["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") @mock_cloudformation() @mock_ec2() def test_stack_spot_fleet(): - conn = boto3.client('ec2', 'us-east-1') + conn = boto3.client("ec2", "us-east-1") - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] spot_fleet_template = { - 'Resources': { + "Resources": { "SpotFleet": { "Type": "AWS::EC2::SpotFleet", "Properties": { "SpotFleetRequestConfigData": { - "IamFleetRole": "arn:aws:iam::123456789012:role/fleet", + "IamFleetRole": "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID), "SpotPrice": "0.12", "TargetCapacity": 6, "AllocationStrategy": "diversified", "LaunchSpecifications": [ { "EbsOptimized": "false", - "InstanceType": 't2.small', + "InstanceType": "t2.small", "ImageId": "ami-1234", "SubnetId": subnet_id, "WeightedCapacity": "2", @@ -2115,129 +1928,137 @@ def test_stack_spot_fleet(): }, { "EbsOptimized": "true", - "InstanceType": 't2.large', + "InstanceType": "t2.large", "ImageId": "ami-1234", "Monitoring": {"Enabled": "true"}, "SecurityGroups": [{"GroupId": "sg-123"}], "SubnetId": subnet_id, - "IamInstanceProfile": {"Arn": "arn:aws:iam::123456789012:role/fleet"}, + "IamInstanceProfile": { + "Arn": "arn:aws:iam::{}:role/fleet".format( + ACCOUNT_ID + ) + }, "WeightedCapacity": "4", "SpotPrice": "10.00", - } - ] + }, + ], } - } + }, } } } spot_fleet_template_json = json.dumps(spot_fleet_template) - cf_conn = boto3.client('cloudformation', 'us-east-1') + cf_conn = boto3.client("cloudformation", "us-east-1") stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=spot_fleet_template_json, - )['StackId'] + StackName="test_stack", TemplateBody=spot_fleet_template_json + )["StackId"] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - stack_resources['StackResourceSummaries'].should.have.length_of(1) - spot_fleet_id = stack_resources[ - 'StackResourceSummaries'][0]['PhysicalResourceId'] + stack_resources["StackResourceSummaries"].should.have.length_of(1) + spot_fleet_id = stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_request['SpotFleetRequestState'].should.equal("active") - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_request["SpotFleetRequestState"].should.equal("active") + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - spot_fleet_config['SpotPrice'].should.equal('0.12') - spot_fleet_config['TargetCapacity'].should.equal(6) - spot_fleet_config['IamFleetRole'].should.equal( - 'arn:aws:iam::123456789012:role/fleet') - spot_fleet_config['AllocationStrategy'].should.equal('diversified') - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + spot_fleet_config["SpotPrice"].should.equal("0.12") + spot_fleet_config["TargetCapacity"].should.equal(6) + spot_fleet_config["IamFleetRole"].should.equal( + "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID) + ) + spot_fleet_config["AllocationStrategy"].should.equal("diversified") + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec = spot_fleet_config['LaunchSpecifications'][0] + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec = spot_fleet_config["LaunchSpecifications"][0] - launch_spec['EbsOptimized'].should.equal(False) - launch_spec['ImageId'].should.equal("ami-1234") - launch_spec['InstanceType'].should.equal("t2.small") - launch_spec['SubnetId'].should.equal(subnet_id) - launch_spec['SpotPrice'].should.equal("0.13") - launch_spec['WeightedCapacity'].should.equal(2.0) + launch_spec["EbsOptimized"].should.equal(False) + launch_spec["ImageId"].should.equal("ami-1234") + launch_spec["InstanceType"].should.equal("t2.small") + launch_spec["SubnetId"].should.equal(subnet_id) + launch_spec["SpotPrice"].should.equal("0.13") + launch_spec["WeightedCapacity"].should.equal(2.0) @mock_cloudformation() @mock_ec2() def test_stack_spot_fleet_should_figure_out_default_price(): - conn = boto3.client('ec2', 'us-east-1') + conn = boto3.client("ec2", "us-east-1") - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] spot_fleet_template = { - 'Resources': { + "Resources": { "SpotFleet1": { "Type": "AWS::EC2::SpotFleet", "Properties": { "SpotFleetRequestConfigData": { - "IamFleetRole": "arn:aws:iam::123456789012:role/fleet", + "IamFleetRole": "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID), "TargetCapacity": 6, "AllocationStrategy": "diversified", "LaunchSpecifications": [ { "EbsOptimized": "false", - "InstanceType": 't2.small', + "InstanceType": "t2.small", "ImageId": "ami-1234", "SubnetId": subnet_id, "WeightedCapacity": "2", }, { "EbsOptimized": "true", - "InstanceType": 't2.large', + "InstanceType": "t2.large", "ImageId": "ami-1234", "Monitoring": {"Enabled": "true"}, "SecurityGroups": [{"GroupId": "sg-123"}], "SubnetId": subnet_id, - "IamInstanceProfile": {"Arn": "arn:aws:iam::123456789012:role/fleet"}, + "IamInstanceProfile": { + "Arn": "arn:aws:iam::{}:role/fleet".format( + ACCOUNT_ID + ) + }, "WeightedCapacity": "4", - } - ] + }, + ], } - } + }, } } } spot_fleet_template_json = json.dumps(spot_fleet_template) - cf_conn = boto3.client('cloudformation', 'us-east-1') + cf_conn = boto3.client("cloudformation", "us-east-1") stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=spot_fleet_template_json, - )['StackId'] + StackName="test_stack", TemplateBody=spot_fleet_template_json + )["StackId"] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - stack_resources['StackResourceSummaries'].should.have.length_of(1) - spot_fleet_id = stack_resources[ - 'StackResourceSummaries'][0]['PhysicalResourceId'] + stack_resources["StackResourceSummaries"].should.have.length_of(1) + spot_fleet_id = stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_request['SpotFleetRequestState'].should.equal("active") - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_request["SpotFleetRequestState"].should.equal("active") + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - assert 'SpotPrice' not in spot_fleet_config - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec1 = spot_fleet_config['LaunchSpecifications'][0] - launch_spec2 = spot_fleet_config['LaunchSpecifications'][1] + assert "SpotPrice" not in spot_fleet_config + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec1 = spot_fleet_config["LaunchSpecifications"][0] + launch_spec2 = spot_fleet_config["LaunchSpecifications"][1] - assert 'SpotPrice' not in launch_spec1 - assert 'SpotPrice' not in launch_spec2 + assert "SpotPrice" not in launch_spec1 + assert "SpotPrice" not in launch_spec2 @mock_ec2 @@ -2262,19 +2083,15 @@ def test_stack_elbv2_resources_integration(): }, "Resources": { "alb": { - "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", - "Properties": { - "Name": "myelbv2", - "Scheme": "internet-facing", - "Subnets": [{ - "Ref": "mysubnet", - }], - "SecurityGroups": [{ - "Ref": "mysg", - }], - "Type": "application", - "IpAddressType": "ipv4", - } + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Name": "myelbv2", + "Scheme": "internet-facing", + "Subnets": [{"Ref": "mysubnet"}], + "SecurityGroups": [{"Ref": "mysg"}], + "Type": "application", + "IpAddressType": "ipv4", + }, }, "mytargetgroup1": { "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", @@ -2286,23 +2103,14 @@ def test_stack_elbv2_resources_integration(): "HealthCheckTimeoutSeconds": 5, "HealthyThresholdCount": 30, "UnhealthyThresholdCount": 5, - "Matcher": { - "HttpCode": "200,201" - }, + "Matcher": {"HttpCode": "200,201"}, "Name": "mytargetgroup1", "Port": 80, "Protocol": "HTTP", "TargetType": "instance", - "Targets": [{ - "Id": { - "Ref": "ec2instance", - "Port": 80, - }, - }], - "VpcId": { - "Ref": "myvpc", - } - } + "Targets": [{"Id": {"Ref": "ec2instance", "Port": 80}}], + "VpcId": {"Ref": "myvpc"}, + }, }, "mytargetgroup2": { "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", @@ -2318,250 +2126,206 @@ def test_stack_elbv2_resources_integration(): "Port": 8080, "Protocol": "HTTP", "TargetType": "instance", - "Targets": [{ - "Id": { - "Ref": "ec2instance", - "Port": 8080, - }, - }], - "VpcId": { - "Ref": "myvpc", - } - } + "Targets": [{"Id": {"Ref": "ec2instance", "Port": 8080}}], + "VpcId": {"Ref": "myvpc"}, + }, }, "listener": { "Type": "AWS::ElasticLoadBalancingV2::Listener", "Properties": { - "DefaultActions": [{ - "Type": "forward", - "TargetGroupArn": {"Ref": "mytargetgroup1"} - }], + "DefaultActions": [ + {"Type": "forward", "TargetGroupArn": {"Ref": "mytargetgroup1"}} + ], "LoadBalancerArn": {"Ref": "alb"}, "Port": "80", - "Protocol": "HTTP" - } + "Protocol": "HTTP", + }, }, "myvpc": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - } + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "mysubnet": { "Type": "AWS::EC2::Subnet", - "Properties": { - "CidrBlock": "10.0.0.0/27", - "VpcId": {"Ref": "myvpc"}, - } + "Properties": {"CidrBlock": "10.0.0.0/27", "VpcId": {"Ref": "myvpc"}}, }, "mysg": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupName": "mysg", "GroupDescription": "test security group", - "VpcId": {"Ref": "myvpc"} - } + "VpcId": {"Ref": "myvpc"}, + }, }, "ec2instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, }, } alb_template_json = json.dumps(alb_template) cfn_conn = boto3.client("cloudformation", "us-west-1") - cfn_conn.create_stack( - StackName="elb_stack", - TemplateBody=alb_template_json, - ) + cfn_conn.create_stack(StackName="elb_stack", TemplateBody=alb_template_json) elbv2_conn = boto3.client("elbv2", "us-west-1") - load_balancers = elbv2_conn.describe_load_balancers()['LoadBalancers'] + load_balancers = elbv2_conn.describe_load_balancers()["LoadBalancers"] len(load_balancers).should.equal(1) - load_balancers[0]['LoadBalancerName'].should.equal('myelbv2') - load_balancers[0]['Scheme'].should.equal('internet-facing') - load_balancers[0]['Type'].should.equal('application') - load_balancers[0]['IpAddressType'].should.equal('ipv4') + load_balancers[0]["LoadBalancerName"].should.equal("myelbv2") + load_balancers[0]["Scheme"].should.equal("internet-facing") + load_balancers[0]["Type"].should.equal("application") + load_balancers[0]["IpAddressType"].should.equal("ipv4") target_groups = sorted( - elbv2_conn.describe_target_groups()['TargetGroups'], - key=lambda tg: tg['TargetGroupName']) # sort to do comparison with indexes + elbv2_conn.describe_target_groups()["TargetGroups"], + key=lambda tg: tg["TargetGroupName"], + ) # sort to do comparison with indexes len(target_groups).should.equal(2) - target_groups[0]['HealthCheckIntervalSeconds'].should.equal(30) - target_groups[0]['HealthCheckPath'].should.equal('/status') - target_groups[0]['HealthCheckPort'].should.equal('80') - target_groups[0]['HealthCheckProtocol'].should.equal('HTTP') - target_groups[0]['HealthCheckTimeoutSeconds'].should.equal(5) - target_groups[0]['HealthyThresholdCount'].should.equal(30) - target_groups[0]['UnhealthyThresholdCount'].should.equal(5) - target_groups[0]['Matcher'].should.equal({'HttpCode': '200,201'}) - target_groups[0]['TargetGroupName'].should.equal('mytargetgroup1') - target_groups[0]['Port'].should.equal(80) - target_groups[0]['Protocol'].should.equal('HTTP') - target_groups[0]['TargetType'].should.equal('instance') + target_groups[0]["HealthCheckIntervalSeconds"].should.equal(30) + target_groups[0]["HealthCheckPath"].should.equal("/status") + target_groups[0]["HealthCheckPort"].should.equal("80") + target_groups[0]["HealthCheckProtocol"].should.equal("HTTP") + target_groups[0]["HealthCheckTimeoutSeconds"].should.equal(5) + target_groups[0]["HealthyThresholdCount"].should.equal(30) + target_groups[0]["UnhealthyThresholdCount"].should.equal(5) + target_groups[0]["Matcher"].should.equal({"HttpCode": "200,201"}) + target_groups[0]["TargetGroupName"].should.equal("mytargetgroup1") + target_groups[0]["Port"].should.equal(80) + target_groups[0]["Protocol"].should.equal("HTTP") + target_groups[0]["TargetType"].should.equal("instance") - target_groups[1]['HealthCheckIntervalSeconds'].should.equal(30) - target_groups[1]['HealthCheckPath'].should.equal('/status') - target_groups[1]['HealthCheckPort'].should.equal('8080') - target_groups[1]['HealthCheckProtocol'].should.equal('HTTP') - target_groups[1]['HealthCheckTimeoutSeconds'].should.equal(5) - target_groups[1]['HealthyThresholdCount'].should.equal(30) - target_groups[1]['UnhealthyThresholdCount'].should.equal(5) - target_groups[1]['Matcher'].should.equal({'HttpCode': '200'}) - target_groups[1]['TargetGroupName'].should.equal('mytargetgroup2') - target_groups[1]['Port'].should.equal(8080) - target_groups[1]['Protocol'].should.equal('HTTP') - target_groups[1]['TargetType'].should.equal('instance') + target_groups[1]["HealthCheckIntervalSeconds"].should.equal(30) + target_groups[1]["HealthCheckPath"].should.equal("/status") + target_groups[1]["HealthCheckPort"].should.equal("8080") + target_groups[1]["HealthCheckProtocol"].should.equal("HTTP") + target_groups[1]["HealthCheckTimeoutSeconds"].should.equal(5) + target_groups[1]["HealthyThresholdCount"].should.equal(30) + target_groups[1]["UnhealthyThresholdCount"].should.equal(5) + target_groups[1]["Matcher"].should.equal({"HttpCode": "200"}) + target_groups[1]["TargetGroupName"].should.equal("mytargetgroup2") + target_groups[1]["Port"].should.equal(8080) + target_groups[1]["Protocol"].should.equal("HTTP") + target_groups[1]["TargetType"].should.equal("instance") - listeners = elbv2_conn.describe_listeners(LoadBalancerArn=load_balancers[0]['LoadBalancerArn'])['Listeners'] + listeners = elbv2_conn.describe_listeners( + LoadBalancerArn=load_balancers[0]["LoadBalancerArn"] + )["Listeners"] len(listeners).should.equal(1) - listeners[0]['LoadBalancerArn'].should.equal(load_balancers[0]['LoadBalancerArn']) - listeners[0]['Port'].should.equal(80) - listeners[0]['Protocol'].should.equal('HTTP') - listeners[0]['DefaultActions'].should.equal([{ - "Type": "forward", - "TargetGroupArn": target_groups[0]['TargetGroupArn'] - }]) + listeners[0]["LoadBalancerArn"].should.equal(load_balancers[0]["LoadBalancerArn"]) + listeners[0]["Port"].should.equal(80) + listeners[0]["Protocol"].should.equal("HTTP") + listeners[0]["DefaultActions"].should.equal( + [{"Type": "forward", "TargetGroupArn": target_groups[0]["TargetGroupArn"]}] + ) # test outputs - stacks = cfn_conn.describe_stacks(StackName='elb_stack')['Stacks'] + stacks = cfn_conn.describe_stacks(StackName="elb_stack")["Stacks"] len(stacks).should.equal(1) - dns = list(filter(lambda item: item['OutputKey'] == 'albdns', stacks[0]['Outputs']))[0] - name = list(filter(lambda item: item['OutputKey'] == 'albname', stacks[0]['Outputs']))[0] + dns = list( + filter(lambda item: item["OutputKey"] == "albdns", stacks[0]["Outputs"]) + )[0] + name = list( + filter(lambda item: item["OutputKey"] == "albname", stacks[0]["Outputs"]) + )[0] - dns['OutputValue'].should.equal(load_balancers[0]['DNSName']) - name['OutputValue'].should.equal(load_balancers[0]['LoadBalancerName']) + dns["OutputValue"].should.equal(load_balancers[0]["DNSName"]) + name["OutputValue"].should.equal(load_balancers[0]["LoadBalancerName"]) @mock_dynamodb2 @mock_cloudformation def test_stack_dynamodb_resources_integration(): dynamodb_template = { - "AWSTemplateFormatVersion": "2010-09-09", - "Resources": { - "myDynamoDBTable": { - "Type": "AWS::DynamoDB::Table", - "Properties": { - "AttributeDefinitions": [ - { - "AttributeName": "Album", - "AttributeType": "S" - }, - { - "AttributeName": "Artist", - "AttributeType": "S" - }, - { - "AttributeName": "Sales", - "AttributeType": "N" - }, - { - "AttributeName": "NumberOfSongs", - "AttributeType": "N" - } - ], - "KeySchema": [ - { - "AttributeName": "Album", - "KeyType": "HASH" - }, - { - "AttributeName": "Artist", - "KeyType": "RANGE" - } - ], - "ProvisionedThroughput": { - "ReadCapacityUnits": "5", - "WriteCapacityUnits": "5" - }, - "TableName": "myTableName", - "GlobalSecondaryIndexes": [{ - "IndexName": "myGSI", - "KeySchema": [ - { - "AttributeName": "Sales", - "KeyType": "HASH" + "AWSTemplateFormatVersion": "2010-09-09", + "Resources": { + "myDynamoDBTable": { + "Type": "AWS::DynamoDB::Table", + "Properties": { + "AttributeDefinitions": [ + {"AttributeName": "Album", "AttributeType": "S"}, + {"AttributeName": "Artist", "AttributeType": "S"}, + {"AttributeName": "Sales", "AttributeType": "N"}, + {"AttributeName": "NumberOfSongs", "AttributeType": "N"}, + ], + "KeySchema": [ + {"AttributeName": "Album", "KeyType": "HASH"}, + {"AttributeName": "Artist", "KeyType": "RANGE"}, + ], + "ProvisionedThroughput": { + "ReadCapacityUnits": "5", + "WriteCapacityUnits": "5", + }, + "TableName": "myTableName", + "GlobalSecondaryIndexes": [ + { + "IndexName": "myGSI", + "KeySchema": [ + {"AttributeName": "Sales", "KeyType": "HASH"}, + {"AttributeName": "Artist", "KeyType": "RANGE"}, + ], + "Projection": { + "NonKeyAttributes": ["Album", "NumberOfSongs"], + "ProjectionType": "INCLUDE", + }, + "ProvisionedThroughput": { + "ReadCapacityUnits": "5", + "WriteCapacityUnits": "5", + }, + }, + { + "IndexName": "myGSI2", + "KeySchema": [ + {"AttributeName": "NumberOfSongs", "KeyType": "HASH"}, + {"AttributeName": "Sales", "KeyType": "RANGE"}, + ], + "Projection": { + "NonKeyAttributes": ["Album", "Artist"], + "ProjectionType": "INCLUDE", + }, + "ProvisionedThroughput": { + "ReadCapacityUnits": "5", + "WriteCapacityUnits": "5", + }, + }, + ], + "LocalSecondaryIndexes": [ + { + "IndexName": "myLSI", + "KeySchema": [ + {"AttributeName": "Album", "KeyType": "HASH"}, + {"AttributeName": "Sales", "KeyType": "RANGE"}, + ], + "Projection": { + "NonKeyAttributes": ["Artist", "NumberOfSongs"], + "ProjectionType": "INCLUDE", + }, + } + ], }, - { - "AttributeName": "Artist", - "KeyType": "RANGE" - } - ], - "Projection": { - "NonKeyAttributes": ["Album","NumberOfSongs"], - "ProjectionType": "INCLUDE" - }, - "ProvisionedThroughput": { - "ReadCapacityUnits": "5", - "WriteCapacityUnits": "5" - } - }, - { - "IndexName": "myGSI2", - "KeySchema": [ - { - "AttributeName": "NumberOfSongs", - "KeyType": "HASH" - }, - { - "AttributeName": "Sales", - "KeyType": "RANGE" - } - ], - "Projection": { - "NonKeyAttributes": ["Album","Artist"], - "ProjectionType": "INCLUDE" - }, - "ProvisionedThroughput": { - "ReadCapacityUnits": "5", - "WriteCapacityUnits": "5" - } - }], - "LocalSecondaryIndexes":[{ - "IndexName": "myLSI", - "KeySchema": [ - { - "AttributeName": "Album", - "KeyType": "HASH" - }, - { - "AttributeName": "Sales", - "KeyType": "RANGE" - } - ], - "Projection": { - "NonKeyAttributes": ["Artist","NumberOfSongs"], - "ProjectionType": "INCLUDE" - } - }] - } - } - } + } + }, } dynamodb_template_json = json.dumps(dynamodb_template) - cfn_conn = boto3.client('cloudformation', 'us-east-1') + cfn_conn = boto3.client("cloudformation", "us-east-1") cfn_conn.create_stack( - StackName='dynamodb_stack', - TemplateBody=dynamodb_template_json, + StackName="dynamodb_stack", TemplateBody=dynamodb_template_json ) - dynamodb_conn = boto3.resource('dynamodb', region_name='us-east-1') - table = dynamodb_conn.Table('myTableName') - table.name.should.equal('myTableName') + dynamodb_conn = boto3.resource("dynamodb", region_name="us-east-1") + table = dynamodb_conn.Table("myTableName") + table.name.should.equal("myTableName") - table.put_item(Item={"Album": "myAlbum", "Artist": "myArtist", "Sales": 10, "NumberOfSongs": 5}) + table.put_item( + Item={"Album": "myAlbum", "Artist": "myArtist", "Sales": 10, "NumberOfSongs": 5} + ) response = table.get_item(Key={"Album": "myAlbum", "Artist": "myArtist"}) - response['Item']['Album'].should.equal('myAlbum') - response['Item']['Sales'].should.equal(Decimal('10')) - response['Item']['NumberOfSongs'].should.equal(Decimal('5')) - response['Item']['Album'].should.equal('myAlbum') + response["Item"]["Album"].should.equal("myAlbum") + response["Item"]["Sales"].should.equal(Decimal("10")) + response["Item"]["NumberOfSongs"].should.equal(Decimal("5")) + response["Item"]["Album"].should.equal("myAlbum") diff --git a/tests/test_cloudformation/test_import_value.py b/tests/test_cloudformation/test_import_value.py index d702753a6..41a8a8a30 100644 --- a/tests/test_cloudformation/test_import_value.py +++ b/tests/test_cloudformation/test_import_value.py @@ -1,87 +1,89 @@ -# -*- coding: utf-8 -*- -from __future__ import absolute_import, division, print_function, unicode_literals - -# Standard library modules -import unittest - -# Third-party modules -import boto3 -from botocore.exceptions import ClientError - -# Package modules -from moto import mock_cloudformation - -AWS_REGION = 'us-west-1' - -SG_STACK_NAME = 'simple-sg-stack' -SG_TEMPLATE = """ -AWSTemplateFormatVersion: 2010-09-09 -Description: Simple test CF template for moto_cloudformation - - -Resources: - SimpleSecurityGroup: - Type: AWS::EC2::SecurityGroup - Description: "A simple security group" - Properties: - GroupName: simple-security-group - GroupDescription: "A simple security group" - SecurityGroupEgress: - - - Description: "Egress to remote HTTPS servers" - CidrIp: 0.0.0.0/0 - IpProtocol: tcp - FromPort: 443 - ToPort: 443 - -Outputs: - SimpleSecurityGroupName: - Value: !GetAtt SimpleSecurityGroup.GroupId - Export: - Name: "SimpleSecurityGroup" - -""" - -EC2_STACK_NAME = 'simple-ec2-stack' -EC2_TEMPLATE = """ ---- -# The latest template format version is "2010-09-09" and as of 2018-04-09 -# is currently the only valid value. -AWSTemplateFormatVersion: 2010-09-09 -Description: Simple test CF template for moto_cloudformation - - -Resources: - SimpleInstance: - Type: AWS::EC2::Instance - Properties: - ImageId: ami-03cf127a - InstanceType: t2.micro - SecurityGroups: !Split [',', !ImportValue SimpleSecurityGroup] -""" - - -class TestSimpleInstance(unittest.TestCase): - def test_simple_instance(self): - """Test that we can create a simple CloudFormation stack that imports values from an existing CloudFormation stack""" - with mock_cloudformation(): - client = boto3.client('cloudformation', region_name=AWS_REGION) - client.create_stack(StackName=SG_STACK_NAME, TemplateBody=SG_TEMPLATE) - response = client.create_stack(StackName=EC2_STACK_NAME, TemplateBody=EC2_TEMPLATE) - self.assertIn('StackId', response) - response = client.describe_stacks(StackName=response['StackId']) - self.assertIn('Stacks', response) - stack_info = response['Stacks'] - self.assertEqual(1, len(stack_info)) - self.assertIn('StackName', stack_info[0]) - self.assertEqual(EC2_STACK_NAME, stack_info[0]['StackName']) - - def test_simple_instance_missing_export(self): - """Test that we get an exception if a CloudFormation stack tries to imports a non-existent export value""" - with mock_cloudformation(): - client = boto3.client('cloudformation', region_name=AWS_REGION) - with self.assertRaises(ClientError) as e: - client.create_stack(StackName=EC2_STACK_NAME, TemplateBody=EC2_TEMPLATE) - self.assertIn('Error', e.exception.response) - self.assertIn('Code', e.exception.response['Error']) - self.assertEqual('ExportNotFound', e.exception.response['Error']['Code']) +# -*- coding: utf-8 -*- +from __future__ import absolute_import, division, print_function, unicode_literals + +# Standard library modules +import unittest + +# Third-party modules +import boto3 +from botocore.exceptions import ClientError + +# Package modules +from moto import mock_cloudformation + +AWS_REGION = "us-west-1" + +SG_STACK_NAME = "simple-sg-stack" +SG_TEMPLATE = """ +AWSTemplateFormatVersion: 2010-09-09 +Description: Simple test CF template for moto_cloudformation + + +Resources: + SimpleSecurityGroup: + Type: AWS::EC2::SecurityGroup + Description: "A simple security group" + Properties: + GroupName: simple-security-group + GroupDescription: "A simple security group" + SecurityGroupEgress: + - + Description: "Egress to remote HTTPS servers" + CidrIp: 0.0.0.0/0 + IpProtocol: tcp + FromPort: 443 + ToPort: 443 + +Outputs: + SimpleSecurityGroupName: + Value: !GetAtt SimpleSecurityGroup.GroupId + Export: + Name: "SimpleSecurityGroup" + +""" + +EC2_STACK_NAME = "simple-ec2-stack" +EC2_TEMPLATE = """ +--- +# The latest template format version is "2010-09-09" and as of 2018-04-09 +# is currently the only valid value. +AWSTemplateFormatVersion: 2010-09-09 +Description: Simple test CF template for moto_cloudformation + + +Resources: + SimpleInstance: + Type: AWS::EC2::Instance + Properties: + ImageId: ami-03cf127a + InstanceType: t2.micro + SecurityGroups: !Split [',', !ImportValue SimpleSecurityGroup] +""" + + +class TestSimpleInstance(unittest.TestCase): + def test_simple_instance(self): + """Test that we can create a simple CloudFormation stack that imports values from an existing CloudFormation stack""" + with mock_cloudformation(): + client = boto3.client("cloudformation", region_name=AWS_REGION) + client.create_stack(StackName=SG_STACK_NAME, TemplateBody=SG_TEMPLATE) + response = client.create_stack( + StackName=EC2_STACK_NAME, TemplateBody=EC2_TEMPLATE + ) + self.assertIn("StackId", response) + response = client.describe_stacks(StackName=response["StackId"]) + self.assertIn("Stacks", response) + stack_info = response["Stacks"] + self.assertEqual(1, len(stack_info)) + self.assertIn("StackName", stack_info[0]) + self.assertEqual(EC2_STACK_NAME, stack_info[0]["StackName"]) + + def test_simple_instance_missing_export(self): + """Test that we get an exception if a CloudFormation stack tries to imports a non-existent export value""" + with mock_cloudformation(): + client = boto3.client("cloudformation", region_name=AWS_REGION) + with self.assertRaises(ClientError) as e: + client.create_stack(StackName=EC2_STACK_NAME, TemplateBody=EC2_TEMPLATE) + self.assertIn("Error", e.exception.response) + self.assertIn("Code", e.exception.response["Error"]) + self.assertEqual("ExportNotFound", e.exception.response["Error"]["Code"]) diff --git a/tests/test_cloudformation/test_server.py b/tests/test_cloudformation/test_server.py index 11f810357..f3f037c42 100644 --- a/tests/test_cloudformation/test_server.py +++ b/tests/test_cloudformation/test_server.py @@ -1,33 +1,36 @@ -from __future__ import unicode_literals - -import json -from six.moves.urllib.parse import urlencode -import re -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_cloudformation_server_get(): - backend = server.create_backend_app("cloudformation") - stack_name = 'test stack' - test_client = backend.test_client() - template_body = { - "Resources": {}, - } - create_stack_resp = test_client.action_data("CreateStack", StackName=stack_name, - TemplateBody=json.dumps(template_body)) - create_stack_resp.should.match( - r".*.*.*.*.*", re.DOTALL) - stack_id_from_create_response = re.search( - "(.*)", create_stack_resp).groups()[0] - - list_stacks_resp = test_client.action_data("ListStacks") - stack_id_from_list_response = re.search( - "(.*)", list_stacks_resp).groups()[0] - - stack_id_from_create_response.should.equal(stack_id_from_list_response) +from __future__ import unicode_literals + +import json +from six.moves.urllib.parse import urlencode +import re +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_cloudformation_server_get(): + backend = server.create_backend_app("cloudformation") + stack_name = "test stack" + test_client = backend.test_client() + template_body = {"Resources": {}} + create_stack_resp = test_client.action_data( + "CreateStack", StackName=stack_name, TemplateBody=json.dumps(template_body) + ) + create_stack_resp.should.match( + r".*.*.*.*.*", + re.DOTALL, + ) + stack_id_from_create_response = re.search( + "(.*)", create_stack_resp + ).groups()[0] + + list_stacks_resp = test_client.action_data("ListStacks") + stack_id_from_list_response = re.search( + "(.*)", list_stacks_resp + ).groups()[0] + + stack_id_from_create_response.should.equal(stack_id_from_list_response) diff --git a/tests/test_cloudformation/test_stack_parsing.py b/tests/test_cloudformation/test_stack_parsing.py index 25242e352..85df76592 100644 --- a/tests/test_cloudformation/test_stack_parsing.py +++ b/tests/test_cloudformation/test_stack_parsing.py @@ -7,91 +7,57 @@ import sure # noqa from moto.cloudformation.exceptions import ValidationError from moto.cloudformation.models import FakeStack -from moto.cloudformation.parsing import resource_class_from_type, parse_condition, Export +from moto.cloudformation.parsing import ( + resource_class_from_type, + parse_condition, + Export, +) from moto.sqs.models import Queue from moto.s3.models import FakeBucket from moto.cloudformation.utils import yaml_tag_constructor from boto.cloudformation.stack import Output - dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "Create a multi-az, load balanced, Auto Scaled sample web site. The Auto Scaling trigger is based on the CPU utilization of the web servers. The AMI is chosen based on the region in which the stack is run. This example creates a web service running across all availability zones in a region. The instances are load balanced with a simple health check. The web site is available on port 80, however, the instances can be configured to listen on any port (8888 by default). **WARNING** This template creates one or more Amazon EC2 instances. You will be billed for the AWS resources used if you create a stack from this template.", - "Resources": { "Queue": { "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, - "S3Bucket": { - "Type": "AWS::S3::Bucket", - "DeletionPolicy": "Retain" + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, }, + "S3Bucket": {"Type": "AWS::S3::Bucket", "DeletionPolicy": "Retain"}, }, } name_type_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "Create a multi-az, load balanced, Auto Scaled sample web site. The Auto Scaling trigger is based on the CPU utilization of the web servers. The AMI is chosen based on the region in which the stack is run. This example creates a web service running across all availability zones in a region. The instances are load balanced with a simple health check. The web site is available on port 80, however, the instances can be configured to listen on any port (8888 by default). **WARNING** This template creates one or more Amazon EC2 instances. You will be billed for the AWS resources used if you create a stack from this template.", - "Resources": { - "Queue": { - "Type": "AWS::SQS::Queue", - "Properties": { - "VisibilityTimeout": 60, - } - }, + "Queue": {"Type": "AWS::SQS::Queue", "Properties": {"VisibilityTimeout": 60}} }, } output_dict = { "Outputs": { - "Output1": { - "Value": {"Ref": "Queue"}, - "Description": "This is a description." - } + "Output1": {"Value": {"Ref": "Queue"}, "Description": "This is a description."} } } bad_output = { - "Outputs": { - "Output1": { - "Value": {"Fn::GetAtt": ["Queue", "InvalidAttribute"]} - } - } + "Outputs": {"Output1": {"Value": {"Fn::GetAtt": ["Queue", "InvalidAttribute"]}}} } get_attribute_output = { - "Outputs": { - "Output1": { - "Value": {"Fn::GetAtt": ["Queue", "QueueName"]} - } - } + "Outputs": {"Output1": {"Value": {"Fn::GetAtt": ["Queue", "QueueName"]}}} } -get_availability_zones_output = { - "Outputs": { - "Output1": { - "Value": {"Fn::GetAZs": ""} - } - } -} +get_availability_zones_output = {"Outputs": {"Output1": {"Value": {"Fn::GetAZs": ""}}}} parameters = { "Parameters": { - "Param": { - "Type": "String", - }, - "NoEchoParam": { - "Type": "String", - "NoEcho": True - } + "Param": {"Type": "String"}, + "NoEchoParam": {"Type": "String", "NoEcho": True}, } } @@ -101,11 +67,11 @@ split_select_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Select": [ "1", {"Fn::Split": [ "-", "123-myqueue" ] } ] }, + "QueueName": {"Fn::Select": ["1", {"Fn::Split": ["-", "123-myqueue"]}]}, "VisibilityTimeout": 60, - } + }, } - } + }, } sub_template = { @@ -114,18 +80,18 @@ sub_template = { "Queue1": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Sub": '${AWS::StackName}-queue-${!Literal}'}, + "QueueName": {"Fn::Sub": "${AWS::StackName}-queue-${!Literal}"}, "VisibilityTimeout": 60, - } + }, }, "Queue2": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Sub": '${Queue1.QueueName}'}, + "QueueName": {"Fn::Sub": "${Queue1.QueueName}"}, "VisibilityTimeout": 60, - } + }, }, - } + }, } export_value_template = { @@ -134,17 +100,12 @@ export_value_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Sub": '${AWS::StackName}-queue'}, + "QueueName": {"Fn::Sub": "${AWS::StackName}-queue"}, "VisibilityTimeout": 60, - } + }, } }, - "Outputs": { - "Output1": { - "Value": "value", - "Export": {"Name": 'queue-us-west-1'} - } - } + "Outputs": {"Output1": {"Value": "value", "Export": {"Name": "queue-us-west-1"}}}, } import_value_template = { @@ -153,33 +114,30 @@ import_value_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::ImportValue": 'queue-us-west-1'}, + "QueueName": {"Fn::ImportValue": "queue-us-west-1"}, "VisibilityTimeout": 60, - } + }, } - } + }, } -outputs_template = dict(list(dummy_template.items()) + - list(output_dict.items())) -bad_outputs_template = dict( - list(dummy_template.items()) + list(bad_output.items())) +outputs_template = dict(list(dummy_template.items()) + list(output_dict.items())) +bad_outputs_template = dict(list(dummy_template.items()) + list(bad_output.items())) get_attribute_outputs_template = dict( - list(dummy_template.items()) + list(get_attribute_output.items())) + list(dummy_template.items()) + list(get_attribute_output.items()) +) get_availability_zones_template = dict( - list(dummy_template.items()) + list(get_availability_zones_output.items())) + list(dummy_template.items()) + list(get_availability_zones_output.items()) +) -parameters_template = dict( - list(dummy_template.items()) + list(parameters.items())) +parameters_template = dict(list(dummy_template.items()) + list(parameters.items())) dummy_template_json = json.dumps(dummy_template) name_type_template_json = json.dumps(name_type_template) output_type_template_json = json.dumps(outputs_template) bad_output_template_json = json.dumps(bad_outputs_template) -get_attribute_outputs_template_json = json.dumps( - get_attribute_outputs_template) -get_availability_zones_template_json = json.dumps( - get_availability_zones_template) +get_attribute_outputs_template_json = json.dumps(get_attribute_outputs_template) +get_availability_zones_template_json = json.dumps(get_availability_zones_template) parameters_template_json = json.dumps(parameters_template) split_select_template_json = json.dumps(split_select_template) sub_template_json = json.dumps(sub_template) @@ -193,15 +151,16 @@ def test_parse_stack_resources(): name="test_stack", template=dummy_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(2) - queue = stack.resource_map['Queue'] + queue = stack.resource_map["Queue"] queue.should.be.a(Queue) queue.name.should.equal("my-queue") - bucket = stack.resource_map['S3Bucket'] + bucket = stack.resource_map["S3Bucket"] bucket.should.be.a(FakeBucket) bucket.physical_resource_id.should.equal(bucket.name) @@ -209,8 +168,7 @@ def test_parse_stack_resources(): @patch("moto.cloudformation.parsing.logger") def test_missing_resource_logs(logger): resource_class_from_type("foobar") - logger.warning.assert_called_with( - 'No Moto CloudFormation support for %s', 'foobar') + logger.warning.assert_called_with("No Moto CloudFormation support for %s", "foobar") def test_parse_stack_with_name_type_resource(): @@ -219,10 +177,11 @@ def test_parse_stack_with_name_type_resource(): name="test_stack", template=name_type_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(1) - list(stack.resource_map.keys())[0].should.equal('Queue') + list(stack.resource_map.keys())[0].should.equal("Queue") queue = list(stack.resource_map.values())[0] queue.should.be.a(Queue) @@ -233,10 +192,11 @@ def test_parse_stack_with_yaml_template(): name="test_stack", template=yaml.dump(name_type_template), parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(1) - list(stack.resource_map.keys())[0].should.equal('Queue') + list(stack.resource_map.keys())[0].should.equal("Queue") queue = list(stack.resource_map.values())[0] queue.should.be.a(Queue) @@ -247,10 +207,11 @@ def test_parse_stack_with_outputs(): name="test_stack", template=output_type_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('Output1') + list(stack.output_map.keys())[0].should.equal("Output1") output = list(stack.output_map.values())[0] output.should.be.a(Output) output.description.should.equal("This is a description.") @@ -262,14 +223,16 @@ def test_parse_stack_with_get_attribute_outputs(): name="test_stack", template=get_attribute_outputs_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('Output1') + list(stack.output_map.keys())[0].should.equal("Output1") output = list(stack.output_map.values())[0] output.should.be.a(Output) output.value.should.equal("my-queue") + def test_parse_stack_with_get_attribute_kms(): from .fixtures.kms_key import template @@ -279,31 +242,35 @@ def test_parse_stack_with_get_attribute_kms(): name="test_stack", template=template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('KeyArn') + list(stack.output_map.keys())[0].should.equal("KeyArn") output = list(stack.output_map.values())[0] output.should.be.a(Output) + def test_parse_stack_with_get_availability_zones(): stack = FakeStack( stack_id="test_id", name="test_stack", template=get_availability_zones_template_json, parameters={}, - region_name='us-east-1') + region_name="us-east-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('Output1') + list(stack.output_map.keys())[0].should.equal("Output1") output = list(stack.output_map.values())[0] output.should.be.a(Output) - output.value.should.equal([ "us-east-1a", "us-east-1b", "us-east-1c", "us-east-1d" ]) + output.value.should.equal(["us-east-1a", "us-east-1b", "us-east-1c", "us-east-1d"]) def test_parse_stack_with_bad_get_attribute_outputs(): FakeStack.when.called_with( - "test_id", "test_stack", bad_output_template_json, {}, "us-west-1").should.throw(ValidationError) + "test_id", "test_stack", bad_output_template_json, {}, "us-west-1" + ).should.throw(ValidationError) def test_parse_stack_with_parameters(): @@ -312,7 +279,8 @@ def test_parse_stack_with_parameters(): name="test_stack", template=parameters_template_json, parameters={"Param": "visible value", "NoEchoParam": "hidden value"}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.no_echo_parameter_keys.should.have("NoEchoParam") stack.resource_map.no_echo_parameter_keys.should_not.have("Param") @@ -334,21 +302,13 @@ def test_parse_equals_condition(): def test_parse_not_condition(): parse_condition( - condition={ - "Fn::Not": [{ - "Fn::Equals": [{"Ref": "EnvType"}, "prod"] - }] - }, + condition={"Fn::Not": [{"Fn::Equals": [{"Ref": "EnvType"}, "prod"]}]}, resources_map={"EnvType": "prod"}, condition_map={}, ).should.equal(False) parse_condition( - condition={ - "Fn::Not": [{ - "Fn::Equals": [{"Ref": "EnvType"}, "prod"] - }] - }, + condition={"Fn::Not": [{"Fn::Equals": [{"Ref": "EnvType"}, "prod"]}]}, resources_map={"EnvType": "staging"}, condition_map={}, ).should.equal(True) @@ -416,10 +376,11 @@ def test_parse_split_and_select(): name="test_stack", template=split_select_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(1) - queue = stack.resource_map['Queue'] + queue = stack.resource_map["Queue"] queue.name.should.equal("myqueue") @@ -429,10 +390,11 @@ def test_sub(): name="test_stack", template=sub_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) - queue1 = stack.resource_map['Queue1'] - queue2 = stack.resource_map['Queue2'] + queue1 = stack.resource_map["Queue1"] + queue2 = stack.resource_map["Queue2"] queue2.name.should.equal(queue1.name) @@ -442,20 +404,21 @@ def test_import(): name="test_stack", template=export_value_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) import_stack = FakeStack( stack_id="test_id", name="test_stack", template=import_value_template_json, parameters={}, - region_name='us-west-1', - cross_stack_resources={export_stack.exports[0].value: export_stack.exports[0]}) + region_name="us-west-1", + cross_stack_resources={export_stack.exports[0].value: export_stack.exports[0]}, + ) - queue = import_stack.resource_map['Queue'] + queue = import_stack.resource_map["Queue"] queue.name.should.equal("value") - def test_short_form_func_in_yaml_teamplate(): template = """--- KeyB64: !Base64 valueToEncode @@ -476,24 +439,24 @@ def test_short_form_func_in_yaml_teamplate(): KeySplit: !Split [A, B] KeySub: !Sub A """ - yaml.add_multi_constructor('', yaml_tag_constructor, Loader=yaml.Loader) + yaml.add_multi_constructor("", yaml_tag_constructor, Loader=yaml.Loader) template_dict = yaml.load(template, Loader=yaml.Loader) key_and_expects = [ - ['KeyRef', {'Ref': 'foo'}], - ['KeyB64', {'Fn::Base64': 'valueToEncode'}], - ['KeyAnd', {'Fn::And': ['A', 'B']}], - ['KeyEquals', {'Fn::Equals': ['A', 'B']}], - ['KeyIf', {'Fn::If': ['A', 'B', 'C']}], - ['KeyNot', {'Fn::Not': ['A']}], - ['KeyOr', {'Fn::Or': ['A', 'B']}], - ['KeyFindInMap', {'Fn::FindInMap': ['A', 'B', 'C']}], - ['KeyGetAtt', {'Fn::GetAtt': ['A', 'B']}], - ['KeyGetAZs', {'Fn::GetAZs': 'A'}], - ['KeyImportValue', {'Fn::ImportValue': 'A'}], - ['KeyJoin', {'Fn::Join': [ ":", [ 'A', 'B', 'C' ] ]}], - ['KeySelect', {'Fn::Select': ['A', 'B']}], - ['KeySplit', {'Fn::Split': ['A', 'B']}], - ['KeySub', {'Fn::Sub': 'A'}], + ["KeyRef", {"Ref": "foo"}], + ["KeyB64", {"Fn::Base64": "valueToEncode"}], + ["KeyAnd", {"Fn::And": ["A", "B"]}], + ["KeyEquals", {"Fn::Equals": ["A", "B"]}], + ["KeyIf", {"Fn::If": ["A", "B", "C"]}], + ["KeyNot", {"Fn::Not": ["A"]}], + ["KeyOr", {"Fn::Or": ["A", "B"]}], + ["KeyFindInMap", {"Fn::FindInMap": ["A", "B", "C"]}], + ["KeyGetAtt", {"Fn::GetAtt": ["A", "B"]}], + ["KeyGetAZs", {"Fn::GetAZs": "A"}], + ["KeyImportValue", {"Fn::ImportValue": "A"}], + ["KeyJoin", {"Fn::Join": [":", ["A", "B", "C"]]}], + ["KeySelect", {"Fn::Select": ["A", "B"]}], + ["KeySplit", {"Fn::Split": ["A", "B"]}], + ["KeySub", {"Fn::Sub": "A"}], ] for k, v in key_and_expects: template_dict.should.have.key(k).which.should.be.equal(v) diff --git a/tests/test_cloudformation/test_validate.py b/tests/test_cloudformation/test_validate.py index e2c3af05d..4dd4d7e08 100644 --- a/tests/test_cloudformation/test_validate.py +++ b/tests/test_cloudformation/test_validate.py @@ -9,7 +9,11 @@ import botocore from moto.cloudformation.exceptions import ValidationError from moto.cloudformation.models import FakeStack -from moto.cloudformation.parsing import resource_class_from_type, parse_condition, Export +from moto.cloudformation.parsing import ( + resource_class_from_type, + parse_condition, + Export, +) from moto.sqs.models import Queue from moto.s3.models import FakeBucket from moto.cloudformation.utils import yaml_tag_constructor @@ -27,25 +31,16 @@ json_template = { "KeyName": "dummy", "InstanceType": "t2.micro", "Tags": [ - { - "Key": "Description", - "Value": "Test tag" - }, - { - "Key": "Name", - "Value": "Name tag for tests" - } - ] - } + {"Key": "Description", "Value": "Test tag"}, + {"Key": "Name", "Value": "Name tag for tests"}, + ], + }, } - } + }, } # One resource is required -json_bad_template = { - "AWSTemplateFormatVersion": "2010-09-09", - "Description": "Stack 1" -} +json_bad_template = {"AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 1"} dummy_template_json = json.dumps(json_template) dummy_bad_template_json = json.dumps(json_bad_template) @@ -53,25 +48,25 @@ dummy_bad_template_json = json.dumps(json_bad_template) @mock_cloudformation def test_boto3_json_validate_successful(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - response = cf_conn.validate_template( - TemplateBody=dummy_template_json, - ) - assert response['Description'] == "Stack 1" - assert response['Parameters'] == [] - assert response['ResponseMetadata']['HTTPStatusCode'] == 200 + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + response = cf_conn.validate_template(TemplateBody=dummy_template_json) + assert response["Description"] == "Stack 1" + assert response["Parameters"] == [] + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + @mock_cloudformation def test_boto3_json_invalid_missing_resource(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") try: - cf_conn.validate_template( - TemplateBody=dummy_bad_template_json, - ) + cf_conn.validate_template(TemplateBody=dummy_bad_template_json) assert False except botocore.exceptions.ClientError as e: - assert str(e) == 'An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack' \ - ' with id Missing top level item Resources to file module does not exist' + assert ( + str(e) + == "An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack" + " with id Missing top level item Resources to file module does not exist" + ) assert True @@ -91,25 +86,26 @@ yaml_bad_template = """ Description: Simple CloudFormation Test Template """ + @mock_cloudformation def test_boto3_yaml_validate_successful(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - response = cf_conn.validate_template( - TemplateBody=yaml_template, - ) - assert response['Description'] == "Simple CloudFormation Test Template" - assert response['Parameters'] == [] - assert response['ResponseMetadata']['HTTPStatusCode'] == 200 + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + response = cf_conn.validate_template(TemplateBody=yaml_template) + assert response["Description"] == "Simple CloudFormation Test Template" + assert response["Parameters"] == [] + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + @mock_cloudformation def test_boto3_yaml_invalid_missing_resource(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") try: - cf_conn.validate_template( - TemplateBody=yaml_bad_template, - ) + cf_conn.validate_template(TemplateBody=yaml_bad_template) assert False except botocore.exceptions.ClientError as e: - assert str(e) == 'An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack' \ - ' with id Missing top level item Resources to file module does not exist' + assert ( + str(e) + == "An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack" + " with id Missing top level item Resources to file module does not exist" + ) assert True diff --git a/tests/test_cloudwatch/test_cloudwatch.py b/tests/test_cloudwatch/test_cloudwatch.py index 2ba233735..cc624e852 100644 --- a/tests/test_cloudwatch/test_cloudwatch.py +++ b/tests/test_cloudwatch/test_cloudwatch.py @@ -1,30 +1,27 @@ import boto from boto.ec2.cloudwatch.alarm import MetricAlarm -import boto3 -from datetime import datetime, timedelta -import pytz import sure # noqa from moto import mock_cloudwatch_deprecated def alarm_fixture(name="tester", action=None): - action = action or ['arn:alarm'] + action = action or ["arn:alarm"] return MetricAlarm( name=name, namespace="{0}_namespace".format(name), metric="{0}_metric".format(name), - comparison='>=', + comparison=">=", threshold=2.0, period=60, evaluation_periods=5, - statistic='Average', - description='A test', - dimensions={'InstanceId': ['i-0123456,i-0123457']}, + statistic="Average", + description="A test", + dimensions={"InstanceId": ["i-0123456,i-0123457"]}, alarm_actions=action, - ok_actions=['arn:ok'], - insufficient_data_actions=['arn:insufficient'], - unit='Seconds', + ok_actions=["arn:ok"], + insufficient_data_actions=["arn:insufficient"], + unit="Seconds", ) @@ -38,21 +35,20 @@ def test_create_alarm(): alarms = conn.describe_alarms() alarms.should.have.length_of(1) alarm = alarms[0] - alarm.name.should.equal('tester') - alarm.namespace.should.equal('tester_namespace') - alarm.metric.should.equal('tester_metric') - alarm.comparison.should.equal('>=') + alarm.name.should.equal("tester") + alarm.namespace.should.equal("tester_namespace") + alarm.metric.should.equal("tester_metric") + alarm.comparison.should.equal(">=") alarm.threshold.should.equal(2.0) alarm.period.should.equal(60) alarm.evaluation_periods.should.equal(5) - alarm.statistic.should.equal('Average') - alarm.description.should.equal('A test') - dict(alarm.dimensions).should.equal( - {'InstanceId': ['i-0123456,i-0123457']}) - list(alarm.alarm_actions).should.equal(['arn:alarm']) - list(alarm.ok_actions).should.equal(['arn:ok']) - list(alarm.insufficient_data_actions).should.equal(['arn:insufficient']) - alarm.unit.should.equal('Seconds') + alarm.statistic.should.equal("Average") + alarm.description.should.equal("A test") + dict(alarm.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]}) + list(alarm.alarm_actions).should.equal(["arn:alarm"]) + list(alarm.ok_actions).should.equal(["arn:ok"]) + list(alarm.insufficient_data_actions).should.equal(["arn:insufficient"]) + alarm.unit.should.equal("Seconds") @mock_cloudwatch_deprecated @@ -79,19 +75,18 @@ def test_put_metric_data(): conn = boto.connect_cloudwatch() conn.put_metric_data( - namespace='tester', - name='metric', + namespace="tester", + name="metric", value=1.5, - dimensions={'InstanceId': ['i-0123456,i-0123457']}, + dimensions={"InstanceId": ["i-0123456,i-0123457"]}, ) metrics = conn.list_metrics() metrics.should.have.length_of(1) metric = metrics[0] - metric.namespace.should.equal('tester') - metric.name.should.equal('metric') - dict(metric.dimensions).should.equal( - {'InstanceId': ['i-0123456,i-0123457']}) + metric.namespace.should.equal("tester") + metric.name.should.equal("metric") + dict(metric.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]}) @mock_cloudwatch_deprecated @@ -110,8 +105,7 @@ def test_describe_alarms(): alarms.should.have.length_of(4) alarms = conn.describe_alarms(alarm_name_prefix="nfoo") alarms.should.have.length_of(2) - alarms = conn.describe_alarms( - alarm_names=["nfoobar", "nbarfoo", "nbazfoo"]) + alarms = conn.describe_alarms(alarm_names=["nfoobar", "nbarfoo", "nbazfoo"]) alarms.should.have.length_of(3) alarms = conn.describe_alarms(action_prefix="afoo") alarms.should.have.length_of(2) diff --git a/tests/test_cloudwatch/test_cloudwatch_boto3.py b/tests/test_cloudwatch/test_cloudwatch_boto3.py old mode 100755 new mode 100644 index 3c205f400..5bd9ed13d --- a/tests/test_cloudwatch/test_cloudwatch_boto3.py +++ b/tests/test_cloudwatch/test_cloudwatch_boto3.py @@ -1,224 +1,291 @@ -from __future__ import unicode_literals - -import boto3 -from botocore.exceptions import ClientError -from datetime import datetime, timedelta -import pytz -import sure # noqa - -from moto import mock_cloudwatch - - -@mock_cloudwatch -def test_put_list_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') - widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - resp = client.list_dashboards() - - len(resp['DashboardEntries']).should.equal(1) - - -@mock_cloudwatch -def test_put_list_prefix_nomatch_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') - widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - resp = client.list_dashboards(DashboardNamePrefix='nomatch') - - len(resp['DashboardEntries']).should.equal(0) - - -@mock_cloudwatch -def test_delete_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') - widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - client.put_dashboard(DashboardName='test2', DashboardBody=widget) - client.put_dashboard(DashboardName='test3', DashboardBody=widget) - client.delete_dashboards(DashboardNames=['test2', 'test1']) - - resp = client.list_dashboards(DashboardNamePrefix='test3') - len(resp['DashboardEntries']).should.equal(1) - - -@mock_cloudwatch -def test_delete_dashboard_fail(): - client = boto3.client('cloudwatch', region_name='eu-central-1') - widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - client.put_dashboard(DashboardName='test2', DashboardBody=widget) - client.put_dashboard(DashboardName='test3', DashboardBody=widget) - # Doesnt delete anything if all dashboards to be deleted do not exist - try: - client.delete_dashboards(DashboardNames=['test2', 'test1', 'test_no_match']) - except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFound') - else: - raise RuntimeError('Should of raised error') - - resp = client.list_dashboards() - len(resp['DashboardEntries']).should.equal(3) - - -@mock_cloudwatch -def test_get_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') - widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - - resp = client.get_dashboard(DashboardName='test1') - resp.should.contain('DashboardArn') - resp.should.contain('DashboardBody') - resp['DashboardName'].should.equal('test1') - - -@mock_cloudwatch -def test_get_dashboard_fail(): - client = boto3.client('cloudwatch', region_name='eu-central-1') - - try: - client.get_dashboard(DashboardName='test1') - except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFound') - else: - raise RuntimeError('Should of raised error') - - -@mock_cloudwatch -def test_alarm_state(): - client = boto3.client('cloudwatch', region_name='eu-central-1') - - client.put_metric_alarm( - AlarmName='testalarm1', - MetricName='cpu', - Namespace='blah', - Period=10, - EvaluationPeriods=5, - Statistic='Average', - Threshold=2, - ComparisonOperator='GreaterThanThreshold', - ) - client.put_metric_alarm( - AlarmName='testalarm2', - MetricName='cpu', - Namespace='blah', - Period=10, - EvaluationPeriods=5, - Statistic='Average', - Threshold=2, - ComparisonOperator='GreaterThanThreshold', - ) - - # This is tested implicitly as if it doesnt work the rest will die - client.set_alarm_state( - AlarmName='testalarm1', - StateValue='ALARM', - StateReason='testreason', - StateReasonData='{"some": "json_data"}' - ) - - resp = client.describe_alarms( - StateValue='ALARM' - ) - len(resp['MetricAlarms']).should.equal(1) - resp['MetricAlarms'][0]['AlarmName'].should.equal('testalarm1') - resp['MetricAlarms'][0]['StateValue'].should.equal('ALARM') - - resp = client.describe_alarms( - StateValue='OK' - ) - len(resp['MetricAlarms']).should.equal(1) - resp['MetricAlarms'][0]['AlarmName'].should.equal('testalarm2') - resp['MetricAlarms'][0]['StateValue'].should.equal('OK') - - # Just for sanity - resp = client.describe_alarms() - len(resp['MetricAlarms']).should.equal(2) - - -@mock_cloudwatch -def test_put_metric_data_no_dimensions(): - conn = boto3.client('cloudwatch', region_name='us-east-1') - - conn.put_metric_data( - Namespace='tester', - MetricData=[ - dict( - MetricName='metric', - Value=1.5, - ) - ] - ) - - metrics = conn.list_metrics()['Metrics'] - metrics.should.have.length_of(1) - metric = metrics[0] - metric['Namespace'].should.equal('tester') - metric['MetricName'].should.equal('metric') - - - -@mock_cloudwatch -def test_put_metric_data_with_statistics(): - conn = boto3.client('cloudwatch', region_name='us-east-1') - - conn.put_metric_data( - Namespace='tester', - MetricData=[ - dict( - MetricName='statmetric', - Timestamp=datetime(2015, 1, 1), - # no Value to test https://github.com/spulec/moto/issues/1615 - StatisticValues=dict( - SampleCount=123.0, - Sum=123.0, - Minimum=123.0, - Maximum=123.0 - ), - Unit='Milliseconds', - StorageResolution=123 - ) - ] - ) - - metrics = conn.list_metrics()['Metrics'] - metrics.should.have.length_of(1) - metric = metrics[0] - metric['Namespace'].should.equal('tester') - metric['MetricName'].should.equal('statmetric') - # TODO: test statistics - https://github.com/spulec/moto/issues/1615 - -@mock_cloudwatch -def test_get_metric_statistics(): - conn = boto3.client('cloudwatch', region_name='us-east-1') - utc_now = datetime.now(tz=pytz.utc) - - conn.put_metric_data( - Namespace='tester', - MetricData=[ - dict( - MetricName='metric', - Value=1.5, - Timestamp=utc_now - ) - ] - ) - - stats = conn.get_metric_statistics( - Namespace='tester', - MetricName='metric', - StartTime=utc_now - timedelta(seconds=60), - EndTime=utc_now + timedelta(seconds=60), - Period=60, - Statistics=['SampleCount', 'Sum'] - ) - - stats['Datapoints'].should.have.length_of(1) - datapoint = stats['Datapoints'][0] - datapoint['SampleCount'].should.equal(1.0) - datapoint['Sum'].should.equal(1.5) +# from __future__ import unicode_literals + +import boto3 +from botocore.exceptions import ClientError +from datetime import datetime, timedelta +from nose.tools import assert_raises +from uuid import uuid4 +import pytz +import sure # noqa + +from moto import mock_cloudwatch + + +@mock_cloudwatch +def test_put_list_dashboard(): + client = boto3.client("cloudwatch", region_name="eu-central-1") + widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' + + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + resp = client.list_dashboards() + + len(resp["DashboardEntries"]).should.equal(1) + + +@mock_cloudwatch +def test_put_list_prefix_nomatch_dashboard(): + client = boto3.client("cloudwatch", region_name="eu-central-1") + widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' + + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + resp = client.list_dashboards(DashboardNamePrefix="nomatch") + + len(resp["DashboardEntries"]).should.equal(0) + + +@mock_cloudwatch +def test_delete_dashboard(): + client = boto3.client("cloudwatch", region_name="eu-central-1") + widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' + + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + client.put_dashboard(DashboardName="test2", DashboardBody=widget) + client.put_dashboard(DashboardName="test3", DashboardBody=widget) + client.delete_dashboards(DashboardNames=["test2", "test1"]) + + resp = client.list_dashboards(DashboardNamePrefix="test3") + len(resp["DashboardEntries"]).should.equal(1) + + +@mock_cloudwatch +def test_delete_dashboard_fail(): + client = boto3.client("cloudwatch", region_name="eu-central-1") + widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' + + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + client.put_dashboard(DashboardName="test2", DashboardBody=widget) + client.put_dashboard(DashboardName="test3", DashboardBody=widget) + # Doesnt delete anything if all dashboards to be deleted do not exist + try: + client.delete_dashboards(DashboardNames=["test2", "test1", "test_no_match"]) + except ClientError as err: + err.response["Error"]["Code"].should.equal("ResourceNotFound") + else: + raise RuntimeError("Should of raised error") + + resp = client.list_dashboards() + len(resp["DashboardEntries"]).should.equal(3) + + +@mock_cloudwatch +def test_get_dashboard(): + client = boto3.client("cloudwatch", region_name="eu-central-1") + widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + + resp = client.get_dashboard(DashboardName="test1") + resp.should.contain("DashboardArn") + resp.should.contain("DashboardBody") + resp["DashboardName"].should.equal("test1") + + +@mock_cloudwatch +def test_get_dashboard_fail(): + client = boto3.client("cloudwatch", region_name="eu-central-1") + + try: + client.get_dashboard(DashboardName="test1") + except ClientError as err: + err.response["Error"]["Code"].should.equal("ResourceNotFound") + else: + raise RuntimeError("Should of raised error") + + +@mock_cloudwatch +def test_alarm_state(): + client = boto3.client("cloudwatch", region_name="eu-central-1") + + client.put_metric_alarm( + AlarmName="testalarm1", + MetricName="cpu", + Namespace="blah", + Period=10, + EvaluationPeriods=5, + Statistic="Average", + Threshold=2, + ComparisonOperator="GreaterThanThreshold", + ) + client.put_metric_alarm( + AlarmName="testalarm2", + MetricName="cpu", + Namespace="blah", + Period=10, + EvaluationPeriods=5, + Statistic="Average", + Threshold=2, + ComparisonOperator="GreaterThanThreshold", + ) + + # This is tested implicitly as if it doesnt work the rest will die + client.set_alarm_state( + AlarmName="testalarm1", + StateValue="ALARM", + StateReason="testreason", + StateReasonData='{"some": "json_data"}', + ) + + resp = client.describe_alarms(StateValue="ALARM") + len(resp["MetricAlarms"]).should.equal(1) + resp["MetricAlarms"][0]["AlarmName"].should.equal("testalarm1") + resp["MetricAlarms"][0]["StateValue"].should.equal("ALARM") + + resp = client.describe_alarms(StateValue="OK") + len(resp["MetricAlarms"]).should.equal(1) + resp["MetricAlarms"][0]["AlarmName"].should.equal("testalarm2") + resp["MetricAlarms"][0]["StateValue"].should.equal("OK") + + # Just for sanity + resp = client.describe_alarms() + len(resp["MetricAlarms"]).should.equal(2) + + +@mock_cloudwatch +def test_put_metric_data_no_dimensions(): + conn = boto3.client("cloudwatch", region_name="us-east-1") + + conn.put_metric_data( + Namespace="tester", MetricData=[dict(MetricName="metric", Value=1.5)] + ) + + metrics = conn.list_metrics()["Metrics"] + metrics.should.have.length_of(1) + metric = metrics[0] + metric["Namespace"].should.equal("tester") + metric["MetricName"].should.equal("metric") + + +@mock_cloudwatch +def test_put_metric_data_with_statistics(): + conn = boto3.client("cloudwatch", region_name="us-east-1") + utc_now = datetime.now(tz=pytz.utc) + + conn.put_metric_data( + Namespace="tester", + MetricData=[ + dict( + MetricName="statmetric", + Timestamp=utc_now, + # no Value to test https://github.com/spulec/moto/issues/1615 + StatisticValues=dict( + SampleCount=123.0, Sum=123.0, Minimum=123.0, Maximum=123.0 + ), + Unit="Milliseconds", + StorageResolution=123, + ) + ], + ) + + metrics = conn.list_metrics()["Metrics"] + metrics.should.have.length_of(1) + metric = metrics[0] + metric["Namespace"].should.equal("tester") + metric["MetricName"].should.equal("statmetric") + # TODO: test statistics - https://github.com/spulec/moto/issues/1615 + + +@mock_cloudwatch +def test_get_metric_statistics(): + conn = boto3.client("cloudwatch", region_name="us-east-1") + utc_now = datetime.now(tz=pytz.utc) + + conn.put_metric_data( + Namespace="tester", + MetricData=[dict(MetricName="metric", Value=1.5, Timestamp=utc_now)], + ) + + stats = conn.get_metric_statistics( + Namespace="tester", + MetricName="metric", + StartTime=utc_now - timedelta(seconds=60), + EndTime=utc_now + timedelta(seconds=60), + Period=60, + Statistics=["SampleCount", "Sum"], + ) + + stats["Datapoints"].should.have.length_of(1) + datapoint = stats["Datapoints"][0] + datapoint["SampleCount"].should.equal(1.0) + datapoint["Sum"].should.equal(1.5) + + +@mock_cloudwatch +def test_list_metrics(): + cloudwatch = boto3.client("cloudwatch", "eu-west-1") + # Verify namespace has to exist + res = cloudwatch.list_metrics(Namespace="unknown/")["Metrics"] + res.should.be.empty + # Create some metrics to filter on + create_metrics(cloudwatch, namespace="list_test_1/", metrics=4, data_points=2) + create_metrics(cloudwatch, namespace="list_test_2/", metrics=4, data_points=2) + # Verify we can retrieve everything + res = cloudwatch.list_metrics()["Metrics"] + len(res).should.equal(16) # 2 namespaces * 4 metrics * 2 data points + # Verify we can filter by namespace/metric name + res = cloudwatch.list_metrics(Namespace="list_test_1/")["Metrics"] + len(res).should.equal(8) # 1 namespace * 4 metrics * 2 data points + res = cloudwatch.list_metrics(Namespace="list_test_1/", MetricName="metric1")[ + "Metrics" + ] + len(res).should.equal(2) # 1 namespace * 1 metrics * 2 data points + # Verify format + res.should.equal( + [ + {u"Namespace": "list_test_1/", u"Dimensions": [], u"MetricName": "metric1"}, + {u"Namespace": "list_test_1/", u"Dimensions": [], u"MetricName": "metric1"}, + ] + ) + # Verify unknown namespace still has no results + res = cloudwatch.list_metrics(Namespace="unknown/")["Metrics"] + res.should.be.empty + + +@mock_cloudwatch +def test_list_metrics_paginated(): + cloudwatch = boto3.client("cloudwatch", "eu-west-1") + # Verify that only a single page of metrics is returned + cloudwatch.list_metrics()["Metrics"].should.be.empty + # Verify we can't pass a random NextToken + with assert_raises(ClientError) as e: + cloudwatch.list_metrics(NextToken=str(uuid4())) + e.exception.response["Error"]["Message"].should.equal( + "Request parameter NextToken is invalid" + ) + # Add a boatload of metrics + create_metrics(cloudwatch, namespace="test", metrics=100, data_points=1) + # Verify that a single page is returned until we've reached 500 + first_page = cloudwatch.list_metrics() + first_page["Metrics"].shouldnt.be.empty + len(first_page["Metrics"]).should.equal(100) + create_metrics(cloudwatch, namespace="test", metrics=200, data_points=2) + first_page = cloudwatch.list_metrics() + len(first_page["Metrics"]).should.equal(500) + first_page.shouldnt.contain("NextToken") + # Verify that adding more data points results in pagination + create_metrics(cloudwatch, namespace="test", metrics=60, data_points=10) + first_page = cloudwatch.list_metrics() + len(first_page["Metrics"]).should.equal(500) + first_page["NextToken"].shouldnt.be.empty + # Retrieve second page - and verify there's more where that came from + second_page = cloudwatch.list_metrics(NextToken=first_page["NextToken"]) + len(second_page["Metrics"]).should.equal(500) + second_page.should.contain("NextToken") + # Last page should only have the last 100 results, and no NextToken (indicating that pagination is finished) + third_page = cloudwatch.list_metrics(NextToken=second_page["NextToken"]) + len(third_page["Metrics"]).should.equal(100) + third_page.shouldnt.contain("NextToken") + # Verify that we can't reuse an existing token + with assert_raises(ClientError) as e: + cloudwatch.list_metrics(NextToken=first_page["NextToken"]) + e.exception.response["Error"]["Message"].should.equal( + "Request parameter NextToken is invalid" + ) + + +def create_metrics(cloudwatch, namespace, metrics=5, data_points=5): + for i in range(0, metrics): + metric_name = "metric" + str(i) + for j in range(0, data_points): + cloudwatch.put_metric_data( + Namespace=namespace, + MetricData=[{"MetricName": metric_name, "Value": j, "Unit": "Seconds"}], + ) diff --git a/tests/test_codepipeline/test_codepipeline.py b/tests/test_codepipeline/test_codepipeline.py new file mode 100644 index 000000000..926d7f873 --- /dev/null +++ b/tests/test_codepipeline/test_codepipeline.py @@ -0,0 +1,877 @@ +import json +from datetime import datetime + +import boto3 +import sure # noqa +from botocore.exceptions import ClientError +from nose.tools import assert_raises + +from moto import mock_codepipeline, mock_iam + + +@mock_codepipeline +def test_create_pipeline(): + client = boto3.client("codepipeline", region_name="us-east-1") + + response = client.create_pipeline( + pipeline={ + "name": "test-pipeline", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + }, + tags=[{"key": "key", "value": "value"}], + ) + + response["pipeline"].should.equal( + { + "name": "test-pipeline", + "roleArn": "arn:aws:iam::123456789012:role/test-role", + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "runOrder": 1, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"}], + "inputArtifacts": [], + } + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + "runOrder": 1, + "configuration": {}, + "outputArtifacts": [], + "inputArtifacts": [], + } + ], + }, + ], + "version": 1, + } + ) + response["tags"].should.equal([{"key": "key", "value": "value"}]) + + +@mock_codepipeline +@mock_iam +def test_create_pipeline_errors(): + client = boto3.client("codepipeline", region_name="us-east-1") + client_iam = boto3.client("iam", region_name="us-east-1") + client.create_pipeline( + pipeline={ + "name": "test-pipeline", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + } + ) + + with assert_raises(ClientError) as e: + client.create_pipeline( + pipeline={ + "name": "test-pipeline", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + } + ) + ex = e.exception + ex.operation_name.should.equal("CreatePipeline") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidStructureException") + ex.response["Error"]["Message"].should.equal( + "A pipeline with the name 'test-pipeline' already exists in account '123456789012'" + ) + + with assert_raises(ClientError) as e: + client.create_pipeline( + pipeline={ + "name": "invalid-pipeline", + "roleArn": "arn:aws:iam::123456789012:role/not-existing", + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "runOrder": 1, + }, + ], + }, + ], + } + ) + ex = e.exception + ex.operation_name.should.equal("CreatePipeline") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidStructureException") + ex.response["Error"]["Message"].should.equal( + "CodePipeline is not authorized to perform AssumeRole on role arn:aws:iam::123456789012:role/not-existing" + ) + + wrong_role_arn = client_iam.create_role( + RoleName="wrong-role", + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "s3.amazonaws.com"}, + "Action": "sts:AssumeRole", + } + ], + } + ), + )["Role"]["Arn"] + + with assert_raises(ClientError) as e: + client.create_pipeline( + pipeline={ + "name": "invalid-pipeline", + "roleArn": wrong_role_arn, + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "runOrder": 1, + }, + ], + }, + ], + } + ) + ex = e.exception + ex.operation_name.should.equal("CreatePipeline") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidStructureException") + ex.response["Error"]["Message"].should.equal( + "CodePipeline is not authorized to perform AssumeRole on role arn:aws:iam::123456789012:role/wrong-role" + ) + + with assert_raises(ClientError) as e: + client.create_pipeline( + pipeline={ + "name": "invalid-pipeline", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "runOrder": 1, + }, + ], + }, + ], + } + ) + ex = e.exception + ex.operation_name.should.equal("CreatePipeline") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidStructureException") + ex.response["Error"]["Message"].should.equal( + "Pipeline has only 1 stage(s). There should be a minimum of 2 stages in a pipeline" + ) + + +@mock_codepipeline +def test_get_pipeline(): + client = boto3.client("codepipeline", region_name="us-east-1") + client.create_pipeline( + pipeline={ + "name": "test-pipeline", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + }, + tags=[{"key": "key", "value": "value"}], + ) + + response = client.get_pipeline(name="test-pipeline") + + response["pipeline"].should.equal( + { + "name": "test-pipeline", + "roleArn": "arn:aws:iam::123456789012:role/test-role", + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "runOrder": 1, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"}], + "inputArtifacts": [], + } + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + "runOrder": 1, + "configuration": {}, + "outputArtifacts": [], + "inputArtifacts": [], + } + ], + }, + ], + "version": 1, + } + ) + response["metadata"]["pipelineArn"].should.equal( + "arn:aws:codepipeline:us-east-1:123456789012:test-pipeline" + ) + response["metadata"]["created"].should.be.a(datetime) + response["metadata"]["updated"].should.be.a(datetime) + + +@mock_codepipeline +def test_get_pipeline_errors(): + client = boto3.client("codepipeline", region_name="us-east-1") + + with assert_raises(ClientError) as e: + client.get_pipeline(name="not-existing") + ex = e.exception + ex.operation_name.should.equal("GetPipeline") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("PipelineNotFoundException") + ex.response["Error"]["Message"].should.equal( + "Account '123456789012' does not have a pipeline with name 'not-existing'" + ) + + +@mock_codepipeline +def test_update_pipeline(): + client = boto3.client("codepipeline", region_name="us-east-1") + role_arn = get_role_arn() + client.create_pipeline( + pipeline={ + "name": "test-pipeline", + "roleArn": role_arn, + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + }, + tags=[{"key": "key", "value": "value"}], + ) + + response = client.get_pipeline(name="test-pipeline") + created_time = response["metadata"]["created"] + updated_time = response["metadata"]["updated"] + + response = client.update_pipeline( + pipeline={ + "name": "test-pipeline", + "roleArn": role_arn, + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "different-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + } + ) + + response["pipeline"].should.equal( + { + "name": "test-pipeline", + "roleArn": "arn:aws:iam::123456789012:role/test-role", + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "runOrder": 1, + "configuration": { + "S3Bucket": "different-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"}], + "inputArtifacts": [], + } + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + "runOrder": 1, + "configuration": {}, + "outputArtifacts": [], + "inputArtifacts": [], + } + ], + }, + ], + "version": 2, + } + ) + + metadata = client.get_pipeline(name="test-pipeline")["metadata"] + metadata["created"].should.equal(created_time) + metadata["updated"].should.be.greater_than(updated_time) + + +@mock_codepipeline +def test_update_pipeline_errors(): + client = boto3.client("codepipeline", region_name="us-east-1") + + with assert_raises(ClientError) as e: + client.update_pipeline( + pipeline={ + "name": "not-existing", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + } + ) + ex = e.exception + ex.operation_name.should.equal("UpdatePipeline") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("ResourceNotFoundException") + ex.response["Error"]["Message"].should.equal( + "The account with id '123456789012' does not include a pipeline with the name 'not-existing'" + ) + + +@mock_codepipeline +def test_list_pipelines(): + client = boto3.client("codepipeline", region_name="us-east-1") + client.create_pipeline( + pipeline={ + "name": "test-pipeline-1", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + }, + ) + client.create_pipeline( + pipeline={ + "name": "test-pipeline-2", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + }, + ) + + response = client.list_pipelines() + + response["pipelines"].should.have.length_of(2) + response["pipelines"][0]["name"].should.equal("test-pipeline-1") + response["pipelines"][0]["version"].should.equal(1) + response["pipelines"][0]["created"].should.be.a(datetime) + response["pipelines"][0]["updated"].should.be.a(datetime) + response["pipelines"][1]["name"].should.equal("test-pipeline-2") + response["pipelines"][1]["version"].should.equal(1) + response["pipelines"][1]["created"].should.be.a(datetime) + response["pipelines"][1]["updated"].should.be.a(datetime) + + +@mock_codepipeline +def test_delete_pipeline(): + client = boto3.client("codepipeline", region_name="us-east-1") + client.create_pipeline( + pipeline={ + "name": "test-pipeline", + "roleArn": get_role_arn(), + "artifactStore": { + "type": "S3", + "location": "codepipeline-us-east-1-123456789012", + }, + "stages": [ + { + "name": "Stage-1", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Source", + "owner": "AWS", + "provider": "S3", + "version": "1", + }, + "configuration": { + "S3Bucket": "test-bucket", + "S3ObjectKey": "test-object", + }, + "outputArtifacts": [{"name": "artifact"},], + }, + ], + }, + { + "name": "Stage-2", + "actions": [ + { + "name": "Action-1", + "actionTypeId": { + "category": "Approval", + "owner": "AWS", + "provider": "Manual", + "version": "1", + }, + }, + ], + }, + ], + }, + ) + client.list_pipelines()["pipelines"].should.have.length_of(1) + + client.delete_pipeline(name="test-pipeline") + + client.list_pipelines()["pipelines"].should.have.length_of(0) + + # deleting a not existing pipeline, should raise no exception + client.delete_pipeline(name="test-pipeline") + + +@mock_iam +def get_role_arn(): + iam = boto3.client("iam", region_name="us-east-1") + try: + return iam.get_role(RoleName="test-role")["Role"]["Arn"] + except ClientError: + return iam.create_role( + RoleName="test-role", + AssumeRolePolicyDocument=json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": {"Service": "codepipeline.amazonaws.com"}, + "Action": "sts:AssumeRole", + } + ], + } + ), + )["Role"]["Arn"] diff --git a/tests/test_cognitoidentity/test_cognitoidentity.py b/tests/test_cognitoidentity/test_cognitoidentity.py index ea9ccbc78..8eae183c6 100644 --- a/tests/test_cognitoidentity/test_cognitoidentity.py +++ b/tests/test_cognitoidentity/test_cognitoidentity.py @@ -1,97 +1,150 @@ from __future__ import unicode_literals import boto3 +from botocore.exceptions import ClientError +from nose.tools import assert_raises from moto import mock_cognitoidentity -import sure # noqa - from moto.cognitoidentity.utils import get_random_identity_id +from moto.core import ACCOUNT_ID @mock_cognitoidentity def test_create_identity_pool(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") - result = conn.create_identity_pool(IdentityPoolName='TestPool', + result = conn.create_identity_pool( + IdentityPoolName="TestPool", AllowUnauthenticatedIdentities=False, - SupportedLoginProviders={'graph.facebook.com': '123456789012345'}, - DeveloperProviderName='devname', - OpenIdConnectProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'], + SupportedLoginProviders={"graph.facebook.com": "123456789012345"}, + DeveloperProviderName="devname", + OpenIdConnectProviderARNs=[ + "arn:aws:rds:eu-west-2:{}:db:mysql-db".format(ACCOUNT_ID) + ], CognitoIdentityProviders=[ { - 'ProviderName': 'testprovider', - 'ClientId': 'CLIENT12345', - 'ServerSideTokenCheck': True - }, + "ProviderName": "testprovider", + "ClientId": "CLIENT12345", + "ServerSideTokenCheck": True, + } ], - SamlProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db']) - assert result['IdentityPoolId'] != '' + SamlProviderARNs=["arn:aws:rds:eu-west-2:{}:db:mysql-db".format(ACCOUNT_ID)], + ) + assert result["IdentityPoolId"] != "" + + +@mock_cognitoidentity +def test_describe_identity_pool(): + conn = boto3.client("cognito-identity", "us-west-2") + + res = conn.create_identity_pool( + IdentityPoolName="TestPool", + AllowUnauthenticatedIdentities=False, + SupportedLoginProviders={"graph.facebook.com": "123456789012345"}, + DeveloperProviderName="devname", + OpenIdConnectProviderARNs=[ + "arn:aws:rds:eu-west-2:{}:db:mysql-db".format(ACCOUNT_ID) + ], + CognitoIdentityProviders=[ + { + "ProviderName": "testprovider", + "ClientId": "CLIENT12345", + "ServerSideTokenCheck": True, + } + ], + SamlProviderARNs=["arn:aws:rds:eu-west-2:{}:db:mysql-db".format(ACCOUNT_ID)], + ) + + result = conn.describe_identity_pool(IdentityPoolId=res["IdentityPoolId"]) + + assert result["IdentityPoolId"] == res["IdentityPoolId"] + assert ( + result["AllowUnauthenticatedIdentities"] + == res["AllowUnauthenticatedIdentities"] + ) + assert result["SupportedLoginProviders"] == res["SupportedLoginProviders"] + assert result["DeveloperProviderName"] == res["DeveloperProviderName"] + assert result["OpenIdConnectProviderARNs"] == res["OpenIdConnectProviderARNs"] + assert result["CognitoIdentityProviders"] == res["CognitoIdentityProviders"] + assert result["SamlProviderARNs"] == res["SamlProviderARNs"] + + +@mock_cognitoidentity +def test_describe_identity_pool_with_invalid_id_raises_error(): + conn = boto3.client("cognito-identity", "us-west-2") + + with assert_raises(ClientError) as cm: + conn.describe_identity_pool(IdentityPoolId="us-west-2_non-existent") + + cm.exception.operation_name.should.equal("DescribeIdentityPool") + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) # testing a helper function def test_get_random_identity_id(): - assert len(get_random_identity_id('us-west-2')) > 0 - assert len(get_random_identity_id('us-west-2').split(':')[1]) == 19 + assert len(get_random_identity_id("us-west-2")) > 0 + assert len(get_random_identity_id("us-west-2").split(":")[1]) == 19 @mock_cognitoidentity def test_get_id(): # These two do NOT work in server mode. They just don't return the data from the model. - conn = boto3.client('cognito-identity', 'us-west-2') - result = conn.get_id(AccountId='someaccount', - IdentityPoolId='us-west-2:12345', - Logins={ - 'someurl': '12345' - }) + conn = boto3.client("cognito-identity", "us-west-2") + result = conn.get_id( + AccountId="someaccount", + IdentityPoolId="us-west-2:12345", + Logins={"someurl": "12345"}, + ) print(result) - assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 + assert ( + result.get("IdentityId", "").startswith("us-west-2") + or result.get("ResponseMetadata").get("HTTPStatusCode") == 200 + ) @mock_cognitoidentity def test_get_credentials_for_identity(): # These two do NOT work in server mode. They just don't return the data from the model. - conn = boto3.client('cognito-identity', 'us-west-2') - result = conn.get_credentials_for_identity(IdentityId='12345') + conn = boto3.client("cognito-identity", "us-west-2") + result = conn.get_credentials_for_identity(IdentityId="12345") - assert result.get('Expiration', 0) > 0 or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 - assert result.get('IdentityId') == '12345' or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 + assert ( + result.get("Expiration", 0) > 0 + or result.get("ResponseMetadata").get("HTTPStatusCode") == 200 + ) + assert ( + result.get("IdentityId") == "12345" + or result.get("ResponseMetadata").get("HTTPStatusCode") == 200 + ) @mock_cognitoidentity def test_get_open_id_token_for_developer_identity(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") result = conn.get_open_id_token_for_developer_identity( - IdentityPoolId='us-west-2:12345', - IdentityId='12345', - Logins={ - 'someurl': '12345' - }, - TokenDuration=123 + IdentityPoolId="us-west-2:12345", + IdentityId="12345", + Logins={"someurl": "12345"}, + TokenDuration=123, ) - assert len(result['Token']) > 0 - assert result['IdentityId'] == '12345' + assert len(result["Token"]) > 0 + assert result["IdentityId"] == "12345" + @mock_cognitoidentity def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") result = conn.get_open_id_token_for_developer_identity( - IdentityPoolId='us-west-2:12345', - Logins={ - 'someurl': '12345' - }, - TokenDuration=123 + IdentityPoolId="us-west-2:12345", Logins={"someurl": "12345"}, TokenDuration=123 ) - assert len(result['Token']) > 0 - assert len(result['IdentityId']) > 0 + assert len(result["Token"]) > 0 + assert len(result["IdentityId"]) > 0 + @mock_cognitoidentity def test_get_open_id_token(): - conn = boto3.client('cognito-identity', 'us-west-2') - result = conn.get_open_id_token( - IdentityId='12345', - Logins={ - 'someurl': '12345' - } - ) - assert len(result['Token']) > 0 - assert result['IdentityId'] == '12345' + conn = boto3.client("cognito-identity", "us-west-2") + result = conn.get_open_id_token(IdentityId="12345", Logins={"someurl": "12345"}) + assert len(result["Token"]) > 0 + assert result["IdentityId"] == "12345" diff --git a/tests/test_cognitoidentity/test_server.py b/tests/test_cognitoidentity/test_server.py index d093158c5..903dae290 100644 --- a/tests/test_cognitoidentity/test_server.py +++ b/tests/test_cognitoidentity/test_server.py @@ -1,45 +1,53 @@ -from __future__ import unicode_literals - -import json -import sure # noqa - -import moto.server as server -from moto import mock_cognitoidentity - -''' -Test the different server responses -''' - - -@mock_cognitoidentity -def test_create_identity_pool(): - - backend = server.create_backend_app("cognito-identity") - test_client = backend.test_client() - - res = test_client.post('/', - data={"IdentityPoolName": "test", "AllowUnauthenticatedIdentities": True}, - headers={ - "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.CreateIdentityPool"}, - ) - - json_data = json.loads(res.data.decode("utf-8")) - assert json_data['IdentityPoolName'] == "test" - - -@mock_cognitoidentity -def test_get_id(): - backend = server.create_backend_app("cognito-identity") - test_client = backend.test_client() - - res = test_client.post('/', - data=json.dumps({'AccountId': 'someaccount', - 'IdentityPoolId': 'us-west-2:12345', - 'Logins': {'someurl': '12345'}}), - headers={ - "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.GetId"}, - ) - - print(res.data) - json_data = json.loads(res.data.decode("utf-8")) - assert ':' in json_data['IdentityId'] +from __future__ import unicode_literals + +import json +import sure # noqa + +import moto.server as server +from moto import mock_cognitoidentity + +""" +Test the different server responses +""" + + +@mock_cognitoidentity +def test_create_identity_pool(): + + backend = server.create_backend_app("cognito-identity") + test_client = backend.test_client() + + res = test_client.post( + "/", + data={"IdentityPoolName": "test", "AllowUnauthenticatedIdentities": True}, + headers={ + "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.CreateIdentityPool" + }, + ) + + json_data = json.loads(res.data.decode("utf-8")) + assert json_data["IdentityPoolName"] == "test" + + +@mock_cognitoidentity +def test_get_id(): + backend = server.create_backend_app("cognito-identity") + test_client = backend.test_client() + + res = test_client.post( + "/", + data=json.dumps( + { + "AccountId": "someaccount", + "IdentityPoolId": "us-west-2:12345", + "Logins": {"someurl": "12345"}, + } + ), + headers={ + "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.GetId" + }, + ) + + print(res.data) + json_data = json.loads(res.data.decode("utf-8")) + assert ":" in json_data["IdentityId"] diff --git a/tests/test_cognitoidp/test_cognitoidp.py b/tests/test_cognitoidp/test_cognitoidp.py index 774ff7621..7ac1038b0 100644 --- a/tests/test_cognitoidp/test_cognitoidp.py +++ b/tests/test_cognitoidp/test_cognitoidp.py @@ -6,6 +6,7 @@ import random import uuid import boto3 + # noinspection PyUnresolvedReferences import sure # noqa from botocore.exceptions import ClientError @@ -13,6 +14,7 @@ from jose import jws from nose.tools import assert_raises from moto import mock_cognitoidp +from moto.core import ACCOUNT_ID @mock_cognitoidp @@ -21,15 +23,10 @@ def test_create_user_pool(): name = str(uuid.uuid4()) value = str(uuid.uuid4()) - result = conn.create_user_pool( - PoolName=name, - LambdaConfig={ - "PreSignUp": value - } - ) + result = conn.create_user_pool(PoolName=name, LambdaConfig={"PreSignUp": value}) result["UserPool"]["Id"].should_not.be.none - result["UserPool"]["Id"].should.match(r'[\w-]+_[0-9a-zA-Z]+') + result["UserPool"]["Id"].should.match(r"[\w-]+_[0-9a-zA-Z]+") result["UserPool"]["Name"].should.equal(name) result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value) @@ -102,10 +99,7 @@ def test_describe_user_pool(): name = str(uuid.uuid4()) value = str(uuid.uuid4()) user_pool_details = conn.create_user_pool( - PoolName=name, - LambdaConfig={ - "PreSignUp": value - } + PoolName=name, LambdaConfig={"PreSignUp": value} ) result = conn.describe_user_pool(UserPoolId=user_pool_details["UserPool"]["Id"]) @@ -139,7 +133,9 @@ def test_create_user_pool_domain_custom_domain_config(): domain = str(uuid.uuid4()) custom_domain_config = { - "CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012", + "CertificateArn": "arn:aws:acm:us-east-1:{}:certificate/123456789012".format( + ACCOUNT_ID + ) } user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] result = conn.create_user_pool_domain( @@ -184,7 +180,9 @@ def test_update_user_pool_domain(): domain = str(uuid.uuid4()) custom_domain_config = { - "CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012", + "CertificateArn": "arn:aws:acm:us-east-1:{}:certificate/123456789012".format( + ACCOUNT_ID + ) } user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] conn.create_user_pool_domain(UserPoolId=user_pool_id, Domain=domain) @@ -203,9 +201,7 @@ def test_create_user_pool_client(): value = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] result = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=client_name, - CallbackURLs=[value], + UserPoolId=user_pool_id, ClientName=client_name, CallbackURLs=[value] ) result["UserPoolClient"]["UserPoolId"].should.equal(user_pool_id) @@ -236,11 +232,11 @@ def test_list_user_pool_clients_returns_max_items(): client_count = 10 for i in range(client_count): client_name = str(uuid.uuid4()) - conn.create_user_pool_client(UserPoolId=user_pool_id, - ClientName=client_name) + conn.create_user_pool_client(UserPoolId=user_pool_id, ClientName=client_name) max_results = 5 - result = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["UserPoolClients"].should.have.length_of(max_results) result.should.have.key("NextToken") @@ -254,18 +250,18 @@ def test_list_user_pool_clients_returns_next_tokens(): client_count = 10 for i in range(client_count): client_name = str(uuid.uuid4()) - conn.create_user_pool_client(UserPoolId=user_pool_id, - ClientName=client_name) + conn.create_user_pool_client(UserPoolId=user_pool_id, ClientName=client_name) max_results = 5 - result = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["UserPoolClients"].should.have.length_of(max_results) result.should.have.key("NextToken") next_token = result["NextToken"] - result_2 = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results, - NextToken=next_token) + result_2 = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results, NextToken=next_token + ) result_2["UserPoolClients"].should.have.length_of(max_results) result_2.shouldnt.have.key("NextToken") @@ -279,11 +275,11 @@ def test_list_user_pool_clients_when_max_items_more_than_total_items(): client_count = 10 for i in range(client_count): client_name = str(uuid.uuid4()) - conn.create_user_pool_client(UserPoolId=user_pool_id, - ClientName=client_name) + conn.create_user_pool_client(UserPoolId=user_pool_id, ClientName=client_name) max_results = client_count + 5 - result = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["UserPoolClients"].should.have.length_of(client_count) result.shouldnt.have.key("NextToken") @@ -296,14 +292,11 @@ def test_describe_user_pool_client(): value = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_details = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=client_name, - CallbackURLs=[value], + UserPoolId=user_pool_id, ClientName=client_name, CallbackURLs=[value] ) result = conn.describe_user_pool_client( - UserPoolId=user_pool_id, - ClientId=client_details["UserPoolClient"]["ClientId"], + UserPoolId=user_pool_id, ClientId=client_details["UserPoolClient"]["ClientId"] ) result["UserPoolClient"]["ClientName"].should.equal(client_name) @@ -321,9 +314,7 @@ def test_update_user_pool_client(): new_value = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_details = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=old_client_name, - CallbackURLs=[old_value], + UserPoolId=user_pool_id, ClientName=old_client_name, CallbackURLs=[old_value] ) result = conn.update_user_pool_client( @@ -344,13 +335,11 @@ def test_delete_user_pool_client(): user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_details = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=str(uuid.uuid4()), + UserPoolId=user_pool_id, ClientName=str(uuid.uuid4()) ) conn.delete_user_pool_client( - UserPoolId=user_pool_id, - ClientId=client_details["UserPoolClient"]["ClientId"], + UserPoolId=user_pool_id, ClientId=client_details["UserPoolClient"]["ClientId"] ) caught = False @@ -377,9 +366,7 @@ def test_create_identity_provider(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) result["IdentityProvider"]["UserPoolId"].should.equal(user_pool_id) @@ -402,10 +389,7 @@ def test_list_identity_providers(): ProviderDetails={}, ) - result = conn.list_identity_providers( - UserPoolId=user_pool_id, - MaxResults=10, - ) + result = conn.list_identity_providers(UserPoolId=user_pool_id, MaxResults=10) result["Providers"].should.have.length_of(1) result["Providers"][0]["ProviderName"].should.equal(provider_name) @@ -430,8 +414,9 @@ def test_list_identity_providers_returns_max_items(): ) max_results = 5 - result = conn.list_identity_providers(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["Providers"].should.have.length_of(max_results) result.should.have.key("NextToken") @@ -454,14 +439,16 @@ def test_list_identity_providers_returns_next_tokens(): ) max_results = 5 - result = conn.list_identity_providers(UserPoolId=user_pool_id, MaxResults=max_results) + result = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["Providers"].should.have.length_of(max_results) result.should.have.key("NextToken") next_token = result["NextToken"] - result_2 = conn.list_identity_providers(UserPoolId=user_pool_id, - MaxResults=max_results, - NextToken=next_token) + result_2 = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results, NextToken=next_token + ) result_2["Providers"].should.have.length_of(max_results) result_2.shouldnt.have.key("NextToken") @@ -484,7 +471,9 @@ def test_list_identity_providers_when_max_items_more_than_total_items(): ) max_results = identity_provider_count + 5 - result = conn.list_identity_providers(UserPoolId=user_pool_id, MaxResults=max_results) + result = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["Providers"].should.have.length_of(identity_provider_count) result.shouldnt.have.key("NextToken") @@ -501,14 +490,11 @@ def test_describe_identity_providers(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) result = conn.describe_identity_provider( - UserPoolId=user_pool_id, - ProviderName=provider_name, + UserPoolId=user_pool_id, ProviderName=provider_name ) result["IdentityProvider"]["UserPoolId"].should.equal(user_pool_id) @@ -530,17 +516,13 @@ def test_update_identity_provider(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) result = conn.update_identity_provider( UserPoolId=user_pool_id, ProviderName=provider_name, - ProviderDetails={ - "thing": new_value - }, + ProviderDetails={"thing": new_value}, ) result["IdentityProvider"]["UserPoolId"].should.equal(user_pool_id) @@ -557,16 +539,12 @@ def test_update_identity_provider_no_user_pool(): with assert_raises(conn.exceptions.ResourceNotFoundException) as cm: conn.update_identity_provider( - UserPoolId="foo", - ProviderName="bar", - ProviderDetails={ - "thing": new_value - }, + UserPoolId="foo", ProviderName="bar", ProviderDetails={"thing": new_value} ) - cm.exception.operation_name.should.equal('UpdateIdentityProvider') - cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.operation_name.should.equal("UpdateIdentityProvider") + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) @mock_cognitoidp @@ -583,14 +561,12 @@ def test_update_identity_provider_no_identity_provider(): conn.update_identity_provider( UserPoolId=user_pool_id, ProviderName="foo", - ProviderDetails={ - "thing": new_value - }, + ProviderDetails={"thing": new_value}, ) - cm.exception.operation_name.should.equal('UpdateIdentityProvider') - cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.operation_name.should.equal("UpdateIdentityProvider") + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) @mock_cognitoidp @@ -605,9 +581,7 @@ def test_delete_identity_providers(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) conn.delete_identity_provider(UserPoolId=user_pool_id, ProviderName=provider_name) @@ -615,8 +589,7 @@ def test_delete_identity_providers(): caught = False try: conn.describe_identity_provider( - UserPoolId=user_pool_id, - ProviderName=provider_name, + UserPoolId=user_pool_id, ProviderName=provider_name ) except conn.exceptions.ResourceNotFoundException: caught = True @@ -662,9 +635,9 @@ def test_create_group_with_duplicate_name_raises_error(): with assert_raises(ClientError) as cm: conn.create_group(GroupName=group_name, UserPoolId=user_pool_id) - cm.exception.operation_name.should.equal('CreateGroup') - cm.exception.response['Error']['Code'].should.equal('GroupExistsException') - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.operation_name.should.equal("CreateGroup") + cm.exception.response["Error"]["Code"].should.equal("GroupExistsException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) @mock_cognitoidp @@ -710,7 +683,7 @@ def test_delete_group(): with assert_raises(ClientError) as cm: conn.get_group(GroupName=group_name, UserPoolId=user_pool_id) - cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") @mock_cognitoidp @@ -724,7 +697,9 @@ def test_admin_add_user_to_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - result = conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + result = conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected @@ -739,8 +714,12 @@ def test_admin_add_user_to_group_again_is_noop(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) @mock_cognitoidp @@ -754,7 +733,9 @@ def test_list_users_in_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) result = conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name) @@ -775,8 +756,12 @@ def test_list_users_in_group_ignores_deleted_user(): username2 = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username2) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username2, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username2, GroupName=group_name + ) conn.admin_delete_user(UserPoolId=user_pool_id, Username=username) result = conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name) @@ -796,7 +781,9 @@ def test_admin_list_groups_for_user(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) result = conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id) @@ -817,8 +804,12 @@ def test_admin_list_groups_for_user_ignores_deleted_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name2) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name2 + ) conn.delete_group(GroupName=group_name, UserPoolId=user_pool_id) result = conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id) @@ -838,14 +829,20 @@ def test_admin_remove_user_from_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) - result = conn.admin_remove_user_from_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + result = conn.admin_remove_user_from_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected - conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name) \ - ["Users"].should.have.length_of(0) - conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id) \ - ["Groups"].should.have.length_of(0) + conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name)[ + "Users" + ].should.have.length_of(0) + conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id)[ + "Groups" + ].should.have.length_of(0) @mock_cognitoidp @@ -859,8 +856,12 @@ def test_admin_remove_user_from_group_again_is_noop(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) @mock_cognitoidp @@ -873,9 +874,7 @@ def test_admin_create_user(): result = conn.admin_create_user( UserPoolId=user_pool_id, Username=username, - UserAttributes=[ - {"Name": "thing", "Value": value} - ], + UserAttributes=[{"Name": "thing", "Value": value}], ) result["User"]["Username"].should.equal(username) @@ -886,6 +885,32 @@ def test_admin_create_user(): result["User"]["Enabled"].should.equal(True) +@mock_cognitoidp +def test_admin_create_existing_user(): + conn = boto3.client("cognito-idp", "us-west-2") + + username = str(uuid.uuid4()) + value = str(uuid.uuid4()) + user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] + conn.admin_create_user( + UserPoolId=user_pool_id, + Username=username, + UserAttributes=[{"Name": "thing", "Value": value}], + ) + + caught = False + try: + conn.admin_create_user( + UserPoolId=user_pool_id, + Username=username, + UserAttributes=[{"Name": "thing", "Value": value}], + ) + except conn.exceptions.UsernameExistsException: + caught = True + + caught.should.be.true + + @mock_cognitoidp def test_admin_get_user(): conn = boto3.client("cognito-idp", "us-west-2") @@ -896,9 +921,7 @@ def test_admin_get_user(): conn.admin_create_user( UserPoolId=user_pool_id, Username=username, - UserAttributes=[ - {"Name": "thing", "Value": value} - ], + UserAttributes=[{"Name": "thing", "Value": value}], ) result = conn.admin_get_user(UserPoolId=user_pool_id, Username=username) @@ -944,8 +967,7 @@ def test_list_users_returns_limit_items(): # Given 10 users user_count = 10 for i in range(user_count): - conn.admin_create_user(UserPoolId=user_pool_id, - Username=str(uuid.uuid4())) + conn.admin_create_user(UserPoolId=user_pool_id, Username=str(uuid.uuid4())) max_results = 5 result = conn.list_users(UserPoolId=user_pool_id, Limit=max_results) result["Users"].should.have.length_of(max_results) @@ -960,8 +982,7 @@ def test_list_users_returns_pagination_tokens(): # Given 10 users user_count = 10 for i in range(user_count): - conn.admin_create_user(UserPoolId=user_pool_id, - Username=str(uuid.uuid4())) + conn.admin_create_user(UserPoolId=user_pool_id, Username=str(uuid.uuid4())) max_results = 5 result = conn.list_users(UserPoolId=user_pool_id, Limit=max_results) @@ -969,8 +990,9 @@ def test_list_users_returns_pagination_tokens(): result.should.have.key("PaginationToken") next_token = result["PaginationToken"] - result_2 = conn.list_users(UserPoolId=user_pool_id, - Limit=max_results, PaginationToken=next_token) + result_2 = conn.list_users( + UserPoolId=user_pool_id, Limit=max_results, PaginationToken=next_token + ) result_2["Users"].should.have.length_of(max_results) result_2.shouldnt.have.key("PaginationToken") @@ -983,8 +1005,7 @@ def test_list_users_when_limit_more_than_total_items(): # Given 10 users user_count = 10 for i in range(user_count): - conn.admin_create_user(UserPoolId=user_pool_id, - Username=str(uuid.uuid4())) + conn.admin_create_user(UserPoolId=user_pool_id, Username=str(uuid.uuid4())) max_results = user_count + 5 result = conn.list_users(UserPoolId=user_pool_id, Limit=max_results) @@ -1003,8 +1024,9 @@ def test_admin_disable_user(): result = conn.admin_disable_user(UserPoolId=user_pool_id, Username=username) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected - conn.admin_get_user(UserPoolId=user_pool_id, Username=username) \ - ["Enabled"].should.equal(False) + conn.admin_get_user(UserPoolId=user_pool_id, Username=username)[ + "Enabled" + ].should.equal(False) @mock_cognitoidp @@ -1019,8 +1041,9 @@ def test_admin_enable_user(): result = conn.admin_enable_user(UserPoolId=user_pool_id, Username=username) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected - conn.admin_get_user(UserPoolId=user_pool_id, Username=username) \ - ["Enabled"].should.equal(True) + conn.admin_get_user(UserPoolId=user_pool_id, Username=username)[ + "Enabled" + ].should.equal(True) @mock_cognitoidp @@ -1050,27 +1073,21 @@ def authentication_flow(conn): client_id = conn.create_user_pool_client( UserPoolId=user_pool_id, ClientName=str(uuid.uuid4()), - ReadAttributes=[user_attribute_name] + ReadAttributes=[user_attribute_name], )["UserPoolClient"]["ClientId"] conn.admin_create_user( UserPoolId=user_pool_id, Username=username, TemporaryPassword=temporary_password, - UserAttributes=[{ - 'Name': user_attribute_name, - 'Value': user_attribute_value - }] + UserAttributes=[{"Name": user_attribute_name, "Value": user_attribute_value}], ) result = conn.admin_initiate_auth( UserPoolId=user_pool_id, ClientId=client_id, AuthFlow="ADMIN_NO_SRP_AUTH", - AuthParameters={ - "USERNAME": username, - "PASSWORD": temporary_password - }, + AuthParameters={"USERNAME": username, "PASSWORD": temporary_password}, ) # A newly created user is forced to set a new password @@ -1083,10 +1100,7 @@ def authentication_flow(conn): Session=result["Session"], ClientId=client_id, ChallengeName="NEW_PASSWORD_REQUIRED", - ChallengeResponses={ - "USERNAME": username, - "NEW_PASSWORD": new_password - } + ChallengeResponses={"USERNAME": username, "NEW_PASSWORD": new_password}, ) result["AuthenticationResult"]["IdToken"].should_not.be.none @@ -1099,9 +1113,7 @@ def authentication_flow(conn): "access_token": result["AuthenticationResult"]["AccessToken"], "username": username, "password": new_password, - "additional_fields": { - user_attribute_name: user_attribute_value - } + "additional_fields": {user_attribute_name: user_attribute_value}, } @@ -1124,7 +1136,9 @@ def test_token_legitimacy(): id_token = outputs["id_token"] access_token = outputs["access_token"] client_id = outputs["client_id"] - issuer = "https://cognito-idp.us-west-2.amazonaws.com/{}".format(outputs["user_pool_id"]) + issuer = "https://cognito-idp.us-west-2.amazonaws.com/{}".format( + outputs["user_pool_id"] + ) id_claims = json.loads(jws.verify(id_token, json_web_key, "RS256")) id_claims["iss"].should.equal(issuer) id_claims["aud"].should.equal(client_id) @@ -1155,10 +1169,7 @@ def test_change_password(): UserPoolId=outputs["user_pool_id"], ClientId=outputs["client_id"], AuthFlow="ADMIN_NO_SRP_AUTH", - AuthParameters={ - "USERNAME": outputs["username"], - "PASSWORD": newer_password, - }, + AuthParameters={"USERNAME": outputs["username"], "PASSWORD": newer_password}, ) result["AuthenticationResult"].should_not.be.none @@ -1168,7 +1179,9 @@ def test_change_password(): def test_forgot_password(): conn = boto3.client("cognito-idp", "us-west-2") - result = conn.forgot_password(ClientId=str(uuid.uuid4()), Username=str(uuid.uuid4())) + result = conn.forgot_password( + ClientId=str(uuid.uuid4()), Username=str(uuid.uuid4()) + ) result["CodeDeliveryDetails"].should_not.be.none @@ -1179,14 +1192,11 @@ def test_confirm_forgot_password(): username = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_id = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=str(uuid.uuid4()), + UserPoolId=user_pool_id, ClientName=str(uuid.uuid4()) )["UserPoolClient"]["ClientId"] conn.admin_create_user( - UserPoolId=user_pool_id, - Username=username, - TemporaryPassword=str(uuid.uuid4()), + UserPoolId=user_pool_id, Username=username, TemporaryPassword=str(uuid.uuid4()) ) conn.confirm_forgot_password( @@ -1196,6 +1206,7 @@ def test_confirm_forgot_password(): Password=str(uuid.uuid4()), ) + @mock_cognitoidp def test_admin_update_user_attributes(): conn = boto3.client("cognito-idp", "us-west-2") @@ -1207,41 +1218,26 @@ def test_admin_update_user_attributes(): UserPoolId=user_pool_id, Username=username, UserAttributes=[ - { - 'Name': 'family_name', - 'Value': 'Doe', - }, - { - 'Name': 'given_name', - 'Value': 'John', - } - ] + {"Name": "family_name", "Value": "Doe"}, + {"Name": "given_name", "Value": "John"}, + ], ) conn.admin_update_user_attributes( UserPoolId=user_pool_id, Username=username, UserAttributes=[ - { - 'Name': 'family_name', - 'Value': 'Doe', - }, - { - 'Name': 'given_name', - 'Value': 'Jane', - } - ] + {"Name": "family_name", "Value": "Doe"}, + {"Name": "given_name", "Value": "Jane"}, + ], ) - user = conn.admin_get_user( - UserPoolId=user_pool_id, - Username=username - ) - attributes = user['UserAttributes'] + user = conn.admin_get_user(UserPoolId=user_pool_id, Username=username) + attributes = user["UserAttributes"] attributes.should.be.a(list) for attr in attributes: - val = attr['Value'] - if attr['Name'] == 'family_name': - val.should.equal('Doe') - elif attr['Name'] == 'given_name': - val.should.equal('Jane') + val = attr["Value"] + if attr["Name"] == "family_name": + val.should.equal("Doe") + elif attr["Name"] == "given_name": + val.should.equal("Jane") diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 95e88cab1..d5ec8f0bc 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -1,1011 +1,1804 @@ +import json from datetime import datetime, timedelta import boto3 from botocore.exceptions import ClientError from nose.tools import assert_raises +from moto import mock_s3 from moto.config import mock_config +from moto.core import ACCOUNT_ID @mock_config def test_put_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Try without a name supplied: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={'roleARN': 'somearn'}) - assert ce.exception.response['Error']['Code'] == 'InvalidConfigurationRecorderNameException' - assert 'is not valid, blank string.' in ce.exception.response['Error']['Message'] + client.put_configuration_recorder(ConfigurationRecorder={"roleARN": "somearn"}) + assert ( + ce.exception.response["Error"]["Code"] + == "InvalidConfigurationRecorderNameException" + ) + assert "is not valid, blank string." in ce.exception.response["Error"]["Message"] # Try with a really long name: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={'name': 'a' * 257, 'roleARN': 'somearn'}) - assert ce.exception.response['Error']['Code'] == 'ValidationException' - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] + client.put_configuration_recorder( + ConfigurationRecorder={"name": "a" * 257, "roleARN": "somearn"} + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) # With resource types and flags set to True: bad_groups = [ - {'allSupported': True, 'includeGlobalResourceTypes': True, 'resourceTypes': ['item']}, - {'allSupported': False, 'includeGlobalResourceTypes': True, 'resourceTypes': ['item']}, - {'allSupported': True, 'includeGlobalResourceTypes': False, 'resourceTypes': ['item']}, - {'allSupported': False, 'includeGlobalResourceTypes': False, 'resourceTypes': []}, - {'includeGlobalResourceTypes': False, 'resourceTypes': []}, - {'includeGlobalResourceTypes': True}, - {'resourceTypes': []}, - {} + { + "allSupported": True, + "includeGlobalResourceTypes": True, + "resourceTypes": ["item"], + }, + { + "allSupported": False, + "includeGlobalResourceTypes": True, + "resourceTypes": ["item"], + }, + { + "allSupported": True, + "includeGlobalResourceTypes": False, + "resourceTypes": ["item"], + }, + { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": [], + }, + {"includeGlobalResourceTypes": False, "resourceTypes": []}, + {"includeGlobalResourceTypes": True}, + {"resourceTypes": []}, + {}, ] for bg in bad_groups: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'default', - 'roleARN': 'somearn', - 'recordingGroup': bg - }) - assert ce.exception.response['Error']['Code'] == 'InvalidRecordingGroupException' - assert ce.exception.response['Error']['Message'] == 'The recording group provided is not valid' + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "default", + "roleARN": "somearn", + "recordingGroup": bg, + } + ) + assert ( + ce.exception.response["Error"]["Code"] == "InvalidRecordingGroupException" + ) + assert ( + ce.exception.response["Error"]["Message"] + == "The recording group provided is not valid" + ) # With an invalid Resource Type: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'default', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - # 2 good, and 2 bad: - 'resourceTypes': ['AWS::EC2::Volume', 'LOLNO', 'AWS::EC2::VPC', 'LOLSTILLNO'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "default", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + # 2 good, and 2 bad: + "resourceTypes": [ + "AWS::EC2::Volume", + "LOLNO", + "AWS::EC2::VPC", + "LOLSTILLNO", + ], + }, } - }) - assert ce.exception.response['Error']['Code'] == 'ValidationException' - assert "2 validation error detected: Value '['LOLNO', 'LOLSTILLNO']" in str(ce.exception.response['Error']['Message']) - assert 'AWS::EC2::Instance' in ce.exception.response['Error']['Message'] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" + assert "2 validation error detected: Value '['LOLNO', 'LOLSTILLNO']" in str( + ce.exception.response["Error"]["Message"] + ) + assert "AWS::EC2::Instance" in ce.exception.response["Error"]["Message"] # Create a proper one: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert not result[0]['recordingGroup']['allSupported'] - assert not result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert len(result[0]['recordingGroup']['resourceTypes']) == 2 - assert 'AWS::EC2::Volume' in result[0]['recordingGroup']['resourceTypes'] \ - and 'AWS::EC2::VPC' in result[0]['recordingGroup']['resourceTypes'] + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert not result[0]["recordingGroup"]["allSupported"] + assert not result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert len(result[0]["recordingGroup"]["resourceTypes"]) == 2 + assert ( + "AWS::EC2::Volume" in result[0]["recordingGroup"]["resourceTypes"] + and "AWS::EC2::VPC" in result[0]["recordingGroup"]["resourceTypes"] + ) # Now update the configuration recorder: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': True, - 'includeGlobalResourceTypes': True + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": True, + "includeGlobalResourceTypes": True, + }, } - }) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + ) + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert result[0]['recordingGroup']['allSupported'] - assert result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert len(result[0]['recordingGroup']['resourceTypes']) == 0 + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert result[0]["recordingGroup"]["allSupported"] + assert result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert len(result[0]["recordingGroup"]["resourceTypes"]) == 0 # With a default recording group (i.e. lacking one) - client.put_configuration_recorder(ConfigurationRecorder={'name': 'testrecorder', 'roleARN': 'somearn'}) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + client.put_configuration_recorder( + ConfigurationRecorder={"name": "testrecorder", "roleARN": "somearn"} + ) + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert result[0]['recordingGroup']['allSupported'] - assert not result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert not result[0]['recordingGroup'].get('resourceTypes') + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert result[0]["recordingGroup"]["allSupported"] + assert not result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert not result[0]["recordingGroup"].get("resourceTypes") # Can currently only have exactly 1 Config Recorder in an account/region: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'someotherrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "someotherrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + }, } - }) - assert ce.exception.response['Error']['Code'] == 'MaxNumberOfConfigurationRecordersExceededException' - assert "maximum number of configuration recorders: 1 is reached." in ce.exception.response['Error']['Message'] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "MaxNumberOfConfigurationRecordersExceededException" + ) + assert ( + "maximum number of configuration recorders: 1 is reached." + in ce.exception.response["Error"]["Message"] + ) @mock_config def test_put_configuration_aggregator(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # With too many aggregation sources: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ] + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], }, { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ] - } - ] + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], + }, + ], ) - assert 'Member must have length less than or equal to 1' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 1" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # With an invalid region config (no regions defined): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AllAwsRegions': False + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AllAwsRegions": False, } - ] + ], ) - assert 'Your request does not specify any regions' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "Your request does not specify any regions" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole' - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole" + }, ) - assert 'Your request does not specify any regions' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "Your request does not specify any regions" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # With both region flags defined: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': True + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": True, } - ] + ], ) - assert 'You must choose one of these options' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "You must choose one of these options" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': True - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": True, + }, ) - assert 'You must choose one of these options' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "You must choose one of these options" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # Name too long: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='a' * 257, + ConfigurationAggregatorName="a" * 257, AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } - ] + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], ) - assert 'configurationAggregatorName' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert "configurationAggregatorName" in ce.exception.response["Error"]["Message"] + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Too many tags (>50): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], + Tags=[ + {"Key": "{}".format(x), "Value": "{}".format(x)} for x in range(0, 51) ], - Tags=[{'Key': '{}'.format(x), 'Value': '{}'.format(x)} for x in range(0, 51)] ) - assert 'Member must have length less than or equal to 50' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 50" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag key is too big (>128 chars): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': 'a' * 129, 'Value': 'a'}] + Tags=[{"Key": "a" * 129, "Value": "a"}], ) - assert 'Member must have length less than or equal to 128' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 128" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag value is too big (>256 chars): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': 'tag', 'Value': 'a' * 257}] + Tags=[{"Key": "tag", "Value": "a" * 257}], ) - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Duplicate Tags: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': 'a', 'Value': 'a'}, {'Key': 'a', 'Value': 'a'}] + Tags=[{"Key": "a", "Value": "a"}, {"Key": "a", "Value": "a"}], ) - assert 'Duplicate tag keys found.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidInput' + assert "Duplicate tag keys found." in ce.exception.response["Error"]["Message"] + assert ce.exception.response["Error"]["Code"] == "InvalidInput" # Invalid characters in the tag key: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': '!', 'Value': 'a'}] + Tags=[{"Key": "!", "Value": "a"}], ) - assert 'Member must satisfy regular expression pattern:' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must satisfy regular expression pattern:" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # If it contains both the AccountAggregationSources and the OrganizationAggregationSource with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': False - } + {"AccountIds": ["012345678910"], "AllAwsRegions": False} ], OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AllAwsRegions': False - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AllAwsRegions": False, + }, ) - assert 'AccountAggregationSource and the OrganizationAggregationSource' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "AccountAggregationSource and the OrganizationAggregationSource" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # If it contains neither: with assert_raises(ClientError) as ce: - client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', - ) - assert 'AccountAggregationSource or the OrganizationAggregationSource' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + client.put_configuration_aggregator(ConfigurationAggregatorName="testing") + assert ( + "AccountAggregationSource or the OrganizationAggregationSource" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # Just make one: account_aggregation_source = { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': False + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": False, } result = client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[account_aggregation_source], ) - assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testing' - assert result['ConfigurationAggregator']['AccountAggregationSources'] == [account_aggregation_source] - assert 'arn:aws:config:us-west-2:123456789012:config-aggregator/config-aggregator-' in \ - result['ConfigurationAggregator']['ConfigurationAggregatorArn'] - assert result['ConfigurationAggregator']['CreationTime'] == result['ConfigurationAggregator']['LastUpdatedTime'] - - # Update the existing one: - original_arn = result['ConfigurationAggregator']['ConfigurationAggregatorArn'] - account_aggregation_source.pop('AwsRegions') - account_aggregation_source['AllAwsRegions'] = True - result = client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', - AccountAggregationSources=[account_aggregation_source] + assert result["ConfigurationAggregator"]["ConfigurationAggregatorName"] == "testing" + assert result["ConfigurationAggregator"]["AccountAggregationSources"] == [ + account_aggregation_source + ] + assert ( + "arn:aws:config:us-west-2:{}:config-aggregator/config-aggregator-".format( + ACCOUNT_ID + ) + in result["ConfigurationAggregator"]["ConfigurationAggregatorArn"] + ) + assert ( + result["ConfigurationAggregator"]["CreationTime"] + == result["ConfigurationAggregator"]["LastUpdatedTime"] ) - assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testing' - assert result['ConfigurationAggregator']['AccountAggregationSources'] == [account_aggregation_source] - assert result['ConfigurationAggregator']['ConfigurationAggregatorArn'] == original_arn + # Update the existing one: + original_arn = result["ConfigurationAggregator"]["ConfigurationAggregatorArn"] + account_aggregation_source.pop("AwsRegions") + account_aggregation_source["AllAwsRegions"] = True + result = client.put_configuration_aggregator( + ConfigurationAggregatorName="testing", + AccountAggregationSources=[account_aggregation_source], + ) + + assert result["ConfigurationAggregator"]["ConfigurationAggregatorName"] == "testing" + assert result["ConfigurationAggregator"]["AccountAggregationSources"] == [ + account_aggregation_source + ] + assert ( + result["ConfigurationAggregator"]["ConfigurationAggregatorArn"] == original_arn + ) # Make an org one: result = client.put_configuration_aggregator( - ConfigurationAggregatorName='testingOrg', + ConfigurationAggregatorName="testingOrg", OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AwsRegions': ['us-east-1', 'us-west-2'] - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AwsRegions": ["us-east-1", "us-west-2"], + }, ) - assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testingOrg' - assert result['ConfigurationAggregator']['OrganizationAggregationSource'] == { - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': False + assert ( + result["ConfigurationAggregator"]["ConfigurationAggregatorName"] == "testingOrg" + ) + assert result["ConfigurationAggregator"]["OrganizationAggregationSource"] == { + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": False, } @mock_config def test_describe_configuration_aggregators(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without any config aggregators: - assert not client.describe_configuration_aggregators()['ConfigurationAggregators'] + assert not client.describe_configuration_aggregators()["ConfigurationAggregators"] # Make 10 config aggregators: for x in range(0, 10): client.put_configuration_aggregator( - ConfigurationAggregatorName='testing{}'.format(x), + ConfigurationAggregatorName="testing{}".format(x), AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } - ] + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], ) # Describe with an incorrect name: with assert_raises(ClientError) as ce: - client.describe_configuration_aggregators(ConfigurationAggregatorNames=['DoesNotExist']) - assert 'The configuration aggregator does not exist.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException' + client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["DoesNotExist"] + ) + assert ( + "The configuration aggregator does not exist." + in ce.exception.response["Error"]["Message"] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "NoSuchConfigurationAggregatorException" + ) # Error describe with more than 1 item in the list: with assert_raises(ClientError) as ce: - client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing0', 'DoesNotExist']) - assert 'At least one of the configuration aggregators does not exist.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException' + client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["testing0", "DoesNotExist"] + ) + assert ( + "At least one of the configuration aggregators does not exist." + in ce.exception.response["Error"]["Message"] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "NoSuchConfigurationAggregatorException" + ) # Get the normal list: result = client.describe_configuration_aggregators() - assert not result.get('NextToken') - assert len(result['ConfigurationAggregators']) == 10 + assert not result.get("NextToken") + assert len(result["ConfigurationAggregators"]) == 10 # Test filtered list: - agg_names = ['testing0', 'testing1', 'testing2'] - result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=agg_names) - assert not result.get('NextToken') - assert len(result['ConfigurationAggregators']) == 3 - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == agg_names + agg_names = ["testing0", "testing1", "testing2"] + result = client.describe_configuration_aggregators( + ConfigurationAggregatorNames=agg_names + ) + assert not result.get("NextToken") + assert len(result["ConfigurationAggregators"]) == 3 + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == agg_names # Test Pagination: result = client.describe_configuration_aggregators(Limit=4) - assert len(result['ConfigurationAggregators']) == 4 - assert result['NextToken'] == 'testing4' - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \ - ['testing{}'.format(x) for x in range(0, 4)] - result = client.describe_configuration_aggregators(Limit=4, NextToken='testing4') - assert len(result['ConfigurationAggregators']) == 4 - assert result['NextToken'] == 'testing8' - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \ - ['testing{}'.format(x) for x in range(4, 8)] - result = client.describe_configuration_aggregators(Limit=4, NextToken='testing8') - assert len(result['ConfigurationAggregators']) == 2 - assert not result.get('NextToken') - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \ - ['testing{}'.format(x) for x in range(8, 10)] + assert len(result["ConfigurationAggregators"]) == 4 + assert result["NextToken"] == "testing4" + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == ["testing{}".format(x) for x in range(0, 4)] + result = client.describe_configuration_aggregators(Limit=4, NextToken="testing4") + assert len(result["ConfigurationAggregators"]) == 4 + assert result["NextToken"] == "testing8" + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == ["testing{}".format(x) for x in range(4, 8)] + result = client.describe_configuration_aggregators(Limit=4, NextToken="testing8") + assert len(result["ConfigurationAggregators"]) == 2 + assert not result.get("NextToken") + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == ["testing{}".format(x) for x in range(8, 10)] # Test Pagination with Filtering: - result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing2', 'testing4'], Limit=1) - assert len(result['ConfigurationAggregators']) == 1 - assert result['NextToken'] == 'testing4' - assert result['ConfigurationAggregators'][0]['ConfigurationAggregatorName'] == 'testing2' - result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing2', 'testing4'], Limit=1, NextToken='testing4') - assert not result.get('NextToken') - assert result['ConfigurationAggregators'][0]['ConfigurationAggregatorName'] == 'testing4' + result = client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["testing2", "testing4"], Limit=1 + ) + assert len(result["ConfigurationAggregators"]) == 1 + assert result["NextToken"] == "testing4" + assert ( + result["ConfigurationAggregators"][0]["ConfigurationAggregatorName"] + == "testing2" + ) + result = client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["testing2", "testing4"], + Limit=1, + NextToken="testing4", + ) + assert not result.get("NextToken") + assert ( + result["ConfigurationAggregators"][0]["ConfigurationAggregatorName"] + == "testing4" + ) # Test with an invalid filter: with assert_raises(ClientError) as ce: - client.describe_configuration_aggregators(NextToken='WRONG') - assert 'The nextToken provided is invalid' == ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidNextTokenException' + client.describe_configuration_aggregators(NextToken="WRONG") + assert ( + "The nextToken provided is invalid" == ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidNextTokenException" @mock_config def test_put_aggregation_authorization(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Too many tags (>50): with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': '{}'.format(x), 'Value': '{}'.format(x)} for x in range(0, 51)] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[ + {"Key": "{}".format(x), "Value": "{}".format(x)} for x in range(0, 51) + ], ) - assert 'Member must have length less than or equal to 50' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 50" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag key is too big (>128 chars): with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': 'a' * 129, 'Value': 'a'}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "a" * 129, "Value": "a"}], ) - assert 'Member must have length less than or equal to 128' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 128" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag value is too big (>256 chars): with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': 'tag', 'Value': 'a' * 257}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "tag", "Value": "a" * 257}], ) - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Duplicate Tags: with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': 'a', 'Value': 'a'}, {'Key': 'a', 'Value': 'a'}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "a", "Value": "a"}, {"Key": "a", "Value": "a"}], ) - assert 'Duplicate tag keys found.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidInput' + assert "Duplicate tag keys found." in ce.exception.response["Error"]["Message"] + assert ce.exception.response["Error"]["Code"] == "InvalidInput" # Invalid characters in the tag key: with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': '!', 'Value': 'a'}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "!", "Value": "a"}], ) - assert 'Member must satisfy regular expression pattern:' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must satisfy regular expression pattern:" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Put a normal one there: - result = client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-east-1', - Tags=[{'Key': 'tag', 'Value': 'a'}]) + result = client.put_aggregation_authorization( + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-east-1", + Tags=[{"Key": "tag", "Value": "a"}], + ) - assert result['AggregationAuthorization']['AggregationAuthorizationArn'] == 'arn:aws:config:us-west-2:123456789012:' \ - 'aggregation-authorization/012345678910/us-east-1' - assert result['AggregationAuthorization']['AuthorizedAccountId'] == '012345678910' - assert result['AggregationAuthorization']['AuthorizedAwsRegion'] == 'us-east-1' - assert isinstance(result['AggregationAuthorization']['CreationTime'], datetime) + assert result["AggregationAuthorization"][ + "AggregationAuthorizationArn" + ] == "arn:aws:config:us-west-2:{}:aggregation-authorization/012345678910/us-east-1".format( + ACCOUNT_ID + ) + assert result["AggregationAuthorization"]["AuthorizedAccountId"] == "012345678910" + assert result["AggregationAuthorization"]["AuthorizedAwsRegion"] == "us-east-1" + assert isinstance(result["AggregationAuthorization"]["CreationTime"], datetime) - creation_date = result['AggregationAuthorization']['CreationTime'] + creation_date = result["AggregationAuthorization"]["CreationTime"] # And again: - result = client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-east-1') - assert result['AggregationAuthorization']['AggregationAuthorizationArn'] == 'arn:aws:config:us-west-2:123456789012:' \ - 'aggregation-authorization/012345678910/us-east-1' - assert result['AggregationAuthorization']['AuthorizedAccountId'] == '012345678910' - assert result['AggregationAuthorization']['AuthorizedAwsRegion'] == 'us-east-1' - assert result['AggregationAuthorization']['CreationTime'] == creation_date + result = client.put_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-east-1" + ) + assert result["AggregationAuthorization"][ + "AggregationAuthorizationArn" + ] == "arn:aws:config:us-west-2:{}:aggregation-authorization/012345678910/us-east-1".format( + ACCOUNT_ID + ) + assert result["AggregationAuthorization"]["AuthorizedAccountId"] == "012345678910" + assert result["AggregationAuthorization"]["AuthorizedAwsRegion"] == "us-east-1" + assert result["AggregationAuthorization"]["CreationTime"] == creation_date @mock_config def test_describe_aggregation_authorizations(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # With no aggregation authorizations: - assert not client.describe_aggregation_authorizations()['AggregationAuthorizations'] + assert not client.describe_aggregation_authorizations()["AggregationAuthorizations"] # Make 10 account authorizations: for i in range(0, 10): - client.put_aggregation_authorization(AuthorizedAccountId='{}'.format(str(i) * 12), AuthorizedAwsRegion='us-west-2') + client.put_aggregation_authorization( + AuthorizedAccountId="{}".format(str(i) * 12), + AuthorizedAwsRegion="us-west-2", + ) result = client.describe_aggregation_authorizations() - assert len(result['AggregationAuthorizations']) == 10 - assert not result.get('NextToken') + assert len(result["AggregationAuthorizations"]) == 10 + assert not result.get("NextToken") for i in range(0, 10): - assert result['AggregationAuthorizations'][i]['AuthorizedAccountId'] == str(i) * 12 + assert ( + result["AggregationAuthorizations"][i]["AuthorizedAccountId"] == str(i) * 12 + ) # Test Pagination: result = client.describe_aggregation_authorizations(Limit=4) - assert len(result['AggregationAuthorizations']) == 4 - assert result['NextToken'] == ('4' * 12) + '/us-west-2' - assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(0, 4)] + assert len(result["AggregationAuthorizations"]) == 4 + assert result["NextToken"] == ("4" * 12) + "/us-west-2" + assert [ + auth["AuthorizedAccountId"] for auth in result["AggregationAuthorizations"] + ] == ["{}".format(str(x) * 12) for x in range(0, 4)] - result = client.describe_aggregation_authorizations(Limit=4, NextToken=('4' * 12) + '/us-west-2') - assert len(result['AggregationAuthorizations']) == 4 - assert result['NextToken'] == ('8' * 12) + '/us-west-2' - assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(4, 8)] + result = client.describe_aggregation_authorizations( + Limit=4, NextToken=("4" * 12) + "/us-west-2" + ) + assert len(result["AggregationAuthorizations"]) == 4 + assert result["NextToken"] == ("8" * 12) + "/us-west-2" + assert [ + auth["AuthorizedAccountId"] for auth in result["AggregationAuthorizations"] + ] == ["{}".format(str(x) * 12) for x in range(4, 8)] - result = client.describe_aggregation_authorizations(Limit=4, NextToken=('8' * 12) + '/us-west-2') - assert len(result['AggregationAuthorizations']) == 2 - assert not result.get('NextToken') - assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(8, 10)] + result = client.describe_aggregation_authorizations( + Limit=4, NextToken=("8" * 12) + "/us-west-2" + ) + assert len(result["AggregationAuthorizations"]) == 2 + assert not result.get("NextToken") + assert [ + auth["AuthorizedAccountId"] for auth in result["AggregationAuthorizations"] + ] == ["{}".format(str(x) * 12) for x in range(8, 10)] # Test with an invalid filter: with assert_raises(ClientError) as ce: - client.describe_aggregation_authorizations(NextToken='WRONG') - assert 'The nextToken provided is invalid' == ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidNextTokenException' + client.describe_aggregation_authorizations(NextToken="WRONG") + assert ( + "The nextToken provided is invalid" == ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidNextTokenException" @mock_config def test_delete_aggregation_authorization(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") - client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2') + client.put_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-west-2" + ) # Delete it: - client.delete_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2') + client.delete_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-west-2" + ) # Verify that none are there: - assert not client.describe_aggregation_authorizations()['AggregationAuthorizations'] + assert not client.describe_aggregation_authorizations()["AggregationAuthorizations"] # Try it again -- nothing should happen: - client.delete_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2') + client.delete_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-west-2" + ) @mock_config def test_delete_configuration_aggregator(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } - ] + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], ) - client.delete_configuration_aggregator(ConfigurationAggregatorName='testing') + client.delete_configuration_aggregator(ConfigurationAggregatorName="testing") # And again to confirm that it's deleted: with assert_raises(ClientError) as ce: - client.delete_configuration_aggregator(ConfigurationAggregatorName='testing') - assert 'The configuration aggregator does not exist.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException' + client.delete_configuration_aggregator(ConfigurationAggregatorName="testing") + assert ( + "The configuration aggregator does not exist." + in ce.exception.response["Error"]["Message"] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "NoSuchConfigurationAggregatorException" + ) @mock_config def test_describe_configurations(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without any configurations: result = client.describe_configuration_recorders() - assert not result['ConfigurationRecorders'] + assert not result["ConfigurationRecorders"] - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert not result[0]['recordingGroup']['allSupported'] - assert not result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert len(result[0]['recordingGroup']['resourceTypes']) == 2 - assert 'AWS::EC2::Volume' in result[0]['recordingGroup']['resourceTypes'] \ - and 'AWS::EC2::VPC' in result[0]['recordingGroup']['resourceTypes'] + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert not result[0]["recordingGroup"]["allSupported"] + assert not result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert len(result[0]["recordingGroup"]["resourceTypes"]) == 2 + assert ( + "AWS::EC2::Volume" in result[0]["recordingGroup"]["resourceTypes"] + and "AWS::EC2::VPC" in result[0]["recordingGroup"]["resourceTypes"] + ) # Specify an incorrect name: with assert_raises(ClientError) as ce: - client.describe_configuration_recorders(ConfigurationRecorderNames=['wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_configuration_recorders(ConfigurationRecorderNames=["wrong"]) + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) + assert "wrong" in ce.exception.response["Error"]["Message"] # And with both a good and wrong name: with assert_raises(ClientError) as ce: - client.describe_configuration_recorders(ConfigurationRecorderNames=['testrecorder', 'wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_configuration_recorders( + ConfigurationRecorderNames=["testrecorder", "wrong"] + ) + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) + assert "wrong" in ce.exception.response["Error"]["Message"] @mock_config def test_delivery_channels(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Try without a config recorder: with assert_raises(ClientError) as ce: client.put_delivery_channel(DeliveryChannel={}) - assert ce.exception.response['Error']['Code'] == 'NoAvailableConfigurationRecorderException' - assert ce.exception.response['Error']['Message'] == 'Configuration recorder is not available to ' \ - 'put delivery channel.' + assert ( + ce.exception.response["Error"]["Code"] + == "NoAvailableConfigurationRecorderException" + ) + assert ( + ce.exception.response["Error"]["Message"] + == "Configuration recorder is not available to " + "put delivery channel." + ) # Create a config recorder to continue testing: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Try without a name supplied: with assert_raises(ClientError) as ce: client.put_delivery_channel(DeliveryChannel={}) - assert ce.exception.response['Error']['Code'] == 'InvalidDeliveryChannelNameException' - assert 'is not valid, blank string.' in ce.exception.response['Error']['Message'] + assert ( + ce.exception.response["Error"]["Code"] == "InvalidDeliveryChannelNameException" + ) + assert "is not valid, blank string." in ce.exception.response["Error"]["Message"] # Try with a really long name: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'a' * 257}) - assert ce.exception.response['Error']['Code'] == 'ValidationException' - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] + client.put_delivery_channel(DeliveryChannel={"name": "a" * 257}) + assert ce.exception.response["Error"]["Code"] == "ValidationException" + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) # Without specifying a bucket name: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel'}) - assert ce.exception.response['Error']['Code'] == 'NoSuchBucketException' - assert ce.exception.response['Error']['Message'] == 'Cannot find a S3 bucket with an empty bucket name.' + client.put_delivery_channel(DeliveryChannel={"name": "testchannel"}) + assert ce.exception.response["Error"]["Code"] == "NoSuchBucketException" + assert ( + ce.exception.response["Error"]["Message"] + == "Cannot find a S3 bucket with an empty bucket name." + ) with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': ''}) - assert ce.exception.response['Error']['Code'] == 'NoSuchBucketException' - assert ce.exception.response['Error']['Message'] == 'Cannot find a S3 bucket with an empty bucket name.' + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": ""} + ) + assert ce.exception.response["Error"]["Code"] == "NoSuchBucketException" + assert ( + ce.exception.response["Error"]["Message"] + == "Cannot find a S3 bucket with an empty bucket name." + ) # With an empty string for the S3 key prefix: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', 's3BucketName': 'somebucket', 's3KeyPrefix': ''}) - assert ce.exception.response['Error']['Code'] == 'InvalidS3KeyPrefixException' - assert 'empty s3 key prefix.' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "s3KeyPrefix": "", + } + ) + assert ce.exception.response["Error"]["Code"] == "InvalidS3KeyPrefixException" + assert "empty s3 key prefix." in ce.exception.response["Error"]["Message"] # With an empty string for the SNS ARN: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', 's3BucketName': 'somebucket', 'snsTopicARN': ''}) - assert ce.exception.response['Error']['Code'] == 'InvalidSNSTopicARNException' - assert 'The sns topic arn' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "snsTopicARN": "", + } + ) + assert ce.exception.response["Error"]["Code"] == "InvalidSNSTopicARNException" + assert "The sns topic arn" in ce.exception.response["Error"]["Message"] # With an invalid delivery frequency: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', - 's3BucketName': 'somebucket', - 'configSnapshotDeliveryProperties': {'deliveryFrequency': 'WRONG'} - }) - assert ce.exception.response['Error']['Code'] == 'InvalidDeliveryFrequency' - assert 'WRONG' in ce.exception.response['Error']['Message'] - assert 'TwentyFour_Hours' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "configSnapshotDeliveryProperties": {"deliveryFrequency": "WRONG"}, + } + ) + assert ce.exception.response["Error"]["Code"] == "InvalidDeliveryFrequency" + assert "WRONG" in ce.exception.response["Error"]["Message"] + assert "TwentyFour_Hours" in ce.exception.response["Error"]["Message"] # Create a proper one: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 2 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" # Overwrite it with another proper configuration: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', - 's3BucketName': 'somebucket', - 'snsTopicARN': 'sometopicarn', - 'configSnapshotDeliveryProperties': {'deliveryFrequency': 'TwentyFour_Hours'} - }) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "snsTopicARN": "sometopicarn", + "configSnapshotDeliveryProperties": { + "deliveryFrequency": "TwentyFour_Hours" + }, + } + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 4 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' - assert result[0]['snsTopicARN'] == 'sometopicarn' - assert result[0]['configSnapshotDeliveryProperties']['deliveryFrequency'] == 'TwentyFour_Hours' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" + assert result[0]["snsTopicARN"] == "sometopicarn" + assert ( + result[0]["configSnapshotDeliveryProperties"]["deliveryFrequency"] + == "TwentyFour_Hours" + ) # Can only have 1: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel2', 's3BucketName': 'somebucket'}) - assert ce.exception.response['Error']['Code'] == 'MaxNumberOfDeliveryChannelsExceededException' - assert 'because the maximum number of delivery channels: 1 is reached.' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel2", "s3BucketName": "somebucket"} + ) + assert ( + ce.exception.response["Error"]["Code"] + == "MaxNumberOfDeliveryChannelsExceededException" + ) + assert ( + "because the maximum number of delivery channels: 1 is reached." + in ce.exception.response["Error"]["Message"] + ) @mock_config def test_describe_delivery_channels(): - client = boto3.client('config', region_name='us-west-2') - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client = boto3.client("config", region_name="us-west-2") + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Without any channels: result = client.describe_delivery_channels() - assert not result['DeliveryChannels'] + assert not result["DeliveryChannels"] - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 2 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" # Overwrite it with another proper configuration: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', - 's3BucketName': 'somebucket', - 'snsTopicARN': 'sometopicarn', - 'configSnapshotDeliveryProperties': {'deliveryFrequency': 'TwentyFour_Hours'} - }) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "snsTopicARN": "sometopicarn", + "configSnapshotDeliveryProperties": { + "deliveryFrequency": "TwentyFour_Hours" + }, + } + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 4 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' - assert result[0]['snsTopicARN'] == 'sometopicarn' - assert result[0]['configSnapshotDeliveryProperties']['deliveryFrequency'] == 'TwentyFour_Hours' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" + assert result[0]["snsTopicARN"] == "sometopicarn" + assert ( + result[0]["configSnapshotDeliveryProperties"]["deliveryFrequency"] + == "TwentyFour_Hours" + ) # Specify an incorrect name: with assert_raises(ClientError) as ce: - client.describe_delivery_channels(DeliveryChannelNames=['wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchDeliveryChannelException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_delivery_channels(DeliveryChannelNames=["wrong"]) + assert ce.exception.response["Error"]["Code"] == "NoSuchDeliveryChannelException" + assert "wrong" in ce.exception.response["Error"]["Message"] # And with both a good and wrong name: with assert_raises(ClientError) as ce: - client.describe_delivery_channels(DeliveryChannelNames=['testchannel', 'wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchDeliveryChannelException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_delivery_channels(DeliveryChannelNames=["testchannel", "wrong"]) + assert ce.exception.response["Error"]["Code"] == "NoSuchDeliveryChannelException" + assert "wrong" in ce.exception.response["Error"]["Message"] @mock_config def test_start_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without a config recorder: with assert_raises(ClientError) as ce: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Without a delivery channel: with assert_raises(ClientError) as ce: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoAvailableDeliveryChannelException' + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoAvailableDeliveryChannelException" + ) # Make the delivery channel: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) # Start it: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") # Verify it's enabled: - result = client.describe_configuration_recorder_status()['ConfigurationRecordersStatus'] - lower_bound = (datetime.utcnow() - timedelta(minutes=5)) - assert result[0]['recording'] - assert result[0]['lastStatus'] == 'PENDING' - assert lower_bound < result[0]['lastStartTime'].replace(tzinfo=None) <= datetime.utcnow() - assert lower_bound < result[0]['lastStatusChangeTime'].replace(tzinfo=None) <= datetime.utcnow() + result = client.describe_configuration_recorder_status()[ + "ConfigurationRecordersStatus" + ] + lower_bound = datetime.utcnow() - timedelta(minutes=5) + assert result[0]["recording"] + assert result[0]["lastStatus"] == "PENDING" + assert ( + lower_bound + < result[0]["lastStartTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) + assert ( + lower_bound + < result[0]["lastStatusChangeTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) @mock_config def test_stop_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without a config recorder: with assert_raises(ClientError) as ce: - client.stop_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' + client.stop_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Make the delivery channel for creation: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) # Start it: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') - client.stop_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") + client.stop_configuration_recorder(ConfigurationRecorderName="testrecorder") # Verify it's disabled: - result = client.describe_configuration_recorder_status()['ConfigurationRecordersStatus'] - lower_bound = (datetime.utcnow() - timedelta(minutes=5)) - assert not result[0]['recording'] - assert result[0]['lastStatus'] == 'PENDING' - assert lower_bound < result[0]['lastStartTime'].replace(tzinfo=None) <= datetime.utcnow() - assert lower_bound < result[0]['lastStopTime'].replace(tzinfo=None) <= datetime.utcnow() - assert lower_bound < result[0]['lastStatusChangeTime'].replace(tzinfo=None) <= datetime.utcnow() + result = client.describe_configuration_recorder_status()[ + "ConfigurationRecordersStatus" + ] + lower_bound = datetime.utcnow() - timedelta(minutes=5) + assert not result[0]["recording"] + assert result[0]["lastStatus"] == "PENDING" + assert ( + lower_bound + < result[0]["lastStartTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) + assert ( + lower_bound + < result[0]["lastStopTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) + assert ( + lower_bound + < result[0]["lastStatusChangeTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) @mock_config def test_describe_configuration_recorder_status(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without any: result = client.describe_configuration_recorder_status() - assert not result['ConfigurationRecordersStatus'] + assert not result["ConfigurationRecordersStatus"] # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Without specifying a config recorder: - result = client.describe_configuration_recorder_status()['ConfigurationRecordersStatus'] + result = client.describe_configuration_recorder_status()[ + "ConfigurationRecordersStatus" + ] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert not result[0]['recording'] + assert result[0]["name"] == "testrecorder" + assert not result[0]["recording"] # With a proper name: result = client.describe_configuration_recorder_status( - ConfigurationRecorderNames=['testrecorder'])['ConfigurationRecordersStatus'] + ConfigurationRecorderNames=["testrecorder"] + )["ConfigurationRecordersStatus"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert not result[0]['recording'] + assert result[0]["name"] == "testrecorder" + assert not result[0]["recording"] # Invalid name: with assert_raises(ClientError) as ce: - client.describe_configuration_recorder_status(ConfigurationRecorderNames=['testrecorder', 'wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_configuration_recorder_status( + ConfigurationRecorderNames=["testrecorder", "wrong"] + ) + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) + assert "wrong" in ce.exception.response["Error"]["Message"] @mock_config def test_delete_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Delete it: - client.delete_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.delete_configuration_recorder(ConfigurationRecorderName="testrecorder") # Try again -- it should be deleted: with assert_raises(ClientError) as ce: - client.delete_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' + client.delete_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) @mock_config def test_delete_delivery_channel(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Need a recorder to test the constraint on recording being enabled: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') + ) + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") # With the recorder enabled: with assert_raises(ClientError) as ce: - client.delete_delivery_channel(DeliveryChannelName='testchannel') - assert ce.exception.response['Error']['Code'] == 'LastDeliveryChannelDeleteFailedException' - assert 'because there is a running configuration recorder.' in ce.exception.response['Error']['Message'] + client.delete_delivery_channel(DeliveryChannelName="testchannel") + assert ( + ce.exception.response["Error"]["Code"] + == "LastDeliveryChannelDeleteFailedException" + ) + assert ( + "because there is a running configuration recorder." + in ce.exception.response["Error"]["Message"] + ) # Stop recording: - client.stop_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.stop_configuration_recorder(ConfigurationRecorderName="testrecorder") # Try again: - client.delete_delivery_channel(DeliveryChannelName='testchannel') + client.delete_delivery_channel(DeliveryChannelName="testchannel") # Verify: with assert_raises(ClientError) as ce: - client.delete_delivery_channel(DeliveryChannelName='testchannel') - assert ce.exception.response['Error']['Code'] == 'NoSuchDeliveryChannelException' + client.delete_delivery_channel(DeliveryChannelName="testchannel") + assert ce.exception.response["Error"]["Code"] == "NoSuchDeliveryChannelException" + + +@mock_config +@mock_s3 +def test_list_discovered_resource(): + """NOTE: We are only really testing the Config part. For each individual service, please add tests + for that individual service's "list_config_service_resources" function. + """ + client = boto3.client("config", region_name="us-west-2") + + # With nothing created yet: + assert not client.list_discovered_resources(resourceType="AWS::S3::Bucket")[ + "resourceIdentifiers" + ] + + # Create some S3 buckets: + s3_client = boto3.client("s3", region_name="us-west-2") + for x in range(0, 10): + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + + # And with an EU bucket -- this should not show up for the us-west-2 config backend: + eu_client = boto3.client("s3", region_name="eu-west-1") + eu_client.create_bucket( + Bucket="eu-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) + + # Now try: + result = client.list_discovered_resources(resourceType="AWS::S3::Bucket") + assert len(result["resourceIdentifiers"]) == 10 + for x in range(0, 10): + assert result["resourceIdentifiers"][x] == { + "resourceType": "AWS::S3::Bucket", + "resourceId": "bucket{}".format(x), + "resourceName": "bucket{}".format(x), + } + assert not result.get("nextToken") + + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", resourceName="eu-bucket" + ) + assert not result["resourceIdentifiers"] + + # Test that pagination places a proper nextToken in the response and also that the limit works: + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, nextToken="bucket1" + ) + assert len(result["resourceIdentifiers"]) == 1 + assert result["nextToken"] == "bucket2" + + # Try with a resource name: + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, resourceName="bucket1" + ) + assert len(result["resourceIdentifiers"]) == 1 + assert not result.get("nextToken") + + # Try with a resource ID: + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, resourceIds=["bucket1"] + ) + assert len(result["resourceIdentifiers"]) == 1 + assert not result.get("nextToken") + + # Try with duplicated resource IDs: + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, resourceIds=["bucket1", "bucket1"] + ) + assert len(result["resourceIdentifiers"]) == 1 + assert not result.get("nextToken") + + # Test with an invalid resource type: + assert not client.list_discovered_resources( + resourceType="LOL::NOT::A::RESOURCE::TYPE" + )["resourceIdentifiers"] + + # Test with an invalid page num > 100: + with assert_raises(ClientError) as ce: + client.list_discovered_resources(resourceType="AWS::S3::Bucket", limit=101) + assert "101" in ce.exception.response["Error"]["Message"] + + # Test by supplying both resourceName and also resourceIds: + with assert_raises(ClientError) as ce: + client.list_discovered_resources( + resourceType="AWS::S3::Bucket", + resourceName="whats", + resourceIds=["up", "doc"], + ) + assert ( + "Both Resource ID and Resource Name cannot be specified in the request" + in ce.exception.response["Error"]["Message"] + ) + + # More than 20 resourceIds: + resource_ids = ["{}".format(x) for x in range(0, 21)] + with assert_raises(ClientError) as ce: + client.list_discovered_resources( + resourceType="AWS::S3::Bucket", resourceIds=resource_ids + ) + assert ( + "The specified list had more than 20 resource ID's." + in ce.exception.response["Error"]["Message"] + ) + + +@mock_config +@mock_s3 +def test_list_aggregate_discovered_resource(): + """NOTE: We are only really testing the Config part. For each individual service, please add tests + for that individual service's "list_config_service_resources" function. + """ + client = boto3.client("config", region_name="us-west-2") + + # Without an aggregator: + with assert_raises(ClientError) as ce: + client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="lolno", ResourceType="AWS::S3::Bucket" + ) + assert ( + "The configuration aggregator does not exist" + in ce.exception.response["Error"]["Message"] + ) + + # Create the aggregator: + account_aggregation_source = { + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AllAwsRegions": True, + } + client.put_configuration_aggregator( + ConfigurationAggregatorName="testing", + AccountAggregationSources=[account_aggregation_source], + ) + + # With nothing created yet: + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", ResourceType="AWS::S3::Bucket" + )["ResourceIdentifiers"] + + # Create some S3 buckets: + s3_client = boto3.client("s3", region_name="us-west-2") + for x in range(0, 10): + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + + s3_client_eu = boto3.client("s3", region_name="eu-west-1") + for x in range(10, 12): + s3_client_eu.create_bucket( + Bucket="eu-bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) + + # Now try: + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", ResourceType="AWS::S3::Bucket" + ) + assert len(result["ResourceIdentifiers"]) == 12 + for x in range(0, 10): + assert result["ResourceIdentifiers"][x] == { + "SourceAccountId": ACCOUNT_ID, + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket{}".format(x), + "ResourceName": "bucket{}".format(x), + "SourceRegion": "us-west-2", + } + for x in range(11, 12): + assert result["ResourceIdentifiers"][x] == { + "SourceAccountId": ACCOUNT_ID, + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "eu-bucket{}".format(x), + "ResourceName": "eu-bucket{}".format(x), + "SourceRegion": "eu-west-1", + } + + assert not result.get("NextToken") + + # Test that pagination places a proper nextToken in the response and also that the limit works: + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=1, + NextToken="bucket1", + ) + assert len(result["ResourceIdentifiers"]) == 1 + assert result["NextToken"] == "bucket2" + + # Try with a resource name: + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=1, + NextToken="bucket1", + Filters={"ResourceName": "bucket1"}, + ) + assert len(result["ResourceIdentifiers"]) == 1 + assert not result.get("NextToken") + + # Try with a resource ID: + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=1, + NextToken="bucket1", + Filters={"ResourceId": "bucket1"}, + ) + assert len(result["ResourceIdentifiers"]) == 1 + assert not result.get("NextToken") + + # Try with a region specified: + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Filters={"Region": "eu-west-1"}, + ) + assert len(result["ResourceIdentifiers"]) == 2 + assert result["ResourceIdentifiers"][0]["SourceRegion"] == "eu-west-1" + assert not result.get("NextToken") + + # Try with both name and id set to the incorrect values: + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Filters={"ResourceId": "bucket1", "ResourceName": "bucket2"}, + )["ResourceIdentifiers"] + + # Test with an invalid resource type: + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="LOL::NOT::A::RESOURCE::TYPE", + )["ResourceIdentifiers"] + + # Try with correct name but incorrect region: + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Filters={"ResourceId": "bucket1", "Region": "us-west-1"}, + )["ResourceIdentifiers"] + + # Test with an invalid page num > 100: + with assert_raises(ClientError) as ce: + client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=101, + ) + assert "101" in ce.exception.response["Error"]["Message"] + + +@mock_config +@mock_s3 +def test_get_resource_config_history(): + """NOTE: We are only really testing the Config part. For each individual service, please add tests + for that individual service's "get_config_resource" function. + """ + client = boto3.client("config", region_name="us-west-2") + + # With an invalid resource type: + with assert_raises(ClientError) as ce: + client.get_resource_config_history( + resourceType="NOT::A::RESOURCE", resourceId="notcreatedyet" + ) + assert ce.exception.response["Error"] == { + "Message": "Resource notcreatedyet of resourceType:NOT::A::RESOURCE is unknown or has " + "not been discovered", + "Code": "ResourceNotDiscoveredException", + } + + # With nothing created yet: + with assert_raises(ClientError) as ce: + client.get_resource_config_history( + resourceType="AWS::S3::Bucket", resourceId="notcreatedyet" + ) + assert ce.exception.response["Error"] == { + "Message": "Resource notcreatedyet of resourceType:AWS::S3::Bucket is unknown or has " + "not been discovered", + "Code": "ResourceNotDiscoveredException", + } + + # Create an S3 bucket: + s3_client = boto3.client("s3", region_name="us-west-2") + for x in range(0, 10): + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + + # Now try: + result = client.get_resource_config_history( + resourceType="AWS::S3::Bucket", resourceId="bucket1" + )["configurationItems"] + assert len(result) == 1 + assert result[0]["resourceName"] == result[0]["resourceId"] == "bucket1" + assert result[0]["arn"] == "arn:aws:s3:::bucket1" + + # Make a bucket in a different region and verify that it does not show up in the config backend: + s3_client = boto3.client("s3", region_name="eu-west-1") + s3_client.create_bucket( + Bucket="eu-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) + with assert_raises(ClientError) as ce: + client.get_resource_config_history( + resourceType="AWS::S3::Bucket", resourceId="eu-bucket" + ) + assert ce.exception.response["Error"]["Code"] == "ResourceNotDiscoveredException" + + +@mock_config +@mock_s3 +def test_batch_get_resource_config(): + """NOTE: We are only really testing the Config part. For each individual service, please add tests + for that individual service's "get_config_resource" function. + """ + client = boto3.client("config", region_name="us-west-2") + + # With more than 100 resourceKeys: + with assert_raises(ClientError) as ce: + client.batch_get_resource_config( + resourceKeys=[ + {"resourceType": "AWS::S3::Bucket", "resourceId": "someBucket"} + ] + * 101 + ) + assert ( + "Member must have length less than or equal to 100" + in ce.exception.response["Error"]["Message"] + ) + + # With invalid resource types and resources that don't exist: + result = client.batch_get_resource_config( + resourceKeys=[ + {"resourceType": "NOT::A::RESOURCE", "resourceId": "NotAThing"}, + {"resourceType": "AWS::S3::Bucket", "resourceId": "NotAThing"}, + ] + ) + + assert not result["baseConfigurationItems"] + assert not result["unprocessedResourceKeys"] + + # Create some S3 buckets: + s3_client = boto3.client("s3", region_name="us-west-2") + for x in range(0, 10): + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + + # Get them all: + keys = [ + {"resourceType": "AWS::S3::Bucket", "resourceId": "bucket{}".format(x)} + for x in range(0, 10) + ] + result = client.batch_get_resource_config(resourceKeys=keys) + assert len(result["baseConfigurationItems"]) == 10 + buckets_missing = ["bucket{}".format(x) for x in range(0, 10)] + for r in result["baseConfigurationItems"]: + buckets_missing.remove(r["resourceName"]) + + assert not buckets_missing + + # Make a bucket in a different region and verify that it does not show up in the config backend: + s3_client = boto3.client("s3", region_name="eu-west-1") + s3_client.create_bucket( + Bucket="eu-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) + keys = [{"resourceType": "AWS::S3::Bucket", "resourceId": "eu-bucket"}] + result = client.batch_get_resource_config(resourceKeys=keys) + assert not result["baseConfigurationItems"] + + +@mock_config +@mock_s3 +def test_batch_get_aggregate_resource_config(): + """NOTE: We are only really testing the Config part. For each individual service, please add tests + for that individual service's "get_config_resource" function. + """ + from moto.config.models import DEFAULT_ACCOUNT_ID + + client = boto3.client("config", region_name="us-west-2") + + # Without an aggregator: + bad_ri = { + "SourceAccountId": "000000000000", + "SourceRegion": "not-a-region", + "ResourceType": "NOT::A::RESOURCE", + "ResourceId": "nope", + } + with assert_raises(ClientError) as ce: + client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="lolno", ResourceIdentifiers=[bad_ri] + ) + assert ( + "The configuration aggregator does not exist" + in ce.exception.response["Error"]["Message"] + ) + + # Create the aggregator: + account_aggregation_source = { + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AllAwsRegions": True, + } + client.put_configuration_aggregator( + ConfigurationAggregatorName="testing", + AccountAggregationSources=[account_aggregation_source], + ) + + # With more than 100 items: + with assert_raises(ClientError) as ce: + client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=[bad_ri] * 101 + ) + assert ( + "Member must have length less than or equal to 100" + in ce.exception.response["Error"]["Message"] + ) + + # Create some S3 buckets: + s3_client = boto3.client("s3", region_name="us-west-2") + for x in range(0, 10): + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + s3_client.put_bucket_tagging( + Bucket="bucket{}".format(x), + Tagging={"TagSet": [{"Key": "Some", "Value": "Tag"}]}, + ) + + s3_client_eu = boto3.client("s3", region_name="eu-west-1") + for x in range(10, 12): + s3_client_eu.create_bucket( + Bucket="eu-bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) + s3_client.put_bucket_tagging( + Bucket="eu-bucket{}".format(x), + Tagging={"TagSet": [{"Key": "Some", "Value": "Tag"}]}, + ) + + # Now try with resources that exist and ones that don't: + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "us-west-2", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket{}".format(x), + } + for x in range(0, 10) + ] + identifiers += [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "eu-west-1", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "eu-bucket{}".format(x), + } + for x in range(10, 12) + ] + identifiers += [bad_ri] + + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert len(result["UnprocessedResourceIdentifiers"]) == 1 + assert result["UnprocessedResourceIdentifiers"][0] == bad_ri + + # Verify all the buckets are there: + assert len(result["BaseConfigurationItems"]) == 12 + missing_buckets = ["bucket{}".format(x) for x in range(0, 10)] + [ + "eu-bucket{}".format(x) for x in range(10, 12) + ] + + for r in result["BaseConfigurationItems"]: + missing_buckets.remove(r["resourceName"]) + + assert not missing_buckets + + # Verify that 'tags' is not in the result set: + for b in result["BaseConfigurationItems"]: + assert not b.get("tags") + assert json.loads( + b["supplementaryConfiguration"]["BucketTaggingConfiguration"] + ) == {"tagSets": [{"tags": {"Some": "Tag"}}]} + + # Verify that if the resource name and ID are correct that things are good: + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "us-west-2", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket1", + "ResourceName": "bucket1", + } + ] + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert not result["UnprocessedResourceIdentifiers"] + assert ( + len(result["BaseConfigurationItems"]) == 1 + and result["BaseConfigurationItems"][0]["resourceName"] == "bucket1" + ) + + # Verify that if the resource name and ID mismatch that we don't get a result: + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "us-west-2", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket1", + "ResourceName": "bucket2", + } + ] + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert not result["BaseConfigurationItems"] + assert len(result["UnprocessedResourceIdentifiers"]) == 1 + assert ( + len(result["UnprocessedResourceIdentifiers"]) == 1 + and result["UnprocessedResourceIdentifiers"][0]["ResourceName"] == "bucket2" + ) + + # Verify that if the region is incorrect that we don't get a result: + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "eu-west-1", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket1", + } + ] + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert not result["BaseConfigurationItems"] + assert len(result["UnprocessedResourceIdentifiers"]) == 1 + assert ( + len(result["UnprocessedResourceIdentifiers"]) == 1 + and result["UnprocessedResourceIdentifiers"][0]["SourceRegion"] == "eu-west-1" + ) diff --git a/tests/test_core/test_auth.py b/tests/test_core/test_auth.py index 00229f808..a8fde5d8c 100644 --- a/tests/test_core/test_auth.py +++ b/tests/test_core/test_auth.py @@ -3,201 +3,267 @@ import json import boto3 import sure # noqa from botocore.exceptions import ClientError + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises from moto import mock_iam, mock_ec2, mock_s3, mock_sts, mock_elbv2, mock_rds2 from moto.core import set_initial_no_auth_action_count -from moto.iam.models import ACCOUNT_ID +from moto.core import ACCOUNT_ID +from uuid import uuid4 @mock_iam -def create_user_with_access_key(user_name='test-user'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key(user_name="test-user"): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) - return client.create_access_key(UserName=user_name)['AccessKey'] + return client.create_access_key(UserName=user_name)["AccessKey"] @mock_iam -def create_user_with_access_key_and_inline_policy(user_name, policy_document, policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key_and_inline_policy( + user_name, policy_document, policy_name="policy1" +): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) - client.put_user_policy(UserName=user_name, PolicyName=policy_name, PolicyDocument=json.dumps(policy_document)) - return client.create_access_key(UserName=user_name)['AccessKey'] + client.put_user_policy( + UserName=user_name, + PolicyName=policy_name, + PolicyDocument=json.dumps(policy_document), + ) + return client.create_access_key(UserName=user_name)["AccessKey"] @mock_iam -def create_user_with_access_key_and_attached_policy(user_name, policy_document, policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key_and_attached_policy( + user_name, policy_document, policy_name="policy1" +): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) policy_arn = client.create_policy( - PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) - )['Policy']['Arn'] + PolicyName=policy_name, PolicyDocument=json.dumps(policy_document) + )["Policy"]["Arn"] client.attach_user_policy(UserName=user_name, PolicyArn=policy_arn) - return client.create_access_key(UserName=user_name)['AccessKey'] + return client.create_access_key(UserName=user_name)["AccessKey"] @mock_iam -def create_user_with_access_key_and_multiple_policies(user_name, inline_policy_document, - attached_policy_document, inline_policy_name='policy1', attached_policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key_and_multiple_policies( + user_name, + inline_policy_document, + attached_policy_document, + inline_policy_name="policy1", + attached_policy_name="policy1", +): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) policy_arn = client.create_policy( PolicyName=attached_policy_name, - PolicyDocument=json.dumps(attached_policy_document) - )['Policy']['Arn'] + PolicyDocument=json.dumps(attached_policy_document), + )["Policy"]["Arn"] client.attach_user_policy(UserName=user_name, PolicyArn=policy_arn) - client.put_user_policy(UserName=user_name, PolicyName=inline_policy_name, PolicyDocument=json.dumps(inline_policy_document)) - return client.create_access_key(UserName=user_name)['AccessKey'] + client.put_user_policy( + UserName=user_name, + PolicyName=inline_policy_name, + PolicyDocument=json.dumps(inline_policy_document), + ) + return client.create_access_key(UserName=user_name)["AccessKey"] -def create_group_with_attached_policy_and_add_user(user_name, policy_document, - group_name='test-group', policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_group_with_attached_policy_and_add_user( + user_name, policy_document, group_name="test-group", policy_name=None +): + if not policy_name: + policy_name = str(uuid4()) + client = boto3.client("iam", region_name="us-east-1") client.create_group(GroupName=group_name) policy_arn = client.create_policy( - PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) - )['Policy']['Arn'] + PolicyName=policy_name, PolicyDocument=json.dumps(policy_document) + )["Policy"]["Arn"] client.attach_group_policy(GroupName=group_name, PolicyArn=policy_arn) client.add_user_to_group(GroupName=group_name, UserName=user_name) -def create_group_with_inline_policy_and_add_user(user_name, policy_document, - group_name='test-group', policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_group_with_inline_policy_and_add_user( + user_name, policy_document, group_name="test-group", policy_name="policy1" +): + client = boto3.client("iam", region_name="us-east-1") client.create_group(GroupName=group_name) client.put_group_policy( GroupName=group_name, PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) + PolicyDocument=json.dumps(policy_document), ) client.add_user_to_group(GroupName=group_name, UserName=user_name) -def create_group_with_multiple_policies_and_add_user(user_name, inline_policy_document, - attached_policy_document, group_name='test-group', - inline_policy_name='policy1', attached_policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_group_with_multiple_policies_and_add_user( + user_name, + inline_policy_document, + attached_policy_document, + group_name="test-group", + inline_policy_name="policy1", + attached_policy_name=None, +): + if not attached_policy_name: + attached_policy_name = str(uuid4()) + client = boto3.client("iam", region_name="us-east-1") client.create_group(GroupName=group_name) client.put_group_policy( GroupName=group_name, PolicyName=inline_policy_name, - PolicyDocument=json.dumps(inline_policy_document) + PolicyDocument=json.dumps(inline_policy_document), ) policy_arn = client.create_policy( PolicyName=attached_policy_name, - PolicyDocument=json.dumps(attached_policy_document) - )['Policy']['Arn'] + PolicyDocument=json.dumps(attached_policy_document), + )["Policy"]["Arn"] client.attach_group_policy(GroupName=group_name, PolicyArn=policy_arn) client.add_user_to_group(GroupName=group_name, UserName=user_name) @mock_iam @mock_sts -def create_role_with_attached_policy_and_assume_it(role_name, trust_policy_document, - policy_document, session_name='session1', policy_name='policy1'): - iam_client = boto3.client('iam', region_name='us-east-1') - sts_client = boto3.client('sts', region_name='us-east-1') +def create_role_with_attached_policy_and_assume_it( + role_name, + trust_policy_document, + policy_document, + session_name="session1", + policy_name="policy1", +): + iam_client = boto3.client("iam", region_name="us-east-1") + sts_client = boto3.client("sts", region_name="us-east-1") role_arn = iam_client.create_role( - RoleName=role_name, - AssumeRolePolicyDocument=json.dumps(trust_policy_document) - )['Role']['Arn'] + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy_document) + )["Role"]["Arn"] policy_arn = iam_client.create_policy( - PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) - )['Policy']['Arn'] + PolicyName=policy_name, PolicyDocument=json.dumps(policy_document) + )["Policy"]["Arn"] iam_client.attach_role_policy(RoleName=role_name, PolicyArn=policy_arn) - return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)['Credentials'] + return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)[ + "Credentials" + ] @mock_iam @mock_sts -def create_role_with_inline_policy_and_assume_it(role_name, trust_policy_document, - policy_document, session_name='session1', policy_name='policy1'): - iam_client = boto3.client('iam', region_name='us-east-1') - sts_client = boto3.client('sts', region_name='us-east-1') +def create_role_with_inline_policy_and_assume_it( + role_name, + trust_policy_document, + policy_document, + session_name="session1", + policy_name="policy1", +): + iam_client = boto3.client("iam", region_name="us-east-1") + sts_client = boto3.client("sts", region_name="us-east-1") role_arn = iam_client.create_role( - RoleName=role_name, - AssumeRolePolicyDocument=json.dumps(trust_policy_document) - )['Role']['Arn'] + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy_document) + )["Role"]["Arn"] iam_client.put_role_policy( RoleName=role_name, PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) + PolicyDocument=json.dumps(policy_document), ) - return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)['Credentials'] + return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)[ + "Credentials" + ] @set_initial_no_auth_action_count(0) @mock_iam def test_invalid_client_token_id(): - client = boto3.client('iam', region_name='us-east-1', aws_access_key_id='invalid', aws_secret_access_key='invalid') + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id="invalid", + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.get_user() - ex.exception.response['Error']['Code'].should.equal('InvalidClientTokenId') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The security token included in the request is invalid.') + ex.exception.response["Error"]["Code"].should.equal("InvalidClientTokenId") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The security token included in the request is invalid." + ) @set_initial_no_auth_action_count(0) @mock_ec2 def test_auth_failure(): - client = boto3.client('ec2', region_name='us-east-1', aws_access_key_id='invalid', aws_secret_access_key='invalid') + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id="invalid", + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.describe_instances() - ex.exception.response['Error']['Code'].should.equal('AuthFailure') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(401) - ex.exception.response['Error']['Message'].should.equal('AWS was not able to validate the provided access credentials') + ex.exception.response["Error"]["Code"].should.equal("AuthFailure") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(401) + ex.exception.response["Error"]["Message"].should.equal( + "AWS was not able to validate the provided access credentials" + ) @set_initial_no_auth_action_count(2) @mock_iam def test_signature_does_not_match(): access_key = create_user_with_access_key() - client = boto3.client('iam', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key='invalid') + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.get_user() - ex.exception.response['Error']['Code'].should.equal('SignatureDoesNotMatch') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.') + ex.exception.response["Error"]["Code"].should.equal("SignatureDoesNotMatch") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details." + ) @set_initial_no_auth_action_count(2) @mock_ec2 def test_auth_failure_with_valid_access_key_id(): access_key = create_user_with_access_key() - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key='invalid') + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.describe_instances() - ex.exception.response['Error']['Code'].should.equal('AuthFailure') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(401) - ex.exception.response['Error']['Message'].should.equal('AWS was not able to validate the provided access credentials') + ex.exception.response["Error"]["Code"].should.equal("AuthFailure") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(401) + ex.exception.response["Error"]["Message"].should.equal( + "AWS was not able to validate the provided access credentials" + ) @set_initial_no_auth_action_count(2) @mock_ec2 def test_access_denied_with_no_policy(): - user_name = 'test-user' + user_name = "test-user" access_key = create_user_with_access_key(user_name) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.describe_instances() - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( account_id=ACCOUNT_ID, user_name=user_name, - operation="ec2:DescribeInstances" + operation="ec2:DescribeInstances", ) ) @@ -205,32 +271,29 @@ def test_access_denied_with_no_policy(): @set_initial_no_auth_action_count(3) @mock_ec2 def test_access_denied_with_not_allowing_policy(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": [ - "ec2:Describe*" - ], - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": ["ec2:Describe*"], "Resource": "*"} + ], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.run_instances(MaxCount=1, MinCount=1) - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( - account_id=ACCOUNT_ID, - user_name=user_name, - operation="ec2:RunInstances" + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( + account_id=ACCOUNT_ID, user_name=user_name, operation="ec2:RunInstances" ) ) @@ -238,37 +301,30 @@ def test_access_denied_with_not_allowing_policy(): @set_initial_no_auth_action_count(3) @mock_ec2 def test_access_denied_with_denying_policy(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": [ - "ec2:*", - ], - "Resource": "*" - }, - { - "Effect": "Deny", - "Action": "ec2:CreateVpc", - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": ["ec2:*"], "Resource": "*"}, + {"Effect": "Deny", "Action": "ec2:CreateVpc", "Resource": "*"}, + ], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.create_vpc(CidrBlock="10.0.0.0/16") - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( - account_id=ACCOUNT_ID, - user_name=user_name, - operation="ec2:CreateVpc" + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( + account_id=ACCOUNT_ID, user_name=user_name, operation="ec2:CreateVpc" ) ) @@ -276,203 +332,179 @@ def test_access_denied_with_denying_policy(): @set_initial_no_auth_action_count(3) @mock_sts def test_get_caller_identity_allowed_with_denying_policy(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Deny", - "Action": "sts:GetCallerIdentity", - "Resource": "*" - } - ] + {"Effect": "Deny", "Action": "sts:GetCallerIdentity", "Resource": "*"} + ], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('sts', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "sts", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) client.get_caller_identity().should.be.a(dict) @set_initial_no_auth_action_count(3) @mock_ec2 def test_allowed_with_wildcard_action(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "ec2:Describe*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "ec2:Describe*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) - client.describe_tags()['Tags'].should.be.empty + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) + client.describe_tags()["Tags"].should.be.empty @set_initial_no_auth_action_count(4) @mock_iam def test_allowed_with_explicit_action_in_attached_policy(): - user_name = 'test-user' + user_name = "test-user" attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "iam:ListGroups", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "iam:ListGroups", "Resource": "*"}], } - access_key = create_user_with_access_key_and_attached_policy(user_name, attached_policy_document) - client = boto3.client('iam', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) - client.list_groups()['Groups'].should.be.empty + access_key = create_user_with_access_key_and_attached_policy( + user_name, attached_policy_document + ) + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) + client.list_groups()["Groups"].should.be.empty @set_initial_no_auth_action_count(8) @mock_s3 @mock_iam def test_s3_access_denied_with_denying_attached_group_policy(): - user_name = 'test-user' + user_name = "test-user" attached_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": "s3:ListAllMyBuckets", - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": "s3:ListAllMyBuckets", "Resource": "*"} + ], } group_attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "s3:List*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "s3:List*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_attached_policy(user_name, attached_policy_document) - create_group_with_attached_policy_and_add_user(user_name, group_attached_policy_document) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_attached_policy( + user_name, attached_policy_document, policy_name="policy1" + ) + create_group_with_attached_policy_and_add_user( + user_name, group_attached_policy_document, policy_name="policy2" + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.list_buckets() - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('Access Denied') + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal("Access Denied") @set_initial_no_auth_action_count(6) @mock_s3 @mock_iam def test_s3_access_denied_with_denying_inline_group_policy(): - user_name = 'test-user' - bucket_name = 'test-bucket' + user_name = "test-user" + bucket_name = "test-bucket" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}], } group_inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "s3:GetObject", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "s3:GetObject", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - create_group_with_inline_policy_and_add_user(user_name, group_inline_policy_document) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + create_group_with_inline_policy_and_add_user( + user_name, group_inline_policy_document + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: - client.get_object(Bucket=bucket_name, Key='sdfsdf') - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('Access Denied') + client.get_object(Bucket=bucket_name, Key="sdfsdf") + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal("Access Denied") @set_initial_no_auth_action_count(10) @mock_iam @mock_ec2 def test_access_denied_with_many_irrelevant_policies(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "ec2:Describe*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "ec2:Describe*", "Resource": "*"}], } attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "s3:*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "s3:*", "Resource": "*"}], } group_inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "iam:List*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "iam:List*", "Resource": "*"}], } group_attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "lambda:*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "lambda:*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_multiple_policies(user_name, inline_policy_document, - attached_policy_document) - create_group_with_multiple_policies_and_add_user(user_name, group_inline_policy_document, - group_attached_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_multiple_policies( + user_name, + inline_policy_document, + attached_policy_document, + attached_policy_name="policy1", + ) + create_group_with_multiple_policies_and_add_user( + user_name, + group_inline_policy_document, + group_attached_policy_document, + attached_policy_name="policy2", + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.create_key_pair(KeyName="TestKey") - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( - account_id=ACCOUNT_ID, - user_name=user_name, - operation="ec2:CreateKeyPair" + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( + account_id=ACCOUNT_ID, user_name=user_name, operation="ec2:CreateKeyPair" ) ) @@ -483,14 +515,16 @@ def test_access_denied_with_many_irrelevant_policies(): @mock_ec2 @mock_elbv2 def test_allowed_with_temporary_credentials(): - role_name = 'test-role' + role_name = "test-role" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } attached_policy_document = { "Version": "2012-10-17", @@ -499,30 +533,35 @@ def test_allowed_with_temporary_credentials(): "Effect": "Allow", "Action": [ "elasticloadbalancing:CreateLoadBalancer", - "ec2:DescribeSubnets" + "ec2:DescribeSubnets", ], - "Resource": "*" + "Resource": "*", } - ] + ], } - credentials = create_role_with_attached_policy_and_assume_it(role_name, trust_policy_document, attached_policy_document) - elbv2_client = boto3.client('elbv2', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token=credentials['SessionToken']) - ec2_client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token=credentials['SessionToken']) - subnets = ec2_client.describe_subnets()['Subnets'] + credentials = create_role_with_attached_policy_and_assume_it( + role_name, trust_policy_document, attached_policy_document + ) + elbv2_client = boto3.client( + "elbv2", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + ec2_client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + subnets = ec2_client.describe_subnets()["Subnets"] len(subnets).should.be.greater_than(1) elbv2_client.create_load_balancer( - Name='test-load-balancer', - Subnets=[ - subnets[0]['SubnetId'], - subnets[1]['SubnetId'] - ] - )['LoadBalancers'].should.have.length_of(1) + Name="test-load-balancer", + Subnets=[subnets[0]["SubnetId"], subnets[1]["SubnetId"]], + )["LoadBalancers"].should.have.length_of(1) @set_initial_no_auth_action_count(3) @@ -530,48 +569,48 @@ def test_allowed_with_temporary_credentials(): @mock_sts @mock_rds2 def test_access_denied_with_temporary_credentials(): - role_name = 'test-role' - session_name = 'test-session' + role_name = "test-role" + session_name = "test-session" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } attached_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": [ - 'rds:Describe*' - ], - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": ["rds:Describe*"], "Resource": "*"} + ], } - credentials = create_role_with_inline_policy_and_assume_it(role_name, trust_policy_document, - attached_policy_document, session_name) - client = boto3.client('rds', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token=credentials['SessionToken']) + credentials = create_role_with_inline_policy_and_assume_it( + role_name, trust_policy_document, attached_policy_document, session_name + ) + client = boto3.client( + "rds", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) with assert_raises(ClientError) as ex: client.create_db_instance( - DBInstanceIdentifier='test-db-instance', - DBInstanceClass='db.t3', - Engine='aurora-postgresql' + DBInstanceIdentifier="test-db-instance", + DBInstanceClass="db.t3", + Engine="aurora-postgresql", ) - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name} is not authorized to perform: {operation}'.format( + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name} is not authorized to perform: {operation}".format( account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name, - operation="rds:CreateDBInstance" + operation="rds:CreateDBInstance", ) ) @@ -579,89 +618,95 @@ def test_access_denied_with_temporary_credentials(): @set_initial_no_auth_action_count(3) @mock_iam def test_get_user_from_credentials(): - user_name = 'new-test-user' + user_name = "new-test-user" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "iam:*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "iam:*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('iam', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) - client.get_user()['User']['UserName'].should.equal(user_name) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) + client.get_user()["User"]["UserName"].should.equal(user_name) @set_initial_no_auth_action_count(0) @mock_s3 def test_s3_invalid_access_key_id(): - client = boto3.client('s3', region_name='us-east-1', aws_access_key_id='invalid', aws_secret_access_key='invalid') + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id="invalid", + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.list_buckets() - ex.exception.response['Error']['Code'].should.equal('InvalidAccessKeyId') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The AWS Access Key Id you provided does not exist in our records.') + ex.exception.response["Error"]["Code"].should.equal("InvalidAccessKeyId") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The AWS Access Key Id you provided does not exist in our records." + ) @set_initial_no_auth_action_count(3) @mock_s3 @mock_iam def test_s3_signature_does_not_match(): - bucket_name = 'test-bucket' + bucket_name = "test-bucket" access_key = create_user_with_access_key() - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key='invalid') + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key="invalid", + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: client.put_object(Bucket=bucket_name, Key="abc") - ex.exception.response['Error']['Code'].should.equal('SignatureDoesNotMatch') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The request signature we calculated does not match the signature you provided. Check your key and signing method.') + ex.exception.response["Error"]["Code"].should.equal("SignatureDoesNotMatch") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The request signature we calculated does not match the signature you provided. Check your key and signing method." + ) @set_initial_no_auth_action_count(7) @mock_s3 @mock_iam def test_s3_access_denied_not_action(): - user_name = 'test-user' - bucket_name = 'test-bucket' + user_name = "test-user" + bucket_name = "test-bucket" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}], } group_inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "NotAction": "iam:GetUser", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "NotAction": "iam:GetUser", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - create_group_with_inline_policy_and_add_user(user_name, group_inline_policy_document) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + create_group_with_inline_policy_and_add_user( + user_name, group_inline_policy_document + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: - client.delete_object(Bucket=bucket_name, Key='sdfsdf') - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('Access Denied') + client.delete_object(Bucket=bucket_name, Key="sdfsdf") + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal("Access Denied") @set_initial_no_auth_action_count(4) @@ -669,38 +714,38 @@ def test_s3_access_denied_not_action(): @mock_sts @mock_s3 def test_s3_invalid_token_with_temporary_credentials(): - role_name = 'test-role' - session_name = 'test-session' - bucket_name = 'test-bucket-888' + role_name = "test-role" + session_name = "test-session" + bucket_name = "test-bucket-888" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": [ - '*' - ], - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": ["*"], "Resource": "*"}], } - credentials = create_role_with_inline_policy_and_assume_it(role_name, trust_policy_document, - attached_policy_document, session_name) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token='invalid') + credentials = create_role_with_inline_policy_and_assume_it( + role_name, trust_policy_document, attached_policy_document, session_name + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token="invalid", + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: client.list_bucket_metrics_configurations(Bucket=bucket_name) - ex.exception.response['Error']['Code'].should.equal('InvalidToken') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal('The provided token is malformed or otherwise invalid.') + ex.exception.response["Error"]["Code"].should.equal("InvalidToken") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "The provided token is malformed or otherwise invalid." + ) diff --git a/tests/test_core/test_context_manager.py b/tests/test_core/test_context_manager.py index 4824e021f..d20c2187f 100644 --- a/tests/test_core/test_context_manager.py +++ b/tests/test_core/test_context_manager.py @@ -5,8 +5,8 @@ from moto import mock_sqs, settings def test_context_manager_returns_mock(): with mock_sqs() as sqs_mock: - conn = boto3.client("sqs", region_name='us-west-1') + conn = boto3.client("sqs", region_name="us-west-1") conn.create_queue(QueueName="queue1") if not settings.TEST_SERVER_MODE: - list(sqs_mock.backends['us-west-1'].queues.keys()).should.equal(['queue1']) + list(sqs_mock.backends["us-west-1"].queues.keys()).should.equal(["queue1"]) diff --git a/tests/test_core/test_decorator_calls.py b/tests/test_core/test_decorator_calls.py index b7e5f7448..408ca6819 100644 --- a/tests/test_core/test_decorator_calls.py +++ b/tests/test_core/test_decorator_calls.py @@ -1,98 +1,96 @@ -from __future__ import unicode_literals -import boto -from boto.exception import EC2ResponseError -import sure # noqa -import unittest - -import tests.backport_assert_raises # noqa -from nose.tools import assert_raises - -from moto import mock_ec2_deprecated, mock_s3_deprecated - -''' -Test the different ways that the decorator can be used -''' - - -@mock_ec2_deprecated -def test_basic_connect(): - boto.connect_ec2() - - -@mock_ec2_deprecated -def test_basic_decorator(): - conn = boto.connect_ec2('the_key', 'the_secret') - list(conn.get_all_instances()).should.equal([]) - - -def test_context_manager(): - conn = boto.connect_ec2('the_key', 'the_secret') - with assert_raises(EC2ResponseError): - conn.get_all_instances() - - with mock_ec2_deprecated(): - conn = boto.connect_ec2('the_key', 'the_secret') - list(conn.get_all_instances()).should.equal([]) - - with assert_raises(EC2ResponseError): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.get_all_instances() - - -def test_decorator_start_and_stop(): - conn = boto.connect_ec2('the_key', 'the_secret') - with assert_raises(EC2ResponseError): - conn.get_all_instances() - - mock = mock_ec2_deprecated() - mock.start() - conn = boto.connect_ec2('the_key', 'the_secret') - list(conn.get_all_instances()).should.equal([]) - mock.stop() - - with assert_raises(EC2ResponseError): - conn.get_all_instances() - - -@mock_ec2_deprecated -def test_decorater_wrapped_gets_set(): - """ - Moto decorator's __wrapped__ should get set to the tests function - """ - test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal( - 'test_decorater_wrapped_gets_set') - - -@mock_ec2_deprecated -class Tester(object): - - def test_the_class(self): - conn = boto.connect_ec2() - list(conn.get_all_instances()).should.have.length_of(0) - - def test_still_the_same(self): - conn = boto.connect_ec2() - list(conn.get_all_instances()).should.have.length_of(0) - - -@mock_s3_deprecated -class TesterWithSetup(unittest.TestCase): - - def setUp(self): - self.conn = boto.connect_s3() - self.conn.create_bucket('mybucket') - - def test_still_the_same(self): - bucket = self.conn.get_bucket('mybucket') - bucket.name.should.equal("mybucket") - - -@mock_s3_deprecated -class TesterWithStaticmethod(object): - - @staticmethod - def static(*args): - assert not args or not isinstance(args[0], TesterWithStaticmethod) - - def test_no_instance_sent_to_staticmethod(self): - self.static() +from __future__ import unicode_literals +import boto +from boto.exception import EC2ResponseError +import sure # noqa +import unittest + +import tests.backport_assert_raises # noqa +from nose.tools import assert_raises + +from moto import mock_ec2_deprecated, mock_s3_deprecated + +""" +Test the different ways that the decorator can be used +""" + + +@mock_ec2_deprecated +def test_basic_connect(): + boto.connect_ec2() + + +@mock_ec2_deprecated +def test_basic_decorator(): + conn = boto.connect_ec2("the_key", "the_secret") + list(conn.get_all_instances()).should.equal([]) + + +def test_context_manager(): + conn = boto.connect_ec2("the_key", "the_secret") + with assert_raises(EC2ResponseError): + conn.get_all_instances() + + with mock_ec2_deprecated(): + conn = boto.connect_ec2("the_key", "the_secret") + list(conn.get_all_instances()).should.equal([]) + + with assert_raises(EC2ResponseError): + conn = boto.connect_ec2("the_key", "the_secret") + conn.get_all_instances() + + +def test_decorator_start_and_stop(): + conn = boto.connect_ec2("the_key", "the_secret") + with assert_raises(EC2ResponseError): + conn.get_all_instances() + + mock = mock_ec2_deprecated() + mock.start() + conn = boto.connect_ec2("the_key", "the_secret") + list(conn.get_all_instances()).should.equal([]) + mock.stop() + + with assert_raises(EC2ResponseError): + conn.get_all_instances() + + +@mock_ec2_deprecated +def test_decorater_wrapped_gets_set(): + """ + Moto decorator's __wrapped__ should get set to the tests function + """ + test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal( + "test_decorater_wrapped_gets_set" + ) + + +@mock_ec2_deprecated +class Tester(object): + def test_the_class(self): + conn = boto.connect_ec2() + list(conn.get_all_instances()).should.have.length_of(0) + + def test_still_the_same(self): + conn = boto.connect_ec2() + list(conn.get_all_instances()).should.have.length_of(0) + + +@mock_s3_deprecated +class TesterWithSetup(unittest.TestCase): + def setUp(self): + self.conn = boto.connect_s3() + self.conn.create_bucket("mybucket") + + def test_still_the_same(self): + bucket = self.conn.get_bucket("mybucket") + bucket.name.should.equal("mybucket") + + +@mock_s3_deprecated +class TesterWithStaticmethod(object): + @staticmethod + def static(*args): + assert not args or not isinstance(args[0], TesterWithStaticmethod) + + def test_no_instance_sent_to_staticmethod(self): + self.static() diff --git a/tests/test_core/test_instance_metadata.py b/tests/test_core/test_instance_metadata.py index b66f9637e..d30138d5d 100644 --- a/tests/test_core/test_instance_metadata.py +++ b/tests/test_core/test_instance_metadata.py @@ -1,46 +1,48 @@ -from __future__ import unicode_literals -import sure # noqa -from nose.tools import assert_raises -import requests - -from moto import mock_ec2, settings - -if settings.TEST_SERVER_MODE: - BASE_URL = 'http://localhost:5000' -else: - BASE_URL = 'http://169.254.169.254' - - -@mock_ec2 -def test_latest_meta_data(): - res = requests.get("{0}/latest/meta-data/".format(BASE_URL)) - res.content.should.equal(b"iam") - - -@mock_ec2 -def test_meta_data_iam(): - res = requests.get("{0}/latest/meta-data/iam".format(BASE_URL)) - json_response = res.json() - default_role = json_response['security-credentials']['default-role'] - default_role.should.contain('AccessKeyId') - default_role.should.contain('SecretAccessKey') - default_role.should.contain('Token') - default_role.should.contain('Expiration') - - -@mock_ec2 -def test_meta_data_security_credentials(): - res = requests.get( - "{0}/latest/meta-data/iam/security-credentials/".format(BASE_URL)) - res.content.should.equal(b"default-role") - - -@mock_ec2 -def test_meta_data_default_role(): - res = requests.get( - "{0}/latest/meta-data/iam/security-credentials/default-role".format(BASE_URL)) - json_response = res.json() - json_response.should.contain('AccessKeyId') - json_response.should.contain('SecretAccessKey') - json_response.should.contain('Token') - json_response.should.contain('Expiration') +from __future__ import unicode_literals +import sure # noqa +from nose.tools import assert_raises +import requests + +from moto import mock_ec2, settings + +if settings.TEST_SERVER_MODE: + BASE_URL = "http://localhost:5000" +else: + BASE_URL = "http://169.254.169.254" + + +@mock_ec2 +def test_latest_meta_data(): + res = requests.get("{0}/latest/meta-data/".format(BASE_URL)) + res.content.should.equal(b"iam") + + +@mock_ec2 +def test_meta_data_iam(): + res = requests.get("{0}/latest/meta-data/iam".format(BASE_URL)) + json_response = res.json() + default_role = json_response["security-credentials"]["default-role"] + default_role.should.contain("AccessKeyId") + default_role.should.contain("SecretAccessKey") + default_role.should.contain("Token") + default_role.should.contain("Expiration") + + +@mock_ec2 +def test_meta_data_security_credentials(): + res = requests.get( + "{0}/latest/meta-data/iam/security-credentials/".format(BASE_URL) + ) + res.content.should.equal(b"default-role") + + +@mock_ec2 +def test_meta_data_default_role(): + res = requests.get( + "{0}/latest/meta-data/iam/security-credentials/default-role".format(BASE_URL) + ) + json_response = res.json() + json_response.should.contain("AccessKeyId") + json_response.should.contain("SecretAccessKey") + json_response.should.contain("Token") + json_response.should.contain("Expiration") diff --git a/tests/test_core/test_moto_api.py b/tests/test_core/test_moto_api.py index 47dbe5a4a..6482d903e 100644 --- a/tests/test_core/test_moto_api.py +++ b/tests/test_core/test_moto_api.py @@ -1,33 +1,37 @@ -from __future__ import unicode_literals -import sure # noqa -from nose.tools import assert_raises -import requests - -import boto3 -from moto import mock_sqs, settings - -base_url = "http://localhost:5000" if settings.TEST_SERVER_MODE else "http://motoapi.amazonaws.com" - - -@mock_sqs -def test_reset_api(): - conn = boto3.client("sqs", region_name='us-west-1') - conn.create_queue(QueueName="queue1") - conn.list_queues()['QueueUrls'].should.have.length_of(1) - - res = requests.post("{base_url}/moto-api/reset".format(base_url=base_url)) - res.content.should.equal(b'{"status": "ok"}') - - conn.list_queues().shouldnt.contain('QueueUrls') # No more queues - - -@mock_sqs -def test_data_api(): - conn = boto3.client("sqs", region_name='us-west-1') - conn.create_queue(QueueName="queue1") - - res = requests.post("{base_url}/moto-api/data.json".format(base_url=base_url)) - queues = res.json()['sqs']['Queue'] - len(queues).should.equal(1) - queue = queues[0] - queue['name'].should.equal("queue1") +from __future__ import unicode_literals +import sure # noqa +from nose.tools import assert_raises +import requests + +import boto3 +from moto import mock_sqs, settings + +base_url = ( + "http://localhost:5000" + if settings.TEST_SERVER_MODE + else "http://motoapi.amazonaws.com" +) + + +@mock_sqs +def test_reset_api(): + conn = boto3.client("sqs", region_name="us-west-1") + conn.create_queue(QueueName="queue1") + conn.list_queues()["QueueUrls"].should.have.length_of(1) + + res = requests.post("{base_url}/moto-api/reset".format(base_url=base_url)) + res.content.should.equal(b'{"status": "ok"}') + + conn.list_queues().shouldnt.contain("QueueUrls") # No more queues + + +@mock_sqs +def test_data_api(): + conn = boto3.client("sqs", region_name="us-west-1") + conn.create_queue(QueueName="queue1") + + res = requests.post("{base_url}/moto-api/data.json".format(base_url=base_url)) + queues = res.json()["sqs"]["Queue"] + len(queues).should.equal(1) + queue = queues[0] + queue["name"].should.equal("queue1") diff --git a/tests/test_core/test_nested.py b/tests/test_core/test_nested.py index ec10a69b9..04b04257c 100644 --- a/tests/test_core/test_nested.py +++ b/tests/test_core/test_nested.py @@ -1,29 +1,28 @@ -from __future__ import unicode_literals -import unittest - -from boto.sqs.connection import SQSConnection -from boto.sqs.message import Message -from boto.ec2 import EC2Connection - -from moto import mock_sqs_deprecated, mock_ec2_deprecated - - -class TestNestedDecorators(unittest.TestCase): - - @mock_sqs_deprecated - def setup_sqs_queue(self): - conn = SQSConnection() - q = conn.create_queue('some-queue') - - m = Message() - m.set_body('This is my first message.') - q.write(m) - - self.assertEqual(q.count(), 1) - - @mock_ec2_deprecated - def test_nested(self): - self.setup_sqs_queue() - - conn = EC2Connection() - conn.run_instances('ami-123456') +from __future__ import unicode_literals +import unittest + +from boto.sqs.connection import SQSConnection +from boto.sqs.message import Message +from boto.ec2 import EC2Connection + +from moto import mock_sqs_deprecated, mock_ec2_deprecated + + +class TestNestedDecorators(unittest.TestCase): + @mock_sqs_deprecated + def setup_sqs_queue(self): + conn = SQSConnection() + q = conn.create_queue("some-queue") + + m = Message() + m.set_body("This is my first message.") + q.write(m) + + self.assertEqual(q.count(), 1) + + @mock_ec2_deprecated + def test_nested(self): + self.setup_sqs_queue() + + conn = EC2Connection() + conn.run_instances("ami-123456") diff --git a/tests/test_core/test_request_mocking.py b/tests/test_core/test_request_mocking.py new file mode 100644 index 000000000..2c44d52ce --- /dev/null +++ b/tests/test_core/test_request_mocking.py @@ -0,0 +1,23 @@ +import requests +import sure # noqa + +import boto3 +from moto import mock_sqs, settings + + +@mock_sqs +def test_passthrough_requests(): + conn = boto3.client("sqs", region_name="us-west-1") + conn.create_queue(QueueName="queue1") + + res = requests.get("https://httpbin.org/ip") + assert res.status_code == 200 + + +if not settings.TEST_SERVER_MODE: + + @mock_sqs + def test_requests_to_amazon_subdomains_dont_work(): + res = requests.get("https://fakeservice.amazonaws.com/foo/bar") + assert res.content == b"The method is not implemented" + assert res.status_code == 400 diff --git a/tests/test_core/test_responses.py b/tests/test_core/test_responses.py index d0f672ab8..587e3584b 100644 --- a/tests/test_core/test_responses.py +++ b/tests/test_core/test_responses.py @@ -9,81 +9,86 @@ from moto.core.responses import flatten_json_request_body def test_flatten_json_request_body(): - spec = AWSServiceSpec( - 'data/emr/2009-03-31/service-2.json').input_spec('RunJobFlow') + spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json").input_spec("RunJobFlow") body = { - 'Name': 'cluster', - 'Instances': { - 'Ec2KeyName': 'ec2key', - 'InstanceGroups': [ - {'InstanceRole': 'MASTER', - 'InstanceType': 'm1.small'}, - {'InstanceRole': 'CORE', - 'InstanceType': 'm1.medium'}, + "Name": "cluster", + "Instances": { + "Ec2KeyName": "ec2key", + "InstanceGroups": [ + {"InstanceRole": "MASTER", "InstanceType": "m1.small"}, + {"InstanceRole": "CORE", "InstanceType": "m1.medium"}, ], - 'Placement': {'AvailabilityZone': 'us-east-1'}, + "Placement": {"AvailabilityZone": "us-east-1"}, }, - 'Steps': [ - {'HadoopJarStep': { - 'Properties': [ - {'Key': 'k1', 'Value': 'v1'}, - {'Key': 'k2', 'Value': 'v2'} - ], - 'Args': ['arg1', 'arg2']}}, + "Steps": [ + { + "HadoopJarStep": { + "Properties": [ + {"Key": "k1", "Value": "v1"}, + {"Key": "k2", "Value": "v2"}, + ], + "Args": ["arg1", "arg2"], + } + } + ], + "Configurations": [ + { + "Classification": "class", + "Properties": {"propkey1": "propkey1", "propkey2": "propkey2"}, + }, + {"Classification": "anotherclass", "Properties": {"propkey3": "propkey3"}}, ], - 'Configurations': [ - {'Classification': 'class', - 'Properties': {'propkey1': 'propkey1', - 'propkey2': 'propkey2'}}, - {'Classification': 'anotherclass', - 'Properties': {'propkey3': 'propkey3'}}, - ] } - flat = flatten_json_request_body('', body, spec) - flat['Name'].should.equal(body['Name']) - flat['Instances.Ec2KeyName'].should.equal(body['Instances']['Ec2KeyName']) + flat = flatten_json_request_body("", body, spec) + flat["Name"].should.equal(body["Name"]) + flat["Instances.Ec2KeyName"].should.equal(body["Instances"]["Ec2KeyName"]) for idx in range(2): - flat['Instances.InstanceGroups.member.' + str(idx + 1) + '.InstanceRole'].should.equal( - body['Instances']['InstanceGroups'][idx]['InstanceRole']) - flat['Instances.InstanceGroups.member.' + str(idx + 1) + '.InstanceType'].should.equal( - body['Instances']['InstanceGroups'][idx]['InstanceType']) - flat['Instances.Placement.AvailabilityZone'].should.equal( - body['Instances']['Placement']['AvailabilityZone']) + flat[ + "Instances.InstanceGroups.member." + str(idx + 1) + ".InstanceRole" + ].should.equal(body["Instances"]["InstanceGroups"][idx]["InstanceRole"]) + flat[ + "Instances.InstanceGroups.member." + str(idx + 1) + ".InstanceType" + ].should.equal(body["Instances"]["InstanceGroups"][idx]["InstanceType"]) + flat["Instances.Placement.AvailabilityZone"].should.equal( + body["Instances"]["Placement"]["AvailabilityZone"] + ) for idx in range(1): - prefix = 'Steps.member.' + str(idx + 1) + '.HadoopJarStep' - step = body['Steps'][idx]['HadoopJarStep'] + prefix = "Steps.member." + str(idx + 1) + ".HadoopJarStep" + step = body["Steps"][idx]["HadoopJarStep"] i = 0 - while prefix + '.Properties.member.' + str(i + 1) + '.Key' in flat: - flat[prefix + '.Properties.member.' + - str(i + 1) + '.Key'].should.equal(step['Properties'][i]['Key']) - flat[prefix + '.Properties.member.' + - str(i + 1) + '.Value'].should.equal(step['Properties'][i]['Value']) + while prefix + ".Properties.member." + str(i + 1) + ".Key" in flat: + flat[prefix + ".Properties.member." + str(i + 1) + ".Key"].should.equal( + step["Properties"][i]["Key"] + ) + flat[prefix + ".Properties.member." + str(i + 1) + ".Value"].should.equal( + step["Properties"][i]["Value"] + ) i += 1 i = 0 - while prefix + '.Args.member.' + str(i + 1) in flat: - flat[prefix + '.Args.member.' + - str(i + 1)].should.equal(step['Args'][i]) + while prefix + ".Args.member." + str(i + 1) in flat: + flat[prefix + ".Args.member." + str(i + 1)].should.equal(step["Args"][i]) i += 1 for idx in range(2): - flat['Configurations.member.' + str(idx + 1) + '.Classification'].should.equal( - body['Configurations'][idx]['Classification']) + flat["Configurations.member." + str(idx + 1) + ".Classification"].should.equal( + body["Configurations"][idx]["Classification"] + ) props = {} i = 1 - keyfmt = 'Configurations.member.{0}.Properties.entry.{1}' + keyfmt = "Configurations.member.{0}.Properties.entry.{1}" key = keyfmt.format(idx + 1, i) - while key + '.key' in flat: - props[flat[key + '.key']] = flat[key + '.value'] + while key + ".key" in flat: + props[flat[key + ".key"]] = flat[key + ".value"] i += 1 key = keyfmt.format(idx + 1, i) - props.should.equal(body['Configurations'][idx]['Properties']) + props.should.equal(body["Configurations"][idx]["Properties"]) def test_parse_qs_unicode_decode_error(): body = b'{"key": "%D0"}, "C": "#0 = :0"}' - request = AWSPreparedRequest('GET', 'http://request', {'foo': 'bar'}, body, False) + request = AWSPreparedRequest("GET", "http://request", {"foo": "bar"}, body, False) BaseResponse().setup_class(request, request.url, request.headers) diff --git a/tests/test_core/test_server.py b/tests/test_core/test_server.py index ef04ae049..5514223af 100644 --- a/tests/test_core/test_server.py +++ b/tests/test_core/test_server.py @@ -1,47 +1,49 @@ -from __future__ import unicode_literals -from mock import patch -import sure # noqa - -from moto.server import main, create_backend_app, DomainDispatcherApplication - - -def test_wrong_arguments(): - try: - main(["name", "test1", "test2", "test3"]) - assert False, ("main() when called with the incorrect number of args" - " should raise a system exit") - except SystemExit: - pass - - -@patch('moto.server.run_simple') -def test_right_arguments(run_simple): - main(["s3"]) - func_call = run_simple.call_args[0] - func_call[0].should.equal("127.0.0.1") - func_call[1].should.equal(5000) - - -@patch('moto.server.run_simple') -def test_port_argument(run_simple): - main(["s3", "--port", "8080"]) - func_call = run_simple.call_args[0] - func_call[0].should.equal("127.0.0.1") - func_call[1].should.equal(8080) - - -def test_domain_dispatched(): - dispatcher = DomainDispatcherApplication(create_backend_app) - backend_app = dispatcher.get_application( - {"HTTP_HOST": "email.us-east1.amazonaws.com"}) - keys = list(backend_app.view_functions.keys()) - keys[0].should.equal('EmailResponse.dispatch') - - -def test_domain_dispatched_with_service(): - # If we pass a particular service, always return that. - dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") - backend_app = dispatcher.get_application( - {"HTTP_HOST": "s3.us-east1.amazonaws.com"}) - keys = set(backend_app.view_functions.keys()) - keys.should.contain('ResponseObject.key_response') +from __future__ import unicode_literals +from mock import patch +import sure # noqa + +from moto.server import main, create_backend_app, DomainDispatcherApplication + + +def test_wrong_arguments(): + try: + main(["name", "test1", "test2", "test3"]) + assert False, ( + "main() when called with the incorrect number of args" + " should raise a system exit" + ) + except SystemExit: + pass + + +@patch("moto.server.run_simple") +def test_right_arguments(run_simple): + main(["s3"]) + func_call = run_simple.call_args[0] + func_call[0].should.equal("127.0.0.1") + func_call[1].should.equal(5000) + + +@patch("moto.server.run_simple") +def test_port_argument(run_simple): + main(["s3", "--port", "8080"]) + func_call = run_simple.call_args[0] + func_call[0].should.equal("127.0.0.1") + func_call[1].should.equal(8080) + + +def test_domain_dispatched(): + dispatcher = DomainDispatcherApplication(create_backend_app) + backend_app = dispatcher.get_application( + {"HTTP_HOST": "email.us-east1.amazonaws.com"} + ) + keys = list(backend_app.view_functions.keys()) + keys[0].should.equal("EmailResponse.dispatch") + + +def test_domain_dispatched_with_service(): + # If we pass a particular service, always return that. + dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") + backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"}) + keys = set(backend_app.view_functions.keys()) + keys.should.contain("ResponseObject.key_response") diff --git a/tests/test_core/test_socket.py b/tests/test_core/test_socket.py index 2e73d7b5f..5e446ca1a 100644 --- a/tests/test_core/test_socket.py +++ b/tests/test_core/test_socket.py @@ -6,16 +6,16 @@ from six import PY3 class TestSocketPair(unittest.TestCase): - @mock_dynamodb2_deprecated def test_asyncio_deprecated(self): if PY3: self.assertIn( - 'moto.packages.httpretty.core.fakesock.socket', + "moto.packages.httpretty.core.fakesock.socket", str(socket.socket), - 'Our mock should be present' + "Our mock should be present", ) import asyncio + self.assertIsNotNone(asyncio.get_event_loop()) @mock_dynamodb2_deprecated @@ -24,9 +24,9 @@ class TestSocketPair(unittest.TestCase): # In Python2, the fakesocket is not set, for some reason. if PY3: self.assertIn( - 'moto.packages.httpretty.core.fakesock.socket', + "moto.packages.httpretty.core.fakesock.socket", str(socket.socket), - 'Our mock should be present' + "Our mock should be present", ) a, b = socket.socketpair() self.assertIsNotNone(a) @@ -36,7 +36,6 @@ class TestSocketPair(unittest.TestCase): if b: b.close() - @mock_dynamodb2 def test_socket_pair(self): a, b = socket.socketpair() diff --git a/tests/test_core/test_url_mapping.py b/tests/test_core/test_url_mapping.py index b58e991c4..4dccc4f21 100644 --- a/tests/test_core/test_url_mapping.py +++ b/tests/test_core/test_url_mapping.py @@ -1,22 +1,23 @@ -from __future__ import unicode_literals -import sure # noqa - -from moto.core.utils import convert_regex_to_flask_path - - -def test_flask_path_converting_simple(): - convert_regex_to_flask_path("/").should.equal("/") - convert_regex_to_flask_path("/$").should.equal("/") - - convert_regex_to_flask_path("/foo").should.equal("/foo") - - convert_regex_to_flask_path("/foo/bar/").should.equal("/foo/bar/") - - -def test_flask_path_converting_regex(): - convert_regex_to_flask_path( - "/(?P[a-zA-Z0-9\-_]+)").should.equal('/') - - convert_regex_to_flask_path("(?P\d+)/(?P.*)$").should.equal( - '/' - ) +from __future__ import unicode_literals +import sure # noqa + +from moto.core.utils import convert_regex_to_flask_path + + +def test_flask_path_converting_simple(): + convert_regex_to_flask_path("/").should.equal("/") + convert_regex_to_flask_path("/$").should.equal("/") + + convert_regex_to_flask_path("/foo").should.equal("/foo") + + convert_regex_to_flask_path("/foo/bar/").should.equal("/foo/bar/") + + +def test_flask_path_converting_regex(): + convert_regex_to_flask_path("/(?P[a-zA-Z0-9\-_]+)").should.equal( + '/' + ) + + convert_regex_to_flask_path("(?P\d+)/(?P.*)$").should.equal( + '/' + ) diff --git a/tests/test_core/test_utils.py b/tests/test_core/test_utils.py index 22449a910..d0dd97688 100644 --- a/tests/test_core/test_utils.py +++ b/tests/test_core/test_utils.py @@ -1,30 +1,62 @@ -from __future__ import unicode_literals - -import sure # noqa -from freezegun import freeze_time - -from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase, unix_time - - -def test_camelcase_to_underscores(): - cases = { - "theNewAttribute": "the_new_attribute", - "attri bute With Space": "attribute_with_space", - "FirstLetterCapital": "first_letter_capital", - "ListMFADevices": "list_mfa_devices", - } - for arg, expected in cases.items(): - camelcase_to_underscores(arg).should.equal(expected) - - -def test_underscores_to_camelcase(): - cases = { - "the_new_attribute": "theNewAttribute", - } - for arg, expected in cases.items(): - underscores_to_camelcase(arg).should.equal(expected) - - -@freeze_time("2015-01-01 12:00:00") -def test_unix_time(): - unix_time().should.equal(1420113600.0) +from __future__ import unicode_literals + +import copy +import sys + +import sure # noqa +from freezegun import freeze_time + +from moto.core.utils import ( + camelcase_to_underscores, + underscores_to_camelcase, + unix_time, + py2_strip_unicode_keys, +) + + +def test_camelcase_to_underscores(): + cases = { + "theNewAttribute": "the_new_attribute", + "attri bute With Space": "attribute_with_space", + "FirstLetterCapital": "first_letter_capital", + "ListMFADevices": "list_mfa_devices", + } + for arg, expected in cases.items(): + camelcase_to_underscores(arg).should.equal(expected) + + +def test_underscores_to_camelcase(): + cases = {"the_new_attribute": "theNewAttribute"} + for arg, expected in cases.items(): + underscores_to_camelcase(arg).should.equal(expected) + + +@freeze_time("2015-01-01 12:00:00") +def test_unix_time(): + unix_time().should.equal(1420113600.0) + + +if sys.version_info[0] < 3: + # Tests for unicode removals (Python 2 only) + def _verify_no_unicode(blob): + """Verify that no unicode values exist""" + if type(blob) == dict: + for key, value in blob.items(): + assert type(key) != unicode + _verify_no_unicode(value) + + elif type(blob) in [list, set]: + for item in blob: + _verify_no_unicode(item) + + assert blob != unicode + + def test_py2_strip_unicode_keys(): + bad_dict = { + "some": "value", + "a": {"nested": ["List", "of", {"unicode": "values"}]}, + "and a": {"nested", "set", "of", 5, "values"}, + } + + result = py2_strip_unicode_keys(copy.deepcopy(bad_dict)) + _verify_no_unicode(result) diff --git a/tests/test_datapipeline/test_datapipeline.py b/tests/test_datapipeline/test_datapipeline.py index 7cf76f5d7..b540d120e 100644 --- a/tests/test_datapipeline/test_datapipeline.py +++ b/tests/test_datapipeline/test_datapipeline.py @@ -9,8 +9,8 @@ from moto.datapipeline.utils import remove_capitalization_of_dict_keys def get_value_from_fields(key, fields): for field in fields: - if field['key'] == key: - return field['stringValue'] + if field["key"] == key: + return field["stringValue"] @mock_datapipeline_deprecated @@ -20,62 +20,46 @@ def test_create_pipeline(): res = conn.create_pipeline("mypipeline", "some-unique-id") pipeline_id = res["pipelineId"] - pipeline_descriptions = conn.describe_pipelines( - [pipeline_id])["pipelineDescriptionList"] + pipeline_descriptions = conn.describe_pipelines([pipeline_id])[ + "pipelineDescriptionList" + ] pipeline_descriptions.should.have.length_of(1) pipeline_description = pipeline_descriptions[0] - pipeline_description['name'].should.equal("mypipeline") + pipeline_description["name"].should.equal("mypipeline") pipeline_description["pipelineId"].should.equal(pipeline_id) - fields = pipeline_description['fields'] + fields = pipeline_description["fields"] - get_value_from_fields('@pipelineState', fields).should.equal("PENDING") - get_value_from_fields('uniqueId', fields).should.equal("some-unique-id") + get_value_from_fields("@pipelineState", fields).should.equal("PENDING") + get_value_from_fields("uniqueId", fields).should.equal("some-unique-id") PIPELINE_OBJECTS = [ { "id": "Default", "name": "Default", - "fields": [{ - "key": "workerGroup", - "stringValue": "workerGroup" - }] + "fields": [{"key": "workerGroup", "stringValue": "workerGroup"}], }, { "id": "Schedule", "name": "Schedule", - "fields": [{ - "key": "startDateTime", - "stringValue": "2012-12-12T00:00:00" - }, { - "key": "type", - "stringValue": "Schedule" - }, { - "key": "period", - "stringValue": "1 hour" - }, { - "key": "endDateTime", - "stringValue": "2012-12-21T18:00:00" - }] + "fields": [ + {"key": "startDateTime", "stringValue": "2012-12-12T00:00:00"}, + {"key": "type", "stringValue": "Schedule"}, + {"key": "period", "stringValue": "1 hour"}, + {"key": "endDateTime", "stringValue": "2012-12-21T18:00:00"}, + ], }, { "id": "SayHello", "name": "SayHello", - "fields": [{ - "key": "type", - "stringValue": "ShellCommandActivity" - }, { - "key": "command", - "stringValue": "echo hello" - }, { - "key": "parent", - "refValue": "Default" - }, { - "key": "schedule", - "refValue": "Schedule" - }] - } + "fields": [ + {"key": "type", "stringValue": "ShellCommandActivity"}, + {"key": "command", "stringValue": "echo hello"}, + {"key": "parent", "refValue": "Default"}, + {"key": "schedule", "refValue": "Schedule"}, + ], + }, ] @@ -88,14 +72,13 @@ def test_creating_pipeline_definition(): conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id) pipeline_definition = conn.get_pipeline_definition(pipeline_id) - pipeline_definition['pipelineObjects'].should.have.length_of(3) - default_object = pipeline_definition['pipelineObjects'][0] - default_object['name'].should.equal("Default") - default_object['id'].should.equal("Default") - default_object['fields'].should.equal([{ - "key": "workerGroup", - "stringValue": "workerGroup" - }]) + pipeline_definition["pipelineObjects"].should.have.length_of(3) + default_object = pipeline_definition["pipelineObjects"][0] + default_object["name"].should.equal("Default") + default_object["id"].should.equal("Default") + default_object["fields"].should.equal( + [{"key": "workerGroup", "stringValue": "workerGroup"}] + ) @mock_datapipeline_deprecated @@ -107,15 +90,15 @@ def test_describing_pipeline_objects(): conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id) objects = conn.describe_objects(["Schedule", "Default"], pipeline_id)[ - 'pipelineObjects'] + "pipelineObjects" + ] objects.should.have.length_of(2) - default_object = [x for x in objects if x['id'] == 'Default'][0] - default_object['name'].should.equal("Default") - default_object['fields'].should.equal([{ - "key": "workerGroup", - "stringValue": "workerGroup" - }]) + default_object = [x for x in objects if x["id"] == "Default"][0] + default_object["name"].should.equal("Default") + default_object["fields"].should.equal( + [{"key": "workerGroup", "stringValue": "workerGroup"}] + ) @mock_datapipeline_deprecated @@ -127,13 +110,14 @@ def test_activate_pipeline(): pipeline_id = res["pipelineId"] conn.activate_pipeline(pipeline_id) - pipeline_descriptions = conn.describe_pipelines( - [pipeline_id])["pipelineDescriptionList"] + pipeline_descriptions = conn.describe_pipelines([pipeline_id])[ + "pipelineDescriptionList" + ] pipeline_descriptions.should.have.length_of(1) pipeline_description = pipeline_descriptions[0] - fields = pipeline_description['fields'] + fields = pipeline_description["fields"] - get_value_from_fields('@pipelineState', fields).should.equal("SCHEDULED") + get_value_from_fields("@pipelineState", fields).should.equal("SCHEDULED") @mock_datapipeline_deprecated @@ -160,14 +144,12 @@ def test_listing_pipelines(): response["hasMoreResults"].should.be(False) response["marker"].should.be.none response["pipelineIdList"].should.have.length_of(2) - response["pipelineIdList"].should.contain({ - "id": res1["pipelineId"], - "name": "mypipeline1", - }) - response["pipelineIdList"].should.contain({ - "id": res2["pipelineId"], - "name": "mypipeline2" - }) + response["pipelineIdList"].should.contain( + {"id": res1["pipelineId"], "name": "mypipeline1"} + ) + response["pipelineIdList"].should.contain( + {"id": res2["pipelineId"], "name": "mypipeline2"} + ) @mock_datapipeline_deprecated @@ -179,7 +161,7 @@ def test_listing_paginated_pipelines(): response = conn.list_pipelines() response["hasMoreResults"].should.be(True) - response["marker"].should.equal(response["pipelineIdList"][-1]['id']) + response["marker"].should.equal(response["pipelineIdList"][-1]["id"]) response["pipelineIdList"].should.have.length_of(50) @@ -188,17 +170,13 @@ def test_remove_capitalization_of_dict_keys(): result = remove_capitalization_of_dict_keys( { "Id": "IdValue", - "Fields": [{ - "Key": "KeyValue", - "StringValue": "StringValueValue" - }] + "Fields": [{"Key": "KeyValue", "StringValue": "StringValueValue"}], } ) - result.should.equal({ - "id": "IdValue", - "fields": [{ - "key": "KeyValue", - "stringValue": "StringValueValue" - }], - }) + result.should.equal( + { + "id": "IdValue", + "fields": [{"key": "KeyValue", "stringValue": "StringValueValue"}], + } + ) diff --git a/tests/test_datapipeline/test_server.py b/tests/test_datapipeline/test_server.py index 7cb2657da..49b8c39ce 100644 --- a/tests/test_datapipeline/test_server.py +++ b/tests/test_datapipeline/test_server.py @@ -1,28 +1,26 @@ -from __future__ import unicode_literals - -import json -import sure # noqa - -import moto.server as server -from moto import mock_datapipeline - -''' -Test the different server responses -''' - - -@mock_datapipeline -def test_list_streams(): - backend = server.create_backend_app("datapipeline") - test_client = backend.test_client() - - res = test_client.post('/', - data={"pipelineIds": ["ASdf"]}, - headers={ - "X-Amz-Target": "DataPipeline.DescribePipelines"}, - ) - - json_data = json.loads(res.data.decode("utf-8")) - json_data.should.equal({ - 'pipelineDescriptionList': [] - }) +from __future__ import unicode_literals + +import json +import sure # noqa + +import moto.server as server +from moto import mock_datapipeline + +""" +Test the different server responses +""" + + +@mock_datapipeline +def test_list_streams(): + backend = server.create_backend_app("datapipeline") + test_client = backend.test_client() + + res = test_client.post( + "/", + data={"pipelineIds": ["ASdf"]}, + headers={"X-Amz-Target": "DataPipeline.DescribePipelines"}, + ) + + json_data = json.loads(res.data.decode("utf-8")) + json_data.should.equal({"pipelineDescriptionList": []}) diff --git a/tests/test_datasync/__init__.py b/tests/test_datasync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_datasync/test_datasync.py b/tests/test_datasync/test_datasync.py new file mode 100644 index 000000000..e3ea87675 --- /dev/null +++ b/tests/test_datasync/test_datasync.py @@ -0,0 +1,425 @@ +import logging + +import boto +import boto3 +from botocore.exceptions import ClientError +from moto import mock_datasync +from nose.tools import assert_raises + + +def create_locations(client, create_smb=False, create_s3=False): + """ + Convenience function for creating locations. + Locations must exist before tasks can be created. + """ + smb_arn = None + s3_arn = None + if create_smb: + response = client.create_location_smb( + ServerHostname="host", + Subdirectory="somewhere", + User="", + Password="", + AgentArns=["stuff"], + ) + smb_arn = response["LocationArn"] + if create_s3: + response = client.create_location_s3( + S3BucketArn="arn:aws:s3:::my_bucket", + Subdirectory="dir", + S3Config={"BucketAccessRoleArn": "role"}, + ) + s3_arn = response["LocationArn"] + return {"smb_arn": smb_arn, "s3_arn": s3_arn} + + +@mock_datasync +def test_create_location_smb(): + client = boto3.client("datasync", region_name="us-east-1") + response = client.create_location_smb( + ServerHostname="host", + Subdirectory="somewhere", + User="", + Password="", + AgentArns=["stuff"], + ) + assert "LocationArn" in response + + +@mock_datasync +def test_describe_location_smb(): + client = boto3.client("datasync", region_name="us-east-1") + agent_arns = ["stuff"] + user = "user" + response = client.create_location_smb( + ServerHostname="host", + Subdirectory="somewhere", + User=user, + Password="", + AgentArns=agent_arns, + ) + response = client.describe_location_smb(LocationArn=response["LocationArn"]) + assert "LocationArn" in response + assert "LocationUri" in response + assert response["User"] == user + assert response["AgentArns"] == agent_arns + + +@mock_datasync +def test_create_location_s3(): + client = boto3.client("datasync", region_name="us-east-1") + response = client.create_location_s3( + S3BucketArn="arn:aws:s3:::my_bucket", + Subdirectory="dir", + S3Config={"BucketAccessRoleArn": "role"}, + ) + assert "LocationArn" in response + + +@mock_datasync +def test_describe_location_s3(): + client = boto3.client("datasync", region_name="us-east-1") + s3_config = {"BucketAccessRoleArn": "role"} + response = client.create_location_s3( + S3BucketArn="arn:aws:s3:::my_bucket", Subdirectory="dir", S3Config=s3_config + ) + response = client.describe_location_s3(LocationArn=response["LocationArn"]) + assert "LocationArn" in response + assert "LocationUri" in response + assert response["S3Config"] == s3_config + + +@mock_datasync +def test_describe_location_wrong(): + client = boto3.client("datasync", region_name="us-east-1") + agent_arns = ["stuff"] + user = "user" + response = client.create_location_smb( + ServerHostname="host", + Subdirectory="somewhere", + User=user, + Password="", + AgentArns=agent_arns, + ) + with assert_raises(ClientError) as e: + response = client.describe_location_s3(LocationArn=response["LocationArn"]) + + +@mock_datasync +def test_list_locations(): + client = boto3.client("datasync", region_name="us-east-1") + response = client.list_locations() + assert len(response["Locations"]) == 0 + + create_locations(client, create_smb=True) + response = client.list_locations() + assert len(response["Locations"]) == 1 + assert response["Locations"][0]["LocationUri"] == "smb://host/somewhere" + + create_locations(client, create_s3=True) + response = client.list_locations() + assert len(response["Locations"]) == 2 + assert response["Locations"][1]["LocationUri"] == "s3://my_bucket/dir" + + create_locations(client, create_s3=True) + response = client.list_locations() + assert len(response["Locations"]) == 3 + assert response["Locations"][2]["LocationUri"] == "s3://my_bucket/dir" + + +@mock_datasync +def test_delete_location(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_smb=True) + response = client.list_locations() + assert len(response["Locations"]) == 1 + location_arn = locations["smb_arn"] + + response = client.delete_location(LocationArn=location_arn) + response = client.list_locations() + assert len(response["Locations"]) == 0 + + with assert_raises(ClientError) as e: + response = client.delete_location(LocationArn=location_arn) + + +@mock_datasync +def test_create_task(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_smb=True, create_s3=True) + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + ) + assert "TaskArn" in response + + +@mock_datasync +def test_create_task_fail(): + """ Test that Locations must exist before a Task can be created """ + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_smb=True, create_s3=True) + with assert_raises(ClientError) as e: + response = client.create_task( + SourceLocationArn="1", DestinationLocationArn=locations["s3_arn"] + ) + with assert_raises(ClientError) as e: + response = client.create_task( + SourceLocationArn=locations["smb_arn"], DestinationLocationArn="2" + ) + + +@mock_datasync +def test_list_tasks(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + ) + response = client.create_task( + SourceLocationArn=locations["s3_arn"], + DestinationLocationArn=locations["smb_arn"], + Name="task_name", + ) + response = client.list_tasks() + tasks = response["Tasks"] + assert len(tasks) == 2 + + task = tasks[0] + assert task["Status"] == "AVAILABLE" + assert "Name" not in task + + task = tasks[1] + assert task["Status"] == "AVAILABLE" + assert task["Name"] == "task_name" + + +@mock_datasync +def test_describe_task(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + Name="task_name", + ) + task_arn = response["TaskArn"] + + response = client.describe_task(TaskArn=task_arn) + + assert "TaskArn" in response + assert "Status" in response + assert "SourceLocationArn" in response + assert "DestinationLocationArn" in response + + +@mock_datasync +def test_describe_task_not_exist(): + client = boto3.client("datasync", region_name="us-east-1") + + with assert_raises(ClientError) as e: + client.describe_task(TaskArn="abc") + + +@mock_datasync +def test_update_task(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + initial_name = "Initial_Name" + updated_name = "Updated_Name" + initial_options = { + "VerifyMode": "NONE", + "Atime": "BEST_EFFORT", + "Mtime": "PRESERVE", + } + updated_options = { + "VerifyMode": "POINT_IN_TIME_CONSISTENT", + "Atime": "BEST_EFFORT", + "Mtime": "PRESERVE", + } + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + Name=initial_name, + Options=initial_options, + ) + task_arn = response["TaskArn"] + response = client.describe_task(TaskArn=task_arn) + assert response["TaskArn"] == task_arn + assert response["Name"] == initial_name + assert response["Options"] == initial_options + + response = client.update_task( + TaskArn=task_arn, Name=updated_name, Options=updated_options + ) + + response = client.describe_task(TaskArn=task_arn) + assert response["TaskArn"] == task_arn + assert response["Name"] == updated_name + assert response["Options"] == updated_options + + with assert_raises(ClientError) as e: + client.update_task(TaskArn="doesnt_exist") + + +@mock_datasync +def test_delete_task(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + Name="task_name", + ) + + response = client.list_tasks() + assert len(response["Tasks"]) == 1 + task_arn = response["Tasks"][0]["TaskArn"] + assert task_arn is not None + + response = client.delete_task(TaskArn=task_arn) + response = client.list_tasks() + assert len(response["Tasks"]) == 0 + + with assert_raises(ClientError) as e: + response = client.delete_task(TaskArn=task_arn) + + +@mock_datasync +def test_start_task_execution(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + Name="task_name", + ) + task_arn = response["TaskArn"] + response = client.describe_task(TaskArn=task_arn) + assert "CurrentTaskExecutionArn" not in response + + response = client.start_task_execution(TaskArn=task_arn) + assert "TaskExecutionArn" in response + task_execution_arn = response["TaskExecutionArn"] + + response = client.describe_task(TaskArn=task_arn) + assert response["CurrentTaskExecutionArn"] == task_execution_arn + + +@mock_datasync +def test_start_task_execution_twice(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + Name="task_name", + ) + task_arn = response["TaskArn"] + + response = client.start_task_execution(TaskArn=task_arn) + assert "TaskExecutionArn" in response + task_execution_arn = response["TaskExecutionArn"] + + with assert_raises(ClientError) as e: + response = client.start_task_execution(TaskArn=task_arn) + + +@mock_datasync +def test_describe_task_execution(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + Name="task_name", + ) + task_arn = response["TaskArn"] + response = client.describe_task(TaskArn=task_arn) + assert response["Status"] == "AVAILABLE" + + response = client.start_task_execution(TaskArn=task_arn) + task_execution_arn = response["TaskExecutionArn"] + + # Each time task_execution is described the Status will increment + # This is a simple way to simulate a task being executed + response = client.describe_task_execution(TaskExecutionArn=task_execution_arn) + assert response["TaskExecutionArn"] == task_execution_arn + assert response["Status"] == "INITIALIZING" + response = client.describe_task(TaskArn=task_arn) + assert response["Status"] == "RUNNING" + + response = client.describe_task_execution(TaskExecutionArn=task_execution_arn) + assert response["TaskExecutionArn"] == task_execution_arn + assert response["Status"] == "PREPARING" + response = client.describe_task(TaskArn=task_arn) + assert response["Status"] == "RUNNING" + + response = client.describe_task_execution(TaskExecutionArn=task_execution_arn) + assert response["TaskExecutionArn"] == task_execution_arn + assert response["Status"] == "TRANSFERRING" + response = client.describe_task(TaskArn=task_arn) + assert response["Status"] == "RUNNING" + + response = client.describe_task_execution(TaskExecutionArn=task_execution_arn) + assert response["TaskExecutionArn"] == task_execution_arn + assert response["Status"] == "VERIFYING" + response = client.describe_task(TaskArn=task_arn) + assert response["Status"] == "RUNNING" + + response = client.describe_task_execution(TaskExecutionArn=task_execution_arn) + assert response["TaskExecutionArn"] == task_execution_arn + assert response["Status"] == "SUCCESS" + response = client.describe_task(TaskArn=task_arn) + assert response["Status"] == "AVAILABLE" + + response = client.describe_task_execution(TaskExecutionArn=task_execution_arn) + assert response["TaskExecutionArn"] == task_execution_arn + assert response["Status"] == "SUCCESS" + response = client.describe_task(TaskArn=task_arn) + assert response["Status"] == "AVAILABLE" + + +@mock_datasync +def test_describe_task_execution_not_exist(): + client = boto3.client("datasync", region_name="us-east-1") + + with assert_raises(ClientError) as e: + client.describe_task_execution(TaskExecutionArn="abc") + + +@mock_datasync +def test_cancel_task_execution(): + client = boto3.client("datasync", region_name="us-east-1") + locations = create_locations(client, create_s3=True, create_smb=True) + + response = client.create_task( + SourceLocationArn=locations["smb_arn"], + DestinationLocationArn=locations["s3_arn"], + Name="task_name", + ) + task_arn = response["TaskArn"] + + response = client.start_task_execution(TaskArn=task_arn) + task_execution_arn = response["TaskExecutionArn"] + + response = client.describe_task(TaskArn=task_arn) + assert response["CurrentTaskExecutionArn"] == task_execution_arn + assert response["Status"] == "RUNNING" + + response = client.cancel_task_execution(TaskExecutionArn=task_execution_arn) + + response = client.describe_task(TaskArn=task_arn) + assert "CurrentTaskExecutionArn" not in response + assert response["Status"] == "AVAILABLE" + + response = client.describe_task_execution(TaskExecutionArn=task_execution_arn) + assert response["Status"] == "ERROR" diff --git a/tests/test_dynamodb/test_dynamodb.py b/tests/test_dynamodb/test_dynamodb.py index 2c675756f..931e57e06 100644 --- a/tests/test_dynamodb/test_dynamodb.py +++ b/tests/test_dynamodb/test_dynamodb.py @@ -1,54 +1,51 @@ -from __future__ import unicode_literals -import six -import boto -import boto.dynamodb -import sure # noqa -import requests -import tests.backport_assert_raises -from nose.tools import assert_raises - -from moto import mock_dynamodb, mock_dynamodb_deprecated -from moto.dynamodb import dynamodb_backend - -from boto.exception import DynamoDBResponseError - - -@mock_dynamodb_deprecated -def test_list_tables(): - name = 'TestTable' - dynamodb_backend.create_table( - name, hash_key_attr="name", hash_key_type="S") - conn = boto.connect_dynamodb('the_key', 'the_secret') - assert conn.list_tables() == ['TestTable'] - - -@mock_dynamodb_deprecated -def test_list_tables_layer_1(): - dynamodb_backend.create_table( - "test_1", hash_key_attr="name", hash_key_type="S") - dynamodb_backend.create_table( - "test_2", hash_key_attr="name", hash_key_type="S") - conn = boto.connect_dynamodb('the_key', 'the_secret') - res = conn.layer1.list_tables(limit=1) - expected = {"TableNames": ["test_1"], "LastEvaluatedTableName": "test_1"} - res.should.equal(expected) - - res = conn.layer1.list_tables(limit=1, start_table="test_1") - expected = {"TableNames": ["test_2"]} - res.should.equal(expected) - - -@mock_dynamodb_deprecated -def test_describe_missing_table(): - conn = boto.connect_dynamodb('the_key', 'the_secret') - with assert_raises(DynamoDBResponseError): - conn.describe_table('messages') - - -@mock_dynamodb_deprecated -def test_dynamodb_with_connect_to_region(): - # this will work if connected with boto.connect_dynamodb() - dynamodb = boto.dynamodb.connect_to_region('us-west-2') - - schema = dynamodb.create_schema('column1', str(), 'column2', int()) - dynamodb.create_table('table1', schema, 200, 200) +from __future__ import unicode_literals +import six +import boto +import boto.dynamodb +import sure # noqa +import requests +import tests.backport_assert_raises +from nose.tools import assert_raises + +from moto import mock_dynamodb, mock_dynamodb_deprecated +from moto.dynamodb import dynamodb_backend + +from boto.exception import DynamoDBResponseError + + +@mock_dynamodb_deprecated +def test_list_tables(): + name = "TestTable" + dynamodb_backend.create_table(name, hash_key_attr="name", hash_key_type="S") + conn = boto.connect_dynamodb("the_key", "the_secret") + assert conn.list_tables() == ["TestTable"] + + +@mock_dynamodb_deprecated +def test_list_tables_layer_1(): + dynamodb_backend.create_table("test_1", hash_key_attr="name", hash_key_type="S") + dynamodb_backend.create_table("test_2", hash_key_attr="name", hash_key_type="S") + conn = boto.connect_dynamodb("the_key", "the_secret") + res = conn.layer1.list_tables(limit=1) + expected = {"TableNames": ["test_1"], "LastEvaluatedTableName": "test_1"} + res.should.equal(expected) + + res = conn.layer1.list_tables(limit=1, start_table="test_1") + expected = {"TableNames": ["test_2"]} + res.should.equal(expected) + + +@mock_dynamodb_deprecated +def test_describe_missing_table(): + conn = boto.connect_dynamodb("the_key", "the_secret") + with assert_raises(DynamoDBResponseError): + conn.describe_table("messages") + + +@mock_dynamodb_deprecated +def test_dynamodb_with_connect_to_region(): + # this will work if connected with boto.connect_dynamodb() + dynamodb = boto.dynamodb.connect_to_region("us-west-2") + + schema = dynamodb.create_schema("column1", str(), "column2", int()) + dynamodb.create_table("table1", schema, 200, 200) diff --git a/tests/test_dynamodb/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb/test_dynamodb_table_with_range_key.py index ee6738934..6986ae9b3 100644 --- a/tests/test_dynamodb/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb/test_dynamodb_table_with_range_key.py @@ -13,17 +13,14 @@ from boto.exception import DynamoDBResponseError def create_table(conn): message_table_schema = conn.create_schema( - hash_key_name='forum_name', + hash_key_name="forum_name", hash_key_proto_value=str, - range_key_name='subject', - range_key_proto_value=str + range_key_name="subject", + range_key_proto_value=str, ) table = conn.create_table( - name='messages', - schema=message_table_schema, - read_units=10, - write_units=10 + name="messages", schema=message_table_schema, read_units=10, write_units=10 ) return table @@ -35,29 +32,23 @@ def test_create_table(): create_table(conn) expected = { - 'Table': { - 'CreationDateTime': 1326499200.0, - 'ItemCount': 0, - 'KeySchema': { - 'HashKeyElement': { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - 'RangeKeyElement': { - 'AttributeName': 'subject', - 'AttributeType': 'S' - } + "Table": { + "CreationDateTime": 1326499200.0, + "ItemCount": 0, + "KeySchema": { + "HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"}, + "RangeKeyElement": {"AttributeName": "subject", "AttributeType": "S"}, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 10, }, - 'TableName': 'messages', - 'TableSizeBytes': 0, - 'TableStatus': 'ACTIVE' + "TableName": "messages", + "TableSizeBytes": 0, + "TableStatus": "ACTIVE", } } - conn.describe_table('messages').should.equal(expected) + conn.describe_table("messages").should.equal(expected) @mock_dynamodb_deprecated @@ -66,11 +57,12 @@ def test_delete_table(): create_table(conn) conn.list_tables().should.have.length_of(1) - conn.layer1.delete_table('messages') + conn.layer1.delete_table("messages") conn.list_tables().should.have.length_of(0) - conn.layer1.delete_table.when.called_with( - 'messages').should.throw(DynamoDBResponseError) + conn.layer1.delete_table.when.called_with("messages").should.throw( + DynamoDBResponseError + ) @mock_dynamodb_deprecated @@ -93,45 +85,47 @@ def test_item_add_and_describe_and_update(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = table.new_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attrs=item_data, + hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data ) item.put() table.has_item("LOLCat Forum", "Check this out!").should.equal(True) returned_item = table.get_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", + range_key="Check this out!", + attributes_to_get=["Body", "SentBy"], + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - }) - item['SentBy'] = 'User B' + item["SentBy"] = "User B" item.put() returned_item = table.get_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", + range_key="Check this out!", + attributes_to_get=["Body", "SentBy"], + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @mock_dynamodb_deprecated @@ -139,11 +133,8 @@ def test_item_put_without_table(): conn = boto.connect_dynamodb() conn.layer1.put_item.when.called_with( - table_name='undeclared-table', - item=dict( - hash_key='LOLCat Forum', - range_key='Check this out!', - ), + table_name="undeclared-table", + item=dict(hash_key="LOLCat Forum", range_key="Check this out!"), ).should.throw(DynamoDBResponseError) @@ -152,10 +143,9 @@ def test_get_missing_item(): conn = boto.connect_dynamodb() table = create_table(conn) - table.get_item.when.called_with( - hash_key='tester', - range_key='other', - ).should.throw(DynamoDBKeyNotFoundError) + table.get_item.when.called_with(hash_key="tester", range_key="other").should.throw( + DynamoDBKeyNotFoundError + ) table.has_item("foobar", "more").should.equal(False) @@ -164,11 +154,8 @@ def test_get_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.get_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - 'RangeKeyElement': {'S': 'test-range'}, - }, + table_name="undeclared-table", + key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}}, ).should.throw(DynamoDBKeyNotFoundError) @@ -182,10 +169,7 @@ def test_get_item_without_range_key(): range_key_proto_value=int, ) table = conn.create_table( - name='messages', - schema=message_table_schema, - read_units=10, - write_units=10 + name="messages", schema=message_table_schema, read_units=10, write_units=10 ) hash_key = 3241526475 @@ -193,8 +177,9 @@ def test_get_item_without_range_key(): new_item = table.new_item(hash_key=hash_key, range_key=range_key) new_item.put() - table.get_item.when.called_with( - hash_key=hash_key).should.throw(DynamoDBValidationError) + table.get_item.when.called_with(hash_key=hash_key).should.throw( + DynamoDBValidationError + ) @mock_dynamodb_deprecated @@ -203,14 +188,12 @@ def test_delete_item(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = table.new_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attrs=item_data, + hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data ) item.put() @@ -218,7 +201,7 @@ def test_delete_item(): table.item_count.should.equal(1) response = item.delete() - response.should.equal({u'Attributes': [], u'ConsumedCapacityUnits': 0.5}) + response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5}) table.refresh() table.item_count.should.equal(0) @@ -231,31 +214,31 @@ def test_delete_item_with_attribute_response(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = table.new_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attrs=item_data, + hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data ) item.put() table.refresh() table.item_count.should.equal(1) - response = item.delete(return_values='ALL_OLD') - response.should.equal({ - 'Attributes': { - 'Body': 'http://url_to_lolcat.gif', - 'forum_name': 'LOLCat Forum', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'SentBy': 'User A', - 'subject': 'Check this out!' - }, - 'ConsumedCapacityUnits': 0.5 - }) + response = item.delete(return_values="ALL_OLD") + response.should.equal( + { + "Attributes": { + "Body": "http://url_to_lolcat.gif", + "forum_name": "LOLCat Forum", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "SentBy": "User A", + "subject": "Check this out!", + }, + "ConsumedCapacityUnits": 0.5, + } + ) table.refresh() table.item_count.should.equal(0) @@ -267,11 +250,8 @@ def test_delete_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.delete_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - 'RangeKeyElement': {'S': 'test-range'}, - }, + table_name="undeclared-table", + key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}}, ).should.throw(DynamoDBResponseError) @@ -281,54 +261,42 @@ def test_query(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - range_key='456', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='123', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='789', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data) item.put() - results = table.query(hash_key='the-key', - range_key_condition=condition.GT('1')) - results.response['Items'].should.have.length_of(3) + results = table.query(hash_key="the-key", range_key_condition=condition.GT("1")) + results.response["Items"].should.have.length_of(3) - results = table.query(hash_key='the-key', - range_key_condition=condition.GT('234')) - results.response['Items'].should.have.length_of(2) + results = table.query(hash_key="the-key", range_key_condition=condition.GT("234")) + results.response["Items"].should.have.length_of(2) - results = table.query(hash_key='the-key', - range_key_condition=condition.GT('9999')) - results.response['Items'].should.have.length_of(0) + results = table.query(hash_key="the-key", range_key_condition=condition.GT("9999")) + results.response["Items"].should.have.length_of(0) - results = table.query(hash_key='the-key', - range_key_condition=condition.CONTAINS('12')) - results.response['Items'].should.have.length_of(1) + results = table.query( + hash_key="the-key", range_key_condition=condition.CONTAINS("12") + ) + results.response["Items"].should.have.length_of(1) - results = table.query(hash_key='the-key', - range_key_condition=condition.BEGINS_WITH('7')) - results.response['Items'].should.have.length_of(1) + results = table.query( + hash_key="the-key", range_key_condition=condition.BEGINS_WITH("7") + ) + results.response["Items"].should.have.length_of(1) - results = table.query(hash_key='the-key', - range_key_condition=condition.BETWEEN('567', '890')) - results.response['Items'].should.have.length_of(1) + results = table.query( + hash_key="the-key", range_key_condition=condition.BETWEEN("567", "890") + ) + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -336,12 +304,10 @@ def test_query_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.query.when.called_with( - table_name='undeclared-table', - hash_key_value={'S': 'the-key'}, + table_name="undeclared-table", + hash_key_value={"S": "the-key"}, range_key_conditions={ - "AttributeValueList": [{ - "S": "User B" - }], + "AttributeValueList": [{"S": "User B"}], "ComparisonOperator": "EQ", }, ).should.throw(DynamoDBResponseError) @@ -353,61 +319,49 @@ def test_scan(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - range_key='456', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='123', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='the-key', - range_key='789', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data) item.put() results = table.scan() - results.response['Items'].should.have.length_of(3) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'SentBy': condition.EQ('User B')}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"SentBy": condition.EQ("User B")}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Body': condition.BEGINS_WITH('http')}) - results.response['Items'].should.have.length_of(3) + results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")}) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'Ids': condition.CONTAINS(2)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NOT_NULL()}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.NOT_NULL()}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NULL()}) - results.response['Items'].should.have.length_of(2) + results = table.scan(scan_filter={"Ids": condition.NULL()}) + results.response["Items"].should.have.length_of(2) - results = table.scan(scan_filter={'PK': condition.BETWEEN(8, 9)}) - results.response['Items'].should.have.length_of(0) + results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)}) + results.response["Items"].should.have.length_of(0) - results = table.scan(scan_filter={'PK': condition.BETWEEN(5, 8)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)}) + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -415,13 +369,11 @@ def test_scan_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(DynamoDBResponseError) @@ -433,7 +385,7 @@ def test_scan_after_has_item(): table = create_table(conn) list(table.scan()).should.equal([]) - table.has_item(hash_key='the-key', range_key='123') + table.has_item(hash_key="the-key", range_key="123") list(table.scan()).should.equal([]) @@ -446,27 +398,31 @@ def test_write_batch(): batch_list = conn.new_batch_write_list() items = [] - items.append(table.new_item( - hash_key='the-key', - range_key='123', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }, - )) + items.append( + table.new_item( + hash_key="the-key", + range_key="123", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + }, + ) + ) - items.append(table.new_item( - hash_key='the-key', - range_key='789', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, - }, - )) + items.append( + table.new_item( + hash_key="the-key", + range_key="789", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, + }, + ) + ) batch_list.add_batch(table, puts=items) conn.batch_write_item(batch_list) @@ -475,7 +431,7 @@ def test_write_batch(): table.item_count.should.equal(2) batch_list = conn.new_batch_write_list() - batch_list.add_batch(table, deletes=[('the-key', '789')]) + batch_list.add_batch(table, deletes=[("the-key", "789")]) conn.batch_write_item(batch_list) table.refresh() @@ -488,39 +444,27 @@ def test_batch_read(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - range_key='456', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='123', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='another-key', - range_key='789', - attrs=item_data, - ) + item = table.new_item(hash_key="another-key", range_key="789", attrs=item_data) item.put() - items = table.batch_get_item([('the-key', '123'), ('another-key', '789')]) + items = table.batch_get_item([("the-key", "123"), ("another-key", "789")]) # Iterate through so that batch_item gets called count = len([x for x in items]) count.should.equal(2) diff --git a/tests/test_dynamodb/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb/test_dynamodb_table_without_range_key.py index c31b1994d..c5031b5d1 100644 --- a/tests/test_dynamodb/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb/test_dynamodb_table_without_range_key.py @@ -13,15 +13,11 @@ from boto.exception import DynamoDBResponseError def create_table(conn): message_table_schema = conn.create_schema( - hash_key_name='forum_name', - hash_key_proto_value=str, + hash_key_name="forum_name", hash_key_proto_value=str ) table = conn.create_table( - name='messages', - schema=message_table_schema, - read_units=10, - write_units=10 + name="messages", schema=message_table_schema, read_units=10, write_units=10 ) return table @@ -33,25 +29,22 @@ def test_create_table(): create_table(conn) expected = { - 'Table': { - 'CreationDateTime': 1326499200.0, - 'ItemCount': 0, - 'KeySchema': { - 'HashKeyElement': { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, + "Table": { + "CreationDateTime": 1326499200.0, + "ItemCount": 0, + "KeySchema": { + "HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"} }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 10, }, - 'TableName': 'messages', - 'TableSizeBytes': 0, - 'TableStatus': 'ACTIVE', + "TableName": "messages", + "TableSizeBytes": 0, + "TableStatus": "ACTIVE", } } - conn.describe_table('messages').should.equal(expected) + conn.describe_table("messages").should.equal(expected) @mock_dynamodb_deprecated @@ -60,11 +53,12 @@ def test_delete_table(): create_table(conn) conn.list_tables().should.have.length_of(1) - conn.layer1.delete_table('messages') + conn.layer1.delete_table("messages") conn.list_tables().should.have.length_of(0) - conn.layer1.delete_table.when.called_with( - 'messages').should.throw(DynamoDBResponseError) + conn.layer1.delete_table.when.called_with("messages").should.throw( + DynamoDBResponseError + ) @mock_dynamodb_deprecated @@ -87,38 +81,37 @@ def test_item_add_and_describe_and_update(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='LOLCat Forum', - attrs=item_data, - ) + item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item.put() returned_item = table.get_item( - hash_key='LOLCat Forum', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"] + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - }) - item['SentBy'] = 'User B' + item["SentBy"] = "User B" item.put() returned_item = table.get_item( - hash_key='LOLCat Forum', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"] + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @mock_dynamodb_deprecated @@ -126,10 +119,7 @@ def test_item_put_without_table(): conn = boto.connect_dynamodb() conn.layer1.put_item.when.called_with( - table_name='undeclared-table', - item=dict( - hash_key='LOLCat Forum', - ), + table_name="undeclared-table", item=dict(hash_key="LOLCat Forum") ).should.throw(DynamoDBResponseError) @@ -138,9 +128,9 @@ def test_get_missing_item(): conn = boto.connect_dynamodb() table = create_table(conn) - table.get_item.when.called_with( - hash_key='tester', - ).should.throw(DynamoDBKeyNotFoundError) + table.get_item.when.called_with(hash_key="tester").should.throw( + DynamoDBKeyNotFoundError + ) @mock_dynamodb_deprecated @@ -148,10 +138,7 @@ def test_get_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.get_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - }, + table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}} ).should.throw(DynamoDBKeyNotFoundError) @@ -161,21 +148,18 @@ def test_delete_item(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='LOLCat Forum', - attrs=item_data, - ) + item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item.put() table.refresh() table.item_count.should.equal(1) response = item.delete() - response.should.equal({u'Attributes': [], u'ConsumedCapacityUnits': 0.5}) + response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5}) table.refresh() table.item_count.should.equal(0) @@ -188,29 +172,28 @@ def test_delete_item_with_attribute_response(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='LOLCat Forum', - attrs=item_data, - ) + item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item.put() table.refresh() table.item_count.should.equal(1) - response = item.delete(return_values='ALL_OLD') - response.should.equal({ - u'Attributes': { - u'Body': u'http://url_to_lolcat.gif', - u'forum_name': u'LOLCat Forum', - u'ReceivedTime': u'12/9/2011 11:36:03 PM', - u'SentBy': u'User A', - }, - u'ConsumedCapacityUnits': 0.5 - }) + response = item.delete(return_values="ALL_OLD") + response.should.equal( + { + "Attributes": { + "Body": "http://url_to_lolcat.gif", + "forum_name": "LOLCat Forum", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "SentBy": "User A", + }, + "ConsumedCapacityUnits": 0.5, + } + ) table.refresh() table.item_count.should.equal(0) @@ -222,10 +205,7 @@ def test_delete_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.delete_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - }, + table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}} ).should.throw(DynamoDBResponseError) @@ -235,18 +215,15 @@ def test_query(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", attrs=item_data) item.put() - results = table.query(hash_key='the-key') - results.response['Items'].should.have.length_of(1) + results = table.query(hash_key="the-key") + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -254,8 +231,7 @@ def test_query_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.query.when.called_with( - table_name='undeclared-table', - hash_key_value={'S': 'the-key'}, + table_name="undeclared-table", hash_key_value={"S": "the-key"} ).should.throw(DynamoDBResponseError) @@ -265,58 +241,49 @@ def test_scan(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key2', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key2", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='the-key3', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key3", attrs=item_data) item.put() results = table.scan() - results.response['Items'].should.have.length_of(3) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'SentBy': condition.EQ('User B')}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"SentBy": condition.EQ("User B")}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Body': condition.BEGINS_WITH('http')}) - results.response['Items'].should.have.length_of(3) + results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")}) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'Ids': condition.CONTAINS(2)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NOT_NULL()}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.NOT_NULL()}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NULL()}) - results.response['Items'].should.have.length_of(2) + results = table.scan(scan_filter={"Ids": condition.NULL()}) + results.response["Items"].should.have.length_of(2) - results = table.scan(scan_filter={'PK': condition.BETWEEN(8, 9)}) - results.response['Items'].should.have.length_of(0) + results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)}) + results.response["Items"].should.have.length_of(0) - results = table.scan(scan_filter={'PK': condition.BETWEEN(5, 8)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)}) + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -324,13 +291,11 @@ def test_scan_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(DynamoDBResponseError) @@ -342,7 +307,7 @@ def test_scan_after_has_item(): table = create_table(conn) list(table.scan()).should.equal([]) - table.has_item('the-key') + table.has_item("the-key") list(table.scan()).should.equal([]) @@ -355,25 +320,29 @@ def test_write_batch(): batch_list = conn.new_batch_write_list() items = [] - items.append(table.new_item( - hash_key='the-key', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }, - )) + items.append( + table.new_item( + hash_key="the-key", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + }, + ) + ) - items.append(table.new_item( - hash_key='the-key2', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, - }, - )) + items.append( + table.new_item( + hash_key="the-key2", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, + }, + ) + ) batch_list.add_batch(table, puts=items) conn.batch_write_item(batch_list) @@ -382,7 +351,7 @@ def test_write_batch(): table.item_count.should.equal(2) batch_list = conn.new_batch_write_list() - batch_list.add_batch(table, deletes=[('the-key')]) + batch_list.add_batch(table, deletes=[("the-key")]) conn.batch_write_item(batch_list) table.refresh() @@ -395,36 +364,27 @@ def test_batch_read(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key1', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key1", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key2', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key2", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='another-key', - attrs=item_data, - ) + item = table.new_item(hash_key="another-key", attrs=item_data) item.put() - items = table.batch_get_item([('the-key1'), ('another-key')]) + items = table.batch_get_item([("the-key1"), ("another-key")]) # Iterate through so that batch_item gets called count = len([x for x in items]) count.should.have.equal(2) diff --git a/tests/test_dynamodb/test_server.py b/tests/test_dynamodb/test_server.py index a9fb7607e..310643628 100644 --- a/tests/test_dynamodb/test_server.py +++ b/tests/test_dynamodb/test_server.py @@ -1,20 +1,20 @@ -from __future__ import unicode_literals -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_table_list(): - backend = server.create_backend_app("dynamodb") - test_client = backend.test_client() - - res = test_client.get('/') - res.status_code.should.equal(404) - - headers = {'X-Amz-Target': 'TestTable.ListTables'} - res = test_client.get('/', headers=headers) - res.data.should.contain(b'TableNames') +from __future__ import unicode_literals +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_table_list(): + backend = server.create_backend_app("dynamodb") + test_client = backend.test_client() + + res = test_client.get("/") + res.status_code.should.equal(404) + + headers = {"X-Amz-Target": "TestTable.ListTables"} + res = test_client.get("/", headers=headers) + res.data.should.contain(b"TableNames") diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index fb6c0e17d..1a8a70615 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -11,7 +11,7 @@ import requests from moto import mock_dynamodb2, mock_dynamodb2_deprecated from moto.dynamodb2 import dynamodb_backend2 from boto.exception import JSONResponseError -from botocore.exceptions import ClientError +from botocore.exceptions import ClientError, ParamValidationError from tests.helpers import requires_boto_gte import tests.backport_assert_raises @@ -19,6 +19,7 @@ import moto.dynamodb2.comparisons import moto.dynamodb2.models from nose.tools import assert_raises + try: import boto.dynamodb2 except ImportError: @@ -28,16 +29,18 @@ except ImportError: @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_list_tables(): - name = 'TestTable' + name = "TestTable" # Should make tables properly with boto - dynamodb_backend2.create_table(name, schema=[ - {u'KeyType': u'HASH', u'AttributeName': u'forum_name'}, - {u'KeyType': u'RANGE', u'AttributeName': u'subject'} - ]) + dynamodb_backend2.create_table( + name, + schema=[ + {"KeyType": "HASH", "AttributeName": "forum_name"}, + {"KeyType": "RANGE", "AttributeName": "subject"}, + ], + ) conn = boto.dynamodb2.connect_to_region( - 'us-east-1', - aws_access_key_id="ak", - aws_secret_access_key="sk") + "us-east-1", aws_access_key_id="ak", aws_secret_access_key="sk" + ) assert conn.list_tables()["TableNames"] == [name] @@ -45,16 +48,15 @@ def test_list_tables(): @mock_dynamodb2_deprecated def test_list_tables_layer_1(): # Should make tables properly with boto - dynamodb_backend2.create_table("test_1", schema=[ - {u'KeyType': u'HASH', u'AttributeName': u'name'} - ]) - dynamodb_backend2.create_table("test_2", schema=[ - {u'KeyType': u'HASH', u'AttributeName': u'name'} - ]) + dynamodb_backend2.create_table( + "test_1", schema=[{"KeyType": "HASH", "AttributeName": "name"}] + ) + dynamodb_backend2.create_table( + "test_2", schema=[{"KeyType": "HASH", "AttributeName": "name"}] + ) conn = boto.dynamodb2.connect_to_region( - 'us-east-1', - aws_access_key_id="ak", - aws_secret_access_key="sk") + "us-east-1", aws_access_key_id="ak", aws_secret_access_key="sk" + ) res = conn.list_tables(limit=1) expected = {"TableNames": ["test_1"], "LastEvaluatedTableName": "test_1"} @@ -69,30 +71,36 @@ def test_list_tables_layer_1(): @mock_dynamodb2_deprecated def test_describe_missing_table(): conn = boto.dynamodb2.connect_to_region( - 'us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") + "us-west-2", aws_access_key_id="ak", aws_secret_access_key="sk" + ) with assert_raises(JSONResponseError): - conn.describe_table('messages') + conn.describe_table("messages") @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_table_tags(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'id','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'id','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) table_description = conn.describe_table(TableName=name) - arn = table_description['Table']['TableArn'] + arn = table_description["Table"]["TableArn"] # Tag table - tags = [{'Key': 'TestTag', 'Value': 'TestValue'}, {'Key': 'TestTag2', 'Value': 'TestValue2'}] + tags = [ + {"Key": "TestTag", "Value": "TestValue"}, + {"Key": "TestTag2", "Value": "TestValue2"}, + ] conn.tag_resource(ResourceArn=arn, Tags=tags) # Check tags @@ -100,28 +108,32 @@ def test_list_table_tags(): assert resp["Tags"] == tags # Remove 1 tag - conn.untag_resource(ResourceArn=arn, TagKeys=['TestTag']) + conn.untag_resource(ResourceArn=arn, TagKeys=["TestTag"]) # Check tags resp = conn.list_tags_of_resource(ResourceArn=arn) - assert resp["Tags"] == [{'Key': 'TestTag2', 'Value': 'TestValue2'}] + assert resp["Tags"] == [{"Key": "TestTag2", "Value": "TestValue2"}] @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_table_tags_empty(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'id','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'id','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) table_description = conn.describe_table(TableName=name) - arn = table_description['Table']['TableArn'] - tags = [{'Key':'TestTag', 'Value': 'TestValue'}] + arn = table_description["Table"]["TableArn"] + tags = [{"Key": "TestTag", "Value": "TestValue"}] # conn.tag_resource(ResourceArn=arn, # Tags=tags) resp = conn.list_tags_of_resource(ResourceArn=arn) @@ -131,773 +143,1267 @@ def test_list_table_tags_empty(): @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_table_tags_paginated(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'id','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'id','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) table_description = conn.describe_table(TableName=name) - arn = table_description['Table']['TableArn'] + arn = table_description["Table"]["TableArn"] for i in range(11): - tags = [{'Key':'TestTag%d' % i, 'Value': 'TestValue'}] - conn.tag_resource(ResourceArn=arn, - Tags=tags) + tags = [{"Key": "TestTag%d" % i, "Value": "TestValue"}] + conn.tag_resource(ResourceArn=arn, Tags=tags) resp = conn.list_tags_of_resource(ResourceArn=arn) assert len(resp["Tags"]) == 10 - assert 'NextToken' in resp.keys() - resp2 = conn.list_tags_of_resource(ResourceArn=arn, - NextToken=resp['NextToken']) + assert "NextToken" in resp.keys() + resp2 = conn.list_tags_of_resource(ResourceArn=arn, NextToken=resp["NextToken"]) assert len(resp2["Tags"]) == 1 - assert 'NextToken' not in resp2.keys() + assert "NextToken" not in resp2.keys() @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_not_found_table_tags(): - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - arn = 'DymmyArn' + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + arn = "DymmyArn" try: conn.list_tags_of_resource(ResourceArn=arn) except ClientError as exception: - assert exception.response['Error']['Code'] == "ResourceNotFoundException" + assert exception.response["Error"]["Code"] == "ResourceNotFoundException" @requires_boto_gte("2.9") @mock_dynamodb2 def test_item_add_empty_string_exception(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) with assert_raises(ClientError) as ex: conn.put_item( TableName=name, Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - } + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": ""}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + }, ) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'One or more parameter values were invalid: An AttributeValue may not contain an empty string' + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "One or more parameter values were invalid: An AttributeValue may not contain an empty string" ) @requires_boto_gte("2.9") @mock_dynamodb2 def test_update_item_with_empty_string_exception(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) conn.put_item( TableName=name, Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "test" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - } + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": "test"}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + }, ) with assert_raises(ClientError) as ex: conn.update_item( TableName=name, - Key={ - 'forum_name': { 'S': 'LOLCat Forum'}, - }, - UpdateExpression='set Body=:Body', - ExpressionAttributeValues={ - ':Body': {'S': ''} - }) + Key={"forum_name": {"S": "LOLCat Forum"}}, + UpdateExpression="set Body=:Body", + ExpressionAttributeValues={":Body": {"S": ""}}, + ) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'One or more parameter values were invalid: An AttributeValue may not contain an empty string' + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "One or more parameter values were invalid: An AttributeValue may not contain an empty string" ) @requires_boto_gte("2.9") @mock_dynamodb2 def test_query_invalid_table(): - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) try: - conn.query(TableName='invalid_table', KeyConditionExpression='index1 = :partitionkeyval', ExpressionAttributeValues={':partitionkeyval': {'S':'test'}}) + conn.query( + TableName="invalid_table", + KeyConditionExpression="index1 = :partitionkeyval", + ExpressionAttributeValues={":partitionkeyval": {"S": "test"}}, + ) except ClientError as exception: - assert exception.response['Error']['Code'] == "ResourceNotFoundException" + assert exception.response["Error"]["Code"] == "ResourceNotFoundException" @requires_boto_gte("2.9") @mock_dynamodb2 def test_scan_returns_consumed_capacity(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) - - conn.put_item( - TableName=name, - Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "test" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - } - ) - - response = conn.scan( - TableName=name, + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", ) - assert 'ConsumedCapacity' in response - assert 'CapacityUnits' in response['ConsumedCapacity'] - assert response['ConsumedCapacity']['TableName'] == name + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + + conn.put_item( + TableName=name, + Item={ + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": "test"}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + }, + ) + + response = conn.scan(TableName=name) + + assert "ConsumedCapacity" in response + assert "CapacityUnits" in response["ConsumedCapacity"] + assert response["ConsumedCapacity"]["TableName"] == name @requires_boto_gte("2.9") @mock_dynamodb2 def test_put_item_with_special_chars(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) conn.put_item( - TableName=name, - Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "test" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - '"': {"S": "foo"}, - } - ) + TableName=name, + Item={ + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": "test"}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + '"': {"S": "foo"}, + }, + ) @requires_boto_gte("2.9") @mock_dynamodb2 def test_query_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) - - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} ) - assert 'ConsumedCapacity' in results - assert 'CapacityUnits' in results['ConsumedCapacity'] - assert results['ConsumedCapacity']['CapacityUnits'] == 1 + results = table.query(KeyConditionExpression=Key("forum_name").eq("the-key")) + + assert "ConsumedCapacity" in results + assert "CapacityUnits" in results["ConsumedCapacity"] + assert results["ConsumedCapacity"]["CapacityUnits"] == 1 @mock_dynamodb2 -def test_basic_projection_expressions(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') +def test_basic_projection_expression_using_get_item(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", } ) - table = dynamodb.Table('users') - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message' - }) - # Test a query returning all items - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body, subject' + result = table.get_item( + Key={"forum_name": "the-key", "subject": "123"}, + ProjectionExpression="body, subject", ) - assert 'body' in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'subject' in results['Items'][0] - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '1234', - 'body': 'yet another test message' - }) - - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body' - ) - - assert 'body' in results['Items'][0] - assert 'subject' not in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'body' in results['Items'][1] - assert 'subject' not in results['Items'][1] - assert results['Items'][1]['body'] == 'yet another test message' + result["Item"].should.be.equal({"subject": "123", "body": "some test message"}) # The projection expression should not remove data from storage - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), + result = table.get_item(Key={"forum_name": "the-key", "subject": "123"}) + + result["Item"].should.be.equal( + {"forum_name": "the-key", "subject": "123", "body": "some test message"} ) - assert 'subject' in results['Items'][0] - assert 'body' in results['Items'][1] - assert 'forum_name' in results['Items'][1] + + +@mock_dynamodb2 +def test_basic_projection_expressions_using_query(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + table = dynamodb.create_table( + TableName="users", + KeySchema=[ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + } + ) + # Test a query returning all items + results = table.query( + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body, subject", + ) + + assert "body" in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "subject" in results["Items"][0] + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "1234", + "body": "yet another test message", + } + ) + + results = table.query( + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body", + ) + + assert "body" in results["Items"][0] + assert "subject" not in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "body" in results["Items"][1] + assert "subject" not in results["Items"][1] + assert results["Items"][1]["body"] == "yet another test message" + + # The projection expression should not remove data from storage + results = table.query(KeyConditionExpression=Key("forum_name").eq("the-key")) + assert "subject" in results["Items"][0] + assert "body" in results["Items"][1] + assert "forum_name" in results["Items"][1] + @mock_dynamodb2 def test_basic_projection_expressions_using_scan(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", } ) - table = dynamodb.Table('users') - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message' - }) # Test a scan returning all items results = table.scan( - FilterExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body, subject' + FilterExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body, subject", ) - assert 'body' in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'subject' in results['Items'][0] + assert "body" in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "subject" in results["Items"][0] - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '1234', - 'body': 'yet another test message' - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "1234", + "body": "yet another test message", + } + ) results = table.scan( - FilterExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body' + FilterExpression=Key("forum_name").eq("the-key"), ProjectionExpression="body" ) - assert 'body' in results['Items'][0] - assert 'subject' not in results['Items'][0] - assert 'forum_name' not in results['Items'][0] - assert 'body' in results['Items'][1] - assert 'subject' not in results['Items'][1] - assert 'forum_name' not in results['Items'][1] + assert "body" in results["Items"][0] + assert "subject" not in results["Items"][0] + assert "forum_name" not in results["Items"][0] + assert "body" in results["Items"][1] + assert "subject" not in results["Items"][1] + assert "forum_name" not in results["Items"][1] # The projection expression should not remove data from storage - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ) - assert 'subject' in results['Items'][0] - assert 'body' in results['Items'][1] - assert 'forum_name' in results['Items'][1] + results = table.query(KeyConditionExpression=Key("forum_name").eq("the-key")) + assert "subject" in results["Items"][0] + assert "body" in results["Items"][1] + assert "forum_name" in results["Items"][1] @mock_dynamodb2 -def test_basic_projection_expressions_with_attr_expression_names(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') +def test_nested_projection_expression_using_get_item(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + table.put_item( + Item={ + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + "foo": "bar", + } + ) + table.put_item( + Item={ + "forum_name": "key2", + "nested": {"id": "id2", "incode": "code2"}, + "foo": "bar", + } + ) + + # Test a get_item returning all items + result = table.get_item( + Key={"forum_name": "key1"}, + ProjectionExpression="nested.level1.id, nested.level2", + )["Item"] + result.should.equal( + {"nested": {"level1": {"id": "id1"}, "level2": {"id": "id2", "include": "all"}}} + ) + # Assert actual data has not been deleted + result = table.get_item(Key={"forum_name": "key1"})["Item"] + result.should.equal( + { + "foo": "bar", + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + } + ) + + +@mock_dynamodb2 +def test_basic_projection_expressions_using_query(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + } + ) + + # Test a query returning all items + result = table.query( + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body, subject", + )["Items"][0] + + assert "body" in result + assert result["body"] == "some test message" + assert "subject" in result + assert "forum_name" not in result + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "1234", + "body": "yet another test message", + } + ) + + items = table.query( + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body", + )["Items"] + + assert "body" in items[0] + assert "subject" not in items[0] + assert items[0]["body"] == "some test message" + assert "body" in items[1] + assert "subject" not in items[1] + assert items[1]["body"] == "yet another test message" + + # The projection expression should not remove data from storage + items = table.query(KeyConditionExpression=Key("forum_name").eq("the-key"))["Items"] + assert "subject" in items[0] + assert "body" in items[1] + assert "forum_name" in items[1] + + +@mock_dynamodb2 +def test_nested_projection_expression_using_query(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + table.put_item( + Item={ + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + "foo": "bar", + } + ) + table.put_item( + Item={ + "forum_name": "key2", + "nested": {"id": "id2", "incode": "code2"}, + "foo": "bar", + } + ) + + # Test a query returning all items + result = table.query( + KeyConditionExpression=Key("forum_name").eq("key1"), + ProjectionExpression="nested.level1.id, nested.level2", + )["Items"][0] + + assert "nested" in result + result["nested"].should.equal( + {"level1": {"id": "id1"}, "level2": {"id": "id2", "include": "all"}} + ) + assert "foo" not in result + # Assert actual data has not been deleted + result = table.query(KeyConditionExpression=Key("forum_name").eq("key1"))["Items"][ + 0 + ] + result.should.equal( + { + "foo": "bar", + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + } + ) + + +@mock_dynamodb2 +def test_basic_projection_expressions_using_scan(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + } + ) + # Test a scan returning all items + results = table.scan( + FilterExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body, subject", + )["Items"] + + results.should.equal([{"body": "some test message", "subject": "123"}]) + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "1234", + "body": "yet another test message", + } + ) + + results = table.scan( + FilterExpression=Key("forum_name").eq("the-key"), ProjectionExpression="body" + )["Items"] + + assert {"body": "some test message"} in results + assert {"body": "yet another test message"} in results + + # The projection expression should not remove data from storage + results = table.query(KeyConditionExpression=Key("forum_name").eq("the-key")) + assert "subject" in results["Items"][0] + assert "body" in results["Items"][1] + assert "forum_name" in results["Items"][1] + + +@mock_dynamodb2 +def test_nested_projection_expression_using_scan(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + table.put_item( + Item={ + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + "foo": "bar", + } + ) + table.put_item( + Item={ + "forum_name": "key2", + "nested": {"id": "id2", "incode": "code2"}, + "foo": "bar", + } + ) + + # Test a scan + results = table.scan( + FilterExpression=Key("forum_name").eq("key1"), + ProjectionExpression="nested.level1.id, nested.level2", + )["Items"] + results.should.equal( + [ + { + "nested": { + "level1": {"id": "id1"}, + "level2": {"include": "all", "id": "id2"}, + } + } + ] + ) + # Assert original data is still there + results = table.scan(FilterExpression=Key("forum_name").eq("key1"))["Items"] + results.should.equal( + [ + { + "forum_name": "key1", + "foo": "bar", + "nested": { + "level1": {"att": "irrelevant", "id": "id1"}, + "level2": {"include": "all", "id": "id2"}, + "level3": {"id": "irrelevant"}, + }, + } + ] + ) + + +@mock_dynamodb2 +def test_basic_projection_expression_using_get_item_with_attr_expression_names(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": "some test message", + "attachment": "something", } ) - table = dynamodb.Table('users') - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - 'attachment': 'something' - }) + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + "attachment": "something", + } + ) + result = table.get_item( + Key={"forum_name": "the-key", "subject": "123"}, + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, + ) - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message', - 'attachment': 'something' - }) + result["Item"].should.be.equal( + {"subject": "123", "body": "some test message", "attachment": "something"} + ) + + +@mock_dynamodb2 +def test_basic_projection_expressions_using_query_with_attr_expression_names(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + table = dynamodb.create_table( + TableName="users", + KeySchema=[ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": "some test message", + "attachment": "something", + } + ) + + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + "attachment": "something", + } + ) # Test a query returning all items results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='#rl, #rt, subject', - ExpressionAttributeNames={ - '#rl': 'body', - '#rt': 'attachment' - }, + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, + ) + + assert "body" in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "subject" in results["Items"][0] + assert results["Items"][0]["subject"] == "123" + assert "attachment" in results["Items"][0] + assert results["Items"][0]["attachment"] == "something" + + +@mock_dynamodb2 +def test_nested_projection_expression_using_get_item_with_attr_expression(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + table.put_item( + Item={ + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + "foo": "bar", + } + ) + table.put_item( + Item={ + "forum_name": "key2", + "nested": {"id": "id2", "incode": "code2"}, + "foo": "bar", + } + ) + + # Test a get_item returning all items + result = table.get_item( + Key={"forum_name": "key1"}, + ProjectionExpression="#nst.level1.id, #nst.#lvl2", + ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level2"}, + )["Item"] + result.should.equal( + {"nested": {"level1": {"id": "id1"}, "level2": {"id": "id2", "include": "all"}}} + ) + # Assert actual data has not been deleted + result = table.get_item(Key={"forum_name": "key1"})["Item"] + result.should.equal( + { + "foo": "bar", + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + } + ) + + +@mock_dynamodb2 +def test_nested_projection_expression_using_query_with_attr_expression_names(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + table.put_item( + Item={ + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + "foo": "bar", + } + ) + table.put_item( + Item={ + "forum_name": "key2", + "nested": {"id": "id2", "incode": "code2"}, + "foo": "bar", + } + ) + + # Test a query returning all items + result = table.query( + KeyConditionExpression=Key("forum_name").eq("key1"), + ProjectionExpression="#nst.level1.id, #nst.#lvl2", + ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level2"}, + )["Items"][0] + + assert "nested" in result + result["nested"].should.equal( + {"level1": {"id": "id1"}, "level2": {"id": "id2", "include": "all"}} + ) + assert "foo" not in result + # Assert actual data has not been deleted + result = table.query(KeyConditionExpression=Key("forum_name").eq("key1"))["Items"][ + 0 + ] + result.should.equal( + { + "foo": "bar", + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + } ) - assert 'body' in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'subject' in results['Items'][0] - assert results['Items'][0]['subject'] == '123' - assert 'attachment' in results['Items'][0] - assert results['Items'][0]['attachment'] == 'something' @mock_dynamodb2 def test_basic_projection_expressions_using_scan_with_attr_expression_names(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": "some test message", + "attachment": "something", } ) - table = dynamodb.Table('users') - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - 'attachment': 'something' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message', - 'attachment': 'something' - }) + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + "attachment": "something", + } + ) # Test a scan returning all items results = table.scan( - FilterExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='#rl, #rt, subject', - ExpressionAttributeNames={ - '#rl': 'body', - '#rt': 'attachment' - }, + FilterExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, ) - assert 'body' in results['Items'][0] - assert 'attachment' in results['Items'][0] - assert 'subject' in results['Items'][0] - assert 'form_name' not in results['Items'][0] + assert "body" in results["Items"][0] + assert "attachment" in results["Items"][0] + assert "subject" in results["Items"][0] + assert "form_name" not in results["Items"][0] # Test without a FilterExpression results = table.scan( - ProjectionExpression='#rl, #rt, subject', - ExpressionAttributeNames={ - '#rl': 'body', - '#rt': 'attachment' - }, + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, ) - assert 'body' in results['Items'][0] - assert 'attachment' in results['Items'][0] - assert 'subject' in results['Items'][0] - assert 'form_name' not in results['Items'][0] + assert "body" in results["Items"][0] + assert "attachment" in results["Items"][0] + assert "subject" in results["Items"][0] + assert "form_name" not in results["Items"][0] + + +@mock_dynamodb2 +def test_nested_projection_expression_using_scan_with_attr_expression_names(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + dynamodb.create_table( + TableName="users", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + table.put_item( + Item={ + "forum_name": "key1", + "nested": { + "level1": {"id": "id1", "att": "irrelevant"}, + "level2": {"id": "id2", "include": "all"}, + "level3": {"id": "irrelevant"}, + }, + "foo": "bar", + } + ) + table.put_item( + Item={ + "forum_name": "key2", + "nested": {"id": "id2", "incode": "code2"}, + "foo": "bar", + } + ) + + # Test a scan + results = table.scan( + FilterExpression=Key("forum_name").eq("key1"), + ProjectionExpression="nested.level1.id, nested.level2", + ExpressionAttributeNames={"#nst": "nested", "#lvl2": "level2"}, + )["Items"] + results.should.equal( + [ + { + "nested": { + "level1": {"id": "id1"}, + "level2": {"include": "all", "id": "id2"}, + } + } + ] + ) + # Assert original data is still there + results = table.scan(FilterExpression=Key("forum_name").eq("key1"))["Items"] + results.should.equal( + [ + { + "forum_name": "key1", + "foo": "bar", + "nested": { + "level1": {"att": "irrelevant", "id": "id1"}, + "level2": {"include": "all", "id": "id2"}, + "level3": {"id": "irrelevant"}, + }, + } + ] + ) @mock_dynamodb2 def test_put_item_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - response = table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - }) + response = table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) - assert 'ConsumedCapacity' in response + assert "ConsumedCapacity" in response @mock_dynamodb2 def test_update_item_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - }) + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) - response = table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, - UpdateExpression='set body=:tb', - ExpressionAttributeValues={ - ':tb': 'a new message' - }) + response = table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, + UpdateExpression="set body=:tb", + ExpressionAttributeValues={":tb": "a new message"}, + ) - assert 'ConsumedCapacity' in response - assert 'CapacityUnits' in response['ConsumedCapacity'] - assert 'TableName' in response['ConsumedCapacity'] + assert "ConsumedCapacity" in response + assert "CapacityUnits" in response["ConsumedCapacity"] + assert "TableName" in response["ConsumedCapacity"] @mock_dynamodb2 def test_get_item_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - }) + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) - response = table.get_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }) + response = table.get_item(Key={"forum_name": "the-key", "subject": "123"}) - assert 'ConsumedCapacity' in response - assert 'CapacityUnits' in response['ConsumedCapacity'] - assert 'TableName' in response['ConsumedCapacity'] + assert "ConsumedCapacity" in response + assert "CapacityUnits" in response["ConsumedCapacity"] + assert "TableName" in response["ConsumedCapacity"] def test_filter_expression(): - row1 = moto.dynamodb2.models.Item(None, None, None, None, {'Id': {'N': '8'}, 'Subs': {'N': '5'}, 'Desc': {'S': 'Some description'}, 'KV': {'SS': ['test1', 'test2']}}) - row2 = moto.dynamodb2.models.Item(None, None, None, None, {'Id': {'N': '8'}, 'Subs': {'N': '10'}, 'Desc': {'S': 'A description'}, 'KV': {'SS': ['test3', 'test4']}}) + row1 = moto.dynamodb2.models.Item( + None, + None, + None, + None, + { + "Id": {"N": "8"}, + "Subs": {"N": "5"}, + "Desc": {"S": "Some description"}, + "KV": {"SS": ["test1", "test2"]}, + }, + ) + row2 = moto.dynamodb2.models.Item( + None, + None, + None, + None, + { + "Id": {"N": "8"}, + "Subs": {"N": "10"}, + "Desc": {"S": "A description"}, + "KV": {"SS": ["test3", "test4"]}, + }, + ) # NOT test 1 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT attribute_not_exists(Id)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "NOT attribute_not_exists(Id)", {}, {} + ) filter_expr.expr(row1).should.be(True) # NOT test 2 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "NOT (Id = :v0)", {}, {":v0": {"N": "8"}} + ) filter_expr.expr(row1).should.be(False) # Id = 8 so should be false # AND test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id > :v0 AND Subs < :v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "7"}} + ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # OR test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id = :v0 OR Id=:v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "8"}} + ) filter_expr.expr(row1).should.be(True) # BETWEEN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id BETWEEN :v0 AND :v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "10"}} + ) filter_expr.expr(row1).should.be(True) # PAREN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id = :v0 AND (Subs = :v0 OR Subs = :v1)", + {}, + {":v0": {"N": "8"}, ":v1": {"N": "5"}}, + ) filter_expr.expr(row1).should.be(True) # IN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, { - ':v0': {'N': '7'}, - ':v1': {'N': '8'}, - ':v2': {'N': '9'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id IN (:v0, :v1, :v2)", + {}, + {":v0": {"N": "7"}, ":v1": {"N": "8"}, ":v2": {"N": "9"}}, + ) filter_expr.expr(row1).should.be(True) # attribute function tests (with extra spaces) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "attribute_exists(Id) AND attribute_not_exists (User)", {}, {} + ) filter_expr.expr(row1).should.be(True) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "attribute_type(Id, :v0)", {}, {":v0": {"S": "N"}} + ) filter_expr.expr(row1).should.be(True) # beginswith function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "begins_with(Desc, :v0)", {}, {":v0": {"S": "Some"}} + ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # contains function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "contains(KV, :v0)", {}, {":v0": {"S": "test1"}} + ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # size function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('size(Desc) > size(KV)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "size(Desc) > size(KV)", {}, {} + ) filter_expr.expr(row1).should.be(True) # Expression from @batkuip filter_expr = moto.dynamodb2.comparisons.get_filter_expression( - '(#n0 < :v0 AND attribute_not_exists(#n1))', - {'#n0': 'Subs', '#n1': 'fanout_ts'}, - {':v0': {'N': '7'}} + "(#n0 < :v0 AND attribute_not_exists(#n1))", + {"#n0": "Subs", "#n1": "fanout_ts"}, + {":v0": {"N": "7"}}, ) filter_expr.expr(row1).should.be(True) # Expression from to check contains on string value filter_expr = moto.dynamodb2.comparisons.get_filter_expression( - 'contains(#n0, :v0)', - {'#n0': 'Desc'}, - {':v0': {'S': 'Some'}} + "contains(#n0, :v0)", {"#n0": "Desc"}, {":v0": {"S": "Some"}} ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) @@ -905,1096 +1411,1074 @@ def test_filter_expression(): @mock_dynamodb2 def test_query_filter(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'nested': {'M': { - 'version': {'S': 'version1'}, - 'contents': {'L': [ - {'S': 'value1'}, {'S': 'value2'}, - ]}, - }}, - } + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "nested": { + "M": { + "version": {"S": "version1"}, + "contents": {"L": [{"S": "value1"}, {"S": "value2"}]}, + } + }, + }, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app2'}, - 'nested': {'M': { - 'version': {'S': 'version2'}, - 'contents': {'L': [ - {'S': 'value1'}, {'S': 'value2'}, - ]}, - }}, - } + "client": {"S": "client1"}, + "app": {"S": "app2"}, + "nested": { + "M": { + "version": {"S": "version2"}, + "contents": {"L": [{"S": "value1"}, {"S": "value2"}]}, + } + }, + }, ) - table = dynamodb.Table('test1') - response = table.query( - KeyConditionExpression=Key('client').eq('client1') - ) - assert response['Count'] == 2 + table = dynamodb.Table("test1") + response = table.query(KeyConditionExpression=Key("client").eq("client1")) + assert response["Count"] == 2 response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('app').eq('app2') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("app").eq("app2"), ) - assert response['Count'] == 1 - assert response['Items'][0]['app'] == 'app2' + assert response["Count"] == 1 + assert response["Items"][0]["app"] == "app2" response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('app').contains('app') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("app").contains("app"), ) - assert response['Count'] == 2 + assert response["Count"] == 2 response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('nested.version').contains('version') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("nested.version").contains("version"), ) - assert response['Count'] == 2 + assert response["Count"] == 2 response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('nested.contents[0]').eq('value1') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("nested.contents[0]").eq("value1"), ) - assert response['Count'] == 2 + assert response["Count"] == 2 + + +@mock_dynamodb2 +def test_query_filter_overlapping_expression_prefixes(): + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + + # Create the DynamoDB table. + client.create_table( + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "nested": { + "M": { + "version": {"S": "version1"}, + "contents": {"L": [{"S": "value1"}, {"S": "value2"}]}, + } + }, + }, + ) + + table = dynamodb.Table("test1") + response = table.query( + KeyConditionExpression=Key("client").eq("client1") & Key("app").eq("app1"), + ProjectionExpression="#1, #10, nested", + ExpressionAttributeNames={"#1": "client", "#10": "app"}, + ) + + assert response["Count"] == 1 + assert response["Items"][0] == { + "client": "client1", + "app": "app1", + "nested": {"version": "version1", "contents": ["value1", "value2"]}, + } @mock_dynamodb2 def test_scan_filter(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"S": "app1"}} ) - table = dynamodb.Table('test1') - response = table.scan( - FilterExpression=Attr('app').eq('app2') - ) - assert response['Count'] == 0 + table = dynamodb.Table("test1") + response = table.scan(FilterExpression=Attr("app").eq("app2")) + assert response["Count"] == 0 - response = table.scan( - FilterExpression=Attr('app').eq('app1') - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("app").eq("app1")) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('app').ne('app2') - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("app").ne("app2")) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('app').ne('app1') - ) - assert response['Count'] == 0 + response = table.scan(FilterExpression=Attr("app").ne("app1")) + assert response["Count"] == 0 @mock_dynamodb2 def test_scan_filter2(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'N'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "N"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'N': '1'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"N": "1"}} ) response = client.scan( - TableName='test1', - Select='ALL_ATTRIBUTES', - FilterExpression='#tb >= :dt', + TableName="test1", + Select="ALL_ATTRIBUTES", + FilterExpression="#tb >= :dt", ExpressionAttributeNames={"#tb": "app"}, - ExpressionAttributeValues={":dt": {"N": str(1)}} + ExpressionAttributeValues={":dt": {"N": str(1)}}, ) - assert response['Count'] == 1 + assert response["Count"] == 1 @mock_dynamodb2 def test_scan_filter3(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'N'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "N"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'N': '1'}, - 'active': {'BOOL': True} - } + TableName="test1", + Item={"client": {"S": "client1"}, "app": {"N": "1"}, "active": {"BOOL": True}}, ) - table = dynamodb.Table('test1') - response = table.scan( - FilterExpression=Attr('active').eq(True) - ) - assert response['Count'] == 1 + table = dynamodb.Table("test1") + response = table.scan(FilterExpression=Attr("active").eq(True)) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('active').ne(True) - ) - assert response['Count'] == 0 + response = table.scan(FilterExpression=Attr("active").ne(True)) + assert response["Count"] == 0 - response = table.scan( - FilterExpression=Attr('active').ne(False) - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("active").ne(False)) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('app').ne(1) - ) - assert response['Count'] == 0 + response = table.scan(FilterExpression=Attr("app").ne(1)) + assert response["Count"] == 0 - response = table.scan( - FilterExpression=Attr('app').ne(2) - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("app").ne(2)) + assert response["Count"] == 1 @mock_dynamodb2 def test_scan_filter4(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'N'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "N"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) - table = dynamodb.Table('test1') + table = dynamodb.Table("test1") response = table.scan( - FilterExpression=Attr('epoch_ts').lt(7) & Attr('fanout_ts').not_exists() + FilterExpression=Attr("epoch_ts").lt(7) & Attr("fanout_ts").not_exists() ) # Just testing - assert response['Count'] == 0 + assert response["Count"] == 0 @mock_dynamodb2 def test_bad_scan_filter(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) - table = dynamodb.Table('test1') + table = dynamodb.Table("test1") # Bad expression try: - table.scan( - FilterExpression='client test' - ) + table.scan(FilterExpression="client test") except ClientError as err: - err.response['Error']['Code'].should.equal('ValidationError') + err.response["Error"]["Code"].should.equal("ValidationError") else: - raise RuntimeError('Should of raised ResourceInUseException') + raise RuntimeError("Should have raised ResourceInUseException") @mock_dynamodb2 def test_create_table_pay_per_request(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - BillingMode="PAY_PER_REQUEST" + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + BillingMode="PAY_PER_REQUEST", ) @mock_dynamodb2 def test_create_table_error_pay_per_request_with_provisioned_param(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") try: client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123}, - BillingMode="PAY_PER_REQUEST" + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, + BillingMode="PAY_PER_REQUEST", ) except ClientError as err: - err.response['Error']['Code'].should.equal('ValidationException') + err.response["Error"]["Code"].should.equal("ValidationException") @mock_dynamodb2 def test_duplicate_create(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) try: client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceInUseException') + err.response["Error"]["Code"].should.equal("ResourceInUseException") else: - raise RuntimeError('Should of raised ResourceInUseException') + raise RuntimeError("Should have raised ResourceInUseException") @mock_dynamodb2 def test_delete_table(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) - client.delete_table(TableName='test1') + client.delete_table(TableName="test1") resp = client.list_tables() - len(resp['TableNames']).should.equal(0) + len(resp["TableNames"]).should.equal(0) try: - client.delete_table(TableName='test1') + client.delete_table(TableName="test1") except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should have raised ResourceNotFoundException") @mock_dynamodb2 def test_delete_item(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"S": "app1"}} ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app2'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"S": "app2"}} ) - table = dynamodb.Table('test1') + table = dynamodb.Table("test1") response = table.scan() - assert response['Count'] == 2 + assert response["Count"] == 2 # Test ReturnValues validation with assert_raises(ClientError) as ex: - table.delete_item(Key={'client': 'client1', 'app': 'app1'}, - ReturnValues='ALL_NEW') + table.delete_item( + Key={"client": "client1", "app": "app1"}, ReturnValues="ALL_NEW" + ) # Test deletion and returning old value - response = table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_OLD') - response['Attributes'].should.contain('client') - response['Attributes'].should.contain('app') + response = table.delete_item( + Key={"client": "client1", "app": "app1"}, ReturnValues="ALL_OLD" + ) + response["Attributes"].should.contain("client") + response["Attributes"].should.contain("app") response = table.scan() - assert response['Count'] == 1 + assert response["Count"] == 1 # Test deletion returning nothing - response = table.delete_item(Key={'client': 'client1', 'app': 'app2'}) - len(response['Attributes']).should.equal(0) + response = table.delete_item(Key={"client": "client1", "app": "app2"}) + len(response["Attributes"]).should.equal(0) response = table.scan() - assert response['Count'] == 0 + assert response["Count"] == 0 @mock_dynamodb2 def test_describe_limits(): - client = boto3.client('dynamodb', region_name='eu-central-1') + client = boto3.client("dynamodb", region_name="eu-central-1") resp = client.describe_limits() - resp['AccountMaxReadCapacityUnits'].should.equal(20000) - resp['AccountMaxWriteCapacityUnits'].should.equal(20000) - resp['TableMaxWriteCapacityUnits'].should.equal(10000) - resp['TableMaxReadCapacityUnits'].should.equal(10000) + resp["AccountMaxReadCapacityUnits"].should.equal(20000) + resp["AccountMaxWriteCapacityUnits"].should.equal(20000) + resp["TableMaxWriteCapacityUnits"].should.equal(10000) + resp["TableMaxReadCapacityUnits"].should.equal(10000) @mock_dynamodb2 def test_set_ttl(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.update_time_to_live( - TableName='test1', - TimeToLiveSpecification={ - 'Enabled': True, - 'AttributeName': 'expire' - } + TableName="test1", + TimeToLiveSpecification={"Enabled": True, "AttributeName": "expire"}, ) - resp = client.describe_time_to_live(TableName='test1') - resp['TimeToLiveDescription']['TimeToLiveStatus'].should.equal('ENABLED') - resp['TimeToLiveDescription']['AttributeName'].should.equal('expire') + resp = client.describe_time_to_live(TableName="test1") + resp["TimeToLiveDescription"]["TimeToLiveStatus"].should.equal("ENABLED") + resp["TimeToLiveDescription"]["AttributeName"].should.equal("expire") client.update_time_to_live( - TableName='test1', - TimeToLiveSpecification={ - 'Enabled': False, - 'AttributeName': 'expire' - } + TableName="test1", + TimeToLiveSpecification={"Enabled": False, "AttributeName": "expire"}, ) - resp = client.describe_time_to_live(TableName='test1') - resp['TimeToLiveDescription']['TimeToLiveStatus'].should.equal('DISABLED') + resp = client.describe_time_to_live(TableName="test1") + resp["TimeToLiveDescription"]["TimeToLiveStatus"].should.equal("DISABLED") # https://github.com/spulec/moto/issues/1043 @mock_dynamodb2 def test_query_missing_expr_names(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, + ) + client.put_item( + TableName="test1", Item={"client": {"S": "test1"}, "app": {"S": "test1"}} + ) + client.put_item( + TableName="test1", Item={"client": {"S": "test2"}, "app": {"S": "test2"}} ) - client.put_item(TableName='test1', Item={'client': {'S': 'test1'}, 'app': {'S': 'test1'}}) - client.put_item(TableName='test1', Item={'client': {'S': 'test2'}, 'app': {'S': 'test2'}}) - resp = client.query(TableName='test1', KeyConditionExpression='client=:client', - ExpressionAttributeValues={':client': {'S': 'test1'}}) + resp = client.query( + TableName="test1", + KeyConditionExpression="client=:client", + ExpressionAttributeValues={":client": {"S": "test1"}}, + ) - resp['Count'].should.equal(1) - resp['Items'][0]['client']['S'].should.equal('test1') + resp["Count"].should.equal(1) + resp["Items"][0]["client"]["S"].should.equal("test1") - resp = client.query(TableName='test1', KeyConditionExpression=':name=test2', - ExpressionAttributeNames={':name': 'client'}) + resp = client.query( + TableName="test1", + KeyConditionExpression=":name=test2", + ExpressionAttributeNames={":name": "client"}, + ) - resp['Count'].should.equal(1) - resp['Items'][0]['client']['S'].should.equal('test2') + resp["Count"].should.equal(1) + resp["Items"][0]["client"]["S"].should.equal("test2") # https://github.com/spulec/moto/issues/2328 @mock_dynamodb2 def test_update_item_with_list(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='Table', - KeySchema=[ - { - 'AttributeName': 'key', - 'KeyType': 'HASH' - } - ], - AttributeDefinitions=[ - { - 'AttributeName': 'key', - 'AttributeType': 'S' - }, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } + TableName="Table", + KeySchema=[{"AttributeName": "key", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "key", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - table = dynamodb.Table('Table') + table = dynamodb.Table("Table") table.update_item( - Key={'key': 'the-key'}, - AttributeUpdates={ - 'list': {'Value': [1, 2], 'Action': 'PUT'} - } + Key={"key": "the-key"}, + AttributeUpdates={"list": {"Value": [1, 2], "Action": "PUT"}}, ) - resp = table.get_item(Key={'key': 'the-key'}) - resp['Item'].should.equal({ - 'key': 'the-key', - 'list': [1, 2] - }) + resp = table.get_item(Key={"key": "the-key"}) + resp["Item"].should.equal({"key": "the-key", "list": [1, 2]}) # https://github.com/spulec/moto/issues/1342 @mock_dynamodb2 def test_update_item_on_map(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') - client = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": {"nested": {"data": "test"}}, } ) - table = dynamodb.Table('users') - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': {'nested': {'data': 'test'}}, - }) resp = table.scan() - resp['Items'][0]['body'].should.equal({'nested': {'data': 'test'}}) + resp["Items"][0]["body"].should.equal({"nested": {"data": "test"}}) # Nonexistent nested attributes are supported for existing top-level attributes. - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, - UpdateExpression='SET body.#nested.#data = :tb, body.nested.#nonexistentnested.#data = :tb2', + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, + UpdateExpression="SET body.#nested.#data = :tb, body.nested.#nonexistentnested.#data = :tb2", ExpressionAttributeNames={ - '#nested': 'nested', - '#nonexistentnested': 'nonexistentnested', - '#data': 'data' + "#nested": "nested", + "#nonexistentnested": "nonexistentnested", + "#data": "data", }, - ExpressionAttributeValues={ - ':tb': 'new_value', - ':tb2': 'other_value' - }) + ExpressionAttributeValues={":tb": "new_value", ":tb2": "other_value"}, + ) resp = table.scan() - resp['Items'][0]['body'].should.equal({ - 'nested': { - 'data': 'new_value', - 'nonexistentnested': {'data': 'other_value'} - } - }) + resp["Items"][0]["body"].should.equal( + {"nested": {"data": "new_value", "nonexistentnested": {"data": "other_value"}}} + ) # Test nested value for a nonexistent attribute. with assert_raises(client.exceptions.ConditionalCheckFailedException): - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, - UpdateExpression='SET nonexistent.#nested = :tb', - ExpressionAttributeNames={ - '#nested': 'nested' - }, - ExpressionAttributeValues={ - ':tb': 'new_value' - }) - + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, + UpdateExpression="SET nonexistent.#nested = :tb", + ExpressionAttributeNames={"#nested": "nested"}, + ExpressionAttributeValues={":tb": "new_value"}, + ) # https://github.com/spulec/moto/issues/1358 @mock_dynamodb2 def test_update_if_not_exists(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123"}) - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, # if_not_exists without space - UpdateExpression='SET created_at=if_not_exists(created_at,:created_at)', - ExpressionAttributeValues={ - ':created_at': 123 - } + UpdateExpression="SET created_at=if_not_exists(created_at,:created_at)", + ExpressionAttributeValues={":created_at": 123}, ) resp = table.scan() - assert resp['Items'][0]['created_at'] == 123 + assert resp["Items"][0]["created_at"] == 123 - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, # if_not_exists with space - UpdateExpression='SET created_at = if_not_exists (created_at, :created_at)', - ExpressionAttributeValues={ - ':created_at': 456 - } + UpdateExpression="SET created_at = if_not_exists (created_at, :created_at)", + ExpressionAttributeValues={":created_at": 456}, ) resp = table.scan() # Still the original value - assert resp['Items'][0]['created_at'] == 123 + assert resp["Items"][0]["created_at"] == 123 # https://github.com/spulec/moto/issues/1937 @mock_dynamodb2 def test_update_return_attributes(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='moto-test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1} + TableName="moto-test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) def update(col, to, rv): return dynamodb.update_item( - TableName='moto-test', - Key={'id': {'S': 'foo'}}, - AttributeUpdates={col: {'Value': {'S': to}, 'Action': 'PUT'}}, - ReturnValues=rv + TableName="moto-test", + Key={"id": {"S": "foo"}}, + AttributeUpdates={col: {"Value": {"S": to}, "Action": "PUT"}}, + ReturnValues=rv, ) - r = update('col1', 'val1', 'ALL_NEW') - assert r['Attributes'] == {'id': {'S': 'foo'}, 'col1': {'S': 'val1'}} + r = update("col1", "val1", "ALL_NEW") + assert r["Attributes"] == {"id": {"S": "foo"}, "col1": {"S": "val1"}} - r = update('col1', 'val2', 'ALL_OLD') - assert r['Attributes'] == {'id': {'S': 'foo'}, 'col1': {'S': 'val1'}} + r = update("col1", "val2", "ALL_OLD") + assert r["Attributes"] == {"id": {"S": "foo"}, "col1": {"S": "val1"}} - r = update('col2', 'val3', 'UPDATED_NEW') - assert r['Attributes'] == {'col2': {'S': 'val3'}} + r = update("col2", "val3", "UPDATED_NEW") + assert r["Attributes"] == {"col2": {"S": "val3"}} - r = update('col2', 'val4', 'UPDATED_OLD') - assert r['Attributes'] == {'col2': {'S': 'val3'}} + r = update("col2", "val4", "UPDATED_OLD") + assert r["Attributes"] == {"col2": {"S": "val3"}} - r = update('col1', 'val5', 'NONE') - assert r['Attributes'] == {} + r = update("col1", "val5", "NONE") + assert r["Attributes"] == {} with assert_raises(ClientError) as ex: - r = update('col1', 'val6', 'WRONG') + r = update("col1", "val6", "WRONG") @mock_dynamodb2 def test_put_return_attributes(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='moto-test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1} + TableName="moto-test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) r = dynamodb.put_item( - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'col1': {'S': 'val1'}}, - ReturnValues='NONE' + TableName="moto-test", + Item={"id": {"S": "foo"}, "col1": {"S": "val1"}}, + ReturnValues="NONE", ) - assert 'Attributes' not in r + assert "Attributes" not in r r = dynamodb.put_item( - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'col1': {'S': 'val2'}}, - ReturnValues='ALL_OLD' + TableName="moto-test", + Item={"id": {"S": "foo"}, "col1": {"S": "val2"}}, + ReturnValues="ALL_OLD", ) - assert r['Attributes'] == {'id': {'S': 'foo'}, 'col1': {'S': 'val1'}} + assert r["Attributes"] == {"id": {"S": "foo"}, "col1": {"S": "val1"}} with assert_raises(ClientError) as ex: dynamodb.put_item( - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'col1': {'S': 'val3'}}, - ReturnValues='ALL_NEW' + TableName="moto-test", + Item={"id": {"S": "foo"}, "col1": {"S": "val3"}}, + ReturnValues="ALL_NEW", ) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal('Return values set to invalid value') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "Return values set to invalid value" + ) @mock_dynamodb2 def test_query_global_secondary_index_when_created_via_update_table_resource(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='users', - KeySchema=[ - { - 'AttributeName': 'user_id', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'user_id', - 'AttributeType': 'N', - }, - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - }, + TableName="users", + KeySchema=[{"AttributeName": "user_id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "user_id", "AttributeType": "N"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.update( - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - ], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], GlobalSecondaryIndexUpdates=[ - {'Create': - { - 'IndexName': 'forum_name_index', - 'KeySchema': [ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH', - }, - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + { + "Create": { + "IndexName": "forum_name_index", + "KeySchema": [{"AttributeName": "forum_name", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, }, } } - ] + ], ) next_user_id = 1 - for my_forum_name in ['cats', 'dogs']: - for my_subject in ['my pet is the cutest', 'wow look at what my pet did', "don't you love my pet?"]: - table.put_item(Item={'user_id': next_user_id, 'forum_name': my_forum_name, 'subject': my_subject}) + for my_forum_name in ["cats", "dogs"]: + for my_subject in [ + "my pet is the cutest", + "wow look at what my pet did", + "don't you love my pet?", + ]: + table.put_item( + Item={ + "user_id": next_user_id, + "forum_name": my_forum_name, + "subject": my_subject, + } + ) next_user_id += 1 # get all the cat users forum_only_query_response = table.query( - IndexName='forum_name_index', - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('forum_name').eq('cats'), + IndexName="forum_name_index", + Select="ALL_ATTRIBUTES", + KeyConditionExpression=Key("forum_name").eq("cats"), ) - forum_only_items = forum_only_query_response['Items'] + forum_only_items = forum_only_query_response["Items"] assert len(forum_only_items) == 3 for item in forum_only_items: - assert item['forum_name'] == 'cats' + assert item["forum_name"] == "cats" # query all cat users with a particular subject forum_and_subject_query_results = table.query( - IndexName='forum_name_index', - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('forum_name').eq('cats'), - FilterExpression=Attr('subject').eq('my pet is the cutest'), + IndexName="forum_name_index", + Select="ALL_ATTRIBUTES", + KeyConditionExpression=Key("forum_name").eq("cats"), + FilterExpression=Attr("subject").eq("my pet is the cutest"), ) - forum_and_subject_items = forum_and_subject_query_results['Items'] + forum_and_subject_items = forum_and_subject_query_results["Items"] assert len(forum_and_subject_items) == 1 - assert forum_and_subject_items[0] == {'user_id': Decimal('1'), 'forum_name': 'cats', - 'subject': 'my pet is the cutest'} + assert forum_and_subject_items[0] == { + "user_id": Decimal("1"), + "forum_name": "cats", + "subject": "my pet is the cutest", + } @mock_dynamodb2 def test_dynamodb_streams_1(): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test-streams", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, StreamSpecification={ - 'StreamEnabled': True, - 'StreamViewType': 'NEW_AND_OLD_IMAGES' - } + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", + }, ) - assert 'StreamSpecification' in resp['TableDescription'] - assert resp['TableDescription']['StreamSpecification'] == { - 'StreamEnabled': True, - 'StreamViewType': 'NEW_AND_OLD_IMAGES' + assert "StreamSpecification" in resp["TableDescription"] + assert resp["TableDescription"]["StreamSpecification"] == { + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", } - assert 'LatestStreamLabel' in resp['TableDescription'] - assert 'LatestStreamArn' in resp['TableDescription'] + assert "LatestStreamLabel" in resp["TableDescription"] + assert "LatestStreamArn" in resp["TableDescription"] - resp = conn.delete_table(TableName='test-streams') + resp = conn.delete_table(TableName="test-streams") - assert 'StreamSpecification' in resp['TableDescription'] + assert "StreamSpecification" in resp["TableDescription"] @mock_dynamodb2 def test_dynamodb_streams_2(): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-stream-update', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test-stream-update", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - assert 'StreamSpecification' not in resp['TableDescription'] + assert "StreamSpecification" not in resp["TableDescription"] resp = conn.update_table( - TableName='test-stream-update', - StreamSpecification={ - 'StreamEnabled': True, - 'StreamViewType': 'NEW_IMAGE' - } + TableName="test-stream-update", + StreamSpecification={"StreamEnabled": True, "StreamViewType": "NEW_IMAGE"}, ) - assert 'StreamSpecification' in resp['TableDescription'] - assert resp['TableDescription']['StreamSpecification'] == { - 'StreamEnabled': True, - 'StreamViewType': 'NEW_IMAGE' + assert "StreamSpecification" in resp["TableDescription"] + assert resp["TableDescription"]["StreamSpecification"] == { + "StreamEnabled": True, + "StreamViewType": "NEW_IMAGE", } - assert 'LatestStreamLabel' in resp['TableDescription'] - assert 'LatestStreamArn' in resp['TableDescription'] + assert "LatestStreamLabel" in resp["TableDescription"] + assert "LatestStreamArn" in resp["TableDescription"] @mock_dynamodb2 def test_condition_expressions(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - } + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match', + ConditionExpression="attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match", ExpressionAttributeNames={ - '#existing': 'existing', - '#nonexistent': 'nope', - '#match': 'match', + "#existing": "existing", + "#nonexistent": "nope", + "#match": "match", }, + ExpressionAttributeValues={":match": {"S": "match"}}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="NOT(attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2))", + ExpressionAttributeNames={"#nonexistent1": "nope", "#nonexistent2": "nope2"}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="attribute_exists(#nonexistent) OR attribute_exists(#existing)", + ExpressionAttributeNames={"#nonexistent": "nope", "#existing": "existing"}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="#client BETWEEN :a AND :z", + ExpressionAttributeNames={"#client": "client"}, + ExpressionAttributeValues={":a": {"S": "a"}, ":z": {"S": "z"}}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="#client IN (:client1, :client2)", + ExpressionAttributeNames={"#client": "client"}, ExpressionAttributeValues={ - ':match': {'S': 'match'} - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + ":client1": {"S": "client1"}, + ":client2": {"S": "client2"}, }, - ConditionExpression='NOT(attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2))', - ExpressionAttributeNames={ - '#nonexistent1': 'nope', - '#nonexistent2': 'nope2' - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - }, - ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)', - ExpressionAttributeNames={ - '#nonexistent': 'nope', - '#existing': 'existing' - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - }, - ConditionExpression='#client BETWEEN :a AND :z', - ExpressionAttributeNames={ - '#client': 'client', - }, - ExpressionAttributeValues={ - ':a': {'S': 'a'}, - ':z': {'S': 'z'}, - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - }, - ConditionExpression='#client IN (:client1, :client2)', - ExpressionAttributeNames={ - '#client': 'client', - }, - ExpressionAttributeValues={ - ':client1': {'S': 'client1'}, - ':client2': {'S': 'client2'}, - } ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2)', + ConditionExpression="attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2)", ExpressionAttributeNames={ - '#nonexistent1': 'nope', - '#nonexistent2': 'nope2' - } + "#nonexistent1": "nope", + "#nonexistent2": "nope2", + }, ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='NOT(attribute_not_exists(#nonexistent1) AND attribute_not_exists(#nonexistent2))', + ConditionExpression="NOT(attribute_not_exists(#nonexistent1) AND attribute_not_exists(#nonexistent2))", ExpressionAttributeNames={ - '#nonexistent1': 'nope', - '#nonexistent2': 'nope2' - } + "#nonexistent1": "nope", + "#nonexistent2": "nope2", + }, ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match', + ConditionExpression="attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match", ExpressionAttributeNames={ - '#existing': 'existing', - '#nonexistent': 'nope', - '#match': 'match', + "#existing": "existing", + "#nonexistent": "nope", + "#match": "match", }, - ExpressionAttributeValues={ - ':match': {'S': 'match2'} - } + ExpressionAttributeValues={":match": {"S": "match2"}}, ) # Make sure update_item honors ConditionExpression as well client.update_item( - TableName='test1', - Key={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - }, - UpdateExpression='set #match=:match', - ConditionExpression='attribute_exists(#existing)', - ExpressionAttributeNames={ - '#existing': 'existing', - '#match': 'match', - }, - ExpressionAttributeValues={ - ':match': {'S': 'match'} - } + TableName="test1", + Key={"client": {"S": "client1"}, "app": {"S": "app1"}}, + UpdateExpression="set #match=:match", + ConditionExpression="attribute_exists(#existing)", + ExpressionAttributeNames={"#existing": "existing", "#match": "match"}, + ExpressionAttributeValues={":match": {"S": "match"}}, ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.update_item( - TableName='test1', - Key={ - 'client': { 'S': 'client1'}, - 'app': { 'S': 'app1'}, - }, - UpdateExpression='set #match=:match', - ConditionExpression='attribute_not_exists(#existing)', - ExpressionAttributeValues={ - ':match': {'S': 'match'} - }, - ExpressionAttributeNames={ - '#existing': 'existing', - '#match': 'match', - }, + TableName="test1", + Key={"client": {"S": "client1"}, "app": {"S": "app1"}}, + UpdateExpression="set #match=:match", + ConditionExpression="attribute_not_exists(#existing)", + ExpressionAttributeValues={":match": {"S": "match"}}, + ExpressionAttributeNames={"#existing": "existing", "#match": "match"}, + ) + + with assert_raises(client.exceptions.ConditionalCheckFailedException): + client.delete_item( + TableName="test1", + Key={"client": {"S": "client1"}, "app": {"S": "app1"}}, + ConditionExpression="attribute_not_exists(#existing)", + ExpressionAttributeValues={":match": {"S": "match"}}, + ExpressionAttributeNames={"#existing": "existing", "#match": "match"}, ) @mock_dynamodb2 def test_condition_expression__attr_doesnt_exist(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], - AttributeDefinitions=[ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) client.put_item( - TableName='test', - Item={ - 'forum_name': {'S': 'foo'}, - 'ttl': {'N': 'bar'}, - } + TableName="test", Item={"forum_name": {"S": "foo"}, "ttl": {"N": "bar"}} ) - def update_if_attr_doesnt_exist(): # Test nonexistent top-level attribute. client.update_item( - TableName='test', - Key={ - 'forum_name': {'S': 'the-key'}, - 'subject': {'S': 'the-subject'}, - }, - UpdateExpression='set #new_state=:new_state, #ttl=:ttl', - ConditionExpression='attribute_not_exists(#new_state)', - ExpressionAttributeNames={'#new_state': 'foobar', '#ttl': 'ttl'}, + TableName="test", + Key={"forum_name": {"S": "the-key"}, "subject": {"S": "the-subject"}}, + UpdateExpression="set #new_state=:new_state, #ttl=:ttl", + ConditionExpression="attribute_not_exists(#new_state)", + ExpressionAttributeNames={"#new_state": "foobar", "#ttl": "ttl"}, ExpressionAttributeValues={ - ':new_state': {'S': 'some-value'}, - ':ttl': {'N': '12345.67'}, + ":new_state": {"S": "some-value"}, + ":ttl": {"N": "12345.67"}, }, - ReturnValues='ALL_NEW', + ReturnValues="ALL_NEW", ) update_if_attr_doesnt_exist() @@ -2006,190 +2490,1060 @@ def test_condition_expression__attr_doesnt_exist(): @mock_dynamodb2 def test_condition_expression__or_order(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], - AttributeDefinitions=[ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) # ensure that the RHS of the OR expression is not evaluated if the LHS # returns true (as it would result an error) client.update_item( - TableName='test', - Key={ - 'forum_name': {'S': 'the-key'}, - }, - UpdateExpression='set #ttl=:ttl', - ConditionExpression='attribute_not_exists(#ttl) OR #ttl <= :old_ttl', - ExpressionAttributeNames={'#ttl': 'ttl'}, - ExpressionAttributeValues={ - ':ttl': {'N': '6'}, - ':old_ttl': {'N': '5'}, - } + TableName="test", + Key={"forum_name": {"S": "the-key"}}, + UpdateExpression="set #ttl=:ttl", + ConditionExpression="attribute_not_exists(#ttl) OR #ttl <= :old_ttl", + ExpressionAttributeNames={"#ttl": "ttl"}, + ExpressionAttributeValues={":ttl": {"N": "6"}, ":old_ttl": {"N": "5"}}, ) +@mock_dynamodb2 +def test_condition_expression__and_order(): + client = boto3.client("dynamodb", region_name="us-east-1") + + client.create_table( + TableName="test", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) + + # ensure that the RHS of the AND expression is not evaluated if the LHS + # returns true (as it would result an error) + with assert_raises(client.exceptions.ConditionalCheckFailedException): + client.update_item( + TableName="test", + Key={"forum_name": {"S": "the-key"}}, + UpdateExpression="set #ttl=:ttl", + ConditionExpression="attribute_exists(#ttl) AND #ttl <= :old_ttl", + ExpressionAttributeNames={"#ttl": "ttl"}, + ExpressionAttributeValues={":ttl": {"N": "6"}, ":old_ttl": {"N": "5"}}, + ) + + @mock_dynamodb2 def test_query_gsi_with_range_key(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_hash_key', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_range_key', 'AttributeType': 'S'} + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_hash_key", "AttributeType": "S"}, + {"AttributeName": "gsi_range_key", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - { - 'AttributeName': 'gsi_hash_key', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'gsi_range_key', - 'KeyType': 'RANGE' - }, + "IndexName": "test_gsi", + "KeySchema": [ + {"AttributeName": "gsi_hash_key", "KeyType": "HASH"}, + {"AttributeName": "gsi_range_key", "KeyType": "RANGE"}, ], - 'Projection': { - 'ProjectionType': 'ALL', + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, - ] + } + ], ) dynamodb.put_item( - TableName='test', + TableName="test", Item={ - 'id': {'S': 'test1'}, - 'gsi_hash_key': {'S': 'key1'}, - 'gsi_range_key': {'S': 'range1'}, - } + "id": {"S": "test1"}, + "gsi_hash_key": {"S": "key1"}, + "gsi_range_key": {"S": "range1"}, + }, ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': 'test2'}, - 'gsi_hash_key': {'S': 'key1'}, - } + TableName="test", Item={"id": {"S": "test2"}, "gsi_hash_key": {"S": "key1"}} ) - res = dynamodb.query(TableName='test', IndexName='test_gsi', - KeyConditionExpression='gsi_hash_key = :gsi_hash_key AND gsi_range_key = :gsi_range_key', - ExpressionAttributeValues={ - ':gsi_hash_key': {'S': 'key1'}, - ':gsi_range_key': {'S': 'range1'} - }) + res = dynamodb.query( + TableName="test", + IndexName="test_gsi", + KeyConditionExpression="gsi_hash_key = :gsi_hash_key AND gsi_range_key = :gsi_range_key", + ExpressionAttributeValues={ + ":gsi_hash_key": {"S": "key1"}, + ":gsi_range_key": {"S": "range1"}, + }, + ) res.should.have.key("Count").equal(1) res.should.have.key("Items") - res['Items'][0].should.equal({ - 'id': {'S': 'test1'}, - 'gsi_hash_key': {'S': 'key1'}, - 'gsi_range_key': {'S': 'range1'}, - }) + res["Items"][0].should.equal( + { + "id": {"S": "test1"}, + "gsi_hash_key": {"S": "key1"}, + "gsi_range_key": {"S": "range1"}, + } + ) @mock_dynamodb2 def test_scan_by_non_exists_index(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_col', 'AttributeType': 'S'} + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_col", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - { - 'AttributeName': 'gsi_col', - 'KeyType': 'HASH' - }, - ], - 'Projection': { - 'ProjectionType': 'ALL', + "IndexName": "test_gsi", + "KeySchema": [{"AttributeName": "gsi_col", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, - ] + } + ], ) with assert_raises(ClientError) as ex: - dynamodb.scan(TableName='test', IndexName='non_exists_index') + dynamodb.scan(TableName="test", IndexName="non_exists_index") - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'The table does not have the specified index: non_exists_index' + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "The table does not have the specified index: non_exists_index" + ) + + +@mock_dynamodb2 +def test_query_by_non_exists_index(): + dynamodb = boto3.client("dynamodb", region_name="us-east-1") + + dynamodb.create_table( + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[ + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_col", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + GlobalSecondaryIndexes=[ + { + "IndexName": "test_gsi", + "KeySchema": [{"AttributeName": "gsi_col", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, + }, + } + ], + ) + + with assert_raises(ClientError) as ex: + dynamodb.query( + TableName="test", + IndexName="non_exists_index", + KeyConditionExpression="CarModel=M", + ) + + ex.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + ex.exception.response["Error"]["Message"].should.equal( + "Invalid index: non_exists_index for table: test. Available indexes are: test_gsi" ) @mock_dynamodb2 def test_batch_items_returns_all(): dynamodb = _create_user_table() - returned_items = dynamodb.batch_get_item(RequestItems={ - 'users': { - 'Keys': [{ - 'username': {'S': 'user0'} - }, { - 'username': {'S': 'user1'} - }, { - 'username': {'S': 'user2'} - }, { - 'username': {'S': 'user3'} - }], - 'ConsistentRead': True + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + } } - })['Responses']['users'] + )["Responses"]["users"] assert len(returned_items) == 3 - assert [item['username']['S'] for item in returned_items] == ['user1', 'user2', 'user3'] + assert [item["username"]["S"] for item in returned_items] == [ + "user1", + "user2", + "user3", + ] + + +@mock_dynamodb2 +def test_batch_items_with_basic_projection_expression(): + dynamodb = _create_user_table() + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + "ProjectionExpression": "username", + } + } + )["Responses"]["users"] + + returned_items.should.have.length_of(3) + [item["username"]["S"] for item in returned_items].should.be.equal( + ["user1", "user2", "user3"] + ) + [item.get("foo") for item in returned_items].should.be.equal([None, None, None]) + + # The projection expression should not remove data from storage + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + } + } + )["Responses"]["users"] + + [item["username"]["S"] for item in returned_items].should.be.equal( + ["user1", "user2", "user3"] + ) + [item["foo"]["S"] for item in returned_items].should.be.equal(["bar", "bar", "bar"]) + + +@mock_dynamodb2 +def test_batch_items_with_basic_projection_expression_and_attr_expression_names(): + dynamodb = _create_user_table() + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + "ProjectionExpression": "#rl", + "ExpressionAttributeNames": {"#rl": "username"}, + } + } + )["Responses"]["users"] + + returned_items.should.have.length_of(3) + [item["username"]["S"] for item in returned_items].should.be.equal( + ["user1", "user2", "user3"] + ) + [item.get("foo") for item in returned_items].should.be.equal([None, None, None]) @mock_dynamodb2 def test_batch_items_should_throw_exception_for_duplicate_request(): client = _create_user_table() with assert_raises(ClientError) as ex: - client.batch_get_item(RequestItems={ - 'users': { - 'Keys': [{ - 'username': {'S': 'user0'} - }, { - 'username': {'S': 'user0'} - }], - 'ConsistentRead': True - }}) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal('Provided list of item keys contains duplicates') + client.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user0"}}, + ], + "ConsistentRead": True, + } + } + ) + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Provided list of item keys contains duplicates" + ) + + +@mock_dynamodb2 +def test_index_with_unknown_attributes_should_fail(): + dynamodb = boto3.client("dynamodb", region_name="us-east-1") + + expected_exception = ( + "Some index key attributes are not defined in AttributeDefinitions." + ) + + with assert_raises(ClientError) as ex: + dynamodb.create_table( + AttributeDefinitions=[ + {"AttributeName": "customer_nr", "AttributeType": "S"}, + {"AttributeName": "last_name", "AttributeType": "S"}, + ], + TableName="table_with_missing_attribute_definitions", + KeySchema=[ + {"AttributeName": "customer_nr", "KeyType": "HASH"}, + {"AttributeName": "last_name", "KeyType": "RANGE"}, + ], + LocalSecondaryIndexes=[ + { + "IndexName": "indexthataddsanadditionalattribute", + "KeySchema": [ + {"AttributeName": "customer_nr", "KeyType": "HASH"}, + {"AttributeName": "postcode", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + } + ], + BillingMode="PAY_PER_REQUEST", + ) + + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.contain(expected_exception) + + +@mock_dynamodb2 +def test_update_list_index__set_existing_index(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="set itemlist[1]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar2_update"}}, + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + result["id"].should.equal({"S": "foo"}) + result["itemlist"].should.equal( + {"L": [{"S": "bar1"}, {"S": "bar2_update"}, {"S": "bar3"}]} + ) + + +@mock_dynamodb2 +def test_update_list_index__set_existing_nested_index(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": {"itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}} + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemmap.itemlist[1]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar2_update"}}, + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + result["id"].should.equal({"S": "foo2"}) + result["itemmap"]["M"]["itemlist"]["L"].should.equal( + [{"S": "bar1"}, {"S": "bar2_update"}, {"S": "bar3"}] + ) + + +@mock_dynamodb2 +def test_update_list_index__set_index_out_of_range(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="set itemlist[10]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar10"}}, + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + assert result["id"] == {"S": "foo"} + assert result["itemlist"] == { + "L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}, {"S": "bar10"}] + } + + +@mock_dynamodb2 +def test_update_list_index__set_nested_index_out_of_range(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": {"itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}} + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemmap.itemlist[10]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar10"}}, + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + assert result["id"] == {"S": "foo2"} + assert result["itemmap"]["M"]["itemlist"]["L"] == [ + {"S": "bar1"}, + {"S": "bar2"}, + {"S": "bar3"}, + {"S": "bar10"}, + ] + + +@mock_dynamodb2 +def test_update_list_index__set_double_nested_index(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": { + "itemlist": { + "L": [ + {"M": {"foo": {"S": "bar11"}, "foos": {"S": "bar12"}}}, + {"M": {"foo": {"S": "bar21"}, "foos": {"S": "bar21"}}}, + ] + } + } + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemmap.itemlist[1].foos=:Item", + ExpressionAttributeValues={":Item": {"S": "bar22"}}, + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + assert result["id"] == {"S": "foo2"} + len(result["itemmap"]["M"]["itemlist"]["L"]).should.equal(2) + result["itemmap"]["M"]["itemlist"]["L"][0].should.equal( + {"M": {"foo": {"S": "bar11"}, "foos": {"S": "bar12"}}} + ) # unchanged + result["itemmap"]["M"]["itemlist"]["L"][1].should.equal( + {"M": {"foo": {"S": "bar21"}, "foos": {"S": "bar22"}}} + ) # updated + + +@mock_dynamodb2 +def test_update_list_index__set_index_of_a_string(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, Item={"id": {"S": "foo2"}, "itemstr": {"S": "somestring"}} + ) + with assert_raises(ClientError) as ex: + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemstr[1]=:Item", + ExpressionAttributeValues={":Item": {"S": "string_update"}}, + ) + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})[ + "Item" + ] + + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "The document path provided in the update expression is invalid for update" + ) + + +@mock_dynamodb2 +def test_remove_top_level_attribute(): + table_name = "test_remove" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, Item={"id": {"S": "foo"}, "item": {"S": "bar"}} + ) + client.update_item( + TableName=table_name, Key={"id": {"S": "foo"}}, UpdateExpression="REMOVE item" + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + result.should.equal({"id": {"S": "foo"}}) + + +@mock_dynamodb2 +def test_remove_list_index__remove_existing_index(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="REMOVE itemlist[1]", + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + result["id"].should.equal({"S": "foo"}) + result["itemlist"].should.equal({"L": [{"S": "bar1"}, {"S": "bar3"}]}) + + +@mock_dynamodb2 +def test_remove_list_index__remove_existing_nested_index(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": {"M": {"itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}]}}}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="REMOVE itemmap.itemlist[1]", + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + result["id"].should.equal({"S": "foo2"}) + result["itemmap"]["M"]["itemlist"]["L"].should.equal([{"S": "bar1"}]) + + +@mock_dynamodb2 +def test_remove_list_index__remove_existing_double_nested_index(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": { + "itemlist": { + "L": [ + {"M": {"foo00": {"S": "bar1"}, "foo01": {"S": "bar2"}}}, + {"M": {"foo10": {"S": "bar1"}, "foo11": {"S": "bar2"}}}, + ] + } + } + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="REMOVE itemmap.itemlist[1].foo10", + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + assert result["id"] == {"S": "foo2"} + assert result["itemmap"]["M"]["itemlist"]["L"][0]["M"].should.equal( + {"foo00": {"S": "bar1"}, "foo01": {"S": "bar2"}} + ) # untouched + assert result["itemmap"]["M"]["itemlist"]["L"][1]["M"].should.equal( + {"foo11": {"S": "bar2"}} + ) # changed + + +@mock_dynamodb2 +def test_remove_list_index__remove_index_out_of_range(): + table_name = "test_list_index_access" + client = create_table_with_list(table_name) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="REMOVE itemlist[10]", + ) + # + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + assert result["id"] == {"S": "foo"} + assert result["itemlist"] == {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]} + + +def create_table_with_list(table_name): + client = boto3.client("dynamodb", region_name="us-east-1") + client.create_table( + TableName=table_name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + BillingMode="PAY_PER_REQUEST", + ) + return client + + +@mock_dynamodb2 +def test_sorted_query_with_numerical_sort_key(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + dynamodb.create_table( + TableName="CarCollection", + KeySchema=[ + {"AttributeName": "CarModel", "KeyType": "HASH"}, + {"AttributeName": "CarPrice", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "CarModel", "AttributeType": "S"}, + {"AttributeName": "CarPrice", "AttributeType": "N"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) + + def create_item(price): + return {"CarModel": "M", "CarPrice": price} + + table = dynamodb.Table("CarCollection") + items = list(map(create_item, [2, 1, 10, 3])) + for item in items: + table.put_item(Item=item) + + response = table.query(KeyConditionExpression=Key("CarModel").eq("M")) + + response_items = response["Items"] + assert len(items) == len(response_items) + assert all(isinstance(item["CarPrice"], Decimal) for item in response_items) + response_prices = [item["CarPrice"] for item in response_items] + expected_prices = [Decimal(item["CarPrice"]) for item in items] + expected_prices.sort() + assert ( + expected_prices == response_prices + ), "result items are not sorted by numerical value" + + +# https://github.com/spulec/moto/issues/1874 +@mock_dynamodb2 +def test_item_size_is_under_400KB(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + client = boto3.client("dynamodb", region_name="us-east-1") + + dynamodb.create_table( + TableName="moto-test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) + table = dynamodb.Table("moto-test") + + large_item = "x" * 410 * 1000 + assert_failure_due_to_item_size( + func=client.put_item, + TableName="moto-test", + Item={"id": {"S": "foo"}, "item": {"S": large_item}}, + ) + assert_failure_due_to_item_size( + func=table.put_item, Item={"id": "bar", "item": large_item} + ) + assert_failure_due_to_item_size( + func=client.update_item, + TableName="moto-test", + Key={"id": {"S": "foo2"}}, + UpdateExpression="set item=:Item", + ExpressionAttributeValues={":Item": {"S": large_item}}, + ) + # Assert op fails when updating a nested item + assert_failure_due_to_item_size( + func=table.put_item, Item={"id": "bar", "itemlist": [{"item": large_item}]} + ) + assert_failure_due_to_item_size( + func=client.put_item, + TableName="moto-test", + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"M": {"item1": {"S": large_item}}}]}, + }, + ) + + +def assert_failure_due_to_item_size(func, **kwargs): + with assert_raises(ClientError) as ex: + func(**kwargs) + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Item size has exceeded the maximum allowed size" + ) + + +@mock_dynamodb2 +# https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_Query.html#DDB-Query-request-KeyConditionExpression +def test_hash_key_cannot_use_begins_with_operations(): + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + table = dynamodb.create_table( + TableName="test-table", + KeySchema=[{"AttributeName": "key", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "key", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) + + items = [ + {"key": "prefix-$LATEST", "value": "$LATEST"}, + {"key": "prefix-DEV", "value": "DEV"}, + {"key": "prefix-PROD", "value": "PROD"}, + ] + + with table.batch_writer() as batch: + for item in items: + batch.put_item(Item=item) + + table = dynamodb.Table("test-table") + with assert_raises(ClientError) as ex: + table.query(KeyConditionExpression=Key("key").begins_with("prefix-")) + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Query key condition not supported" + ) + + +@mock_dynamodb2 +def test_update_supports_complex_expression_attribute_values(): + client = boto3.client("dynamodb", region_name="us-east-1") + + client.create_table( + AttributeDefinitions=[{"AttributeName": "SHA256", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "SHA256", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + + client.update_item( + TableName="TestTable", + Key={"SHA256": {"S": "sha-of-file"}}, + UpdateExpression=( + "SET MD5 = :md5," "MyStringSet = :string_set," "MyMap = :map" + ), + ExpressionAttributeValues={ + ":md5": {"S": "md5-of-file"}, + ":string_set": {"SS": ["string1", "string2"]}, + ":map": {"M": {"EntryKey": {"SS": ["thing1", "thing2"]}}}, + }, + ) + result = client.get_item( + TableName="TestTable", Key={"SHA256": {"S": "sha-of-file"}} + )["Item"] + result.should.equal( + { + "MyStringSet": {"SS": ["string1", "string2"]}, + "MyMap": {"M": {"EntryKey": {"SS": ["thing1", "thing2"]}}}, + "SHA256": {"S": "sha-of-file"}, + "MD5": {"S": "md5-of-file"}, + } + ) + + +@mock_dynamodb2 +def test_update_supports_list_append(): + # Verify whether the list_append operation works as expected + client = boto3.client("dynamodb", region_name="us-east-1") + + client.create_table( + AttributeDefinitions=[{"AttributeName": "SHA256", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "SHA256", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="TestTable", + Item={"SHA256": {"S": "sha-of-file"}, "crontab": {"L": [{"S": "bar1"}]}}, + ) + + # Update item using list_append expression + client.update_item( + TableName="TestTable", + Key={"SHA256": {"S": "sha-of-file"}}, + UpdateExpression="SET crontab = list_append(crontab, :i)", + ExpressionAttributeValues={":i": {"L": [{"S": "bar2"}]}}, + ) + + # Verify item is appended to the existing list + result = client.get_item( + TableName="TestTable", Key={"SHA256": {"S": "sha-of-file"}} + )["Item"] + result.should.equal( + { + "SHA256": {"S": "sha-of-file"}, + "crontab": {"L": [{"S": "bar1"}, {"S": "bar2"}]}, + } + ) + + +@mock_dynamodb2 +def test_update_supports_nested_list_append(): + # Verify whether we can append a list that's inside a map + client = boto3.client("dynamodb", region_name="us-east-1") + + client.create_table( + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="TestTable", + Item={ + "id": {"S": "nested_list_append"}, + "a": {"M": {"b": {"L": [{"S": "bar1"}]}}}, + }, + ) + + # Update item using list_append expression + client.update_item( + TableName="TestTable", + Key={"id": {"S": "nested_list_append"}}, + UpdateExpression="SET a.#b = list_append(a.#b, :i)", + ExpressionAttributeValues={":i": {"L": [{"S": "bar2"}]}}, + ExpressionAttributeNames={"#b": "b"}, + ) + + # Verify item is appended to the existing list + result = client.get_item( + TableName="TestTable", Key={"id": {"S": "nested_list_append"}} + )["Item"] + result.should.equal( + { + "id": {"S": "nested_list_append"}, + "a": {"M": {"b": {"L": [{"S": "bar1"}, {"S": "bar2"}]}}}, + } + ) + + +@mock_dynamodb2 +def test_update_supports_multiple_levels_nested_list_append(): + # Verify whether we can append a list that's inside a map that's inside a map (Inception!) + client = boto3.client("dynamodb", region_name="us-east-1") + + client.create_table( + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="TestTable", + Item={ + "id": {"S": "nested_list_append"}, + "a": {"M": {"b": {"M": {"c": {"L": [{"S": "bar1"}]}}}}}, + }, + ) + + # Update item using list_append expression + client.update_item( + TableName="TestTable", + Key={"id": {"S": "nested_list_append"}}, + UpdateExpression="SET a.#b.c = list_append(a.#b.#c, :i)", + ExpressionAttributeValues={":i": {"L": [{"S": "bar2"}]}}, + ExpressionAttributeNames={"#b": "b", "#c": "c"}, + ) + + # Verify item is appended to the existing list + result = client.get_item( + TableName="TestTable", Key={"id": {"S": "nested_list_append"}} + )["Item"] + result.should.equal( + { + "id": {"S": "nested_list_append"}, + "a": {"M": {"b": {"M": {"c": {"L": [{"S": "bar1"}, {"S": "bar2"}]}}}}}, + } + ) + + +@mock_dynamodb2 +def test_update_supports_nested_list_append_onto_another_list(): + # Verify whether we can take the contents of one list, and use that to fill another list + # Note that the contents of the other list is completely overwritten + client = boto3.client("dynamodb", region_name="us-east-1") + + client.create_table( + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="TestTable", + Item={ + "id": {"S": "list_append_another"}, + "a": {"M": {"b": {"L": [{"S": "bar1"}]}, "c": {"L": [{"S": "car1"}]}}}, + }, + ) + + # Update item using list_append expression + client.update_item( + TableName="TestTable", + Key={"id": {"S": "list_append_another"}}, + UpdateExpression="SET a.#c = list_append(a.#b, :i)", + ExpressionAttributeValues={":i": {"L": [{"S": "bar2"}]}}, + ExpressionAttributeNames={"#b": "b", "#c": "c"}, + ) + + # Verify item is appended to the existing list + result = client.get_item( + TableName="TestTable", Key={"id": {"S": "list_append_another"}} + )["Item"] + result.should.equal( + { + "id": {"S": "list_append_another"}, + "a": { + "M": { + "b": {"L": [{"S": "bar1"}]}, + "c": {"L": [{"S": "bar1"}, {"S": "bar2"}]}, + } + }, + } + ) + + +@mock_dynamodb2 +def test_update_catches_invalid_list_append_operation(): + client = boto3.client("dynamodb", region_name="us-east-1") + + client.create_table( + AttributeDefinitions=[{"AttributeName": "SHA256", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "SHA256", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="TestTable", + Item={"SHA256": {"S": "sha-of-file"}, "crontab": {"L": [{"S": "bar1"}]}}, + ) + + # Update item using invalid list_append expression + with assert_raises(ParamValidationError) as ex: + client.update_item( + TableName="TestTable", + Key={"SHA256": {"S": "sha-of-file"}}, + UpdateExpression="SET crontab = list_append(crontab, :i)", + ExpressionAttributeValues={":i": [{"S": "bar2"}]}, + ) + + # Verify correct error is returned + str(ex.exception).should.match("Parameter validation failed:") + str(ex.exception).should.match( + "Invalid type for parameter ExpressionAttributeValues." + ) def _create_user_table(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='users', - KeySchema=[{'AttributeName': 'username', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'username', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 5, 'WriteCapacityUnits': 5} + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="users", Item={"username": {"S": "user1"}, "foo": {"S": "bar"}} + ) + client.put_item( + TableName="users", Item={"username": {"S": "user2"}, "foo": {"S": "bar"}} + ) + client.put_item( + TableName="users", Item={"username": {"S": "user3"}, "foo": {"S": "bar"}} ) - client.put_item(TableName='users', Item={'username': {'S': 'user1'}, 'foo': {'S': 'bar'}}) - client.put_item(TableName='users', Item={'username': {'S': 'user2'}, 'foo': {'S': 'bar'}}) - client.put_item(TableName='users', Item={'username': {'S': 'user3'}, 'foo': {'S': 'bar'}}) return client + + +@mock_dynamodb2 +def test_update_item_if_original_value_is_none(): + dynamo = boto3.resource("dynamodb", region_name="eu-central-1") + dynamo.create_table( + AttributeDefinitions=[{"AttributeName": "job_id", "AttributeType": "S"}], + TableName="origin-rbu-dev", + KeySchema=[{"AttributeName": "job_id", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) + table = dynamo.Table("origin-rbu-dev") + table.put_item(Item={"job_id": "a", "job_name": None}) + table.update_item( + Key={"job_id": "a"}, + UpdateExpression="SET job_name = :output", + ExpressionAttributeValues={":output": "updated"}, + ) + table.scan()["Items"][0]["job_name"].should.equal("updated") + + +@mock_dynamodb2 +def test_update_nested_item_if_original_value_is_none(): + dynamo = boto3.resource("dynamodb", region_name="eu-central-1") + dynamo.create_table( + AttributeDefinitions=[{"AttributeName": "job_id", "AttributeType": "S"}], + TableName="origin-rbu-dev", + KeySchema=[{"AttributeName": "job_id", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) + table = dynamo.Table("origin-rbu-dev") + table.put_item(Item={"job_id": "a", "job_details": {"job_name": None}}) + table.update_item( + Key={"job_id": "a"}, + UpdateExpression="SET job_details.job_name = :output", + ExpressionAttributeValues={":output": "updated"}, + ) + table.scan()["Items"][0]["job_details"]["job_name"].should.equal("updated") + + +@mock_dynamodb2 +def test_allow_update_to_item_with_different_type(): + dynamo = boto3.resource("dynamodb", region_name="eu-central-1") + dynamo.create_table( + AttributeDefinitions=[{"AttributeName": "job_id", "AttributeType": "S"}], + TableName="origin-rbu-dev", + KeySchema=[{"AttributeName": "job_id", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) + table = dynamo.Table("origin-rbu-dev") + table.put_item(Item={"job_id": "a", "job_details": {"job_name": {"nested": "yes"}}}) + table.put_item(Item={"job_id": "b", "job_details": {"job_name": {"nested": "yes"}}}) + table.update_item( + Key={"job_id": "a"}, + UpdateExpression="SET job_details.job_name = :output", + ExpressionAttributeValues={":output": "updated"}, + ) + table.get_item(Key={"job_id": "a"})["Item"]["job_details"][ + "job_name" + ].should.be.equal("updated") + table.get_item(Key={"job_id": "b"})["Item"]["job_details"][ + "job_name" + ].should.be.equal({"nested": "yes"}) diff --git a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py index e64d7d196..7c7770874 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py @@ -11,6 +11,7 @@ from freezegun import freeze_time from moto import mock_dynamodb2, mock_dynamodb2_deprecated from boto.exception import JSONResponseError from tests.helpers import requires_boto_gte + try: from boto.dynamodb2.fields import GlobalAllIndex, HashKey, RangeKey, AllIndex from boto.dynamodb2.table import Item, Table @@ -22,36 +23,28 @@ except ImportError: def create_table(): - table = Table.create('messages', schema=[ - HashKey('forum_name'), - RangeKey('subject'), - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", + schema=[HashKey("forum_name"), RangeKey("subject")], + throughput={"read": 10, "write": 10}, + ) return table def create_table_with_local_indexes(): table = Table.create( - 'messages', - schema=[ - HashKey('forum_name'), - RangeKey('subject'), - ], - throughput={ - 'read': 10, - 'write': 10, - }, + "messages", + schema=[HashKey("forum_name"), RangeKey("subject")], + throughput={"read": 10, "write": 10}, indexes=[ AllIndex( - 'threads_index', + "threads_index", parts=[ - HashKey('forum_name', data_type=STRING), - RangeKey('threads', data_type=NUMBER), - ] + HashKey("forum_name", data_type=STRING), + RangeKey("threads", data_type=NUMBER), + ], ) - ] + ], ) return table @@ -67,25 +60,28 @@ def iterate_results(res): def test_create_table(): table = create_table() expected = { - 'Table': { - 'AttributeDefinitions': [ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - {'AttributeName': 'subject', 'AttributeType': 'S'} + "Table": { + "AttributeDefinitions": [ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - 'ProvisionedThroughput': { - 'NumberOfDecreasesToday': 0, 'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10 + "ProvisionedThroughput": { + "NumberOfDecreasesToday": 0, + "WriteCapacityUnits": 10, + "ReadCapacityUnits": 10, }, - 'TableSizeBytes': 0, - 'TableName': 'messages', - 'TableStatus': 'ACTIVE', - 'TableArn': 'arn:aws:dynamodb:us-east-1:123456789011:table/messages', - 'KeySchema': [ - {'KeyType': 'HASH', 'AttributeName': 'forum_name'}, - {'KeyType': 'RANGE', 'AttributeName': 'subject'} + "TableSizeBytes": 0, + "TableName": "messages", + "TableStatus": "ACTIVE", + "TableArn": "arn:aws:dynamodb:us-east-1:123456789011:table/messages", + "KeySchema": [ + {"KeyType": "HASH", "AttributeName": "forum_name"}, + {"KeyType": "RANGE", "AttributeName": "subject"}, ], - 'LocalSecondaryIndexes': [], - 'ItemCount': 0, 'CreationDateTime': 1326499200.0, - 'GlobalSecondaryIndexes': [] + "LocalSecondaryIndexes": [], + "ItemCount": 0, + "CreationDateTime": 1326499200.0, + "GlobalSecondaryIndexes": [], } } table.describe().should.equal(expected) @@ -97,38 +93,38 @@ def test_create_table(): def test_create_table_with_local_index(): table = create_table_with_local_indexes() expected = { - 'Table': { - 'AttributeDefinitions': [ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - {'AttributeName': 'subject', 'AttributeType': 'S'}, - {'AttributeName': 'threads', 'AttributeType': 'N'} + "Table": { + "AttributeDefinitions": [ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "threads", "AttributeType": "N"}, ], - 'ProvisionedThroughput': { - 'NumberOfDecreasesToday': 0, - 'WriteCapacityUnits': 10, - 'ReadCapacityUnits': 10, + "ProvisionedThroughput": { + "NumberOfDecreasesToday": 0, + "WriteCapacityUnits": 10, + "ReadCapacityUnits": 10, }, - 'TableSizeBytes': 0, - 'TableName': 'messages', - 'TableStatus': 'ACTIVE', - 'TableArn': 'arn:aws:dynamodb:us-east-1:123456789011:table/messages', - 'KeySchema': [ - {'KeyType': 'HASH', 'AttributeName': 'forum_name'}, - {'KeyType': 'RANGE', 'AttributeName': 'subject'} + "TableSizeBytes": 0, + "TableName": "messages", + "TableStatus": "ACTIVE", + "TableArn": "arn:aws:dynamodb:us-east-1:123456789011:table/messages", + "KeySchema": [ + {"KeyType": "HASH", "AttributeName": "forum_name"}, + {"KeyType": "RANGE", "AttributeName": "subject"}, ], - 'LocalSecondaryIndexes': [ + "LocalSecondaryIndexes": [ { - 'IndexName': 'threads_index', - 'KeySchema': [ - {'AttributeName': 'forum_name', 'KeyType': 'HASH'}, - {'AttributeName': 'threads', 'KeyType': 'RANGE'} + "IndexName": "threads_index", + "KeySchema": [ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "threads", "KeyType": "RANGE"}, ], - 'Projection': {'ProjectionType': 'ALL'} + "Projection": {"ProjectionType": "ALL"}, } ], - 'ItemCount': 0, - 'CreationDateTime': 1326499200.0, - 'GlobalSecondaryIndexes': [] + "ItemCount": 0, + "CreationDateTime": 1326499200.0, + "GlobalSecondaryIndexes": [], } } table.describe().should.equal(expected) @@ -143,8 +139,7 @@ def test_delete_table(): table.delete() conn.list_tables()["TableNames"].should.have.length_of(0) - conn.delete_table.when.called_with( - 'messages').should.throw(JSONResponseError) + conn.delete_table.when.called_with("messages").should.throw(JSONResponseError) @requires_boto_gte("2.9") @@ -153,18 +148,12 @@ def test_update_table_throughput(): table = create_table() table.throughput["read"].should.equal(10) table.throughput["write"].should.equal(10) - table.update(throughput={ - 'read': 5, - 'write': 15, - }) + table.update(throughput={"read": 5, "write": 15}) table.throughput["read"].should.equal(5) table.throughput["write"].should.equal(15) - table.update(throughput={ - 'read': 5, - 'write': 6, - }) + table.update(throughput={"read": 5, "write": 6}) table.describe() @@ -176,44 +165,45 @@ def test_update_table_throughput(): @mock_dynamodb2_deprecated def test_item_add_and_describe_and_update(): table = create_table() - ok = table.put_item(data={ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) + ok = table.put_item( + data={ + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) ok.should.equal(True) - table.get_item(forum_name="LOLCat Forum", - subject='Check this out!').should_not.be.none + table.get_item( + forum_name="LOLCat Forum", subject="Check this out!" + ).should_not.be.none - returned_item = table.get_item( - forum_name='LOLCat Forum', - subject='Check this out!' + returned_item = table.get_item(forum_name="LOLCat Forum", subject="Check this out!") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.save(overwrite=True) - returned_item = table.get_item( - forum_name='LOLCat Forum', - subject='Check this out!' + returned_item = table.get_item(forum_name="LOLCat Forum", subject="Check this out!") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) @requires_boto_gte("2.9") @@ -222,40 +212,38 @@ def test_item_partial_save(): table = create_table() data = { - 'forum_name': 'LOLCat Forum', - 'subject': 'The LOLz', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', + "forum_name": "LOLCat Forum", + "subject": "The LOLz", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", } table.put_item(data=data) - returned_item = table.get_item( - forum_name="LOLCat Forum", subject='The LOLz') + returned_item = table.get_item(forum_name="LOLCat Forum", subject="The LOLz") - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.partial_save() - returned_item = table.get_item( - forum_name='LOLCat Forum', - subject='The LOLz' + returned_item = table.get_item(forum_name="LOLCat Forum", subject="The LOLz") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "The LOLz", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'The LOLz', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_item_put_without_table(): - table = Table('undeclared-table') + table = Table("undeclared-table") item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.save.when.called_with().should.throw(JSONResponseError) @@ -266,36 +254,35 @@ def test_item_put_without_table(): def test_get_missing_item(): table = create_table() - table.get_item.when.called_with( - hash_key='tester', - range_key='other', - ).should.throw(ValidationException) + table.get_item.when.called_with(hash_key="tester", range_key="other").should.throw( + ValidationException + ) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_item_with_undeclared_table(): - table = Table('undeclared-table') - table.get_item.when.called_with( - test_hash=3241526475).should.throw(JSONResponseError) + table = Table("undeclared-table") + table.get_item.when.called_with(test_hash=3241526475).should.throw( + JSONResponseError + ) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_item_without_range_key(): - table = Table.create('messages', schema=[ - HashKey('test_hash'), - RangeKey('test_range'), - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", + schema=[HashKey("test_hash"), RangeKey("test_range")], + throughput={"read": 10, "write": 10}, + ) hash_key = 3241526475 range_key = 1234567890987 - table.put_item(data={'test_hash': hash_key, 'test_range': range_key}) - table.get_item.when.called_with( - test_hash=hash_key).should.throw(ValidationException) + table.put_item(data={"test_hash": hash_key, "test_range": range_key}) + table.get_item.when.called_with(test_hash=hash_key).should.throw( + ValidationException + ) @requires_boto_gte("2.30.0") @@ -303,13 +290,13 @@ def test_get_item_without_range_key(): def test_delete_item(): table = create_table() item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) - item['subject'] = 'Check this out!' + item["subject"] = "Check this out!" item.save() table.count().should.equal(1) @@ -326,10 +313,10 @@ def test_delete_item(): def test_delete_item_with_undeclared_table(): table = Table("undeclared-table") item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.delete.when.called_with().should.throw(JSONResponseError) @@ -341,70 +328,65 @@ def test_query(): table = create_table() item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'subject': 'Check this out!' + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "subject": "Check this out!", } item = Item(table, item_data) item.save(overwrite=True) - item['forum_name'] = 'the-key' - item['subject'] = '456' + item["forum_name"] = "the-key" + item["subject"] = "456" item.save(overwrite=True) - item['forum_name'] = 'the-key' - item['subject'] = '123' + item["forum_name"] = "the-key" + item["subject"] = "123" item.save(overwrite=True) - item['forum_name'] = 'the-key' - item['subject'] = '789' + item["forum_name"] = "the-key" + item["subject"] = "789" item.save(overwrite=True) table.count().should.equal(4) - results = table.query_2(forum_name__eq='the-key', - subject__gt='1', consistent=True) + results = table.query_2(forum_name__eq="the-key", subject__gt="1", consistent=True) expected = ["123", "456", "789"] for index, item in enumerate(results): item["subject"].should.equal(expected[index]) - results = table.query_2(forum_name__eq="the-key", - subject__gt='1', reverse=True) + results = table.query_2(forum_name__eq="the-key", subject__gt="1", reverse=True) for index, item in enumerate(results): item["subject"].should.equal(expected[len(expected) - 1 - index]) - results = table.query_2(forum_name__eq='the-key', - subject__gt='1', consistent=True) + results = table.query_2(forum_name__eq="the-key", subject__gt="1", consistent=True) sum(1 for _ in results).should.equal(3) - results = table.query_2(forum_name__eq='the-key', - subject__gt='234', consistent=True) + results = table.query_2( + forum_name__eq="the-key", subject__gt="234", consistent=True + ) sum(1 for _ in results).should.equal(2) - results = table.query_2(forum_name__eq='the-key', subject__gt='9999') + results = table.query_2(forum_name__eq="the-key", subject__gt="9999") sum(1 for _ in results).should.equal(0) - results = table.query_2(forum_name__eq='the-key', subject__beginswith='12') + results = table.query_2(forum_name__eq="the-key", subject__beginswith="12") sum(1 for _ in results).should.equal(1) - results = table.query_2(forum_name__eq='the-key', subject__beginswith='7') + results = table.query_2(forum_name__eq="the-key", subject__beginswith="7") sum(1 for _ in results).should.equal(1) - results = table.query_2(forum_name__eq='the-key', - subject__between=['567', '890']) + results = table.query_2(forum_name__eq="the-key", subject__between=["567", "890"]) sum(1 for _ in results).should.equal(1) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_query_with_undeclared_table(): - table = Table('undeclared') + table = Table("undeclared") results = table.query( - forum_name__eq='Amazon DynamoDB', - subject__beginswith='DynamoDB', - limit=1 + forum_name__eq="Amazon DynamoDB", subject__beginswith="DynamoDB", limit=1 ) iterate_results.when.called_with(results).should.throw(JSONResponseError) @@ -414,30 +396,30 @@ def test_query_with_undeclared_table(): def test_scan(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key' - item_data['subject'] = '456' + item_data["forum_name"] = "the-key" + item_data["subject"] = "456" item = Item(table, item_data) item.save() - item['forum_name'] = 'the-key' - item['subject'] = '123' + item["forum_name"] = "the-key" + item["subject"] = "123" item.save() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:09 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:09 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item_data['forum_name'] = 'the-key' - item_data['subject'] = '789' + item_data["forum_name"] = "the-key" + item_data["subject"] = "789" item = Item(table, item_data) item.save() @@ -445,10 +427,10 @@ def test_scan(): results = table.scan() sum(1 for _ in results).should.equal(3) - results = table.scan(SentBy__eq='User B') + results = table.scan(SentBy__eq="User B") sum(1 for _ in results).should.equal(1) - results = table.scan(Body__beginswith='http') + results = table.scan(Body__beginswith="http") sum(1 for _ in results).should.equal(3) results = table.scan(Ids__null=False) @@ -469,13 +451,11 @@ def test_scan(): def test_scan_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(JSONResponseError) @@ -486,27 +466,28 @@ def test_scan_with_undeclared_table(): def test_write_batch(): table = create_table() with table.batch_write() as batch: - batch.put_item(data={ - 'forum_name': 'the-key', - 'subject': '123', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) - batch.put_item(data={ - 'forum_name': 'the-key', - 'subject': '789', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) + batch.put_item( + data={ + "forum_name": "the-key", + "subject": "123", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) + batch.put_item( + data={ + "forum_name": "the-key", + "subject": "789", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) table.count().should.equal(2) with table.batch_write() as batch: - batch.delete_item( - forum_name='the-key', - subject='789' - ) + batch.delete_item(forum_name="the-key", subject="789") table.count().should.equal(1) @@ -516,37 +497,37 @@ def test_write_batch(): def test_batch_read(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key' - item_data['subject'] = '456' + item_data["forum_name"] = "the-key" + item_data["subject"] = "456" item = Item(table, item_data) item.save() item = Item(table, item_data) - item_data['forum_name'] = 'the-key' - item_data['subject'] = '123' + item_data["forum_name"] = "the-key" + item_data["subject"] = "123" item.save() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } item = Item(table, item_data) - item_data['forum_name'] = 'another-key' - item_data['subject'] = '789' + item_data["forum_name"] = "another-key" + item_data["subject"] = "789" item.save() results = table.batch_get( keys=[ - {'forum_name': 'the-key', 'subject': '123'}, - {'forum_name': 'another-key', 'subject': '789'}, + {"forum_name": "the-key", "subject": "123"}, + {"forum_name": "another-key", "subject": "789"}, ] ) @@ -560,95 +541,76 @@ def test_batch_read(): def test_get_key_fields(): table = create_table() kf = table.get_key_fields() - kf.should.equal(['forum_name', 'subject']) + kf.should.equal(["forum_name", "subject"]) @mock_dynamodb2_deprecated def test_create_with_global_indexes(): conn = boto.dynamodb2.layer1.DynamoDBConnection() - Table.create('messages', schema=[ - HashKey('subject'), - RangeKey('version'), - ], global_indexes=[ - GlobalAllIndex('topic-created_at-index', - parts=[ - HashKey('topic'), - RangeKey('created_at', data_type='N') - ], - throughput={ - 'read': 6, - 'write': 1 - } - ), - ]) + Table.create( + "messages", + schema=[HashKey("subject"), RangeKey("version")], + global_indexes=[ + GlobalAllIndex( + "topic-created_at-index", + parts=[HashKey("topic"), RangeKey("created_at", data_type="N")], + throughput={"read": 6, "write": 1}, + ) + ], + ) table_description = conn.describe_table("messages") - table_description['Table']["GlobalSecondaryIndexes"].should.equal([ - { - "IndexName": "topic-created_at-index", - "KeySchema": [ - { - "AttributeName": "topic", - "KeyType": "HASH" + table_description["Table"]["GlobalSecondaryIndexes"].should.equal( + [ + { + "IndexName": "topic-created_at-index", + "KeySchema": [ + {"AttributeName": "topic", "KeyType": "HASH"}, + {"AttributeName": "created_at", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 6, + "WriteCapacityUnits": 1, }, - { - "AttributeName": "created_at", - "KeyType": "RANGE" - }, - ], - "Projection": { - "ProjectionType": "ALL" - }, - "ProvisionedThroughput": { - "ReadCapacityUnits": 6, - "WriteCapacityUnits": 1, } - } - ]) + ] + ) @mock_dynamodb2_deprecated def test_query_with_global_indexes(): - table = Table.create('messages', schema=[ - HashKey('subject'), - RangeKey('version'), - ], global_indexes=[ - GlobalAllIndex('topic-created_at-index', - parts=[ - HashKey('topic'), - RangeKey('created_at', data_type='N') - ], - throughput={ - 'read': 6, - 'write': 1 - } - ), - GlobalAllIndex('status-created_at-index', - parts=[ - HashKey('status'), - RangeKey('created_at', data_type='N') - ], - throughput={ - 'read': 2, - 'write': 1 - } - ) - ]) + table = Table.create( + "messages", + schema=[HashKey("subject"), RangeKey("version")], + global_indexes=[ + GlobalAllIndex( + "topic-created_at-index", + parts=[HashKey("topic"), RangeKey("created_at", data_type="N")], + throughput={"read": 6, "write": 1}, + ), + GlobalAllIndex( + "status-created_at-index", + parts=[HashKey("status"), RangeKey("created_at", data_type="N")], + throughput={"read": 2, "write": 1}, + ), + ], + ) item_data = { - 'subject': 'Check this out!', - 'version': '1', - 'created_at': 0, - 'status': 'inactive' + "subject": "Check this out!", + "version": "1", + "created_at": 0, + "status": "inactive", } item = Item(table, item_data) item.save(overwrite=True) - item['version'] = '2' + item["version"] = "2" item.save(overwrite=True) - results = table.query(status__eq='active') + results = table.query(status__eq="active") list(results).should.have.length_of(0) @@ -656,19 +618,20 @@ def test_query_with_global_indexes(): def test_query_with_local_indexes(): table = create_table_with_local_indexes() item_data = { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, - 'status': 'inactive' + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, + "status": "inactive", } item = Item(table, item_data) item.save(overwrite=True) - item['version'] = '2' + item["version"] = "2" item.save(overwrite=True) - results = table.query(forum_name__eq='Cool Forum', - index='threads_index', threads__eq=1) + results = table.query( + forum_name__eq="Cool Forum", index="threads_index", threads__eq=1 + ) list(results).should.have.length_of(1) @@ -678,29 +641,29 @@ def test_query_filter_eq(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query_2( - forum_name__eq='Cool Forum', index='threads_index', threads__eq=5 + forum_name__eq="Cool Forum", index="threads_index", threads__eq=5 ) list(results).should.have.length_of(1) @@ -711,30 +674,30 @@ def test_query_filter_lt(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__lt=5 + forum_name__eq="Cool Forum", index="threads_index", threads__lt=5 ) results = list(results) results.should.have.length_of(2) @@ -746,30 +709,30 @@ def test_query_filter_gt(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__gt=1 + forum_name__eq="Cool Forum", index="threads_index", threads__gt=1 ) list(results).should.have.length_of(1) @@ -780,30 +743,30 @@ def test_query_filter_lte(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__lte=5 + forum_name__eq="Cool Forum", index="threads_index", threads__lte=5 ) list(results).should.have.length_of(3) @@ -814,30 +777,30 @@ def test_query_filter_gte(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__gte=1 + forum_name__eq="Cool Forum", index="threads_index", threads__gte=1 ) list(results).should.have.length_of(2) @@ -848,37 +811,33 @@ def test_query_non_hash_range_key(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '3', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "3", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '2', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "2", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) - results = table.query( - forum_name__eq='Cool Forum', version__gt="2" - ) + results = table.query(forum_name__eq="Cool Forum", version__gt="2") results = list(results) results.should.have.length_of(1) - results = table.query( - forum_name__eq='Cool Forum', version__lt="3" - ) + results = table.query(forum_name__eq="Cool Forum", version__lt="3") results = list(results) results.should.have.length_of(2) @@ -887,94 +846,83 @@ def test_query_non_hash_range_key(): def test_reverse_query(): conn = boto.dynamodb2.layer1.DynamoDBConnection() - table = Table.create('messages', schema=[ - HashKey('subject'), - RangeKey('created_at', data_type='N') - ]) + table = Table.create( + "messages", schema=[HashKey("subject"), RangeKey("created_at", data_type="N")] + ) for i in range(10): - table.put_item({ - 'subject': "Hi", - 'created_at': i - }) + table.put_item({"subject": "Hi", "created_at": i}) - results = table.query_2(subject__eq="Hi", - created_at__lt=6, - limit=4, - reverse=True) + results = table.query_2(subject__eq="Hi", created_at__lt=6, limit=4, reverse=True) expected = [Decimal(5), Decimal(4), Decimal(3), Decimal(2)] - [r['created_at'] for r in results].should.equal(expected) + [r["created_at"] for r in results].should.equal(expected) @mock_dynamodb2_deprecated def test_lookup(): from decimal import Decimal - table = Table.create('messages', schema=[ - HashKey('test_hash'), - RangeKey('test_range'), - ], throughput={ - 'read': 10, - 'write': 10, - }) + + table = Table.create( + "messages", + schema=[HashKey("test_hash"), RangeKey("test_range")], + throughput={"read": 10, "write": 10}, + ) hash_key = 3241526475 range_key = 1234567890987 - data = {'test_hash': hash_key, 'test_range': range_key} + data = {"test_hash": hash_key, "test_range": range_key} table.put_item(data=data) message = table.lookup(hash_key, range_key) - message.get('test_hash').should.equal(Decimal(hash_key)) - message.get('test_range').should.equal(Decimal(range_key)) + message.get("test_hash").should.equal(Decimal(hash_key)) + message.get("test_range").should.equal(Decimal(range_key)) @mock_dynamodb2_deprecated def test_failed_overwrite(): - table = Table.create('messages', schema=[ - HashKey('id'), - RangeKey('range'), - ], throughput={ - 'read': 7, - 'write': 3, - }) + table = Table.create( + "messages", + schema=[HashKey("id"), RangeKey("range")], + throughput={"read": 7, "write": 3}, + ) - data1 = {'id': '123', 'range': 'abc', 'data': '678'} + data1 = {"id": "123", "range": "abc", "data": "678"} table.put_item(data=data1) - data2 = {'id': '123', 'range': 'abc', 'data': '345'} + data2 = {"id": "123", "range": "abc", "data": "345"} table.put_item(data=data2, overwrite=True) - data3 = {'id': '123', 'range': 'abc', 'data': '812'} + data3 = {"id": "123", "range": "abc", "data": "812"} table.put_item.when.called_with(data=data3).should.throw( - ConditionalCheckFailedException) + ConditionalCheckFailedException + ) - returned_item = table.lookup('123', 'abc') + returned_item = table.lookup("123", "abc") dict(returned_item).should.equal(data2) - data4 = {'id': '123', 'range': 'ghi', 'data': 812} + data4 = {"id": "123", "range": "ghi", "data": 812} table.put_item(data=data4) - returned_item = table.lookup('123', 'ghi') + returned_item = table.lookup("123", "ghi") dict(returned_item).should.equal(data4) @mock_dynamodb2_deprecated def test_conflicting_writes(): - table = Table.create('messages', schema=[ - HashKey('id'), - RangeKey('range'), - ]) + table = Table.create("messages", schema=[HashKey("id"), RangeKey("range")]) - item_data = {'id': '123', 'range': 'abc', 'data': '678'} + item_data = {"id": "123", "range": "abc", "data": "678"} item1 = Item(table, item_data) item2 = Item(table, item_data) item1.save() - item1['data'] = '579' - item2['data'] = '912' + item1["data"] = "579" + item2["data"] = "912" item1.save() item2.save.when.called_with().should.throw(ConditionalCheckFailedException) + """ boto3 """ @@ -982,464 +930,351 @@ boto3 @mock_dynamodb2 def test_boto3_conditions(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123' - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '456' - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '789' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123"}) + table.put_item(Item={"forum_name": "the-key", "subject": "456"}) + table.put_item(Item={"forum_name": "the-key", "subject": "789"}) # Test a query returning all items results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('1'), + KeyConditionExpression=Key("forum_name").eq("the-key") & Key("subject").gt("1"), ScanIndexForward=True, ) expected = ["123", "456", "789"] - for index, item in enumerate(results['Items']): + for index, item in enumerate(results["Items"]): item["subject"].should.equal(expected[index]) # Return all items again, but in reverse results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('1'), + KeyConditionExpression=Key("forum_name").eq("the-key") & Key("subject").gt("1"), ScanIndexForward=False, ) - for index, item in enumerate(reversed(results['Items'])): + for index, item in enumerate(reversed(results["Items"])): item["subject"].should.equal(expected[index]) # Filter the subjects to only return some of the results results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('234'), + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").gt("234"), ConsistentRead=True, ) - results['Count'].should.equal(2) + results["Count"].should.equal(2) # Filter to return no results results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('9999') + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").gt("9999") ) - results['Count'].should.equal(0) + results["Count"].should.equal(0) results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").begins_with('12') + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").begins_with("12") ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) results = table.query( - KeyConditionExpression=Key("subject").begins_with( - '7') & Key('forum_name').eq('the-key') + KeyConditionExpression=Key("subject").begins_with("7") + & Key("forum_name").eq("the-key") ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").between('567', '890') + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").between("567", "890") ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) @mock_dynamodb2 def test_boto3_put_item_with_conditions(): import botocore - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123"}) table.put_item( - Item={ - 'forum_name': 'the-key-2', - 'subject': '1234', - }, - ConditionExpression='attribute_not_exists(forum_name) AND attribute_not_exists(subject)' + Item={"forum_name": "the-key-2", "subject": "1234"}, + ConditionExpression="attribute_not_exists(forum_name) AND attribute_not_exists(subject)", ) table.put_item.when.called_with( - Item={ - 'forum_name': 'the-key', - 'subject': '123' - }, - ConditionExpression='attribute_not_exists(forum_name) AND attribute_not_exists(subject)' + Item={"forum_name": "the-key", "subject": "123"}, + ConditionExpression="attribute_not_exists(forum_name) AND attribute_not_exists(subject)", ).should.throw(botocore.exceptions.ClientError) table.put_item.when.called_with( - Item={ - 'forum_name': 'bogus-key', - 'subject': 'bogus', - 'test': '123' - }, - ConditionExpression='attribute_exists(forum_name) AND attribute_exists(subject)' + Item={"forum_name": "bogus-key", "subject": "bogus", "test": "123"}, + ConditionExpression="attribute_exists(forum_name) AND attribute_exists(subject)", ).should.throw(botocore.exceptions.ClientError) def _create_table_with_range_key(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], - GlobalSecondaryIndexes=[{ - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', + GlobalSecondaryIndexes=[ + { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', - } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } - }], - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'created', - 'AttributeType': 'N' } ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "username", "AttributeType": "S"}, + {"AttributeName": "created", "AttributeType": "N"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - return dynamodb.Table('users') + return dynamodb.Table("users") @mock_dynamodb2 def test_update_item_range_key_set(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'username': 'johndoe', - 'created': Decimal('3'), - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "username": "johndoe", + "created": Decimal("3"), + } + ) - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} table.update_item( Key=item_key, AttributeUpdates={ - 'username': { - 'Action': u'PUT', - 'Value': 'johndoe2' - }, - 'created': { - 'Action': u'PUT', - 'Value': Decimal('4'), - }, - 'mapfield': { - 'Action': u'PUT', - 'Value': {'key': 'value'}, - } + "username": {"Action": "PUT", "Value": "johndoe2"}, + "created": {"Action": "PUT", "Value": Decimal("4")}, + "mapfield": {"Action": "PUT", "Value": {"key": "value"}}, }, ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'username': "johndoe2", - 'forum_name': 'the-key', - 'subject': '123', - 'created': '4', - 'mapfield': {'key': 'value'}, - }) + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + { + "username": "johndoe2", + "forum_name": "the-key", + "subject": "123", + "created": "4", + "mapfield": {"key": "value"}, + } + ) @mock_dynamodb2 def test_update_item_does_not_exist_is_created(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} result = table.update_item( Key=item_key, AttributeUpdates={ - 'username': { - 'Action': u'PUT', - 'Value': 'johndoe2' - }, - 'created': { - 'Action': u'PUT', - 'Value': Decimal('4'), - }, - 'mapfield': { - 'Action': u'PUT', - 'Value': {'key': 'value'}, - } + "username": {"Action": "PUT", "Value": "johndoe2"}, + "created": {"Action": "PUT", "Value": Decimal("4")}, + "mapfield": {"Action": "PUT", "Value": {"key": "value"}}, }, - ReturnValues='ALL_OLD', + ReturnValues="ALL_OLD", ) - assert not result.get('Attributes') + assert not result.get("Attributes") - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'username': "johndoe2", - 'forum_name': 'the-key', - 'subject': '123', - 'created': '4', - 'mapfield': {'key': 'value'}, - }) + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + { + "username": "johndoe2", + "forum_name": "the-key", + "subject": "123", + "created": "4", + "mapfield": {"key": "value"}, + } + ) @mock_dynamodb2 def test_update_item_add_value(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'numeric_field': Decimal('-1'), - }) - - item_key = {'forum_name': 'the-key', 'subject': '123'} - table.update_item( - Key=item_key, - AttributeUpdates={ - 'numeric_field': { - 'Action': u'ADD', - 'Value': Decimal('2'), - }, - }, + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "numeric_field": Decimal("-1")} ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'numeric_field': '1', - 'forum_name': 'the-key', - 'subject': '123', - }) + item_key = {"forum_name": "the-key", "subject": "123"} + table.update_item( + Key=item_key, + AttributeUpdates={"numeric_field": {"Action": "ADD", "Value": Decimal("2")}}, + ) + + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + {"numeric_field": "1", "forum_name": "the-key", "subject": "123"} + ) @mock_dynamodb2 def test_update_item_add_value_string_set(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'string_set': set(['str1', 'str2']), - }) - - item_key = {'forum_name': 'the-key', 'subject': '123'} - table.update_item( - Key=item_key, - AttributeUpdates={ - 'string_set': { - 'Action': u'ADD', - 'Value': set(['str3']), - }, - }, + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "string_set": set(["str1", "str2"]), + } + ) + + item_key = {"forum_name": "the-key", "subject": "123"} + table.update_item( + Key=item_key, + AttributeUpdates={"string_set": {"Action": "ADD", "Value": set(["str3"])}}, + ) + + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + { + "string_set": set(["str1", "str2", "str3"]), + "forum_name": "the-key", + "subject": "123", + } ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'string_set': set(['str1', 'str2', 'str3']), - 'forum_name': 'the-key', - 'subject': '123', - }) @mock_dynamodb2 def test_update_item_delete_value_string_set(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'string_set': set(['str1', 'str2']), - }) - - item_key = {'forum_name': 'the-key', 'subject': '123'} - table.update_item( - Key=item_key, - AttributeUpdates={ - 'string_set': { - 'Action': u'DELETE', - 'Value': set(['str2']), - }, - }, + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "string_set": set(["str1", "str2"]), + } + ) + + item_key = {"forum_name": "the-key", "subject": "123"} + table.update_item( + Key=item_key, + AttributeUpdates={"string_set": {"Action": "DELETE", "Value": set(["str2"])}}, + ) + + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + {"string_set": set(["str1"]), "forum_name": "the-key", "subject": "123"} ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'string_set': set(['str1']), - 'forum_name': 'the-key', - 'subject': '123', - }) @mock_dynamodb2 def test_update_item_add_value_does_not_exist_is_created(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} table.update_item( Key=item_key, - AttributeUpdates={ - 'numeric_field': { - 'Action': u'ADD', - 'Value': Decimal('2'), - }, - }, + AttributeUpdates={"numeric_field": {"Action": "ADD", "Value": Decimal("2")}}, ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'numeric_field': '2', - 'forum_name': 'the-key', - 'subject': '123', - }) + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + {"numeric_field": "2", "forum_name": "the-key", "subject": "123"} + ) @mock_dynamodb2 def test_update_item_with_expression(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'field': '1' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123", "field": "1"}) - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} - table.update_item( - Key=item_key, - UpdateExpression='SET field=2', + table.update_item(Key=item_key, UpdateExpression="SET field=2") + dict(table.get_item(Key=item_key)["Item"]).should.equal( + {"field": "2", "forum_name": "the-key", "subject": "123"} ) - dict(table.get_item(Key=item_key)['Item']).should.equal({ - 'field': '2', - 'forum_name': 'the-key', - 'subject': '123', - }) - table.update_item( - Key=item_key, - UpdateExpression='SET field = 3', + table.update_item(Key=item_key, UpdateExpression="SET field = 3") + dict(table.get_item(Key=item_key)["Item"]).should.equal( + {"field": "3", "forum_name": "the-key", "subject": "123"} ) - dict(table.get_item(Key=item_key)['Item']).should.equal({ - 'field': '3', - 'forum_name': 'the-key', - 'subject': '123', - }) + @mock_dynamodb2 def test_update_item_add_with_expression(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} current_item = { - 'forum_name': 'the-key', - 'subject': '123', - 'str_set': {'item1', 'item2', 'item3'}, - 'num_set': {1, 2, 3}, - 'num_val': 6 + "forum_name": "the-key", + "subject": "123", + "str_set": {"item1", "item2", "item3"}, + "num_set": {1, 2, 3}, + "num_val": 6, } # Put an entry in the DB to play with @@ -1448,83 +1283,143 @@ def test_update_item_add_with_expression(): # Update item to add a string value to a string set table.update_item( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': {'item4'} - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": {"item4"}}, ) - current_item['str_set'] = current_item['str_set'].union({'item4'}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["str_set"] = current_item["str_set"].union({"item4"}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + + # Update item to add a string value to a non-existing set + # Should just create the set in the background + table.update_item( + Key=item_key, + UpdateExpression="ADD non_existing_str_set :v", + ExpressionAttributeValues={":v": {"item4"}}, + ) + current_item["non_existing_str_set"] = {"item4"} + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Update item to add a num value to a num set table.update_item( Key=item_key, - UpdateExpression='ADD num_set :v', - ExpressionAttributeValues={ - ':v': {6} - } + UpdateExpression="ADD num_set :v", + ExpressionAttributeValues={":v": {6}}, ) - current_item['num_set'] = current_item['num_set'].union({6}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["num_set"] = current_item["num_set"].union({6}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Update item to add a value to a number value table.update_item( Key=item_key, - UpdateExpression='ADD num_val :v', - ExpressionAttributeValues={ - ':v': 20 - } + UpdateExpression="ADD num_val :v", + ExpressionAttributeValues={":v": 20}, ) - current_item['num_val'] = current_item['num_val'] + 20 - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["num_val"] = current_item["num_val"] + 20 + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to add a number value to a string set, should raise Client Error table.update_item.when.called_with( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': 20 - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": 20}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to add a number set to the string set, should raise a ClientError table.update_item.when.called_with( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': { 20 } - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": {20}}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to update with a bad expression table.update_item.when.called_with( - Key=item_key, - UpdateExpression='ADD str_set bad_value' + Key=item_key, UpdateExpression="ADD str_set bad_value" ).should.have.raised(ClientError) # Attempt to add a string value instead of a string set table.update_item.when.called_with( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': 'new_string' - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": "new_string"}, ).should.have.raised(ClientError) +@mock_dynamodb2 +def test_update_item_add_with_nested_sets(): + table = _create_table_with_range_key() + + item_key = {"forum_name": "the-key", "subject": "123"} + current_item = { + "forum_name": "the-key", + "subject": "123", + "nested": {"str_set": {"item1", "item2", "item3"}}, + } + + # Put an entry in the DB to play with + table.put_item(Item=current_item) + + # Update item to add a string value to a nested string set + table.update_item( + Key=item_key, + UpdateExpression="ADD nested.str_set :v", + ExpressionAttributeValues={":v": {"item4"}}, + ) + current_item["nested"]["str_set"] = current_item["nested"]["str_set"].union( + {"item4"} + ) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + + # Update item to add a string value to a non-existing set + # Should just create the set in the background + table.update_item( + Key=item_key, + UpdateExpression="ADD #ns.#ne :v", + ExpressionAttributeNames={"#ns": "nested", "#ne": "non_existing_str_set"}, + ExpressionAttributeValues={":v": {"new_item"}}, + ) + current_item["nested"]["non_existing_str_set"] = {"new_item"} + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + + +@mock_dynamodb2 +def test_update_item_delete_with_nested_sets(): + table = _create_table_with_range_key() + + item_key = {"forum_name": "the-key", "subject": "123"} + current_item = { + "forum_name": "the-key", + "subject": "123", + "nested": {"str_set": {"item1", "item2", "item3"}}, + } + + # Put an entry in the DB to play with + table.put_item(Item=current_item) + + # Update item to add a string value to a nested string set + table.update_item( + Key=item_key, + UpdateExpression="DELETE nested.str_set :v", + ExpressionAttributeValues={":v": {"item3"}}, + ) + current_item["nested"]["str_set"] = current_item["nested"]["str_set"].difference( + {"item3"} + ) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) + + @mock_dynamodb2 def test_update_item_delete_with_expression(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} current_item = { - 'forum_name': 'the-key', - 'subject': '123', - 'str_set': {'item1', 'item2', 'item3'}, - 'num_set': {1, 2, 3}, - 'num_val': 6 + "forum_name": "the-key", + "subject": "123", + "str_set": {"item1", "item2", "item3"}, + "num_set": {1, 2, 3}, + "num_val": 6, } # Put an entry in the DB to play with @@ -1533,49 +1428,40 @@ def test_update_item_delete_with_expression(): # Update item to delete a string value from a string set table.update_item( Key=item_key, - UpdateExpression='DELETE str_set :v', - ExpressionAttributeValues={ - ':v': {'item2'} - } + UpdateExpression="DELETE str_set :v", + ExpressionAttributeValues={":v": {"item2"}}, ) - current_item['str_set'] = current_item['str_set'].difference({'item2'}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["str_set"] = current_item["str_set"].difference({"item2"}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Update item to delete a num value from a num set table.update_item( Key=item_key, - UpdateExpression='DELETE num_set :v', - ExpressionAttributeValues={ - ':v': {2} - } + UpdateExpression="DELETE num_set :v", + ExpressionAttributeValues={":v": {2}}, ) - current_item['num_set'] = current_item['num_set'].difference({2}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["num_set"] = current_item["num_set"].difference({2}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Try to delete on a number, this should fail table.update_item.when.called_with( Key=item_key, - UpdateExpression='DELETE num_val :v', - ExpressionAttributeValues={ - ':v': 20 - } + UpdateExpression="DELETE num_val :v", + ExpressionAttributeValues={":v": 20}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Try to delete a string set from a number set table.update_item.when.called_with( Key=item_key, - UpdateExpression='DELETE num_set :v', - ExpressionAttributeValues={ - ':v': {'del_str'} - } + UpdateExpression="DELETE num_set :v", + ExpressionAttributeValues={":v": {"del_str"}}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to update with a bad expression table.update_item.when.called_with( - Key=item_key, - UpdateExpression='DELETE num_val badvalue' + Key=item_key, UpdateExpression="DELETE num_val badvalue" ).should.have.raised(ClientError) @@ -1583,378 +1469,309 @@ def test_update_item_delete_with_expression(): def test_boto3_query_gsi_range_comparison(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'username': 'johndoe', - 'created': 3, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '456', - 'username': 'johndoe', - 'created': 1, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '789', - 'username': 'johndoe', - 'created': 2, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '159', - 'username': 'janedoe', - 'created': 2, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '601', - 'username': 'janedoe', - 'created': 5, - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "username": "johndoe", + "created": 3, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "456", + "username": "johndoe", + "created": 1, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "789", + "username": "johndoe", + "created": 2, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "159", + "username": "janedoe", + "created": 2, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "601", + "username": "janedoe", + "created": 5, + } + ) # Test a query returning all johndoe items results = table.query( - KeyConditionExpression=Key('username').eq( - 'johndoe') & Key("created").gt(0), + KeyConditionExpression=Key("username").eq("johndoe") & Key("created").gt(0), ScanIndexForward=True, - IndexName='TestGSI', + IndexName="TestGSI", ) expected = ["456", "789", "123"] - for index, item in enumerate(results['Items']): + for index, item in enumerate(results["Items"]): item["subject"].should.equal(expected[index]) # Return all johndoe items again, but in reverse results = table.query( - KeyConditionExpression=Key('username').eq( - 'johndoe') & Key("created").gt(0), + KeyConditionExpression=Key("username").eq("johndoe") & Key("created").gt(0), ScanIndexForward=False, - IndexName='TestGSI', + IndexName="TestGSI", ) - for index, item in enumerate(reversed(results['Items'])): + for index, item in enumerate(reversed(results["Items"])): item["subject"].should.equal(expected[index]) # Filter the creation to only return some of the results # And reverse order of hash + range key results = table.query( - KeyConditionExpression=Key("created").gt( - 1) & Key('username').eq('johndoe'), + KeyConditionExpression=Key("created").gt(1) & Key("username").eq("johndoe"), ConsistentRead=True, - IndexName='TestGSI', + IndexName="TestGSI", ) - results['Count'].should.equal(2) + results["Count"].should.equal(2) # Filter to return no results results = table.query( - KeyConditionExpression=Key('username').eq( - 'janedoe') & Key("created").gt(9), - IndexName='TestGSI', + KeyConditionExpression=Key("username").eq("janedoe") & Key("created").gt(9), + IndexName="TestGSI", ) - results['Count'].should.equal(0) + results["Count"].should.equal(0) results = table.query( - KeyConditionExpression=Key('username').eq( - 'janedoe') & Key("created").eq(5), - IndexName='TestGSI', + KeyConditionExpression=Key("username").eq("janedoe") & Key("created").eq(5), + IndexName="TestGSI", ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) # Test range key sorting results = table.query( - KeyConditionExpression=Key('username').eq( - 'johndoe') & Key("created").gt(0), - IndexName='TestGSI', + KeyConditionExpression=Key("username").eq("johndoe") & Key("created").gt(0), + IndexName="TestGSI", ) - expected = [Decimal('1'), Decimal('2'), Decimal('3')] - for index, item in enumerate(results['Items']): + expected = [Decimal("1"), Decimal("2"), Decimal("3")] + for index, item in enumerate(results["Items"]): item["created"].should.equal(expected[index]) @mock_dynamodb2 def test_boto3_update_table_throughput(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.provisioned_throughput['ReadCapacityUnits'].should.equal(5) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(6) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(5) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(6) - table.update(ProvisionedThroughput={ - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 11, - }) + table.update( + ProvisionedThroughput={"ReadCapacityUnits": 10, "WriteCapacityUnits": 11} + ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.provisioned_throughput['ReadCapacityUnits'].should.equal(10) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(11) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(10) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(11) @mock_dynamodb2 def test_boto3_update_table_gsi_throughput(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], - GlobalSecondaryIndexes=[{ - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', + GlobalSecondaryIndexes=[ + { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 4, }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', - } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 3, - 'WriteCapacityUnits': 4 } - }], - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "username", "AttributeType": "S"}, + {"AttributeName": "created", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - gsi_throughput['ReadCapacityUnits'].should.equal(3) - gsi_throughput['WriteCapacityUnits'].should.equal(4) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + gsi_throughput["ReadCapacityUnits"].should.equal(3) + gsi_throughput["WriteCapacityUnits"].should.equal(4) - table.provisioned_throughput['ReadCapacityUnits'].should.equal(5) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(6) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(5) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(6) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Update': { - 'IndexName': 'TestGSI', - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 11, + table.update( + GlobalSecondaryIndexUpdates=[ + { + "Update": { + "IndexName": "TestGSI", + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 11, + }, + } } - }, - }]) + ] + ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") # Primary throughput has not changed - table.provisioned_throughput['ReadCapacityUnits'].should.equal(5) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(6) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(5) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(6) - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - gsi_throughput['ReadCapacityUnits'].should.equal(10) - gsi_throughput['WriteCapacityUnits'].should.equal(11) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + gsi_throughput["ReadCapacityUnits"].should.equal(10) + gsi_throughput["WriteCapacityUnits"].should.equal(11) @mock_dynamodb2 def test_update_table_gsi_create(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(0) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Create': { - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', - }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', + table.update( + GlobalSecondaryIndexUpdates=[ + { + "Create": { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 4, + }, } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 3, - 'WriteCapacityUnits': 4 } - }, - }]) + ] + ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(1) - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - assert gsi_throughput['ReadCapacityUnits'].should.equal(3) - assert gsi_throughput['WriteCapacityUnits'].should.equal(4) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + assert gsi_throughput["ReadCapacityUnits"].should.equal(3) + assert gsi_throughput["WriteCapacityUnits"].should.equal(4) # Check update works - table.update(GlobalSecondaryIndexUpdates=[{ - 'Update': { - 'IndexName': 'TestGSI', - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 11, + table.update( + GlobalSecondaryIndexUpdates=[ + { + "Update": { + "IndexName": "TestGSI", + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 11, + }, + } } - }, - }]) - table = dynamodb.Table('users') + ] + ) + table = dynamodb.Table("users") - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - assert gsi_throughput['ReadCapacityUnits'].should.equal(10) - assert gsi_throughput['WriteCapacityUnits'].should.equal(11) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + assert gsi_throughput["ReadCapacityUnits"].should.equal(10) + assert gsi_throughput["WriteCapacityUnits"].should.equal(11) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Delete': { - 'IndexName': 'TestGSI', - }, - }]) + table.update(GlobalSecondaryIndexUpdates=[{"Delete": {"IndexName": "TestGSI"}}]) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(0) @mock_dynamodb2 def test_update_table_gsi_throughput(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], - GlobalSecondaryIndexes=[{ - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', + GlobalSecondaryIndexes=[ + { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 4, }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', - } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 3, - 'WriteCapacityUnits': 4 } - }], - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "username", "AttributeType": "S"}, + {"AttributeName": "created", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(1) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Delete': { - 'IndexName': 'TestGSI', - }, - }]) + table.update(GlobalSecondaryIndexUpdates=[{"Delete": {"IndexName": "TestGSI"}}]) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(0) @@ -1962,140 +1779,131 @@ def test_update_table_gsi_throughput(): def test_query_pagination(): table = _create_table_with_range_key() for i in range(10): - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '{0}'.format(i), - 'username': 'johndoe', - 'created': Decimal('3'), - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "{0}".format(i), + "username": "johndoe", + "created": Decimal("3"), + } + ) - page1 = table.query( - KeyConditionExpression=Key('forum_name').eq('the-key'), - Limit=6 - ) - page1['Count'].should.equal(6) - page1['Items'].should.have.length_of(6) - page1.should.have.key('LastEvaluatedKey') + page1 = table.query(KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6) + page1["Count"].should.equal(6) + page1["Items"].should.have.length_of(6) + page1.should.have.key("LastEvaluatedKey") page2 = table.query( - KeyConditionExpression=Key('forum_name').eq('the-key'), + KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6, - ExclusiveStartKey=page1['LastEvaluatedKey'] + ExclusiveStartKey=page1["LastEvaluatedKey"], ) - page2['Count'].should.equal(4) - page2['Items'].should.have.length_of(4) - page2.should_not.have.key('LastEvaluatedKey') + page2["Count"].should.equal(4) + page2["Items"].should.have.length_of(4) + page2.should_not.have.key("LastEvaluatedKey") - results = page1['Items'] + page2['Items'] - subjects = set([int(r['subject']) for r in results]) + results = page1["Items"] + page2["Items"] + subjects = set([int(r["subject"]) for r in results]) subjects.should.equal(set(range(10))) @mock_dynamodb2 def test_scan_by_index(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', + TableName="test", KeySchema=[ - {'AttributeName': 'id', 'KeyType': 'HASH'}, - {'AttributeName': 'range_key', 'KeyType': 'RANGE'}, + {"AttributeName": "id", "KeyType": "HASH"}, + {"AttributeName": "range_key", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'range_key', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_col', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_range_key', 'AttributeType': 'S'}, - {'AttributeName': 'lsi_range_key', 'AttributeType': 'S'}, + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "range_key", "AttributeType": "S"}, + {"AttributeName": "gsi_col", "AttributeType": "S"}, + {"AttributeName": "gsi_range_key", "AttributeType": "S"}, + {"AttributeName": "lsi_range_key", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - {'AttributeName': 'gsi_col', 'KeyType': 'HASH'}, - {'AttributeName': 'gsi_range_key', 'KeyType': 'RANGE'}, + "IndexName": "test_gsi", + "KeySchema": [ + {"AttributeName": "gsi_col", "KeyType": "HASH"}, + {"AttributeName": "gsi_range_key", "KeyType": "RANGE"}, ], - 'Projection': { - 'ProjectionType': 'ALL', + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, + } ], LocalSecondaryIndexes=[ { - 'IndexName': 'test_lsi', - 'KeySchema': [ - {'AttributeName': 'id', 'KeyType': 'HASH'}, - {'AttributeName': 'lsi_range_key', 'KeyType': 'RANGE'}, + "IndexName": "test_lsi", + "KeySchema": [ + {"AttributeName": "id", "KeyType": "HASH"}, + {"AttributeName": "lsi_range_key", "KeyType": "RANGE"}, ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - }, - ] + "Projection": {"ProjectionType": "ALL"}, + } + ], ) dynamodb.put_item( - TableName='test', + TableName="test", Item={ - 'id': {'S': '1'}, - 'range_key': {'S': '1'}, - 'col1': {'S': 'val1'}, - 'gsi_col': {'S': '1'}, - 'gsi_range_key': {'S': '1'}, - 'lsi_range_key': {'S': '1'}, - } + "id": {"S": "1"}, + "range_key": {"S": "1"}, + "col1": {"S": "val1"}, + "gsi_col": {"S": "1"}, + "gsi_range_key": {"S": "1"}, + "lsi_range_key": {"S": "1"}, + }, ) dynamodb.put_item( - TableName='test', + TableName="test", Item={ - 'id': {'S': '1'}, - 'range_key': {'S': '2'}, - 'col1': {'S': 'val2'}, - 'gsi_col': {'S': '1'}, - 'gsi_range_key': {'S': '2'}, - 'lsi_range_key': {'S': '2'}, - } + "id": {"S": "1"}, + "range_key": {"S": "2"}, + "col1": {"S": "val2"}, + "gsi_col": {"S": "1"}, + "gsi_range_key": {"S": "2"}, + "lsi_range_key": {"S": "2"}, + }, ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '3'}, - 'range_key': {'S': '1'}, - 'col1': {'S': 'val3'}, - } + TableName="test", + Item={"id": {"S": "3"}, "range_key": {"S": "1"}, "col1": {"S": "val3"}}, ) - res = dynamodb.scan(TableName='test') - assert res['Count'] == 3 - assert len(res['Items']) == 3 + res = dynamodb.scan(TableName="test") + assert res["Count"] == 3 + assert len(res["Items"]) == 3 - res = dynamodb.scan(TableName='test', IndexName='test_gsi') - assert res['Count'] == 2 - assert len(res['Items']) == 2 + res = dynamodb.scan(TableName="test", IndexName="test_gsi") + assert res["Count"] == 2 + assert len(res["Items"]) == 2 - res = dynamodb.scan(TableName='test', IndexName='test_gsi', Limit=1) - assert res['Count'] == 1 - assert len(res['Items']) == 1 - last_eval_key = res['LastEvaluatedKey'] - assert last_eval_key['id']['S'] == '1' - assert last_eval_key['gsi_col']['S'] == '1' - assert last_eval_key['gsi_range_key']['S'] == '1' + res = dynamodb.scan(TableName="test", IndexName="test_gsi", Limit=1) + assert res["Count"] == 1 + assert len(res["Items"]) == 1 + last_eval_key = res["LastEvaluatedKey"] + assert last_eval_key["id"]["S"] == "1" + assert last_eval_key["gsi_col"]["S"] == "1" + assert last_eval_key["gsi_range_key"]["S"] == "1" - res = dynamodb.scan(TableName='test', IndexName='test_lsi') - assert res['Count'] == 2 - assert len(res['Items']) == 2 + res = dynamodb.scan(TableName="test", IndexName="test_lsi") + assert res["Count"] == 2 + assert len(res["Items"]) == 2 - res = dynamodb.scan(TableName='test', IndexName='test_lsi', Limit=1) - assert res['Count'] == 1 - assert len(res['Items']) == 1 - last_eval_key = res['LastEvaluatedKey'] - assert last_eval_key['id']['S'] == '1' - assert last_eval_key['range_key']['S'] == '1' - assert last_eval_key['lsi_range_key']['S'] == '1' + res = dynamodb.scan(TableName="test", IndexName="test_lsi", Limit=1) + assert res["Count"] == 1 + assert len(res["Items"]) == 1 + last_eval_key = res["LastEvaluatedKey"] + assert last_eval_key["id"]["S"] == "1" + assert last_eval_key["range_key"]["S"] == "1" + assert last_eval_key["lsi_range_key"]["S"] == "1" diff --git a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py index 1880c7cab..08d7724f8 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py @@ -9,6 +9,7 @@ from boto.exception import JSONResponseError from moto import mock_dynamodb2, mock_dynamodb2_deprecated from tests.helpers import requires_boto_gte import botocore + try: from boto.dynamodb2.fields import HashKey from boto.dynamodb2.table import Table @@ -19,12 +20,9 @@ except ImportError: def create_table(): - table = Table.create('messages', schema=[ - HashKey('forum_name') - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", schema=[HashKey("forum_name")], throughput={"read": 10, "write": 10} + ) return table @@ -34,32 +32,31 @@ def create_table(): def test_create_table(): create_table() expected = { - 'Table': { - 'AttributeDefinitions': [ - {'AttributeName': 'forum_name', 'AttributeType': 'S'} + "Table": { + "AttributeDefinitions": [ + {"AttributeName": "forum_name", "AttributeType": "S"} ], - 'ProvisionedThroughput': { - 'NumberOfDecreasesToday': 0, 'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10 + "ProvisionedThroughput": { + "NumberOfDecreasesToday": 0, + "WriteCapacityUnits": 10, + "ReadCapacityUnits": 10, }, - 'TableSizeBytes': 0, - 'TableName': 'messages', - 'TableStatus': 'ACTIVE', - 'TableArn': 'arn:aws:dynamodb:us-east-1:123456789011:table/messages', - 'KeySchema': [ - {'KeyType': 'HASH', 'AttributeName': 'forum_name'} - ], - 'ItemCount': 0, 'CreationDateTime': 1326499200.0, - 'GlobalSecondaryIndexes': [], - 'LocalSecondaryIndexes': [] + "TableSizeBytes": 0, + "TableName": "messages", + "TableStatus": "ACTIVE", + "TableArn": "arn:aws:dynamodb:us-east-1:123456789011:table/messages", + "KeySchema": [{"KeyType": "HASH", "AttributeName": "forum_name"}], + "ItemCount": 0, + "CreationDateTime": 1326499200.0, + "GlobalSecondaryIndexes": [], + "LocalSecondaryIndexes": [], } } conn = boto.dynamodb2.connect_to_region( - 'us-east-1', - aws_access_key_id="ak", - aws_secret_access_key="sk" + "us-east-1", aws_access_key_id="ak", aws_secret_access_key="sk" ) - conn.describe_table('messages').should.equal(expected) + conn.describe_table("messages").should.equal(expected) @requires_boto_gte("2.9") @@ -69,11 +66,10 @@ def test_delete_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.list_tables()["TableNames"].should.have.length_of(1) - conn.delete_table('messages') + conn.delete_table("messages") conn.list_tables()["TableNames"].should.have.length_of(0) - conn.delete_table.when.called_with( - 'messages').should.throw(JSONResponseError) + conn.delete_table.when.called_with("messages").should.throw(JSONResponseError) @requires_boto_gte("2.9") @@ -83,10 +79,7 @@ def test_update_table_throughput(): table.throughput["read"].should.equal(10) table.throughput["write"].should.equal(10) - table.update(throughput={ - 'read': 5, - 'write': 6, - }) + table.update(throughput={"read": 5, "write": 6}) table.throughput["read"].should.equal(5) table.throughput["write"].should.equal(6) @@ -98,32 +91,34 @@ def test_item_add_and_describe_and_update(): table = create_table() data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", } table.put_item(data=data) returned_item = table.get_item(forum_name="LOLCat Forum") returned_item.should_not.be.none - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - }) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + } + ) - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.save(overwrite=True) - returned_item = table.get_item( - forum_name='LOLCat Forum' + returned_item = table.get_item(forum_name="LOLCat Forum") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @requires_boto_gte("2.9") @@ -132,25 +127,25 @@ def test_item_partial_save(): table = create_table() data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", } table.put_item(data=data) returned_item = table.get_item(forum_name="LOLCat Forum") - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.partial_save() - returned_item = table.get_item( - forum_name='LOLCat Forum' + returned_item = table.get_item(forum_name="LOLCat Forum") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @requires_boto_gte("2.9") @@ -159,12 +154,12 @@ def test_item_put_without_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.put_item.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", item={ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - } + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + }, ).should.throw(JSONResponseError) @@ -174,8 +169,7 @@ def test_get_item_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.get_item.when.called_with( - table_name='undeclared-table', - key={"forum_name": {"S": "LOLCat Forum"}}, + table_name="undeclared-table", key={"forum_name": {"S": "LOLCat Forum"}} ).should.throw(JSONResponseError) @@ -185,10 +179,10 @@ def test_delete_item(): table = create_table() item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.save() @@ -210,8 +204,7 @@ def test_delete_item_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.delete_item.when.called_with( - table_name='undeclared-table', - key={"forum_name": {"S": "LOLCat Forum"}}, + table_name="undeclared-table", key={"forum_name": {"S": "LOLCat Forum"}} ).should.throw(JSONResponseError) @@ -221,17 +214,17 @@ def test_query(): table = create_table() item_data = { - 'forum_name': 'the-key', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "the-key", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.save(overwrite=True) table.count().should.equal(1) table = Table("messages") - results = table.query(forum_name__eq='the-key') + results = table.query(forum_name__eq="the-key") sum(1 for _ in results).should.equal(1) @@ -241,9 +234,13 @@ def test_query_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.query.when.called_with( - table_name='undeclared-table', - key_conditions={"forum_name": { - "ComparisonOperator": "EQ", "AttributeValueList": [{"S": "the-key"}]}} + table_name="undeclared-table", + key_conditions={ + "forum_name": { + "ComparisonOperator": "EQ", + "AttributeValueList": [{"S": "the-key"}], + } + }, ).should.throw(JSONResponseError) @@ -253,36 +250,36 @@ def test_scan(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key' + item_data["forum_name"] = "the-key" item = Item(table, item_data) item.save() - item['forum_name'] = 'the-key2' + item["forum_name"] = "the-key2" item.save(overwrite=True) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item_data['forum_name'] = 'the-key3' + item_data["forum_name"] = "the-key3" item = Item(table, item_data) item.save() results = table.scan() sum(1 for _ in results).should.equal(3) - results = table.scan(SentBy__eq='User B') + results = table.scan(SentBy__eq="User B") sum(1 for _ in results).should.equal(1) - results = table.scan(Body__beginswith='http') + results = table.scan(Body__beginswith="http") sum(1 for _ in results).should.equal(3) results = table.scan(Ids__null=False) @@ -304,13 +301,11 @@ def test_scan_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(JSONResponseError) @@ -322,27 +317,28 @@ def test_write_batch(): table = create_table() with table.batch_write() as batch: - batch.put_item(data={ - 'forum_name': 'the-key', - 'subject': '123', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) - batch.put_item(data={ - 'forum_name': 'the-key2', - 'subject': '789', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) + batch.put_item( + data={ + "forum_name": "the-key", + "subject": "123", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) + batch.put_item( + data={ + "forum_name": "the-key2", + "subject": "789", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) table.count().should.equal(2) with table.batch_write() as batch: - batch.delete_item( - forum_name='the-key', - subject='789' - ) + batch.delete_item(forum_name="the-key", subject="789") table.count().should.equal(1) @@ -353,34 +349,31 @@ def test_batch_read(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key1' + item_data["forum_name"] = "the-key1" item = Item(table, item_data) item.save() item = Item(table, item_data) - item_data['forum_name'] = 'the-key2' + item_data["forum_name"] = "the-key2" item.save(overwrite=True) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } item = Item(table, item_data) - item_data['forum_name'] = 'another-key' + item_data["forum_name"] = "another-key" item.save(overwrite=True) results = table.batch_get( - keys=[ - {'forum_name': 'the-key1'}, - {'forum_name': 'another-key'}, - ] + keys=[{"forum_name": "the-key1"}, {"forum_name": "another-key"}] ) # Iterate through so that batch_item gets called @@ -393,132 +386,136 @@ def test_batch_read(): def test_get_key_fields(): table = create_table() kf = table.get_key_fields() - kf[0].should.equal('forum_name') + kf[0].should.equal("forum_name") @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_missing_item(): table = create_table() - table.get_item.when.called_with( - forum_name='missing').should.throw(ItemNotFound) + table.get_item.when.called_with(forum_name="missing").should.throw(ItemNotFound) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_special_item(): - table = Table.create('messages', schema=[ - HashKey('date-joined') - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", + schema=[HashKey("date-joined")], + throughput={"read": 10, "write": 10}, + ) - data = { - 'date-joined': 127549192, - 'SentBy': 'User A', - } + data = {"date-joined": 127549192, "SentBy": "User A"} table.put_item(data=data) - returned_item = table.get_item(**{'date-joined': 127549192}) + returned_item = table.get_item(**{"date-joined": 127549192}) dict(returned_item).should.equal(data) @mock_dynamodb2_deprecated def test_update_item_remove(): conn = boto.dynamodb2.connect_to_region("us-east-1") - table = Table.create('messages', schema=[ - HashKey('username') - ]) + table = Table.create("messages", schema=[HashKey("username")]) - data = { - 'username': "steve", - 'SentBy': 'User A', - 'SentTo': 'User B', - } + data = {"username": "steve", "SentBy": "User A", "SentTo": "User B"} table.put_item(data=data) - key_map = { - 'username': {"S": "steve"} - } + key_map = {"username": {"S": "steve"}} # Then remove the SentBy field - conn.update_item("messages", key_map, - update_expression="REMOVE SentBy, SentTo") + conn.update_item("messages", key_map, update_expression="REMOVE SentBy, SentTo") returned_item = table.get_item(username="steve") - dict(returned_item).should.equal({ - 'username': "steve", - }) + dict(returned_item).should.equal({"username": "steve"}) + + +@mock_dynamodb2_deprecated +def test_update_item_nested_remove(): + conn = boto.dynamodb2.connect_to_region("us-east-1") + table = Table.create("messages", schema=[HashKey("username")]) + + data = {"username": "steve", "Meta": {"FullName": "Steve Urkel"}} + table.put_item(data=data) + key_map = {"username": {"S": "steve"}} + + # Then remove the Meta.FullName field + conn.update_item("messages", key_map, update_expression="REMOVE Meta.FullName") + + returned_item = table.get_item(username="steve") + dict(returned_item).should.equal({"username": "steve", "Meta": {}}) + + +@mock_dynamodb2_deprecated +def test_update_item_double_nested_remove(): + conn = boto.dynamodb2.connect_to_region("us-east-1") + table = Table.create("messages", schema=[HashKey("username")]) + + data = {"username": "steve", "Meta": {"Name": {"First": "Steve", "Last": "Urkel"}}} + table.put_item(data=data) + key_map = {"username": {"S": "steve"}} + + # Then remove the Meta.FullName field + conn.update_item("messages", key_map, update_expression="REMOVE Meta.Name.First") + + returned_item = table.get_item(username="steve") + dict(returned_item).should.equal( + {"username": "steve", "Meta": {"Name": {"Last": "Urkel"}}} + ) @mock_dynamodb2_deprecated def test_update_item_set(): conn = boto.dynamodb2.connect_to_region("us-east-1") - table = Table.create('messages', schema=[ - HashKey('username') - ]) + table = Table.create("messages", schema=[HashKey("username")]) - data = { - 'username': "steve", - 'SentBy': 'User A', - } + data = {"username": "steve", "SentBy": "User A"} table.put_item(data=data) - key_map = { - 'username': {"S": "steve"} - } + key_map = {"username": {"S": "steve"}} - conn.update_item("messages", key_map, - update_expression="SET foo=bar, blah=baz REMOVE SentBy") + conn.update_item( + "messages", key_map, update_expression="SET foo=bar, blah=baz REMOVE SentBy" + ) returned_item = table.get_item(username="steve") - dict(returned_item).should.equal({ - 'username': "steve", - 'foo': 'bar', - 'blah': 'baz', - }) + dict(returned_item).should.equal({"username": "steve", "foo": "bar", "blah": "baz"}) @mock_dynamodb2_deprecated def test_failed_overwrite(): - table = Table.create('messages', schema=[ - HashKey('id'), - ], throughput={ - 'read': 7, - 'write': 3, - }) + table = Table.create( + "messages", schema=[HashKey("id")], throughput={"read": 7, "write": 3} + ) - data1 = {'id': '123', 'data': '678'} + data1 = {"id": "123", "data": "678"} table.put_item(data=data1) - data2 = {'id': '123', 'data': '345'} + data2 = {"id": "123", "data": "345"} table.put_item(data=data2, overwrite=True) - data3 = {'id': '123', 'data': '812'} + data3 = {"id": "123", "data": "812"} table.put_item.when.called_with(data=data3).should.throw( - ConditionalCheckFailedException) + ConditionalCheckFailedException + ) - returned_item = table.lookup('123') + returned_item = table.lookup("123") dict(returned_item).should.equal(data2) - data4 = {'id': '124', 'data': 812} + data4 = {"id": "124", "data": 812} table.put_item(data=data4) - returned_item = table.lookup('124') + returned_item = table.lookup("124") dict(returned_item).should.equal(data4) @mock_dynamodb2_deprecated def test_conflicting_writes(): - table = Table.create('messages', schema=[ - HashKey('id'), - ]) + table = Table.create("messages", schema=[HashKey("id")]) - item_data = {'id': '123', 'data': '678'} + item_data = {"id": "123", "data": "678"} item1 = Item(table, item_data) item2 = Item(table, item_data) item1.save() - item1['data'] = '579' - item2['data'] = '912' + item1["data"] = "579" + item2["data"] = "912" item1.save() item2.save.when.called_with().should.throw(ConditionalCheckFailedException) @@ -531,230 +528,178 @@ boto3 @mock_dynamodb2 def test_boto3_create_table(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") table = dynamodb.create_table( - TableName='users', - KeySchema=[ - { - 'AttributeName': 'username', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table.name.should.equal('users') + table.name.should.equal("users") def _create_user_table(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") table = dynamodb.create_table( - TableName='users', - KeySchema=[ - { - 'AttributeName': 'username', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - return dynamodb.Table('users') + return dynamodb.Table("users") @mock_dynamodb2 def test_boto3_conditions(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe'}) - table.put_item(Item={'username': 'janedoe'}) + table.put_item(Item={"username": "johndoe"}) + table.put_item(Item={"username": "janedoe"}) - response = table.query( - KeyConditionExpression=Key('username').eq('johndoe') - ) - response['Count'].should.equal(1) - response['Items'].should.have.length_of(1) - response['Items'][0].should.equal({"username": "johndoe"}) + response = table.query(KeyConditionExpression=Key("username").eq("johndoe")) + response["Count"].should.equal(1) + response["Items"].should.have.length_of(1) + response["Items"][0].should.equal({"username": "johndoe"}) @mock_dynamodb2 def test_boto3_put_item_conditions_pass(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'EQ', - 'AttributeValueList': ['bar'] - } - }) - final_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(final_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "EQ", "AttributeValueList": ["bar"]}}, + ) + final_item = table.get_item(Key={"username": "johndoe"}) + assert dict(final_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_put_item_conditions_pass_because_expect_not_exists_by_compare_to_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'whatever': { - 'ComparisonOperator': 'NULL', - } - }) - final_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(final_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"whatever": {"ComparisonOperator": "NULL"}}, + ) + final_item = table.get_item(Key={"username": "johndoe"}) + assert dict(final_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_put_item_conditions_pass_because_expect_exists_by_compare_to_not_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'NOT_NULL', - } - }) - final_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(final_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "NOT_NULL"}}, + ) + final_item = table.get_item(Key={"username": "johndoe"}) + assert dict(final_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_put_item_conditions_fail(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item.when.called_with( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'NE', - 'AttributeValueList': ['bar'] - } - }).should.throw(botocore.client.ClientError) + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "NE", "AttributeValueList": ["bar"]}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_fail(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'baz'}) + table.put_item(Item={"username": "johndoe", "foo": "baz"}) table.update_item.when.called_with( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=bar', - Expected={ - 'foo': { - 'Value': 'bar', - } - }).should.throw(botocore.client.ClientError) + Key={"username": "johndoe"}, + UpdateExpression="SET foo=bar", + Expected={"foo": {"Value": "bar"}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_fail_because_expect_not_exists(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'baz'}) + table.put_item(Item={"username": "johndoe", "foo": "baz"}) table.update_item.when.called_with( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=bar', - Expected={ - 'foo': { - 'Exists': False - } - }).should.throw(botocore.client.ClientError) + Key={"username": "johndoe"}, + UpdateExpression="SET foo=bar", + Expected={"foo": {"Exists": False}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_fail_because_expect_not_exists_by_compare_to_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'baz'}) + table.put_item(Item={"username": "johndoe", "foo": "baz"}) table.update_item.when.called_with( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=bar', - Expected={ - 'foo': { - 'ComparisonOperator': 'NULL', - } - }).should.throw(botocore.client.ClientError) + Key={"username": "johndoe"}, + UpdateExpression="SET foo=bar", + Expected={"foo": {"ComparisonOperator": "NULL"}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_pass(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'foo': { - 'Value': 'bar', - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"foo": {"Value": "bar"}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_update_item_conditions_pass_because_expect_not_exists(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'whatever': { - 'Exists': False, - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"whatever": {"Exists": False}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_update_item_conditions_pass_because_expect_not_exists_by_compare_to_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'whatever': { - 'ComparisonOperator': 'NULL', - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"whatever": {"ComparisonOperator": "NULL"}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_update_item_conditions_pass_because_expect_exists_by_compare_to_not_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'foo': { - 'ComparisonOperator': 'NOT_NULL', - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"foo": {"ComparisonOperator": "NOT_NULL"}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") @mock_dynamodb2 def test_boto3_update_settype_item_with_conditions(): class OrderedSet(set): """A set with predictable iteration order""" + def __init__(self, values): super(OrderedSet, self).__init__(values) self.__ordered_values = values @@ -763,143 +708,113 @@ def test_boto3_update_settype_item_with_conditions(): return iter(self.__ordered_values) table = _create_user_table() - table.put_item(Item={'username': 'johndoe'}) + table.put_item(Item={"username": "johndoe"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=:new_value', - ExpressionAttributeValues={ - ':new_value': OrderedSet(['hello', 'world']), - }, + Key={"username": "johndoe"}, + UpdateExpression="SET foo=:new_value", + ExpressionAttributeValues={":new_value": OrderedSet(["hello", "world"])}, ) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=:new_value', - ExpressionAttributeValues={ - ':new_value': set(['baz']), - }, + Key={"username": "johndoe"}, + UpdateExpression="SET foo=:new_value", + ExpressionAttributeValues={":new_value": set(["baz"])}, Expected={ - 'foo': { - 'ComparisonOperator': 'EQ', - 'AttributeValueList': [ - OrderedSet(['world', 'hello']), # Opposite order to original + "foo": { + "ComparisonOperator": "EQ", + "AttributeValueList": [ + OrderedSet(["world", "hello"]) # Opposite order to original ], } }, ) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal(set(['baz'])) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal(set(["baz"])) @mock_dynamodb2 def test_boto3_put_item_conditions_pass(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'EQ', - 'AttributeValueList': ['bar'] - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "EQ", "AttributeValueList": ["bar"]}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") @mock_dynamodb2 def test_scan_pagination(): table = _create_user_table() - expected_usernames = ['user{0}'.format(i) for i in range(10)] + expected_usernames = ["user{0}".format(i) for i in range(10)] for u in expected_usernames: - table.put_item(Item={'username': u}) + table.put_item(Item={"username": u}) page1 = table.scan(Limit=6) - page1['Count'].should.equal(6) - page1['Items'].should.have.length_of(6) - page1.should.have.key('LastEvaluatedKey') + page1["Count"].should.equal(6) + page1["Items"].should.have.length_of(6) + page1.should.have.key("LastEvaluatedKey") - page2 = table.scan(Limit=6, - ExclusiveStartKey=page1['LastEvaluatedKey']) - page2['Count'].should.equal(4) - page2['Items'].should.have.length_of(4) - page2.should_not.have.key('LastEvaluatedKey') + page2 = table.scan(Limit=6, ExclusiveStartKey=page1["LastEvaluatedKey"]) + page2["Count"].should.equal(4) + page2["Items"].should.have.length_of(4) + page2.should_not.have.key("LastEvaluatedKey") - results = page1['Items'] + page2['Items'] - usernames = set([r['username'] for r in results]) + results = page1["Items"] + page2["Items"] + usernames = set([r["username"] for r in results]) usernames.should.equal(set(expected_usernames)) @mock_dynamodb2 def test_scan_by_index(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_col', 'AttributeType': 'S'} + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_col", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - { - 'AttributeName': 'gsi_col', - 'KeyType': 'HASH' - }, - ], - 'Projection': { - 'ProjectionType': 'ALL', + "IndexName": "test_gsi", + "KeySchema": [{"AttributeName": "gsi_col", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, - ] + } + ], ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '1'}, - 'col1': {'S': 'val1'}, - 'gsi_col': {'S': 'gsi_val1'}, - } + TableName="test", + Item={"id": {"S": "1"}, "col1": {"S": "val1"}, "gsi_col": {"S": "gsi_val1"}}, ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '2'}, - 'col1': {'S': 'val2'}, - 'gsi_col': {'S': 'gsi_val2'}, - } + TableName="test", + Item={"id": {"S": "2"}, "col1": {"S": "val2"}, "gsi_col": {"S": "gsi_val2"}}, ) - dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '3'}, - 'col1': {'S': 'val3'}, - } - ) + dynamodb.put_item(TableName="test", Item={"id": {"S": "3"}, "col1": {"S": "val3"}}) - res = dynamodb.scan(TableName='test') - assert res['Count'] == 3 - assert len(res['Items']) == 3 + res = dynamodb.scan(TableName="test") + assert res["Count"] == 3 + assert len(res["Items"]) == 3 - res = dynamodb.scan(TableName='test', IndexName='test_gsi') - assert res['Count'] == 2 - assert len(res['Items']) == 2 + res = dynamodb.scan(TableName="test", IndexName="test_gsi") + assert res["Count"] == 2 + assert len(res["Items"]) == 2 - res = dynamodb.scan(TableName='test', IndexName='test_gsi', Limit=1) - assert res['Count'] == 1 - assert len(res['Items']) == 1 - last_eval_key = res['LastEvaluatedKey'] - assert last_eval_key['id']['S'] == '1' - assert last_eval_key['gsi_col']['S'] == 'gsi_val1' + res = dynamodb.scan(TableName="test", IndexName="test_gsi", Limit=1) + assert res["Count"] == 1 + assert len(res["Items"]) == 1 + last_eval_key = res["LastEvaluatedKey"] + assert last_eval_key["id"]["S"] == "1" + assert last_eval_key["gsi_col"]["S"] == "gsi_val1" diff --git a/tests/test_dynamodb2/test_server.py b/tests/test_dynamodb2/test_server.py index be94df0f4..880909fac 100644 --- a/tests/test_dynamodb2/test_server.py +++ b/tests/test_dynamodb2/test_server.py @@ -1,19 +1,19 @@ -from __future__ import unicode_literals -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_table_list(): - backend = server.create_backend_app("dynamodb2") - test_client = backend.test_client() - res = test_client.get('/') - res.status_code.should.equal(404) - - headers = {'X-Amz-Target': 'TestTable.ListTables'} - res = test_client.get('/', headers=headers) - res.data.should.contain(b'TableNames') +from __future__ import unicode_literals +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_table_list(): + backend = server.create_backend_app("dynamodb2") + test_client = backend.test_client() + res = test_client.get("/") + res.status_code.should.equal(404) + + headers = {"X-Amz-Target": "TestTable.ListTables"} + res = test_client.get("/", headers=headers) + res.data.should.contain(b"TableNames") diff --git a/tests/test_dynamodbstreams/test_dynamodbstreams.py b/tests/test_dynamodbstreams/test_dynamodbstreams.py index b60c21053..8fad0ff23 100644 --- a/tests/test_dynamodbstreams/test_dynamodbstreams.py +++ b/tests/test_dynamodbstreams/test_dynamodbstreams.py @@ -6,142 +6,190 @@ import boto3 from moto import mock_dynamodb2, mock_dynamodbstreams -class TestCore(): +class TestCore: stream_arn = None mocks = [] - + def setup(self): self.mocks = [mock_dynamodb2(), mock_dynamodbstreams()] for m in self.mocks: m.start() - + # create a table with a stream - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', - 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1}, + TableName="test-streams", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, StreamSpecification={ - 'StreamEnabled': True, - 'StreamViewType': 'NEW_AND_OLD_IMAGES' - } + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", + }, ) - self.stream_arn = resp['TableDescription']['LatestStreamArn'] + self.stream_arn = resp["TableDescription"]["LatestStreamArn"] def teardown(self): - conn = boto3.client('dynamodb', region_name='us-east-1') - conn.delete_table(TableName='test-streams') + conn = boto3.client("dynamodb", region_name="us-east-1") + conn.delete_table(TableName="test-streams") self.stream_arn = None for m in self.mocks: m.stop() - def test_verify_stream(self): - conn = boto3.client('dynamodb', region_name='us-east-1') - resp = conn.describe_table(TableName='test-streams') - assert 'LatestStreamArn' in resp['Table'] + conn = boto3.client("dynamodb", region_name="us-east-1") + resp = conn.describe_table(TableName="test-streams") + assert "LatestStreamArn" in resp["Table"] def test_describe_stream(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - assert 'StreamDescription' in resp - desc = resp['StreamDescription'] - assert desc['StreamArn'] == self.stream_arn - assert desc['TableName'] == 'test-streams' + assert "StreamDescription" in resp + desc = resp["StreamDescription"] + assert desc["StreamArn"] == self.stream_arn + assert desc["TableName"] == "test-streams" def test_list_streams(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.list_streams() - assert resp['Streams'][0]['StreamArn'] == self.stream_arn + assert resp["Streams"][0]["StreamArn"] == self.stream_arn - resp = conn.list_streams(TableName='no-stream') - assert not resp['Streams'] + resp = conn.list_streams(TableName="no-stream") + assert not resp["Streams"] def test_get_shard_iterator(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='TRIM_HORIZON' + ShardIteratorType="TRIM_HORIZON", ) - assert 'ShardIterator' in resp - + assert "ShardIterator" in resp + + def test_get_shard_iterator_at_sequence_number(self): + conn = boto3.client("dynamodbstreams", region_name="us-east-1") + + resp = conn.describe_stream(StreamArn=self.stream_arn) + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType="AT_SEQUENCE_NUMBER", + SequenceNumber=resp["StreamDescription"]["Shards"][0][ + "SequenceNumberRange" + ]["StartingSequenceNumber"], + ) + assert "ShardIterator" in resp + + def test_get_shard_iterator_after_sequence_number(self): + conn = boto3.client("dynamodbstreams", region_name="us-east-1") + + resp = conn.describe_stream(StreamArn=self.stream_arn) + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType="AFTER_SEQUENCE_NUMBER", + SequenceNumber=resp["StreamDescription"]["Shards"][0][ + "SequenceNumberRange" + ]["StartingSequenceNumber"], + ) + assert "ShardIterator" in resp + def test_get_records_empty(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( - StreamArn=self.stream_arn, - ShardId=shard_id, - ShardIteratorType='LATEST' + StreamArn=self.stream_arn, ShardId=shard_id, ShardIteratorType="LATEST" ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] resp = conn.get_records(ShardIterator=iterator_id) - assert 'Records' in resp - assert len(resp['Records']) == 0 + assert "Records" in resp + assert len(resp["Records"]) == 0 def test_get_records_seq(self): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") conn.put_item( - TableName='test-streams', - Item={ - 'id': {'S': 'entry1'}, - 'first_col': {'S': 'foo'} - } + TableName="test-streams", + Item={"id": {"S": "entry1"}, "first_col": {"S": "foo"}}, ) conn.put_item( - TableName='test-streams', + TableName="test-streams", Item={ - 'id': {'S': 'entry1'}, - 'first_col': {'S': 'bar'}, - 'second_col': {'S': 'baz'} - } + "id": {"S": "entry1"}, + "first_col": {"S": "bar"}, + "second_col": {"S": "baz"}, + }, ) - conn.delete_item( - TableName='test-streams', - Key={'id': {'S': 'entry1'}} - ) - - conn = boto3.client('dynamodbstreams', region_name='us-east-1') - + conn.delete_item(TableName="test-streams", Key={"id": {"S": "entry1"}}) + + conn = boto3.client("dynamodbstreams", region_name="us-east-1") + resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='TRIM_HORIZON' + ShardIteratorType="TRIM_HORIZON", ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] resp = conn.get_records(ShardIterator=iterator_id) - assert len(resp['Records']) == 3 - assert resp['Records'][0]['eventName'] == 'INSERT' - assert resp['Records'][1]['eventName'] == 'MODIFY' - assert resp['Records'][2]['eventName'] == 'DELETE' + assert len(resp["Records"]) == 3 + assert resp["Records"][0]["eventName"] == "INSERT" + assert resp["Records"][1]["eventName"] == "MODIFY" + assert resp["Records"][2]["eventName"] == "DELETE" + + sequence_number_modify = resp["Records"][1]["dynamodb"]["SequenceNumber"] # now try fetching from the next shard iterator, it should be # empty - resp = conn.get_records(ShardIterator=resp['NextShardIterator']) - assert len(resp['Records']) == 0 + resp = conn.get_records(ShardIterator=resp["NextShardIterator"]) + assert len(resp["Records"]) == 0 + + # check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType="AT_SEQUENCE_NUMBER", + SequenceNumber=sequence_number_modify, + ) + iterator_id = resp["ShardIterator"] + resp = conn.get_records(ShardIterator=iterator_id) + assert len(resp["Records"]) == 2 + assert resp["Records"][0]["eventName"] == "MODIFY" + assert resp["Records"][1]["eventName"] == "DELETE" + + # check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType="AFTER_SEQUENCE_NUMBER", + SequenceNumber=sequence_number_modify, + ) + iterator_id = resp["ShardIterator"] + resp = conn.get_records(ShardIterator=iterator_id) + assert len(resp["Records"]) == 1 + assert resp["Records"][0]["eventName"] == "DELETE" -class TestEdges(): +class TestEdges: mocks = [] def setup(self): @@ -153,82 +201,73 @@ class TestEdges(): for m in self.mocks: m.stop() - def test_enable_stream_on_table(self): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', - 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1} + TableName="test-streams", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - assert 'StreamSpecification' not in resp['TableDescription'] - + assert "StreamSpecification" not in resp["TableDescription"] + resp = conn.update_table( - TableName='test-streams', - StreamSpecification={ - 'StreamViewType': 'KEYS_ONLY' - } + TableName="test-streams", + StreamSpecification={"StreamViewType": "KEYS_ONLY", "StreamEnabled": True}, ) - assert 'StreamSpecification' in resp['TableDescription'] - assert resp['TableDescription']['StreamSpecification'] == { - 'StreamEnabled': True, - 'StreamViewType': 'KEYS_ONLY' + assert "StreamSpecification" in resp["TableDescription"] + assert resp["TableDescription"]["StreamSpecification"] == { + "StreamEnabled": True, + "StreamViewType": "KEYS_ONLY", } - assert 'LatestStreamLabel' in resp['TableDescription'] + assert "LatestStreamLabel" in resp["TableDescription"] # now try to enable it again with assert_raises(conn.exceptions.ResourceInUseException): resp = conn.update_table( - TableName='test-streams', + TableName="test-streams", StreamSpecification={ - 'StreamViewType': 'OLD_IMAGES' - } + "StreamViewType": "OLD_IMAGES", + "StreamEnabled": True, + }, ) - + def test_stream_with_range_key(self): - dyn = boto3.client('dynamodb', region_name='us-east-1') + dyn = boto3.client("dynamodb", region_name="us-east-1") resp = dyn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}, - {'AttributeName': 'color', 'KeyType': 'RANGE'}], - AttributeDefinitions=[{'AttributeName': 'id', - 'AttributeType': 'S'}, - {'AttributeName': 'color', - 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1}, - StreamSpecification={ - 'StreamViewType': 'NEW_IMAGES' - } + TableName="test-streams", + KeySchema=[ + {"AttributeName": "id", "KeyType": "HASH"}, + {"AttributeName": "color", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "color", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + StreamSpecification={"StreamViewType": "NEW_IMAGES", "StreamEnabled": True}, ) - stream_arn = resp['TableDescription']['LatestStreamArn'] + stream_arn = resp["TableDescription"]["LatestStreamArn"] - streams = boto3.client('dynamodbstreams', region_name='us-east-1') + streams = boto3.client("dynamodbstreams", region_name="us-east-1") resp = streams.describe_stream(StreamArn=stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] resp = streams.get_shard_iterator( - StreamArn=stream_arn, - ShardId=shard_id, - ShardIteratorType='LATEST' + StreamArn=stream_arn, ShardId=shard_id, ShardIteratorType="LATEST" ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] dyn.put_item( - TableName='test-streams', - Item={'id': {'S': 'row1'}, 'color': {'S': 'blue'}} + TableName="test-streams", Item={"id": {"S": "row1"}, "color": {"S": "blue"}} ) dyn.put_item( - TableName='test-streams', - Item={'id': {'S': 'row2'}, 'color': {'S': 'green'}} + TableName="test-streams", + Item={"id": {"S": "row2"}, "color": {"S": "green"}}, ) resp = streams.get_records(ShardIterator=iterator_id) - assert len(resp['Records']) == 2 - assert resp['Records'][0]['eventName'] == 'INSERT' - assert resp['Records'][1]['eventName'] == 'INSERT' - + assert len(resp["Records"]) == 2 + assert resp["Records"][0]["eventName"] == "INSERT" + assert resp["Records"][1]["eventName"] == "INSERT" diff --git a/tests/test_ec2/helpers.py b/tests/test_ec2/helpers.py index 94c9c10cb..6dd281874 100644 --- a/tests/test_ec2/helpers.py +++ b/tests/test_ec2/helpers.py @@ -9,7 +9,8 @@ def rsa_check_private_key(private_key_material): assert isinstance(private_key_material, six.string_types) private_key = serialization.load_pem_private_key( - data=private_key_material.encode('ascii'), + data=private_key_material.encode("ascii"), backend=default_backend(), - password=None) + password=None, + ) assert isinstance(private_key, rsa.RSAPrivateKey) diff --git a/tests/test_ec2/test_account_attributes.py b/tests/test_ec2/test_account_attributes.py index 45ae09419..41c71def5 100644 --- a/tests/test_ec2/test_account_attributes.py +++ b/tests/test_ec2/test_account_attributes.py @@ -6,39 +6,32 @@ import sure # noqa @mock_ec2 def test_describe_account_attributes(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") response = conn.describe_account_attributes() - expected_attribute_values = [{ - 'AttributeValues': [{ - 'AttributeValue': '5' - }], - 'AttributeName': 'vpc-max-security-groups-per-interface' - }, { - 'AttributeValues': [{ - 'AttributeValue': '20' - }], - 'AttributeName': 'max-instances' - }, { - 'AttributeValues': [{ - 'AttributeValue': 'EC2' - }, { - 'AttributeValue': 'VPC' - }], - 'AttributeName': 'supported-platforms' - }, { - 'AttributeValues': [{ - 'AttributeValue': 'none' - }], - 'AttributeName': 'default-vpc' - }, { - 'AttributeValues': [{ - 'AttributeValue': '5' - }], - 'AttributeName': 'max-elastic-ips' - }, { - 'AttributeValues': [{ - 'AttributeValue': '5' - }], - 'AttributeName': 'vpc-max-elastic-ips' - }] - response['AccountAttributes'].should.equal(expected_attribute_values) + expected_attribute_values = [ + { + "AttributeValues": [{"AttributeValue": "5"}], + "AttributeName": "vpc-max-security-groups-per-interface", + }, + { + "AttributeValues": [{"AttributeValue": "20"}], + "AttributeName": "max-instances", + }, + { + "AttributeValues": [{"AttributeValue": "EC2"}, {"AttributeValue": "VPC"}], + "AttributeName": "supported-platforms", + }, + { + "AttributeValues": [{"AttributeValue": "none"}], + "AttributeName": "default-vpc", + }, + { + "AttributeValues": [{"AttributeValue": "5"}], + "AttributeName": "max-elastic-ips", + }, + { + "AttributeValues": [{"AttributeValue": "5"}], + "AttributeName": "vpc-max-elastic-ips", + }, + ] + response["AccountAttributes"].should.equal(expected_attribute_values) diff --git a/tests/test_ec2/test_amis.py b/tests/test_ec2/test_amis.py index feff4a16c..f65352c7c 100644 --- a/tests/test_ec2/test_amis.py +++ b/tests/test_ec2/test_amis.py @@ -5,6 +5,7 @@ import boto.ec2 import boto3 from boto.exception import EC2ResponseError from botocore.exceptions import ClientError + # Ensure 'assert_raises' context manager support for Python 2.6 from nose.tools import assert_raises import sure # noqa @@ -16,22 +17,24 @@ from tests.helpers import requires_boto_gte @mock_ec2_deprecated def test_ami_create_and_delete(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") initial_ami_count = len(AMIS) conn.get_all_volumes().should.have.length_of(0) conn.get_all_snapshots().should.have.length_of(initial_ami_count) - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: image_id = conn.create_image( - instance.id, "test-ami", "this is a test ami", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + instance.id, "test-ami", "this is a test ami", dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateImage operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateImage operation: Request would have succeeded, but DryRun flag is set" + ) image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") @@ -56,30 +59,36 @@ def test_ami_create_and_delete(): snapshots = conn.get_all_snapshots() snapshots.should.have.length_of(initial_ami_count + 1) - retrieved_image_snapshot_id = retrieved_image.block_device_mapping.current_value.snapshot_id + retrieved_image_snapshot_id = ( + retrieved_image.block_device_mapping.current_value.snapshot_id + ) [s.id for s in snapshots].should.contain(retrieved_image_snapshot_id) snapshot = [s for s in snapshots if s.id == retrieved_image_snapshot_id][0] snapshot.description.should.equal( - "Auto-created snapshot for AMI {0}".format(retrieved_image.id)) + "Auto-created snapshot for AMI {0}".format(retrieved_image.id) + ) # root device should be in AMI's block device mappings - root_mapping = retrieved_image.block_device_mapping.get(retrieved_image.root_device_name) + root_mapping = retrieved_image.block_device_mapping.get( + retrieved_image.root_device_name + ) root_mapping.should_not.be.none # Deregister with assert_raises(EC2ResponseError) as ex: success = conn.deregister_image(image_id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeregisterImage operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeregisterImage operation: Request would have succeeded, but DryRun flag is set" + ) success = conn.deregister_image(image_id) success.should.be.true with assert_raises(EC2ResponseError) as cm: conn.deregister_image(image_id) - cm.exception.code.should.equal('InvalidAMIID.NotFound') + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -93,11 +102,10 @@ def test_ami_copy(): conn.get_all_volumes().should.have.length_of(0) conn.get_all_snapshots().should.have.length_of(initial_ami_count) - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] - source_image_id = conn.create_image( - instance.id, "test-ami", "this is a test ami") + source_image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") instance.terminate() source_image = conn.get_all_images(image_ids=[source_image_id])[0] @@ -105,21 +113,29 @@ def test_ami_copy(): # the image_id to fetch the full info. with assert_raises(EC2ResponseError) as ex: copy_image_ref = conn.copy_image( - source_image.region.name, source_image.id, "test-copy-ami", "this is a test copy ami", - dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + source_image.region.name, + source_image.id, + "test-copy-ami", + "this is a test copy ami", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CopyImage operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CopyImage operation: Request would have succeeded, but DryRun flag is set" + ) copy_image_ref = conn.copy_image( - source_image.region.name, source_image.id, "test-copy-ami", "this is a test copy ami") + source_image.region.name, + source_image.id, + "test-copy-ami", + "this is a test copy ami", + ) copy_image_id = copy_image_ref.image_id copy_image = conn.get_all_images(image_ids=[copy_image_id])[0] copy_image.id.should.equal(copy_image_id) - copy_image.virtualization_type.should.equal( - source_image.virtualization_type) + copy_image.virtualization_type.should.equal(source_image.virtualization_type) copy_image.architecture.should.equal(source_image.architecture) copy_image.kernel_id.should.equal(source_image.kernel_id) copy_image.platform.should.equal(source_image.platform) @@ -131,30 +147,37 @@ def test_ami_copy(): conn.get_all_snapshots().should.have.length_of(initial_ami_count + 2) copy_image.block_device_mapping.current_value.snapshot_id.should_not.equal( - source_image.block_device_mapping.current_value.snapshot_id) + source_image.block_device_mapping.current_value.snapshot_id + ) # Copy from non-existent source ID. with assert_raises(EC2ResponseError) as cm: - conn.copy_image(source_image.region.name, 'ami-abcd1234', - "test-copy-ami", "this is a test copy ami") - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.copy_image( + source_image.region.name, + "ami-abcd1234", + "test-copy-ami", + "this is a test copy ami", + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Copy from non-existent source region. with assert_raises(EC2ResponseError) as cm: - invalid_region = 'us-east-1' if (source_image.region.name != - 'us-east-1') else 'us-west-1' - conn.copy_image(invalid_region, source_image.id, - "test-copy-ami", "this is a test copy ami") - cm.exception.code.should.equal('InvalidAMIID.NotFound') + invalid_region = ( + "us-east-1" if (source_image.region.name != "us-east-1") else "us-west-1" + ) + conn.copy_image( + invalid_region, source_image.id, "test-copy-ami", "this is a test copy ami" + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_copy_image_changes_owner_id(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") # this source AMI ID is from moto/ec2/resources/amis.json source_ami_id = "ami-03cf127a" @@ -168,7 +191,8 @@ def test_copy_image_changes_owner_id(): SourceImageId=source_ami_id, Name="new-image", Description="a copy of an image", - SourceRegion="us-east-1") + SourceRegion="us-east-1", + ) describe_resp = conn.describe_images(Owners=["self"]) describe_resp["Images"][0]["OwnerId"].should.equal(OWNER_ID) @@ -177,18 +201,19 @@ def test_copy_image_changes_owner_id(): @mock_ec2_deprecated def test_ami_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_vpc("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_all_images()[0] with assert_raises(EC2ResponseError) as ex: image.add_tag("a key", "some value", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) image.add_tag("a key", "some value") @@ -204,368 +229,374 @@ def test_ami_tagging(): @mock_ec2_deprecated def test_ami_create_from_missing_instance(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") args = ["i-abcdefg", "test-ami", "this is a test ami"] with assert_raises(EC2ResponseError) as cm: conn.create_image(*args) - cm.exception.code.should.equal('InvalidInstanceID.NotFound') + cm.exception.code.should.equal("InvalidInstanceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_ami_pulls_attributes_from_instance(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.modify_attribute("kernel", "test-kernel") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) - image.kernel_id.should.equal('test-kernel') + image.kernel_id.should.equal("test-kernel") @mock_ec2_deprecated def test_ami_filters(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - reservationA = conn.run_instances('ami-1234abcd') + reservationA = conn.run_instances("ami-1234abcd") instanceA = reservationA.instances[0] instanceA.modify_attribute("architecture", "i386") instanceA.modify_attribute("kernel", "k-1234abcd") instanceA.modify_attribute("platform", "windows") instanceA.modify_attribute("virtualization_type", "hvm") - imageA_id = conn.create_image( - instanceA.id, "test-ami-A", "this is a test ami") + imageA_id = conn.create_image(instanceA.id, "test-ami-A", "this is a test ami") imageA = conn.get_image(imageA_id) - reservationB = conn.run_instances('ami-abcd1234') + reservationB = conn.run_instances("ami-abcd1234") instanceB = reservationB.instances[0] instanceB.modify_attribute("architecture", "x86_64") instanceB.modify_attribute("kernel", "k-abcd1234") instanceB.modify_attribute("platform", "linux") instanceB.modify_attribute("virtualization_type", "paravirtual") - imageB_id = conn.create_image( - instanceB.id, "test-ami-B", "this is a test ami") + imageB_id = conn.create_image(instanceB.id, "test-ami-B", "this is a test ami") imageB = conn.get_image(imageB_id) imageB.set_launch_permissions(group_names=("all")) - amis_by_architecture = conn.get_all_images( - filters={'architecture': 'x86_64'}) + amis_by_architecture = conn.get_all_images(filters={"architecture": "x86_64"}) set([ami.id for ami in amis_by_architecture]).should.contain(imageB.id) len(amis_by_architecture).should.equal(35) - amis_by_kernel = conn.get_all_images(filters={'kernel-id': 'k-abcd1234'}) + amis_by_kernel = conn.get_all_images(filters={"kernel-id": "k-abcd1234"}) set([ami.id for ami in amis_by_kernel]).should.equal(set([imageB.id])) amis_by_virtualization = conn.get_all_images( - filters={'virtualization-type': 'paravirtual'}) - set([ami.id for ami in amis_by_virtualization] - ).should.contain(imageB.id) + filters={"virtualization-type": "paravirtual"} + ) + set([ami.id for ami in amis_by_virtualization]).should.contain(imageB.id) len(amis_by_virtualization).should.equal(3) - amis_by_platform = conn.get_all_images(filters={'platform': 'windows'}) + amis_by_platform = conn.get_all_images(filters={"platform": "windows"}) set([ami.id for ami in amis_by_platform]).should.contain(imageA.id) len(amis_by_platform).should.equal(24) - amis_by_id = conn.get_all_images(filters={'image-id': imageA.id}) + amis_by_id = conn.get_all_images(filters={"image-id": imageA.id}) set([ami.id for ami in amis_by_id]).should.equal(set([imageA.id])) - amis_by_state = conn.get_all_images(filters={'state': 'available'}) + amis_by_state = conn.get_all_images(filters={"state": "available"}) ami_ids_by_state = [ami.id for ami in amis_by_state] ami_ids_by_state.should.contain(imageA.id) ami_ids_by_state.should.contain(imageB.id) len(amis_by_state).should.equal(36) - amis_by_name = conn.get_all_images(filters={'name': imageA.name}) + amis_by_name = conn.get_all_images(filters={"name": imageA.name}) set([ami.id for ami in amis_by_name]).should.equal(set([imageA.id])) - amis_by_public = conn.get_all_images(filters={'is-public': 'true'}) + amis_by_public = conn.get_all_images(filters={"is-public": "true"}) set([ami.id for ami in amis_by_public]).should.contain(imageB.id) len(amis_by_public).should.equal(35) - amis_by_nonpublic = conn.get_all_images(filters={'is-public': 'false'}) + amis_by_nonpublic = conn.get_all_images(filters={"is-public": "false"}) set([ami.id for ami in amis_by_nonpublic]).should.contain(imageA.id) len(amis_by_nonpublic).should.equal(1) @mock_ec2_deprecated def test_ami_filtering_via_tag(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - reservationA = conn.run_instances('ami-1234abcd') + reservationA = conn.run_instances("ami-1234abcd") instanceA = reservationA.instances[0] - imageA_id = conn.create_image( - instanceA.id, "test-ami-A", "this is a test ami") + imageA_id = conn.create_image(instanceA.id, "test-ami-A", "this is a test ami") imageA = conn.get_image(imageA_id) imageA.add_tag("a key", "some value") - reservationB = conn.run_instances('ami-abcd1234') + reservationB = conn.run_instances("ami-abcd1234") instanceB = reservationB.instances[0] - imageB_id = conn.create_image( - instanceB.id, "test-ami-B", "this is a test ami") + imageB_id = conn.create_image(instanceB.id, "test-ami-B", "this is a test ami") imageB = conn.get_image(imageB_id) imageB.add_tag("another key", "some other value") - amis_by_tagA = conn.get_all_images(filters={'tag:a key': 'some value'}) + amis_by_tagA = conn.get_all_images(filters={"tag:a key": "some value"}) set([ami.id for ami in amis_by_tagA]).should.equal(set([imageA.id])) - amis_by_tagB = conn.get_all_images( - filters={'tag:another key': 'some other value'}) + amis_by_tagB = conn.get_all_images(filters={"tag:another key": "some other value"}) set([ami.id for ami in amis_by_tagB]).should.equal(set([imageB.id])) @mock_ec2_deprecated def test_getting_missing_ami(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.get_image('ami-missing') - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.get_image("ami-missing") + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_getting_malformed_ami(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.get_image('foo-missing') - cm.exception.code.should.equal('InvalidAMIID.Malformed') + conn.get_image("foo-missing") + cm.exception.code.should.equal("InvalidAMIID.Malformed") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_ami_attribute_group_permissions(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Baseline - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.name.should.equal('launch_permission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.name.should.equal("launch_permission") attributes.attrs.should.have.length_of(0) - ADD_GROUP_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'add', - 'groups': 'all'} + ADD_GROUP_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "add", + "groups": "all", + } - REMOVE_GROUP_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'groups': 'all'} + REMOVE_GROUP_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "groups": "all", + } # Add 'all' group and confirm with assert_raises(EC2ResponseError) as ex: - conn.modify_image_attribute( - **dict(ADD_GROUP_ARGS, **{'dry_run': True})) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_image_attribute(**dict(ADD_GROUP_ARGS, **{"dry_run": True})) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyImageAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyImageAttribute operation: Request would have succeeded, but DryRun flag is set" + ) conn.modify_image_attribute(**ADD_GROUP_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['groups'].should.have.length_of(1) - attributes.attrs['groups'].should.equal(['all']) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["groups"].should.have.length_of(1) + attributes.attrs["groups"].should.equal(["all"]) image = conn.get_image(image_id) image.is_public.should.equal(True) # Add is idempotent - conn.modify_image_attribute.when.called_with( - **ADD_GROUP_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**ADD_GROUP_ARGS).should_not.throw( + EC2ResponseError + ) # Remove 'all' group and confirm conn.modify_image_attribute(**REMOVE_GROUP_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) image = conn.get_image(image_id) image.is_public.should.equal(False) # Remove is idempotent - conn.modify_image_attribute.when.called_with( - **REMOVE_GROUP_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**REMOVE_GROUP_ARGS).should_not.throw( + EC2ResponseError + ) @mock_ec2_deprecated def test_ami_attribute_user_permissions(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Baseline - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.name.should.equal('launch_permission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.name.should.equal("launch_permission") attributes.attrs.should.have.length_of(0) # Both str and int values should work. - USER1 = '123456789011' + USER1 = "123456789011" USER2 = 123456789022 - ADD_USERS_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'add', - 'user_ids': [USER1, USER2]} + ADD_USERS_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "add", + "user_ids": [USER1, USER2], + } - REMOVE_USERS_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'user_ids': [USER1, USER2]} + REMOVE_USERS_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "user_ids": [USER1, USER2], + } - REMOVE_SINGLE_USER_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'user_ids': [USER1]} + REMOVE_SINGLE_USER_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "user_ids": [USER1], + } # Add multiple users and confirm conn.modify_image_attribute(**ADD_USERS_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['user_ids'].should.have.length_of(2) - set(attributes.attrs['user_ids']).should.equal( - set([str(USER1), str(USER2)])) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["user_ids"].should.have.length_of(2) + set(attributes.attrs["user_ids"]).should.equal(set([str(USER1), str(USER2)])) image = conn.get_image(image_id) image.is_public.should.equal(False) # Add is idempotent - conn.modify_image_attribute.when.called_with( - **ADD_USERS_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**ADD_USERS_ARGS).should_not.throw( + EC2ResponseError + ) # Remove single user and confirm conn.modify_image_attribute(**REMOVE_SINGLE_USER_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['user_ids'].should.have.length_of(1) - set(attributes.attrs['user_ids']).should.equal(set([str(USER2)])) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["user_ids"].should.have.length_of(1) + set(attributes.attrs["user_ids"]).should.equal(set([str(USER2)])) image = conn.get_image(image_id) image.is_public.should.equal(False) # Remove multiple users and confirm conn.modify_image_attribute(**REMOVE_USERS_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) image = conn.get_image(image_id) image.is_public.should.equal(False) # Remove is idempotent - conn.modify_image_attribute.when.called_with( - **REMOVE_USERS_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**REMOVE_USERS_ARGS).should_not.throw( + EC2ResponseError + ) @mock_ec2 def test_ami_describe_executable_users(): - conn = boto3.client('ec2', region_name='us-east-1') - ec2 = boto3.resource('ec2', 'us-east-1') - ec2.create_instances(ImageId='', - MinCount=1, - MaxCount=1) - response = conn.describe_instances(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]) - instance_id = response['Reservations'][0]['Instances'][0]['InstanceId'] - image_id = conn.create_image(InstanceId=instance_id, - Name='TestImage', )['ImageId'] + conn = boto3.client("ec2", region_name="us-east-1") + ec2 = boto3.resource("ec2", "us-east-1") + ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + response = conn.describe_instances( + Filters=[{"Name": "instance-state-name", "Values": ["running"]}] + ) + instance_id = response["Reservations"][0]["Instances"][0]["InstanceId"] + image_id = conn.create_image(InstanceId=instance_id, Name="TestImage")["ImageId"] - USER1 = '123456789011' + USER1 = "123456789011" - ADD_USER_ARGS = {'ImageId': image_id, - 'Attribute': 'launchPermission', - 'OperationType': 'add', - 'UserIds': [USER1]} + ADD_USER_ARGS = { + "ImageId": image_id, + "Attribute": "launchPermission", + "OperationType": "add", + "UserIds": [USER1], + } # Add users and get no images conn.modify_image_attribute(**ADD_USER_ARGS) - attributes = conn.describe_image_attribute(ImageId=image_id, - Attribute='LaunchPermissions', - DryRun=False) - attributes['LaunchPermissions'].should.have.length_of(1) - attributes['LaunchPermissions'][0]['UserId'].should.equal(USER1) - images = conn.describe_images(ExecutableUsers=[USER1])['Images'] + attributes = conn.describe_image_attribute( + ImageId=image_id, Attribute="LaunchPermissions", DryRun=False + ) + attributes["LaunchPermissions"].should.have.length_of(1) + attributes["LaunchPermissions"][0]["UserId"].should.equal(USER1) + images = conn.describe_images(ExecutableUsers=[USER1])["Images"] images.should.have.length_of(1) - images[0]['ImageId'].should.equal(image_id) + images[0]["ImageId"].should.equal(image_id) @mock_ec2 def test_ami_describe_executable_users_negative(): - conn = boto3.client('ec2', region_name='us-east-1') - ec2 = boto3.resource('ec2', 'us-east-1') - ec2.create_instances(ImageId='', - MinCount=1, - MaxCount=1) - response = conn.describe_instances(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]) - instance_id = response['Reservations'][0]['Instances'][0]['InstanceId'] - image_id = conn.create_image(InstanceId=instance_id, - Name='TestImage')['ImageId'] + conn = boto3.client("ec2", region_name="us-east-1") + ec2 = boto3.resource("ec2", "us-east-1") + ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + response = conn.describe_instances( + Filters=[{"Name": "instance-state-name", "Values": ["running"]}] + ) + instance_id = response["Reservations"][0]["Instances"][0]["InstanceId"] + image_id = conn.create_image(InstanceId=instance_id, Name="TestImage")["ImageId"] - USER1 = '123456789011' - USER2 = '113355789012' + USER1 = "123456789011" + USER2 = "113355789012" - ADD_USER_ARGS = {'ImageId': image_id, - 'Attribute': 'launchPermission', - 'OperationType': 'add', - 'UserIds': [USER1]} + ADD_USER_ARGS = { + "ImageId": image_id, + "Attribute": "launchPermission", + "OperationType": "add", + "UserIds": [USER1], + } # Add users and get no images # Add users and get no images conn.modify_image_attribute(**ADD_USER_ARGS) - attributes = conn.describe_image_attribute(ImageId=image_id, - Attribute='LaunchPermissions', - DryRun=False) - attributes['LaunchPermissions'].should.have.length_of(1) - attributes['LaunchPermissions'][0]['UserId'].should.equal(USER1) - images = conn.describe_images(ExecutableUsers=[USER2])['Images'] + attributes = conn.describe_image_attribute( + ImageId=image_id, Attribute="LaunchPermissions", DryRun=False + ) + attributes["LaunchPermissions"].should.have.length_of(1) + attributes["LaunchPermissions"][0]["UserId"].should.equal(USER1) + images = conn.describe_images(ExecutableUsers=[USER2])["Images"] images.should.have.length_of(0) @mock_ec2 def test_ami_describe_executable_users_and_filter(): - conn = boto3.client('ec2', region_name='us-east-1') - ec2 = boto3.resource('ec2', 'us-east-1') - ec2.create_instances(ImageId='', - MinCount=1, - MaxCount=1) - response = conn.describe_instances(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]) - instance_id = response['Reservations'][0]['Instances'][0]['InstanceId'] - image_id = conn.create_image(InstanceId=instance_id, - Name='ImageToDelete', )['ImageId'] + conn = boto3.client("ec2", region_name="us-east-1") + ec2 = boto3.resource("ec2", "us-east-1") + ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + response = conn.describe_instances( + Filters=[{"Name": "instance-state-name", "Values": ["running"]}] + ) + instance_id = response["Reservations"][0]["Instances"][0]["InstanceId"] + image_id = conn.create_image(InstanceId=instance_id, Name="ImageToDelete")[ + "ImageId" + ] - USER1 = '123456789011' + USER1 = "123456789011" - ADD_USER_ARGS = {'ImageId': image_id, - 'Attribute': 'launchPermission', - 'OperationType': 'add', - 'UserIds': [USER1]} + ADD_USER_ARGS = { + "ImageId": image_id, + "Attribute": "launchPermission", + "OperationType": "add", + "UserIds": [USER1], + } # Add users and get no images conn.modify_image_attribute(**ADD_USER_ARGS) - attributes = conn.describe_image_attribute(ImageId=image_id, - Attribute='LaunchPermissions', - DryRun=False) - attributes['LaunchPermissions'].should.have.length_of(1) - attributes['LaunchPermissions'][0]['UserId'].should.equal(USER1) - images = conn.describe_images(ExecutableUsers=[USER1], - Filters=[{'Name': 'state', 'Values': ['available']}])['Images'] + attributes = conn.describe_image_attribute( + ImageId=image_id, Attribute="LaunchPermissions", DryRun=False + ) + attributes["LaunchPermissions"].should.have.length_of(1) + attributes["LaunchPermissions"][0]["UserId"].should.equal(USER1) + images = conn.describe_images( + ExecutableUsers=[USER1], Filters=[{"Name": "state", "Values": ["available"]}] + )["Images"] images.should.have.length_of(1) - images[0]['ImageId'].should.equal(image_id) + images[0]["ImageId"].should.equal(image_id) @mock_ec2_deprecated @@ -575,49 +606,50 @@ def test_ami_attribute_user_and_group_permissions(): Just spot-check this -- input variations, idempotency, etc are validated via user-specific and group-specific tests above. """ - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Baseline - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.name.should.equal('launch_permission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.name.should.equal("launch_permission") attributes.attrs.should.have.length_of(0) - USER1 = '123456789011' - USER2 = '123456789022' + USER1 = "123456789011" + USER2 = "123456789022" - ADD_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'add', - 'groups': ['all'], - 'user_ids': [USER1, USER2]} + ADD_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "add", + "groups": ["all"], + "user_ids": [USER1, USER2], + } - REMOVE_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'groups': ['all'], - 'user_ids': [USER1, USER2]} + REMOVE_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "groups": ["all"], + "user_ids": [USER1, USER2], + } # Add and confirm conn.modify_image_attribute(**ADD_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['user_ids'].should.have.length_of(2) - set(attributes.attrs['user_ids']).should.equal(set([USER1, USER2])) - set(attributes.attrs['groups']).should.equal(set(['all'])) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["user_ids"].should.have.length_of(2) + set(attributes.attrs["user_ids"]).should.equal(set([USER1, USER2])) + set(attributes.attrs["groups"]).should.equal(set(["all"])) image = conn.get_image(image_id) image.is_public.should.equal(True) # Remove and confirm conn.modify_image_attribute(**REMOVE_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) image = conn.get_image(image_id) image.is_public.should.equal(False) @@ -625,130 +657,138 @@ def test_ami_attribute_user_and_group_permissions(): @mock_ec2_deprecated def test_ami_attribute_error_cases(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Error: Add with group != 'all' with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - groups='everyone') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, attribute="launchPermission", operation="add", groups="everyone" + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with user ID that isn't an integer. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids='12345678901A') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids="12345678901A", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with user ID that is > length 12. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids='1234567890123') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids="1234567890123", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with user ID that is < length 12. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids='12345678901') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids="12345678901", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with one invalid user ID among other valid IDs, ensure no # partial changes. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids=['123456789011', 'foo', '123456789022']) - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids=["123456789011", "foo", "123456789022"], + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) # Error: Add with invalid image ID with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute("ami-abcd1234", - attribute='launchPermission', - operation='add', - groups='all') - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.modify_image_attribute( + "ami-abcd1234", attribute="launchPermission", operation="add", groups="all" + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Remove with invalid image ID with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute("ami-abcd1234", - attribute='launchPermission', - operation='remove', - groups='all') - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.modify_image_attribute( + "ami-abcd1234", + attribute="launchPermission", + operation="remove", + groups="all", + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_ami_describe_non_existent(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Valid pattern but non-existent id - img = ec2.Image('ami-abcd1234') + img = ec2.Image("ami-abcd1234") with assert_raises(ClientError): img.load() # Invalid ami pattern - img = ec2.Image('not_an_ami_id') + img = ec2.Image("not_an_ami_id") with assert_raises(ClientError): img.load() @mock_ec2 def test_ami_filter_wildcard(): - ec2_resource = boto3.resource('ec2', region_name='us-west-1') - ec2_client = boto3.client('ec2', region_name='us-west-1') + ec2_resource = boto3.resource("ec2", region_name="us-west-1") + ec2_client = boto3.client("ec2", region_name="us-west-1") - instance = ec2_resource.create_instances(ImageId='ami-1234abcd', MinCount=1, MaxCount=1)[0] - instance.create_image(Name='test-image') + instance = ec2_resource.create_instances( + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 + )[0] + instance.create_image(Name="test-image") # create an image with the same owner but will not match the filter - instance.create_image(Name='not-matching-image') + instance.create_image(Name="not-matching-image") my_images = ec2_client.describe_images( - Owners=['111122223333'], - Filters=[{'Name': 'name', 'Values': ['test*']}] - )['Images'] + Owners=["111122223333"], Filters=[{"Name": "name", "Values": ["test*"]}] + )["Images"] my_images.should.have.length_of(1) @mock_ec2 def test_ami_filter_by_owner_id(): - client = boto3.client('ec2', region_name='us-east-1') + client = boto3.client("ec2", region_name="us-east-1") - ubuntu_id = '099720109477' + ubuntu_id = "099720109477" ubuntu_images = client.describe_images(Owners=[ubuntu_id]) all_images = client.describe_images() - ubuntu_ids = [ami['OwnerId'] for ami in ubuntu_images['Images']] - all_ids = [ami['OwnerId'] for ami in all_images['Images']] + ubuntu_ids = [ami["OwnerId"] for ami in ubuntu_images["Images"]] + all_ids = [ami["OwnerId"] for ami in all_images["Images"]] # Assert all ubuntu_ids are the same and one equals ubuntu_id assert all(ubuntu_ids) and ubuntu_ids[0] == ubuntu_id @@ -758,42 +798,42 @@ def test_ami_filter_by_owner_id(): @mock_ec2 def test_ami_filter_by_self(): - ec2_resource = boto3.resource('ec2', region_name='us-west-1') - ec2_client = boto3.client('ec2', region_name='us-west-1') + ec2_resource = boto3.resource("ec2", region_name="us-west-1") + ec2_client = boto3.client("ec2", region_name="us-west-1") - my_images = ec2_client.describe_images(Owners=['self'])['Images'] + my_images = ec2_client.describe_images(Owners=["self"])["Images"] my_images.should.have.length_of(0) # Create a new image - instance = ec2_resource.create_instances(ImageId='ami-1234abcd', MinCount=1, MaxCount=1)[0] - instance.create_image(Name='test-image') + instance = ec2_resource.create_instances( + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 + )[0] + instance.create_image(Name="test-image") - my_images = ec2_client.describe_images(Owners=['self'])['Images'] + my_images = ec2_client.describe_images(Owners=["self"])["Images"] my_images.should.have.length_of(1) @mock_ec2 def test_ami_snapshots_have_correct_owner(): - ec2_client = boto3.client('ec2', region_name='us-west-1') + ec2_client = boto3.client("ec2", region_name="us-west-1") images_response = ec2_client.describe_images() owner_id_to_snapshot_ids = {} - for image in images_response['Images']: - owner_id = image['OwnerId'] + for image in images_response["Images"]: + owner_id = image["OwnerId"] snapshot_ids = [ - block_device_mapping['Ebs']['SnapshotId'] - for block_device_mapping in image['BlockDeviceMappings'] + block_device_mapping["Ebs"]["SnapshotId"] + for block_device_mapping in image["BlockDeviceMappings"] ] existing_snapshot_ids = owner_id_to_snapshot_ids.get(owner_id, []) - owner_id_to_snapshot_ids[owner_id] = ( - existing_snapshot_ids + snapshot_ids - ) + owner_id_to_snapshot_ids[owner_id] = existing_snapshot_ids + snapshot_ids for owner_id in owner_id_to_snapshot_ids: snapshots_rseponse = ec2_client.describe_snapshots( SnapshotIds=owner_id_to_snapshot_ids[owner_id] ) - for snapshot in snapshots_rseponse['Snapshots']: - assert owner_id == snapshot['OwnerId'] + for snapshot in snapshots_rseponse["Snapshots"]: + assert owner_id == snapshot["OwnerId"] diff --git a/tests/test_ec2/test_availability_zones_and_regions.py b/tests/test_ec2/test_availability_zones_and_regions.py index 0c94687fa..349be7936 100644 --- a/tests/test_ec2/test_availability_zones_and_regions.py +++ b/tests/test_ec2/test_availability_zones_and_regions.py @@ -1,54 +1,54 @@ -from __future__ import unicode_literals -import boto -import boto.ec2 -import boto3 -import sure # noqa - -from moto import mock_ec2, mock_ec2_deprecated - - -@mock_ec2_deprecated -def test_describe_regions(): - conn = boto.connect_ec2('the_key', 'the_secret') - regions = conn.get_all_regions() - regions.should.have.length_of(16) - for region in regions: - region.endpoint.should.contain(region.name) - - -@mock_ec2_deprecated -def test_availability_zones(): - conn = boto.connect_ec2('the_key', 'the_secret') - regions = conn.get_all_regions() - for region in regions: - conn = boto.ec2.connect_to_region(region.name) - if conn is None: - continue - for zone in conn.get_all_zones(): - zone.name.should.contain(region.name) - - -@mock_ec2 -def test_boto3_describe_regions(): - ec2 = boto3.client('ec2', 'us-east-1') - resp = ec2.describe_regions() - resp['Regions'].should.have.length_of(16) - for rec in resp['Regions']: - rec['Endpoint'].should.contain(rec['RegionName']) - - test_region = 'us-east-1' - resp = ec2.describe_regions(RegionNames=[test_region]) - resp['Regions'].should.have.length_of(1) - resp['Regions'][0].should.have.key('RegionName').which.should.equal(test_region) - - -@mock_ec2 -def test_boto3_availability_zones(): - ec2 = boto3.client('ec2', 'us-east-1') - resp = ec2.describe_regions() - regions = [r['RegionName'] for r in resp['Regions']] - for region in regions: - conn = boto3.client('ec2', region) - resp = conn.describe_availability_zones() - for rec in resp['AvailabilityZones']: - rec['ZoneName'].should.contain(region) +from __future__ import unicode_literals +import boto +import boto.ec2 +import boto3 +import sure # noqa + +from moto import mock_ec2, mock_ec2_deprecated + + +@mock_ec2_deprecated +def test_describe_regions(): + conn = boto.connect_ec2("the_key", "the_secret") + regions = conn.get_all_regions() + regions.should.have.length_of(16) + for region in regions: + region.endpoint.should.contain(region.name) + + +@mock_ec2_deprecated +def test_availability_zones(): + conn = boto.connect_ec2("the_key", "the_secret") + regions = conn.get_all_regions() + for region in regions: + conn = boto.ec2.connect_to_region(region.name) + if conn is None: + continue + for zone in conn.get_all_zones(): + zone.name.should.contain(region.name) + + +@mock_ec2 +def test_boto3_describe_regions(): + ec2 = boto3.client("ec2", "us-east-1") + resp = ec2.describe_regions() + resp["Regions"].should.have.length_of(16) + for rec in resp["Regions"]: + rec["Endpoint"].should.contain(rec["RegionName"]) + + test_region = "us-east-1" + resp = ec2.describe_regions(RegionNames=[test_region]) + resp["Regions"].should.have.length_of(1) + resp["Regions"][0].should.have.key("RegionName").which.should.equal(test_region) + + +@mock_ec2 +def test_boto3_availability_zones(): + ec2 = boto3.client("ec2", "us-east-1") + resp = ec2.describe_regions() + regions = [r["RegionName"] for r in resp["Regions"]] + for region in regions: + conn = boto3.client("ec2", region) + resp = conn.describe_availability_zones() + for rec in resp["AvailabilityZones"]: + rec["ZoneName"].should.contain(region) diff --git a/tests/test_ec2/test_customer_gateways.py b/tests/test_ec2/test_customer_gateways.py index 82e316723..a676a2b5d 100644 --- a/tests/test_ec2/test_customer_gateways.py +++ b/tests/test_ec2/test_customer_gateways.py @@ -1,52 +1,49 @@ -from __future__ import unicode_literals -import boto -import sure # noqa -from nose.tools import assert_raises -from nose.tools import assert_false -from boto.exception import EC2ResponseError - -from moto import mock_ec2_deprecated - - -@mock_ec2_deprecated -def test_create_customer_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') - - customer_gateway = conn.create_customer_gateway( - 'ipsec.1', '205.251.242.54', 65534) - customer_gateway.should_not.be.none - customer_gateway.id.should.match(r'cgw-\w+') - customer_gateway.type.should.equal('ipsec.1') - customer_gateway.bgp_asn.should.equal(65534) - customer_gateway.ip_address.should.equal('205.251.242.54') - - -@mock_ec2_deprecated -def test_describe_customer_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') - customer_gateway = conn.create_customer_gateway( - 'ipsec.1', '205.251.242.54', 65534) - cgws = conn.get_all_customer_gateways() - cgws.should.have.length_of(1) - cgws[0].id.should.match(customer_gateway.id) - - -@mock_ec2_deprecated -def test_delete_customer_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') - - customer_gateway = conn.create_customer_gateway( - 'ipsec.1', '205.251.242.54', 65534) - customer_gateway.should_not.be.none - cgws = conn.get_all_customer_gateways() - cgws[0].id.should.match(customer_gateway.id) - deleted = conn.delete_customer_gateway(customer_gateway.id) - cgws = conn.get_all_customer_gateways() - cgws.should.have.length_of(0) - - -@mock_ec2_deprecated -def test_delete_customer_gateways_bad_id(): - conn = boto.connect_vpc('the_key', 'the_secret') - with assert_raises(EC2ResponseError) as cm: - conn.delete_customer_gateway('cgw-0123abcd') +from __future__ import unicode_literals +import boto +import sure # noqa +from nose.tools import assert_raises +from nose.tools import assert_false +from boto.exception import EC2ResponseError + +from moto import mock_ec2_deprecated + + +@mock_ec2_deprecated +def test_create_customer_gateways(): + conn = boto.connect_vpc("the_key", "the_secret") + + customer_gateway = conn.create_customer_gateway("ipsec.1", "205.251.242.54", 65534) + customer_gateway.should_not.be.none + customer_gateway.id.should.match(r"cgw-\w+") + customer_gateway.type.should.equal("ipsec.1") + customer_gateway.bgp_asn.should.equal(65534) + customer_gateway.ip_address.should.equal("205.251.242.54") + + +@mock_ec2_deprecated +def test_describe_customer_gateways(): + conn = boto.connect_vpc("the_key", "the_secret") + customer_gateway = conn.create_customer_gateway("ipsec.1", "205.251.242.54", 65534) + cgws = conn.get_all_customer_gateways() + cgws.should.have.length_of(1) + cgws[0].id.should.match(customer_gateway.id) + + +@mock_ec2_deprecated +def test_delete_customer_gateways(): + conn = boto.connect_vpc("the_key", "the_secret") + + customer_gateway = conn.create_customer_gateway("ipsec.1", "205.251.242.54", 65534) + customer_gateway.should_not.be.none + cgws = conn.get_all_customer_gateways() + cgws[0].id.should.match(customer_gateway.id) + deleted = conn.delete_customer_gateway(customer_gateway.id) + cgws = conn.get_all_customer_gateways() + cgws.should.have.length_of(0) + + +@mock_ec2_deprecated +def test_delete_customer_gateways_bad_id(): + conn = boto.connect_vpc("the_key", "the_secret") + with assert_raises(EC2ResponseError) as cm: + conn.delete_customer_gateway("cgw-0123abcd") diff --git a/tests/test_ec2/test_dhcp_options.py b/tests/test_ec2/test_dhcp_options.py index 2aff803ae..4aaceaa07 100644 --- a/tests/test_ec2/test_dhcp_options.py +++ b/tests/test_ec2/test_dhcp_options.py @@ -1,333 +1,337 @@ -from __future__ import unicode_literals -# Ensure 'assert_raises' context manager support for Python 2.6 -import tests.backport_assert_raises -from nose.tools import assert_raises - -import boto3 -import boto -from boto.exception import EC2ResponseError - -import sure # noqa - -from moto import mock_ec2, mock_ec2_deprecated - -SAMPLE_DOMAIN_NAME = u'example.com' -SAMPLE_NAME_SERVERS = [u'10.0.0.6', u'10.0.0.7'] - - -@mock_ec2_deprecated -def test_dhcp_options_associate(): - """ associate dhcp option """ - conn = boto.connect_vpc('the_key', 'the_secret') - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) - vpc = conn.create_vpc("10.0.0.0/16") - - rval = conn.associate_dhcp_options(dhcp_options.id, vpc.id) - rval.should.be.equal(True) - - -@mock_ec2_deprecated -def test_dhcp_options_associate_invalid_dhcp_id(): - """ associate dhcp option bad dhcp options id """ - conn = boto.connect_vpc('the_key', 'the_secret') - vpc = conn.create_vpc("10.0.0.0/16") - - with assert_raises(EC2ResponseError) as cm: - conn.associate_dhcp_options("foo", vpc.id) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_dhcp_options_associate_invalid_vpc_id(): - """ associate dhcp option invalid vpc id """ - conn = boto.connect_vpc('the_key', 'the_secret') - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) - - with assert_raises(EC2ResponseError) as cm: - conn.associate_dhcp_options(dhcp_options.id, "foo") - cm.exception.code.should.equal('InvalidVpcID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_dhcp_options_delete_with_vpc(): - """Test deletion of dhcp options with vpc""" - conn = boto.connect_vpc('the_key', 'the_secret') - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) - dhcp_options_id = dhcp_options.id - vpc = conn.create_vpc("10.0.0.0/16") - - rval = conn.associate_dhcp_options(dhcp_options_id, vpc.id) - rval.should.be.equal(True) - - with assert_raises(EC2ResponseError) as cm: - conn.delete_dhcp_options(dhcp_options_id) - cm.exception.code.should.equal('DependencyViolation') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - vpc.delete() - - with assert_raises(EC2ResponseError) as cm: - conn.get_all_dhcp_options([dhcp_options_id]) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_create_dhcp_options(): - """Create most basic dhcp option""" - conn = boto.connect_vpc('the_key', 'the_secret') - - dhcp_option = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) - dhcp_option.options[u'domain-name'][0].should.be.equal(SAMPLE_DOMAIN_NAME) - dhcp_option.options[ - u'domain-name-servers'][0].should.be.equal(SAMPLE_NAME_SERVERS[0]) - dhcp_option.options[ - u'domain-name-servers'][1].should.be.equal(SAMPLE_NAME_SERVERS[1]) - - -@mock_ec2_deprecated -def test_create_dhcp_options_invalid_options(): - """Create invalid dhcp options""" - conn = boto.connect_vpc('the_key', 'the_secret') - servers = ["f", "f", "f", "f", "f"] - - with assert_raises(EC2ResponseError) as cm: - conn.create_dhcp_options(ntp_servers=servers) - cm.exception.code.should.equal('InvalidParameterValue') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - with assert_raises(EC2ResponseError) as cm: - conn.create_dhcp_options(netbios_node_type="0") - cm.exception.code.should.equal('InvalidParameterValue') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_describe_dhcp_options(): - """Test dhcp options lookup by id""" - conn = boto.connect_vpc('the_key', 'the_secret') - - dhcp_option = conn.create_dhcp_options() - dhcp_options = conn.get_all_dhcp_options([dhcp_option.id]) - dhcp_options.should.be.length_of(1) - - dhcp_options = conn.get_all_dhcp_options() - dhcp_options.should.be.length_of(1) - - -@mock_ec2_deprecated -def test_describe_dhcp_options_invalid_id(): - """get error on invalid dhcp_option_id lookup""" - conn = boto.connect_vpc('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.get_all_dhcp_options(["1"]) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_delete_dhcp_options(): - """delete dhcp option""" - conn = boto.connect_vpc('the_key', 'the_secret') - - dhcp_option = conn.create_dhcp_options() - dhcp_options = conn.get_all_dhcp_options([dhcp_option.id]) - dhcp_options.should.be.length_of(1) - - conn.delete_dhcp_options(dhcp_option.id) # .should.be.equal(True) - - with assert_raises(EC2ResponseError) as cm: - conn.get_all_dhcp_options([dhcp_option.id]) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_delete_dhcp_options_invalid_id(): - conn = boto.connect_vpc('the_key', 'the_secret') - - conn.create_dhcp_options() - - with assert_raises(EC2ResponseError) as cm: - conn.delete_dhcp_options("dopt-abcd1234") - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_delete_dhcp_options_malformed_id(): - conn = boto.connect_vpc('the_key', 'the_secret') - - conn.create_dhcp_options() - - with assert_raises(EC2ResponseError) as cm: - conn.delete_dhcp_options("foo-abcd1234") - cm.exception.code.should.equal('InvalidDhcpOptionsId.Malformed') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_dhcp_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') - dhcp_option = conn.create_dhcp_options() - - dhcp_option.add_tag("a key", "some value") - - tag = conn.get_all_tags()[0] - tag.name.should.equal("a key") - tag.value.should.equal("some value") - - # Refresh the DHCP options - dhcp_option = conn.get_all_dhcp_options()[0] - dhcp_option.tags.should.have.length_of(1) - dhcp_option.tags["a key"].should.equal("some value") - - -@mock_ec2_deprecated -def test_dhcp_options_get_by_tag(): - conn = boto.connect_vpc('the_key', 'the_secret') - - dhcp1 = conn.create_dhcp_options('example.com', ['10.0.10.2']) - dhcp1.add_tag('Name', 'TestDhcpOptions1') - dhcp1.add_tag('test-tag', 'test-value') - - dhcp2 = conn.create_dhcp_options('example.com', ['10.0.20.2']) - dhcp2.add_tag('Name', 'TestDhcpOptions2') - dhcp2.add_tag('test-tag', 'test-value') - - filters = {'tag:Name': 'TestDhcpOptions1', 'tag:test-tag': 'test-value'} - dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) - - dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options[ - 'domain-name'][0].should.be.equal('example.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.10.2') - dhcp_options_sets[0].tags['Name'].should.equal('TestDhcpOptions1') - dhcp_options_sets[0].tags['test-tag'].should.equal('test-value') - - filters = {'tag:Name': 'TestDhcpOptions2', 'tag:test-tag': 'test-value'} - dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) - - dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options[ - 'domain-name'][0].should.be.equal('example.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.20.2') - dhcp_options_sets[0].tags['Name'].should.equal('TestDhcpOptions2') - dhcp_options_sets[0].tags['test-tag'].should.equal('test-value') - - filters = {'tag:test-tag': 'test-value'} - dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) - - dhcp_options_sets.should.have.length_of(2) - - -@mock_ec2_deprecated -def test_dhcp_options_get_by_id(): - conn = boto.connect_vpc('the_key', 'the_secret') - - dhcp1 = conn.create_dhcp_options('test1.com', ['10.0.10.2']) - dhcp1.add_tag('Name', 'TestDhcpOptions1') - dhcp1.add_tag('test-tag', 'test-value') - dhcp1_id = dhcp1.id - - dhcp2 = conn.create_dhcp_options('test2.com', ['10.0.20.2']) - dhcp2.add_tag('Name', 'TestDhcpOptions2') - dhcp2.add_tag('test-tag', 'test-value') - dhcp2_id = dhcp2.id - - dhcp_options_sets = conn.get_all_dhcp_options() - dhcp_options_sets.should.have.length_of(2) - - dhcp_options_sets = conn.get_all_dhcp_options( - filters={'dhcp-options-id': dhcp1_id}) - - dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options['domain-name'][0].should.be.equal('test1.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.10.2') - - dhcp_options_sets = conn.get_all_dhcp_options( - filters={'dhcp-options-id': dhcp2_id}) - - dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options['domain-name'][0].should.be.equal('test2.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.20.2') - - -@mock_ec2 -def test_dhcp_options_get_by_value_filter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.10.2']} - ]) - - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.20.2']} - ]) - - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.30.2']} - ]) - - filters = [{'Name': 'value', 'Values': ['10.0.10.2']}] - dhcp_options_sets = list(ec2.dhcp_options_sets.filter(Filters=filters)) - dhcp_options_sets.should.have.length_of(1) - - -@mock_ec2 -def test_dhcp_options_get_by_key_filter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.10.2']} - ]) - - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.20.2']} - ]) - - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.30.2']} - ]) - - filters = [{'Name': 'key', 'Values': ['domain-name']}] - dhcp_options_sets = list(ec2.dhcp_options_sets.filter(Filters=filters)) - dhcp_options_sets.should.have.length_of(3) - - -@mock_ec2_deprecated -def test_dhcp_options_get_by_invalid_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') - - conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) - filters = {'invalid-filter': 'invalid-value'} - - conn.get_all_dhcp_options.when.called_with( - filters=filters).should.throw(NotImplementedError) +from __future__ import unicode_literals + +# Ensure 'assert_raises' context manager support for Python 2.6 +import tests.backport_assert_raises +from nose.tools import assert_raises + +import boto3 +import boto +from boto.exception import EC2ResponseError + +import sure # noqa + +from moto import mock_ec2, mock_ec2_deprecated + +SAMPLE_DOMAIN_NAME = "example.com" +SAMPLE_NAME_SERVERS = ["10.0.0.6", "10.0.0.7"] + + +@mock_ec2_deprecated +def test_dhcp_options_associate(): + """ associate dhcp option """ + conn = boto.connect_vpc("the_key", "the_secret") + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + vpc = conn.create_vpc("10.0.0.0/16") + + rval = conn.associate_dhcp_options(dhcp_options.id, vpc.id) + rval.should.be.equal(True) + + +@mock_ec2_deprecated +def test_dhcp_options_associate_invalid_dhcp_id(): + """ associate dhcp option bad dhcp options id """ + conn = boto.connect_vpc("the_key", "the_secret") + vpc = conn.create_vpc("10.0.0.0/16") + + with assert_raises(EC2ResponseError) as cm: + conn.associate_dhcp_options("foo", vpc.id) + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_dhcp_options_associate_invalid_vpc_id(): + """ associate dhcp option invalid vpc id """ + conn = boto.connect_vpc("the_key", "the_secret") + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + + with assert_raises(EC2ResponseError) as cm: + conn.associate_dhcp_options(dhcp_options.id, "foo") + cm.exception.code.should.equal("InvalidVpcID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_dhcp_options_delete_with_vpc(): + """Test deletion of dhcp options with vpc""" + conn = boto.connect_vpc("the_key", "the_secret") + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + dhcp_options_id = dhcp_options.id + vpc = conn.create_vpc("10.0.0.0/16") + + rval = conn.associate_dhcp_options(dhcp_options_id, vpc.id) + rval.should.be.equal(True) + + with assert_raises(EC2ResponseError) as cm: + conn.delete_dhcp_options(dhcp_options_id) + cm.exception.code.should.equal("DependencyViolation") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + vpc.delete() + + with assert_raises(EC2ResponseError) as cm: + conn.get_all_dhcp_options([dhcp_options_id]) + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_create_dhcp_options(): + """Create most basic dhcp option""" + conn = boto.connect_vpc("the_key", "the_secret") + + dhcp_option = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + dhcp_option.options["domain-name"][0].should.be.equal(SAMPLE_DOMAIN_NAME) + dhcp_option.options["domain-name-servers"][0].should.be.equal( + SAMPLE_NAME_SERVERS[0] + ) + dhcp_option.options["domain-name-servers"][1].should.be.equal( + SAMPLE_NAME_SERVERS[1] + ) + + +@mock_ec2_deprecated +def test_create_dhcp_options_invalid_options(): + """Create invalid dhcp options""" + conn = boto.connect_vpc("the_key", "the_secret") + servers = ["f", "f", "f", "f", "f"] + + with assert_raises(EC2ResponseError) as cm: + conn.create_dhcp_options(ntp_servers=servers) + cm.exception.code.should.equal("InvalidParameterValue") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + with assert_raises(EC2ResponseError) as cm: + conn.create_dhcp_options(netbios_node_type="0") + cm.exception.code.should.equal("InvalidParameterValue") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_describe_dhcp_options(): + """Test dhcp options lookup by id""" + conn = boto.connect_vpc("the_key", "the_secret") + + dhcp_option = conn.create_dhcp_options() + dhcp_options = conn.get_all_dhcp_options([dhcp_option.id]) + dhcp_options.should.be.length_of(1) + + dhcp_options = conn.get_all_dhcp_options() + dhcp_options.should.be.length_of(1) + + +@mock_ec2_deprecated +def test_describe_dhcp_options_invalid_id(): + """get error on invalid dhcp_option_id lookup""" + conn = boto.connect_vpc("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.get_all_dhcp_options(["1"]) + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_delete_dhcp_options(): + """delete dhcp option""" + conn = boto.connect_vpc("the_key", "the_secret") + + dhcp_option = conn.create_dhcp_options() + dhcp_options = conn.get_all_dhcp_options([dhcp_option.id]) + dhcp_options.should.be.length_of(1) + + conn.delete_dhcp_options(dhcp_option.id) # .should.be.equal(True) + + with assert_raises(EC2ResponseError) as cm: + conn.get_all_dhcp_options([dhcp_option.id]) + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_delete_dhcp_options_invalid_id(): + conn = boto.connect_vpc("the_key", "the_secret") + + conn.create_dhcp_options() + + with assert_raises(EC2ResponseError) as cm: + conn.delete_dhcp_options("dopt-abcd1234") + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_delete_dhcp_options_malformed_id(): + conn = boto.connect_vpc("the_key", "the_secret") + + conn.create_dhcp_options() + + with assert_raises(EC2ResponseError) as cm: + conn.delete_dhcp_options("foo-abcd1234") + cm.exception.code.should.equal("InvalidDhcpOptionsId.Malformed") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_dhcp_tagging(): + conn = boto.connect_vpc("the_key", "the_secret") + dhcp_option = conn.create_dhcp_options() + + dhcp_option.add_tag("a key", "some value") + + tag = conn.get_all_tags()[0] + tag.name.should.equal("a key") + tag.value.should.equal("some value") + + # Refresh the DHCP options + dhcp_option = conn.get_all_dhcp_options()[0] + dhcp_option.tags.should.have.length_of(1) + dhcp_option.tags["a key"].should.equal("some value") + + +@mock_ec2_deprecated +def test_dhcp_options_get_by_tag(): + conn = boto.connect_vpc("the_key", "the_secret") + + dhcp1 = conn.create_dhcp_options("example.com", ["10.0.10.2"]) + dhcp1.add_tag("Name", "TestDhcpOptions1") + dhcp1.add_tag("test-tag", "test-value") + + dhcp2 = conn.create_dhcp_options("example.com", ["10.0.20.2"]) + dhcp2.add_tag("Name", "TestDhcpOptions2") + dhcp2.add_tag("test-tag", "test-value") + + filters = {"tag:Name": "TestDhcpOptions1", "tag:test-tag": "test-value"} + dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) + + dhcp_options_sets.should.have.length_of(1) + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("example.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.10.2") + dhcp_options_sets[0].tags["Name"].should.equal("TestDhcpOptions1") + dhcp_options_sets[0].tags["test-tag"].should.equal("test-value") + + filters = {"tag:Name": "TestDhcpOptions2", "tag:test-tag": "test-value"} + dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) + + dhcp_options_sets.should.have.length_of(1) + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("example.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.20.2") + dhcp_options_sets[0].tags["Name"].should.equal("TestDhcpOptions2") + dhcp_options_sets[0].tags["test-tag"].should.equal("test-value") + + filters = {"tag:test-tag": "test-value"} + dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) + + dhcp_options_sets.should.have.length_of(2) + + +@mock_ec2_deprecated +def test_dhcp_options_get_by_id(): + conn = boto.connect_vpc("the_key", "the_secret") + + dhcp1 = conn.create_dhcp_options("test1.com", ["10.0.10.2"]) + dhcp1.add_tag("Name", "TestDhcpOptions1") + dhcp1.add_tag("test-tag", "test-value") + dhcp1_id = dhcp1.id + + dhcp2 = conn.create_dhcp_options("test2.com", ["10.0.20.2"]) + dhcp2.add_tag("Name", "TestDhcpOptions2") + dhcp2.add_tag("test-tag", "test-value") + dhcp2_id = dhcp2.id + + dhcp_options_sets = conn.get_all_dhcp_options() + dhcp_options_sets.should.have.length_of(2) + + dhcp_options_sets = conn.get_all_dhcp_options(filters={"dhcp-options-id": dhcp1_id}) + + dhcp_options_sets.should.have.length_of(1) + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("test1.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.10.2") + + dhcp_options_sets = conn.get_all_dhcp_options(filters={"dhcp-options-id": dhcp2_id}) + + dhcp_options_sets.should.have.length_of(1) + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("test2.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.20.2") + + +@mock_ec2 +def test_dhcp_options_get_by_value_filter(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.10.2"]}, + ] + ) + + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.20.2"]}, + ] + ) + + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.30.2"]}, + ] + ) + + filters = [{"Name": "value", "Values": ["10.0.10.2"]}] + dhcp_options_sets = list(ec2.dhcp_options_sets.filter(Filters=filters)) + dhcp_options_sets.should.have.length_of(1) + + +@mock_ec2 +def test_dhcp_options_get_by_key_filter(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.10.2"]}, + ] + ) + + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.20.2"]}, + ] + ) + + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.30.2"]}, + ] + ) + + filters = [{"Name": "key", "Values": ["domain-name"]}] + dhcp_options_sets = list(ec2.dhcp_options_sets.filter(Filters=filters)) + dhcp_options_sets.should.have.length_of(3) + + +@mock_ec2_deprecated +def test_dhcp_options_get_by_invalid_filter(): + conn = boto.connect_vpc("the_key", "the_secret") + + conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + filters = {"invalid-filter": "invalid-value"} + + conn.get_all_dhcp_options.when.called_with(filters=filters).should.throw( + NotImplementedError + ) diff --git a/tests/test_ec2/test_elastic_block_store.py b/tests/test_ec2/test_elastic_block_store.py index 9dbaa5ea6..3c7e17ec8 100644 --- a/tests/test_ec2/test_elastic_block_store.py +++ b/tests/test_ec2/test_elastic_block_store.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -32,10 +33,11 @@ def test_create_and_delete_volume(): with assert_raises(EC2ResponseError) as ex: volume.delete(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteVolume operation: Request would have succeeded, but DryRun flag is set" + ) volume.delete() @@ -46,7 +48,7 @@ def test_create_and_delete_volume(): # Deleting something that was already deleted should throw an error with assert_raises(EC2ResponseError) as cm: volume.delete() - cm.exception.code.should.equal('InvalidVolume.NotFound') + cm.exception.code.should.equal("InvalidVolume.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -56,10 +58,11 @@ def test_create_encrypted_volume_dryrun(): conn = boto.ec2.connect_to_region("us-east-1") with assert_raises(EC2ResponseError) as ex: conn.create_volume(80, "us-east-1a", encrypted=True, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set" + ) @mock_ec2_deprecated @@ -69,10 +72,11 @@ def test_create_encrypted_volume(): with assert_raises(EC2ResponseError) as ex: conn.create_volume(80, "us-east-1a", encrypted=True, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set" + ) all_volumes = [vol for vol in conn.get_all_volumes() if vol.id == volume.id] all_volumes[0].encrypted.should.be(True) @@ -87,13 +91,13 @@ def test_filter_volume_by_id(): vol1 = conn.get_all_volumes(volume_ids=volume3.id) vol1.should.have.length_of(1) vol1[0].size.should.equal(20) - vol1[0].zone.should.equal('us-east-1c') + vol1[0].zone.should.equal("us-east-1c") vol2 = conn.get_all_volumes(volume_ids=[volume1.id, volume2.id]) vol2.should.have.length_of(2) with assert_raises(EC2ResponseError) as cm: - conn.get_all_volumes(volume_ids=['vol-does_not_exist']) - cm.exception.code.should.equal('InvalidVolume.NotFound') + conn.get_all_volumes(volume_ids=["vol-does_not_exist"]) + cm.exception.code.should.equal("InvalidVolume.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -102,7 +106,7 @@ def test_filter_volume_by_id(): def test_volume_filters(): conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.update() @@ -111,142 +115,155 @@ def test_volume_filters(): volume2 = conn.create_volume(36, "us-east-1b", encrypted=False) volume3 = conn.create_volume(20, "us-east-1c", encrypted=True) - snapshot = volume3.create_snapshot(description='testsnap') + snapshot = volume3.create_snapshot(description="testsnap") volume4 = conn.create_volume(25, "us-east-1a", snapshot=snapshot) - conn.create_tags([volume1.id], {'testkey1': 'testvalue1'}) - conn.create_tags([volume2.id], {'testkey2': 'testvalue2'}) + conn.create_tags([volume1.id], {"testkey1": "testvalue1"}) + conn.create_tags([volume2.id], {"testkey2": "testvalue2"}) volume1.update() volume2.update() volume3.update() volume4.update() - block_mapping = instance.block_device_mapping['/dev/sda1'] + block_mapping = instance.block_device_mapping["/dev/sda1"] - volume_ids = (volume1.id, volume2.id, volume3.id, volume4.id, block_mapping.volume_id) - - volumes_by_attach_time = conn.get_all_volumes( - filters={'attachment.attach-time': block_mapping.attach_time}) - set([vol.id for vol in volumes_by_attach_time] - ).should.equal({block_mapping.volume_id}) - - volumes_by_attach_device = conn.get_all_volumes( - filters={'attachment.device': '/dev/sda1'}) - set([vol.id for vol in volumes_by_attach_device] - ).should.equal({block_mapping.volume_id}) - - volumes_by_attach_instance_id = conn.get_all_volumes( - filters={'attachment.instance-id': instance.id}) - set([vol.id for vol in volumes_by_attach_instance_id] - ).should.equal({block_mapping.volume_id}) - - volumes_by_attach_status = conn.get_all_volumes( - filters={'attachment.status': 'attached'}) - set([vol.id for vol in volumes_by_attach_status] - ).should.equal({block_mapping.volume_id}) - - volumes_by_create_time = conn.get_all_volumes( - filters={'create-time': volume4.create_time}) - set([vol.create_time for vol in volumes_by_create_time] - ).should.equal({volume4.create_time}) - - volumes_by_size = conn.get_all_volumes(filters={'size': volume2.size}) - set([vol.id for vol in volumes_by_size]).should.equal({volume2.id}) - - volumes_by_snapshot_id = conn.get_all_volumes( - filters={'snapshot-id': snapshot.id}) - set([vol.id for vol in volumes_by_snapshot_id] - ).should.equal({volume4.id}) - - volumes_by_status = conn.get_all_volumes(filters={'status': 'in-use'}) - set([vol.id for vol in volumes_by_status]).should.equal( - {block_mapping.volume_id}) - - volumes_by_id = conn.get_all_volumes(filters={'volume-id': volume1.id}) - set([vol.id for vol in volumes_by_id]).should.equal({volume1.id}) - - volumes_by_tag_key = conn.get_all_volumes(filters={'tag-key': 'testkey1'}) - set([vol.id for vol in volumes_by_tag_key]).should.equal({volume1.id}) - - volumes_by_tag_value = conn.get_all_volumes( - filters={'tag-value': 'testvalue1'}) - set([vol.id for vol in volumes_by_tag_value] - ).should.equal({volume1.id}) - - volumes_by_tag = conn.get_all_volumes( - filters={'tag:testkey1': 'testvalue1'}) - set([vol.id for vol in volumes_by_tag]).should.equal({volume1.id}) - - volumes_by_unencrypted = conn.get_all_volumes( - filters={'encrypted': 'false'}) - set([vol.id for vol in volumes_by_unencrypted if vol.id in volume_ids]).should.equal( - {block_mapping.volume_id, volume2.id} + volume_ids = ( + volume1.id, + volume2.id, + volume3.id, + volume4.id, + block_mapping.volume_id, ) - volumes_by_encrypted = conn.get_all_volumes(filters={'encrypted': 'true'}) + volumes_by_attach_time = conn.get_all_volumes( + filters={"attachment.attach-time": block_mapping.attach_time} + ) + set([vol.id for vol in volumes_by_attach_time]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_attach_device = conn.get_all_volumes( + filters={"attachment.device": "/dev/sda1"} + ) + set([vol.id for vol in volumes_by_attach_device]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_attach_instance_id = conn.get_all_volumes( + filters={"attachment.instance-id": instance.id} + ) + set([vol.id for vol in volumes_by_attach_instance_id]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_attach_status = conn.get_all_volumes( + filters={"attachment.status": "attached"} + ) + set([vol.id for vol in volumes_by_attach_status]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_create_time = conn.get_all_volumes( + filters={"create-time": volume4.create_time} + ) + set([vol.create_time for vol in volumes_by_create_time]).should.equal( + {volume4.create_time} + ) + + volumes_by_size = conn.get_all_volumes(filters={"size": volume2.size}) + set([vol.id for vol in volumes_by_size]).should.equal({volume2.id}) + + volumes_by_snapshot_id = conn.get_all_volumes(filters={"snapshot-id": snapshot.id}) + set([vol.id for vol in volumes_by_snapshot_id]).should.equal({volume4.id}) + + volumes_by_status = conn.get_all_volumes(filters={"status": "in-use"}) + set([vol.id for vol in volumes_by_status]).should.equal({block_mapping.volume_id}) + + volumes_by_id = conn.get_all_volumes(filters={"volume-id": volume1.id}) + set([vol.id for vol in volumes_by_id]).should.equal({volume1.id}) + + volumes_by_tag_key = conn.get_all_volumes(filters={"tag-key": "testkey1"}) + set([vol.id for vol in volumes_by_tag_key]).should.equal({volume1.id}) + + volumes_by_tag_value = conn.get_all_volumes(filters={"tag-value": "testvalue1"}) + set([vol.id for vol in volumes_by_tag_value]).should.equal({volume1.id}) + + volumes_by_tag = conn.get_all_volumes(filters={"tag:testkey1": "testvalue1"}) + set([vol.id for vol in volumes_by_tag]).should.equal({volume1.id}) + + volumes_by_unencrypted = conn.get_all_volumes(filters={"encrypted": "false"}) + set( + [vol.id for vol in volumes_by_unencrypted if vol.id in volume_ids] + ).should.equal({block_mapping.volume_id, volume2.id}) + + volumes_by_encrypted = conn.get_all_volumes(filters={"encrypted": "true"}) set([vol.id for vol in volumes_by_encrypted if vol.id in volume_ids]).should.equal( {volume1.id, volume3.id, volume4.id} ) - volumes_by_availability_zone = conn.get_all_volumes(filters={'availability-zone': 'us-east-1b'}) - set([vol.id for vol in volumes_by_availability_zone if vol.id in volume_ids]).should.equal( - {volume2.id} + volumes_by_availability_zone = conn.get_all_volumes( + filters={"availability-zone": "us-east-1b"} ) + set( + [vol.id for vol in volumes_by_availability_zone if vol.id in volume_ids] + ).should.equal({volume2.id}) @mock_ec2_deprecated def test_volume_attach_and_detach(): conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] volume = conn.create_volume(80, "us-east-1a") volume.update() - volume.volume_state().should.equal('available') + volume.volume_state().should.equal("available") with assert_raises(EC2ResponseError) as ex: volume.attach(instance.id, "/dev/sdh", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AttachVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AttachVolume operation: Request would have succeeded, but DryRun flag is set" + ) volume.attach(instance.id, "/dev/sdh") volume.update() - volume.volume_state().should.equal('in-use') - volume.attachment_state().should.equal('attached') + volume.volume_state().should.equal("in-use") + volume.attachment_state().should.equal("attached") volume.attach_data.instance_id.should.equal(instance.id) with assert_raises(EC2ResponseError) as ex: volume.detach(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DetachVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DetachVolume operation: Request would have succeeded, but DryRun flag is set" + ) volume.detach() volume.update() - volume.volume_state().should.equal('available') + volume.volume_state().should.equal("available") with assert_raises(EC2ResponseError) as cm1: - volume.attach('i-1234abcd', "/dev/sdh") - cm1.exception.code.should.equal('InvalidInstanceID.NotFound') + volume.attach("i-1234abcd", "/dev/sdh") + cm1.exception.code.should.equal("InvalidInstanceID.NotFound") cm1.exception.status.should.equal(400) cm1.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as cm2: conn.detach_volume(volume.id, instance.id, "/dev/sdh") - cm2.exception.code.should.equal('InvalidAttachment.NotFound') + cm2.exception.code.should.equal("InvalidAttachment.NotFound") cm2.exception.status.should.equal(400) cm2.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as cm3: - conn.detach_volume(volume.id, 'i-1234abcd', "/dev/sdh") - cm3.exception.code.should.equal('InvalidInstanceID.NotFound') + conn.detach_volume(volume.id, "i-1234abcd", "/dev/sdh") + cm3.exception.code.should.equal("InvalidInstanceID.NotFound") cm3.exception.status.should.equal(400) cm3.exception.request_id.should_not.be.none @@ -257,19 +274,20 @@ def test_create_snapshot(): volume = conn.create_volume(80, "us-east-1a") with assert_raises(EC2ResponseError) as ex: - snapshot = volume.create_snapshot('a dryrun snapshot', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + snapshot = volume.create_snapshot("a dryrun snapshot", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set" + ) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") snapshots = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] snapshots.should.have.length_of(1) - snapshots[0].description.should.equal('a test snapshot') + snapshots[0].description.should.equal("a test snapshot") snapshots[0].start_time.should_not.be.none snapshots[0].encrypted.should.be(False) @@ -285,7 +303,7 @@ def test_create_snapshot(): # Deleting something that was already deleted should throw an error with assert_raises(EC2ResponseError) as cm: snapshot.delete() - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -294,13 +312,13 @@ def test_create_snapshot(): def test_create_encrypted_snapshot(): conn = boto.ec2.connect_to_region("us-east-1") volume = conn.create_volume(80, "us-east-1a", encrypted=True) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") snapshots = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] snapshots.should.have.length_of(1) - snapshots[0].description.should.equal('a test snapshot') + snapshots[0].description.should.equal("a test snapshot") snapshots[0].start_time.should_not.be.none snapshots[0].encrypted.should.be(True) @@ -309,11 +327,11 @@ def test_create_encrypted_snapshot(): def test_filter_snapshot_by_id(): conn = boto.ec2.connect_to_region("us-east-1") volume1 = conn.create_volume(36, "us-east-1a") - snap1 = volume1.create_snapshot('a test snapshot 1') - volume2 = conn.create_volume(42, 'us-east-1a') - snap2 = volume2.create_snapshot('a test snapshot 2') - volume3 = conn.create_volume(84, 'us-east-1a') - snap3 = volume3.create_snapshot('a test snapshot 3') + snap1 = volume1.create_snapshot("a test snapshot 1") + volume2 = conn.create_volume(42, "us-east-1a") + snap2 = volume2.create_snapshot("a test snapshot 2") + volume3 = conn.create_volume(84, "us-east-1a") + snap3 = volume3.create_snapshot("a test snapshot 3") snapshots1 = conn.get_all_snapshots(snapshot_ids=snap2.id) snapshots1.should.have.length_of(1) snapshots1[0].volume_id.should.equal(volume2.id) @@ -326,8 +344,8 @@ def test_filter_snapshot_by_id(): s.region.name.should.equal(conn.region.name) with assert_raises(EC2ResponseError) as cm: - conn.get_all_snapshots(snapshot_ids=['snap-does_not_exist']) - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + conn.get_all_snapshots(snapshot_ids=["snap-does_not_exist"]) + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -338,67 +356,62 @@ def test_snapshot_filters(): volume1 = conn.create_volume(20, "us-east-1a", encrypted=False) volume2 = conn.create_volume(25, "us-east-1a", encrypted=True) - snapshot1 = volume1.create_snapshot(description='testsnapshot1') - snapshot2 = volume1.create_snapshot(description='testsnapshot2') - snapshot3 = volume2.create_snapshot(description='testsnapshot3') + snapshot1 = volume1.create_snapshot(description="testsnapshot1") + snapshot2 = volume1.create_snapshot(description="testsnapshot2") + snapshot3 = volume2.create_snapshot(description="testsnapshot3") - conn.create_tags([snapshot1.id], {'testkey1': 'testvalue1'}) - conn.create_tags([snapshot2.id], {'testkey2': 'testvalue2'}) + conn.create_tags([snapshot1.id], {"testkey1": "testvalue1"}) + conn.create_tags([snapshot2.id], {"testkey2": "testvalue2"}) snapshots_by_description = conn.get_all_snapshots( - filters={'description': 'testsnapshot1'}) - set([snap.id for snap in snapshots_by_description] - ).should.equal({snapshot1.id}) + filters={"description": "testsnapshot1"} + ) + set([snap.id for snap in snapshots_by_description]).should.equal({snapshot1.id}) - snapshots_by_id = conn.get_all_snapshots( - filters={'snapshot-id': snapshot1.id}) - set([snap.id for snap in snapshots_by_id] - ).should.equal({snapshot1.id}) + snapshots_by_id = conn.get_all_snapshots(filters={"snapshot-id": snapshot1.id}) + set([snap.id for snap in snapshots_by_id]).should.equal({snapshot1.id}) snapshots_by_start_time = conn.get_all_snapshots( - filters={'start-time': snapshot1.start_time}) - set([snap.start_time for snap in snapshots_by_start_time] - ).should.equal({snapshot1.start_time}) + filters={"start-time": snapshot1.start_time} + ) + set([snap.start_time for snap in snapshots_by_start_time]).should.equal( + {snapshot1.start_time} + ) - snapshots_by_volume_id = conn.get_all_snapshots( - filters={'volume-id': volume1.id}) - set([snap.id for snap in snapshots_by_volume_id] - ).should.equal({snapshot1.id, snapshot2.id}) + snapshots_by_volume_id = conn.get_all_snapshots(filters={"volume-id": volume1.id}) + set([snap.id for snap in snapshots_by_volume_id]).should.equal( + {snapshot1.id, snapshot2.id} + ) - snapshots_by_status = conn.get_all_snapshots( - filters={'status': 'completed'}) - ({snapshot1.id, snapshot2.id, snapshot3.id} - - {snap.id for snap in snapshots_by_status}).should.have.length_of(0) + snapshots_by_status = conn.get_all_snapshots(filters={"status": "completed"}) + ( + {snapshot1.id, snapshot2.id, snapshot3.id} + - {snap.id for snap in snapshots_by_status} + ).should.have.length_of(0) snapshots_by_volume_size = conn.get_all_snapshots( - filters={'volume-size': volume1.size}) - set([snap.id for snap in snapshots_by_volume_size] - ).should.equal({snapshot1.id, snapshot2.id}) + filters={"volume-size": volume1.size} + ) + set([snap.id for snap in snapshots_by_volume_size]).should.equal( + {snapshot1.id, snapshot2.id} + ) - snapshots_by_tag_key = conn.get_all_snapshots( - filters={'tag-key': 'testkey1'}) - set([snap.id for snap in snapshots_by_tag_key] - ).should.equal({snapshot1.id}) + snapshots_by_tag_key = conn.get_all_snapshots(filters={"tag-key": "testkey1"}) + set([snap.id for snap in snapshots_by_tag_key]).should.equal({snapshot1.id}) - snapshots_by_tag_value = conn.get_all_snapshots( - filters={'tag-value': 'testvalue1'}) - set([snap.id for snap in snapshots_by_tag_value] - ).should.equal({snapshot1.id}) + snapshots_by_tag_value = conn.get_all_snapshots(filters={"tag-value": "testvalue1"}) + set([snap.id for snap in snapshots_by_tag_value]).should.equal({snapshot1.id}) - snapshots_by_tag = conn.get_all_snapshots( - filters={'tag:testkey1': 'testvalue1'}) - set([snap.id for snap in snapshots_by_tag] - ).should.equal({snapshot1.id}) + snapshots_by_tag = conn.get_all_snapshots(filters={"tag:testkey1": "testvalue1"}) + set([snap.id for snap in snapshots_by_tag]).should.equal({snapshot1.id}) - snapshots_by_encrypted = conn.get_all_snapshots( - filters={'encrypted': 'true'}) - set([snap.id for snap in snapshots_by_encrypted] - ).should.equal({snapshot3.id}) + snapshots_by_encrypted = conn.get_all_snapshots(filters={"encrypted": "true"}) + set([snap.id for snap in snapshots_by_encrypted]).should.equal({snapshot3.id}) - snapshots_by_owner_id = conn.get_all_snapshots( - filters={'owner-id': OWNER_ID}) - set([snap.id for snap in snapshots_by_owner_id] - ).should.equal({snapshot1.id, snapshot2.id, snapshot3.id}) + snapshots_by_owner_id = conn.get_all_snapshots(filters={"owner-id": OWNER_ID}) + set([snap.id for snap in snapshots_by_owner_id]).should.equal( + {snapshot1.id, snapshot2.id, snapshot3.id} + ) @mock_ec2_deprecated @@ -411,119 +424,139 @@ def test_snapshot_attribute(): # Baseline attributes = conn.get_snapshot_attribute( - snapshot.id, attribute='createVolumePermission') - attributes.name.should.equal('create_volume_permission') + snapshot.id, attribute="createVolumePermission" + ) + attributes.name.should.equal("create_volume_permission") attributes.attrs.should.have.length_of(0) - ADD_GROUP_ARGS = {'snapshot_id': snapshot.id, - 'attribute': 'createVolumePermission', - 'operation': 'add', - 'groups': 'all'} + ADD_GROUP_ARGS = { + "snapshot_id": snapshot.id, + "attribute": "createVolumePermission", + "operation": "add", + "groups": "all", + } - REMOVE_GROUP_ARGS = {'snapshot_id': snapshot.id, - 'attribute': 'createVolumePermission', - 'operation': 'remove', - 'groups': 'all'} + REMOVE_GROUP_ARGS = { + "snapshot_id": snapshot.id, + "attribute": "createVolumePermission", + "operation": "remove", + "groups": "all", + } # Add 'all' group and confirm with assert_raises(EC2ResponseError) as ex: - conn.modify_snapshot_attribute( - **dict(ADD_GROUP_ARGS, **{'dry_run': True})) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_snapshot_attribute(**dict(ADD_GROUP_ARGS, **{"dry_run": True})) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set" + ) conn.modify_snapshot_attribute(**ADD_GROUP_ARGS) attributes = conn.get_snapshot_attribute( - snapshot.id, attribute='createVolumePermission') - attributes.attrs['groups'].should.have.length_of(1) - attributes.attrs['groups'].should.equal(['all']) + snapshot.id, attribute="createVolumePermission" + ) + attributes.attrs["groups"].should.have.length_of(1) + attributes.attrs["groups"].should.equal(["all"]) # Add is idempotent - conn.modify_snapshot_attribute.when.called_with( - **ADD_GROUP_ARGS).should_not.throw(EC2ResponseError) + conn.modify_snapshot_attribute.when.called_with(**ADD_GROUP_ARGS).should_not.throw( + EC2ResponseError + ) # Remove 'all' group and confirm with assert_raises(EC2ResponseError) as ex: - conn.modify_snapshot_attribute( - **dict(REMOVE_GROUP_ARGS, **{'dry_run': True})) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_snapshot_attribute(**dict(REMOVE_GROUP_ARGS, **{"dry_run": True})) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set" + ) conn.modify_snapshot_attribute(**REMOVE_GROUP_ARGS) attributes = conn.get_snapshot_attribute( - snapshot.id, attribute='createVolumePermission') + snapshot.id, attribute="createVolumePermission" + ) attributes.attrs.should.have.length_of(0) # Remove is idempotent conn.modify_snapshot_attribute.when.called_with( - **REMOVE_GROUP_ARGS).should_not.throw(EC2ResponseError) + **REMOVE_GROUP_ARGS + ).should_not.throw(EC2ResponseError) # Error: Add with group != 'all' with assert_raises(EC2ResponseError) as cm: - conn.modify_snapshot_attribute(snapshot.id, - attribute='createVolumePermission', - operation='add', - groups='everyone') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_snapshot_attribute( + snapshot.id, + attribute="createVolumePermission", + operation="add", + groups="everyone", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with invalid snapshot ID with assert_raises(EC2ResponseError) as cm: - conn.modify_snapshot_attribute("snapshot-abcd1234", - attribute='createVolumePermission', - operation='add', - groups='all') - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + conn.modify_snapshot_attribute( + "snapshot-abcd1234", + attribute="createVolumePermission", + operation="add", + groups="all", + ) + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Remove with invalid snapshot ID with assert_raises(EC2ResponseError) as cm: - conn.modify_snapshot_attribute("snapshot-abcd1234", - attribute='createVolumePermission', - operation='remove', - groups='all') - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + conn.modify_snapshot_attribute( + "snapshot-abcd1234", + attribute="createVolumePermission", + operation="remove", + groups="all", + ) + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add or remove with user ID instead of group - conn.modify_snapshot_attribute.when.called_with(snapshot.id, - attribute='createVolumePermission', - operation='add', - user_ids=['user']).should.throw(NotImplementedError) - conn.modify_snapshot_attribute.when.called_with(snapshot.id, - attribute='createVolumePermission', - operation='remove', - user_ids=['user']).should.throw(NotImplementedError) + conn.modify_snapshot_attribute.when.called_with( + snapshot.id, + attribute="createVolumePermission", + operation="add", + user_ids=["user"], + ).should.throw(NotImplementedError) + conn.modify_snapshot_attribute.when.called_with( + snapshot.id, + attribute="createVolumePermission", + operation="remove", + user_ids=["user"], + ).should.throw(NotImplementedError) @mock_ec2_deprecated def test_create_volume_from_snapshot(): conn = boto.ec2.connect_to_region("us-east-1") volume = conn.create_volume(80, "us-east-1a") - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") with assert_raises(EC2ResponseError) as ex: - snapshot = volume.create_snapshot('a test snapshot', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + snapshot = volume.create_snapshot("a test snapshot", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set" + ) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") - new_volume = snapshot.create_volume('us-east-1a') + new_volume = snapshot.create_volume("us-east-1a") new_volume.size.should.equal(80) new_volume.snapshot_id.should.equal(snapshot.id) @@ -533,11 +566,11 @@ def test_create_volume_from_encrypted_snapshot(): conn = boto.ec2.connect_to_region("us-east-1") volume = conn.create_volume(80, "us-east-1a", encrypted=True) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") - new_volume = snapshot.create_volume('us-east-1a') + new_volume = snapshot.create_volume("us-east-1a") new_volume.size.should.equal(80) new_volume.snapshot_id.should.equal(snapshot.id) new_volume.encrypted.should.be(True) @@ -553,131 +586,133 @@ def test_modify_attribute_blockDeviceMapping(): """ conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: - instance.modify_attribute('blockDeviceMapping', { - '/dev/sda1': True}, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + instance.modify_attribute( + "blockDeviceMapping", {"/dev/sda1": True}, dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceAttribute operation: Request would have succeeded, but DryRun flag is set" + ) - instance.modify_attribute('blockDeviceMapping', {'/dev/sda1': True}) + instance.modify_attribute("blockDeviceMapping", {"/dev/sda1": True}) instance = ec2_backends[conn.region.name].get_instance(instance.id) - instance.block_device_mapping.should.have.key('/dev/sda1') - instance.block_device_mapping[ - '/dev/sda1'].delete_on_termination.should.be(True) + instance.block_device_mapping.should.have.key("/dev/sda1") + instance.block_device_mapping["/dev/sda1"].delete_on_termination.should.be(True) @mock_ec2_deprecated def test_volume_tag_escaping(): conn = boto.ec2.connect_to_region("us-east-1") - vol = conn.create_volume(10, 'us-east-1a') - snapshot = conn.create_snapshot(vol.id, 'Desc') + vol = conn.create_volume(10, "us-east-1a") + snapshot = conn.create_snapshot(vol.id, "Desc") with assert_raises(EC2ResponseError) as ex: - snapshot.add_tags({'key': ''}, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + snapshot.add_tags({"key": ""}, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) snaps = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] - dict(snaps[0].tags).should_not.be.equal( - {'key': ''}) + dict(snaps[0].tags).should_not.be.equal({"key": ""}) - snapshot.add_tags({'key': ''}) + snapshot.add_tags({"key": ""}) snaps = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] - dict(snaps[0].tags).should.equal({'key': ''}) + dict(snaps[0].tags).should.equal({"key": ""}) @mock_ec2 def test_volume_property_hidden_when_no_tags_exist(): - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") - volume_response = ec2_client.create_volume( - Size=10, - AvailabilityZone='us-east-1a' - ) + volume_response = ec2_client.create_volume(Size=10, AvailabilityZone="us-east-1a") - volume_response.get('Tags').should.equal(None) + volume_response.get("Tags").should.equal(None) @freeze_time @mock_ec2 def test_copy_snapshot(): - ec2_client = boto3.client('ec2', region_name='eu-west-1') - dest_ec2_client = boto3.client('ec2', region_name='eu-west-2') + ec2_client = boto3.client("ec2", region_name="eu-west-1") + dest_ec2_client = boto3.client("ec2", region_name="eu-west-2") - volume_response = ec2_client.create_volume( - AvailabilityZone='eu-west-1a', Size=10 - ) + volume_response = ec2_client.create_volume(AvailabilityZone="eu-west-1a", Size=10) create_snapshot_response = ec2_client.create_snapshot( - VolumeId=volume_response['VolumeId'] + VolumeId=volume_response["VolumeId"] ) copy_snapshot_response = dest_ec2_client.copy_snapshot( - SourceSnapshotId=create_snapshot_response['SnapshotId'], - SourceRegion="eu-west-1" + SourceSnapshotId=create_snapshot_response["SnapshotId"], + SourceRegion="eu-west-1", ) - ec2 = boto3.resource('ec2', region_name='eu-west-1') - dest_ec2 = boto3.resource('ec2', region_name='eu-west-2') + ec2 = boto3.resource("ec2", region_name="eu-west-1") + dest_ec2 = boto3.resource("ec2", region_name="eu-west-2") - source = ec2.Snapshot(create_snapshot_response['SnapshotId']) - dest = dest_ec2.Snapshot(copy_snapshot_response['SnapshotId']) + source = ec2.Snapshot(create_snapshot_response["SnapshotId"]) + dest = dest_ec2.Snapshot(copy_snapshot_response["SnapshotId"]) - attribs = ['data_encryption_key_id', 'encrypted', - 'kms_key_id', 'owner_alias', 'owner_id', - 'progress', 'state', 'state_message', - 'tags', 'volume_id', 'volume_size'] + attribs = [ + "data_encryption_key_id", + "encrypted", + "kms_key_id", + "owner_alias", + "owner_id", + "progress", + "state", + "state_message", + "tags", + "volume_id", + "volume_size", + ] for attrib in attribs: getattr(source, attrib).should.equal(getattr(dest, attrib)) # Copy from non-existent source ID. with assert_raises(ClientError) as cm: - create_snapshot_error = ec2_client.create_snapshot( - VolumeId='vol-abcd1234' - ) - cm.exception.response['Error']['Code'].should.equal('InvalidVolume.NotFound') - cm.exception.response['Error']['Message'].should.equal("The volume 'vol-abcd1234' does not exist.") - cm.exception.response['ResponseMetadata']['RequestId'].should_not.be.none - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + create_snapshot_error = ec2_client.create_snapshot(VolumeId="vol-abcd1234") + cm.exception.response["Error"]["Code"].should.equal("InvalidVolume.NotFound") + cm.exception.response["Error"]["Message"].should.equal( + "The volume 'vol-abcd1234' does not exist." + ) + cm.exception.response["ResponseMetadata"]["RequestId"].should_not.be.none + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) # Copy from non-existent source region. with assert_raises(ClientError) as cm: copy_snapshot_response = dest_ec2_client.copy_snapshot( - SourceSnapshotId=create_snapshot_response['SnapshotId'], - SourceRegion="eu-west-2" + SourceSnapshotId=create_snapshot_response["SnapshotId"], + SourceRegion="eu-west-2", ) - cm.exception.response['Error']['Code'].should.equal('InvalidSnapshot.NotFound') - cm.exception.response['Error']['Message'].should.be.none - cm.exception.response['ResponseMetadata']['RequestId'].should_not.be.none - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.response["Error"]["Code"].should.equal("InvalidSnapshot.NotFound") + cm.exception.response["Error"]["Message"].should.be.none + cm.exception.response["ResponseMetadata"]["RequestId"].should_not.be.none + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + @mock_ec2 def test_search_for_many_snapshots(): - ec2_client = boto3.client('ec2', region_name='eu-west-1') + ec2_client = boto3.client("ec2", region_name="eu-west-1") - volume_response = ec2_client.create_volume( - AvailabilityZone='eu-west-1a', Size=10 - ) + volume_response = ec2_client.create_volume(AvailabilityZone="eu-west-1a", Size=10) snapshot_ids = [] for i in range(1, 20): create_snapshot_response = ec2_client.create_snapshot( - VolumeId=volume_response['VolumeId'] + VolumeId=volume_response["VolumeId"] ) - snapshot_ids.append(create_snapshot_response['SnapshotId']) + snapshot_ids.append(create_snapshot_response["SnapshotId"]) - snapshots_response = ec2_client.describe_snapshots( - SnapshotIds=snapshot_ids - ) + snapshots_response = ec2_client.describe_snapshots(SnapshotIds=snapshot_ids) - assert len(snapshots_response['Snapshots']) == len(snapshot_ids) + assert len(snapshots_response["Snapshots"]) == len(snapshot_ids) diff --git a/tests/test_ec2/test_elastic_ip_addresses.py b/tests/test_ec2/test_elastic_ip_addresses.py index 3fad7fd3c..886cdff56 100644 --- a/tests/test_ec2/test_elastic_ip_addresses.py +++ b/tests/test_ec2/test_elastic_ip_addresses.py @@ -1,514 +1,539 @@ -from __future__ import unicode_literals -# Ensure 'assert_raises' context manager support for Python 2.6 -import tests.backport_assert_raises -from nose.tools import assert_raises - -import boto -import boto3 -from boto.exception import EC2ResponseError -import six - -import sure # noqa - -from moto import mock_ec2, mock_ec2_deprecated - -import logging - - -@mock_ec2_deprecated -def test_eip_allocate_classic(): - """Allocate/release Classic EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as ex: - standard = conn.allocate_address(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set') - - standard = conn.allocate_address() - standard.should.be.a(boto.ec2.address.Address) - standard.public_ip.should.be.a(six.text_type) - standard.instance_id.should.be.none - standard.domain.should.be.equal("standard") - - with assert_raises(EC2ResponseError) as ex: - standard.release(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set') - - standard.release() - standard.should_not.be.within(conn.get_all_addresses()) - - -@mock_ec2_deprecated -def test_eip_allocate_vpc(): - """Allocate/release VPC EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as ex: - vpc = conn.allocate_address(domain="vpc", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set') - - vpc = conn.allocate_address(domain="vpc") - vpc.should.be.a(boto.ec2.address.Address) - vpc.domain.should.be.equal("vpc") - logging.debug("vpc alloc_id:".format(vpc.allocation_id)) - vpc.release() - -@mock_ec2 -def test_specific_eip_allocate_vpc(): - """Allocate VPC EIP with specific address""" - service = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') - - vpc = client.allocate_address(Domain="vpc", Address="127.38.43.222") - vpc['Domain'].should.be.equal("vpc") - vpc['PublicIp'].should.be.equal("127.38.43.222") - logging.debug("vpc alloc_id:".format(vpc['AllocationId'])) - - -@mock_ec2_deprecated -def test_eip_allocate_invalid_domain(): - """Allocate EIP invalid domain""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.allocate_address(domain="bogus") - cm.exception.code.should.equal('InvalidParameterValue') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_eip_associate_classic(): - """Associate/Disassociate EIP to classic instance""" - conn = boto.connect_ec2('the_key', 'the_secret') - - reservation = conn.run_instances('ami-1234abcd') - instance = reservation.instances[0] - - eip = conn.allocate_address() - eip.instance_id.should.be.none - - with assert_raises(EC2ResponseError) as cm: - conn.associate_address(public_ip=eip.public_ip) - cm.exception.code.should.equal('MissingParameter') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - with assert_raises(EC2ResponseError) as ex: - conn.associate_address(instance_id=instance.id, - public_ip=eip.public_ip, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AssociateAddress operation: Request would have succeeded, but DryRun flag is set') - - conn.associate_address(instance_id=instance.id, public_ip=eip.public_ip) - # no .update() on address ): - eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.instance_id.should.be.equal(instance.id) - - with assert_raises(EC2ResponseError) as ex: - conn.disassociate_address(public_ip=eip.public_ip, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DisAssociateAddress operation: Request would have succeeded, but DryRun flag is set') - - conn.disassociate_address(public_ip=eip.public_ip) - # no .update() on address ): - eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.instance_id.should.be.equal(u'') - eip.release() - eip.should_not.be.within(conn.get_all_addresses()) - eip = None - - instance.terminate() - - -@mock_ec2_deprecated -def test_eip_associate_vpc(): - """Associate/Disassociate EIP to VPC instance""" - conn = boto.connect_ec2('the_key', 'the_secret') - - reservation = conn.run_instances('ami-1234abcd') - instance = reservation.instances[0] - - eip = conn.allocate_address(domain='vpc') - eip.instance_id.should.be.none - - with assert_raises(EC2ResponseError) as cm: - conn.associate_address(allocation_id=eip.allocation_id) - cm.exception.code.should.equal('MissingParameter') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - conn.associate_address(instance_id=instance.id, - allocation_id=eip.allocation_id) - # no .update() on address ): - eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.instance_id.should.be.equal(instance.id) - conn.disassociate_address(association_id=eip.association_id) - # no .update() on address ): - eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.instance_id.should.be.equal(u'') - eip.association_id.should.be.none - - with assert_raises(EC2ResponseError) as ex: - eip.release(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set') - - eip.release() - eip = None - - instance.terminate() - - -@mock_ec2 -def test_eip_boto3_vpc_association(): - """Associate EIP to VPC instance in a new subnet with boto3""" - service = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') - vpc_res = client.create_vpc(CidrBlock='10.0.0.0/24') - subnet_res = client.create_subnet( - VpcId=vpc_res['Vpc']['VpcId'], CidrBlock='10.0.0.0/24') - instance = service.create_instances(**{ - 'InstanceType': 't2.micro', - 'ImageId': 'ami-test', - 'MinCount': 1, - 'MaxCount': 1, - 'SubnetId': subnet_res['Subnet']['SubnetId'] - })[0] - allocation_id = client.allocate_address(Domain='vpc')['AllocationId'] - address = service.VpcAddress(allocation_id) - address.load() - address.association_id.should.be.none - address.instance_id.should.be.empty - address.network_interface_id.should.be.empty - association_id = client.associate_address( - InstanceId=instance.id, - AllocationId=allocation_id, - AllowReassociation=False) - instance.load() - address.reload() - address.association_id.should_not.be.none - instance.public_ip_address.should_not.be.none - instance.public_dns_name.should_not.be.none - address.network_interface_id.should.equal(instance.network_interfaces_attribute[0].get('NetworkInterfaceId')) - address.public_ip.should.equal(instance.public_ip_address) - address.instance_id.should.equal(instance.id) - - client.disassociate_address(AssociationId=address.association_id) - instance.reload() - address.reload() - instance.public_ip_address.should.be.none - address.network_interface_id.should.be.empty - address.association_id.should.be.none - address.instance_id.should.be.empty - - -@mock_ec2_deprecated -def test_eip_associate_network_interface(): - """Associate/Disassociate EIP to NIC""" - conn = boto.connect_vpc('the_key', 'the_secret') - vpc = conn.create_vpc("10.0.0.0/16") - subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") - eni = conn.create_network_interface(subnet.id) - - eip = conn.allocate_address(domain='vpc') - eip.network_interface_id.should.be.none - - with assert_raises(EC2ResponseError) as cm: - conn.associate_address(network_interface_id=eni.id) - cm.exception.code.should.equal('MissingParameter') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - conn.associate_address(network_interface_id=eni.id, - allocation_id=eip.allocation_id) - # no .update() on address ): - eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.network_interface_id.should.be.equal(eni.id) - - conn.disassociate_address(association_id=eip.association_id) - # no .update() on address ): - eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.network_interface_id.should.be.equal(u'') - eip.association_id.should.be.none - eip.release() - eip = None - - -@mock_ec2_deprecated -def test_eip_reassociate(): - """reassociate EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') - - reservation = conn.run_instances('ami-1234abcd', min_count=2) - instance1, instance2 = reservation.instances - - eip = conn.allocate_address() - conn.associate_address(instance_id=instance1.id, public_ip=eip.public_ip) - - # Same ID is idempotent - conn.associate_address(instance_id=instance1.id, public_ip=eip.public_ip) - - # Different ID detects resource association - with assert_raises(EC2ResponseError) as cm: - conn.associate_address( - instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=False) - cm.exception.code.should.equal('Resource.AlreadyAssociated') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - conn.associate_address.when.called_with( - instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=True).should_not.throw(EC2ResponseError) - - eip.release() - eip = None - - instance1.terminate() - instance2.terminate() - - -@mock_ec2_deprecated -def test_eip_reassociate_nic(): - """reassociate EIP""" - conn = boto.connect_vpc('the_key', 'the_secret') - - vpc = conn.create_vpc("10.0.0.0/16") - subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") - eni1 = conn.create_network_interface(subnet.id) - eni2 = conn.create_network_interface(subnet.id) - - eip = conn.allocate_address() - conn.associate_address(network_interface_id=eni1.id, - public_ip=eip.public_ip) - - # Same ID is idempotent - conn.associate_address(network_interface_id=eni1.id, - public_ip=eip.public_ip) - - # Different ID detects resource association - with assert_raises(EC2ResponseError) as cm: - conn.associate_address( - network_interface_id=eni2.id, public_ip=eip.public_ip) - cm.exception.code.should.equal('Resource.AlreadyAssociated') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - conn.associate_address.when.called_with( - network_interface_id=eni2.id, public_ip=eip.public_ip, allow_reassociation=True).should_not.throw(EC2ResponseError) - - eip.release() - eip = None - - -@mock_ec2_deprecated -def test_eip_associate_invalid_args(): - """Associate EIP, invalid args """ - conn = boto.connect_ec2('the_key', 'the_secret') - - reservation = conn.run_instances('ami-1234abcd') - instance = reservation.instances[0] - - eip = conn.allocate_address() - - with assert_raises(EC2ResponseError) as cm: - conn.associate_address(instance_id=instance.id) - cm.exception.code.should.equal('MissingParameter') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - instance.terminate() - - -@mock_ec2_deprecated -def test_eip_disassociate_bogus_association(): - """Disassociate bogus EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.disassociate_address(association_id="bogus") - cm.exception.code.should.equal('InvalidAssociationID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_eip_release_bogus_eip(): - """Release bogus EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.release_address(allocation_id="bogus") - cm.exception.code.should.equal('InvalidAllocationID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_eip_disassociate_arg_error(): - """Invalid arguments disassociate address""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.disassociate_address() - cm.exception.code.should.equal('MissingParameter') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_eip_release_arg_error(): - """Invalid arguments release address""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.release_address() - cm.exception.code.should.equal('MissingParameter') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_eip_describe(): - """Listing of allocated Elastic IP Addresses.""" - conn = boto.connect_ec2('the_key', 'the_secret') - eips = [] - number_of_classic_ips = 2 - number_of_vpc_ips = 2 - - # allocate some IPs - for _ in range(number_of_classic_ips): - eips.append(conn.allocate_address()) - for _ in range(number_of_vpc_ips): - eips.append(conn.allocate_address(domain='vpc')) - len(eips).should.be.equal(number_of_classic_ips + number_of_vpc_ips) - - # Can we find each one individually? - for eip in eips: - if eip.allocation_id: - lookup_addresses = conn.get_all_addresses( - allocation_ids=[eip.allocation_id]) - else: - lookup_addresses = conn.get_all_addresses( - addresses=[eip.public_ip]) - len(lookup_addresses).should.be.equal(1) - lookup_addresses[0].public_ip.should.be.equal(eip.public_ip) - - # Can we find first two when we search for them? - lookup_addresses = conn.get_all_addresses( - addresses=[eips[0].public_ip, eips[1].public_ip]) - len(lookup_addresses).should.be.equal(2) - lookup_addresses[0].public_ip.should.be.equal(eips[0].public_ip) - lookup_addresses[1].public_ip.should.be.equal(eips[1].public_ip) - - # Release all IPs - for eip in eips: - eip.release() - len(conn.get_all_addresses()).should.be.equal(0) - - -@mock_ec2_deprecated -def test_eip_describe_none(): - """Error when search for bogus IP""" - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.get_all_addresses(addresses=["256.256.256.256"]) - cm.exception.code.should.equal('InvalidAddress.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2 -def test_eip_filters(): - service = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') - vpc_res = client.create_vpc(CidrBlock='10.0.0.0/24') - subnet_res = client.create_subnet( - VpcId=vpc_res['Vpc']['VpcId'], CidrBlock='10.0.0.0/24') - - def create_inst_with_eip(): - instance = service.create_instances(**{ - 'InstanceType': 't2.micro', - 'ImageId': 'ami-test', - 'MinCount': 1, - 'MaxCount': 1, - 'SubnetId': subnet_res['Subnet']['SubnetId'] - })[0] - allocation_id = client.allocate_address(Domain='vpc')['AllocationId'] - _ = client.associate_address( - InstanceId=instance.id, - AllocationId=allocation_id, - AllowReassociation=False) - instance.load() - address = service.VpcAddress(allocation_id) - address.load() - return instance, address - - inst1, eip1 = create_inst_with_eip() - inst2, eip2 = create_inst_with_eip() - inst3, eip3 = create_inst_with_eip() - - # Param search by AllocationId - addresses = list(service.vpc_addresses.filter(AllocationIds=[eip2.allocation_id])) - len(addresses).should.be.equal(1) - addresses[0].public_ip.should.equal(eip2.public_ip) - inst2.public_ip_address.should.equal(addresses[0].public_ip) - - # Param search by PublicIp - addresses = list(service.vpc_addresses.filter(PublicIps=[eip3.public_ip])) - len(addresses).should.be.equal(1) - addresses[0].public_ip.should.equal(eip3.public_ip) - inst3.public_ip_address.should.equal(addresses[0].public_ip) - - # Param search by Filter - def check_vpc_filter_valid(filter_name, filter_values): - addresses = list(service.vpc_addresses.filter( - Filters=[{'Name': filter_name, - 'Values': filter_values}])) - len(addresses).should.equal(2) - ips = [addr.public_ip for addr in addresses] - set(ips).should.equal(set([eip1.public_ip, eip2.public_ip])) - ips.should.contain(inst1.public_ip_address) - - def check_vpc_filter_invalid(filter_name): - addresses = list(service.vpc_addresses.filter( - Filters=[{'Name': filter_name, - 'Values': ['dummy1', 'dummy2']}])) - len(addresses).should.equal(0) - - def check_vpc_filter(filter_name, filter_values): - check_vpc_filter_valid(filter_name, filter_values) - check_vpc_filter_invalid(filter_name) - - check_vpc_filter('allocation-id', [eip1.allocation_id, eip2.allocation_id]) - check_vpc_filter('association-id', [eip1.association_id, eip2.association_id]) - check_vpc_filter('instance-id', [inst1.id, inst2.id]) - check_vpc_filter( - 'network-interface-id', - [inst1.network_interfaces_attribute[0].get('NetworkInterfaceId'), - inst2.network_interfaces_attribute[0].get('NetworkInterfaceId')]) - check_vpc_filter( - 'private-ip-address', - [inst1.network_interfaces_attribute[0].get('PrivateIpAddress'), - inst2.network_interfaces_attribute[0].get('PrivateIpAddress')]) - check_vpc_filter('public-ip', [inst1.public_ip_address, inst2.public_ip_address]) - - # all the ips are in a VPC - addresses = list(service.vpc_addresses.filter( - Filters=[{'Name': 'domain', 'Values': ['vpc']}])) - len(addresses).should.equal(3) +from __future__ import unicode_literals + +# Ensure 'assert_raises' context manager support for Python 2.6 +import tests.backport_assert_raises +from nose.tools import assert_raises + +import boto +import boto3 +from boto.exception import EC2ResponseError +import six + +import sure # noqa + +from moto import mock_ec2, mock_ec2_deprecated + +import logging + + +@mock_ec2_deprecated +def test_eip_allocate_classic(): + """Allocate/release Classic EIP""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as ex: + standard = conn.allocate_address(dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set" + ) + + standard = conn.allocate_address() + standard.should.be.a(boto.ec2.address.Address) + standard.public_ip.should.be.a(six.text_type) + standard.instance_id.should.be.none + standard.domain.should.be.equal("standard") + + with assert_raises(EC2ResponseError) as ex: + standard.release(dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set" + ) + + standard.release() + standard.should_not.be.within(conn.get_all_addresses()) + + +@mock_ec2_deprecated +def test_eip_allocate_vpc(): + """Allocate/release VPC EIP""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as ex: + vpc = conn.allocate_address(domain="vpc", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set" + ) + + vpc = conn.allocate_address(domain="vpc") + vpc.should.be.a(boto.ec2.address.Address) + vpc.domain.should.be.equal("vpc") + logging.debug("vpc alloc_id:".format(vpc.allocation_id)) + vpc.release() + + +@mock_ec2 +def test_specific_eip_allocate_vpc(): + """Allocate VPC EIP with specific address""" + service = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + + vpc = client.allocate_address(Domain="vpc", Address="127.38.43.222") + vpc["Domain"].should.be.equal("vpc") + vpc["PublicIp"].should.be.equal("127.38.43.222") + logging.debug("vpc alloc_id:".format(vpc["AllocationId"])) + + +@mock_ec2_deprecated +def test_eip_allocate_invalid_domain(): + """Allocate EIP invalid domain""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.allocate_address(domain="bogus") + cm.exception.code.should.equal("InvalidParameterValue") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_eip_associate_classic(): + """Associate/Disassociate EIP to classic instance""" + conn = boto.connect_ec2("the_key", "the_secret") + + reservation = conn.run_instances("ami-1234abcd") + instance = reservation.instances[0] + + eip = conn.allocate_address() + eip.instance_id.should.be.none + + with assert_raises(EC2ResponseError) as cm: + conn.associate_address(public_ip=eip.public_ip) + cm.exception.code.should.equal("MissingParameter") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + with assert_raises(EC2ResponseError) as ex: + conn.associate_address( + instance_id=instance.id, public_ip=eip.public_ip, dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the AssociateAddress operation: Request would have succeeded, but DryRun flag is set" + ) + + conn.associate_address(instance_id=instance.id, public_ip=eip.public_ip) + # no .update() on address ): + eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] + eip.instance_id.should.be.equal(instance.id) + + with assert_raises(EC2ResponseError) as ex: + conn.disassociate_address(public_ip=eip.public_ip, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the DisAssociateAddress operation: Request would have succeeded, but DryRun flag is set" + ) + + conn.disassociate_address(public_ip=eip.public_ip) + # no .update() on address ): + eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] + eip.instance_id.should.be.equal("") + eip.release() + eip.should_not.be.within(conn.get_all_addresses()) + eip = None + + instance.terminate() + + +@mock_ec2_deprecated +def test_eip_associate_vpc(): + """Associate/Disassociate EIP to VPC instance""" + conn = boto.connect_ec2("the_key", "the_secret") + + reservation = conn.run_instances("ami-1234abcd") + instance = reservation.instances[0] + + eip = conn.allocate_address(domain="vpc") + eip.instance_id.should.be.none + + with assert_raises(EC2ResponseError) as cm: + conn.associate_address(allocation_id=eip.allocation_id) + cm.exception.code.should.equal("MissingParameter") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + conn.associate_address(instance_id=instance.id, allocation_id=eip.allocation_id) + # no .update() on address ): + eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] + eip.instance_id.should.be.equal(instance.id) + conn.disassociate_address(association_id=eip.association_id) + # no .update() on address ): + eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] + eip.instance_id.should.be.equal("") + eip.association_id.should.be.none + + with assert_raises(EC2ResponseError) as ex: + eip.release(dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set" + ) + + eip.release() + eip = None + + instance.terminate() + + +@mock_ec2 +def test_eip_boto3_vpc_association(): + """Associate EIP to VPC instance in a new subnet with boto3""" + service = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + vpc_res = client.create_vpc(CidrBlock="10.0.0.0/24") + subnet_res = client.create_subnet( + VpcId=vpc_res["Vpc"]["VpcId"], CidrBlock="10.0.0.0/24" + ) + instance = service.create_instances( + **{ + "InstanceType": "t2.micro", + "ImageId": "ami-test", + "MinCount": 1, + "MaxCount": 1, + "SubnetId": subnet_res["Subnet"]["SubnetId"], + } + )[0] + allocation_id = client.allocate_address(Domain="vpc")["AllocationId"] + address = service.VpcAddress(allocation_id) + address.load() + address.association_id.should.be.none + address.instance_id.should.be.empty + address.network_interface_id.should.be.empty + association_id = client.associate_address( + InstanceId=instance.id, AllocationId=allocation_id, AllowReassociation=False + ) + instance.load() + address.reload() + address.association_id.should_not.be.none + instance.public_ip_address.should_not.be.none + instance.public_dns_name.should_not.be.none + address.network_interface_id.should.equal( + instance.network_interfaces_attribute[0].get("NetworkInterfaceId") + ) + address.public_ip.should.equal(instance.public_ip_address) + address.instance_id.should.equal(instance.id) + + client.disassociate_address(AssociationId=address.association_id) + instance.reload() + address.reload() + instance.public_ip_address.should.be.none + address.network_interface_id.should.be.empty + address.association_id.should.be.none + address.instance_id.should.be.empty + + +@mock_ec2_deprecated +def test_eip_associate_network_interface(): + """Associate/Disassociate EIP to NIC""" + conn = boto.connect_vpc("the_key", "the_secret") + vpc = conn.create_vpc("10.0.0.0/16") + subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") + eni = conn.create_network_interface(subnet.id) + + eip = conn.allocate_address(domain="vpc") + eip.network_interface_id.should.be.none + + with assert_raises(EC2ResponseError) as cm: + conn.associate_address(network_interface_id=eni.id) + cm.exception.code.should.equal("MissingParameter") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + conn.associate_address(network_interface_id=eni.id, allocation_id=eip.allocation_id) + # no .update() on address ): + eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] + eip.network_interface_id.should.be.equal(eni.id) + + conn.disassociate_address(association_id=eip.association_id) + # no .update() on address ): + eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] + eip.network_interface_id.should.be.equal("") + eip.association_id.should.be.none + eip.release() + eip = None + + +@mock_ec2_deprecated +def test_eip_reassociate(): + """reassociate EIP""" + conn = boto.connect_ec2("the_key", "the_secret") + + reservation = conn.run_instances("ami-1234abcd", min_count=2) + instance1, instance2 = reservation.instances + + eip = conn.allocate_address() + conn.associate_address(instance_id=instance1.id, public_ip=eip.public_ip) + + # Same ID is idempotent + conn.associate_address(instance_id=instance1.id, public_ip=eip.public_ip) + + # Different ID detects resource association + with assert_raises(EC2ResponseError) as cm: + conn.associate_address( + instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=False + ) + cm.exception.code.should.equal("Resource.AlreadyAssociated") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + conn.associate_address.when.called_with( + instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=True + ).should_not.throw(EC2ResponseError) + + eip.release() + eip = None + + instance1.terminate() + instance2.terminate() + + +@mock_ec2_deprecated +def test_eip_reassociate_nic(): + """reassociate EIP""" + conn = boto.connect_vpc("the_key", "the_secret") + + vpc = conn.create_vpc("10.0.0.0/16") + subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") + eni1 = conn.create_network_interface(subnet.id) + eni2 = conn.create_network_interface(subnet.id) + + eip = conn.allocate_address() + conn.associate_address(network_interface_id=eni1.id, public_ip=eip.public_ip) + + # Same ID is idempotent + conn.associate_address(network_interface_id=eni1.id, public_ip=eip.public_ip) + + # Different ID detects resource association + with assert_raises(EC2ResponseError) as cm: + conn.associate_address(network_interface_id=eni2.id, public_ip=eip.public_ip) + cm.exception.code.should.equal("Resource.AlreadyAssociated") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + conn.associate_address.when.called_with( + network_interface_id=eni2.id, public_ip=eip.public_ip, allow_reassociation=True + ).should_not.throw(EC2ResponseError) + + eip.release() + eip = None + + +@mock_ec2_deprecated +def test_eip_associate_invalid_args(): + """Associate EIP, invalid args """ + conn = boto.connect_ec2("the_key", "the_secret") + + reservation = conn.run_instances("ami-1234abcd") + instance = reservation.instances[0] + + eip = conn.allocate_address() + + with assert_raises(EC2ResponseError) as cm: + conn.associate_address(instance_id=instance.id) + cm.exception.code.should.equal("MissingParameter") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + instance.terminate() + + +@mock_ec2_deprecated +def test_eip_disassociate_bogus_association(): + """Disassociate bogus EIP""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.disassociate_address(association_id="bogus") + cm.exception.code.should.equal("InvalidAssociationID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_eip_release_bogus_eip(): + """Release bogus EIP""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.release_address(allocation_id="bogus") + cm.exception.code.should.equal("InvalidAllocationID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_eip_disassociate_arg_error(): + """Invalid arguments disassociate address""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.disassociate_address() + cm.exception.code.should.equal("MissingParameter") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_eip_release_arg_error(): + """Invalid arguments release address""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.release_address() + cm.exception.code.should.equal("MissingParameter") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_eip_describe(): + """Listing of allocated Elastic IP Addresses.""" + conn = boto.connect_ec2("the_key", "the_secret") + eips = [] + number_of_classic_ips = 2 + number_of_vpc_ips = 2 + + # allocate some IPs + for _ in range(number_of_classic_ips): + eips.append(conn.allocate_address()) + for _ in range(number_of_vpc_ips): + eips.append(conn.allocate_address(domain="vpc")) + len(eips).should.be.equal(number_of_classic_ips + number_of_vpc_ips) + + # Can we find each one individually? + for eip in eips: + if eip.allocation_id: + lookup_addresses = conn.get_all_addresses( + allocation_ids=[eip.allocation_id] + ) + else: + lookup_addresses = conn.get_all_addresses(addresses=[eip.public_ip]) + len(lookup_addresses).should.be.equal(1) + lookup_addresses[0].public_ip.should.be.equal(eip.public_ip) + + # Can we find first two when we search for them? + lookup_addresses = conn.get_all_addresses( + addresses=[eips[0].public_ip, eips[1].public_ip] + ) + len(lookup_addresses).should.be.equal(2) + lookup_addresses[0].public_ip.should.be.equal(eips[0].public_ip) + lookup_addresses[1].public_ip.should.be.equal(eips[1].public_ip) + + # Release all IPs + for eip in eips: + eip.release() + len(conn.get_all_addresses()).should.be.equal(0) + + +@mock_ec2_deprecated +def test_eip_describe_none(): + """Error when search for bogus IP""" + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.get_all_addresses(addresses=["256.256.256.256"]) + cm.exception.code.should.equal("InvalidAddress.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2 +def test_eip_filters(): + service = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + vpc_res = client.create_vpc(CidrBlock="10.0.0.0/24") + subnet_res = client.create_subnet( + VpcId=vpc_res["Vpc"]["VpcId"], CidrBlock="10.0.0.0/24" + ) + + def create_inst_with_eip(): + instance = service.create_instances( + **{ + "InstanceType": "t2.micro", + "ImageId": "ami-test", + "MinCount": 1, + "MaxCount": 1, + "SubnetId": subnet_res["Subnet"]["SubnetId"], + } + )[0] + allocation_id = client.allocate_address(Domain="vpc")["AllocationId"] + _ = client.associate_address( + InstanceId=instance.id, AllocationId=allocation_id, AllowReassociation=False + ) + instance.load() + address = service.VpcAddress(allocation_id) + address.load() + return instance, address + + inst1, eip1 = create_inst_with_eip() + inst2, eip2 = create_inst_with_eip() + inst3, eip3 = create_inst_with_eip() + + # Param search by AllocationId + addresses = list(service.vpc_addresses.filter(AllocationIds=[eip2.allocation_id])) + len(addresses).should.be.equal(1) + addresses[0].public_ip.should.equal(eip2.public_ip) + inst2.public_ip_address.should.equal(addresses[0].public_ip) + + # Param search by PublicIp + addresses = list(service.vpc_addresses.filter(PublicIps=[eip3.public_ip])) + len(addresses).should.be.equal(1) + addresses[0].public_ip.should.equal(eip3.public_ip) + inst3.public_ip_address.should.equal(addresses[0].public_ip) + + # Param search by Filter + def check_vpc_filter_valid(filter_name, filter_values): + addresses = list( + service.vpc_addresses.filter( + Filters=[{"Name": filter_name, "Values": filter_values}] + ) + ) + len(addresses).should.equal(2) + ips = [addr.public_ip for addr in addresses] + set(ips).should.equal(set([eip1.public_ip, eip2.public_ip])) + ips.should.contain(inst1.public_ip_address) + + def check_vpc_filter_invalid(filter_name): + addresses = list( + service.vpc_addresses.filter( + Filters=[{"Name": filter_name, "Values": ["dummy1", "dummy2"]}] + ) + ) + len(addresses).should.equal(0) + + def check_vpc_filter(filter_name, filter_values): + check_vpc_filter_valid(filter_name, filter_values) + check_vpc_filter_invalid(filter_name) + + check_vpc_filter("allocation-id", [eip1.allocation_id, eip2.allocation_id]) + check_vpc_filter("association-id", [eip1.association_id, eip2.association_id]) + check_vpc_filter("instance-id", [inst1.id, inst2.id]) + check_vpc_filter( + "network-interface-id", + [ + inst1.network_interfaces_attribute[0].get("NetworkInterfaceId"), + inst2.network_interfaces_attribute[0].get("NetworkInterfaceId"), + ], + ) + check_vpc_filter( + "private-ip-address", + [ + inst1.network_interfaces_attribute[0].get("PrivateIpAddress"), + inst2.network_interfaces_attribute[0].get("PrivateIpAddress"), + ], + ) + check_vpc_filter("public-ip", [inst1.public_ip_address, inst2.public_ip_address]) + + # all the ips are in a VPC + addresses = list( + service.vpc_addresses.filter(Filters=[{"Name": "domain", "Values": ["vpc"]}]) + ) + len(addresses).should.equal(3) diff --git a/tests/test_ec2/test_elastic_network_interfaces.py b/tests/test_ec2/test_elastic_network_interfaces.py index 05b45fda9..4e502586e 100644 --- a/tests/test_ec2/test_elastic_network_interfaces.py +++ b/tests/test_ec2/test_elastic_network_interfaces.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -19,16 +20,17 @@ import json @mock_ec2_deprecated def test_elastic_network_interfaces(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") with assert_raises(EC2ResponseError) as ex: eni = conn.create_network_interface(subnet.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) eni = conn.create_network_interface(subnet.id) @@ -37,14 +39,15 @@ def test_elastic_network_interfaces(): eni = all_enis[0] eni.groups.should.have.length_of(0) eni.private_ip_addresses.should.have.length_of(1) - eni.private_ip_addresses[0].private_ip_address.startswith('10.').should.be.true + eni.private_ip_addresses[0].private_ip_address.startswith("10.").should.be.true with assert_raises(EC2ResponseError) as ex: conn.delete_network_interface(eni.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) conn.delete_network_interface(eni.id) @@ -53,25 +56,25 @@ def test_elastic_network_interfaces(): with assert_raises(EC2ResponseError) as cm: conn.delete_network_interface(eni.id) - cm.exception.error_code.should.equal('InvalidNetworkInterfaceID.NotFound') + cm.exception.error_code.should.equal("InvalidNetworkInterfaceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_elastic_network_interfaces_subnet_validation(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.create_network_interface("subnet-abcd1234") - cm.exception.error_code.should.equal('InvalidSubnetID.NotFound') + cm.exception.error_code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_elastic_network_interfaces_with_private_ip(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") private_ip = "54.0.0.1" @@ -89,15 +92,18 @@ def test_elastic_network_interfaces_with_private_ip(): @mock_ec2_deprecated def test_elastic_network_interfaces_with_groups(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) conn.create_network_interface( - subnet.id, groups=[security_group1.id, security_group2.id]) + subnet.id, groups=[security_group1.id, security_group2.id] + ) all_enis = conn.get_all_network_interfaces() all_enis.should.have.length_of(1) @@ -105,19 +111,22 @@ def test_elastic_network_interfaces_with_groups(): eni = all_enis[0] eni.groups.should.have.length_of(2) set([group.id for group in eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) @requires_boto_gte("2.12.0") @mock_ec2_deprecated def test_elastic_network_interfaces_modify_attribute(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) conn.create_network_interface(subnet.id, groups=[security_group1.id]) all_enis = conn.get_all_network_interfaces() @@ -129,14 +138,15 @@ def test_elastic_network_interfaces_modify_attribute(): with assert_raises(EC2ResponseError) as ex: conn.modify_network_interface_attribute( - eni.id, 'groupset', [security_group2.id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + eni.id, "groupset", [security_group2.id], dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) - conn.modify_network_interface_attribute( - eni.id, 'groupset', [security_group2.id]) + conn.modify_network_interface_attribute(eni.id, "groupset", [security_group2.id]) all_enis = conn.get_all_network_interfaces() all_enis.should.have.length_of(1) @@ -148,20 +158,22 @@ def test_elastic_network_interfaces_modify_attribute(): @mock_ec2_deprecated def test_elastic_network_interfaces_filtering(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) eni1 = conn.create_network_interface( - subnet.id, groups=[security_group1.id, security_group2.id]) - eni2 = conn.create_network_interface( - subnet.id, groups=[security_group1.id]) - eni3 = conn.create_network_interface(subnet.id, description='test description') + subnet.id, groups=[security_group1.id, security_group2.id] + ) + eni2 = conn.create_network_interface(subnet.id, groups=[security_group1.id]) + eni3 = conn.create_network_interface(subnet.id, description="test description") all_enis = conn.get_all_network_interfaces() all_enis.should.have.length_of(3) @@ -173,280 +185,322 @@ def test_elastic_network_interfaces_filtering(): # Filter by ENI ID enis_by_id = conn.get_all_network_interfaces( - filters={'network-interface-id': eni1.id}) + filters={"network-interface-id": eni1.id} + ) enis_by_id.should.have.length_of(1) set([eni.id for eni in enis_by_id]).should.equal(set([eni1.id])) # Filter by Security Group enis_by_group = conn.get_all_network_interfaces( - filters={'group-id': security_group1.id}) + filters={"group-id": security_group1.id} + ) enis_by_group.should.have.length_of(2) set([eni.id for eni in enis_by_group]).should.equal(set([eni1.id, eni2.id])) # Filter by ENI ID and Security Group enis_by_group = conn.get_all_network_interfaces( - filters={'network-interface-id': eni1.id, 'group-id': security_group1.id}) + filters={"network-interface-id": eni1.id, "group-id": security_group1.id} + ) enis_by_group.should.have.length_of(1) set([eni.id for eni in enis_by_group]).should.equal(set([eni1.id])) # Filter by Description enis_by_description = conn.get_all_network_interfaces( - filters={'description': eni3.description }) + filters={"description": eni3.description} + ) enis_by_description.should.have.length_of(1) enis_by_description[0].description.should.equal(eni3.description) # Unsupported filter conn.get_all_network_interfaces.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2 def test_elastic_network_interfaces_get_by_tag_name(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) with assert_raises(ClientError) as ex: - eni1.create_tags(Tags=[{'Key': 'Name', 'Value': 'eni1'}], DryRun=True) - ex.exception.response['Error']['Code'].should.equal('DryRunOperation') - ex.exception.response['ResponseMetadata'][ - 'HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + eni1.create_tags(Tags=[{"Key": "Name", "Value": "eni1"}], DryRun=True) + ex.exception.response["Error"]["Code"].should.equal("DryRunOperation") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) - eni1.create_tags(Tags=[{'Key': 'Name', 'Value': 'eni1'}]) + eni1.create_tags(Tags=[{"Key": "Name", "Value": "eni1"}]) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'tag:Name', 'Values': ['eni1']}] + filters = [{"Name": "tag:Name", "Values": ["eni1"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'tag:Name', 'Values': ['wrong-name']}] + filters = [{"Name": "tag:Name", "Values": ["wrong-name"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_availability_zone(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet1 = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.1.0/24', AvailabilityZone='us-west-2b') + VpcId=vpc.id, CidrBlock="10.0.1.0/24", AvailabilityZone="us-west-2b" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet1.id, PrivateIpAddress='10.0.0.15') + SubnetId=subnet1.id, PrivateIpAddress="10.0.0.15" + ) eni2 = ec2.create_network_interface( - SubnetId=subnet2.id, PrivateIpAddress='10.0.1.15') + SubnetId=subnet2.id, PrivateIpAddress="10.0.1.15" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id, eni2.id]) - filters = [{'Name': 'availability-zone', 'Values': ['us-west-2a']}] + filters = [{"Name": "availability-zone", "Values": ["us-west-2a"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'availability-zone', 'Values': ['us-west-2c']}] + filters = [{"Name": "availability-zone", "Values": ["us-west-2c"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_private_ip(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'private-ip-address', 'Values': ['10.0.10.5']}] + filters = [{"Name": "private-ip-address", "Values": ["10.0.10.5"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'private-ip-address', 'Values': ['10.0.10.10']}] + filters = [{"Name": "private-ip-address", "Values": ["10.0.10.10"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) - filters = [{'Name': 'addresses.private-ip-address', 'Values': ['10.0.10.5']}] + filters = [{"Name": "addresses.private-ip-address", "Values": ["10.0.10.5"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'addresses.private-ip-address', 'Values': ['10.0.10.10']}] + filters = [{"Name": "addresses.private-ip-address", "Values": ["10.0.10.10"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_vpc_id(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'vpc-id', 'Values': [subnet.vpc_id]}] + filters = [{"Name": "vpc-id", "Values": [subnet.vpc_id]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'vpc-id', 'Values': ['vpc-aaaa1111']}] + filters = [{"Name": "vpc-id", "Values": ["vpc-aaaa1111"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_subnet_id(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'subnet-id', 'Values': [subnet.id]}] + filters = [{"Name": "subnet-id", "Values": [subnet.id]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'subnet-id', 'Values': ['subnet-aaaa1111']}] + filters = [{"Name": "subnet-id", "Values": ["subnet-aaaa1111"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_description(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5', Description='test interface') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5", Description="test interface" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'description', 'Values': [eni1.description]}] + filters = [{"Name": "description", "Values": [eni1.description]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'description', 'Values': ['bad description']}] + filters = [{"Name": "description", "Values": ["bad description"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_describe_network_interfaces_with_filter(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5', Description='test interface') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5", Description="test interface" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) # Filter by network-interface-id response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'network-interface-id', 'Values': [eni1.id]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "network-interface-id", "Values": [eni1.id]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'network-interface-id', 'Values': ['bad-id']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "network-interface-id", "Values": ["bad-id"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by private-ip-address response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'private-ip-address', 'Values': [eni1.private_ip_address]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "private-ip-address", "Values": [eni1.private_ip_address]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'private-ip-address', 'Values': ['11.11.11.11']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "private-ip-address", "Values": ["11.11.11.11"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by sunet-id response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'subnet-id', 'Values': [eni1.subnet.id]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "subnet-id", "Values": [eni1.subnet.id]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'subnet-id', 'Values': ['sn-bad-id']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "subnet-id", "Values": ["sn-bad-id"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by description response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'description', 'Values': [eni1.description]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "description", "Values": [eni1.description]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'description', 'Values': ['bad description']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "description", "Values": ["bad description"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by multiple filters response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'private-ip-address', 'Values': [eni1.private_ip_address]}, - {'Name': 'network-interface-id', 'Values': [eni1.id]}, - {'Name': 'subnet-id', 'Values': [eni1.subnet.id]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[ + {"Name": "private-ip-address", "Values": [eni1.private_ip_address]}, + {"Name": "network-interface-id", "Values": [eni1.id]}, + {"Name": "subnet-id", "Values": [eni1.subnet.id]}, + ] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) @mock_ec2_deprecated @@ -455,19 +509,19 @@ def test_elastic_network_interfaces_cloudformation(): template = vpc_eni.template template_json = json.dumps(template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=template_json, - ) + conn.create_stack("test_stack", template_body=template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") eni = ec2_conn.get_all_network_interfaces()[0] eni.private_ip_addresses.should.have.length_of(1) stack = conn.describe_stacks()[0] resources = stack.describe_resources() - cfn_eni = [resource for resource in resources if resource.resource_type == - 'AWS::EC2::NetworkInterface'][0] + cfn_eni = [ + resource + for resource in resources + if resource.resource_type == "AWS::EC2::NetworkInterface" + ][0] cfn_eni.physical_resource_id.should.equal(eni.id) outputs = {output.key: output.value for output in stack.outputs} - outputs['ENIIpAddress'].should.equal(eni.private_ip_addresses[0].private_ip_address) + outputs["ENIIpAddress"].should.equal(eni.private_ip_addresses[0].private_ip_address) diff --git a/tests/test_ec2/test_general.py b/tests/test_ec2/test_general.py index 7249af6a2..7b8f3bd53 100644 --- a/tests/test_ec2/test_general.py +++ b/tests/test_ec2/test_general.py @@ -1,42 +1,41 @@ -from __future__ import unicode_literals -# Ensure 'assert_raises' context manager support for Python 2.6 -import tests.backport_assert_raises -from nose.tools import assert_raises - -import boto -import boto3 -from boto.exception import EC2ResponseError -import sure # noqa - -from moto import mock_ec2_deprecated, mock_ec2 - - -@mock_ec2_deprecated -def test_console_output(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') - instance_id = reservation.instances[0].id - output = conn.get_console_output(instance_id) - output.output.should_not.equal(None) - - -@mock_ec2_deprecated -def test_console_output_without_instance(): - conn = boto.connect_ec2('the_key', 'the_secret') - - with assert_raises(EC2ResponseError) as cm: - conn.get_console_output('i-1234abcd') - cm.exception.code.should.equal('InvalidInstanceID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2 -def test_console_output_boto3(): - conn = boto3.resource('ec2', 'us-east-1') - instances = conn.create_instances(ImageId='ami-1234abcd', - MinCount=1, - MaxCount=1) - - output = instances[0].console_output() - output.get('Output').should_not.equal(None) +from __future__ import unicode_literals + +# Ensure 'assert_raises' context manager support for Python 2.6 +import tests.backport_assert_raises +from nose.tools import assert_raises + +import boto +import boto3 +from boto.exception import EC2ResponseError +import sure # noqa + +from moto import mock_ec2_deprecated, mock_ec2 + + +@mock_ec2_deprecated +def test_console_output(): + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") + instance_id = reservation.instances[0].id + output = conn.get_console_output(instance_id) + output.output.should_not.equal(None) + + +@mock_ec2_deprecated +def test_console_output_without_instance(): + conn = boto.connect_ec2("the_key", "the_secret") + + with assert_raises(EC2ResponseError) as cm: + conn.get_console_output("i-1234abcd") + cm.exception.code.should.equal("InvalidInstanceID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2 +def test_console_output_boto3(): + conn = boto3.resource("ec2", "us-east-1") + instances = conn.create_instances(ImageId="ami-1234abcd", MinCount=1, MaxCount=1) + + output = instances[0].console_output() + output.get("Output").should_not.equal(None) diff --git a/tests/test_ec2/test_instances.py b/tests/test_ec2/test_instances.py index a83384709..041bc8c85 100644 --- a/tests/test_ec2/test_instances.py +++ b/tests/test_ec2/test_instances.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 from botocore.exceptions import ClientError @@ -30,13 +31,14 @@ def add_servers(ami_id, count): @mock_ec2_deprecated def test_add_servers(): - add_servers('ami-1234abcd', 2) + add_servers("ami-1234abcd", 2) conn = boto.connect_ec2() reservations = conn.get_all_instances() assert len(reservations) == 2 instance1 = reservations[0].instances[0] - assert instance1.image_id == 'ami-1234abcd' + assert instance1.image_id == "ami-1234abcd" + ############################################ @@ -47,17 +49,18 @@ def test_instance_launch_and_terminate(): conn = boto.ec2.connect_to_region("us-east-1") with assert_raises(EC2ResponseError) as ex: - reservation = conn.run_instances('ami-1234abcd', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + reservation = conn.run_instances("ami-1234abcd", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RunInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RunInstance operation: Request would have succeeded, but DryRun flag is set" + ) - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") reservation.should.be.a(Reservation) reservation.instances.should.have.length_of(1) instance = reservation.instances[0] - instance.state.should.equal('pending') + instance.state.should.equal("pending") reservations = conn.get_all_instances() reservations.should.have.length_of(1) @@ -66,47 +69,46 @@ def test_instance_launch_and_terminate(): instances.should.have.length_of(1) instance = instances[0] instance.id.should.equal(instance.id) - instance.state.should.equal('running') + instance.state.should.equal("running") instance.launch_time.should.equal("2014-01-01T05:00:00.000Z") instance.vpc_id.should.equal(None) - instance.placement.should.equal('us-east-1a') + instance.placement.should.equal("us-east-1a") root_device_name = instance.root_device_name - instance.block_device_mapping[ - root_device_name].status.should.equal('in-use') + instance.block_device_mapping[root_device_name].status.should.equal("in-use") volume_id = instance.block_device_mapping[root_device_name].volume_id - volume_id.should.match(r'vol-\w+') + volume_id.should.match(r"vol-\w+") volume = conn.get_all_volumes(volume_ids=[volume_id])[0] volume.attach_data.instance_id.should.equal(instance.id) - volume.status.should.equal('in-use') + volume.status.should.equal("in-use") with assert_raises(EC2ResponseError) as ex: conn.terminate_instances([instance.id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the TerminateInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the TerminateInstance operation: Request would have succeeded, but DryRun flag is set" + ) conn.terminate_instances([instance.id]) reservations = conn.get_all_instances() instance = reservations[0].instances[0] - instance.state.should.equal('terminated') + instance.state.should.equal("terminated") @mock_ec2_deprecated def test_terminate_empty_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.terminate_instances.when.called_with( - []).should.throw(EC2ResponseError) + conn = boto.connect_ec2("the_key", "the_secret") + conn.terminate_instances.when.called_with([]).should.throw(EC2ResponseError) @freeze_time("2014-01-01 05:00:00") @mock_ec2_deprecated def test_instance_attach_volume(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] vol1 = conn.create_volume(size=36, zone=conn.region.name) @@ -124,20 +126,22 @@ def test_instance_attach_volume(): instance.block_device_mapping.should.have.length_of(3) - for v in conn.get_all_volumes(volume_ids=[instance.block_device_mapping['/dev/sdc1'].volume_id]): + for v in conn.get_all_volumes( + volume_ids=[instance.block_device_mapping["/dev/sdc1"].volume_id] + ): v.attach_data.instance_id.should.equal(instance.id) # can do due to freeze_time decorator. v.attach_data.attach_time.should.equal(instance.launch_time) # can do due to freeze_time decorator. v.create_time.should.equal(instance.launch_time) v.region.name.should.equal(instance.region.name) - v.status.should.equal('in-use') + v.status.should.equal("in-use") @mock_ec2_deprecated def test_get_instances_by_id(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=2) + reservation = conn.run_instances("ami-1234abcd", min_count=2) instance1, instance2 = reservation.instances reservations = conn.get_all_instances(instance_ids=[instance1.id]) @@ -146,8 +150,7 @@ def test_get_instances_by_id(): reservation.instances.should.have.length_of(1) reservation.instances[0].id.should.equal(instance1.id) - reservations = conn.get_all_instances( - instance_ids=[instance1.id, instance2.id]) + reservations = conn.get_all_instances(instance_ids=[instance1.id, instance2.id]) reservations.should.have.length_of(1) reservation = reservations[0] reservation.instances.should.have.length_of(2) @@ -157,78 +160,64 @@ def test_get_instances_by_id(): # Call get_all_instances with a bad id should raise an error with assert_raises(EC2ResponseError) as cm: conn.get_all_instances(instance_ids=[instance1.id, "i-1234abcd"]) - cm.exception.code.should.equal('InvalidInstanceID.NotFound') + cm.exception.code.should.equal("InvalidInstanceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_get_paginated_instances(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - conn = boto3.resource('ec2', 'us-east-1') + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + conn = boto3.resource("ec2", "us-east-1") for i in range(100): - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1) + conn.create_instances(ImageId=image_id, MinCount=1, MaxCount=1) resp = client.describe_instances(MaxResults=50) - reservations = resp['Reservations'] + reservations = resp["Reservations"] reservations.should.have.length_of(50) - next_token = resp['NextToken'] + next_token = resp["NextToken"] next_token.should_not.be.none resp2 = client.describe_instances(NextToken=next_token) - reservations.extend(resp2['Reservations']) + reservations.extend(resp2["Reservations"]) reservations.should.have.length_of(100) - assert 'NextToken' not in resp2.keys() + assert "NextToken" not in resp2.keys() @mock_ec2 def test_create_with_tags(): - ec2 = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.client("ec2", region_name="us-west-2") instances = ec2.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) - assert 'Tags' in instances['Instances'][0] - len(instances['Instances'][0]['Tags']).should.equal(3) + assert "Tags" in instances["Instances"][0] + len(instances["Instances"][0]["Tags"]).should.equal(3) @mock_ec2_deprecated def test_get_instances_filtering_by_state(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances conn.terminate_instances([instance1.id]) - reservations = conn.get_all_instances( - filters={'instance-state-name': 'running'}) + reservations = conn.get_all_instances(filters={"instance-state-name": "running"}) reservations.should.have.length_of(1) # Since we terminated instance1, only instance2 and instance3 should be # returned @@ -236,13 +225,15 @@ def test_get_instances_filtering_by_state(): set(instance_ids).should.equal(set([instance2.id, instance3.id])) reservations = conn.get_all_instances( - [instance2.id], filters={'instance-state-name': 'running'}) + [instance2.id], filters={"instance-state-name": "running"} + ) reservations.should.have.length_of(1) instance_ids = [instance.id for instance in reservations[0].instances] instance_ids.should.equal([instance2.id]) reservations = conn.get_all_instances( - [instance2.id], filters={'instance-state-name': 'terminated'}) + [instance2.id], filters={"instance-state-name": "terminated"} + ) list(reservations).should.equal([]) # get_all_instances should still return all 3 @@ -250,60 +241,58 @@ def test_get_instances_filtering_by_state(): reservations[0].instances.should.have.length_of(3) conn.get_all_instances.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2_deprecated def test_get_instances_filtering_by_instance_id(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - reservations = conn.get_all_instances( - filters={'instance-id': instance1.id}) + reservations = conn.get_all_instances(filters={"instance-id": instance1.id}) # get_all_instances should return just instance1 reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance1.id) reservations = conn.get_all_instances( - filters={'instance-id': [instance1.id, instance2.id]}) + filters={"instance-id": [instance1.id, instance2.id]} + ) # get_all_instances should return two reservations[0].instances.should.have.length_of(2) - reservations = conn.get_all_instances( - filters={'instance-id': 'non-existing-id'}) + reservations = conn.get_all_instances(filters={"instance-id": "non-existing-id"}) reservations.should.have.length_of(0) @mock_ec2_deprecated def test_get_instances_filtering_by_instance_type(): conn = boto.connect_ec2() - reservation1 = conn.run_instances('ami-1234abcd', instance_type='m1.small') + reservation1 = conn.run_instances("ami-1234abcd", instance_type="m1.small") instance1 = reservation1.instances[0] - reservation2 = conn.run_instances('ami-1234abcd', instance_type='m1.small') + reservation2 = conn.run_instances("ami-1234abcd", instance_type="m1.small") instance2 = reservation2.instances[0] - reservation3 = conn.run_instances('ami-1234abcd', instance_type='t1.micro') + reservation3 = conn.run_instances("ami-1234abcd", instance_type="t1.micro") instance3 = reservation3.instances[0] - reservations = conn.get_all_instances( - filters={'instance-type': 'm1.small'}) + reservations = conn.get_all_instances(filters={"instance-type": "m1.small"}) # get_all_instances should return instance1,2 reservations.should.have.length_of(2) reservations[0].instances.should.have.length_of(1) reservations[1].instances.should.have.length_of(1) - instance_ids = [reservations[0].instances[0].id, - reservations[1].instances[0].id] + instance_ids = [reservations[0].instances[0].id, reservations[1].instances[0].id] set(instance_ids).should.equal(set([instance1.id, instance2.id])) - reservations = conn.get_all_instances( - filters={'instance-type': 't1.micro'}) + reservations = conn.get_all_instances(filters={"instance-type": "t1.micro"}) # get_all_instances should return one reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance3.id) reservations = conn.get_all_instances( - filters={'instance-type': ['t1.micro', 'm1.small']}) + filters={"instance-type": ["t1.micro", "m1.small"]} + ) reservations.should.have.length_of(3) reservations[0].instances.should.have.length_of(1) reservations[1].instances.should.have.length_of(1) @@ -313,10 +302,9 @@ def test_get_instances_filtering_by_instance_type(): reservations[1].instances[0].id, reservations[2].instances[0].id, ] - set(instance_ids).should.equal( - set([instance1.id, instance2.id, instance3.id])) + set(instance_ids).should.equal(set([instance1.id, instance2.id, instance3.id])) - reservations = conn.get_all_instances(filters={'instance-type': 'bogus'}) + reservations = conn.get_all_instances(filters={"instance-type": "bogus"}) # bogus instance-type should return none reservations.should.have.length_of(0) @@ -324,19 +312,21 @@ def test_get_instances_filtering_by_instance_type(): @mock_ec2_deprecated def test_get_instances_filtering_by_reason_code(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances instance1.stop() instance2.terminate() reservations = conn.get_all_instances( - filters={'state-reason-code': 'Client.UserInitiatedShutdown'}) + filters={"state-reason-code": "Client.UserInitiatedShutdown"} + ) # get_all_instances should return instance1 and instance2 reservations[0].instances.should.have.length_of(2) set([instance1.id, instance2.id]).should.equal( - set([i.id for i in reservations[0].instances])) + set([i.id for i in reservations[0].instances]) + ) - reservations = conn.get_all_instances(filters={'state-reason-code': ''}) + reservations = conn.get_all_instances(filters={"state-reason-code": ""}) # get_all_instances should return instance 3 reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance3.id) @@ -345,15 +335,18 @@ def test_get_instances_filtering_by_reason_code(): @mock_ec2_deprecated def test_get_instances_filtering_by_source_dest_check(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=2) + reservation = conn.run_instances("ami-1234abcd", min_count=2) instance1, instance2 = reservation.instances conn.modify_instance_attribute( - instance1.id, attribute='sourceDestCheck', value=False) + instance1.id, attribute="sourceDestCheck", value=False + ) source_dest_check_false = conn.get_all_instances( - filters={'source-dest-check': 'false'}) + filters={"source-dest-check": "false"} + ) source_dest_check_true = conn.get_all_instances( - filters={'source-dest-check': 'true'}) + filters={"source-dest-check": "true"} + ) source_dest_check_false[0].instances.should.have.length_of(1) source_dest_check_false[0].instances[0].id.should.equal(instance1.id) @@ -364,27 +357,25 @@ def test_get_instances_filtering_by_source_dest_check(): @mock_ec2_deprecated def test_get_instances_filtering_by_vpc_id(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc1 = conn.create_vpc("10.0.0.0/16") subnet1 = conn.create_subnet(vpc1.id, "10.0.0.0/27") - reservation1 = conn.run_instances( - 'ami-1234abcd', min_count=1, subnet_id=subnet1.id) + reservation1 = conn.run_instances("ami-1234abcd", min_count=1, subnet_id=subnet1.id) instance1 = reservation1.instances[0] vpc2 = conn.create_vpc("10.1.0.0/16") subnet2 = conn.create_subnet(vpc2.id, "10.1.0.0/27") - reservation2 = conn.run_instances( - 'ami-1234abcd', min_count=1, subnet_id=subnet2.id) + reservation2 = conn.run_instances("ami-1234abcd", min_count=1, subnet_id=subnet2.id) instance2 = reservation2.instances[0] - reservations1 = conn.get_all_instances(filters={'vpc-id': vpc1.id}) + reservations1 = conn.get_all_instances(filters={"vpc-id": vpc1.id}) reservations1.should.have.length_of(1) reservations1[0].instances.should.have.length_of(1) reservations1[0].instances[0].id.should.equal(instance1.id) reservations1[0].instances[0].vpc_id.should.equal(vpc1.id) reservations1[0].instances[0].subnet_id.should.equal(subnet1.id) - reservations2 = conn.get_all_instances(filters={'vpc-id': vpc2.id}) + reservations2 = conn.get_all_instances(filters={"vpc-id": vpc2.id}) reservations2.should.have.length_of(1) reservations2[0].instances.should.have.length_of(1) reservations2[0].instances[0].id.should.equal(instance2.id) @@ -395,111 +386,105 @@ def test_get_instances_filtering_by_vpc_id(): @mock_ec2_deprecated def test_get_instances_filtering_by_architecture(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=1) + reservation = conn.run_instances("ami-1234abcd", min_count=1) instance = reservation.instances - reservations = conn.get_all_instances(filters={'architecture': 'x86_64'}) + reservations = conn.get_all_instances(filters={"architecture": "x86_64"}) # get_all_instances should return the instance reservations[0].instances.should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_image_id(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - conn = boto3.resource('ec2', 'us-east-1') - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1) + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + conn = boto3.resource("ec2", "us-east-1") + conn.create_instances(ImageId=image_id, MinCount=1, MaxCount=1) - reservations = client.describe_instances(Filters=[{'Name': 'image-id', - 'Values': [image_id]}])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + reservations = client.describe_instances( + Filters=[{"Name": "image-id", "Values": [image_id]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_private_dns(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - conn = boto3.resource('ec2', 'us-east-1') - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - PrivateIpAddress='10.0.0.1') - reservations = client.describe_instances(Filters=[ - {'Name': 'private-dns-name', 'Values': ['ip-10-0-0-1.ec2.internal']} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + conn = boto3.resource("ec2", "us-east-1") + conn.create_instances( + ImageId=image_id, MinCount=1, MaxCount=1, PrivateIpAddress="10.0.0.1" + ) + reservations = client.describe_instances( + Filters=[{"Name": "private-dns-name", "Values": ["ip-10-0-0-1.ec2.internal"]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_ni_private_dns(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-west-2') - conn = boto3.resource('ec2', 'us-west-2') - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - PrivateIpAddress='10.0.0.1') - reservations = client.describe_instances(Filters=[ - {'Name': 'network-interface.private-dns-name', 'Values': ['ip-10-0-0-1.us-west-2.compute.internal']} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-west-2") + conn = boto3.resource("ec2", "us-west-2") + conn.create_instances( + ImageId=image_id, MinCount=1, MaxCount=1, PrivateIpAddress="10.0.0.1" + ) + reservations = client.describe_instances( + Filters=[ + { + "Name": "network-interface.private-dns-name", + "Values": ["ip-10-0-0-1.us-west-2.compute.internal"], + } + ] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_instance_group_name(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - client.create_security_group( - Description='test', - GroupName='test_sg' + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + client.create_security_group(Description="test", GroupName="test_sg") + client.run_instances( + ImageId=image_id, MinCount=1, MaxCount=1, SecurityGroups=["test_sg"] ) - client.run_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - SecurityGroups=['test_sg']) - reservations = client.describe_instances(Filters=[ - {'Name': 'instance.group-name', 'Values': ['test_sg']} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + reservations = client.describe_instances( + Filters=[{"Name": "instance.group-name", "Values": ["test_sg"]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_instance_group_id(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - create_sg = client.create_security_group( - Description='test', - GroupName='test_sg' + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + create_sg = client.create_security_group(Description="test", GroupName="test_sg") + group_id = create_sg["GroupId"] + client.run_instances( + ImageId=image_id, MinCount=1, MaxCount=1, SecurityGroups=["test_sg"] ) - group_id = create_sg['GroupId'] - client.run_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - SecurityGroups=['test_sg']) - reservations = client.describe_instances(Filters=[ - {'Name': 'instance.group-id', 'Values': [group_id]} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + reservations = client.describe_instances( + Filters=[{"Name": "instance.group-id", "Values": [group_id]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2_deprecated def test_get_instances_filtering_by_tag(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - instance1.add_tag('tag1', 'value1') - instance1.add_tag('tag2', 'value2') - instance2.add_tag('tag1', 'value1') - instance2.add_tag('tag2', 'wrong value') - instance3.add_tag('tag2', 'value2') + instance1.add_tag("tag1", "value1") + instance1.add_tag("tag2", "value2") + instance2.add_tag("tag1", "value1") + instance2.add_tag("tag2", "wrong value") + instance3.add_tag("tag2", "value2") - reservations = conn.get_all_instances(filters={'tag:tag0': 'value0'}) + reservations = conn.get_all_instances(filters={"tag:tag0": "value0"}) # get_all_instances should return no instances reservations.should.have.length_of(0) - reservations = conn.get_all_instances(filters={'tag:tag1': 'value1'}) + reservations = conn.get_all_instances(filters={"tag:tag1": "value1"}) # get_all_instances should return both instances with this tag value reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(2) @@ -507,21 +492,22 @@ def test_get_instances_filtering_by_tag(): reservations[0].instances[1].id.should.equal(instance2.id) reservations = conn.get_all_instances( - filters={'tag:tag1': 'value1', 'tag:tag2': 'value2'}) + filters={"tag:tag1": "value1", "tag:tag2": "value2"} + ) # get_all_instances should return the instance with both tag values reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance1.id) reservations = conn.get_all_instances( - filters={'tag:tag1': 'value1', 'tag:tag2': 'value2'}) + filters={"tag:tag1": "value1", "tag:tag2": "value2"} + ) # get_all_instances should return the instance with both tag values reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance1.id) - reservations = conn.get_all_instances( - filters={'tag:tag2': ['value2', 'bogus']}) + reservations = conn.get_all_instances(filters={"tag:tag2": ["value2", "bogus"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -533,27 +519,26 @@ def test_get_instances_filtering_by_tag(): @mock_ec2_deprecated def test_get_instances_filtering_by_tag_value(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - instance1.add_tag('tag1', 'value1') - instance1.add_tag('tag2', 'value2') - instance2.add_tag('tag1', 'value1') - instance2.add_tag('tag2', 'wrong value') - instance3.add_tag('tag2', 'value2') + instance1.add_tag("tag1", "value1") + instance1.add_tag("tag2", "value2") + instance2.add_tag("tag1", "value1") + instance2.add_tag("tag2", "wrong value") + instance3.add_tag("tag2", "value2") - reservations = conn.get_all_instances(filters={'tag-value': 'value0'}) + reservations = conn.get_all_instances(filters={"tag-value": "value0"}) # get_all_instances should return no instances reservations.should.have.length_of(0) - reservations = conn.get_all_instances(filters={'tag-value': 'value1'}) + reservations = conn.get_all_instances(filters={"tag-value": "value1"}) # get_all_instances should return both instances with this tag value reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(2) reservations[0].instances[0].id.should.equal(instance1.id) reservations[0].instances[1].id.should.equal(instance2.id) - reservations = conn.get_all_instances( - filters={'tag-value': ['value2', 'value1']}) + reservations = conn.get_all_instances(filters={"tag-value": ["value2", "value1"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -562,8 +547,7 @@ def test_get_instances_filtering_by_tag_value(): reservations[0].instances[1].id.should.equal(instance2.id) reservations[0].instances[2].id.should.equal(instance3.id) - reservations = conn.get_all_instances( - filters={'tag-value': ['value2', 'bogus']}) + reservations = conn.get_all_instances(filters={"tag-value": ["value2", "bogus"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -575,27 +559,26 @@ def test_get_instances_filtering_by_tag_value(): @mock_ec2_deprecated def test_get_instances_filtering_by_tag_name(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - instance1.add_tag('tag1') - instance1.add_tag('tag2') - instance2.add_tag('tag1') - instance2.add_tag('tag2X') - instance3.add_tag('tag3') + instance1.add_tag("tag1") + instance1.add_tag("tag2") + instance2.add_tag("tag1") + instance2.add_tag("tag2X") + instance3.add_tag("tag3") - reservations = conn.get_all_instances(filters={'tag-key': 'tagX'}) + reservations = conn.get_all_instances(filters={"tag-key": "tagX"}) # get_all_instances should return no instances reservations.should.have.length_of(0) - reservations = conn.get_all_instances(filters={'tag-key': 'tag1'}) + reservations = conn.get_all_instances(filters={"tag-key": "tag1"}) # get_all_instances should return both instances with this tag value reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(2) reservations[0].instances[0].id.should.equal(instance1.id) reservations[0].instances[1].id.should.equal(instance2.id) - reservations = conn.get_all_instances( - filters={'tag-key': ['tag1', 'tag3']}) + reservations = conn.get_all_instances(filters={"tag-key": ["tag1", "tag3"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -607,8 +590,8 @@ def test_get_instances_filtering_by_tag_name(): @mock_ec2_deprecated def test_instance_start_and_stop(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', min_count=2) + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", min_count=2) instances = reservation.instances instances.should.have.length_of(2) @@ -616,103 +599,111 @@ def test_instance_start_and_stop(): with assert_raises(EC2ResponseError) as ex: stopped_instances = conn.stop_instances(instance_ids, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the StopInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the StopInstance operation: Request would have succeeded, but DryRun flag is set" + ) stopped_instances = conn.stop_instances(instance_ids) for instance in stopped_instances: - instance.state.should.equal('stopping') + instance.state.should.equal("stopping") with assert_raises(EC2ResponseError) as ex: - started_instances = conn.start_instances( - [instances[0].id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + started_instances = conn.start_instances([instances[0].id], dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the StartInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the StartInstance operation: Request would have succeeded, but DryRun flag is set" + ) started_instances = conn.start_instances([instances[0].id]) - started_instances[0].state.should.equal('pending') + started_instances[0].state.should.equal("pending") @mock_ec2_deprecated def test_instance_reboot(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: instance.reboot(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RebootInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RebootInstance operation: Request would have succeeded, but DryRun flag is set" + ) instance.reboot() - instance.state.should.equal('pending') + instance.state.should.equal("pending") @mock_ec2_deprecated def test_instance_attribute_instance_type(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: instance.modify_attribute("instanceType", "m1.small", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceType operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceType operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("instanceType", "m1.small") instance_attribute = instance.get_attribute("instanceType") instance_attribute.should.be.a(InstanceAttribute) - instance_attribute.get('instanceType').should.equal("m1.small") + instance_attribute.get("instanceType").should.equal("m1.small") @mock_ec2_deprecated def test_modify_instance_attribute_security_groups(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] - sg_id = conn.create_security_group('test security group', 'this is a test security group').id - sg_id2 = conn.create_security_group('test security group 2', 'this is a test security group 2').id + sg_id = conn.create_security_group( + "test security group", "this is a test security group" + ).id + sg_id2 = conn.create_security_group( + "test security group 2", "this is a test security group 2" + ).id with assert_raises(EC2ResponseError) as ex: instance.modify_attribute("groupSet", [sg_id, sg_id2], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("groupSet", [sg_id, sg_id2]) instance_attribute = instance.get_attribute("groupSet") instance_attribute.should.be.a(InstanceAttribute) - group_list = instance_attribute.get('groupSet') + group_list = instance_attribute.get("groupSet") any(g.id == sg_id for g in group_list).should.be.ok any(g.id == sg_id2 for g in group_list).should.be.ok @mock_ec2_deprecated def test_instance_attribute_user_data(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: - instance.modify_attribute( - "userData", "this is my user data", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + instance.modify_attribute("userData", "this is my user data", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyUserData operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyUserData operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("userData", "this is my user data") @@ -723,12 +714,12 @@ def test_instance_attribute_user_data(): @mock_ec2_deprecated def test_instance_attribute_source_dest_check(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] # Default value is true - instance.sourceDestCheck.should.equal('true') + instance.sourceDestCheck.should.equal("true") instance_attribute = instance.get_attribute("sourceDestCheck") instance_attribute.should.be.a(InstanceAttribute) @@ -738,15 +729,16 @@ def test_instance_attribute_source_dest_check(): with assert_raises(EC2ResponseError) as ex: instance.modify_attribute("sourceDestCheck", False, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifySourceDestCheck operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifySourceDestCheck operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("sourceDestCheck", False) instance.update() - instance.sourceDestCheck.should.equal('false') + instance.sourceDestCheck.should.equal("false") instance_attribute = instance.get_attribute("sourceDestCheck") instance_attribute.should.be.a(InstanceAttribute) @@ -756,7 +748,7 @@ def test_instance_attribute_source_dest_check(): instance.modify_attribute("sourceDestCheck", True) instance.update() - instance.sourceDestCheck.should.equal('true') + instance.sourceDestCheck.should.equal("true") instance_attribute = instance.get_attribute("sourceDestCheck") instance_attribute.should.be.a(InstanceAttribute) @@ -766,33 +758,32 @@ def test_instance_attribute_source_dest_check(): @mock_ec2_deprecated def test_user_data_with_run_instance(): user_data = b"some user data" - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', user_data=user_data) + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", user_data=user_data) instance = reservation.instances[0] instance_attribute = instance.get_attribute("userData") instance_attribute.should.be.a(InstanceAttribute) - retrieved_user_data = instance_attribute.get("userData").encode('utf-8') + retrieved_user_data = instance_attribute.get("userData").encode("utf-8") decoded_user_data = base64.decodestring(retrieved_user_data) decoded_user_data.should.equal(b"some user data") @mock_ec2_deprecated def test_run_instance_with_security_group_name(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - group = conn.create_security_group( - 'group1', "some description", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + group = conn.create_security_group("group1", "some description", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set" + ) - group = conn.create_security_group('group1', "some description") + group = conn.create_security_group("group1", "some description") - reservation = conn.run_instances('ami-1234abcd', - security_groups=['group1']) + reservation = conn.run_instances("ami-1234abcd", security_groups=["group1"]) instance = reservation.instances[0] instance.groups[0].id.should.equal(group.id) @@ -801,10 +792,9 @@ def test_run_instance_with_security_group_name(): @mock_ec2_deprecated def test_run_instance_with_security_group_id(): - conn = boto.connect_ec2('the_key', 'the_secret') - group = conn.create_security_group('group1', "some description") - reservation = conn.run_instances('ami-1234abcd', - security_group_ids=[group.id]) + conn = boto.connect_ec2("the_key", "the_secret") + group = conn.create_security_group("group1", "some description") + reservation = conn.run_instances("ami-1234abcd", security_group_ids=[group.id]) instance = reservation.instances[0] instance.groups[0].id.should.equal(group.id) @@ -813,8 +803,8 @@ def test_run_instance_with_security_group_id(): @mock_ec2_deprecated def test_run_instance_with_instance_type(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', instance_type="t1.micro") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", instance_type="t1.micro") instance = reservation.instances[0] instance.instance_type.should.equal("t1.micro") @@ -823,7 +813,7 @@ def test_run_instance_with_instance_type(): @mock_ec2_deprecated def test_run_instance_with_default_placement(): conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.placement.should.equal("us-east-1a") @@ -831,8 +821,8 @@ def test_run_instance_with_default_placement(): @mock_ec2_deprecated def test_run_instance_with_placement(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', placement="us-east-1b") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", placement="us-east-1b") instance = reservation.instances[0] instance.placement.should.equal("us-east-1b") @@ -840,11 +830,14 @@ def test_run_instance_with_placement(): @mock_ec2 def test_run_instance_with_subnet_boto3(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") ip_networks = [ - (ipaddress.ip_network('10.0.0.0/16'), ipaddress.ip_network('10.0.99.0/24')), - (ipaddress.ip_network('192.168.42.0/24'), ipaddress.ip_network('192.168.42.0/25')) + (ipaddress.ip_network("10.0.0.0/16"), ipaddress.ip_network("10.0.99.0/24")), + ( + ipaddress.ip_network("192.168.42.0/24"), + ipaddress.ip_network("192.168.42.0/25"), + ), ] # Tests instances are created with the correct IPs @@ -853,115 +846,104 @@ def test_run_instance_with_subnet_boto3(): CidrBlock=str(vpc_cidr), AmazonProvidedIpv6CidrBlock=False, DryRun=False, - InstanceTenancy='default' + InstanceTenancy="default", ) - vpc_id = resp['Vpc']['VpcId'] + vpc_id = resp["Vpc"]["VpcId"] - resp = client.create_subnet( - CidrBlock=str(subnet_cidr), - VpcId=vpc_id - ) - subnet_id = resp['Subnet']['SubnetId'] + resp = client.create_subnet(CidrBlock=str(subnet_cidr), VpcId=vpc_id) + subnet_id = resp["Subnet"]["SubnetId"] resp = client.run_instances( - ImageId='ami-1234abcd', - MaxCount=1, - MinCount=1, - SubnetId=subnet_id + ImageId="ami-1234abcd", MaxCount=1, MinCount=1, SubnetId=subnet_id ) - instance = resp['Instances'][0] - instance['SubnetId'].should.equal(subnet_id) + instance = resp["Instances"][0] + instance["SubnetId"].should.equal(subnet_id) - priv_ipv4 = ipaddress.ip_address(six.text_type(instance['PrivateIpAddress'])) + priv_ipv4 = ipaddress.ip_address(six.text_type(instance["PrivateIpAddress"])) subnet_cidr.should.contain(priv_ipv4) @mock_ec2 def test_run_instance_with_specified_private_ipv4(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") - vpc_cidr = ipaddress.ip_network('192.168.42.0/24') - subnet_cidr = ipaddress.ip_network('192.168.42.0/25') + vpc_cidr = ipaddress.ip_network("192.168.42.0/24") + subnet_cidr = ipaddress.ip_network("192.168.42.0/25") resp = client.create_vpc( CidrBlock=str(vpc_cidr), AmazonProvidedIpv6CidrBlock=False, DryRun=False, - InstanceTenancy='default' + InstanceTenancy="default", ) - vpc_id = resp['Vpc']['VpcId'] + vpc_id = resp["Vpc"]["VpcId"] - resp = client.create_subnet( - CidrBlock=str(subnet_cidr), - VpcId=vpc_id - ) - subnet_id = resp['Subnet']['SubnetId'] + resp = client.create_subnet(CidrBlock=str(subnet_cidr), VpcId=vpc_id) + subnet_id = resp["Subnet"]["SubnetId"] resp = client.run_instances( - ImageId='ami-1234abcd', + ImageId="ami-1234abcd", MaxCount=1, MinCount=1, SubnetId=subnet_id, - PrivateIpAddress='192.168.42.5' + PrivateIpAddress="192.168.42.5", ) - instance = resp['Instances'][0] - instance['SubnetId'].should.equal(subnet_id) - instance['PrivateIpAddress'].should.equal('192.168.42.5') + instance = resp["Instances"][0] + instance["SubnetId"].should.equal(subnet_id) + instance["PrivateIpAddress"].should.equal("192.168.42.5") @mock_ec2 def test_run_instance_mapped_public_ipv4(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") - vpc_cidr = ipaddress.ip_network('192.168.42.0/24') - subnet_cidr = ipaddress.ip_network('192.168.42.0/25') + vpc_cidr = ipaddress.ip_network("192.168.42.0/24") + subnet_cidr = ipaddress.ip_network("192.168.42.0/25") resp = client.create_vpc( CidrBlock=str(vpc_cidr), AmazonProvidedIpv6CidrBlock=False, DryRun=False, - InstanceTenancy='default' + InstanceTenancy="default", ) - vpc_id = resp['Vpc']['VpcId'] + vpc_id = resp["Vpc"]["VpcId"] - resp = client.create_subnet( - CidrBlock=str(subnet_cidr), - VpcId=vpc_id - ) - subnet_id = resp['Subnet']['SubnetId'] + resp = client.create_subnet(CidrBlock=str(subnet_cidr), VpcId=vpc_id) + subnet_id = resp["Subnet"]["SubnetId"] client.modify_subnet_attribute( - SubnetId=subnet_id, - MapPublicIpOnLaunch={'Value': True} + SubnetId=subnet_id, MapPublicIpOnLaunch={"Value": True} ) resp = client.run_instances( - ImageId='ami-1234abcd', - MaxCount=1, - MinCount=1, - SubnetId=subnet_id + ImageId="ami-1234abcd", MaxCount=1, MinCount=1, SubnetId=subnet_id ) - instance = resp['Instances'][0] - instance.should.contain('PublicDnsName') - instance.should.contain('PublicIpAddress') - len(instance['PublicDnsName']).should.be.greater_than(0) - len(instance['PublicIpAddress']).should.be.greater_than(0) + instance = resp["Instances"][0] + instance.should.contain("PublicDnsName") + instance.should.contain("PublicIpAddress") + len(instance["PublicDnsName"]).should.be.greater_than(0) + len(instance["PublicIpAddress"]).should.be.greater_than(0) @mock_ec2_deprecated def test_run_instance_with_nic_autocreated(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) private_ip = "10.0.0.1" - reservation = conn.run_instances('ami-1234abcd', subnet_id=subnet.id, - security_groups=[security_group1.name], - security_group_ids=[security_group2.id], - private_ip_address=private_ip) + reservation = conn.run_instances( + "ami-1234abcd", + subnet_id=subnet.id, + security_groups=[security_group1.name], + security_group_ids=[security_group2.id], + private_ip_address=private_ip, + ) instance = reservation.instances[0] all_enis = conn.get_all_network_interfaces() @@ -974,39 +956,52 @@ def test_run_instance_with_nic_autocreated(): instance.subnet_id.should.equal(subnet.id) instance.groups.should.have.length_of(2) set([group.id for group in instance.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) eni.subnet_id.should.equal(subnet.id) eni.groups.should.have.length_of(2) set([group.id for group in eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) eni.private_ip_addresses.should.have.length_of(1) eni.private_ip_addresses[0].private_ip_address.should.equal(private_ip) @mock_ec2_deprecated def test_run_instance_with_nic_preexisting(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) private_ip = "54.0.0.1" eni = conn.create_network_interface( - subnet.id, private_ip, groups=[security_group1.id]) + subnet.id, private_ip, groups=[security_group1.id] + ) # Boto requires NetworkInterfaceCollection of NetworkInterfaceSpecifications... # annoying, but generates the desired querystring. - from boto.ec2.networkinterface import NetworkInterfaceSpecification, NetworkInterfaceCollection + from boto.ec2.networkinterface import ( + NetworkInterfaceSpecification, + NetworkInterfaceCollection, + ) + interface = NetworkInterfaceSpecification( - network_interface_id=eni.id, device_index=0) + network_interface_id=eni.id, device_index=0 + ) interfaces = NetworkInterfaceCollection(interface) # end Boto objects - reservation = conn.run_instances('ami-1234abcd', network_interfaces=interfaces, - security_group_ids=[security_group2.id]) + reservation = conn.run_instances( + "ami-1234abcd", + network_interfaces=interfaces, + security_group_ids=[security_group2.id], + ) instance = reservation.instances[0] instance.subnet_id.should.equal(subnet.id) @@ -1021,26 +1016,29 @@ def test_run_instance_with_nic_preexisting(): instance_eni.subnet_id.should.equal(subnet.id) instance_eni.groups.should.have.length_of(2) set([group.id for group in instance_eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) instance_eni.private_ip_addresses.should.have.length_of(1) - instance_eni.private_ip_addresses[ - 0].private_ip_address.should.equal(private_ip) + instance_eni.private_ip_addresses[0].private_ip_address.should.equal(private_ip) @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_instance_with_nic_attach_detach(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) reservation = conn.run_instances( - 'ami-1234abcd', security_group_ids=[security_group1.id]) + "ami-1234abcd", security_group_ids=[security_group1.id] + ) instance = reservation.instances[0] eni = conn.create_network_interface(subnet.id, groups=[security_group2.id]) @@ -1049,17 +1047,16 @@ def test_instance_with_nic_attach_detach(): instance.interfaces.should.have.length_of(1) eni.groups.should.have.length_of(1) - set([group.id for group in eni.groups]).should.equal( - set([security_group2.id])) + set([group.id for group in eni.groups]).should.equal(set([security_group2.id])) # Attach with assert_raises(EC2ResponseError) as ex: - conn.attach_network_interface( - eni.id, instance.id, device_index=1, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.attach_network_interface(eni.id, instance.id, device_index=1, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AttachNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AttachNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) conn.attach_network_interface(eni.id, instance.id, device_index=1) @@ -1070,21 +1067,23 @@ def test_instance_with_nic_attach_detach(): instance_eni.id.should.equal(eni.id) instance_eni.groups.should.have.length_of(2) set([group.id for group in instance_eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) - eni = conn.get_all_network_interfaces( - filters={'network-interface-id': eni.id})[0] + eni = conn.get_all_network_interfaces(filters={"network-interface-id": eni.id})[0] eni.groups.should.have.length_of(2) set([group.id for group in eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) # Detach with assert_raises(EC2ResponseError) as ex: conn.detach_network_interface(instance_eni.attachment.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DetachNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DetachNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) conn.detach_network_interface(instance_eni.attachment.id) @@ -1092,35 +1091,35 @@ def test_instance_with_nic_attach_detach(): instance.update() instance.interfaces.should.have.length_of(1) - eni = conn.get_all_network_interfaces( - filters={'network-interface-id': eni.id})[0] + eni = conn.get_all_network_interfaces(filters={"network-interface-id": eni.id})[0] eni.groups.should.have.length_of(1) - set([group.id for group in eni.groups]).should.equal( - set([security_group2.id])) + set([group.id for group in eni.groups]).should.equal(set([security_group2.id])) # Detach with invalid attachment ID with assert_raises(EC2ResponseError) as cm: - conn.detach_network_interface('eni-attach-1234abcd') - cm.exception.code.should.equal('InvalidAttachmentID.NotFound') + conn.detach_network_interface("eni-attach-1234abcd") + cm.exception.code.should.equal("InvalidAttachmentID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_ec2_classic_has_public_ip_address(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", key_name="keypair_name") instance = reservation.instances[0] instance.ip_address.should_not.equal(None) - instance.public_dns_name.should.contain(instance.ip_address.replace('.', '-')) + instance.public_dns_name.should.contain(instance.ip_address.replace(".", "-")) instance.private_ip_address.should_not.equal(None) - instance.private_dns_name.should.contain(instance.private_ip_address.replace('.', '-')) + instance.private_dns_name.should.contain( + instance.private_ip_address.replace(".", "-") + ) @mock_ec2_deprecated def test_run_instance_with_keypair(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", key_name="keypair_name") instance = reservation.instances[0] instance.key_name.should.equal("keypair_name") @@ -1128,32 +1127,32 @@ def test_run_instance_with_keypair(): @mock_ec2_deprecated def test_describe_instance_status_no_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") all_status = conn.get_all_instance_status() len(all_status).should.equal(0) @mock_ec2_deprecated def test_describe_instance_status_with_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn = boto.connect_ec2("the_key", "the_secret") + conn.run_instances("ami-1234abcd", key_name="keypair_name") all_status = conn.get_all_instance_status() len(all_status).should.equal(1) - all_status[0].instance_status.status.should.equal('ok') - all_status[0].system_status.status.should.equal('ok') + all_status[0].instance_status.status.should.equal("ok") + all_status[0].system_status.status.should.equal("ok") @mock_ec2_deprecated def test_describe_instance_status_with_instance_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") # We want to filter based on this one - reservation = conn.run_instances('ami-1234abcd', key_name="keypair_name") + reservation = conn.run_instances("ami-1234abcd", key_name="keypair_name") instance = reservation.instances[0] # This is just to setup the test - conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn.run_instances("ami-1234abcd", key_name="keypair_name") all_status = conn.get_all_instance_status(instance_ids=[instance.id]) len(all_status).should.equal(1) @@ -1162,7 +1161,7 @@ def test_describe_instance_status_with_instance_filter(): # Call get_all_instance_status with a bad id should raise an error with assert_raises(EC2ResponseError) as cm: conn.get_all_instance_status(instance_ids=[instance.id, "i-1234abcd"]) - cm.exception.code.should.equal('InvalidInstanceID.NotFound') + cm.exception.code.should.equal("InvalidInstanceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -1170,8 +1169,8 @@ def test_describe_instance_status_with_instance_filter(): @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_describe_instance_status_with_non_running_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', min_count=3) + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances instance1.stop() instance2.terminate() @@ -1179,40 +1178,41 @@ def test_describe_instance_status_with_non_running_instances(): all_running_status = conn.get_all_instance_status() all_running_status.should.have.length_of(1) all_running_status[0].id.should.equal(instance3.id) - all_running_status[0].state_name.should.equal('running') + all_running_status[0].state_name.should.equal("running") all_status = conn.get_all_instance_status(include_all_instances=True) all_status.should.have.length_of(3) status1 = next((s for s in all_status if s.id == instance1.id), None) - status1.state_name.should.equal('stopped') + status1.state_name.should.equal("stopped") status2 = next((s for s in all_status if s.id == instance2.id), None) - status2.state_name.should.equal('terminated') + status2.state_name.should.equal("terminated") status3 = next((s for s in all_status if s.id == instance3.id), None) - status3.state_name.should.equal('running') + status3.state_name.should.equal("running") @mock_ec2_deprecated def test_get_instance_by_security_group(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - conn.run_instances('ami-1234abcd') + conn.run_instances("ami-1234abcd") instance = conn.get_only_instances()[0] - security_group = conn.create_security_group('test', 'test') + security_group = conn.create_security_group("test", "test") with assert_raises(EC2ResponseError) as ex: - conn.modify_instance_attribute(instance.id, "groupSet", [ - security_group.id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_instance_attribute( + instance.id, "groupSet", [security_group.id], dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set" + ) - conn.modify_instance_attribute( - instance.id, "groupSet", [security_group.id]) + conn.modify_instance_attribute(instance.id, "groupSet", [security_group.id]) security_group_instances = security_group.instances() @@ -1222,38 +1222,31 @@ def test_get_instance_by_security_group(): @mock_ec2 def test_modify_delete_on_termination(): - ec2_client = boto3.resource('ec2', region_name='us-west-1') - result = ec2_client.create_instances(ImageId='ami-12345678', MinCount=1, MaxCount=1) + ec2_client = boto3.resource("ec2", region_name="us-west-1") + result = ec2_client.create_instances(ImageId="ami-12345678", MinCount=1, MaxCount=1) instance = result[0] instance.load() - instance.block_device_mappings[0]['Ebs']['DeleteOnTermination'].should.be(False) + instance.block_device_mappings[0]["Ebs"]["DeleteOnTermination"].should.be(False) instance.modify_attribute( - BlockDeviceMappings=[{ - 'DeviceName': '/dev/sda1', - 'Ebs': {'DeleteOnTermination': True} - }] + BlockDeviceMappings=[ + {"DeviceName": "/dev/sda1", "Ebs": {"DeleteOnTermination": True}} + ] ) instance.load() - instance.block_device_mappings[0]['Ebs']['DeleteOnTermination'].should.be(True) + instance.block_device_mappings[0]["Ebs"]["DeleteOnTermination"].should.be(True) + @mock_ec2 def test_create_instance_ebs_optimized(): - ec2_resource = boto3.resource('ec2', region_name='eu-west-1') + ec2_resource = boto3.resource("ec2", region_name="eu-west-1") instance = ec2_resource.create_instances( - ImageId = 'ami-12345678', - MaxCount = 1, - MinCount = 1, - EbsOptimized = True, + ImageId="ami-12345678", MaxCount=1, MinCount=1, EbsOptimized=True )[0] instance.load() instance.ebs_optimized.should.be(True) - instance.modify_attribute( - EbsOptimized={ - 'Value': False - } - ) + instance.modify_attribute(EbsOptimized={"Value": False}) instance.load() instance.ebs_optimized.should.be(False) @@ -1261,34 +1254,55 @@ def test_create_instance_ebs_optimized(): @mock_ec2 def test_run_multiple_instances_in_same_command(): instance_count = 4 - client = boto3.client('ec2', region_name='us-east-1') - client.run_instances(ImageId='ami-1234abcd', - MinCount=instance_count, - MaxCount=instance_count) - reservations = client.describe_instances()['Reservations'] + client = boto3.client("ec2", region_name="us-east-1") + client.run_instances( + ImageId="ami-1234abcd", MinCount=instance_count, MaxCount=instance_count + ) + reservations = client.describe_instances()["Reservations"] - reservations[0]['Instances'].should.have.length_of(instance_count) + reservations[0]["Instances"].should.have.length_of(instance_count) - instances = reservations[0]['Instances'] + instances = reservations[0]["Instances"] for i in range(0, instance_count): - instances[i]['AmiLaunchIndex'].should.be(i) + instances[i]["AmiLaunchIndex"].should.be(i) @mock_ec2 def test_describe_instance_attribute(): - client = boto3.client('ec2', region_name='us-east-1') + client = boto3.client("ec2", region_name="us-east-1") security_group_id = client.create_security_group( - GroupName='test security group', Description='this is a test security group')['GroupId'] - client.run_instances(ImageId='ami-1234abcd', - MinCount=1, - MaxCount=1, - SecurityGroupIds=[security_group_id]) - instance_id = client.describe_instances()['Reservations'][0]['Instances'][0]['InstanceId'] + GroupName="test security group", Description="this is a test security group" + )["GroupId"] + client.run_instances( + ImageId="ami-1234abcd", + MinCount=1, + MaxCount=1, + SecurityGroupIds=[security_group_id], + ) + instance_id = client.describe_instances()["Reservations"][0]["Instances"][0][ + "InstanceId" + ] - valid_instance_attributes = ['instanceType', 'kernel', 'ramdisk', 'userData', 'disableApiTermination', 'instanceInitiatedShutdownBehavior', 'rootDeviceName', 'blockDeviceMapping', 'productCodes', 'sourceDestCheck', 'groupSet', 'ebsOptimized', 'sriovNetSupport'] + valid_instance_attributes = [ + "instanceType", + "kernel", + "ramdisk", + "userData", + "disableApiTermination", + "instanceInitiatedShutdownBehavior", + "rootDeviceName", + "blockDeviceMapping", + "productCodes", + "sourceDestCheck", + "groupSet", + "ebsOptimized", + "sriovNetSupport", + ] for valid_instance_attribute in valid_instance_attributes: - response = client.describe_instance_attribute(InstanceId=instance_id, Attribute=valid_instance_attribute) + response = client.describe_instance_attribute( + InstanceId=instance_id, Attribute=valid_instance_attribute + ) if valid_instance_attribute == "groupSet": response.should.have.key("Groups") response["Groups"].should.have.length_of(1) @@ -1297,12 +1311,22 @@ def test_describe_instance_attribute(): response.should.have.key("UserData") response["UserData"].should.be.empty - invalid_instance_attributes = ['abc', 'Kernel', 'RamDisk', 'userdata', 'iNsTaNcEtYpE'] + invalid_instance_attributes = [ + "abc", + "Kernel", + "RamDisk", + "userdata", + "iNsTaNcEtYpE", + ] for invalid_instance_attribute in invalid_instance_attributes: with assert_raises(ClientError) as ex: - client.describe_instance_attribute(InstanceId=instance_id, Attribute=invalid_instance_attribute) - ex.exception.response['Error']['Code'].should.equal('InvalidParameterValue') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - message = 'Value ({invalid_instance_attribute}) for parameter attribute is invalid. Unknown attribute.'.format(invalid_instance_attribute=invalid_instance_attribute) - ex.exception.response['Error']['Message'].should.equal(message) + client.describe_instance_attribute( + InstanceId=instance_id, Attribute=invalid_instance_attribute + ) + ex.exception.response["Error"]["Code"].should.equal("InvalidParameterValue") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + message = "Value ({invalid_instance_attribute}) for parameter attribute is invalid. Unknown attribute.".format( + invalid_instance_attribute=invalid_instance_attribute + ) + ex.exception.response["Error"]["Message"].should.equal(message) diff --git a/tests/test_ec2/test_internet_gateways.py b/tests/test_ec2/test_internet_gateways.py index 1f010223c..5941643cf 100644 --- a/tests/test_ec2/test_internet_gateways.py +++ b/tests/test_ec2/test_internet_gateways.py @@ -1,269 +1,271 @@ -from __future__ import unicode_literals -# Ensure 'assert_raises' context manager support for Python 2.6 -import tests.backport_assert_raises -from nose.tools import assert_raises - -import re - -import boto -from boto.exception import EC2ResponseError - -import sure # noqa - -from moto import mock_ec2_deprecated - - -VPC_CIDR = "10.0.0.0/16" -BAD_VPC = "vpc-deadbeef" -BAD_IGW = "igw-deadbeef" - - -@mock_ec2_deprecated -def test_igw_create(): - """ internet gateway create """ - conn = boto.connect_vpc('the_key', 'the_secret') - - conn.get_all_internet_gateways().should.have.length_of(0) - - with assert_raises(EC2ResponseError) as ex: - igw = conn.create_internet_gateway(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateInternetGateway operation: Request would have succeeded, but DryRun flag is set') - - igw = conn.create_internet_gateway() - conn.get_all_internet_gateways().should.have.length_of(1) - igw.id.should.match(r'igw-[0-9a-f]+') - - igw = conn.get_all_internet_gateways()[0] - igw.attachments.should.have.length_of(0) - - -@mock_ec2_deprecated -def test_igw_attach(): - """ internet gateway attach """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - vpc = conn.create_vpc(VPC_CIDR) - - with assert_raises(EC2ResponseError) as ex: - conn.attach_internet_gateway(igw.id, vpc.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AttachInternetGateway operation: Request would have succeeded, but DryRun flag is set') - - conn.attach_internet_gateway(igw.id, vpc.id) - - igw = conn.get_all_internet_gateways()[0] - igw.attachments[0].vpc_id.should.be.equal(vpc.id) - - -@mock_ec2_deprecated -def test_igw_attach_bad_vpc(): - """ internet gateway fail to attach w/ bad vpc """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - - with assert_raises(EC2ResponseError) as cm: - conn.attach_internet_gateway(igw.id, BAD_VPC) - cm.exception.code.should.equal('InvalidVpcID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_igw_attach_twice(): - """ internet gateway fail to attach twice """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - vpc1 = conn.create_vpc(VPC_CIDR) - vpc2 = conn.create_vpc(VPC_CIDR) - conn.attach_internet_gateway(igw.id, vpc1.id) - - with assert_raises(EC2ResponseError) as cm: - conn.attach_internet_gateway(igw.id, vpc2.id) - cm.exception.code.should.equal('Resource.AlreadyAssociated') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_igw_detach(): - """ internet gateway detach""" - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - vpc = conn.create_vpc(VPC_CIDR) - conn.attach_internet_gateway(igw.id, vpc.id) - - with assert_raises(EC2ResponseError) as ex: - conn.detach_internet_gateway(igw.id, vpc.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DetachInternetGateway operation: Request would have succeeded, but DryRun flag is set') - - conn.detach_internet_gateway(igw.id, vpc.id) - igw = conn.get_all_internet_gateways()[0] - igw.attachments.should.have.length_of(0) - - -@mock_ec2_deprecated -def test_igw_detach_wrong_vpc(): - """ internet gateway fail to detach w/ wrong vpc """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - vpc1 = conn.create_vpc(VPC_CIDR) - vpc2 = conn.create_vpc(VPC_CIDR) - conn.attach_internet_gateway(igw.id, vpc1.id) - - with assert_raises(EC2ResponseError) as cm: - conn.detach_internet_gateway(igw.id, vpc2.id) - cm.exception.code.should.equal('Gateway.NotAttached') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_igw_detach_invalid_vpc(): - """ internet gateway fail to detach w/ invalid vpc """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - vpc = conn.create_vpc(VPC_CIDR) - conn.attach_internet_gateway(igw.id, vpc.id) - - with assert_raises(EC2ResponseError) as cm: - conn.detach_internet_gateway(igw.id, BAD_VPC) - cm.exception.code.should.equal('Gateway.NotAttached') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_igw_detach_unattached(): - """ internet gateway fail to detach unattached """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - vpc = conn.create_vpc(VPC_CIDR) - - with assert_raises(EC2ResponseError) as cm: - conn.detach_internet_gateway(igw.id, vpc.id) - cm.exception.code.should.equal('Gateway.NotAttached') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_igw_delete(): - """ internet gateway delete""" - conn = boto.connect_vpc('the_key', 'the_secret') - vpc = conn.create_vpc(VPC_CIDR) - conn.get_all_internet_gateways().should.have.length_of(0) - igw = conn.create_internet_gateway() - conn.get_all_internet_gateways().should.have.length_of(1) - - with assert_raises(EC2ResponseError) as ex: - conn.delete_internet_gateway(igw.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') - ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteInternetGateway operation: Request would have succeeded, but DryRun flag is set') - - conn.delete_internet_gateway(igw.id) - conn.get_all_internet_gateways().should.have.length_of(0) - - -@mock_ec2_deprecated -def test_igw_delete_attached(): - """ internet gateway fail to delete attached """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - vpc = conn.create_vpc(VPC_CIDR) - conn.attach_internet_gateway(igw.id, vpc.id) - - with assert_raises(EC2ResponseError) as cm: - conn.delete_internet_gateway(igw.id) - cm.exception.code.should.equal('DependencyViolation') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_igw_desribe(): - """ internet gateway fetch by id """ - conn = boto.connect_vpc('the_key', 'the_secret') - igw = conn.create_internet_gateway() - igw_by_search = conn.get_all_internet_gateways([igw.id])[0] - igw.id.should.equal(igw_by_search.id) - - -@mock_ec2_deprecated -def test_igw_describe_bad_id(): - """ internet gateway fail to fetch by bad id """ - conn = boto.connect_vpc('the_key', 'the_secret') - with assert_raises(EC2ResponseError) as cm: - conn.get_all_internet_gateways([BAD_IGW]) - cm.exception.code.should.equal('InvalidInternetGatewayID.NotFound') - cm.exception.status.should.equal(400) - cm.exception.request_id.should_not.be.none - - -@mock_ec2_deprecated -def test_igw_filter_by_vpc_id(): - """ internet gateway filter by vpc id """ - conn = boto.connect_vpc('the_key', 'the_secret') - - igw1 = conn.create_internet_gateway() - igw2 = conn.create_internet_gateway() - vpc = conn.create_vpc(VPC_CIDR) - conn.attach_internet_gateway(igw1.id, vpc.id) - - result = conn.get_all_internet_gateways( - filters={"attachment.vpc-id": vpc.id}) - result.should.have.length_of(1) - result[0].id.should.equal(igw1.id) - - -@mock_ec2_deprecated -def test_igw_filter_by_tags(): - """ internet gateway filter by vpc id """ - conn = boto.connect_vpc('the_key', 'the_secret') - - igw1 = conn.create_internet_gateway() - igw2 = conn.create_internet_gateway() - igw1.add_tag("tests", "yes") - - result = conn.get_all_internet_gateways(filters={"tag:tests": "yes"}) - result.should.have.length_of(1) - result[0].id.should.equal(igw1.id) - - -@mock_ec2_deprecated -def test_igw_filter_by_internet_gateway_id(): - """ internet gateway filter by internet gateway id """ - conn = boto.connect_vpc('the_key', 'the_secret') - - igw1 = conn.create_internet_gateway() - igw2 = conn.create_internet_gateway() - - result = conn.get_all_internet_gateways( - filters={"internet-gateway-id": igw1.id}) - result.should.have.length_of(1) - result[0].id.should.equal(igw1.id) - - -@mock_ec2_deprecated -def test_igw_filter_by_attachment_state(): - """ internet gateway filter by attachment state """ - conn = boto.connect_vpc('the_key', 'the_secret') - - igw1 = conn.create_internet_gateway() - igw2 = conn.create_internet_gateway() - vpc = conn.create_vpc(VPC_CIDR) - conn.attach_internet_gateway(igw1.id, vpc.id) - - result = conn.get_all_internet_gateways( - filters={"attachment.state": "available"}) - result.should.have.length_of(1) - result[0].id.should.equal(igw1.id) +from __future__ import unicode_literals + +# Ensure 'assert_raises' context manager support for Python 2.6 +import tests.backport_assert_raises +from nose.tools import assert_raises + +import re + +import boto +from boto.exception import EC2ResponseError + +import sure # noqa + +from moto import mock_ec2_deprecated + + +VPC_CIDR = "10.0.0.0/16" +BAD_VPC = "vpc-deadbeef" +BAD_IGW = "igw-deadbeef" + + +@mock_ec2_deprecated +def test_igw_create(): + """ internet gateway create """ + conn = boto.connect_vpc("the_key", "the_secret") + + conn.get_all_internet_gateways().should.have.length_of(0) + + with assert_raises(EC2ResponseError) as ex: + igw = conn.create_internet_gateway(dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the CreateInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) + + igw = conn.create_internet_gateway() + conn.get_all_internet_gateways().should.have.length_of(1) + igw.id.should.match(r"igw-[0-9a-f]+") + + igw = conn.get_all_internet_gateways()[0] + igw.attachments.should.have.length_of(0) + + +@mock_ec2_deprecated +def test_igw_attach(): + """ internet gateway attach """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + vpc = conn.create_vpc(VPC_CIDR) + + with assert_raises(EC2ResponseError) as ex: + conn.attach_internet_gateway(igw.id, vpc.id, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the AttachInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) + + conn.attach_internet_gateway(igw.id, vpc.id) + + igw = conn.get_all_internet_gateways()[0] + igw.attachments[0].vpc_id.should.be.equal(vpc.id) + + +@mock_ec2_deprecated +def test_igw_attach_bad_vpc(): + """ internet gateway fail to attach w/ bad vpc """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + + with assert_raises(EC2ResponseError) as cm: + conn.attach_internet_gateway(igw.id, BAD_VPC) + cm.exception.code.should.equal("InvalidVpcID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_igw_attach_twice(): + """ internet gateway fail to attach twice """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + vpc1 = conn.create_vpc(VPC_CIDR) + vpc2 = conn.create_vpc(VPC_CIDR) + conn.attach_internet_gateway(igw.id, vpc1.id) + + with assert_raises(EC2ResponseError) as cm: + conn.attach_internet_gateway(igw.id, vpc2.id) + cm.exception.code.should.equal("Resource.AlreadyAssociated") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_igw_detach(): + """ internet gateway detach""" + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + vpc = conn.create_vpc(VPC_CIDR) + conn.attach_internet_gateway(igw.id, vpc.id) + + with assert_raises(EC2ResponseError) as ex: + conn.detach_internet_gateway(igw.id, vpc.id, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the DetachInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) + + conn.detach_internet_gateway(igw.id, vpc.id) + igw = conn.get_all_internet_gateways()[0] + igw.attachments.should.have.length_of(0) + + +@mock_ec2_deprecated +def test_igw_detach_wrong_vpc(): + """ internet gateway fail to detach w/ wrong vpc """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + vpc1 = conn.create_vpc(VPC_CIDR) + vpc2 = conn.create_vpc(VPC_CIDR) + conn.attach_internet_gateway(igw.id, vpc1.id) + + with assert_raises(EC2ResponseError) as cm: + conn.detach_internet_gateway(igw.id, vpc2.id) + cm.exception.code.should.equal("Gateway.NotAttached") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_igw_detach_invalid_vpc(): + """ internet gateway fail to detach w/ invalid vpc """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + vpc = conn.create_vpc(VPC_CIDR) + conn.attach_internet_gateway(igw.id, vpc.id) + + with assert_raises(EC2ResponseError) as cm: + conn.detach_internet_gateway(igw.id, BAD_VPC) + cm.exception.code.should.equal("Gateway.NotAttached") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_igw_detach_unattached(): + """ internet gateway fail to detach unattached """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + vpc = conn.create_vpc(VPC_CIDR) + + with assert_raises(EC2ResponseError) as cm: + conn.detach_internet_gateway(igw.id, vpc.id) + cm.exception.code.should.equal("Gateway.NotAttached") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_igw_delete(): + """ internet gateway delete""" + conn = boto.connect_vpc("the_key", "the_secret") + vpc = conn.create_vpc(VPC_CIDR) + conn.get_all_internet_gateways().should.have.length_of(0) + igw = conn.create_internet_gateway() + conn.get_all_internet_gateways().should.have.length_of(1) + + with assert_raises(EC2ResponseError) as ex: + conn.delete_internet_gateway(igw.id, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") + ex.exception.status.should.equal(400) + ex.exception.message.should.equal( + "An error occurred (DryRunOperation) when calling the DeleteInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) + + conn.delete_internet_gateway(igw.id) + conn.get_all_internet_gateways().should.have.length_of(0) + + +@mock_ec2_deprecated +def test_igw_delete_attached(): + """ internet gateway fail to delete attached """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + vpc = conn.create_vpc(VPC_CIDR) + conn.attach_internet_gateway(igw.id, vpc.id) + + with assert_raises(EC2ResponseError) as cm: + conn.delete_internet_gateway(igw.id) + cm.exception.code.should.equal("DependencyViolation") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_igw_desribe(): + """ internet gateway fetch by id """ + conn = boto.connect_vpc("the_key", "the_secret") + igw = conn.create_internet_gateway() + igw_by_search = conn.get_all_internet_gateways([igw.id])[0] + igw.id.should.equal(igw_by_search.id) + + +@mock_ec2_deprecated +def test_igw_describe_bad_id(): + """ internet gateway fail to fetch by bad id """ + conn = boto.connect_vpc("the_key", "the_secret") + with assert_raises(EC2ResponseError) as cm: + conn.get_all_internet_gateways([BAD_IGW]) + cm.exception.code.should.equal("InvalidInternetGatewayID.NotFound") + cm.exception.status.should.equal(400) + cm.exception.request_id.should_not.be.none + + +@mock_ec2_deprecated +def test_igw_filter_by_vpc_id(): + """ internet gateway filter by vpc id """ + conn = boto.connect_vpc("the_key", "the_secret") + + igw1 = conn.create_internet_gateway() + igw2 = conn.create_internet_gateway() + vpc = conn.create_vpc(VPC_CIDR) + conn.attach_internet_gateway(igw1.id, vpc.id) + + result = conn.get_all_internet_gateways(filters={"attachment.vpc-id": vpc.id}) + result.should.have.length_of(1) + result[0].id.should.equal(igw1.id) + + +@mock_ec2_deprecated +def test_igw_filter_by_tags(): + """ internet gateway filter by vpc id """ + conn = boto.connect_vpc("the_key", "the_secret") + + igw1 = conn.create_internet_gateway() + igw2 = conn.create_internet_gateway() + igw1.add_tag("tests", "yes") + + result = conn.get_all_internet_gateways(filters={"tag:tests": "yes"}) + result.should.have.length_of(1) + result[0].id.should.equal(igw1.id) + + +@mock_ec2_deprecated +def test_igw_filter_by_internet_gateway_id(): + """ internet gateway filter by internet gateway id """ + conn = boto.connect_vpc("the_key", "the_secret") + + igw1 = conn.create_internet_gateway() + igw2 = conn.create_internet_gateway() + + result = conn.get_all_internet_gateways(filters={"internet-gateway-id": igw1.id}) + result.should.have.length_of(1) + result[0].id.should.equal(igw1.id) + + +@mock_ec2_deprecated +def test_igw_filter_by_attachment_state(): + """ internet gateway filter by attachment state """ + conn = boto.connect_vpc("the_key", "the_secret") + + igw1 = conn.create_internet_gateway() + igw2 = conn.create_internet_gateway() + vpc = conn.create_vpc(VPC_CIDR) + conn.attach_internet_gateway(igw1.id, vpc.id) + + result = conn.get_all_internet_gateways(filters={"attachment.state": "available"}) + result.should.have.length_of(1) + result[0].id.should.equal(igw1.id) diff --git a/tests/test_ec2/test_key_pairs.py b/tests/test_ec2/test_key_pairs.py index dfe6eabdf..d632c2478 100644 --- a/tests/test_ec2/test_key_pairs.py +++ b/tests/test_ec2/test_key_pairs.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -47,116 +48,119 @@ moto@github.com""" @mock_ec2_deprecated def test_key_pairs_empty(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") assert len(conn.get_all_key_pairs()) == 0 @mock_ec2_deprecated def test_key_pairs_invalid_id(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.get_all_key_pairs('foo') - cm.exception.code.should.equal('InvalidKeyPair.NotFound') + conn.get_all_key_pairs("foo") + cm.exception.code.should.equal("InvalidKeyPair.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_key_pairs_create(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - conn.create_key_pair('foo', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.create_key_pair("foo", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateKeyPair operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateKeyPair operation: Request would have succeeded, but DryRun flag is set" + ) - kp = conn.create_key_pair('foo') + kp = conn.create_key_pair("foo") rsa_check_private_key(kp.material) kps = conn.get_all_key_pairs() assert len(kps) == 1 - assert kps[0].name == 'foo' + assert kps[0].name == "foo" @mock_ec2_deprecated def test_key_pairs_create_two(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - kp1 = conn.create_key_pair('foo') + kp1 = conn.create_key_pair("foo") rsa_check_private_key(kp1.material) - kp2 = conn.create_key_pair('bar') + kp2 = conn.create_key_pair("bar") rsa_check_private_key(kp2.material) assert kp1.material != kp2.material kps = conn.get_all_key_pairs() kps.should.have.length_of(2) - assert {i.name for i in kps} == {'foo', 'bar'} + assert {i.name for i in kps} == {"foo", "bar"} - kps = conn.get_all_key_pairs('foo') + kps = conn.get_all_key_pairs("foo") kps.should.have.length_of(1) - kps[0].name.should.equal('foo') + kps[0].name.should.equal("foo") @mock_ec2_deprecated def test_key_pairs_create_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.create_key_pair('foo') + conn = boto.connect_ec2("the_key", "the_secret") + conn.create_key_pair("foo") assert len(conn.get_all_key_pairs()) == 1 with assert_raises(EC2ResponseError) as cm: - conn.create_key_pair('foo') - cm.exception.code.should.equal('InvalidKeyPair.Duplicate') + conn.create_key_pair("foo") + cm.exception.code.should.equal("InvalidKeyPair.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_key_pairs_delete_no_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") assert len(conn.get_all_key_pairs()) == 0 - r = conn.delete_key_pair('foo') + r = conn.delete_key_pair("foo") r.should.be.ok @mock_ec2_deprecated def test_key_pairs_delete_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.create_key_pair('foo') + conn = boto.connect_ec2("the_key", "the_secret") + conn.create_key_pair("foo") with assert_raises(EC2ResponseError) as ex: - r = conn.delete_key_pair('foo', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + r = conn.delete_key_pair("foo", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteKeyPair operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteKeyPair operation: Request would have succeeded, but DryRun flag is set" + ) - r = conn.delete_key_pair('foo') + r = conn.delete_key_pair("foo") r.should.be.ok assert len(conn.get_all_key_pairs()) == 0 @mock_ec2_deprecated def test_key_pairs_import(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', RSA_PUBLIC_KEY_OPENSSH, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.import_key_pair("foo", RSA_PUBLIC_KEY_OPENSSH, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set" + ) - kp1 = conn.import_key_pair('foo', RSA_PUBLIC_KEY_OPENSSH) - assert kp1.name == 'foo' + kp1 = conn.import_key_pair("foo", RSA_PUBLIC_KEY_OPENSSH) + assert kp1.name == "foo" assert kp1.fingerprint == RSA_PUBLIC_KEY_FINGERPRINT - kp2 = conn.import_key_pair('foo2', RSA_PUBLIC_KEY_RFC4716) - assert kp2.name == 'foo2' + kp2 = conn.import_key_pair("foo2", RSA_PUBLIC_KEY_RFC4716) + assert kp2.name == "foo2" assert kp2.fingerprint == RSA_PUBLIC_KEY_FINGERPRINT kps = conn.get_all_key_pairs() @@ -167,58 +171,51 @@ def test_key_pairs_import(): @mock_ec2_deprecated def test_key_pairs_import_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') - kp = conn.import_key_pair('foo', RSA_PUBLIC_KEY_OPENSSH) - assert kp.name == 'foo' + conn = boto.connect_ec2("the_key", "the_secret") + kp = conn.import_key_pair("foo", RSA_PUBLIC_KEY_OPENSSH) + assert kp.name == "foo" assert len(conn.get_all_key_pairs()) == 1 with assert_raises(EC2ResponseError) as cm: - conn.create_key_pair('foo') - cm.exception.code.should.equal('InvalidKeyPair.Duplicate') + conn.create_key_pair("foo") + cm.exception.code.should.equal("InvalidKeyPair.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_key_pairs_invalid(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', b'') - ex.exception.error_code.should.equal('InvalidKeyPair.Format') + conn.import_key_pair("foo", b"") + ex.exception.error_code.should.equal("InvalidKeyPair.Format") ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'Key is not in valid OpenSSH public key format') + ex.exception.message.should.equal("Key is not in valid OpenSSH public key format") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', b'garbage') - ex.exception.error_code.should.equal('InvalidKeyPair.Format') + conn.import_key_pair("foo", b"garbage") + ex.exception.error_code.should.equal("InvalidKeyPair.Format") ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'Key is not in valid OpenSSH public key format') + ex.exception.message.should.equal("Key is not in valid OpenSSH public key format") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', DSA_PUBLIC_KEY_OPENSSH) - ex.exception.error_code.should.equal('InvalidKeyPair.Format') + conn.import_key_pair("foo", DSA_PUBLIC_KEY_OPENSSH) + ex.exception.error_code.should.equal("InvalidKeyPair.Format") ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'Key is not in valid OpenSSH public key format') + ex.exception.message.should.equal("Key is not in valid OpenSSH public key format") @mock_ec2_deprecated def test_key_pair_filters(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - _ = conn.create_key_pair('kpfltr1') - kp2 = conn.create_key_pair('kpfltr2') - kp3 = conn.create_key_pair('kpfltr3') + _ = conn.create_key_pair("kpfltr1") + kp2 = conn.create_key_pair("kpfltr2") + kp3 = conn.create_key_pair("kpfltr3") - kp_by_name = conn.get_all_key_pairs( - filters={'key-name': 'kpfltr2'}) - set([kp.name for kp in kp_by_name] - ).should.equal(set([kp2.name])) + kp_by_name = conn.get_all_key_pairs(filters={"key-name": "kpfltr2"}) + set([kp.name for kp in kp_by_name]).should.equal(set([kp2.name])) - kp_by_name = conn.get_all_key_pairs( - filters={'fingerprint': kp3.fingerprint}) - set([kp.name for kp in kp_by_name] - ).should.equal(set([kp3.name])) + kp_by_name = conn.get_all_key_pairs(filters={"fingerprint": kp3.fingerprint}) + set([kp.name for kp in kp_by_name]).should.equal(set([kp3.name])) diff --git a/tests/test_ec2/test_launch_templates.py b/tests/test_ec2/test_launch_templates.py index 87e1d3986..4c37818d1 100644 --- a/tests/test_ec2/test_launch_templates.py +++ b/tests/test_ec2/test_launch_templates.py @@ -13,16 +13,14 @@ def test_launch_template_create(): resp = cli.create_launch_template( LaunchTemplateName="test-template", - # the absolute minimum needed to create a template without other resources LaunchTemplateData={ - "TagSpecifications": [{ - "ResourceType": "instance", - "Tags": [{ - "Key": "test", - "Value": "value", - }], - }], + "TagSpecifications": [ + { + "ResourceType": "instance", + "Tags": [{"Key": "test", "Value": "value"}], + } + ] }, ) @@ -36,18 +34,18 @@ def test_launch_template_create(): cli.create_launch_template( LaunchTemplateName="test-template", LaunchTemplateData={ - "TagSpecifications": [{ - "ResourceType": "instance", - "Tags": [{ - "Key": "test", - "Value": "value", - }], - }], + "TagSpecifications": [ + { + "ResourceType": "instance", + "Tags": [{"Key": "test", "Value": "value"}], + } + ] }, ) str(ex.exception).should.equal( - 'An error occurred (InvalidLaunchTemplateName.AlreadyExistsException) when calling the CreateLaunchTemplate operation: Launch template name already in use.') + "An error occurred (InvalidLaunchTemplateName.AlreadyExistsException) when calling the CreateLaunchTemplate operation: Launch template name already in use." + ) @mock_ec2 @@ -55,29 +53,22 @@ def test_describe_launch_template_versions(): template_data = { "ImageId": "ami-abc123", "DisableApiTermination": False, - "TagSpecifications": [{ - "ResourceType": "instance", - "Tags": [{ - "Key": "test", - "Value": "value", - }], - }], - "SecurityGroupIds": [ - "sg-1234", - "sg-ab5678", + "TagSpecifications": [ + {"ResourceType": "instance", "Tags": [{"Key": "test", "Value": "value"}]} ], + "SecurityGroupIds": ["sg-1234", "sg-ab5678"], } cli = boto3.client("ec2", region_name="us-east-1") create_resp = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData=template_data) + LaunchTemplateName="test-template", LaunchTemplateData=template_data + ) # test using name resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - Versions=['1']) + LaunchTemplateName="test-template", Versions=["1"] + ) templ = resp["LaunchTemplateVersions"][0]["LaunchTemplateData"] templ.should.equal(template_data) @@ -85,7 +76,8 @@ def test_describe_launch_template_versions(): # test using id resp = cli.describe_launch_template_versions( LaunchTemplateId=create_resp["LaunchTemplate"]["LaunchTemplateId"], - Versions=['1']) + Versions=["1"], + ) templ = resp["LaunchTemplateVersions"][0]["LaunchTemplateData"] templ.should.equal(template_data) @@ -96,22 +88,21 @@ def test_create_launch_template_version(): cli = boto3.client("ec2", region_name="us-east-1") create_resp = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) version_resp = cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) version_resp.should.have.key("LaunchTemplateVersion") version = version_resp["LaunchTemplateVersion"] version["DefaultVersion"].should.equal(False) - version["LaunchTemplateId"].should.equal(create_resp["LaunchTemplate"]["LaunchTemplateId"]) + version["LaunchTemplateId"].should.equal( + create_resp["LaunchTemplate"]["LaunchTemplateId"] + ) version["VersionDescription"].should.equal("new ami") version["VersionNumber"].should.equal(2) @@ -121,22 +112,21 @@ def test_create_launch_template_version_by_id(): cli = boto3.client("ec2", region_name="us-east-1") create_resp = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) version_resp = cli.create_launch_template_version( LaunchTemplateId=create_resp["LaunchTemplate"]["LaunchTemplateId"], - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) version_resp.should.have.key("LaunchTemplateVersion") version = version_resp["LaunchTemplateVersion"] version["DefaultVersion"].should.equal(False) - version["LaunchTemplateId"].should.equal(create_resp["LaunchTemplate"]["LaunchTemplateId"]) + version["LaunchTemplateId"].should.equal( + create_resp["LaunchTemplate"]["LaunchTemplateId"] + ) version["VersionDescription"].should.equal("new ami") version["VersionNumber"].should.equal(2) @@ -146,24 +136,24 @@ def test_describe_launch_template_versions_with_multiple_versions(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) - resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template") + resp = cli.describe_launch_template_versions(LaunchTemplateName="test-template") resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-abc123") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-abc123" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) @mock_ec2 @@ -171,32 +161,32 @@ def test_describe_launch_template_versions_with_versions_option(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - Versions=["2", "3"]) + LaunchTemplateName="test-template", Versions=["2", "3"] + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-hij789") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-hij789" + ) @mock_ec2 @@ -204,32 +194,32 @@ def test_describe_launch_template_versions_with_min(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - MinVersion="2") + LaunchTemplateName="test-template", MinVersion="2" + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-hij789") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-hij789" + ) @mock_ec2 @@ -237,32 +227,32 @@ def test_describe_launch_template_versions_with_max(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - MaxVersion="2") + LaunchTemplateName="test-template", MaxVersion="2" + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-abc123") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-abc123" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) @mock_ec2 @@ -270,40 +260,38 @@ def test_describe_launch_template_versions_with_min_and_max(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-345abc" - }, - VersionDescription="new ami, because why not") + LaunchTemplateData={"ImageId": "ami-345abc"}, + VersionDescription="new ami, because why not", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - MinVersion="2", - MaxVersion="3") + LaunchTemplateName="test-template", MinVersion="2", MaxVersion="3" + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-hij789") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-hij789" + ) @mock_ec2 @@ -312,17 +300,14 @@ def test_describe_launch_templates(): lt_ids = [] r = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) lt_ids.append(r["LaunchTemplate"]["LaunchTemplateId"]) r = cli.create_launch_template( LaunchTemplateName="test-template2", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateData={"ImageId": "ami-abc123"}, + ) lt_ids.append(r["LaunchTemplate"]["LaunchTemplateId"]) # general call, all templates @@ -334,7 +319,8 @@ def test_describe_launch_templates(): # filter by names resp = cli.describe_launch_templates( - LaunchTemplateNames=["test-template2", "test-template"]) + LaunchTemplateNames=["test-template2", "test-template"] + ) resp.should.have.key("LaunchTemplates") resp["LaunchTemplates"].should.have.length_of(2) resp["LaunchTemplates"][0]["LaunchTemplateName"].should.equal("test-template2") @@ -353,34 +339,31 @@ def test_describe_launch_templates_with_filters(): cli = boto3.client("ec2", region_name="us-east-1") r = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_tags( Resources=[r["LaunchTemplate"]["LaunchTemplateId"]], Tags=[ {"Key": "tag1", "Value": "a value"}, {"Key": "another-key", "Value": "this value"}, - ]) + ], + ) cli.create_launch_template( - LaunchTemplateName="no-tags", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="no-tags", LaunchTemplateData={"ImageId": "ami-abc123"} + ) - resp = cli.describe_launch_templates(Filters=[{ - "Name": "tag:tag1", "Values": ["a value"] - }]) + resp = cli.describe_launch_templates( + Filters=[{"Name": "tag:tag1", "Values": ["a value"]}] + ) resp["LaunchTemplates"].should.have.length_of(1) resp["LaunchTemplates"][0]["LaunchTemplateName"].should.equal("test-template") - resp = cli.describe_launch_templates(Filters=[{ - "Name": "launch-template-name", "Values": ["no-tags"] - }]) + resp = cli.describe_launch_templates( + Filters=[{"Name": "launch-template-name", "Values": ["no-tags"]}] + ) resp["LaunchTemplates"].should.have.length_of(1) resp["LaunchTemplates"][0]["LaunchTemplateName"].should.equal("no-tags") @@ -392,24 +375,18 @@ def test_create_launch_template_with_tag_spec(): cli.create_launch_template( LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"}, - TagSpecifications=[{ - "ResourceType": "instance", - "Tags": [ - {"Key": "key", "Value": "value"} - ] - }], + TagSpecifications=[ + {"ResourceType": "instance", "Tags": [{"Key": "key", "Value": "value"}]} + ], ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - Versions=["1"]) + LaunchTemplateName="test-template", Versions=["1"] + ) version = resp["LaunchTemplateVersions"][0] version["LaunchTemplateData"].should.have.key("TagSpecifications") version["LaunchTemplateData"]["TagSpecifications"].should.have.length_of(1) - version["LaunchTemplateData"]["TagSpecifications"][0].should.equal({ - "ResourceType": "instance", - "Tags": [ - {"Key": "key", "Value": "value"} - ] - }) + version["LaunchTemplateData"]["TagSpecifications"][0].should.equal( + {"ResourceType": "instance", "Tags": [{"Key": "key", "Value": "value"}]} + ) diff --git a/tests/test_ec2/test_nat_gateway.py b/tests/test_ec2/test_nat_gateway.py index 310ae2c3a..fd8c721be 100644 --- a/tests/test_ec2/test_nat_gateway.py +++ b/tests/test_ec2/test_nat_gateway.py @@ -1,109 +1,225 @@ -from __future__ import unicode_literals -import boto3 -import sure # noqa -from moto import mock_ec2 - - -@mock_ec2 -def test_describe_nat_gateways(): - conn = boto3.client('ec2', 'us-east-1') - - response = conn.describe_nat_gateways() - - response['NatGateways'].should.have.length_of(0) - - -@mock_ec2 -def test_create_nat_gateway(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock='10.0.0.0/16') - vpc_id = vpc['Vpc']['VpcId'] - subnet = conn.create_subnet( - VpcId=vpc_id, - CidrBlock='10.0.1.0/27', - AvailabilityZone='us-east-1a', - ) - allocation_id = conn.allocate_address(Domain='vpc')['AllocationId'] - subnet_id = subnet['Subnet']['SubnetId'] - - response = conn.create_nat_gateway( - SubnetId=subnet_id, - AllocationId=allocation_id, - ) - - response['NatGateway']['VpcId'].should.equal(vpc_id) - response['NatGateway']['SubnetId'].should.equal(subnet_id) - response['NatGateway']['State'].should.equal('available') - - -@mock_ec2 -def test_delete_nat_gateway(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock='10.0.0.0/16') - vpc_id = vpc['Vpc']['VpcId'] - subnet = conn.create_subnet( - VpcId=vpc_id, - CidrBlock='10.0.1.0/27', - AvailabilityZone='us-east-1a', - ) - allocation_id = conn.allocate_address(Domain='vpc')['AllocationId'] - subnet_id = subnet['Subnet']['SubnetId'] - - nat_gateway = conn.create_nat_gateway( - SubnetId=subnet_id, - AllocationId=allocation_id, - ) - nat_gateway_id = nat_gateway['NatGateway']['NatGatewayId'] - response = conn.delete_nat_gateway(NatGatewayId=nat_gateway_id) - - # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'NatGatewayId': nat_gateway_id, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': '741fc8ab-6ebe-452b-b92b-example' - } - }) - - -@mock_ec2 -def test_create_and_describe_nat_gateway(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock='10.0.0.0/16') - vpc_id = vpc['Vpc']['VpcId'] - subnet = conn.create_subnet( - VpcId=vpc_id, - CidrBlock='10.0.1.0/27', - AvailabilityZone='us-east-1a', - ) - allocation_id = conn.allocate_address(Domain='vpc')['AllocationId'] - subnet_id = subnet['Subnet']['SubnetId'] - - create_response = conn.create_nat_gateway( - SubnetId=subnet_id, - AllocationId=allocation_id, - ) - nat_gateway_id = create_response['NatGateway']['NatGatewayId'] - describe_response = conn.describe_nat_gateways() - - enis = conn.describe_network_interfaces()['NetworkInterfaces'] - eni_id = enis[0]['NetworkInterfaceId'] - public_ip = conn.describe_addresses(AllocationIds=[allocation_id])[ - 'Addresses'][0]['PublicIp'] - - describe_response['NatGateways'].should.have.length_of(1) - describe_response['NatGateways'][0][ - 'NatGatewayId'].should.equal(nat_gateway_id) - describe_response['NatGateways'][0]['State'].should.equal('available') - describe_response['NatGateways'][0]['SubnetId'].should.equal(subnet_id) - describe_response['NatGateways'][0]['VpcId'].should.equal(vpc_id) - describe_response['NatGateways'][0]['NatGatewayAddresses'][ - 0]['AllocationId'].should.equal(allocation_id) - describe_response['NatGateways'][0]['NatGatewayAddresses'][ - 0]['NetworkInterfaceId'].should.equal(eni_id) - assert describe_response['NatGateways'][0][ - 'NatGatewayAddresses'][0]['PrivateIp'].startswith('10.') - describe_response['NatGateways'][0]['NatGatewayAddresses'][ - 0]['PublicIp'].should.equal(public_ip) +from __future__ import unicode_literals +import boto3 +import sure # noqa +from moto import mock_ec2 + + +@mock_ec2 +def test_describe_nat_gateways(): + conn = boto3.client("ec2", "us-east-1") + + response = conn.describe_nat_gateways() + + response["NatGateways"].should.have.length_of(0) + + +@mock_ec2 +def test_create_nat_gateway(): + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = conn.create_subnet( + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" + ) + allocation_id = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id = subnet["Subnet"]["SubnetId"] + + response = conn.create_nat_gateway(SubnetId=subnet_id, AllocationId=allocation_id) + + response["NatGateway"]["VpcId"].should.equal(vpc_id) + response["NatGateway"]["SubnetId"].should.equal(subnet_id) + response["NatGateway"]["State"].should.equal("available") + + +@mock_ec2 +def test_delete_nat_gateway(): + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = conn.create_subnet( + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" + ) + allocation_id = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id = subnet["Subnet"]["SubnetId"] + + nat_gateway = conn.create_nat_gateway( + SubnetId=subnet_id, AllocationId=allocation_id + ) + nat_gateway_id = nat_gateway["NatGateway"]["NatGatewayId"] + response = conn.delete_nat_gateway(NatGatewayId=nat_gateway_id) + + # this is hard to match against, so remove it + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "NatGatewayId": nat_gateway_id, + "ResponseMetadata": { + "HTTPStatusCode": 200, + "RequestId": "741fc8ab-6ebe-452b-b92b-example", + }, + } + ) + + +@mock_ec2 +def test_create_and_describe_nat_gateway(): + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = conn.create_subnet( + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" + ) + allocation_id = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id = subnet["Subnet"]["SubnetId"] + + create_response = conn.create_nat_gateway( + SubnetId=subnet_id, AllocationId=allocation_id + ) + nat_gateway_id = create_response["NatGateway"]["NatGatewayId"] + describe_response = conn.describe_nat_gateways() + + enis = conn.describe_network_interfaces()["NetworkInterfaces"] + eni_id = enis[0]["NetworkInterfaceId"] + public_ip = conn.describe_addresses(AllocationIds=[allocation_id])["Addresses"][0][ + "PublicIp" + ] + + describe_response["NatGateways"].should.have.length_of(1) + describe_response["NatGateways"][0]["NatGatewayId"].should.equal(nat_gateway_id) + describe_response["NatGateways"][0]["State"].should.equal("available") + describe_response["NatGateways"][0]["SubnetId"].should.equal(subnet_id) + describe_response["NatGateways"][0]["VpcId"].should.equal(vpc_id) + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "AllocationId" + ].should.equal(allocation_id) + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "NetworkInterfaceId" + ].should.equal(eni_id) + assert describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "PrivateIp" + ].startswith("10.") + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "PublicIp" + ].should.equal(public_ip) + + +@mock_ec2 +def test_describe_nat_gateway_filter_by_net_gateway_id_and_state(): + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] + subnet = conn.create_subnet( + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" + ) + allocation_id = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id = subnet["Subnet"]["SubnetId"] + + create_response = conn.create_nat_gateway( + SubnetId=subnet_id, AllocationId=allocation_id + ) + nat_gateway_id = create_response["NatGateway"]["NatGatewayId"] + + describe_response = conn.describe_nat_gateways( + Filters=[ + {"Name": "nat-gateway-id", "Values": ["non-existent-id"]}, + {"Name": "state", "Values": ["available"]}, + ] + ) + describe_response["NatGateways"].should.have.length_of(0) + + describe_response = conn.describe_nat_gateways( + Filters=[ + {"Name": "nat-gateway-id", "Values": [nat_gateway_id]}, + {"Name": "state", "Values": ["available"]}, + ] + ) + + describe_response["NatGateways"].should.have.length_of(1) + describe_response["NatGateways"][0]["NatGatewayId"].should.equal(nat_gateway_id) + describe_response["NatGateways"][0]["State"].should.equal("available") + describe_response["NatGateways"][0]["SubnetId"].should.equal(subnet_id) + describe_response["NatGateways"][0]["VpcId"].should.equal(vpc_id) + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "AllocationId" + ].should.equal(allocation_id) + + +@mock_ec2 +def test_describe_nat_gateway_filter_by_subnet_id(): + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] + subnet_1 = conn.create_subnet( + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" + ) + subnet_2 = conn.create_subnet( + VpcId=vpc_id, CidrBlock="10.0.2.0/27", AvailabilityZone="us-east-1a" + ) + allocation_id_1 = conn.allocate_address(Domain="vpc")["AllocationId"] + allocation_id_2 = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id_1 = subnet_1["Subnet"]["SubnetId"] + subnet_id_2 = subnet_2["Subnet"]["SubnetId"] + + create_response_1 = conn.create_nat_gateway( + SubnetId=subnet_id_1, AllocationId=allocation_id_1 + ) + # create_response_2 = + conn.create_nat_gateway(SubnetId=subnet_id_2, AllocationId=allocation_id_2) + nat_gateway_id_1 = create_response_1["NatGateway"]["NatGatewayId"] + # nat_gateway_id_2 = create_response_2["NatGateway"]["NatGatewayId"] + + describe_response = conn.describe_nat_gateways() + describe_response["NatGateways"].should.have.length_of(2) + + describe_response = conn.describe_nat_gateways( + Filters=[{"Name": "subnet-id", "Values": [subnet_id_1]}] + ) + describe_response["NatGateways"].should.have.length_of(1) + describe_response["NatGateways"][0]["NatGatewayId"].should.equal(nat_gateway_id_1) + describe_response["NatGateways"][0]["State"].should.equal("available") + describe_response["NatGateways"][0]["SubnetId"].should.equal(subnet_id_1) + describe_response["NatGateways"][0]["VpcId"].should.equal(vpc_id) + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "AllocationId" + ].should.equal(allocation_id_1) + + +@mock_ec2 +def test_describe_nat_gateway_filter_vpc_id(): + conn = boto3.client("ec2", "us-east-1") + vpc_1 = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id_1 = vpc_1["Vpc"]["VpcId"] + vpc_2 = conn.create_vpc(CidrBlock="10.1.0.0/16") + vpc_id_2 = vpc_2["Vpc"]["VpcId"] + subnet_1 = conn.create_subnet( + VpcId=vpc_id_1, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" + ) + subnet_2 = conn.create_subnet( + VpcId=vpc_id_2, CidrBlock="10.1.1.0/27", AvailabilityZone="us-east-1a" + ) + allocation_id_1 = conn.allocate_address(Domain="vpc")["AllocationId"] + allocation_id_2 = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id_1 = subnet_1["Subnet"]["SubnetId"] + subnet_id_2 = subnet_2["Subnet"]["SubnetId"] + + create_response_1 = conn.create_nat_gateway( + SubnetId=subnet_id_1, AllocationId=allocation_id_1 + ) + conn.create_nat_gateway(SubnetId=subnet_id_2, AllocationId=allocation_id_2) + nat_gateway_id_1 = create_response_1["NatGateway"]["NatGatewayId"] + + describe_response = conn.describe_nat_gateways() + describe_response["NatGateways"].should.have.length_of(2) + + describe_response = conn.describe_nat_gateways( + Filters=[{"Name": "vpc-id", "Values": [vpc_id_1]}] + ) + describe_response["NatGateways"].should.have.length_of(1) + describe_response["NatGateways"][0]["NatGatewayId"].should.equal(nat_gateway_id_1) + describe_response["NatGateways"][0]["State"].should.equal("available") + describe_response["NatGateways"][0]["SubnetId"].should.equal(subnet_id_1) + describe_response["NatGateways"][0]["VpcId"].should.equal(vpc_id_1) + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "AllocationId" + ].should.equal(allocation_id_1) diff --git a/tests/test_ec2/test_network_acls.py b/tests/test_ec2/test_network_acls.py index 1c69624bf..fb62f7178 100644 --- a/tests/test_ec2/test_network_acls.py +++ b/tests/test_ec2/test_network_acls.py @@ -10,7 +10,7 @@ from moto import mock_ec2_deprecated, mock_ec2 @mock_ec2_deprecated def test_default_network_acl_created_with_vpc(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") all_network_acls = conn.get_all_network_acls() all_network_acls.should.have.length_of(2) @@ -18,7 +18,7 @@ def test_default_network_acl_created_with_vpc(): @mock_ec2_deprecated def test_network_acls(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) all_network_acls = conn.get_all_network_acls() @@ -27,7 +27,7 @@ def test_network_acls(): @mock_ec2_deprecated def test_new_subnet_associates_with_default_network_acl(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.get_all_vpcs()[0] subnet = conn.create_subnet(vpc.id, "172.31.112.0/20") @@ -41,88 +41,100 @@ def test_new_subnet_associates_with_default_network_acl(): @mock_ec2_deprecated def test_network_acl_entries(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) network_acl_entry = conn.create_network_acl_entry( - network_acl.id, 110, 6, - 'ALLOW', '0.0.0.0/0', False, - port_range_from='443', - port_range_to='443' + network_acl.id, + 110, + 6, + "ALLOW", + "0.0.0.0/0", + False, + port_range_from="443", + port_range_to="443", ) all_network_acls = conn.get_all_network_acls() all_network_acls.should.have.length_of(3) - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) entries = test_network_acl.network_acl_entries entries.should.have.length_of(1) - entries[0].rule_number.should.equal('110') - entries[0].protocol.should.equal('6') - entries[0].rule_action.should.equal('ALLOW') + entries[0].rule_number.should.equal("110") + entries[0].protocol.should.equal("6") + entries[0].rule_action.should.equal("ALLOW") @mock_ec2_deprecated def test_delete_network_acl_entry(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) conn.create_network_acl_entry( - network_acl.id, 110, 6, - 'ALLOW', '0.0.0.0/0', False, - port_range_from='443', - port_range_to='443' - ) - conn.delete_network_acl_entry( - network_acl.id, 110, False + network_acl.id, + 110, + 6, + "ALLOW", + "0.0.0.0/0", + False, + port_range_from="443", + port_range_to="443", ) + conn.delete_network_acl_entry(network_acl.id, 110, False) all_network_acls = conn.get_all_network_acls() - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) entries = test_network_acl.network_acl_entries entries.should.have.length_of(0) @mock_ec2_deprecated def test_replace_network_acl_entry(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) conn.create_network_acl_entry( - network_acl.id, 110, 6, - 'ALLOW', '0.0.0.0/0', False, - port_range_from='443', - port_range_to='443' + network_acl.id, + 110, + 6, + "ALLOW", + "0.0.0.0/0", + False, + port_range_from="443", + port_range_to="443", ) conn.replace_network_acl_entry( - network_acl.id, 110, -1, - 'DENY', '0.0.0.0/0', False, - port_range_from='22', - port_range_to='22' + network_acl.id, + 110, + -1, + "DENY", + "0.0.0.0/0", + False, + port_range_from="22", + port_range_to="22", ) all_network_acls = conn.get_all_network_acls() - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) entries = test_network_acl.network_acl_entries entries.should.have.length_of(1) - entries[0].rule_number.should.equal('110') - entries[0].protocol.should.equal('-1') - entries[0].rule_action.should.equal('DENY') + entries[0].rule_number.should.equal("110") + entries[0].protocol.should.equal("-1") + entries[0].rule_action.should.equal("DENY") + @mock_ec2_deprecated def test_associate_new_network_acl_with_subnet(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") network_acl = conn.create_network_acl(vpc.id) @@ -132,8 +144,7 @@ def test_associate_new_network_acl_with_subnet(): all_network_acls = conn.get_all_network_acls() all_network_acls.should.have.length_of(3) - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) test_network_acl.associations.should.have.length_of(1) test_network_acl.associations[0].subnet_id.should.equal(subnet.id) @@ -141,7 +152,7 @@ def test_associate_new_network_acl_with_subnet(): @mock_ec2_deprecated def test_delete_network_acl(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") network_acl = conn.create_network_acl(vpc.id) @@ -161,7 +172,7 @@ def test_delete_network_acl(): @mock_ec2_deprecated def test_network_acl_tagging(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) @@ -172,46 +183,45 @@ def test_network_acl_tagging(): tag.value.should.equal("some value") all_network_acls = conn.get_all_network_acls() - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) test_network_acl.tags.should.have.length_of(1) test_network_acl.tags["a key"].should.equal("some value") @mock_ec2 def test_new_subnet_in_new_vpc_associates_with_default_network_acl(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - new_vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-west-1") + new_vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") new_vpc.reload() - subnet = ec2.create_subnet(VpcId=new_vpc.id, CidrBlock='10.0.0.0/24') + subnet = ec2.create_subnet(VpcId=new_vpc.id, CidrBlock="10.0.0.0/24") subnet.reload() new_vpcs_default_network_acl = next(iter(new_vpc.network_acls.all()), None) new_vpcs_default_network_acl.reload() new_vpcs_default_network_acl.vpc_id.should.equal(new_vpc.id) new_vpcs_default_network_acl.associations.should.have.length_of(1) - new_vpcs_default_network_acl.associations[0]['SubnetId'].should.equal(subnet.id) + new_vpcs_default_network_acl.associations[0]["SubnetId"].should.equal(subnet.id) @mock_ec2 def test_default_network_acl_default_entries(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_network_acl = next(iter(ec2.network_acls.all()), None) default_network_acl.is_default.should.be.ok default_network_acl.entries.should.have.length_of(4) unique_entries = [] for entry in default_network_acl.entries: - entry['CidrBlock'].should.equal('0.0.0.0/0') - entry['Protocol'].should.equal('-1') - entry['RuleNumber'].should.be.within([100, 32767]) - entry['RuleAction'].should.be.within(['allow', 'deny']) - assert type(entry['Egress']) is bool - if entry['RuleAction'] == 'allow': - entry['RuleNumber'].should.be.equal(100) + entry["CidrBlock"].should.equal("0.0.0.0/0") + entry["Protocol"].should.equal("-1") + entry["RuleNumber"].should.be.within([100, 32767]) + entry["RuleAction"].should.be.within(["allow", "deny"]) + assert type(entry["Egress"]) is bool + if entry["RuleAction"] == "allow": + entry["RuleNumber"].should.be.equal(100) else: - entry['RuleNumber'].should.be.equal(32767) + entry["RuleNumber"].should.be.equal(32767) if entry not in unique_entries: unique_entries.append(entry) @@ -220,33 +230,48 @@ def test_default_network_acl_default_entries(): @mock_ec2 def test_delete_default_network_acl_default_entry(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_network_acl = next(iter(ec2.network_acls.all()), None) default_network_acl.is_default.should.be.ok default_network_acl.entries.should.have.length_of(4) first_default_network_acl_entry = default_network_acl.entries[0] - default_network_acl.delete_entry(Egress=first_default_network_acl_entry['Egress'], - RuleNumber=first_default_network_acl_entry['RuleNumber']) + default_network_acl.delete_entry( + Egress=first_default_network_acl_entry["Egress"], + RuleNumber=first_default_network_acl_entry["RuleNumber"], + ) default_network_acl.entries.should.have.length_of(3) @mock_ec2 def test_duplicate_network_acl_entry(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_network_acl = next(iter(ec2.network_acls.all()), None) default_network_acl.is_default.should.be.ok rule_number = 200 egress = True - default_network_acl.create_entry(CidrBlock="0.0.0.0/0", Egress=egress, Protocol="-1", RuleAction="allow", RuleNumber=rule_number) + default_network_acl.create_entry( + CidrBlock="0.0.0.0/0", + Egress=egress, + Protocol="-1", + RuleAction="allow", + RuleNumber=rule_number, + ) with assert_raises(ClientError) as ex: - default_network_acl.create_entry(CidrBlock="10.0.0.0/0", Egress=egress, Protocol="-1", RuleAction="deny", RuleNumber=rule_number) + default_network_acl.create_entry( + CidrBlock="10.0.0.0/0", + Egress=egress, + Protocol="-1", + RuleAction="deny", + RuleNumber=rule_number, + ) str(ex.exception).should.equal( "An error occurred (NetworkAclEntryAlreadyExists) when calling the CreateNetworkAclEntry " - "operation: The network acl entry identified by {} already exists.".format(rule_number)) - - + "operation: The network acl entry identified by {} already exists.".format( + rule_number + ) + ) diff --git a/tests/test_ec2/test_regions.py b/tests/test_ec2/test_regions.py index f94c78eaf..551b739f2 100644 --- a/tests/test_ec2/test_regions.py +++ b/tests/test_ec2/test_regions.py @@ -7,38 +7,41 @@ from moto import mock_ec2_deprecated, mock_autoscaling_deprecated, mock_elb_depr from moto.ec2 import ec2_backends + def test_use_boto_regions(): boto_regions = {r.name for r in boto.ec2.regions()} moto_regions = set(ec2_backends) moto_regions.should.equal(boto_regions) + def add_servers_to_region(ami_id, count, region): conn = boto.ec2.connect_to_region(region) for index in range(count): conn.run_instances(ami_id) + @mock_ec2_deprecated def test_add_servers_to_a_single_region(): - region = 'ap-northeast-1' - add_servers_to_region('ami-1234abcd', 1, region) - add_servers_to_region('ami-5678efgh', 1, region) + region = "ap-northeast-1" + add_servers_to_region("ami-1234abcd", 1, region) + add_servers_to_region("ami-5678efgh", 1, region) conn = boto.ec2.connect_to_region(region) reservations = conn.get_all_instances() len(reservations).should.equal(2) reservations.sort(key=lambda x: x.instances[0].image_id) - reservations[0].instances[0].image_id.should.equal('ami-1234abcd') - reservations[1].instances[0].image_id.should.equal('ami-5678efgh') + reservations[0].instances[0].image_id.should.equal("ami-1234abcd") + reservations[1].instances[0].image_id.should.equal("ami-5678efgh") @mock_ec2_deprecated def test_add_servers_to_multiple_regions(): - region1 = 'us-east-1' - region2 = 'ap-northeast-1' - add_servers_to_region('ami-1234abcd', 1, region1) - add_servers_to_region('ami-5678efgh', 1, region2) + region1 = "us-east-1" + region2 = "ap-northeast-1" + add_servers_to_region("ami-1234abcd", 1, region1) + add_servers_to_region("ami-5678efgh", 1, region2) us_conn = boto.ec2.connect_to_region(region1) ap_conn = boto.ec2.connect_to_region(region2) @@ -48,33 +51,35 @@ def test_add_servers_to_multiple_regions(): len(us_reservations).should.equal(1) len(ap_reservations).should.equal(1) - us_reservations[0].instances[0].image_id.should.equal('ami-1234abcd') - ap_reservations[0].instances[0].image_id.should.equal('ami-5678efgh') + us_reservations[0].instances[0].image_id.should.equal("ami-1234abcd") + ap_reservations[0].instances[0].image_id.should.equal("ami-5678efgh") @mock_autoscaling_deprecated @mock_elb_deprecated def test_create_autoscaling_group(): - elb_conn = boto.ec2.elb.connect_to_region('us-east-1') + elb_conn = boto.ec2.elb.connect_to_region("us-east-1") elb_conn.create_load_balancer( - 'us_test_lb', zones=[], listeners=[(80, 8080, 'http')]) - elb_conn = boto.ec2.elb.connect_to_region('ap-northeast-1') + "us_test_lb", zones=[], listeners=[(80, 8080, "http")] + ) + elb_conn = boto.ec2.elb.connect_to_region("ap-northeast-1") elb_conn.create_load_balancer( - 'ap_test_lb', zones=[], listeners=[(80, 8080, 'http')]) + "ap_test_lb", zones=[], listeners=[(80, 8080, "http")] + ) - us_conn = boto.ec2.autoscale.connect_to_region('us-east-1') + us_conn = boto.ec2.autoscale.connect_to_region("us-east-1") config = boto.ec2.autoscale.LaunchConfiguration( - name='us_tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="us_tester", image_id="ami-abcd1234", instance_type="m1.small" ) x = us_conn.create_launch_configuration(config) - us_subnet_id = list(ec2_backends['us-east-1'].subnets['us-east-1c'].keys())[0] - ap_subnet_id = list(ec2_backends['ap-northeast-1'].subnets['ap-northeast-1a'].keys())[0] + us_subnet_id = list(ec2_backends["us-east-1"].subnets["us-east-1c"].keys())[0] + ap_subnet_id = list( + ec2_backends["ap-northeast-1"].subnets["ap-northeast-1a"].keys() + )[0] group = boto.ec2.autoscale.AutoScalingGroup( - name='us_tester_group', - availability_zones=['us-east-1c'], + name="us_tester_group", + availability_zones=["us-east-1c"], default_cooldown=60, desired_capacity=2, health_check_period=100, @@ -89,17 +94,15 @@ def test_create_autoscaling_group(): ) us_conn.create_auto_scaling_group(group) - ap_conn = boto.ec2.autoscale.connect_to_region('ap-northeast-1') + ap_conn = boto.ec2.autoscale.connect_to_region("ap-northeast-1") config = boto.ec2.autoscale.LaunchConfiguration( - name='ap_tester', - image_id='ami-efgh5678', - instance_type='m1.small', + name="ap_tester", image_id="ami-efgh5678", instance_type="m1.small" ) ap_conn.create_launch_configuration(config) group = boto.ec2.autoscale.AutoScalingGroup( - name='ap_tester_group', - availability_zones=['ap-northeast-1a'], + name="ap_tester_group", + availability_zones=["ap-northeast-1a"], default_cooldown=60, desired_capacity=2, health_check_period=100, @@ -118,33 +121,35 @@ def test_create_autoscaling_group(): len(ap_conn.get_all_groups()).should.equal(1) us_group = us_conn.get_all_groups()[0] - us_group.name.should.equal('us_tester_group') - list(us_group.availability_zones).should.equal(['us-east-1c']) + us_group.name.should.equal("us_tester_group") + list(us_group.availability_zones).should.equal(["us-east-1c"]) us_group.desired_capacity.should.equal(2) us_group.max_size.should.equal(2) us_group.min_size.should.equal(2) us_group.vpc_zone_identifier.should.equal(us_subnet_id) - us_group.launch_config_name.should.equal('us_tester') + us_group.launch_config_name.should.equal("us_tester") us_group.default_cooldown.should.equal(60) us_group.health_check_period.should.equal(100) us_group.health_check_type.should.equal("EC2") list(us_group.load_balancers).should.equal(["us_test_lb"]) us_group.placement_group.should.equal("us_test_placement") list(us_group.termination_policies).should.equal( - ["OldestInstance", "NewestInstance"]) + ["OldestInstance", "NewestInstance"] + ) ap_group = ap_conn.get_all_groups()[0] - ap_group.name.should.equal('ap_tester_group') - list(ap_group.availability_zones).should.equal(['ap-northeast-1a']) + ap_group.name.should.equal("ap_tester_group") + list(ap_group.availability_zones).should.equal(["ap-northeast-1a"]) ap_group.desired_capacity.should.equal(2) ap_group.max_size.should.equal(2) ap_group.min_size.should.equal(2) ap_group.vpc_zone_identifier.should.equal(ap_subnet_id) - ap_group.launch_config_name.should.equal('ap_tester') + ap_group.launch_config_name.should.equal("ap_tester") ap_group.default_cooldown.should.equal(60) ap_group.health_check_period.should.equal(100) ap_group.health_check_type.should.equal("EC2") list(ap_group.load_balancers).should.equal(["ap_test_lb"]) ap_group.placement_group.should.equal("ap_test_placement") list(ap_group.termination_policies).should.equal( - ["OldestInstance", "NewestInstance"]) + ["OldestInstance", "NewestInstance"] + ) diff --git a/tests/test_ec2/test_route_tables.py b/tests/test_ec2/test_route_tables.py index de33b3f7a..dfb3292b6 100644 --- a/tests/test_ec2/test_route_tables.py +++ b/tests/test_ec2/test_route_tables.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -15,10 +16,10 @@ from tests.helpers import requires_boto_gte @mock_ec2_deprecated def test_route_tables_defaults(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(1) main_route_table = all_route_tables[0] @@ -28,23 +29,23 @@ def test_route_tables_defaults(): routes.should.have.length_of(1) local_route = routes[0] - local_route.gateway_id.should.equal('local') - local_route.state.should.equal('active') + local_route.gateway_id.should.equal("local") + local_route.state.should.equal("active") local_route.destination_cidr_block.should.equal(vpc.cidr_block) vpc.delete() - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(0) @mock_ec2_deprecated def test_route_tables_additional(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") route_table = conn.create_route_table(vpc.id) - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(2) all_route_tables[0].vpc_id.should.equal(vpc.id) all_route_tables[1].vpc_id.should.equal(vpc.id) @@ -56,31 +57,31 @@ def test_route_tables_additional(): routes.should.have.length_of(1) local_route = routes[0] - local_route.gateway_id.should.equal('local') - local_route.state.should.equal('active') + local_route.gateway_id.should.equal("local") + local_route.state.should.equal("active") local_route.destination_cidr_block.should.equal(vpc.cidr_block) with assert_raises(EC2ResponseError) as cm: conn.delete_vpc(vpc.id) - cm.exception.code.should.equal('DependencyViolation') + cm.exception.code.should.equal("DependencyViolation") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none conn.delete_route_table(route_table.id) - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(1) with assert_raises(EC2ResponseError) as cm: conn.delete_route_table("rtb-1234abcd") - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_route_tables_filters_standard(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc1 = conn.create_vpc("10.0.0.0/16") route_table1 = conn.create_route_table(vpc1.id) @@ -92,39 +93,39 @@ def test_route_tables_filters_standard(): all_route_tables.should.have.length_of(5) # Filter by main route table - main_route_tables = conn.get_all_route_tables( - filters={'association.main': 'true'}) + main_route_tables = conn.get_all_route_tables(filters={"association.main": "true"}) main_route_tables.should.have.length_of(3) - main_route_table_ids = [ - route_table.id for route_table in main_route_tables] + main_route_table_ids = [route_table.id for route_table in main_route_tables] main_route_table_ids.should_not.contain(route_table1.id) main_route_table_ids.should_not.contain(route_table2.id) # Filter by VPC - vpc1_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc1.id}) + vpc1_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc1.id}) vpc1_route_tables.should.have.length_of(2) - vpc1_route_table_ids = [ - route_table.id for route_table in vpc1_route_tables] + vpc1_route_table_ids = [route_table.id for route_table in vpc1_route_tables] vpc1_route_table_ids.should.contain(route_table1.id) vpc1_route_table_ids.should_not.contain(route_table2.id) # Filter by VPC and main route table vpc2_main_route_tables = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc2.id}) + filters={"association.main": "true", "vpc-id": vpc2.id} + ) vpc2_main_route_tables.should.have.length_of(1) vpc2_main_route_table_ids = [ - route_table.id for route_table in vpc2_main_route_tables] + route_table.id for route_table in vpc2_main_route_tables + ] vpc2_main_route_table_ids.should_not.contain(route_table1.id) vpc2_main_route_table_ids.should_not.contain(route_table2.id) # Unsupported filter conn.get_all_route_tables.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2_deprecated def test_route_tables_filters_associations(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet1 = conn.create_subnet(vpc.id, "10.0.0.0/24") @@ -142,21 +143,24 @@ def test_route_tables_filters_associations(): # Filter by association ID association1_route_tables = conn.get_all_route_tables( - filters={'association.route-table-association-id': association_id1}) + filters={"association.route-table-association-id": association_id1} + ) association1_route_tables.should.have.length_of(1) association1_route_tables[0].id.should.equal(route_table1.id) association1_route_tables[0].associations.should.have.length_of(2) # Filter by route table ID route_table2_route_tables = conn.get_all_route_tables( - filters={'association.route-table-id': route_table2.id}) + filters={"association.route-table-id": route_table2.id} + ) route_table2_route_tables.should.have.length_of(1) route_table2_route_tables[0].id.should.equal(route_table2.id) route_table2_route_tables[0].associations.should.have.length_of(1) # Filter by subnet ID subnet_route_tables = conn.get_all_route_tables( - filters={'association.subnet-id': subnet1.id}) + filters={"association.subnet-id": subnet1.id} + ) subnet_route_tables.should.have.length_of(1) subnet_route_tables[0].id.should.equal(route_table1.id) association1_route_tables[0].associations.should.have.length_of(2) @@ -164,7 +168,7 @@ def test_route_tables_filters_associations(): @mock_ec2_deprecated def test_route_table_associations(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") route_table = conn.create_route_table(vpc.id) @@ -189,14 +193,13 @@ def test_route_table_associations(): route_table.associations[0].subnet_id.should.equal(subnet.id) # Associate is idempotent - association_id_idempotent = conn.associate_route_table( - route_table.id, subnet.id) + association_id_idempotent = conn.associate_route_table(route_table.id, subnet.id) association_id_idempotent.should.equal(association_id) # Error: Attempt delete associated route table. with assert_raises(EC2ResponseError) as cm: conn.delete_route_table(route_table.id) - cm.exception.code.should.equal('DependencyViolation') + cm.exception.code.should.equal("DependencyViolation") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -210,21 +213,21 @@ def test_route_table_associations(): # Error: Disassociate with invalid association ID with assert_raises(EC2ResponseError) as cm: conn.disassociate_route_table(association_id) - cm.exception.code.should.equal('InvalidAssociationID.NotFound') + cm.exception.code.should.equal("InvalidAssociationID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Associate with invalid subnet ID with assert_raises(EC2ResponseError) as cm: conn.associate_route_table(route_table.id, "subnet-1234abcd") - cm.exception.code.should.equal('InvalidSubnetID.NotFound') + cm.exception.code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Associate with invalid route table ID with assert_raises(EC2ResponseError) as cm: conn.associate_route_table("rtb-1234abcd", subnet.id) - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -236,7 +239,7 @@ def test_route_table_replace_route_table_association(): Note: Boto has deprecated replace_route_table_assocation (which returns status) and now uses replace_route_table_assocation_with_assoc (which returns association ID). """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") route_table1 = conn.create_route_table(vpc.id) @@ -267,7 +270,8 @@ def test_route_table_replace_route_table_association(): # Replace Association association_id2 = conn.replace_route_table_association_with_assoc( - association_id1, route_table2.id) + association_id1, route_table2.id + ) # Refresh route_table1 = conn.get_all_route_tables(route_table1.id)[0] @@ -284,120 +288,128 @@ def test_route_table_replace_route_table_association(): # Replace Association is idempotent association_id_idempotent = conn.replace_route_table_association_with_assoc( - association_id2, route_table2.id) + association_id2, route_table2.id + ) association_id_idempotent.should.equal(association_id2) # Error: Replace association with invalid association ID with assert_raises(EC2ResponseError) as cm: conn.replace_route_table_association_with_assoc( - "rtbassoc-1234abcd", route_table1.id) - cm.exception.code.should.equal('InvalidAssociationID.NotFound') + "rtbassoc-1234abcd", route_table1.id + ) + cm.exception.code.should.equal("InvalidAssociationID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Replace association with invalid route table ID with assert_raises(EC2ResponseError) as cm: - conn.replace_route_table_association_with_assoc( - association_id2, "rtb-1234abcd") - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + conn.replace_route_table_association_with_assoc(association_id2, "rtb-1234abcd") + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_route_table_get_by_tag(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - vpc = conn.create_vpc('10.0.0.0/16') + vpc = conn.create_vpc("10.0.0.0/16") route_table = conn.create_route_table(vpc.id) - route_table.add_tag('Name', 'TestRouteTable') + route_table.add_tag("Name", "TestRouteTable") - route_tables = conn.get_all_route_tables( - filters={'tag:Name': 'TestRouteTable'}) + route_tables = conn.get_all_route_tables(filters={"tag:Name": "TestRouteTable"}) route_tables.should.have.length_of(1) route_tables[0].vpc_id.should.equal(vpc.id) route_tables[0].id.should.equal(route_table.id) route_tables[0].tags.should.have.length_of(1) - route_tables[0].tags['Name'].should.equal('TestRouteTable') + route_tables[0].tags["Name"].should.equal("TestRouteTable") @mock_ec2 def test_route_table_get_by_tag_boto3(): - ec2 = boto3.resource('ec2', region_name='eu-central-1') + ec2 = boto3.resource("ec2", region_name="eu-central-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") route_table = ec2.create_route_table(VpcId=vpc.id) - route_table.create_tags(Tags=[{'Key': 'Name', 'Value': 'TestRouteTable'}]) + route_table.create_tags(Tags=[{"Key": "Name", "Value": "TestRouteTable"}]) - filters = [{'Name': 'tag:Name', 'Values': ['TestRouteTable']}] + filters = [{"Name": "tag:Name", "Values": ["TestRouteTable"]}] route_tables = list(ec2.route_tables.filter(Filters=filters)) route_tables.should.have.length_of(1) route_tables[0].vpc_id.should.equal(vpc.id) route_tables[0].id.should.equal(route_table.id) route_tables[0].tags.should.have.length_of(1) - route_tables[0].tags[0].should.equal( - {'Key': 'Name', 'Value': 'TestRouteTable'}) + route_tables[0].tags[0].should.equal({"Key": "Name", "Value": "TestRouteTable"}) @mock_ec2_deprecated def test_routes_additional(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - main_route_table = conn.get_all_route_tables(filters={'vpc-id': vpc.id})[0] + main_route_table = conn.get_all_route_tables(filters={"vpc-id": vpc.id})[0] local_route = main_route_table.routes[0] igw = conn.create_internet_gateway() ROUTE_CIDR = "10.0.0.4/24" conn.create_route(main_route_table.id, ROUTE_CIDR, gateway_id=igw.id) - main_route_table = conn.get_all_route_tables( - filters={'vpc-id': vpc.id})[0] # Refresh route table + main_route_table = conn.get_all_route_tables(filters={"vpc-id": vpc.id})[ + 0 + ] # Refresh route table main_route_table.routes.should.have.length_of(2) new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(1) new_route = new_routes[0] new_route.gateway_id.should.equal(igw.id) new_route.instance_id.should.be.none - new_route.state.should.equal('active') + new_route.state.should.equal("active") new_route.destination_cidr_block.should.equal(ROUTE_CIDR) conn.delete_route(main_route_table.id, ROUTE_CIDR) - main_route_table = conn.get_all_route_tables( - filters={'vpc-id': vpc.id})[0] # Refresh route table + main_route_table = conn.get_all_route_tables(filters={"vpc-id": vpc.id})[ + 0 + ] # Refresh route table main_route_table.routes.should.have.length_of(1) new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(0) with assert_raises(EC2ResponseError) as cm: conn.delete_route(main_route_table.id, ROUTE_CIDR) - cm.exception.code.should.equal('InvalidRoute.NotFound') + cm.exception.code.should.equal("InvalidRoute.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_routes_replace(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc.id})[0] + filters={"association.main": "true", "vpc-id": vpc.id} + )[0] local_route = main_route_table.routes[0] ROUTE_CIDR = "10.0.0.4/24" # Various route targets igw = conn.create_internet_gateway() - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] # Create initial route @@ -407,17 +419,19 @@ def test_routes_replace(): def get_target_route(): route_table = conn.get_all_route_tables(main_route_table.id)[0] routes = [ - route for route in route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] routes.should.have.length_of(1) return routes[0] - conn.replace_route(main_route_table.id, ROUTE_CIDR, - instance_id=instance.id) + conn.replace_route(main_route_table.id, ROUTE_CIDR, instance_id=instance.id) target_route = get_target_route() target_route.gateway_id.should.be.none target_route.instance_id.should.equal(instance.id) - target_route.state.should.equal('active') + target_route.state.should.equal("active") target_route.destination_cidr_block.should.equal(ROUTE_CIDR) conn.replace_route(main_route_table.id, ROUTE_CIDR, gateway_id=igw.id) @@ -425,12 +439,12 @@ def test_routes_replace(): target_route = get_target_route() target_route.gateway_id.should.equal(igw.id) target_route.instance_id.should.be.none - target_route.state.should.equal('active') + target_route.state.should.equal("active") target_route.destination_cidr_block.should.equal(ROUTE_CIDR) with assert_raises(EC2ResponseError) as cm: - conn.replace_route('rtb-1234abcd', ROUTE_CIDR, gateway_id=igw.id) - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + conn.replace_route("rtb-1234abcd", ROUTE_CIDR, gateway_id=igw.id) + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -438,7 +452,7 @@ def test_routes_replace(): @requires_boto_gte("2.19.0") @mock_ec2_deprecated def test_routes_not_supported(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables()[0] local_route = main_route_table.routes[0] @@ -447,42 +461,49 @@ def test_routes_not_supported(): # Create conn.create_route.when.called_with( - main_route_table.id, ROUTE_CIDR, interface_id='eni-1234abcd').should.throw(NotImplementedError) + main_route_table.id, ROUTE_CIDR, interface_id="eni-1234abcd" + ).should.throw(NotImplementedError) # Replace igw = conn.create_internet_gateway() conn.create_route(main_route_table.id, ROUTE_CIDR, gateway_id=igw.id) conn.replace_route.when.called_with( - main_route_table.id, ROUTE_CIDR, interface_id='eni-1234abcd').should.throw(NotImplementedError) + main_route_table.id, ROUTE_CIDR, interface_id="eni-1234abcd" + ).should.throw(NotImplementedError) @requires_boto_gte("2.34.0") @mock_ec2_deprecated def test_routes_vpc_peering_connection(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc.id})[0] + filters={"association.main": "true", "vpc-id": vpc.id} + )[0] local_route = main_route_table.routes[0] ROUTE_CIDR = "10.0.0.4/24" peer_vpc = conn.create_vpc("11.0.0.0/16") vpc_pcx = conn.create_vpc_peering_connection(vpc.id, peer_vpc.id) - conn.create_route(main_route_table.id, ROUTE_CIDR, - vpc_peering_connection_id=vpc_pcx.id) + conn.create_route( + main_route_table.id, ROUTE_CIDR, vpc_peering_connection_id=vpc_pcx.id + ) # Refresh route table main_route_table = conn.get_all_route_tables(main_route_table.id)[0] new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(1) new_route = new_routes[0] new_route.gateway_id.should.be.none new_route.instance_id.should.be.none new_route.vpc_peering_connection_id.should.equal(vpc_pcx.id) - new_route.state.should.equal('blackhole') + new_route.state.should.equal("blackhole") new_route.destination_cidr_block.should.equal(ROUTE_CIDR) @@ -490,10 +511,11 @@ def test_routes_vpc_peering_connection(): @mock_ec2_deprecated def test_routes_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc.id})[0] + filters={"association.main": "true", "vpc-id": vpc.id} + )[0] ROUTE_CIDR = "10.0.0.4/24" vpn_gw = conn.create_vpn_gateway(type="ipsec.1") @@ -502,7 +524,10 @@ def test_routes_vpn_gateway(): main_route_table = conn.get_all_route_tables(main_route_table.id)[0] new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(1) new_route = new_routes[0] @@ -514,7 +539,7 @@ def test_routes_vpn_gateway(): @mock_ec2_deprecated def test_network_acl_tagging(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") route_table = conn.create_route_table(vpc.id) @@ -525,17 +550,16 @@ def test_network_acl_tagging(): tag.value.should.equal("some value") all_route_tables = conn.get_all_route_tables() - test_route_table = next(na for na in all_route_tables - if na.id == route_table.id) + test_route_table = next(na for na in all_route_tables if na.id == route_table.id) test_route_table.tags.should.have.length_of(1) test_route_table.tags["a key"].should.equal("some value") @mock_ec2 def test_create_route_with_invalid_destination_cidr_block_parameter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok @@ -546,9 +570,51 @@ def test_create_route_with_invalid_destination_cidr_block_parameter(): vpc.attach_internet_gateway(InternetGatewayId=internet_gateway.id) internet_gateway.reload() - destination_cidr_block = '1000.1.0.0/20' + destination_cidr_block = "1000.1.0.0/20" with assert_raises(ClientError) as ex: - route = route_table.create_route(DestinationCidrBlock=destination_cidr_block, GatewayId=internet_gateway.id) + route = route_table.create_route( + DestinationCidrBlock=destination_cidr_block, GatewayId=internet_gateway.id + ) str(ex.exception).should.equal( "An error occurred (InvalidParameterValue) when calling the CreateRoute " - "operation: Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(destination_cidr_block)) \ No newline at end of file + "operation: Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format( + destination_cidr_block + ) + ) + + +@mock_ec2 +def test_describe_route_tables_with_nat_gateway(): + ec2 = boto3.client("ec2", region_name="us-west-1") + vpc_id = ec2.create_vpc(CidrBlock="192.168.0.0/23")["Vpc"]["VpcId"] + igw_id = ec2.create_internet_gateway()["InternetGateway"]["InternetGatewayId"] + ec2.attach_internet_gateway(VpcId=vpc_id, InternetGatewayId=igw_id) + az = ec2.describe_availability_zones()["AvailabilityZones"][0]["ZoneName"] + sn_id = ec2.create_subnet( + AvailabilityZone=az, CidrBlock="192.168.0.0/24", VpcId=vpc_id + )["Subnet"]["SubnetId"] + route_table_id = ec2.create_route_table(VpcId=vpc_id)["RouteTable"]["RouteTableId"] + ec2.associate_route_table(SubnetId=sn_id, RouteTableId=route_table_id) + alloc_id = ec2.allocate_address(Domain="vpc")["AllocationId"] + nat_gw_id = ec2.create_nat_gateway(SubnetId=sn_id, AllocationId=alloc_id)[ + "NatGateway" + ]["NatGatewayId"] + ec2.create_route( + DestinationCidrBlock="0.0.0.0/0", + NatGatewayId=nat_gw_id, + RouteTableId=route_table_id, + ) + + route_table = ec2.describe_route_tables( + Filters=[{"Name": "route-table-id", "Values": [route_table_id]}] + )["RouteTables"][0] + nat_gw_routes = [ + route + for route in route_table["Routes"] + if route["DestinationCidrBlock"] == "0.0.0.0/0" + ] + + nat_gw_routes.should.have.length_of(1) + nat_gw_routes[0]["DestinationCidrBlock"].should.equal("0.0.0.0/0") + nat_gw_routes[0]["NatGatewayId"].should.equal(nat_gw_id) + nat_gw_routes[0]["State"].should.equal("active") diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index c09b1e8f4..bb9c8f52a 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -17,27 +17,31 @@ from moto import mock_ec2, mock_ec2_deprecated @mock_ec2_deprecated def test_create_and_describe_security_group(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: security_group = conn.create_security_group( - 'test security group', 'this is a test security group', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + "test security group", "this is a test security group", dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set" + ) security_group = conn.create_security_group( - 'test security group', 'this is a test security group') + "test security group", "this is a test security group" + ) - security_group.name.should.equal('test security group') - security_group.description.should.equal('this is a test security group') + security_group.name.should.equal("test security group") + security_group.description.should.equal("this is a test security group") # Trying to create another group with the same name should throw an error with assert_raises(EC2ResponseError) as cm: conn.create_security_group( - 'test security group', 'this is a test security group') - cm.exception.code.should.equal('InvalidGroup.Duplicate') + "test security group", "this is a test security group" + ) + cm.exception.code.should.equal("InvalidGroup.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -50,18 +54,18 @@ def test_create_and_describe_security_group(): @mock_ec2_deprecated def test_create_security_group_without_description_raises_error(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.create_security_group('test security group', '') - cm.exception.code.should.equal('MissingParameter') + conn.create_security_group("test security group", "") + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_default_security_group(): - conn = boto.ec2.connect_to_region('us-east-1') + conn = boto.ec2.connect_to_region("us-east-1") groups = conn.get_all_security_groups() groups.should.have.length_of(2) groups[0].name.should.equal("default") @@ -69,43 +73,47 @@ def test_default_security_group(): @mock_ec2_deprecated def test_create_and_describe_vpc_security_group(): - conn = boto.connect_ec2('the_key', 'the_secret') - vpc_id = 'vpc-5300000c' + conn = boto.connect_ec2("the_key", "the_secret") + vpc_id = "vpc-5300000c" security_group = conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id=vpc_id) + "test security group", "this is a test security group", vpc_id=vpc_id + ) security_group.vpc_id.should.equal(vpc_id) - security_group.name.should.equal('test security group') - security_group.description.should.equal('this is a test security group') + security_group.name.should.equal("test security group") + security_group.description.should.equal("this is a test security group") # Trying to create another group with the same name in the same VPC should # throw an error with assert_raises(EC2ResponseError) as cm: conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id) - cm.exception.code.should.equal('InvalidGroup.Duplicate') + "test security group", "this is a test security group", vpc_id + ) + cm.exception.code.should.equal("InvalidGroup.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none - all_groups = conn.get_all_security_groups(filters={'vpc_id': [vpc_id]}) + all_groups = conn.get_all_security_groups(filters={"vpc_id": [vpc_id]}) all_groups[0].vpc_id.should.equal(vpc_id) all_groups.should.have.length_of(1) - all_groups[0].name.should.equal('test security group') + all_groups[0].name.should.equal("test security group") @mock_ec2_deprecated def test_create_two_security_groups_with_same_name_in_different_vpc(): - conn = boto.connect_ec2('the_key', 'the_secret') - vpc_id = 'vpc-5300000c' - vpc_id2 = 'vpc-5300000d' + conn = boto.connect_ec2("the_key", "the_secret") + vpc_id = "vpc-5300000c" + vpc_id2 = "vpc-5300000d" conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id) + "test security group", "this is a test security group", vpc_id + ) conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id2) + "test security group", "this is a test security group", vpc_id2 + ) all_groups = conn.get_all_security_groups() @@ -117,28 +125,29 @@ def test_create_two_security_groups_with_same_name_in_different_vpc(): @mock_ec2_deprecated def test_deleting_security_groups(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group1 = conn.create_security_group('test1', 'test1') - conn.create_security_group('test2', 'test2') + conn = boto.connect_ec2("the_key", "the_secret") + security_group1 = conn.create_security_group("test1", "test1") + conn.create_security_group("test2", "test2") conn.get_all_security_groups().should.have.length_of(4) # Deleting a group that doesn't exist should throw an error with assert_raises(EC2ResponseError) as cm: - conn.delete_security_group('foobar') - cm.exception.code.should.equal('InvalidGroup.NotFound') + conn.delete_security_group("foobar") + cm.exception.code.should.equal("InvalidGroup.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Delete by name with assert_raises(EC2ResponseError) as ex: - conn.delete_security_group('test2', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.delete_security_group("test2", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteSecurityGroup operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteSecurityGroup operation: Request would have succeeded, but DryRun flag is set" + ) - conn.delete_security_group('test2') + conn.delete_security_group("test2") conn.get_all_security_groups().should.have.length_of(3) # Delete by group id @@ -148,9 +157,9 @@ def test_deleting_security_groups(): @mock_ec2_deprecated def test_delete_security_group_in_vpc(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") vpc_id = "vpc-12345" - security_group1 = conn.create_security_group('test1', 'test1', vpc_id) + security_group1 = conn.create_security_group("test1", "test1", vpc_id) # this should not throw an exception conn.delete_security_group(group_id=security_group1.id) @@ -158,87 +167,130 @@ def test_delete_security_group_in_vpc(): @mock_ec2_deprecated def test_authorize_ip_range_and_revoke(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group = conn.create_security_group('test', 'test') + conn = boto.connect_ec2("the_key", "the_secret") + security_group = conn.create_security_group("test", "test") with assert_raises(EC2ResponseError) as ex: success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ip_protocol="tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the GrantSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the GrantSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set" + ) success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32") + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32" + ) assert success.should.be.true - security_group = conn.get_all_security_groups(groupnames=['test'])[0] + security_group = conn.get_all_security_groups(groupnames=["test"])[0] int(security_group.rules[0].to_port).should.equal(2222) - security_group.rules[0].grants[ - 0].cidr_ip.should.equal("123.123.123.123/32") + security_group.rules[0].grants[0].cidr_ip.should.equal("123.123.123.123/32") # Wrong Cidr should throw error with assert_raises(EC2ResponseError) as cm: - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", cidr_ip="123.123.123.122/32") - cm.exception.code.should.equal('InvalidPermission.NotFound') + security_group.revoke( + ip_protocol="tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.122/32", + ) + cm.exception.code.should.equal("InvalidPermission.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Actually revoke with assert_raises(EC2ResponseError) as ex: - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + security_group.revoke( + ip_protocol="tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RevokeSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RevokeSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set" + ) - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", cidr_ip="123.123.123.123/32") + security_group.revoke( + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32" + ) security_group = conn.get_all_security_groups()[0] security_group.rules.should.have.length_of(0) # Test for egress as well egress_security_group = conn.create_security_group( - 'testegress', 'testegress', vpc_id='vpc-3432589') + "testegress", "testegress", vpc_id="vpc-3432589" + ) with assert_raises(EC2ResponseError) as ex: success = conn.authorize_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the GrantSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the GrantSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set" + ) success = conn.authorize_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32") + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + ) assert success.should.be.true - egress_security_group = conn.get_all_security_groups( - groupnames='testegress')[0] + egress_security_group = conn.get_all_security_groups(groupnames="testegress")[0] # There are two egress rules associated with the security group: # the default outbound rule and the new one int(egress_security_group.rules_egress[1].to_port).should.equal(2222) - egress_security_group.rules_egress[1].grants[ - 0].cidr_ip.should.equal("123.123.123.123/32") + egress_security_group.rules_egress[1].grants[0].cidr_ip.should.equal( + "123.123.123.123/32" + ) # Wrong Cidr should throw error egress_security_group.revoke.when.called_with( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.122/32").should.throw(EC2ResponseError) + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.122/32" + ).should.throw(EC2ResponseError) # Actually revoke with assert_raises(EC2ResponseError) as ex: conn.revoke_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RevokeSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RevokeSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set" + ) conn.revoke_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32") + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + ) egress_security_group = conn.get_all_security_groups()[0] # There is still the default outbound rule @@ -247,55 +299,69 @@ def test_authorize_ip_range_and_revoke(): @mock_ec2_deprecated def test_authorize_other_group_and_revoke(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group = conn.create_security_group('test', 'test') - other_security_group = conn.create_security_group('other', 'other') - wrong_group = conn.create_security_group('wrong', 'wrong') + conn = boto.connect_ec2("the_key", "the_secret") + security_group = conn.create_security_group("test", "test") + other_security_group = conn.create_security_group("other", "other") + wrong_group = conn.create_security_group("wrong", "wrong") success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", src_group=other_security_group) + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) assert success.should.be.true security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test'][0] + group for group in conn.get_all_security_groups() if group.name == "test" + ][0] int(security_group.rules[0].to_port).should.equal(2222) - security_group.rules[0].grants[ - 0].group_id.should.equal(other_security_group.id) + security_group.rules[0].grants[0].group_id.should.equal(other_security_group.id) # Wrong source group should throw error with assert_raises(EC2ResponseError) as cm: - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", src_group=wrong_group) - cm.exception.code.should.equal('InvalidPermission.NotFound') + security_group.revoke( + ip_protocol="tcp", from_port="22", to_port="2222", src_group=wrong_group + ) + cm.exception.code.should.equal("InvalidPermission.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Actually revoke - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", src_group=other_security_group) + security_group.revoke( + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test'][0] + group for group in conn.get_all_security_groups() if group.name == "test" + ][0] security_group.rules.should.have.length_of(0) @mock_ec2 def test_authorize_other_group_egress_and_revoke(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") sg01 = ec2.create_security_group( - GroupName='sg01', Description='Test security group sg01', VpcId=vpc.id) + GroupName="sg01", Description="Test security group sg01", VpcId=vpc.id + ) sg02 = ec2.create_security_group( - GroupName='sg02', Description='Test security group sg02', VpcId=vpc.id) + GroupName="sg02", Description="Test security group sg02", VpcId=vpc.id + ) ip_permission = { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'UserIdGroupPairs': [{'GroupId': sg02.id, 'GroupName': 'sg02', 'UserId': sg02.owner_id}], - 'IpRanges': [] + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "UserIdGroupPairs": [ + {"GroupId": sg02.id, "GroupName": "sg02", "UserId": sg02.owner_id} + ], + "IpRanges": [], } sg01.authorize_egress(IpPermissions=[ip_permission]) @@ -308,32 +374,41 @@ def test_authorize_other_group_egress_and_revoke(): @mock_ec2_deprecated def test_authorize_group_in_vpc(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") vpc_id = "vpc-12345" # create 2 groups in a vpc - security_group = conn.create_security_group('test1', 'test1', vpc_id) - other_security_group = conn.create_security_group('test2', 'test2', vpc_id) + security_group = conn.create_security_group("test1", "test1", vpc_id) + other_security_group = conn.create_security_group("test2", "test2", vpc_id) success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", src_group=other_security_group) + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) success.should.be.true # Check that the rule is accurate security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test1'][0] + group for group in conn.get_all_security_groups() if group.name == "test1" + ][0] int(security_group.rules[0].to_port).should.equal(2222) - security_group.rules[0].grants[ - 0].group_id.should.equal(other_security_group.id) + security_group.rules[0].grants[0].group_id.should.equal(other_security_group.id) # Now remove the rule success = security_group.revoke( - ip_protocol="tcp", from_port="22", to_port="2222", src_group=other_security_group) + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) success.should.be.true # And check that it gets revoked security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test1'][0] + group for group in conn.get_all_security_groups() if group.name == "test1" + ][0] security_group.rules.should.have.length_of(0) @@ -341,31 +416,32 @@ def test_authorize_group_in_vpc(): def test_get_all_security_groups(): conn = boto.connect_ec2() sg1 = conn.create_security_group( - name='test1', description='test1', vpc_id='vpc-mjm05d27') - conn.create_security_group(name='test2', description='test2') + name="test1", description="test1", vpc_id="vpc-mjm05d27" + ) + conn.create_security_group(name="test2", description="test2") - resp = conn.get_all_security_groups(groupnames=['test1']) + resp = conn.get_all_security_groups(groupnames=["test1"]) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) with assert_raises(EC2ResponseError) as cm: - conn.get_all_security_groups(groupnames=['does_not_exist']) - cm.exception.code.should.equal('InvalidGroup.NotFound') + conn.get_all_security_groups(groupnames=["does_not_exist"]) + cm.exception.code.should.equal("InvalidGroup.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) - resp = conn.get_all_security_groups(filters={'vpc-id': ['vpc-mjm05d27']}) + resp = conn.get_all_security_groups(filters={"vpc-id": ["vpc-mjm05d27"]}) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) - resp = conn.get_all_security_groups(filters={'vpc_id': ['vpc-mjm05d27']}) + resp = conn.get_all_security_groups(filters={"vpc_id": ["vpc-mjm05d27"]}) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) - resp = conn.get_all_security_groups(filters={'description': ['test1']}) + resp = conn.get_all_security_groups(filters={"description": ["test1"]}) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) @@ -375,12 +451,13 @@ def test_get_all_security_groups(): @mock_ec2_deprecated def test_authorize_bad_cidr_throws_invalid_parameter_value(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group = conn.create_security_group('test', 'test') + conn = boto.connect_ec2("the_key", "the_secret") + security_group = conn.create_security_group("test", "test") with assert_raises(EC2ResponseError) as cm: security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123") - cm.exception.code.should.equal('InvalidParameterValue') + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123" + ) + cm.exception.code.should.equal("InvalidParameterValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -394,10 +471,11 @@ def test_security_group_tagging(): with assert_raises(EC2ResponseError) as ex: sg.add_tag("Test", "Tag", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) sg.add_tag("Test", "Tag") @@ -416,20 +494,19 @@ def test_security_group_tag_filtering(): sg = conn.create_security_group("test-sg", "Test SG") sg.add_tag("test-tag", "test-value") - groups = conn.get_all_security_groups( - filters={"tag:test-tag": "test-value"}) + groups = conn.get_all_security_groups(filters={"tag:test-tag": "test-value"}) groups.should.have.length_of(1) @mock_ec2_deprecated def test_authorize_all_protocols_with_no_port_specification(): conn = boto.connect_ec2() - sg = conn.create_security_group('test', 'test') + sg = conn.create_security_group("test", "test") - success = sg.authorize(ip_protocol='-1', cidr_ip='0.0.0.0/0') + success = sg.authorize(ip_protocol="-1", cidr_ip="0.0.0.0/0") success.should.be.true - sg = conn.get_all_security_groups('test')[0] + sg = conn.get_all_security_groups("test")[0] sg.rules[0].from_port.should.equal(None) sg.rules[0].to_port.should.equal(None) @@ -437,63 +514,68 @@ def test_authorize_all_protocols_with_no_port_specification(): @mock_ec2_deprecated def test_sec_group_rule_limit(): ec2_conn = boto.connect_ec2() - sg = ec2_conn.create_security_group('test', 'test') - other_sg = ec2_conn.create_security_group('test_2', 'test_other') + sg = ec2_conn.create_security_group("test", "test") + other_sg = ec2_conn.create_security_group("test_2", "test_other") # INGRESS with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(110)]) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(110)], + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") sg.rules.should.be.empty # authorize a rule targeting a different sec group (because this count too) success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) success.should.be.true # fill the rules up the limit success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(99)]) + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(99)], + ) success.should.be.true # verify that we cannot authorize past the limit for a CIDR IP with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', cidr_ip=['100.0.0.0/0']) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip=["100.0.0.0/0"] + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # EGRESS # authorize a rule targeting a different sec group (because this count too) ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) # fill the rules up the limit # remember that by default, when created a sec group contains 1 egress rule # so our other_sg rule + 98 CIDR IP rules + 1 by default == 100 the limit for i in range(98): ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='{0}.0.0.0/0'.format(i)) + group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i) + ) # verify that we cannot authorize past the limit for a CIDR IP with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='101.0.0.0/0') - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip="101.0.0.0/0" + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") @mock_ec2_deprecated @@ -501,87 +583,93 @@ def test_sec_group_rule_limit_vpc(): ec2_conn = boto.connect_ec2() vpc_conn = boto.connect_vpc() - vpc = vpc_conn.create_vpc('10.0.0.0/16') + vpc = vpc_conn.create_vpc("10.0.0.0/16") - sg = ec2_conn.create_security_group('test', 'test', vpc_id=vpc.id) - other_sg = ec2_conn.create_security_group('test_2', 'test', vpc_id=vpc.id) + sg = ec2_conn.create_security_group("test", "test", vpc_id=vpc.id) + other_sg = ec2_conn.create_security_group("test_2", "test", vpc_id=vpc.id) # INGRESS with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(110)]) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(110)], + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") sg.rules.should.be.empty # authorize a rule targeting a different sec group (because this count too) success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) success.should.be.true # fill the rules up the limit success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(49)]) + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(49)], + ) # verify that we cannot authorize past the limit for a CIDR IP success.should.be.true with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', cidr_ip=['100.0.0.0/0']) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip=["100.0.0.0/0"] + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # EGRESS # authorize a rule targeting a different sec group (because this count too) ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) # fill the rules up the limit # remember that by default, when created a sec group contains 1 egress rule # so our other_sg rule + 48 CIDR IP rules + 1 by default == 50 the limit for i in range(48): ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='{0}.0.0.0/0'.format(i)) + group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i) + ) # verify that we cannot authorize past the limit for a CIDR IP with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='50.0.0.0/0') - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip="50.0.0.0/0" + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") -''' +""" Boto3 -''' +""" @mock_ec2 def test_add_same_rule_twice_throws_error(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") sg = ec2.create_security_group( - GroupName='sg1', Description='Test security group sg1', VpcId=vpc.id) + GroupName="sg1", Description="Test security group sg1", VpcId=vpc.id + ) ip_permissions = [ { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'IpRanges': [{"CidrIp": "1.2.3.4/32"}] - }, + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "IpRanges": [{"CidrIp": "1.2.3.4/32"}], + } ] sg.authorize_ingress(IpPermissions=ip_permissions) @@ -591,82 +679,89 @@ def test_add_same_rule_twice_throws_error(): @mock_ec2 def test_security_group_tagging_boto3(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") sg = conn.create_security_group(GroupName="test-sg", Description="Test SG") with assert_raises(ClientError) as ex: - conn.create_tags(Resources=[sg['GroupId']], Tags=[ - {'Key': 'Test', 'Value': 'Tag'}], DryRun=True) - ex.exception.response['Error']['Code'].should.equal('DryRunOperation') - ex.exception.response['ResponseMetadata'][ - 'HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + conn.create_tags( + Resources=[sg["GroupId"]], + Tags=[{"Key": "Test", "Value": "Tag"}], + DryRun=True, + ) + ex.exception.response["Error"]["Code"].should.equal("DryRunOperation") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) - conn.create_tags(Resources=[sg['GroupId']], Tags=[ - {'Key': 'Test', 'Value': 'Tag'}]) + conn.create_tags(Resources=[sg["GroupId"]], Tags=[{"Key": "Test", "Value": "Tag"}]) describe = conn.describe_security_groups( - Filters=[{'Name': 'tag-value', 'Values': ['Tag']}]) - tag = describe["SecurityGroups"][0]['Tags'][0] - tag['Value'].should.equal("Tag") - tag['Key'].should.equal("Test") + Filters=[{"Name": "tag-value", "Values": ["Tag"]}] + ) + tag = describe["SecurityGroups"][0]["Tags"][0] + tag["Value"].should.equal("Tag") + tag["Key"].should.equal("Test") @mock_ec2 def test_security_group_wildcard_tag_filter_boto3(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") sg = conn.create_security_group(GroupName="test-sg", Description="Test SG") - conn.create_tags(Resources=[sg['GroupId']], Tags=[ - {'Key': 'Test', 'Value': 'Tag'}]) + conn.create_tags(Resources=[sg["GroupId"]], Tags=[{"Key": "Test", "Value": "Tag"}]) describe = conn.describe_security_groups( - Filters=[{'Name': 'tag-value', 'Values': ['*']}]) + Filters=[{"Name": "tag-value", "Values": ["*"]}] + ) - tag = describe["SecurityGroups"][0]['Tags'][0] - tag['Value'].should.equal("Tag") - tag['Key'].should.equal("Test") + tag = describe["SecurityGroups"][0]["Tags"][0] + tag["Value"].should.equal("Tag") + tag["Key"].should.equal("Test") @mock_ec2 def test_authorize_and_revoke_in_bulk(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") sg01 = ec2.create_security_group( - GroupName='sg01', Description='Test security group sg01', VpcId=vpc.id) + GroupName="sg01", Description="Test security group sg01", VpcId=vpc.id + ) sg02 = ec2.create_security_group( - GroupName='sg02', Description='Test security group sg02', VpcId=vpc.id) + GroupName="sg02", Description="Test security group sg02", VpcId=vpc.id + ) sg03 = ec2.create_security_group( - GroupName='sg03', Description='Test security group sg03') + GroupName="sg03", Description="Test security group sg03" + ) ip_permissions = [ { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'UserIdGroupPairs': [{'GroupId': sg02.id, 'GroupName': 'sg02', - 'UserId': sg02.owner_id}], - 'IpRanges': [] + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "UserIdGroupPairs": [ + {"GroupId": sg02.id, "GroupName": "sg02", "UserId": sg02.owner_id} + ], + "IpRanges": [], }, { - 'IpProtocol': 'tcp', - 'FromPort': 27018, - 'ToPort': 27018, - 'UserIdGroupPairs': [{'GroupId': sg02.id, 'UserId': sg02.owner_id}], - 'IpRanges': [] + "IpProtocol": "tcp", + "FromPort": 27018, + "ToPort": 27018, + "UserIdGroupPairs": [{"GroupId": sg02.id, "UserId": sg02.owner_id}], + "IpRanges": [], }, { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'UserIdGroupPairs': [{'GroupName': 'sg03', 'UserId': sg03.owner_id}], - 'IpRanges': [] - } + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "UserIdGroupPairs": [{"GroupName": "sg03", "UserId": sg03.owner_id}], + "IpRanges": [], + }, ] expected_ip_permissions = copy.deepcopy(ip_permissions) - expected_ip_permissions[1]['UserIdGroupPairs'][0]['GroupName'] = 'sg02' - expected_ip_permissions[2]['UserIdGroupPairs'][0]['GroupId'] = sg03.id + expected_ip_permissions[1]["UserIdGroupPairs"][0]["GroupName"] = "sg02" + expected_ip_permissions[2]["UserIdGroupPairs"][0]["GroupId"] = sg03.id sg01.authorize_ingress(IpPermissions=ip_permissions) sg01.ip_permissions.should.have.length_of(3) @@ -691,11 +786,13 @@ def test_authorize_and_revoke_in_bulk(): @mock_ec2 def test_security_group_ingress_without_multirule(): - ec2 = boto3.resource('ec2', 'ca-central-1') - sg = ec2.create_security_group(Description='Test SG', GroupName='test-sg') + ec2 = boto3.resource("ec2", "ca-central-1") + sg = ec2.create_security_group(Description="Test SG", GroupName="test-sg") assert len(sg.ip_permissions) == 0 - sg.authorize_ingress(CidrIp='192.168.0.1/32', FromPort=22, ToPort=22, IpProtocol='tcp') + sg.authorize_ingress( + CidrIp="192.168.0.1/32", FromPort=22, ToPort=22, IpProtocol="tcp" + ) # Fails assert len(sg.ip_permissions) == 1 @@ -703,11 +800,13 @@ def test_security_group_ingress_without_multirule(): @mock_ec2 def test_security_group_ingress_without_multirule_after_reload(): - ec2 = boto3.resource('ec2', 'ca-central-1') - sg = ec2.create_security_group(Description='Test SG', GroupName='test-sg') + ec2 = boto3.resource("ec2", "ca-central-1") + sg = ec2.create_security_group(Description="Test SG", GroupName="test-sg") assert len(sg.ip_permissions) == 0 - sg.authorize_ingress(CidrIp='192.168.0.1/32', FromPort=22, ToPort=22, IpProtocol='tcp') + sg.authorize_ingress( + CidrIp="192.168.0.1/32", FromPort=22, ToPort=22, IpProtocol="tcp" + ) # Also Fails sg_after = ec2.SecurityGroup(sg.id) @@ -716,22 +815,51 @@ def test_security_group_ingress_without_multirule_after_reload(): @mock_ec2_deprecated def test_get_all_security_groups_filter_with_same_vpc_id(): - conn = boto.connect_ec2('the_key', 'the_secret') - vpc_id = 'vpc-5300000c' - security_group = conn.create_security_group( - 'test1', 'test1', vpc_id=vpc_id) - security_group2 = conn.create_security_group( - 'test2', 'test2', vpc_id=vpc_id) + conn = boto.connect_ec2("the_key", "the_secret") + vpc_id = "vpc-5300000c" + security_group = conn.create_security_group("test1", "test1", vpc_id=vpc_id) + security_group2 = conn.create_security_group("test2", "test2", vpc_id=vpc_id) security_group.vpc_id.should.equal(vpc_id) security_group2.vpc_id.should.equal(vpc_id) security_groups = conn.get_all_security_groups( - group_ids=[security_group.id], filters={'vpc-id': [vpc_id]}) + group_ids=[security_group.id], filters={"vpc-id": [vpc_id]} + ) security_groups.should.have.length_of(1) with assert_raises(EC2ResponseError) as cm: - conn.get_all_security_groups(group_ids=['does_not_exist']) - cm.exception.code.should.equal('InvalidGroup.NotFound') + conn.get_all_security_groups(group_ids=["does_not_exist"]) + cm.exception.code.should.equal("InvalidGroup.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none + + +@mock_ec2 +def test_revoke_security_group_egress(): + ec2 = boto3.resource("ec2", "us-east-1") + sg = ec2.create_security_group(Description="Test SG", GroupName="test-sg") + + sg.ip_permissions_egress.should.equal( + [ + { + "IpProtocol": "-1", + "IpRanges": [{"CidrIp": "0.0.0.0/0"}], + "UserIdGroupPairs": [], + } + ] + ) + + sg.revoke_egress( + IpPermissions=[ + { + "FromPort": 0, + "IpProtocol": "-1", + "IpRanges": [{"CidrIp": "0.0.0.0/0"},], + "ToPort": 123, + }, + ] + ) + + sg.reload() + sg.ip_permissions_egress.should.have.length_of(0) diff --git a/tests/test_ec2/test_server.py b/tests/test_ec2/test_server.py index dc5657144..f09146b2a 100644 --- a/tests/test_ec2/test_server.py +++ b/tests/test_ec2/test_server.py @@ -1,26 +1,25 @@ -from __future__ import unicode_literals -import re -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_ec2_server_get(): - backend = server.create_backend_app("ec2") - test_client = backend.test_client() - - res = test_client.get( - '/?Action=RunInstances&ImageId=ami-60a54009', - headers={"Host": "ec2.us-east-1.amazonaws.com"} - ) - - groups = re.search("(.*)", - res.data.decode('utf-8')) - instance_id = groups.groups()[0] - - res = test_client.get('/?Action=DescribeInstances') - res.data.decode('utf-8').should.contain(instance_id) +from __future__ import unicode_literals +import re +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_ec2_server_get(): + backend = server.create_backend_app("ec2") + test_client = backend.test_client() + + res = test_client.get( + "/?Action=RunInstances&ImageId=ami-60a54009", + headers={"Host": "ec2.us-east-1.amazonaws.com"}, + ) + + groups = re.search("(.*)", res.data.decode("utf-8")) + instance_id = groups.groups()[0] + + res = test_client.get("/?Action=DescribeInstances") + res.data.decode("utf-8").should.contain(instance_id) diff --git a/tests/test_ec2/test_spot_fleet.py b/tests/test_ec2/test_spot_fleet.py index 6221d633f..87b2f6291 100644 --- a/tests/test_ec2/test_spot_fleet.py +++ b/tests/test_ec2/test_spot_fleet.py @@ -4,384 +4,376 @@ import boto3 import sure # noqa from moto import mock_ec2 +from moto.core import ACCOUNT_ID def get_subnet_id(conn): - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] return subnet_id def spot_config(subnet_id, allocation_strategy="lowestPrice"): return { - 'ClientToken': 'string', - 'SpotPrice': '0.12', - 'TargetCapacity': 6, - 'IamFleetRole': 'arn:aws:iam::123456789012:role/fleet', - 'LaunchSpecifications': [{ - 'ImageId': 'ami-123', - 'KeyName': 'my-key', - 'SecurityGroups': [ - { - 'GroupId': 'sg-123' - }, - ], - 'UserData': 'some user data', - 'InstanceType': 't2.small', - 'BlockDeviceMappings': [ - { - 'VirtualName': 'string', - 'DeviceName': 'string', - 'Ebs': { - 'SnapshotId': 'string', - 'VolumeSize': 123, - 'DeleteOnTermination': True | False, - 'VolumeType': 'standard', - 'Iops': 123, - 'Encrypted': True | False + "ClientToken": "string", + "SpotPrice": "0.12", + "TargetCapacity": 6, + "IamFleetRole": "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID), + "LaunchSpecifications": [ + { + "ImageId": "ami-123", + "KeyName": "my-key", + "SecurityGroups": [{"GroupId": "sg-123"}], + "UserData": "some user data", + "InstanceType": "t2.small", + "BlockDeviceMappings": [ + { + "VirtualName": "string", + "DeviceName": "string", + "Ebs": { + "SnapshotId": "string", + "VolumeSize": 123, + "DeleteOnTermination": True | False, + "VolumeType": "standard", + "Iops": 123, + "Encrypted": True | False, }, - 'NoDevice': 'string' + "NoDevice": "string", + } + ], + "Monitoring": {"Enabled": True}, + "SubnetId": subnet_id, + "IamInstanceProfile": { + "Arn": "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID) }, - ], - 'Monitoring': { - 'Enabled': True + "EbsOptimized": False, + "WeightedCapacity": 2.0, + "SpotPrice": "0.13", }, - 'SubnetId': subnet_id, - 'IamInstanceProfile': { - 'Arn': 'arn:aws:iam::123456789012:role/fleet' - }, - 'EbsOptimized': False, - 'WeightedCapacity': 2.0, - 'SpotPrice': '0.13', - }, { - 'ImageId': 'ami-123', - 'KeyName': 'my-key', - 'SecurityGroups': [ - { - 'GroupId': 'sg-123' + { + "ImageId": "ami-123", + "KeyName": "my-key", + "SecurityGroups": [{"GroupId": "sg-123"}], + "UserData": "some user data", + "InstanceType": "t2.large", + "Monitoring": {"Enabled": True}, + "SubnetId": subnet_id, + "IamInstanceProfile": { + "Arn": "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID) }, - ], - 'UserData': 'some user data', - 'InstanceType': 't2.large', - 'Monitoring': { - 'Enabled': True + "EbsOptimized": False, + "WeightedCapacity": 4.0, + "SpotPrice": "10.00", }, - 'SubnetId': subnet_id, - 'IamInstanceProfile': { - 'Arn': 'arn:aws:iam::123456789012:role/fleet' - }, - 'EbsOptimized': False, - 'WeightedCapacity': 4.0, - 'SpotPrice': '10.00', - }], - 'AllocationStrategy': allocation_strategy, - 'FulfilledCapacity': 6, + ], + "AllocationStrategy": allocation_strategy, + "FulfilledCapacity": 6, } @mock_ec2 def test_create_spot_fleet_with_lowest_price(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_request['SpotFleetRequestState'].should.equal("active") - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_request["SpotFleetRequestState"].should.equal("active") + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - spot_fleet_config['SpotPrice'].should.equal('0.12') - spot_fleet_config['TargetCapacity'].should.equal(6) - spot_fleet_config['IamFleetRole'].should.equal( - 'arn:aws:iam::123456789012:role/fleet') - spot_fleet_config['AllocationStrategy'].should.equal('lowestPrice') - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + spot_fleet_config["SpotPrice"].should.equal("0.12") + spot_fleet_config["TargetCapacity"].should.equal(6) + spot_fleet_config["IamFleetRole"].should.equal( + "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID) + ) + spot_fleet_config["AllocationStrategy"].should.equal("lowestPrice") + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec = spot_fleet_config['LaunchSpecifications'][0] + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec = spot_fleet_config["LaunchSpecifications"][0] - launch_spec['EbsOptimized'].should.equal(False) - launch_spec['SecurityGroups'].should.equal([{"GroupId": "sg-123"}]) - launch_spec['IamInstanceProfile'].should.equal( - {"Arn": "arn:aws:iam::123456789012:role/fleet"}) - launch_spec['ImageId'].should.equal("ami-123") - launch_spec['InstanceType'].should.equal("t2.small") - launch_spec['KeyName'].should.equal("my-key") - launch_spec['Monitoring'].should.equal({"Enabled": True}) - launch_spec['SpotPrice'].should.equal("0.13") - launch_spec['SubnetId'].should.equal(subnet_id) - launch_spec['UserData'].should.equal("some user data") - launch_spec['WeightedCapacity'].should.equal(2.0) + launch_spec["EbsOptimized"].should.equal(False) + launch_spec["SecurityGroups"].should.equal([{"GroupId": "sg-123"}]) + launch_spec["IamInstanceProfile"].should.equal( + {"Arn": "arn:aws:iam::{}:role/fleet".format(ACCOUNT_ID)} + ) + launch_spec["ImageId"].should.equal("ami-123") + launch_spec["InstanceType"].should.equal("t2.small") + launch_spec["KeyName"].should.equal("my-key") + launch_spec["Monitoring"].should.equal({"Enabled": True}) + launch_spec["SpotPrice"].should.equal("0.13") + launch_spec["SubnetId"].should.equal(subnet_id) + launch_spec["UserData"].should.equal("some user data") + launch_spec["WeightedCapacity"].should.equal(2.0) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(3) @mock_ec2 def test_create_diversified_spot_fleet(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) - diversified_config = spot_config( - subnet_id, allocation_strategy='diversified') + diversified_config = spot_config(subnet_id, allocation_strategy="diversified") - spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=diversified_config - ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=diversified_config) + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(2) - instance_types = set([instance['InstanceType'] for instance in instances]) + instance_types = set([instance["InstanceType"] for instance in instances]) instance_types.should.equal(set(["t2.small", "t2.large"])) - instances[0]['InstanceId'].should.contain("i-") + instances[0]["InstanceId"].should.contain("i-") @mock_ec2 def test_create_spot_fleet_request_with_tag_spec(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) tag_spec = [ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'tag-1', - 'Value': 'foo', - }, - { - 'Key': 'tag-2', - 'Value': 'bar', - }, - ] - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "tag-1", "Value": "foo"}, + {"Key": "tag-2", "Value": "bar"}, + ], + } ] config = spot_config(subnet_id) - config['LaunchSpecifications'][0]['TagSpecifications'] = tag_spec - spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=config - ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + config["LaunchSpecifications"][0]["TagSpecifications"] = tag_spec + spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=config) + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] - spot_fleet_config = spot_fleet_requests[0]['SpotFleetRequestConfig'] - spot_fleet_config['LaunchSpecifications'][0]['TagSpecifications'][0][ - 'ResourceType'].should.equal('instance') - for tag in tag_spec[0]['Tags']: - spot_fleet_config['LaunchSpecifications'][0]['TagSpecifications'][0]['Tags'].should.contain(tag) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] + spot_fleet_config = spot_fleet_requests[0]["SpotFleetRequestConfig"] + spot_fleet_config["LaunchSpecifications"][0]["TagSpecifications"][0][ + "ResourceType" + ].should.equal("instance") + for tag in tag_spec[0]["Tags"]: + spot_fleet_config["LaunchSpecifications"][0]["TagSpecifications"][0][ + "Tags" + ].should.contain(tag) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = conn.describe_instances(InstanceIds=[i['InstanceId'] for i in instance_res['ActiveInstances']]) - for instance in instances['Reservations'][0]['Instances']: - for tag in tag_spec[0]['Tags']: - instance['Tags'].should.contain(tag) + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = conn.describe_instances( + InstanceIds=[i["InstanceId"] for i in instance_res["ActiveInstances"]] + ) + for instance in instances["Reservations"][0]["Instances"]: + for tag in tag_spec[0]["Tags"]: + instance["Tags"].should.contain(tag) @mock_ec2 def test_cancel_spot_fleet_request(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] conn.cancel_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id], TerminateInstances=True) + SpotFleetRequestIds=[spot_fleet_id], TerminateInstances=True + ) spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(0) @mock_ec2 def test_modify_spot_fleet_request_up(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=20) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=20) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(10) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(20) - spot_fleet_config['FulfilledCapacity'].should.equal(20.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(20) + spot_fleet_config["FulfilledCapacity"].should.equal(20.0) @mock_ec2 def test_modify_spot_fleet_request_up_diversified(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config( - subnet_id, allocation_strategy='diversified'), + SpotFleetRequestConfig=spot_config(subnet_id, allocation_strategy="diversified") ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=19) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=19) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(7) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(19) - spot_fleet_config['FulfilledCapacity'].should.equal(20.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(19) + spot_fleet_config["FulfilledCapacity"].should.equal(20.0) @mock_ec2 def test_modify_spot_fleet_request_down_no_terminate(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=1, ExcessCapacityTerminationPolicy="noTermination") + SpotFleetRequestId=spot_fleet_id, + TargetCapacity=1, + ExcessCapacityTerminationPolicy="noTermination", + ) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(3) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(1) - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(1) + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) @mock_ec2 def test_modify_spot_fleet_request_down_odd(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=7) - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=5) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=7) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=5) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(3) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(5) - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(5) + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) @mock_ec2 def test_modify_spot_fleet_request_down(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=1) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=1) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(1) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(1) - spot_fleet_config['FulfilledCapacity'].should.equal(2.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(1) + spot_fleet_config["FulfilledCapacity"].should.equal(2.0) @mock_ec2 def test_modify_spot_fleet_request_down_no_terminate_after_custom_terminate(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] - conn.terminate_instances(InstanceIds=[i['InstanceId'] for i in instances[1:]]) + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] + conn.terminate_instances(InstanceIds=[i["InstanceId"] for i in instances[1:]]) conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=1, ExcessCapacityTerminationPolicy="noTermination") + SpotFleetRequestId=spot_fleet_id, + TargetCapacity=1, + ExcessCapacityTerminationPolicy="noTermination", + ) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(1) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(1) - spot_fleet_config['FulfilledCapacity'].should.equal(2.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(1) + spot_fleet_config["FulfilledCapacity"].should.equal(2.0) @mock_ec2 def test_create_spot_fleet_without_spot_price(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) # remove prices to force a fallback to ondemand price spot_config_without_price = spot_config(subnet_id) - del spot_config_without_price['SpotPrice'] - for spec in spot_config_without_price['LaunchSpecifications']: - del spec['SpotPrice'] + del spot_config_without_price["SpotPrice"] + for spec in spot_config_without_price["LaunchSpecifications"]: + del spec["SpotPrice"] - spot_fleet_id = conn.request_spot_fleet(SpotFleetRequestConfig=spot_config_without_price)['SpotFleetRequestId'] + spot_fleet_id = conn.request_spot_fleet( + SpotFleetRequestConfig=spot_config_without_price + )["SpotFleetRequestId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec1 = spot_fleet_config['LaunchSpecifications'][0] - launch_spec2 = spot_fleet_config['LaunchSpecifications'][1] + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec1 = spot_fleet_config["LaunchSpecifications"][0] + launch_spec2 = spot_fleet_config["LaunchSpecifications"][1] # AWS will figure out the price - assert 'SpotPrice' not in launch_spec1 - assert 'SpotPrice' not in launch_spec2 + assert "SpotPrice" not in launch_spec1 + assert "SpotPrice" not in launch_spec2 diff --git a/tests/test_ec2/test_spot_instances.py b/tests/test_ec2/test_spot_instances.py index ab08d392c..cfc95bb82 100644 --- a/tests/test_ec2/test_spot_instances.py +++ b/tests/test_ec2/test_spot_instances.py @@ -16,14 +16,15 @@ from moto.core.utils import iso_8601_datetime_with_milliseconds @mock_ec2 def test_request_spot_instances(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] - conn.create_security_group(GroupName='group1', Description='description') - conn.create_security_group(GroupName='group2', Description='description') + conn.create_security_group(GroupName="group1", Description="description") + conn.create_security_group(GroupName="group2", Description="description") start_dt = datetime.datetime(2013, 1, 1).replace(tzinfo=pytz.utc) end_dt = datetime.datetime(2013, 1, 2).replace(tzinfo=pytz.utc) @@ -32,78 +33,79 @@ def test_request_spot_instances(): with assert_raises(ClientError) as ex: request = conn.request_spot_instances( - SpotPrice="0.5", InstanceCount=1, Type='one-time', - ValidFrom=start, ValidUntil=end, LaunchGroup="the-group", - AvailabilityZoneGroup='my-group', + SpotPrice="0.5", + InstanceCount=1, + Type="one-time", + ValidFrom=start, + ValidUntil=end, + LaunchGroup="the-group", + AvailabilityZoneGroup="my-group", LaunchSpecification={ - "ImageId": 'ami-abcd1234', + "ImageId": "ami-abcd1234", "KeyName": "test", - "SecurityGroups": ['group1', 'group2'], + "SecurityGroups": ["group1", "group2"], "UserData": "some test data", - "InstanceType": 'm1.small', - "Placement": { - "AvailabilityZone": 'us-east-1c', - }, + "InstanceType": "m1.small", + "Placement": {"AvailabilityZone": "us-east-1c"}, "KernelId": "test-kernel", "RamdiskId": "test-ramdisk", - "Monitoring": { - "Enabled": True, - }, + "Monitoring": {"Enabled": True}, "SubnetId": subnet_id, }, DryRun=True, ) - ex.exception.response['Error']['Code'].should.equal('DryRunOperation') - ex.exception.response['ResponseMetadata'][ - 'HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'An error occurred (DryRunOperation) when calling the RequestSpotInstance operation: Request would have succeeded, but DryRun flag is set') + ex.exception.response["Error"]["Code"].should.equal("DryRunOperation") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "An error occurred (DryRunOperation) when calling the RequestSpotInstance operation: Request would have succeeded, but DryRun flag is set" + ) request = conn.request_spot_instances( - SpotPrice="0.5", InstanceCount=1, Type='one-time', - ValidFrom=start, ValidUntil=end, LaunchGroup="the-group", - AvailabilityZoneGroup='my-group', + SpotPrice="0.5", + InstanceCount=1, + Type="one-time", + ValidFrom=start, + ValidUntil=end, + LaunchGroup="the-group", + AvailabilityZoneGroup="my-group", LaunchSpecification={ - "ImageId": 'ami-abcd1234', + "ImageId": "ami-abcd1234", "KeyName": "test", - "SecurityGroups": ['group1', 'group2'], + "SecurityGroups": ["group1", "group2"], "UserData": "some test data", - "InstanceType": 'm1.small', - "Placement": { - "AvailabilityZone": 'us-east-1c', - }, + "InstanceType": "m1.small", + "Placement": {"AvailabilityZone": "us-east-1c"}, "KernelId": "test-kernel", "RamdiskId": "test-ramdisk", - "Monitoring": { - "Enabled": True, - }, + "Monitoring": {"Enabled": True}, "SubnetId": subnet_id, }, ) - requests = conn.describe_spot_instance_requests()['SpotInstanceRequests'] + requests = conn.describe_spot_instance_requests()["SpotInstanceRequests"] requests.should.have.length_of(1) request = requests[0] - request['State'].should.equal("open") - request['SpotPrice'].should.equal("0.5") - request['Type'].should.equal('one-time') - request['ValidFrom'].should.equal(start_dt) - request['ValidUntil'].should.equal(end_dt) - request['LaunchGroup'].should.equal("the-group") - request['AvailabilityZoneGroup'].should.equal('my-group') + request["State"].should.equal("open") + request["SpotPrice"].should.equal("0.5") + request["Type"].should.equal("one-time") + request["ValidFrom"].should.equal(start_dt) + request["ValidUntil"].should.equal(end_dt) + request["LaunchGroup"].should.equal("the-group") + request["AvailabilityZoneGroup"].should.equal("my-group") - launch_spec = request['LaunchSpecification'] - security_group_names = [group['GroupName'] - for group in launch_spec['SecurityGroups']] - set(security_group_names).should.equal(set(['group1', 'group2'])) + launch_spec = request["LaunchSpecification"] + security_group_names = [ + group["GroupName"] for group in launch_spec["SecurityGroups"] + ] + set(security_group_names).should.equal(set(["group1", "group2"])) - launch_spec['ImageId'].should.equal('ami-abcd1234') - launch_spec['KeyName'].should.equal("test") - launch_spec['InstanceType'].should.equal('m1.small') - launch_spec['KernelId'].should.equal("test-kernel") - launch_spec['RamdiskId'].should.equal("test-ramdisk") - launch_spec['SubnetId'].should.equal(subnet_id) + launch_spec["ImageId"].should.equal("ami-abcd1234") + launch_spec["KeyName"].should.equal("test") + launch_spec["InstanceType"].should.equal("m1.small") + launch_spec["KernelId"].should.equal("test-kernel") + launch_spec["RamdiskId"].should.equal("test-ramdisk") + launch_spec["SubnetId"].should.equal(subnet_id) @mock_ec2 @@ -111,58 +113,55 @@ def test_request_spot_instances_default_arguments(): """ Test that moto set the correct default arguments """ - conn = boto3.client('ec2', 'us-east-1') + conn = boto3.client("ec2", "us-east-1") request = conn.request_spot_instances( - SpotPrice="0.5", - LaunchSpecification={ - "ImageId": 'ami-abcd1234', - } + SpotPrice="0.5", LaunchSpecification={"ImageId": "ami-abcd1234"} ) - requests = conn.describe_spot_instance_requests()['SpotInstanceRequests'] + requests = conn.describe_spot_instance_requests()["SpotInstanceRequests"] requests.should.have.length_of(1) request = requests[0] - request['State'].should.equal("open") - request['SpotPrice'].should.equal("0.5") - request['Type'].should.equal('one-time') - request.shouldnt.contain('ValidFrom') - request.shouldnt.contain('ValidUntil') - request.shouldnt.contain('LaunchGroup') - request.shouldnt.contain('AvailabilityZoneGroup') + request["State"].should.equal("open") + request["SpotPrice"].should.equal("0.5") + request["Type"].should.equal("one-time") + request.shouldnt.contain("ValidFrom") + request.shouldnt.contain("ValidUntil") + request.shouldnt.contain("LaunchGroup") + request.shouldnt.contain("AvailabilityZoneGroup") - launch_spec = request['LaunchSpecification'] + launch_spec = request["LaunchSpecification"] - security_group_names = [group['GroupName'] - for group in launch_spec['SecurityGroups']] + security_group_names = [ + group["GroupName"] for group in launch_spec["SecurityGroups"] + ] security_group_names.should.equal(["default"]) - launch_spec['ImageId'].should.equal('ami-abcd1234') - request.shouldnt.contain('KeyName') - launch_spec['InstanceType'].should.equal('m1.small') - request.shouldnt.contain('KernelId') - request.shouldnt.contain('RamdiskId') - request.shouldnt.contain('SubnetId') + launch_spec["ImageId"].should.equal("ami-abcd1234") + request.shouldnt.contain("KeyName") + launch_spec["InstanceType"].should.equal("m1.small") + request.shouldnt.contain("KernelId") + request.shouldnt.contain("RamdiskId") + request.shouldnt.contain("SubnetId") @mock_ec2_deprecated def test_cancel_spot_instance_request(): conn = boto.connect_ec2() - conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) + conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) with assert_raises(EC2ResponseError) as ex: conn.cancel_spot_instance_requests([requests[0].id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CancelSpotInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CancelSpotInstance operation: Request would have succeeded, but DryRun flag is set" + ) conn.cancel_spot_instance_requests([requests[0].id]) @@ -177,9 +176,7 @@ def test_request_spot_instances_fulfilled(): """ conn = boto.ec2.connect_to_region("us-east-1") - request = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) + request = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) @@ -187,7 +184,7 @@ def test_request_spot_instances_fulfilled(): request.state.should.equal("open") - get_model('SpotInstanceRequest', 'us-east-1')[0].state = 'active' + get_model("SpotInstanceRequest", "us-east-1")[0].state = "active" requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) @@ -203,18 +200,16 @@ def test_tag_spot_instance_request(): """ conn = boto.connect_ec2() - request = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - request[0].add_tag('tag1', 'value1') - request[0].add_tag('tag2', 'value2') + request = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + request[0].add_tag("tag1", "value1") + request[0].add_tag("tag2", "value2") requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) request = requests[0] tag_dict = dict(request.tags) - tag_dict.should.equal({'tag1': 'value1', 'tag2': 'value2'}) + tag_dict.should.equal({"tag1": "value1", "tag2": "value2"}) @mock_ec2_deprecated @@ -224,45 +219,38 @@ def test_get_all_spot_instance_requests_filtering(): """ conn = boto.connect_ec2() - request1 = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - request2 = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - request1[0].add_tag('tag1', 'value1') - request1[0].add_tag('tag2', 'value2') - request2[0].add_tag('tag1', 'value1') - request2[0].add_tag('tag2', 'wrong') + request1 = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + request2 = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + request1[0].add_tag("tag1", "value1") + request1[0].add_tag("tag2", "value2") + request2[0].add_tag("tag1", "value1") + request2[0].add_tag("tag2", "wrong") - requests = conn.get_all_spot_instance_requests(filters={'state': 'active'}) + requests = conn.get_all_spot_instance_requests(filters={"state": "active"}) requests.should.have.length_of(0) - requests = conn.get_all_spot_instance_requests(filters={'state': 'open'}) + requests = conn.get_all_spot_instance_requests(filters={"state": "open"}) requests.should.have.length_of(3) - requests = conn.get_all_spot_instance_requests( - filters={'tag:tag1': 'value1'}) + requests = conn.get_all_spot_instance_requests(filters={"tag:tag1": "value1"}) requests.should.have.length_of(2) requests = conn.get_all_spot_instance_requests( - filters={'tag:tag1': 'value1', 'tag:tag2': 'value2'}) + filters={"tag:tag1": "value1", "tag:tag2": "value2"} + ) requests.should.have.length_of(1) @mock_ec2_deprecated def test_request_spot_instances_setting_instance_id(): conn = boto.ec2.connect_to_region("us-east-1") - request = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234') + request = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") - req = get_model('SpotInstanceRequest', 'us-east-1')[0] - req.state = 'active' - req.instance_id = 'i-12345678' + req = get_model("SpotInstanceRequest", "us-east-1")[0] + req.state = "active" + req.instance_id = "i-12345678" request = conn.get_all_spot_instance_requests()[0] - assert request.state == 'active' - assert request.instance_id == 'i-12345678' + assert request.state == "active" + assert request.instance_id == "i-12345678" diff --git a/tests/test_ec2/test_subnets.py b/tests/test_ec2/test_subnets.py index 38c36f682..7bb57aab4 100644 --- a/tests/test_ec2/test_subnets.py +++ b/tests/test_ec2/test_subnets.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises # noqa from nose.tools import assert_raises @@ -10,14 +11,15 @@ from boto.exception import EC2ResponseError from botocore.exceptions import ParamValidationError, ClientError import json import sure # noqa +import random from moto import mock_cloudformation_deprecated, mock_ec2, mock_ec2_deprecated @mock_ec2_deprecated def test_subnets(): - ec2 = boto.connect_ec2('the_key', 'the_secret') - conn = boto.connect_vpc('the_key', 'the_secret') + ec2 = boto.connect_ec2("the_key", "the_secret") + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") @@ -31,25 +33,25 @@ def test_subnets(): with assert_raises(EC2ResponseError) as cm: conn.delete_subnet(subnet.id) - cm.exception.code.should.equal('InvalidSubnetID.NotFound') + cm.exception.code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_subnet_create_vpc_validation(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.create_subnet("vpc-abcd1234", "10.0.0.0/18") - cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_subnet_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") @@ -67,31 +69,31 @@ def test_subnet_tagging(): @mock_ec2_deprecated def test_subnet_should_have_proper_availability_zone_set(): - conn = boto.vpc.connect_to_region('us-west-1') + conn = boto.vpc.connect_to_region("us-west-1") vpcA = conn.create_vpc("10.0.0.0/16") - subnetA = conn.create_subnet( - vpcA.id, "10.0.0.0/24", availability_zone='us-west-1b') - subnetA.availability_zone.should.equal('us-west-1b') + subnetA = conn.create_subnet(vpcA.id, "10.0.0.0/24", availability_zone="us-west-1b") + subnetA.availability_zone.should.equal("us-west-1b") @mock_ec2 def test_default_subnet(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_vpc = list(ec2.vpcs.all())[0] - default_vpc.cidr_block.should.equal('172.31.0.0/16') + default_vpc.cidr_block.should.equal("172.31.0.0/16") default_vpc.reload() default_vpc.is_default.should.be.ok subnet = ec2.create_subnet( - VpcId=default_vpc.id, CidrBlock='172.31.48.0/20', AvailabilityZone='us-west-1a') + VpcId=default_vpc.id, CidrBlock="172.31.48.0/20", AvailabilityZone="us-west-1a" + ) subnet.reload() subnet.map_public_ip_on_launch.shouldnt.be.ok @mock_ec2_deprecated def test_non_default_subnet(): - vpc_cli = boto.vpc.connect_to_region('us-west-1') + vpc_cli = boto.vpc.connect_to_region("us-west-1") # Create the non default VPC vpc = vpc_cli.create_vpc("10.0.0.0/16") @@ -99,34 +101,36 @@ def test_non_default_subnet(): subnet = vpc_cli.create_subnet(vpc.id, "10.0.0.0/24") subnet = vpc_cli.get_all_subnets(subnet_ids=[subnet.id])[0] - subnet.mapPublicIpOnLaunch.should.equal('false') + subnet.mapPublicIpOnLaunch.should.equal("false") @mock_ec2 def test_boto3_non_default_subnet(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the non default VPC - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + ) subnet.reload() subnet.map_public_ip_on_launch.shouldnt.be.ok @mock_ec2 def test_modify_subnet_attribute_public_ip_on_launch(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") # Get the default VPC vpc = list(ec2.vpcs.all())[0] subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock="172.31.48.0/20", AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="172.31.48.0/20", AvailabilityZone="us-west-1a" + ) # 'map_public_ip_on_launch' is set when calling 'DescribeSubnets' action subnet.reload() @@ -135,26 +139,29 @@ def test_modify_subnet_attribute_public_ip_on_launch(): subnet.map_public_ip_on_launch.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': False}) + SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": False} + ) subnet.reload() subnet.map_public_ip_on_launch.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': True}) + SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": True} + ) subnet.reload() subnet.map_public_ip_on_launch.should.be.ok @mock_ec2 def test_modify_subnet_attribute_assign_ipv6_address_on_creation(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") # Get the default VPC vpc = list(ec2.vpcs.all())[0] subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='172.31.112.0/20', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="172.31.112.0/20", AvailabilityZone="us-west-1a" + ) # 'map_public_ip_on_launch' is set when calling 'DescribeSubnets' action subnet.reload() @@ -163,41 +170,46 @@ def test_modify_subnet_attribute_assign_ipv6_address_on_creation(): subnet.assign_ipv6_address_on_creation.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, AssignIpv6AddressOnCreation={'Value': False}) + SubnetId=subnet.id, AssignIpv6AddressOnCreation={"Value": False} + ) subnet.reload() subnet.assign_ipv6_address_on_creation.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, AssignIpv6AddressOnCreation={'Value': True}) + SubnetId=subnet.id, AssignIpv6AddressOnCreation={"Value": True} + ) subnet.reload() subnet.assign_ipv6_address_on_creation.should.be.ok @mock_ec2 def test_modify_subnet_attribute_validation(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + ) with assert_raises(ParamValidationError): client.modify_subnet_attribute( - SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': 'invalid'}) + SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": "invalid"} + ) @mock_ec2_deprecated def test_subnet_get_by_id(): - ec2 = boto.ec2.connect_to_region('us-west-1') - conn = boto.vpc.connect_to_region('us-west-1') + ec2 = boto.ec2.connect_to_region("us-west-1") + conn = boto.vpc.connect_to_region("us-west-1") vpcA = conn.create_vpc("10.0.0.0/16") - subnetA = conn.create_subnet( - vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a') + subnetA = conn.create_subnet(vpcA.id, "10.0.0.0/24", availability_zone="us-west-1a") vpcB = conn.create_vpc("10.0.0.0/16") subnetB1 = conn.create_subnet( - vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a') + vpcB.id, "10.0.0.0/24", availability_zone="us-west-1a" + ) subnetB2 = conn.create_subnet( - vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b') + vpcB.id, "10.0.1.0/24", availability_zone="us-west-1b" + ) subnets_by_id = conn.get_all_subnets(subnet_ids=[subnetA.id, subnetB1.id]) subnets_by_id.should.have.length_of(2) @@ -206,85 +218,91 @@ def test_subnet_get_by_id(): subnetB1.id.should.be.within(subnets_by_id) with assert_raises(EC2ResponseError) as cm: - conn.get_all_subnets(subnet_ids=['subnet-does_not_exist']) - cm.exception.code.should.equal('InvalidSubnetID.NotFound') + conn.get_all_subnets(subnet_ids=["subnet-does_not_exist"]) + cm.exception.code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_get_subnets_filtering(): - ec2 = boto.ec2.connect_to_region('us-west-1') - conn = boto.vpc.connect_to_region('us-west-1') + ec2 = boto.ec2.connect_to_region("us-west-1") + conn = boto.vpc.connect_to_region("us-west-1") vpcA = conn.create_vpc("10.0.0.0/16") - subnetA = conn.create_subnet( - vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a') + subnetA = conn.create_subnet(vpcA.id, "10.0.0.0/24", availability_zone="us-west-1a") vpcB = conn.create_vpc("10.0.0.0/16") subnetB1 = conn.create_subnet( - vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a') + vpcB.id, "10.0.0.0/24", availability_zone="us-west-1a" + ) subnetB2 = conn.create_subnet( - vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b') + vpcB.id, "10.0.1.0/24", availability_zone="us-west-1b" + ) all_subnets = conn.get_all_subnets() all_subnets.should.have.length_of(3 + len(ec2.get_all_zones())) # Filter by VPC ID - subnets_by_vpc = conn.get_all_subnets(filters={'vpc-id': vpcB.id}) + subnets_by_vpc = conn.get_all_subnets(filters={"vpc-id": vpcB.id}) subnets_by_vpc.should.have.length_of(2) set([subnet.id for subnet in subnets_by_vpc]).should.equal( - set([subnetB1.id, subnetB2.id])) + set([subnetB1.id, subnetB2.id]) + ) # Filter by CIDR variations - subnets_by_cidr1 = conn.get_all_subnets(filters={'cidr': "10.0.0.0/24"}) + subnets_by_cidr1 = conn.get_all_subnets(filters={"cidr": "10.0.0.0/24"}) subnets_by_cidr1.should.have.length_of(2) - set([subnet.id for subnet in subnets_by_cidr1] - ).should.equal(set([subnetA.id, subnetB1.id])) + set([subnet.id for subnet in subnets_by_cidr1]).should.equal( + set([subnetA.id, subnetB1.id]) + ) - subnets_by_cidr2 = conn.get_all_subnets( - filters={'cidr-block': "10.0.0.0/24"}) + subnets_by_cidr2 = conn.get_all_subnets(filters={"cidr-block": "10.0.0.0/24"}) subnets_by_cidr2.should.have.length_of(2) - set([subnet.id for subnet in subnets_by_cidr2] - ).should.equal(set([subnetA.id, subnetB1.id])) + set([subnet.id for subnet in subnets_by_cidr2]).should.equal( + set([subnetA.id, subnetB1.id]) + ) - subnets_by_cidr3 = conn.get_all_subnets( - filters={'cidrBlock': "10.0.0.0/24"}) + subnets_by_cidr3 = conn.get_all_subnets(filters={"cidrBlock": "10.0.0.0/24"}) subnets_by_cidr3.should.have.length_of(2) - set([subnet.id for subnet in subnets_by_cidr3] - ).should.equal(set([subnetA.id, subnetB1.id])) + set([subnet.id for subnet in subnets_by_cidr3]).should.equal( + set([subnetA.id, subnetB1.id]) + ) # Filter by VPC ID and CIDR subnets_by_vpc_and_cidr = conn.get_all_subnets( - filters={'vpc-id': vpcB.id, 'cidr': "10.0.0.0/24"}) + filters={"vpc-id": vpcB.id, "cidr": "10.0.0.0/24"} + ) subnets_by_vpc_and_cidr.should.have.length_of(1) - set([subnet.id for subnet in subnets_by_vpc_and_cidr] - ).should.equal(set([subnetB1.id])) + set([subnet.id for subnet in subnets_by_vpc_and_cidr]).should.equal( + set([subnetB1.id]) + ) # Filter by subnet ID - subnets_by_id = conn.get_all_subnets(filters={'subnet-id': subnetA.id}) + subnets_by_id = conn.get_all_subnets(filters={"subnet-id": subnetA.id}) subnets_by_id.should.have.length_of(1) set([subnet.id for subnet in subnets_by_id]).should.equal(set([subnetA.id])) # Filter by availabilityZone subnets_by_az = conn.get_all_subnets( - filters={'availabilityZone': 'us-west-1a', 'vpc-id': vpcB.id}) + filters={"availabilityZone": "us-west-1a", "vpc-id": vpcB.id} + ) subnets_by_az.should.have.length_of(1) - set([subnet.id for subnet in subnets_by_az] - ).should.equal(set([subnetB1.id])) + set([subnet.id for subnet in subnets_by_az]).should.equal(set([subnetB1.id])) # Filter by defaultForAz - subnets_by_az = conn.get_all_subnets(filters={'defaultForAz': "true"}) + subnets_by_az = conn.get_all_subnets(filters={"defaultForAz": "true"}) subnets_by_az.should.have.length_of(len(conn.get_all_zones())) # Unsupported filter conn.get_all_subnets.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2_deprecated @mock_cloudformation_deprecated def test_subnet_tags_through_cloudformation(): - vpc_conn = boto.vpc.connect_to_region('us-west-1') + vpc_conn = boto.vpc.connect_to_region("us-west-1") vpc = vpc_conn.create_vpc("10.0.0.0/16") subnet_template = { @@ -296,151 +314,288 @@ def test_subnet_tags_through_cloudformation(): "VpcId": vpc.id, "CidrBlock": "10.0.0.0/24", "AvailabilityZone": "us-west-1b", - "Tags": [{ - "Key": "foo", - "Value": "bar", - }, { - "Key": "blah", - "Value": "baz", - }] - } + "Tags": [ + {"Key": "foo", "Value": "bar"}, + {"Key": "blah", "Value": "baz"}, + ], + }, } - } + }, } cf_conn = boto.cloudformation.connect_to_region("us-west-1") template_json = json.dumps(subnet_template) - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) - subnet = vpc_conn.get_all_subnets(filters={'cidrBlock': '10.0.0.0/24'})[0] + subnet = vpc_conn.get_all_subnets(filters={"cidrBlock": "10.0.0.0/24"})[0] subnet.tags["foo"].should.equal("bar") subnet.tags["blah"].should.equal("baz") @mock_ec2 def test_create_subnet_response_fields(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = client.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a')['Subnet'] + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + )["Subnet"] - subnet.should.have.key('AvailabilityZone') - subnet.should.have.key('AvailabilityZoneId') - subnet.should.have.key('AvailableIpAddressCount') - subnet.should.have.key('CidrBlock') - subnet.should.have.key('State') - subnet.should.have.key('SubnetId') - subnet.should.have.key('VpcId') - subnet.shouldnt.have.key('Tags') - subnet.should.have.key('DefaultForAz').which.should.equal(False) - subnet.should.have.key('MapPublicIpOnLaunch').which.should.equal(False) - subnet.should.have.key('OwnerId') - subnet.should.have.key('AssignIpv6AddressOnCreation').which.should.equal(False) + subnet.should.have.key("AvailabilityZone") + subnet.should.have.key("AvailabilityZoneId") + subnet.should.have.key("AvailableIpAddressCount") + subnet.should.have.key("CidrBlock") + subnet.should.have.key("State") + subnet.should.have.key("SubnetId") + subnet.should.have.key("VpcId") + subnet.shouldnt.have.key("Tags") + subnet.should.have.key("DefaultForAz").which.should.equal(False) + subnet.should.have.key("MapPublicIpOnLaunch").which.should.equal(False) + subnet.should.have.key("OwnerId") + subnet.should.have.key("AssignIpv6AddressOnCreation").which.should.equal(False) - subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format(region=subnet['AvailabilityZone'][0:-1], - owner_id=subnet['OwnerId'], - subnet_id=subnet['SubnetId']) - subnet.should.have.key('SubnetArn').which.should.equal(subnet_arn) - subnet.should.have.key('Ipv6CidrBlockAssociationSet').which.should.equal([]) + subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format( + region=subnet["AvailabilityZone"][0:-1], + owner_id=subnet["OwnerId"], + subnet_id=subnet["SubnetId"], + ) + subnet.should.have.key("SubnetArn").which.should.equal(subnet_arn) + subnet.should.have.key("Ipv6CidrBlockAssociationSet").which.should.equal([]) @mock_ec2 def test_describe_subnet_response_fields(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet_object = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + ) - subnets = client.describe_subnets(SubnetIds=[subnet_object.id])['Subnets'] + subnets = client.describe_subnets(SubnetIds=[subnet_object.id])["Subnets"] subnets.should.have.length_of(1) subnet = subnets[0] - subnet.should.have.key('AvailabilityZone') - subnet.should.have.key('AvailabilityZoneId') - subnet.should.have.key('AvailableIpAddressCount') - subnet.should.have.key('CidrBlock') - subnet.should.have.key('State') - subnet.should.have.key('SubnetId') - subnet.should.have.key('VpcId') - subnet.shouldnt.have.key('Tags') - subnet.should.have.key('DefaultForAz').which.should.equal(False) - subnet.should.have.key('MapPublicIpOnLaunch').which.should.equal(False) - subnet.should.have.key('OwnerId') - subnet.should.have.key('AssignIpv6AddressOnCreation').which.should.equal(False) + subnet.should.have.key("AvailabilityZone") + subnet.should.have.key("AvailabilityZoneId") + subnet.should.have.key("AvailableIpAddressCount") + subnet.should.have.key("CidrBlock") + subnet.should.have.key("State") + subnet.should.have.key("SubnetId") + subnet.should.have.key("VpcId") + subnet.shouldnt.have.key("Tags") + subnet.should.have.key("DefaultForAz").which.should.equal(False) + subnet.should.have.key("MapPublicIpOnLaunch").which.should.equal(False) + subnet.should.have.key("OwnerId") + subnet.should.have.key("AssignIpv6AddressOnCreation").which.should.equal(False) - subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format(region=subnet['AvailabilityZone'][0:-1], - owner_id=subnet['OwnerId'], - subnet_id=subnet['SubnetId']) - subnet.should.have.key('SubnetArn').which.should.equal(subnet_arn) - subnet.should.have.key('Ipv6CidrBlockAssociationSet').which.should.equal([]) + subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format( + region=subnet["AvailabilityZone"][0:-1], + owner_id=subnet["OwnerId"], + subnet_id=subnet["SubnetId"], + ) + subnet.should.have.key("SubnetArn").which.should.equal(subnet_arn) + subnet.should.have.key("Ipv6CidrBlockAssociationSet").which.should.equal([]) @mock_ec2 def test_create_subnet_with_invalid_availability_zone(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") - subnet_availability_zone = 'asfasfas' + subnet_availability_zone = "asfasfas" with assert_raises(ClientError) as ex: subnet = client.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone=subnet_availability_zone) + VpcId=vpc.id, + CidrBlock="10.0.0.0/24", + AvailabilityZone=subnet_availability_zone, + ) assert str(ex.exception).startswith( "An error occurred (InvalidParameterValue) when calling the CreateSubnet " - "operation: Value ({}) for parameter availabilityZone is invalid. Subnets can currently only be created in the following availability zones: ".format(subnet_availability_zone)) + "operation: Value ({}) for parameter availabilityZone is invalid. Subnets can currently only be created in the following availability zones: ".format( + subnet_availability_zone + ) + ) @mock_ec2 def test_create_subnet_with_invalid_cidr_range(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok - subnet_cidr_block = '10.1.0.0/20' + subnet_cidr_block = "10.1.0.0/20" with assert_raises(ClientError) as ex: subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidSubnet.Range) when calling the CreateSubnet " - "operation: The CIDR '{}' is invalid.".format(subnet_cidr_block)) + "operation: The CIDR '{}' is invalid.".format(subnet_cidr_block) + ) @mock_ec2 def test_create_subnet_with_invalid_cidr_block_parameter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok - subnet_cidr_block = '1000.1.0.0/20' + subnet_cidr_block = "1000.1.0.0/20" with assert_raises(ClientError) as ex: subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidParameterValue) when calling the CreateSubnet " - "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(subnet_cidr_block)) + "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format( + subnet_cidr_block + ) + ) @mock_ec2 def test_create_subnets_with_overlapping_cidr_blocks(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok - subnet_cidr_block = '10.0.0.0/24' + subnet_cidr_block = "10.0.0.0/24" with assert_raises(ClientError) as ex: subnet1 = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) subnet2 = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidSubnet.Conflict) when calling the CreateSubnet " - "operation: The CIDR '{}' conflicts with another subnet".format(subnet_cidr_block)) + "operation: The CIDR '{}' conflicts with another subnet".format( + subnet_cidr_block + ) + ) + + +@mock_ec2 +def test_available_ip_addresses_in_subnet(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + cidr_range_addresses = [ + ("10.0.0.0/16", 65531), + ("10.0.0.0/17", 32763), + ("10.0.0.0/18", 16379), + ("10.0.0.0/19", 8187), + ("10.0.0.0/20", 4091), + ("10.0.0.0/21", 2043), + ("10.0.0.0/22", 1019), + ("10.0.0.0/23", 507), + ("10.0.0.0/24", 251), + ("10.0.0.0/25", 123), + ("10.0.0.0/26", 59), + ("10.0.0.0/27", 27), + ("10.0.0.0/28", 11), + ] + for (cidr, expected_count) in cidr_range_addresses: + validate_subnet_details(client, vpc, cidr, expected_count) + + +@mock_ec2 +def test_available_ip_addresses_in_subnet_with_enis(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + # Verify behaviour for various CIDR ranges (...) + # Don't try to assign ENIs to /27 and /28, as there are not a lot of IP addresses to go around + cidr_range_addresses = [ + ("10.0.0.0/16", 65531), + ("10.0.0.0/17", 32763), + ("10.0.0.0/18", 16379), + ("10.0.0.0/19", 8187), + ("10.0.0.0/20", 4091), + ("10.0.0.0/21", 2043), + ("10.0.0.0/22", 1019), + ("10.0.0.0/23", 507), + ("10.0.0.0/24", 251), + ("10.0.0.0/25", 123), + ("10.0.0.0/26", 59), + ] + for (cidr, expected_count) in cidr_range_addresses: + validate_subnet_details_after_creating_eni(client, vpc, cidr, expected_count) + + +def validate_subnet_details(client, vpc, cidr, expected_ip_address_count): + subnet = client.create_subnet( + VpcId=vpc.id, CidrBlock=cidr, AvailabilityZone="us-west-1b" + )["Subnet"] + subnet["AvailableIpAddressCount"].should.equal(expected_ip_address_count) + client.delete_subnet(SubnetId=subnet["SubnetId"]) + + +def validate_subnet_details_after_creating_eni( + client, vpc, cidr, expected_ip_address_count +): + subnet = client.create_subnet( + VpcId=vpc.id, CidrBlock=cidr, AvailabilityZone="us-west-1b" + )["Subnet"] + # Create a random number of Elastic Network Interfaces + nr_of_eni_to_create = random.randint(0, 5) + ip_addresses_assigned = 0 + enis_created = [] + for i in range(0, nr_of_eni_to_create): + # Create a random number of IP addresses per ENI + nr_of_ip_addresses = random.randint(1, 5) + if nr_of_ip_addresses == 1: + # Pick the first available IP address (First 4 are reserved by AWS) + private_address = "10.0.0." + str(ip_addresses_assigned + 4) + eni = client.create_network_interface( + SubnetId=subnet["SubnetId"], PrivateIpAddress=private_address + )["NetworkInterface"] + enis_created.append(eni) + ip_addresses_assigned = ip_addresses_assigned + 1 + else: + # Assign a list of IP addresses + private_addresses = [ + "10.0.0." + str(4 + ip_addresses_assigned + i) + for i in range(0, nr_of_ip_addresses) + ] + eni = client.create_network_interface( + SubnetId=subnet["SubnetId"], + PrivateIpAddresses=[ + {"PrivateIpAddress": address} for address in private_addresses + ], + )["NetworkInterface"] + enis_created.append(eni) + ip_addresses_assigned = ip_addresses_assigned + nr_of_ip_addresses + 1 # + # Verify that the nr of available IP addresses takes these ENIs into account + updated_subnet = client.describe_subnets(SubnetIds=[subnet["SubnetId"]])["Subnets"][ + 0 + ] + private_addresses = [ + eni["PrivateIpAddress"] for eni in enis_created if eni["PrivateIpAddress"] + ] + for eni in enis_created: + private_addresses.extend( + [address["PrivateIpAddress"] for address in eni["PrivateIpAddresses"]] + ) + error_msg = ( + "Nr of IP addresses for Subnet with CIDR {0} is incorrect. Expected: {1}, Actual: {2}. " + "Addresses: {3}" + ) + with sure.ensure( + error_msg, + cidr, + str(expected_ip_address_count), + updated_subnet["AvailableIpAddressCount"], + str(private_addresses), + ): + updated_subnet["AvailableIpAddressCount"].should.equal( + expected_ip_address_count - ip_addresses_assigned + ) + # Clean up, as we have to create a few more subnets that shouldn't interfere with each other + for eni in enis_created: + client.delete_network_interface(NetworkInterfaceId=eni["NetworkInterfaceId"]) + client.delete_subnet(SubnetId=subnet["SubnetId"]) diff --git a/tests/test_ec2/test_tags.py b/tests/test_ec2/test_tags.py index 2294979ba..29d2cb1e3 100644 --- a/tests/test_ec2/test_tags.py +++ b/tests/test_ec2/test_tags.py @@ -16,21 +16,23 @@ from nose.tools import assert_raises @mock_ec2_deprecated def test_add_tag(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: instance.add_tag("a key", "some value", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) instance.add_tag("a key", "some value") chain = itertools.chain.from_iterable existing_instances = list( - chain([res.instances for res in conn.get_all_instances()])) + chain([res.instances for res in conn.get_all_instances()]) + ) existing_instances.should.have.length_of(1) existing_instance = existing_instances[0] existing_instance.tags["a key"].should.equal("some value") @@ -38,8 +40,8 @@ def test_add_tag(): @mock_ec2_deprecated def test_remove_tag(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("a key", "some value") @@ -51,10 +53,11 @@ def test_remove_tag(): with assert_raises(EC2ResponseError) as ex: instance.remove_tag("a key", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteTags operation: Request would have succeeded, but DryRun flag is set" + ) instance.remove_tag("a key") conn.get_all_tags().should.have.length_of(0) @@ -66,8 +69,8 @@ def test_remove_tag(): @mock_ec2_deprecated def test_get_all_tags(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("a key", "some value") @@ -80,8 +83,8 @@ def test_get_all_tags(): @mock_ec2_deprecated def test_get_all_tags_with_special_characters(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("a key", "some<> value") @@ -94,47 +97,50 @@ def test_get_all_tags_with_special_characters(): @mock_ec2_deprecated def test_create_tags(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] - tag_dict = {'a key': 'some value', - 'another key': 'some other value', - 'blank key': ''} + tag_dict = { + "a key": "some value", + "another key": "some other value", + "blank key": "", + } with assert_raises(EC2ResponseError) as ex: conn.create_tags(instance.id, tag_dict, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) conn.create_tags(instance.id, tag_dict) tags = conn.get_all_tags() - set([key for key in tag_dict]).should.equal( - set([tag.name for tag in tags])) + set([key for key in tag_dict]).should.equal(set([tag.name for tag in tags])) set([tag_dict[key] for key in tag_dict]).should.equal( - set([tag.value for tag in tags])) + set([tag.value for tag in tags]) + ) @mock_ec2_deprecated def test_tag_limit_exceeded(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] tag_dict = {} for i in range(51): - tag_dict['{0:02d}'.format(i + 1)] = '' + tag_dict["{0:02d}".format(i + 1)] = "" with assert_raises(EC2ResponseError) as cm: conn.create_tags(instance.id, tag_dict) - cm.exception.code.should.equal('TagLimitExceeded') + cm.exception.code.should.equal("TagLimitExceeded") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none instance.add_tag("a key", "a value") with assert_raises(EC2ResponseError) as cm: conn.create_tags(instance.id, tag_dict) - cm.exception.code.should.equal('TagLimitExceeded') + cm.exception.code.should.equal("TagLimitExceeded") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -147,158 +153,158 @@ def test_tag_limit_exceeded(): @mock_ec2_deprecated def test_invalid_parameter_tag_null(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as cm: instance.add_tag("a key", None) - cm.exception.code.should.equal('InvalidParameterValue') + cm.exception.code.should.equal("InvalidParameterValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_invalid_id(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.create_tags('ami-blah', {'key': 'tag'}) - cm.exception.code.should.equal('InvalidID') + conn.create_tags("ami-blah", {"key": "tag"}) + cm.exception.code.should.equal("InvalidID") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as cm: - conn.create_tags('blah-blah', {'key': 'tag'}) - cm.exception.code.should.equal('InvalidID') + conn.create_tags("blah-blah", {"key": "tag"}) + cm.exception.code.should.equal("InvalidID") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_get_all_tags_resource_id_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'resource-id': instance.id}) + tags = conn.get_all_tags(filters={"resource-id": instance.id}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(instance.id) - tag.res_type.should.equal('instance') + tag.res_type.should.equal("instance") tag.name.should.equal("an instance key") tag.value.should.equal("some value") - tags = conn.get_all_tags(filters={'resource-id': image_id}) + tags = conn.get_all_tags(filters={"resource-id": image_id}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(image_id) - tag.res_type.should.equal('image') + tag.res_type.should.equal("image") tag.name.should.equal("an image key") tag.value.should.equal("some value") @mock_ec2_deprecated def test_get_all_tags_resource_type_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'resource-type': 'instance'}) + tags = conn.get_all_tags(filters={"resource-type": "instance"}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(instance.id) - tag.res_type.should.equal('instance') + tag.res_type.should.equal("instance") tag.name.should.equal("an instance key") tag.value.should.equal("some value") - tags = conn.get_all_tags(filters={'resource-type': 'image'}) + tags = conn.get_all_tags(filters={"resource-type": "image"}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(image_id) - tag.res_type.should.equal('image') + tag.res_type.should.equal("image") tag.name.should.equal("an image key") tag.value.should.equal("some value") @mock_ec2_deprecated def test_get_all_tags_key_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'key': 'an instance key'}) + tags = conn.get_all_tags(filters={"key": "an instance key"}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(instance.id) - tag.res_type.should.equal('instance') + tag.res_type.should.equal("instance") tag.name.should.equal("an instance key") tag.value.should.equal("some value") @mock_ec2_deprecated def test_get_all_tags_value_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") - reservation_b = conn.run_instances('ami-1234abcd') + reservation_b = conn.run_instances("ami-1234abcd") instance_b = reservation_b.instances[0] instance_b.add_tag("an instance key", "some other value") - reservation_c = conn.run_instances('ami-1234abcd') + reservation_c = conn.run_instances("ami-1234abcd") instance_c = reservation_c.instances[0] instance_c.add_tag("an instance key", "other value*") - reservation_d = conn.run_instances('ami-1234abcd') + reservation_d = conn.run_instances("ami-1234abcd") instance_d = reservation_d.instances[0] instance_d.add_tag("an instance key", "other value**") - reservation_e = conn.run_instances('ami-1234abcd') + reservation_e = conn.run_instances("ami-1234abcd") instance_e = reservation_e.instances[0] instance_e.add_tag("an instance key", "other value*?") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'value': 'some value'}) + tags = conn.get_all_tags(filters={"value": "some value"}) tags.should.have.length_of(2) - tags = conn.get_all_tags(filters={'value': 'some*value'}) + tags = conn.get_all_tags(filters={"value": "some*value"}) tags.should.have.length_of(3) - tags = conn.get_all_tags(filters={'value': '*some*value'}) + tags = conn.get_all_tags(filters={"value": "*some*value"}) tags.should.have.length_of(3) - tags = conn.get_all_tags(filters={'value': '*some*value*'}) + tags = conn.get_all_tags(filters={"value": "*some*value*"}) tags.should.have.length_of(3) - tags = conn.get_all_tags(filters={'value': '*value\*'}) + tags = conn.get_all_tags(filters={"value": "*value\*"}) tags.should.have.length_of(1) - tags = conn.get_all_tags(filters={'value': '*value\*\*'}) + tags = conn.get_all_tags(filters={"value": "*value\*\*"}) tags.should.have.length_of(1) - tags = conn.get_all_tags(filters={'value': '*value\*\?'}) + tags = conn.get_all_tags(filters={"value": "*value\*\?"}) tags.should.have.length_of(1) @mock_ec2_deprecated def test_retrieved_instances_must_contain_their_tags(): - tag_key = 'Tag name' - tag_value = 'Tag value' + tag_key = "Tag name" + tag_value = "Tag value" tags_to_be_set = {tag_key: tag_value} - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") reservation.should.be.a(Reservation) reservation.instances.should.have.length_of(1) instance = reservation.instances[0] @@ -324,10 +330,10 @@ def test_retrieved_instances_must_contain_their_tags(): @mock_ec2_deprecated def test_retrieved_volumes_must_contain_their_tags(): - tag_key = 'Tag name' - tag_value = 'Tag value' + tag_key = "Tag name" + tag_value = "Tag value" tags_to_be_set = {tag_key: tag_value} - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") volume = conn.create_volume(80, "us-east-1a") all_volumes = conn.get_all_volumes() @@ -347,11 +353,12 @@ def test_retrieved_volumes_must_contain_their_tags(): @mock_ec2_deprecated def test_retrieved_snapshots_must_contain_their_tags(): - tag_key = 'Tag name' - tag_value = 'Tag value' + tag_key = "Tag name" + tag_value = "Tag value" tags_to_be_set = {tag_key: tag_value} - conn = boto.connect_ec2(aws_access_key_id='the_key', - aws_secret_access_key='the_secret') + conn = boto.connect_ec2( + aws_access_key_id="the_key", aws_secret_access_key="the_secret" + ) volume = conn.create_volume(80, "eu-west-1a") snapshot = conn.create_snapshot(volume.id) conn.create_tags([snapshot.id], tags_to_be_set) @@ -370,113 +377,94 @@ def test_retrieved_snapshots_must_contain_their_tags(): @mock_ec2_deprecated def test_filter_instances_by_wildcard_tags(): - conn = boto.connect_ec2(aws_access_key_id='the_key', - aws_secret_access_key='the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2( + aws_access_key_id="the_key", aws_secret_access_key="the_secret" + ) + reservation = conn.run_instances("ami-1234abcd") instance_a = reservation.instances[0] instance_a.add_tag("Key1", "Value1") - reservation_b = conn.run_instances('ami-1234abcd') + reservation_b = conn.run_instances("ami-1234abcd") instance_b = reservation_b.instances[0] instance_b.add_tag("Key1", "Value2") - reservations = conn.get_all_instances(filters={'tag:Key1': 'Value*'}) + reservations = conn.get_all_instances(filters={"tag:Key1": "Value*"}) reservations.should.have.length_of(2) - reservations = conn.get_all_instances(filters={'tag-key': 'Key*'}) + reservations = conn.get_all_instances(filters={"tag-key": "Key*"}) reservations.should.have.length_of(2) - reservations = conn.get_all_instances(filters={'tag-value': 'Value*'}) + reservations = conn.get_all_instances(filters={"tag-value": "Value*"}) reservations.should.have.length_of(2) @mock_ec2 def test_create_volume_with_tags(): - client = boto3.client('ec2', 'us-west-2') + client = boto3.client("ec2", "us-west-2") response = client.create_volume( - AvailabilityZone='us-west-2', - Encrypted=False, - Size=40, - TagSpecifications=[ - { - 'ResourceType': 'volume', - 'Tags': [ - { - 'Key': 'TEST_TAG', - 'Value': 'TEST_VALUE' - } - ], - } - ] - ) - - assert response['Tags'][0]['Key'] == 'TEST_TAG' - - -@mock_ec2 -def test_create_snapshot_with_tags(): - client = boto3.client('ec2', 'us-west-2') - volume_id = client.create_volume( - AvailabilityZone='us-west-2', + AvailabilityZone="us-west-2", Encrypted=False, Size=40, TagSpecifications=[ { - 'ResourceType': 'volume', - 'Tags': [ - { - 'Key': 'TEST_TAG', - 'Value': 'TEST_VALUE' - } - ], + "ResourceType": "volume", + "Tags": [{"Key": "TEST_TAG", "Value": "TEST_VALUE"}], } - ] - )['VolumeId'] + ], + ) + + assert response["Tags"][0]["Key"] == "TEST_TAG" + + +@mock_ec2 +def test_create_snapshot_with_tags(): + client = boto3.client("ec2", "us-west-2") + volume_id = client.create_volume( + AvailabilityZone="us-west-2", + Encrypted=False, + Size=40, + TagSpecifications=[ + { + "ResourceType": "volume", + "Tags": [{"Key": "TEST_TAG", "Value": "TEST_VALUE"}], + } + ], + )["VolumeId"] snapshot = client.create_snapshot( VolumeId=volume_id, TagSpecifications=[ { - 'ResourceType': 'snapshot', - 'Tags': [ - { - 'Key': 'TEST_SNAPSHOT_TAG', - 'Value': 'TEST_SNAPSHOT_VALUE' - } - ], + "ResourceType": "snapshot", + "Tags": [{"Key": "TEST_SNAPSHOT_TAG", "Value": "TEST_SNAPSHOT_VALUE"}], } - ] + ], ) - expected_tags = [{ - 'Key': 'TEST_SNAPSHOT_TAG', - 'Value': 'TEST_SNAPSHOT_VALUE' - }] + expected_tags = [{"Key": "TEST_SNAPSHOT_TAG", "Value": "TEST_SNAPSHOT_VALUE"}] - assert snapshot['Tags'] == expected_tags + assert snapshot["Tags"] == expected_tags @mock_ec2 def test_create_tag_empty_resource(): # create ec2 client in us-west-1 - client = boto3.client('ec2', region_name='us-west-1') + client = boto3.client("ec2", region_name="us-west-1") # create tag with empty resource with assert_raises(ClientError) as ex: - client.create_tags( - Resources=[], - Tags=[{'Key': 'Value'}] - ) - ex.exception.response['Error']['Code'].should.equal('MissingParameter') - ex.exception.response['Error']['Message'].should.equal('The request must contain the parameter resourceIdSet') + client.create_tags(Resources=[], Tags=[{"Key": "Value"}]) + ex.exception.response["Error"]["Code"].should.equal("MissingParameter") + ex.exception.response["Error"]["Message"].should.equal( + "The request must contain the parameter resourceIdSet" + ) @mock_ec2 def test_delete_tag_empty_resource(): # create ec2 client in us-west-1 - client = boto3.client('ec2', region_name='us-west-1') + client = boto3.client("ec2", region_name="us-west-1") # delete tag with empty resource with assert_raises(ClientError) as ex: - client.delete_tags( - Resources=[], - Tags=[{'Key': 'Value'}] - ) - ex.exception.response['Error']['Code'].should.equal('MissingParameter') - ex.exception.response['Error']['Message'].should.equal('The request must contain the parameter resourceIdSet') + client.delete_tags(Resources=[], Tags=[{"Key": "Value"}]) + ex.exception.response["Error"]["Code"].should.equal("MissingParameter") + ex.exception.response["Error"]["Message"].should.equal( + "The request must contain the parameter resourceIdSet" + ) diff --git a/tests/test_ec2/test_utils.py b/tests/test_ec2/test_utils.py index 49192dc79..75e3953bf 100644 --- a/tests/test_ec2/test_utils.py +++ b/tests/test_ec2/test_utils.py @@ -5,8 +5,8 @@ from .helpers import rsa_check_private_key def test_random_key_pair(): key_pair = utils.random_key_pair() - rsa_check_private_key(key_pair['material']) + rsa_check_private_key(key_pair["material"]) # AWS uses MD5 fingerprints, which are 47 characters long, *not* SHA1 # fingerprints with 59 characters. - assert len(key_pair['fingerprint']) == 47 + assert len(key_pair["fingerprint"]) == 47 diff --git a/tests/test_ec2/test_virtual_private_gateways.py b/tests/test_ec2/test_virtual_private_gateways.py index a57bdc59f..bb944df0b 100644 --- a/tests/test_ec2/test_virtual_private_gateways.py +++ b/tests/test_ec2/test_virtual_private_gateways.py @@ -7,54 +7,51 @@ from moto import mock_ec2_deprecated @mock_ec2_deprecated def test_virtual_private_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway.should_not.be.none - vpn_gateway.id.should.match(r'vgw-\w+') - vpn_gateway.type.should.equal('ipsec.1') - vpn_gateway.state.should.equal('available') - vpn_gateway.availability_zone.should.equal('us-east-1a') + vpn_gateway.id.should.match(r"vgw-\w+") + vpn_gateway.type.should.equal("ipsec.1") + vpn_gateway.state.should.equal("available") + vpn_gateway.availability_zone.should.equal("us-east-1a") @mock_ec2_deprecated def test_describe_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + conn = boto.connect_vpc("the_key", "the_secret") + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vgws = conn.get_all_vpn_gateways() vgws.should.have.length_of(1) gateway = vgws[0] - gateway.id.should.match(r'vgw-\w+') + gateway.id.should.match(r"vgw-\w+") gateway.id.should.equal(vpn_gateway.id) - vpn_gateway.type.should.equal('ipsec.1') - vpn_gateway.state.should.equal('available') - vpn_gateway.availability_zone.should.equal('us-east-1a') + vpn_gateway.type.should.equal("ipsec.1") + vpn_gateway.state.should.equal("available") + vpn_gateway.availability_zone.should.equal("us-east-1a") @mock_ec2_deprecated def test_vpn_gateway_vpc_attachment(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") - conn.attach_vpn_gateway( - vpn_gateway_id=vpn_gateway.id, - vpc_id=vpc.id - ) + conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) gateway = conn.get_all_vpn_gateways()[0] attachments = gateway.attachments attachments.should.have.length_of(1) attachments[0].vpc_id.should.equal(vpc.id) - attachments[0].state.should.equal('attached') + attachments[0].state.should.equal("attached") @mock_ec2_deprecated def test_delete_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + conn = boto.connect_vpc("the_key", "the_secret") + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") conn.delete_vpn_gateway(vpn_gateway.id) vgws = conn.get_all_vpn_gateways() @@ -63,8 +60,8 @@ def test_delete_vpn_gateway(): @mock_ec2_deprecated def test_vpn_gateway_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + conn = boto.connect_vpc("the_key", "the_secret") + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway.add_tag("a key", "some value") tag = conn.get_all_tags()[0] @@ -80,25 +77,19 @@ def test_vpn_gateway_tagging(): @mock_ec2_deprecated def test_detach_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") - conn.attach_vpn_gateway( - vpn_gateway_id=vpn_gateway.id, - vpc_id=vpc.id - ) + conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) gateway = conn.get_all_vpn_gateways()[0] attachments = gateway.attachments attachments.should.have.length_of(1) attachments[0].vpc_id.should.equal(vpc.id) - attachments[0].state.should.equal('attached') + attachments[0].state.should.equal("attached") - conn.detach_vpn_gateway( - vpn_gateway_id=vpn_gateway.id, - vpc_id=vpc.id - ) + conn.detach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) gateway = conn.get_all_vpn_gateways()[0] attachments = gateway.attachments diff --git a/tests/test_ec2/test_vpc_peering.py b/tests/test_ec2/test_vpc_peering.py index edfbfb3c2..fc1646961 100644 --- a/tests/test_ec2/test_vpc_peering.py +++ b/tests/test_ec2/test_vpc_peering.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -17,12 +18,12 @@ from tests.helpers import requires_boto_gte @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") peer_vpc = conn.create_vpc("11.0.0.0/16") vpc_pcx = conn.create_vpc_peering_connection(vpc.id, peer_vpc.id) - vpc_pcx._status.code.should.equal('initiating-request') + vpc_pcx._status.code.should.equal("initiating-request") return vpc_pcx @@ -30,39 +31,39 @@ def test_vpc_peering_connections(): @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections_get_all(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() - vpc_pcx._status.code.should.equal('initiating-request') + vpc_pcx._status.code.should.equal("initiating-request") all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('pending-acceptance') + all_vpc_pcxs[0]._status.code.should.equal("pending-acceptance") @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections_accept(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() vpc_pcx = conn.accept_vpc_peering_connection(vpc_pcx.id) - vpc_pcx._status.code.should.equal('active') + vpc_pcx._status.code.should.equal("active") with assert_raises(EC2ResponseError) as cm: conn.reject_vpc_peering_connection(vpc_pcx.id) - cm.exception.code.should.equal('InvalidStateTransition') + cm.exception.code.should.equal("InvalidStateTransition") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('active') + all_vpc_pcxs[0]._status.code.should.equal("active") @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections_reject(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() verdict = conn.reject_vpc_peering_connection(vpc_pcx.id) @@ -70,19 +71,19 @@ def test_vpc_peering_connections_reject(): with assert_raises(EC2ResponseError) as cm: conn.accept_vpc_peering_connection(vpc_pcx.id) - cm.exception.code.should.equal('InvalidStateTransition') + cm.exception.code.should.equal("InvalidStateTransition") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('rejected') + all_vpc_pcxs[0]._status.code.should.equal("rejected") @requires_boto_gte("2.32.1") @mock_ec2_deprecated def test_vpc_peering_connections_delete(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() verdict = vpc_pcx.delete() @@ -90,11 +91,11 @@ def test_vpc_peering_connections_delete(): all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('deleted') + all_vpc_pcxs[0]._status.code.should.equal("deleted") with assert_raises(EC2ResponseError) as cm: conn.delete_vpc_peering_connection("pcx-1234abcd") - cm.exception.code.should.equal('InvalidVpcPeeringConnectionId.NotFound') + cm.exception.code.should.equal("InvalidVpcPeeringConnectionId.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -102,17 +103,15 @@ def test_vpc_peering_connections_delete(): @mock_ec2 def test_vpc_peering_connections_cross_region(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) - vpc_pcx_usw1.status['Code'].should.equal('initiating-request') + vpc_pcx_usw1.status["Code"].should.equal("initiating-request") vpc_pcx_usw1.requester_vpc.id.should.equal(vpc_usw1.id) vpc_pcx_usw1.accepter_vpc.id.should.equal(vpc_apn1.id) # test cross region vpc peering connection exist @@ -125,35 +124,32 @@ def test_vpc_peering_connections_cross_region(): @mock_ec2 def test_vpc_peering_connections_cross_region_fail(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering wrong region with no vpc with assert_raises(ClientError) as cm: ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-2') - cm.exception.response['Error']['Code'].should.equal('InvalidVpcID.NotFound') + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-2" + ) + cm.exception.response["Error"]["Code"].should.equal("InvalidVpcID.NotFound") @mock_ec2 def test_vpc_peering_connections_cross_region_accept(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # accept peering from ap-northeast-1 - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") acp_pcx_apn1 = ec2_apn1.accept_vpc_peering_connection( VpcPeeringConnectionId=vpc_pcx_usw1.id ) @@ -163,27 +159,25 @@ def test_vpc_peering_connections_cross_region_accept(): des_pcx_usw1 = ec2_usw1.describe_vpc_peering_connections( VpcPeeringConnectionIds=[vpc_pcx_usw1.id] ) - acp_pcx_apn1['VpcPeeringConnection']['Status']['Code'].should.equal('active') - des_pcx_apn1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('active') - des_pcx_usw1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('active') + acp_pcx_apn1["VpcPeeringConnection"]["Status"]["Code"].should.equal("active") + des_pcx_apn1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("active") + des_pcx_usw1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("active") @mock_ec2 def test_vpc_peering_connections_cross_region_reject(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # reject peering from ap-northeast-1 - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") rej_pcx_apn1 = ec2_apn1.reject_vpc_peering_connection( VpcPeeringConnectionId=vpc_pcx_usw1.id ) @@ -193,27 +187,25 @@ def test_vpc_peering_connections_cross_region_reject(): des_pcx_usw1 = ec2_usw1.describe_vpc_peering_connections( VpcPeeringConnectionIds=[vpc_pcx_usw1.id] ) - rej_pcx_apn1['Return'].should.equal(True) - des_pcx_apn1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('rejected') - des_pcx_usw1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('rejected') + rej_pcx_apn1["Return"].should.equal(True) + des_pcx_apn1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("rejected") + des_pcx_usw1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("rejected") @mock_ec2 def test_vpc_peering_connections_cross_region_delete(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # reject peering from ap-northeast-1 - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") del_pcx_apn1 = ec2_apn1.delete_vpc_peering_connection( VpcPeeringConnectionId=vpc_pcx_usw1.id ) @@ -223,61 +215,57 @@ def test_vpc_peering_connections_cross_region_delete(): des_pcx_usw1 = ec2_usw1.describe_vpc_peering_connections( VpcPeeringConnectionIds=[vpc_pcx_usw1.id] ) - del_pcx_apn1['Return'].should.equal(True) - des_pcx_apn1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('deleted') - des_pcx_usw1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('deleted') + del_pcx_apn1["Return"].should.equal(True) + des_pcx_apn1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("deleted") + des_pcx_usw1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("deleted") @mock_ec2 def test_vpc_peering_connections_cross_region_accept_wrong_region(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # accept wrong peering from us-west-1 which will raise error - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") with assert_raises(ClientError) as cm: - ec2_usw1.accept_vpc_peering_connection( - VpcPeeringConnectionId=vpc_pcx_usw1.id - ) - cm.exception.response['Error']['Code'].should.equal('OperationNotPermitted') - exp_msg = 'Incorrect region ({0}) specified for this request.VPC ' \ - 'peering connection {1} must be ' \ - 'accepted in region {2}'.format('us-west-1', vpc_pcx_usw1.id, 'ap-northeast-1') - cm.exception.response['Error']['Message'].should.equal(exp_msg) + ec2_usw1.accept_vpc_peering_connection(VpcPeeringConnectionId=vpc_pcx_usw1.id) + cm.exception.response["Error"]["Code"].should.equal("OperationNotPermitted") + exp_msg = ( + "Incorrect region ({0}) specified for this request.VPC " + "peering connection {1} must be " + "accepted in region {2}".format("us-west-1", vpc_pcx_usw1.id, "ap-northeast-1") + ) + cm.exception.response["Error"]["Message"].should.equal(exp_msg) @mock_ec2 def test_vpc_peering_connections_cross_region_reject_wrong_region(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # reject wrong peering from us-west-1 which will raise error - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") with assert_raises(ClientError) as cm: - ec2_usw1.reject_vpc_peering_connection( - VpcPeeringConnectionId=vpc_pcx_usw1.id - ) - cm.exception.response['Error']['Code'].should.equal('OperationNotPermitted') - exp_msg = 'Incorrect region ({0}) specified for this request.VPC ' \ - 'peering connection {1} must be accepted or ' \ - 'rejected in region {2}'.format('us-west-1', vpc_pcx_usw1.id, 'ap-northeast-1') - cm.exception.response['Error']['Message'].should.equal(exp_msg) + ec2_usw1.reject_vpc_peering_connection(VpcPeeringConnectionId=vpc_pcx_usw1.id) + cm.exception.response["Error"]["Code"].should.equal("OperationNotPermitted") + exp_msg = ( + "Incorrect region ({0}) specified for this request.VPC " + "peering connection {1} must be accepted or " + "rejected in region {2}".format("us-west-1", vpc_pcx_usw1.id, "ap-northeast-1") + ) + cm.exception.response["Error"]["Message"].should.equal(exp_msg) diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index ad17deb3c..1bc3ddd98 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 -import tests.backport_assert_raises # flake8: noqa +import tests.backport_assert_raises # noqa from nose.tools import assert_raises from moto.ec2.exceptions import EC2ClientError from botocore.exceptions import ClientError @@ -12,15 +13,15 @@ import sure # noqa from moto import mock_ec2, mock_ec2_deprecated -SAMPLE_DOMAIN_NAME = u'example.com' -SAMPLE_NAME_SERVERS = [u'10.0.0.6', u'10.0.0.7'] +SAMPLE_DOMAIN_NAME = "example.com" +SAMPLE_NAME_SERVERS = ["10.0.0.6", "10.0.0.7"] @mock_ec2_deprecated def test_vpcs(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - vpc.cidr_block.should.equal('10.0.0.0/16') + vpc.cidr_block.should.equal("10.0.0.0/16") all_vpcs = conn.get_all_vpcs() all_vpcs.should.have.length_of(2) @@ -32,58 +33,56 @@ def test_vpcs(): with assert_raises(EC2ResponseError) as cm: conn.delete_vpc("vpc-1234abcd") - cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_vpc_defaults(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") conn.get_all_vpcs().should.have.length_of(2) conn.get_all_route_tables().should.have.length_of(2) - conn.get_all_security_groups( - filters={'vpc-id': [vpc.id]}).should.have.length_of(1) + conn.get_all_security_groups(filters={"vpc-id": [vpc.id]}).should.have.length_of(1) vpc.delete() conn.get_all_vpcs().should.have.length_of(1) conn.get_all_route_tables().should.have.length_of(1) - conn.get_all_security_groups( - filters={'vpc-id': [vpc.id]}).should.have.length_of(0) + conn.get_all_security_groups(filters={"vpc-id": [vpc.id]}).should.have.length_of(0) @mock_ec2_deprecated def test_vpc_isdefault_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - conn.get_all_vpcs(filters={'isDefault': 'true'}).should.have.length_of(1) + conn.get_all_vpcs(filters={"isDefault": "true"}).should.have.length_of(1) vpc.delete() - conn.get_all_vpcs(filters={'isDefault': 'true'}).should.have.length_of(1) + conn.get_all_vpcs(filters={"isDefault": "true"}).should.have.length_of(1) @mock_ec2_deprecated def test_multiple_vpcs_default_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") conn.create_vpc("10.8.0.0/16") conn.create_vpc("10.0.0.0/16") conn.create_vpc("192.168.0.0/16") conn.get_all_vpcs().should.have.length_of(4) - vpc = conn.get_all_vpcs(filters={'isDefault': 'true'}) + vpc = conn.get_all_vpcs(filters={"isDefault": "true"}) vpc.should.have.length_of(1) - vpc[0].cidr_block.should.equal('172.31.0.0/16') + vpc[0].cidr_block.should.equal("172.31.0.0/16") @mock_ec2_deprecated def test_vpc_state_available_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.1.0.0/16") - conn.get_all_vpcs(filters={'state': 'available'}).should.have.length_of(3) + conn.get_all_vpcs(filters={"state": "available"}).should.have.length_of(3) vpc.delete() - conn.get_all_vpcs(filters={'state': 'available'}).should.have.length_of(2) + conn.get_all_vpcs(filters={"state": "available"}).should.have.length_of(2) @mock_ec2_deprecated @@ -116,8 +115,8 @@ def test_vpc_get_by_id(): vpc2.id.should.be.within(vpc_ids) with assert_raises(EC2ResponseError) as cm: - conn.get_all_vpcs(vpc_ids=['vpc-does_not_exist']) - cm.exception.code.should.equal('InvalidVpcID.NotFound') + conn.get_all_vpcs(vpc_ids=["vpc-does_not_exist"]) + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -129,7 +128,7 @@ def test_vpc_get_by_cidr_block(): vpc2 = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.0.0.0/24") - vpcs = conn.get_all_vpcs(filters={'cidr': '10.0.0.0/16'}) + vpcs = conn.get_all_vpcs(filters={"cidr": "10.0.0.0/16"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -139,8 +138,7 @@ def test_vpc_get_by_cidr_block(): @mock_ec2_deprecated def test_vpc_get_by_dhcp_options_id(): conn = boto.connect_vpc() - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) vpc1 = conn.create_vpc("10.0.0.0/16") vpc2 = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.0.0.0/24") @@ -148,7 +146,7 @@ def test_vpc_get_by_dhcp_options_id(): conn.associate_dhcp_options(dhcp_options.id, vpc1.id) conn.associate_dhcp_options(dhcp_options.id, vpc2.id) - vpcs = conn.get_all_vpcs(filters={'dhcp-options-id': dhcp_options.id}) + vpcs = conn.get_all_vpcs(filters={"dhcp-options-id": dhcp_options.id}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -162,11 +160,11 @@ def test_vpc_get_by_tag(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc2.add_tag('Name', 'TestVPC') - vpc3.add_tag('Name', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc2.add_tag("Name", "TestVPC") + vpc3.add_tag("Name", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag:Name': 'TestVPC'}) + vpcs = conn.get_all_vpcs(filters={"tag:Name": "TestVPC"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -180,13 +178,13 @@ def test_vpc_get_by_tag_key_superset(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') - vpc3.add_tag('Key', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") + vpc3.add_tag("Key", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-key': 'Name'}) + vpcs = conn.get_all_vpcs(filters={"tag-key": "Name"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -200,13 +198,13 @@ def test_vpc_get_by_tag_key_subset(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') - vpc3.add_tag('Test', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") + vpc3.add_tag("Test", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-key': ['Name', 'Key']}) + vpcs = conn.get_all_vpcs(filters={"tag-key": ["Name", "Key"]}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -220,13 +218,13 @@ def test_vpc_get_by_tag_value_superset(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') - vpc3.add_tag('Key', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") + vpc3.add_tag("Key", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-value': 'TestVPC'}) + vpcs = conn.get_all_vpcs(filters={"tag-value": "TestVPC"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -240,12 +238,12 @@ def test_vpc_get_by_tag_value_subset(): vpc2 = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-value': ['TestVPC', 'TestVPC2']}) + vpcs = conn.get_all_vpcs(filters={"tag-value": ["TestVPC", "TestVPC2"]}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -254,117 +252,116 @@ def test_vpc_get_by_tag_value_subset(): @mock_ec2 def test_default_vpc(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC default_vpc = list(ec2.vpcs.all())[0] - default_vpc.cidr_block.should.equal('172.31.0.0/16') - default_vpc.instance_tenancy.should.equal('default') + default_vpc.cidr_block.should.equal("172.31.0.0/16") + default_vpc.instance_tenancy.should.equal("default") default_vpc.reload() default_vpc.is_default.should.be.ok # Test default values for VPC attributes - response = default_vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').should.be.ok + response = default_vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").should.be.ok - response = default_vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').should.be.ok + response = default_vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").should.be.ok @mock_ec2 def test_non_default_vpc(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - this already exists when backend instantiated! - #ec2.create_vpc(CidrBlock='172.31.0.0/16') + # ec2.create_vpc(CidrBlock='172.31.0.0/16') # Create the non default VPC - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok # Test default instance_tenancy - vpc.instance_tenancy.should.equal('default') + vpc.instance_tenancy.should.equal("default") # Test default values for VPC attributes - response = vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').should.be.ok + response = vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").should.be.ok - response = vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').shouldnt.be.ok + response = vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").shouldnt.be.ok # Check Primary CIDR Block Associations cidr_block_association_set = next(iter(vpc.cidr_block_association_set), None) - cidr_block_association_set['CidrBlockState']['State'].should.equal('associated') - cidr_block_association_set['CidrBlock'].should.equal(vpc.cidr_block) - cidr_block_association_set['AssociationId'].should.contain('vpc-cidr-assoc') + cidr_block_association_set["CidrBlockState"]["State"].should.equal("associated") + cidr_block_association_set["CidrBlock"].should.equal(vpc.cidr_block) + cidr_block_association_set["AssociationId"].should.contain("vpc-cidr-assoc") @mock_ec2 def test_vpc_dedicated_tenancy(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - ec2.create_vpc(CidrBlock='172.31.0.0/16') + ec2.create_vpc(CidrBlock="172.31.0.0/16") # Create the non default VPC - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16', InstanceTenancy='dedicated') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16", InstanceTenancy="dedicated") vpc.reload() vpc.is_default.shouldnt.be.ok - vpc.instance_tenancy.should.equal('dedicated') + vpc.instance_tenancy.should.equal("dedicated") @mock_ec2 def test_vpc_modify_enable_dns_support(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - ec2.create_vpc(CidrBlock='172.31.0.0/16') + ec2.create_vpc(CidrBlock="172.31.0.0/16") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") # Test default values for VPC attributes - response = vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').should.be.ok + response = vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").should.be.ok - vpc.modify_attribute(EnableDnsSupport={'Value': False}) + vpc.modify_attribute(EnableDnsSupport={"Value": False}) - response = vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').shouldnt.be.ok + response = vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").shouldnt.be.ok @mock_ec2 def test_vpc_modify_enable_dns_hostnames(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - ec2.create_vpc(CidrBlock='172.31.0.0/16') + ec2.create_vpc(CidrBlock="172.31.0.0/16") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") # Test default values for VPC attributes - response = vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').shouldnt.be.ok + response = vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").shouldnt.be.ok - vpc.modify_attribute(EnableDnsHostnames={'Value': True}) + vpc.modify_attribute(EnableDnsHostnames={"Value": True}) - response = vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').should.be.ok + response = vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").should.be.ok @mock_ec2_deprecated def test_vpc_associate_dhcp_options(): conn = boto.connect_vpc() - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) vpc = conn.create_vpc("10.0.0.0/16") conn.associate_dhcp_options(dhcp_options.id, vpc.id) @@ -375,117 +372,206 @@ def test_vpc_associate_dhcp_options(): @mock_ec2 def test_associate_vpc_ipv4_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24') + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24") # Associate/Extend vpc CIDR range up to 5 ciders for i in range(43, 47): - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock='10.10.{}.0/24'.format(i)) - response['CidrBlockAssociation']['CidrBlockState']['State'].should.equal('associating') - response['CidrBlockAssociation']['CidrBlock'].should.equal('10.10.{}.0/24'.format(i)) - response['CidrBlockAssociation']['AssociationId'].should.contain('vpc-cidr-assoc') + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, CidrBlock="10.10.{}.0/24".format(i) + ) + response["CidrBlockAssociation"]["CidrBlockState"]["State"].should.equal( + "associating" + ) + response["CidrBlockAssociation"]["CidrBlock"].should.equal( + "10.10.{}.0/24".format(i) + ) + response["CidrBlockAssociation"]["AssociationId"].should.contain( + "vpc-cidr-assoc" + ) # Check all associations exist vpc = ec2.Vpc(vpc.id) vpc.cidr_block_association_set.should.have.length_of(5) - vpc.cidr_block_association_set[2]['CidrBlockState']['State'].should.equal('associated') - vpc.cidr_block_association_set[4]['CidrBlockState']['State'].should.equal('associated') + vpc.cidr_block_association_set[2]["CidrBlockState"]["State"].should.equal( + "associated" + ) + vpc.cidr_block_association_set[4]["CidrBlockState"]["State"].should.equal( + "associated" + ) # Check error on adding 6th association. with assert_raises(ClientError) as ex: - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock='10.10.50.0/22') + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, CidrBlock="10.10.50.0/22" + ) str(ex.exception).should.equal( "An error occurred (CidrLimitExceeded) when calling the AssociateVpcCidrBlock " - "operation: This network '{}' has met its maximum number of allowed CIDRs: 5".format(vpc.id)) + "operation: This network '{}' has met its maximum number of allowed CIDRs: 5".format( + vpc.id + ) + ) + @mock_ec2 def test_disassociate_vpc_ipv4_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock='10.10.43.0/24') + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock="10.10.43.0/24") # Remove an extended cidr block vpc = ec2.Vpc(vpc.id) - non_default_assoc_cidr_block = next(iter([x for x in vpc.cidr_block_association_set if vpc.cidr_block != x['CidrBlock']]), None) - response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId=non_default_assoc_cidr_block['AssociationId']) - response['CidrBlockAssociation']['CidrBlockState']['State'].should.equal('disassociating') - response['CidrBlockAssociation']['CidrBlock'].should.equal(non_default_assoc_cidr_block['CidrBlock']) - response['CidrBlockAssociation']['AssociationId'].should.equal(non_default_assoc_cidr_block['AssociationId']) + non_default_assoc_cidr_block = next( + iter( + [ + x + for x in vpc.cidr_block_association_set + if vpc.cidr_block != x["CidrBlock"] + ] + ), + None, + ) + response = ec2.meta.client.disassociate_vpc_cidr_block( + AssociationId=non_default_assoc_cidr_block["AssociationId"] + ) + response["CidrBlockAssociation"]["CidrBlockState"]["State"].should.equal( + "disassociating" + ) + response["CidrBlockAssociation"]["CidrBlock"].should.equal( + non_default_assoc_cidr_block["CidrBlock"] + ) + response["CidrBlockAssociation"]["AssociationId"].should.equal( + non_default_assoc_cidr_block["AssociationId"] + ) # Error attempting to delete a non-existent CIDR_BLOCK association with assert_raises(ClientError) as ex: - response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId='vpc-cidr-assoc-BORING123') + response = ec2.meta.client.disassociate_vpc_cidr_block( + AssociationId="vpc-cidr-assoc-BORING123" + ) str(ex.exception).should.equal( "An error occurred (InvalidVpcCidrBlockAssociationIdError.NotFound) when calling the " "DisassociateVpcCidrBlock operation: The vpc CIDR block association ID " - "'vpc-cidr-assoc-BORING123' does not exist") + "'vpc-cidr-assoc-BORING123' does not exist" + ) # Error attempting to delete Primary CIDR BLOCK association - vpc_base_cidr_assoc_id = next(iter([x for x in vpc.cidr_block_association_set - if vpc.cidr_block == x['CidrBlock']]), {})['AssociationId'] + vpc_base_cidr_assoc_id = next( + iter( + [ + x + for x in vpc.cidr_block_association_set + if vpc.cidr_block == x["CidrBlock"] + ] + ), + {}, + )["AssociationId"] with assert_raises(ClientError) as ex: - response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId=vpc_base_cidr_assoc_id) + response = ec2.meta.client.disassociate_vpc_cidr_block( + AssociationId=vpc_base_cidr_assoc_id + ) str(ex.exception).should.equal( "An error occurred (OperationNotPermitted) when calling the DisassociateVpcCidrBlock operation: " "The vpc CIDR block with association ID {} may not be disassociated. It is the primary " - "IPv4 CIDR block of the VPC".format(vpc_base_cidr_assoc_id)) + "IPv4 CIDR block of the VPC".format(vpc_base_cidr_assoc_id) + ) + @mock_ec2 def test_cidr_block_association_filters(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - vpc1 = ec2.create_vpc(CidrBlock='10.90.0.0/16') - vpc2 = ec2.create_vpc(CidrBlock='10.91.0.0/16') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock='10.10.0.0/19') - vpc3 = ec2.create_vpc(CidrBlock='10.92.0.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.1.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.2.0/24') - vpc3_assoc_response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.3.0/24') + ec2 = boto3.resource("ec2", region_name="us-west-1") + vpc1 = ec2.create_vpc(CidrBlock="10.90.0.0/16") + vpc2 = ec2.create_vpc(CidrBlock="10.91.0.0/16") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock="10.10.0.0/19") + vpc3 = ec2.create_vpc(CidrBlock="10.92.0.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.1.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.2.0/24") + vpc3_assoc_response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc3.id, CidrBlock="10.92.3.0/24" + ) # Test filters for a cidr-block in all VPCs cidr-block-associations - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'cidr-block-association.cidr-block', - 'Values': ['10.10.0.0/19']}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "cidr-block-association.cidr-block", + "Values": ["10.10.0.0/19"], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc2.id) # Test filter for association id in VPCs - association_id = vpc3_assoc_response['CidrBlockAssociation']['AssociationId'] - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'cidr-block-association.association-id', - 'Values': [association_id]}])) + association_id = vpc3_assoc_response["CidrBlockAssociation"]["AssociationId"] + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "cidr-block-association.association-id", + "Values": [association_id], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc3.id) # Test filter for association state in VPC - this will never show anything in this test - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'cidr-block-association.association-id', - 'Values': ['failing']}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + {"Name": "cidr-block-association.association-id", "Values": ["failing"]} + ] + ) + ) filtered_vpcs.should.be.length_of(0) + @mock_ec2 def test_vpc_associate_ipv6_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Test create VPC with IPV6 cidr range - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24', AmazonProvidedIpv6CidrBlock=True) - ipv6_cidr_block_association_set = next(iter(vpc.ipv6_cidr_block_association_set), None) - ipv6_cidr_block_association_set['Ipv6CidrBlockState']['State'].should.equal('associated') - ipv6_cidr_block_association_set['Ipv6CidrBlock'].should.contain('::/56') - ipv6_cidr_block_association_set['AssociationId'].should.contain('vpc-cidr-assoc') + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24", AmazonProvidedIpv6CidrBlock=True) + ipv6_cidr_block_association_set = next( + iter(vpc.ipv6_cidr_block_association_set), None + ) + ipv6_cidr_block_association_set["Ipv6CidrBlockState"]["State"].should.equal( + "associated" + ) + ipv6_cidr_block_association_set["Ipv6CidrBlock"].should.contain("::/56") + ipv6_cidr_block_association_set["AssociationId"].should.contain("vpc-cidr-assoc") # Test Fail on adding 2nd IPV6 association - AWS only allows 1 at this time! with assert_raises(ClientError) as ex: - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True) + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True + ) str(ex.exception).should.equal( "An error occurred (CidrLimitExceeded) when calling the AssociateVpcCidrBlock " - "operation: This network '{}' has met its maximum number of allowed CIDRs: 1".format(vpc.id)) + "operation: This network '{}' has met its maximum number of allowed CIDRs: 1".format( + vpc.id + ) + ) # Test associate ipv6 cidr block after vpc created - vpc = ec2.create_vpc(CidrBlock='10.10.50.0/24') - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True) - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlockState']['State'].should.equal('associating') - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'].should.contain('::/56') - response['Ipv6CidrBlockAssociation']['AssociationId'].should.contain('vpc-cidr-assoc-') + vpc = ec2.create_vpc(CidrBlock="10.10.50.0/24") + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True + ) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlockState"]["State"].should.equal( + "associating" + ) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlock"].should.contain("::/56") + response["Ipv6CidrBlockAssociation"]["AssociationId"].should.contain( + "vpc-cidr-assoc-" + ) # Check on describe vpc that has ipv6 cidr block association vpc = ec2.Vpc(vpc.id) @@ -494,72 +580,248 @@ def test_vpc_associate_ipv6_cidr_block(): @mock_ec2 def test_vpc_disassociate_ipv6_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Test create VPC with IPV6 cidr range - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24', AmazonProvidedIpv6CidrBlock=True) + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24", AmazonProvidedIpv6CidrBlock=True) # Test disassociating the only IPV6 - assoc_id = vpc.ipv6_cidr_block_association_set[0]['AssociationId'] + assoc_id = vpc.ipv6_cidr_block_association_set[0]["AssociationId"] response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId=assoc_id) - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlockState']['State'].should.equal('disassociating') - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'].should.contain('::/56') - response['Ipv6CidrBlockAssociation']['AssociationId'].should.equal(assoc_id) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlockState"]["State"].should.equal( + "disassociating" + ) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlock"].should.contain("::/56") + response["Ipv6CidrBlockAssociation"]["AssociationId"].should.equal(assoc_id) @mock_ec2 def test_ipv6_cidr_block_association_filters(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - vpc1 = ec2.create_vpc(CidrBlock='10.90.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-west-1") + vpc1 = ec2.create_vpc(CidrBlock="10.90.0.0/16") - vpc2 = ec2.create_vpc(CidrBlock='10.91.0.0/16', AmazonProvidedIpv6CidrBlock=True) - vpc2_assoc_ipv6_assoc_id = vpc2.ipv6_cidr_block_association_set[0]['AssociationId'] - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock='10.10.0.0/19') + vpc2 = ec2.create_vpc(CidrBlock="10.91.0.0/16", AmazonProvidedIpv6CidrBlock=True) + vpc2_assoc_ipv6_assoc_id = vpc2.ipv6_cidr_block_association_set[0]["AssociationId"] + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock="10.10.0.0/19") - vpc3 = ec2.create_vpc(CidrBlock='10.92.0.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.1.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.2.0/24') - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, AmazonProvidedIpv6CidrBlock=True) - vpc3_ipv6_cidr_block = response['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'] + vpc3 = ec2.create_vpc(CidrBlock="10.92.0.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.1.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.2.0/24") + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc3.id, AmazonProvidedIpv6CidrBlock=True + ) + vpc3_ipv6_cidr_block = response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlock"] - vpc4 = ec2.create_vpc(CidrBlock='10.95.0.0/16') # Here for its looks + vpc4 = ec2.create_vpc(CidrBlock="10.95.0.0/16") # Here for its looks # Test filters for an ipv6 cidr-block in all VPCs cidr-block-associations - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'ipv6-cidr-block-association.ipv6-cidr-block', - 'Values': [vpc3_ipv6_cidr_block]}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "ipv6-cidr-block-association.ipv6-cidr-block", + "Values": [vpc3_ipv6_cidr_block], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc3.id) # Test filter for association id in VPCs - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'ipv6-cidr-block-association.association-id', - 'Values': [vpc2_assoc_ipv6_assoc_id]}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "ipv6-cidr-block-association.association-id", + "Values": [vpc2_assoc_ipv6_assoc_id], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc2.id) # Test filter for association state in VPC - this will never show anything in this test - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'ipv6-cidr-block-association.state', - 'Values': ['associated']}])) - filtered_vpcs.should.be.length_of(2) # 2 of 4 VPCs + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + {"Name": "ipv6-cidr-block-association.state", "Values": ["associated"]} + ] + ) + ) + filtered_vpcs.should.be.length_of(2) # 2 of 4 VPCs @mock_ec2 def test_create_vpc_with_invalid_cidr_block_parameter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc_cidr_block = '1000.1.0.0/20' + vpc_cidr_block = "1000.1.0.0/20" with assert_raises(ClientError) as ex: vpc = ec2.create_vpc(CidrBlock=vpc_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidParameterValue) when calling the CreateVpc " - "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(vpc_cidr_block)) + "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format( + vpc_cidr_block + ) + ) @mock_ec2 def test_create_vpc_with_invalid_cidr_range(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc_cidr_block = '10.1.0.0/29' + vpc_cidr_block = "10.1.0.0/29" with assert_raises(ClientError) as ex: vpc = ec2.create_vpc(CidrBlock=vpc_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidVpc.Range) when calling the CreateVpc " - "operation: The CIDR '{}' is invalid.".format(vpc_cidr_block)) + "operation: The CIDR '{}' is invalid.".format(vpc_cidr_block) + ) + + +@mock_ec2 +def test_enable_vpc_classic_link(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.1.0.0/16") + + response = ec2.meta.client.enable_vpc_classic_link(VpcId=vpc.id) + assert response.get("Return").should.be.true + + +@mock_ec2 +def test_enable_vpc_classic_link_failure(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.90.0.0/16") + + response = ec2.meta.client.enable_vpc_classic_link(VpcId=vpc.id) + assert response.get("Return").should.be.false + + +@mock_ec2 +def test_disable_vpc_classic_link(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + + ec2.meta.client.enable_vpc_classic_link(VpcId=vpc.id) + response = ec2.meta.client.disable_vpc_classic_link(VpcId=vpc.id) + assert response.get("Return").should.be.false + + +@mock_ec2 +def test_describe_classic_link_enabled(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + + ec2.meta.client.enable_vpc_classic_link(VpcId=vpc.id) + response = ec2.meta.client.describe_vpc_classic_link(VpcIds=[vpc.id]) + assert response.get("Vpcs")[0].get("ClassicLinkEnabled").should.be.true + + +@mock_ec2 +def test_describe_classic_link_disabled(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.90.0.0/16") + + response = ec2.meta.client.describe_vpc_classic_link(VpcIds=[vpc.id]) + assert response.get("Vpcs")[0].get("ClassicLinkEnabled").should.be.false + + +@mock_ec2 +def test_describe_classic_link_multiple(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc1 = ec2.create_vpc(CidrBlock="10.90.0.0/16") + vpc2 = ec2.create_vpc(CidrBlock="10.0.0.0/16") + + ec2.meta.client.enable_vpc_classic_link(VpcId=vpc2.id) + response = ec2.meta.client.describe_vpc_classic_link(VpcIds=[vpc1.id, vpc2.id]) + expected = [ + {"VpcId": vpc1.id, "ClassicLinkDnsSupported": False}, + {"VpcId": vpc2.id, "ClassicLinkDnsSupported": True}, + ] + + # Ensure response is sorted, because they can come in random order + assert response.get("Vpcs").sort(key=lambda x: x["VpcId"]) == expected.sort( + key=lambda x: x["VpcId"] + ) + + +@mock_ec2 +def test_enable_vpc_classic_link_dns_support(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.1.0.0/16") + + response = ec2.meta.client.enable_vpc_classic_link_dns_support(VpcId=vpc.id) + assert response.get("Return").should.be.true + + +@mock_ec2 +def test_disable_vpc_classic_link_dns_support(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + + ec2.meta.client.enable_vpc_classic_link_dns_support(VpcId=vpc.id) + response = ec2.meta.client.disable_vpc_classic_link_dns_support(VpcId=vpc.id) + assert response.get("Return").should.be.false + + +@mock_ec2 +def test_describe_classic_link_dns_support_enabled(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + + ec2.meta.client.enable_vpc_classic_link_dns_support(VpcId=vpc.id) + response = ec2.meta.client.describe_vpc_classic_link_dns_support(VpcIds=[vpc.id]) + assert response.get("Vpcs")[0].get("ClassicLinkDnsSupported").should.be.true + + +@mock_ec2 +def test_describe_classic_link_dns_support_disabled(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc = ec2.create_vpc(CidrBlock="10.90.0.0/16") + + response = ec2.meta.client.describe_vpc_classic_link_dns_support(VpcIds=[vpc.id]) + assert response.get("Vpcs")[0].get("ClassicLinkDnsSupported").should.be.false + + +@mock_ec2 +def test_describe_classic_link_dns_support_multiple(): + ec2 = boto3.resource("ec2", region_name="us-west-1") + + # Create VPC + vpc1 = ec2.create_vpc(CidrBlock="10.90.0.0/16") + vpc2 = ec2.create_vpc(CidrBlock="10.0.0.0/16") + + ec2.meta.client.enable_vpc_classic_link_dns_support(VpcId=vpc2.id) + response = ec2.meta.client.describe_vpc_classic_link_dns_support( + VpcIds=[vpc1.id, vpc2.id] + ) + expected = [ + {"VpcId": vpc1.id, "ClassicLinkDnsSupported": False}, + {"VpcId": vpc2.id, "ClassicLinkDnsSupported": True}, + ] + + # Ensure response is sorted, because they can come in random order + assert response.get("Vpcs").sort(key=lambda x: x["VpcId"]) == expected.sort( + key=lambda x: x["VpcId"] + ) diff --git a/tests/test_ec2/test_vpn_connections.py b/tests/test_ec2/test_vpn_connections.py index 70c3f3e33..24396d3d1 100644 --- a/tests/test_ec2/test_vpn_connections.py +++ b/tests/test_ec2/test_vpn_connections.py @@ -1,51 +1,53 @@ -from __future__ import unicode_literals -import boto -from nose.tools import assert_raises -import sure # noqa -from boto.exception import EC2ResponseError - -from moto import mock_ec2_deprecated - - -@mock_ec2_deprecated -def test_create_vpn_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_connection = conn.create_vpn_connection( - 'ipsec.1', 'vgw-0123abcd', 'cgw-0123abcd') - vpn_connection.should_not.be.none - vpn_connection.id.should.match(r'vpn-\w+') - vpn_connection.type.should.equal('ipsec.1') - - -@mock_ec2_deprecated -def test_delete_vpn_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_connection = conn.create_vpn_connection( - 'ipsec.1', 'vgw-0123abcd', 'cgw-0123abcd') - list_of_vpn_connections = conn.get_all_vpn_connections() - list_of_vpn_connections.should.have.length_of(1) - conn.delete_vpn_connection(vpn_connection.id) - list_of_vpn_connections = conn.get_all_vpn_connections() - list_of_vpn_connections.should.have.length_of(0) - - -@mock_ec2_deprecated -def test_delete_vpn_connections_bad_id(): - conn = boto.connect_vpc('the_key', 'the_secret') - with assert_raises(EC2ResponseError): - conn.delete_vpn_connection('vpn-0123abcd') - - -@mock_ec2_deprecated -def test_describe_vpn_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') - list_of_vpn_connections = conn.get_all_vpn_connections() - list_of_vpn_connections.should.have.length_of(0) - conn.create_vpn_connection('ipsec.1', 'vgw-0123abcd', 'cgw-0123abcd') - list_of_vpn_connections = conn.get_all_vpn_connections() - list_of_vpn_connections.should.have.length_of(1) - vpn = conn.create_vpn_connection('ipsec.1', 'vgw-1234abcd', 'cgw-1234abcd') - list_of_vpn_connections = conn.get_all_vpn_connections() - list_of_vpn_connections.should.have.length_of(2) - list_of_vpn_connections = conn.get_all_vpn_connections(vpn.id) - list_of_vpn_connections.should.have.length_of(1) +from __future__ import unicode_literals +import boto +from nose.tools import assert_raises +import sure # noqa +from boto.exception import EC2ResponseError + +from moto import mock_ec2_deprecated + + +@mock_ec2_deprecated +def test_create_vpn_connections(): + conn = boto.connect_vpc("the_key", "the_secret") + vpn_connection = conn.create_vpn_connection( + "ipsec.1", "vgw-0123abcd", "cgw-0123abcd" + ) + vpn_connection.should_not.be.none + vpn_connection.id.should.match(r"vpn-\w+") + vpn_connection.type.should.equal("ipsec.1") + + +@mock_ec2_deprecated +def test_delete_vpn_connections(): + conn = boto.connect_vpc("the_key", "the_secret") + vpn_connection = conn.create_vpn_connection( + "ipsec.1", "vgw-0123abcd", "cgw-0123abcd" + ) + list_of_vpn_connections = conn.get_all_vpn_connections() + list_of_vpn_connections.should.have.length_of(1) + conn.delete_vpn_connection(vpn_connection.id) + list_of_vpn_connections = conn.get_all_vpn_connections() + list_of_vpn_connections.should.have.length_of(0) + + +@mock_ec2_deprecated +def test_delete_vpn_connections_bad_id(): + conn = boto.connect_vpc("the_key", "the_secret") + with assert_raises(EC2ResponseError): + conn.delete_vpn_connection("vpn-0123abcd") + + +@mock_ec2_deprecated +def test_describe_vpn_connections(): + conn = boto.connect_vpc("the_key", "the_secret") + list_of_vpn_connections = conn.get_all_vpn_connections() + list_of_vpn_connections.should.have.length_of(0) + conn.create_vpn_connection("ipsec.1", "vgw-0123abcd", "cgw-0123abcd") + list_of_vpn_connections = conn.get_all_vpn_connections() + list_of_vpn_connections.should.have.length_of(1) + vpn = conn.create_vpn_connection("ipsec.1", "vgw-1234abcd", "cgw-1234abcd") + list_of_vpn_connections = conn.get_all_vpn_connections() + list_of_vpn_connections.should.have.length_of(2) + list_of_vpn_connections = conn.get_all_vpn_connections(vpn.id) + list_of_vpn_connections.should.have.length_of(1) diff --git a/tests/test_ecr/test_ecr_boto3.py b/tests/test_ecr/test_ecr_boto3.py index ec0e4e732..9115e3fad 100644 --- a/tests/test_ecr/test_ecr_boto3.py +++ b/tests/test_ecr/test_ecr_boto3.py @@ -20,1062 +20,1035 @@ from nose import SkipTest def _create_image_digest(contents=None): if not contents: - contents = 'docker_image{0}'.format(int(random() * 10 ** 6)) - return "sha256:%s" % hashlib.sha256(contents.encode('utf-8')).hexdigest() + contents = "docker_image{0}".format(int(random() * 10 ** 6)) + return "sha256:%s" % hashlib.sha256(contents.encode("utf-8")).hexdigest() def _create_image_manifest(): return { "schemaVersion": 2, "mediaType": "application/vnd.docker.distribution.manifest.v2+json", - "config": - { - "mediaType": "application/vnd.docker.container.image.v1+json", - "size": 7023, - "digest": _create_image_digest("config") - }, + "config": { + "mediaType": "application/vnd.docker.container.image.v1+json", + "size": 7023, + "digest": _create_image_digest("config"), + }, "layers": [ { "mediaType": "application/vnd.docker.image.rootfs.diff.tar.gzip", "size": 32654, - "digest": _create_image_digest("layer1") + "digest": _create_image_digest("layer1"), }, { "mediaType": "application/vnd.docker.image.rootfs.diff.tar.gzip", "size": 16724, - "digest": _create_image_digest("layer2") + "digest": _create_image_digest("layer2"), }, { "mediaType": "application/vnd.docker.image.rootfs.diff.tar.gzip", "size": 73109, # randomize image digest - "digest": _create_image_digest() - } - ] + "digest": _create_image_digest(), + }, + ], } @mock_ecr def test_create_repository(): - client = boto3.client('ecr', region_name='us-east-1') - response = client.create_repository( - repositoryName='test_ecr_repository' + client = boto3.client("ecr", region_name="us-east-1") + response = client.create_repository(repositoryName="test_ecr_repository") + response["repository"]["repositoryName"].should.equal("test_ecr_repository") + response["repository"]["repositoryArn"].should.equal( + "arn:aws:ecr:us-east-1:012345678910:repository/test_ecr_repository" + ) + response["repository"]["registryId"].should.equal("012345678910") + response["repository"]["repositoryUri"].should.equal( + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_ecr_repository" ) - response['repository']['repositoryName'].should.equal('test_ecr_repository') - response['repository']['repositoryArn'].should.equal( - 'arn:aws:ecr:us-east-1:012345678910:repository/test_ecr_repository') - response['repository']['registryId'].should.equal('012345678910') - response['repository']['repositoryUri'].should.equal( - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_ecr_repository') # response['repository']['createdAt'].should.equal(0) @mock_ecr def test_describe_repositories(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") response = client.describe_repositories() - len(response['repositories']).should.equal(2) + len(response["repositories"]).should.equal(2) - respository_arns = ['arn:aws:ecr:us-east-1:012345678910:repository/test_repository1', - 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository0'] - set([response['repositories'][0]['repositoryArn'], - response['repositories'][1]['repositoryArn']]).should.equal(set(respository_arns)) + respository_arns = [ + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryArn"], + response["repositories"][1]["repositoryArn"], + ] + ).should.equal(set(respository_arns)) - respository_uris = ['012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1', - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0'] - set([response['repositories'][0]['repositoryUri'], - response['repositories'][1]['repositoryUri']]).should.equal(set(respository_uris)) + respository_uris = [ + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryUri"], + response["repositories"][1]["repositoryUri"], + ] + ).should.equal(set(respository_uris)) @mock_ecr def test_describe_repositories_1(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) - response = client.describe_repositories(registryId='012345678910') - len(response['repositories']).should.equal(2) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") + response = client.describe_repositories(registryId="012345678910") + len(response["repositories"]).should.equal(2) - respository_arns = ['arn:aws:ecr:us-east-1:012345678910:repository/test_repository1', - 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository0'] - set([response['repositories'][0]['repositoryArn'], - response['repositories'][1]['repositoryArn']]).should.equal(set(respository_arns)) + respository_arns = [ + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryArn"], + response["repositories"][1]["repositoryArn"], + ] + ).should.equal(set(respository_arns)) - respository_uris = ['012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1', - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0'] - set([response['repositories'][0]['repositoryUri'], - response['repositories'][1]['repositoryUri']]).should.equal(set(respository_uris)) + respository_uris = [ + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryUri"], + response["repositories"][1]["repositoryUri"], + ] + ).should.equal(set(respository_uris)) @mock_ecr def test_describe_repositories_2(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) - response = client.describe_repositories(registryId='109876543210') - len(response['repositories']).should.equal(0) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") + response = client.describe_repositories(registryId="109876543210") + len(response["repositories"]).should.equal(0) @mock_ecr def test_describe_repositories_3(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) - response = client.describe_repositories(repositoryNames=['test_repository1']) - len(response['repositories']).should.equal(1) - respository_arn = 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository1' - response['repositories'][0]['repositoryArn'].should.equal(respository_arn) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") + response = client.describe_repositories(repositoryNames=["test_repository1"]) + len(response["repositories"]).should.equal(1) + respository_arn = "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1" + response["repositories"][0]["repositoryArn"].should.equal(respository_arn) - respository_uri = '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1' - response['repositories'][0]['repositoryUri'].should.equal(respository_uri) + respository_uri = "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1" + response["repositories"][0]["repositoryUri"].should.equal(respository_uri) @mock_ecr def test_describe_repositories_with_image(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - response = client.describe_repositories(repositoryNames=['test_repository']) - len(response['repositories']).should.equal(1) + response = client.describe_repositories(repositoryNames=["test_repository"]) + len(response["repositories"]).should.equal(1) @mock_ecr def test_delete_repository(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") + response = client.delete_repository(repositoryName="test_repository") + response["repository"]["repositoryName"].should.equal("test_repository") + response["repository"]["repositoryArn"].should.equal( + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository" + ) + response["repository"]["registryId"].should.equal("012345678910") + response["repository"]["repositoryUri"].should.equal( + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository" ) - response = client.delete_repository(repositoryName='test_repository') - response['repository']['repositoryName'].should.equal('test_repository') - response['repository']['repositoryArn'].should.equal( - 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository') - response['repository']['registryId'].should.equal('012345678910') - response['repository']['repositoryUri'].should.equal( - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository') # response['repository']['createdAt'].should.equal(0) response = client.describe_repositories() - len(response['repositories']).should.equal(0) + len(response["repositories"]).should.equal(0) @mock_ecr def test_put_image(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - response['image']['imageId']['imageTag'].should.equal('latest') - response['image']['imageId']['imageDigest'].should.contain("sha") - response['image']['repositoryName'].should.equal('test_repository') - response['image']['registryId'].should.equal('012345678910') + response["image"]["imageId"]["imageTag"].should.equal("latest") + response["image"]["imageId"]["imageDigest"].should.contain("sha") + response["image"]["repositoryName"].should.equal("test_repository") + response["image"]["registryId"].should.equal("012345678910") @mock_ecr def test_put_image_with_push_date(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") - with freeze_time('2018-08-28 00:00:00'): + with freeze_time("2018-08-28 00:00:00"): image1_date = datetime.now() _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - with freeze_time('2019-05-31 00:00:00'): + with freeze_time("2019-05-31 00:00:00"): image2_date = datetime.now() _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") - type(describe_response['imageDetails']).should.be(list) - len(describe_response['imageDetails']).should.be(2) + type(describe_response["imageDetails"]).should.be(list) + len(describe_response["imageDetails"]).should.be(2) - set([describe_response['imageDetails'][0]['imagePushedAt'], - describe_response['imageDetails'][1]['imagePushedAt']]).should.equal(set([image1_date, image2_date])) + set( + [ + describe_response["imageDetails"][0]["imagePushedAt"], + describe_response["imageDetails"][1]["imagePushedAt"], + ] + ).should.equal(set([image1_date, image2_date])) @mock_ecr def test_put_image_with_multiple_tags(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag='v1' + imageTag="v1", ) - response['image']['imageId']['imageTag'].should.equal('v1') - response['image']['imageId']['imageDigest'].should.contain("sha") - response['image']['repositoryName'].should.equal('test_repository') - response['image']['registryId'].should.equal('012345678910') + response["image"]["imageId"]["imageTag"].should.equal("v1") + response["image"]["imageId"]["imageDigest"].should.contain("sha") + response["image"]["repositoryName"].should.equal("test_repository") + response["image"]["registryId"].should.equal("012345678910") response1 = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag='latest' + imageTag="latest", ) - response1['image']['imageId']['imageTag'].should.equal('latest') - response1['image']['imageId']['imageDigest'].should.contain("sha") - response1['image']['repositoryName'].should.equal('test_repository') - response1['image']['registryId'].should.equal('012345678910') + response1["image"]["imageId"]["imageTag"].should.equal("latest") + response1["image"]["imageId"]["imageDigest"].should.contain("sha") + response1["image"]["repositoryName"].should.equal("test_repository") + response1["image"]["registryId"].should.equal("012345678910") - response2 = client.describe_images(repositoryName='test_repository') - type(response2['imageDetails']).should.be(list) - len(response2['imageDetails']).should.be(1) + response2 = client.describe_images(repositoryName="test_repository") + type(response2["imageDetails"]).should.be(list) + len(response2["imageDetails"]).should.be(1) - response2['imageDetails'][0]['imageDigest'].should.contain("sha") + response2["imageDetails"][0]["imageDigest"].should.contain("sha") - response2['imageDetails'][0]['registryId'].should.equal("012345678910") + response2["imageDetails"][0]["registryId"].should.equal("012345678910") - response2['imageDetails'][0]['repositoryName'].should.equal("test_repository") + response2["imageDetails"][0]["repositoryName"].should.equal("test_repository") - len(response2['imageDetails'][0]['imageTags']).should.be(2) - response2['imageDetails'][0]['imageTags'].should.be.equal(['v1', 'latest']) + len(response2["imageDetails"][0]["imageTags"]).should.be(2) + response2["imageDetails"][0]["imageTags"].should.be.equal(["v1", "latest"]) @mock_ecr def test_list_images(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository_1' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository_1") - _ = client.create_repository( - repositoryName='test_repository_2' + _ = client.create_repository(repositoryName="test_repository_2") + + _ = client.put_image( + repositoryName="test_repository_1", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository_1', + repositoryName="test_repository_1", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="v1", ) _ = client.put_image( - repositoryName='test_repository_1', + repositoryName="test_repository_1", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' + imageTag="v2", ) _ = client.put_image( - repositoryName='test_repository_1', + repositoryName="test_repository_2", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="oldest", ) - _ = client.put_image( - repositoryName='test_repository_2', - imageManifest=json.dumps(_create_image_manifest()), - imageTag='oldest' - ) + response = client.list_images(repositoryName="test_repository_1") + type(response["imageIds"]).should.be(list) + len(response["imageIds"]).should.be(3) - response = client.list_images(repositoryName='test_repository_1') - type(response['imageIds']).should.be(list) - len(response['imageIds']).should.be(3) + image_tags = ["latest", "v1", "v2"] + set( + [ + response["imageIds"][0]["imageTag"], + response["imageIds"][1]["imageTag"], + response["imageIds"][2]["imageTag"], + ] + ).should.equal(set(image_tags)) - image_tags = ['latest', 'v1', 'v2'] - set([response['imageIds'][0]['imageTag'], - response['imageIds'][1]['imageTag'], - response['imageIds'][2]['imageTag']]).should.equal(set(image_tags)) - - response = client.list_images(repositoryName='test_repository_2') - type(response['imageIds']).should.be(list) - len(response['imageIds']).should.be(1) - response['imageIds'][0]['imageTag'].should.equal('oldest') + response = client.list_images(repositoryName="test_repository_2") + type(response["imageIds"]).should.be(list) + len(response["imageIds"]).should.be(1) + response["imageIds"][0]["imageTag"].should.equal("oldest") @mock_ecr def test_list_images_from_repository_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository_1' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository_1") # non existing repo error_msg = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.list_images.when.called_with( - repositoryName='repo-that-doesnt-exist', - registryId='123', + repositoryName="repo-that-doesnt-exist", registryId="123" ).should.throw(Exception, error_msg) # repo does not exist in specified registry error_msg = re.compile( r".*The repository with name 'test_repository_1' does not exist in the registry with id '222'.*", - re.MULTILINE) + re.MULTILINE, + ) client.list_images.when.called_with( - repositoryName='test_repository_1', - registryId='222', + repositoryName="test_repository_1", registryId="222" ).should.throw(Exception, error_msg) @mock_ecr def test_describe_images(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") _ = client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(_create_image_manifest()) - ) - - _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="v1", ) - response = client.describe_images(repositoryName='test_repository') - type(response['imageDetails']).should.be(list) - len(response['imageDetails']).should.be(4) + _ = client.put_image( + repositoryName="test_repository", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="v2", + ) - response['imageDetails'][0]['imageDigest'].should.contain("sha") - response['imageDetails'][1]['imageDigest'].should.contain("sha") - response['imageDetails'][2]['imageDigest'].should.contain("sha") - response['imageDetails'][3]['imageDigest'].should.contain("sha") + response = client.describe_images(repositoryName="test_repository") + type(response["imageDetails"]).should.be(list) + len(response["imageDetails"]).should.be(4) - response['imageDetails'][0]['registryId'].should.equal("012345678910") - response['imageDetails'][1]['registryId'].should.equal("012345678910") - response['imageDetails'][2]['registryId'].should.equal("012345678910") - response['imageDetails'][3]['registryId'].should.equal("012345678910") + response["imageDetails"][0]["imageDigest"].should.contain("sha") + response["imageDetails"][1]["imageDigest"].should.contain("sha") + response["imageDetails"][2]["imageDigest"].should.contain("sha") + response["imageDetails"][3]["imageDigest"].should.contain("sha") - response['imageDetails'][0]['repositoryName'].should.equal("test_repository") - response['imageDetails'][1]['repositoryName'].should.equal("test_repository") - response['imageDetails'][2]['repositoryName'].should.equal("test_repository") - response['imageDetails'][3]['repositoryName'].should.equal("test_repository") + response["imageDetails"][0]["registryId"].should.equal("012345678910") + response["imageDetails"][1]["registryId"].should.equal("012345678910") + response["imageDetails"][2]["registryId"].should.equal("012345678910") + response["imageDetails"][3]["registryId"].should.equal("012345678910") - response['imageDetails'][0].should_not.have.key('imageTags') - len(response['imageDetails'][1]['imageTags']).should.be(1) - len(response['imageDetails'][2]['imageTags']).should.be(1) - len(response['imageDetails'][3]['imageTags']).should.be(1) + response["imageDetails"][0]["repositoryName"].should.equal("test_repository") + response["imageDetails"][1]["repositoryName"].should.equal("test_repository") + response["imageDetails"][2]["repositoryName"].should.equal("test_repository") + response["imageDetails"][3]["repositoryName"].should.equal("test_repository") - image_tags = ['latest', 'v1', 'v2'] - set([response['imageDetails'][1]['imageTags'][0], - response['imageDetails'][2]['imageTags'][0], - response['imageDetails'][3]['imageTags'][0]]).should.equal(set(image_tags)) + response["imageDetails"][0].should_not.have.key("imageTags") + len(response["imageDetails"][1]["imageTags"]).should.be(1) + len(response["imageDetails"][2]["imageTags"]).should.be(1) + len(response["imageDetails"][3]["imageTags"]).should.be(1) - response['imageDetails'][0]['imageSizeInBytes'].should.equal(52428800) - response['imageDetails'][1]['imageSizeInBytes'].should.equal(52428800) - response['imageDetails'][2]['imageSizeInBytes'].should.equal(52428800) - response['imageDetails'][3]['imageSizeInBytes'].should.equal(52428800) + image_tags = ["latest", "v1", "v2"] + set( + [ + response["imageDetails"][1]["imageTags"][0], + response["imageDetails"][2]["imageTags"][0], + response["imageDetails"][3]["imageTags"][0], + ] + ).should.equal(set(image_tags)) + + response["imageDetails"][0]["imageSizeInBytes"].should.equal(52428800) + response["imageDetails"][1]["imageSizeInBytes"].should.equal(52428800) + response["imageDetails"][2]["imageSizeInBytes"].should.equal(52428800) + response["imageDetails"][3]["imageSizeInBytes"].should.equal(52428800) @mock_ecr def test_describe_images_by_tag(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") tag_map = {} - for tag in ['latest', 'v1', 'v2']: + for tag in ["latest", "v1", "v2"]: put_response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag=tag + imageTag=tag, ) - tag_map[tag] = put_response['image'] + tag_map[tag] = put_response["image"] for tag, put_response in tag_map.items(): - response = client.describe_images(repositoryName='test_repository', imageIds=[{'imageTag': tag}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - image_detail['registryId'].should.equal("012345678910") - image_detail['repositoryName'].should.equal("test_repository") - image_detail['imageTags'].should.equal([put_response['imageId']['imageTag']]) - image_detail['imageDigest'].should.equal(put_response['imageId']['imageDigest']) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageTag": tag}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + image_detail["registryId"].should.equal("012345678910") + image_detail["repositoryName"].should.equal("test_repository") + image_detail["imageTags"].should.equal([put_response["imageId"]["imageTag"]]) + image_detail["imageDigest"].should.equal(put_response["imageId"]["imageDigest"]) @mock_ecr def test_describe_images_tags_should_not_contain_empty_tag1(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(manifest) + repositoryName="test_repository", imageManifest=json.dumps(manifest) ) - tags = ['v1', 'v2', 'latest'] + tags = ["v1", "v2", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - response = client.describe_images(repositoryName='test_repository', imageIds=[{'imageTag': tag}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - len(image_detail['imageTags']).should.equal(3) - image_detail['imageTags'].should.be.equal(tags) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageTag": tag}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + len(image_detail["imageTags"]).should.equal(3) + image_detail["imageTags"].should.be.equal(tags) @mock_ecr def test_describe_images_tags_should_not_contain_empty_tag2(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v2'] + tags = ["v1", "v2"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(manifest) + repositoryName="test_repository", imageManifest=json.dumps(manifest) ) client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag='latest' + imageTag="latest", ) - response = client.describe_images(repositoryName='test_repository', imageIds=[{'imageTag': tag}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - len(image_detail['imageTags']).should.equal(3) - image_detail['imageTags'].should.be.equal(['v1', 'v2', 'latest']) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageTag": tag}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + len(image_detail["imageTags"]).should.equal(3) + image_detail["imageTags"].should.be.equal(["v1", "v2", "latest"]) @mock_ecr def test_describe_repository_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') + client = boto3.client("ecr", region_name="us-east-1") error_msg = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.describe_repositories.when.called_with( - repositoryNames=['repo-that-doesnt-exist'], - registryId='123', + repositoryNames=["repo-that-doesnt-exist"], registryId="123" ).should.throw(ClientError, error_msg) + @mock_ecr def test_describe_image_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository(repositoryName='test_repository') + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") error_msg1 = re.compile( r".*The image with imageId {imageDigest:'null', imageTag:'testtag'} does not exist within " r"the repository with name 'test_repository' in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.describe_images.when.called_with( - repositoryName='test_repository', imageIds=[{'imageTag': 'testtag'}], registryId='123', + repositoryName="test_repository", + imageIds=[{"imageTag": "testtag"}], + registryId="123", ).should.throw(ClientError, error_msg1) error_msg2 = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.describe_images.when.called_with( - repositoryName='repo-that-doesnt-exist', imageIds=[{'imageTag': 'testtag'}], registryId='123', + repositoryName="repo-that-doesnt-exist", + imageIds=[{"imageTag": "testtag"}], + registryId="123", ).should.throw(ClientError, error_msg2) @mock_ecr def test_delete_repository_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') + client = boto3.client("ecr", region_name="us-east-1") error_msg = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.delete_repository.when.called_with( - repositoryName='repo-that-doesnt-exist', - registryId='123').should.throw( - ClientError, error_msg) + repositoryName="repo-that-doesnt-exist", registryId="123" + ).should.throw(ClientError, error_msg) @mock_ecr def test_describe_images_by_digest(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") - tags = ['latest', 'v1', 'v2'] + tags = ["latest", "v1", "v2"] digest_map = {} for tag in tags: put_response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag=tag + imageTag=tag, ) - digest_map[put_response['image']['imageId']['imageDigest']] = put_response['image'] + digest_map[put_response["image"]["imageId"]["imageDigest"]] = put_response[ + "image" + ] for digest, put_response in digest_map.items(): - response = client.describe_images(repositoryName='test_repository', - imageIds=[{'imageDigest': digest}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - image_detail['registryId'].should.equal("012345678910") - image_detail['repositoryName'].should.equal("test_repository") - image_detail['imageTags'].should.equal([put_response['imageId']['imageTag']]) - image_detail['imageDigest'].should.equal(digest) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageDigest": digest}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + image_detail["registryId"].should.equal("012345678910") + image_detail["repositoryName"].should.equal("test_repository") + image_detail["imageTags"].should.equal([put_response["imageId"]["imageTag"]]) + image_detail["imageDigest"].should.equal(digest) @mock_ecr def test_get_authorization_token_assume_region(): - client = boto3.client('ecr', region_name='us-east-1') + client = boto3.client("ecr", region_name="us-east-1") auth_token_response = client.get_authorization_token() - auth_token_response.should.contain('authorizationData') - auth_token_response.should.contain('ResponseMetadata') - auth_token_response['authorizationData'].should.equal([ - { - 'authorizationToken': 'QVdTOjAxMjM0NTY3ODkxMC1hdXRoLXRva2Vu', - 'proxyEndpoint': 'https://012345678910.dkr.ecr.us-east-1.amazonaws.com', - 'expiresAt': datetime(2015, 1, 1, tzinfo=tzlocal()) - }, - ]) + auth_token_response.should.contain("authorizationData") + auth_token_response.should.contain("ResponseMetadata") + auth_token_response["authorizationData"].should.equal( + [ + { + "authorizationToken": "QVdTOjAxMjM0NTY3ODkxMC1hdXRoLXRva2Vu", + "proxyEndpoint": "https://012345678910.dkr.ecr.us-east-1.amazonaws.com", + "expiresAt": datetime(2015, 1, 1, tzinfo=tzlocal()), + } + ] + ) @mock_ecr def test_get_authorization_token_explicit_regions(): - client = boto3.client('ecr', region_name='us-east-1') - auth_token_response = client.get_authorization_token(registryIds=['10987654321', '878787878787']) + client = boto3.client("ecr", region_name="us-east-1") + auth_token_response = client.get_authorization_token( + registryIds=["10987654321", "878787878787"] + ) - auth_token_response.should.contain('authorizationData') - auth_token_response.should.contain('ResponseMetadata') - auth_token_response['authorizationData'].should.equal([ - { - 'authorizationToken': 'QVdTOjEwOTg3NjU0MzIxLWF1dGgtdG9rZW4=', - 'proxyEndpoint': 'https://10987654321.dkr.ecr.us-east-1.amazonaws.com', - 'expiresAt': datetime(2015, 1, 1, tzinfo=tzlocal()), - }, - { - 'authorizationToken': 'QVdTOjg3ODc4Nzg3ODc4Ny1hdXRoLXRva2Vu', - 'proxyEndpoint': 'https://878787878787.dkr.ecr.us-east-1.amazonaws.com', - 'expiresAt': datetime(2015, 1, 1, tzinfo=tzlocal()) - - } - ]) + auth_token_response.should.contain("authorizationData") + auth_token_response.should.contain("ResponseMetadata") + auth_token_response["authorizationData"].should.equal( + [ + { + "authorizationToken": "QVdTOjEwOTg3NjU0MzIxLWF1dGgtdG9rZW4=", + "proxyEndpoint": "https://10987654321.dkr.ecr.us-east-1.amazonaws.com", + "expiresAt": datetime(2015, 1, 1, tzinfo=tzlocal()), + }, + { + "authorizationToken": "QVdTOjg3ODc4Nzg3ODc4Ny1hdXRoLXRva2Vu", + "proxyEndpoint": "https://878787878787.dkr.ecr.us-east-1.amazonaws.com", + "expiresAt": datetime(2015, 1, 1, tzinfo=tzlocal()), + }, + ] + ) @mock_ecr def test_batch_get_image(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") + + _ = client.put_image( + repositoryName="test_repository", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="v1", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' - ) - - _ = client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="v2", ) response = client.batch_get_image( - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'v2' - }, - ], + repositoryName="test_repository", imageIds=[{"imageTag": "v2"}] ) - type(response['images']).should.be(list) - len(response['images']).should.be(1) + type(response["images"]).should.be(list) + len(response["images"]).should.be(1) - response['images'][0]['imageManifest'].should.contain("vnd.docker.distribution.manifest.v2+json") - response['images'][0]['registryId'].should.equal("012345678910") - response['images'][0]['repositoryName'].should.equal("test_repository") + response["images"][0]["imageManifest"].should.contain( + "vnd.docker.distribution.manifest.v2+json" + ) + response["images"][0]["registryId"].should.equal("012345678910") + response["images"][0]["repositoryName"].should.equal("test_repository") - response['images'][0]['imageId']['imageTag'].should.equal("v2") - response['images'][0]['imageId']['imageDigest'].should.contain("sha") + response["images"][0]["imageId"]["imageTag"].should.equal("v2") + response["images"][0]["imageId"]["imageDigest"].should.contain("sha") - type(response['failures']).should.be(list) - len(response['failures']).should.be(0) + type(response["failures"]).should.be(list) + len(response["failures"]).should.be(0) @mock_ecr def test_batch_get_image_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") + + _ = client.put_image( + repositoryName="test_repository", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="v1", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' - ) - - _ = client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="v2", ) response = client.batch_get_image( - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'v5' - }, - ], + repositoryName="test_repository", imageIds=[{"imageTag": "v5"}] ) - type(response['images']).should.be(list) - len(response['images']).should.be(0) + type(response["images"]).should.be(list) + len(response["images"]).should.be(0) - type(response['failures']).should.be(list) - len(response['failures']).should.be(1) - response['failures'][0]['failureReason'].should.equal("Requested image not found") - response['failures'][0]['failureCode'].should.equal("ImageNotFound") - response['failures'][0]['imageId']['imageTag'].should.equal("v5") + type(response["failures"]).should.be(list) + len(response["failures"]).should.be(1) + response["failures"][0]["failureReason"].should.equal("Requested image not found") + response["failures"][0]["failureCode"].should.equal("ImageNotFound") + response["failures"][0]["imageId"]["imageTag"].should.equal("v5") @mock_ecr def test_batch_get_image_no_tags(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) error_msg = re.compile( - r".*Missing required parameter in input: \"imageIds\".*", - re.MULTILINE) + r".*Missing required parameter in input: \"imageIds\".*", re.MULTILINE + ) client.batch_get_image.when.called_with( - repositoryName='test_repository').should.throw( - ParamValidationError, error_msg) + repositoryName="test_repository" + ).should.throw(ParamValidationError, error_msg) @mock_ecr def test_batch_delete_image_by_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v1.0', 'latest'] + tags = ["v1", "v1.0", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), imageTag=tag, ) - describe_response1 = client.describe_images(repositoryName='test_repository') + describe_response1 = client.describe_images(repositoryName="test_repository") batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'latest' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageTag": "latest"}], ) - describe_response2 = client.describe_images(repositoryName='test_repository') + describe_response2 = client.describe_images(repositoryName="test_repository") - type(describe_response1['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response1['imageDetails'][0]['imageTags']).should.be(3) + type(describe_response1["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response1["imageDetails"][0]["imageTags"]).should.be(3) - type(describe_response2['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response2['imageDetails'][0]['imageTags']).should.be(2) + type(describe_response2["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response2["imageDetails"][0]["imageTags"]).should.be(2) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(1) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(1) - batch_delete_response['imageIds'][0]['imageTag'].should.equal("latest") + batch_delete_response["imageIds"][0]["imageTag"].should.equal("latest") - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_delete_last_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1', + imageTag="v1", ) - describe_response1 = client.describe_images(repositoryName='test_repository') + describe_response1 = client.describe_images(repositoryName="test_repository") batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'v1' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageTag": "v1"}], ) - describe_response2 = client.describe_images(repositoryName='test_repository') + describe_response2 = client.describe_images(repositoryName="test_repository") - type(describe_response1['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response1['imageDetails'][0]['imageTags']).should.be(1) + type(describe_response1["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response1["imageDetails"][0]["imageTags"]).should.be(1) - type(describe_response2['imageDetails']).should.be(list) - len(describe_response2['imageDetails']).should.be(0) + type(describe_response2["imageDetails"]).should.be(list) + len(describe_response2["imageDetails"]).should.be(0) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(1) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(1) - batch_delete_response['imageIds'][0]['imageTag'].should.equal("v1") + batch_delete_response["imageIds"][0]["imageTag"].should.equal("v1") - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_with_nonexistent_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v1.0', 'latest'] + tags = ["v1", "v1.0", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") missing_tag = "missing-tag" batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': missing_tag - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageTag": missing_tag}], ) - type(describe_response['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response['imageDetails'][0]['imageTags']).should.be(3) + type(describe_response["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response["imageDetails"][0]["imageTags"]).should.be(3) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - batch_delete_response['failures'][0]['imageId']['imageTag'].should.equal(missing_tag) - batch_delete_response['failures'][0]['failureCode'].should.equal("ImageNotFound") - batch_delete_response['failures'][0]['failureReason'].should.equal("Requested image not found") + batch_delete_response["failures"][0]["imageId"]["imageTag"].should.equal( + missing_tag + ) + batch_delete_response["failures"][0]["failureCode"].should.equal("ImageNotFound") + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Requested image not found" + ) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) @mock_ecr def test_batch_delete_image_by_digest(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v2', 'latest'] + tags = ["v1", "v2", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') - image_digest = describe_response['imageDetails'][0]['imageDigest'] + describe_response = client.describe_images(repositoryName="test_repository") + image_digest = describe_response["imageDetails"][0]["imageDigest"] batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': image_digest - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": image_digest}], ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") - type(describe_response['imageDetails']).should.be(list) - len(describe_response['imageDetails']).should.be(0) + type(describe_response["imageDetails"]).should.be(list) + len(describe_response["imageDetails"]).should.be(0) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(3) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(3) - batch_delete_response['imageIds'][0]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][1]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][2]['imageDigest'].should.equal(image_digest) + batch_delete_response["imageIds"][0]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][1]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][2]["imageDigest"].should.equal(image_digest) - set([ - batch_delete_response['imageIds'][0]['imageTag'], - batch_delete_response['imageIds'][1]['imageTag'], - batch_delete_response['imageIds'][2]['imageTag']]).should.equal(set(tags)) + set( + [ + batch_delete_response["imageIds"][0]["imageTag"], + batch_delete_response["imageIds"][1]["imageTag"], + batch_delete_response["imageIds"][2]["imageTag"], + ] + ).should.equal(set(tags)) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_with_invalid_digest(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v2', 'latest'] + tags = ["v1", "v2", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - invalid_image_digest = 'sha256:invalid-digest' + invalid_image_digest = "sha256:invalid-digest" batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': invalid_image_digest - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": invalid_image_digest}], ) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) - batch_delete_response['failures'][0]['imageId']['imageDigest'].should.equal(invalid_image_digest) - batch_delete_response['failures'][0]['failureCode'].should.equal("InvalidImageDigest") - batch_delete_response['failures'][0]['failureReason'].should.equal("Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'") + batch_delete_response["failures"][0]["imageId"]["imageDigest"].should.equal( + invalid_image_digest + ) + batch_delete_response["failures"][0]["failureCode"].should.equal( + "InvalidImageDigest" + ) + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'" + ) @mock_ecr def test_batch_delete_image_with_missing_parameters(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - }, - ], + registryId="012345678910", repositoryName="test_repository", imageIds=[{}] ) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) - batch_delete_response['failures'][0]['failureCode'].should.equal("MissingDigestAndTag") - batch_delete_response['failures'][0]['failureReason'].should.equal("Invalid request parameters: both tag and digest cannot be null") + batch_delete_response["failures"][0]["failureCode"].should.equal( + "MissingDigestAndTag" + ) + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Invalid request parameters: both tag and digest cannot be null" + ) @mock_ecr def test_batch_delete_image_with_matching_digest_and_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v1.0', 'latest'] + tags = ["v1", "v1.0", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') - image_digest = describe_response['imageDetails'][0]['imageDigest'] + describe_response = client.describe_images(repositoryName="test_repository") + image_digest = describe_response["imageDetails"][0]["imageDigest"] batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': image_digest, - 'imageTag': 'v1' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": image_digest, "imageTag": "v1"}], ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") - type(describe_response['imageDetails']).should.be(list) - len(describe_response['imageDetails']).should.be(0) + type(describe_response["imageDetails"]).should.be(list) + len(describe_response["imageDetails"]).should.be(0) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(3) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(3) - batch_delete_response['imageIds'][0]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][1]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][2]['imageDigest'].should.equal(image_digest) + batch_delete_response["imageIds"][0]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][1]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][2]["imageDigest"].should.equal(image_digest) - set([ - batch_delete_response['imageIds'][0]['imageTag'], - batch_delete_response['imageIds'][1]['imageTag'], - batch_delete_response['imageIds'][2]['imageTag']]).should.equal(set(tags)) + set( + [ + batch_delete_response["imageIds"][0]["imageTag"], + batch_delete_response["imageIds"][1]["imageTag"], + batch_delete_response["imageIds"][2]["imageTag"], + ] + ).should.equal(set(tags)) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_with_mismatched_digest_and_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'latest'] + tags = ["v1", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') - image_digest = describe_response['imageDetails'][0]['imageDigest'] + describe_response = client.describe_images(repositoryName="test_repository") + image_digest = describe_response["imageDetails"][0]["imageDigest"] batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': image_digest, - 'imageTag': 'v2' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": image_digest, "imageTag": "v2"}], ) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) - batch_delete_response['failures'][0]['imageId']['imageDigest'].should.equal(image_digest) - batch_delete_response['failures'][0]['imageId']['imageTag'].should.equal("v2") - batch_delete_response['failures'][0]['failureCode'].should.equal("ImageNotFound") - batch_delete_response['failures'][0]['failureReason'].should.equal("Requested image not found") + batch_delete_response["failures"][0]["imageId"]["imageDigest"].should.equal( + image_digest + ) + batch_delete_response["failures"][0]["imageId"]["imageTag"].should.equal("v2") + batch_delete_response["failures"][0]["failureCode"].should.equal("ImageNotFound") + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Requested image not found" + ) diff --git a/tests/test_ecs/test_ecs_boto3.py b/tests/test_ecs/test_ecs_boto3.py index 9937af26b..973c95b81 100644 --- a/tests/test_ecs/test_ecs_boto3.py +++ b/tests/test_ecs/test_ecs_boto3.py @@ -18,658 +18,721 @@ from nose.tools import assert_raises @mock_ecs def test_create_cluster(): - client = boto3.client('ecs', region_name='us-east-1') - response = client.create_cluster( - clusterName='test_ecs_cluster' + client = boto3.client("ecs", region_name="us-east-1") + response = client.create_cluster(clusterName="test_ecs_cluster") + response["cluster"]["clusterName"].should.equal("test_ecs_cluster") + response["cluster"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" ) - response['cluster']['clusterName'].should.equal('test_ecs_cluster') - response['cluster']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['cluster']['status'].should.equal('ACTIVE') - response['cluster']['registeredContainerInstancesCount'].should.equal(0) - response['cluster']['runningTasksCount'].should.equal(0) - response['cluster']['pendingTasksCount'].should.equal(0) - response['cluster']['activeServicesCount'].should.equal(0) + response["cluster"]["status"].should.equal("ACTIVE") + response["cluster"]["registeredContainerInstancesCount"].should.equal(0) + response["cluster"]["runningTasksCount"].should.equal(0) + response["cluster"]["pendingTasksCount"].should.equal(0) + response["cluster"]["activeServicesCount"].should.equal(0) @mock_ecs def test_list_clusters(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_cluster0' - ) - _ = client.create_cluster( - clusterName='test_cluster1' - ) + client = boto3.client("ecs", region_name="us-east-2") + _ = client.create_cluster(clusterName="test_cluster0") + _ = client.create_cluster(clusterName="test_cluster1") response = client.list_clusters() - response['clusterArns'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_cluster0') - response['clusterArns'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_cluster1') + response["clusterArns"].should.contain( + "arn:aws:ecs:us-east-2:012345678910:cluster/test_cluster0" + ) + response["clusterArns"].should.contain( + "arn:aws:ecs:us-east-2:012345678910:cluster/test_cluster1" + ) @mock_ecs def test_describe_clusters(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") response = client.describe_clusters(clusters=["some-cluster"]) - response['failures'].should.contain({ - 'arn': 'arn:aws:ecs:us-east-1:012345678910:cluster/some-cluster', - 'reason': 'MISSING' - }) + response["failures"].should.contain( + { + "arn": "arn:aws:ecs:us-east-1:012345678910:cluster/some-cluster", + "reason": "MISSING", + } + ) + @mock_ecs def test_delete_cluster(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") + response = client.delete_cluster(cluster="test_ecs_cluster") + response["cluster"]["clusterName"].should.equal("test_ecs_cluster") + response["cluster"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" ) - response = client.delete_cluster(cluster='test_ecs_cluster') - response['cluster']['clusterName'].should.equal('test_ecs_cluster') - response['cluster']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['cluster']['status'].should.equal('ACTIVE') - response['cluster']['registeredContainerInstancesCount'].should.equal(0) - response['cluster']['runningTasksCount'].should.equal(0) - response['cluster']['pendingTasksCount'].should.equal(0) - response['cluster']['activeServicesCount'].should.equal(0) + response["cluster"]["status"].should.equal("ACTIVE") + response["cluster"]["registeredContainerInstancesCount"].should.equal(0) + response["cluster"]["runningTasksCount"].should.equal(0) + response["cluster"]["pendingTasksCount"].should.equal(0) + response["cluster"]["activeServicesCount"].should.equal(0) response = client.list_clusters() - len(response['clusterArns']).should.equal(0) + len(response["clusterArns"]).should.equal(0) @mock_ecs def test_register_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") response = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } ], tags=[ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ] + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], ) - type(response['taskDefinition']).should.be(dict) - response['taskDefinition']['revision'].should.equal(1) - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['taskDefinition']['containerDefinitions'][ - 0]['name'].should.equal('hello_world') - response['taskDefinition']['containerDefinitions'][0][ - 'image'].should.equal('docker/hello-world:latest') - response['taskDefinition']['containerDefinitions'][ - 0]['cpu'].should.equal(1024) - response['taskDefinition']['containerDefinitions'][ - 0]['memory'].should.equal(400) - response['taskDefinition']['containerDefinitions'][ - 0]['essential'].should.equal(True) - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['name'].should.equal('AWS_ACCESS_KEY_ID') - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['value'].should.equal('SOME_ACCESS_KEY') - response['taskDefinition']['containerDefinitions'][0][ - 'logConfiguration']['logDriver'].should.equal('json-file') + type(response["taskDefinition"]).should.be(dict) + response["taskDefinition"]["revision"].should.equal(1) + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["taskDefinition"]["containerDefinitions"][0]["name"].should.equal( + "hello_world" + ) + response["taskDefinition"]["containerDefinitions"][0]["image"].should.equal( + "docker/hello-world:latest" + ) + response["taskDefinition"]["containerDefinitions"][0]["cpu"].should.equal(1024) + response["taskDefinition"]["containerDefinitions"][0]["memory"].should.equal(400) + response["taskDefinition"]["containerDefinitions"][0]["essential"].should.equal( + True + ) + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "name" + ].should.equal("AWS_ACCESS_KEY_ID") + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "value" + ].should.equal("SOME_ACCESS_KEY") + response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][ + "logDriver" + ].should.equal("json-file") @mock_ecs def test_list_task_definitions(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world2', - 'image': 'docker/hello-world2:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY2' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world2", + "image": "docker/hello-world2:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY2"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.list_task_definitions() - len(response['taskDefinitionArns']).should.equal(2) - response['taskDefinitionArns'][0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['taskDefinitionArns'][1].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2') + len(response["taskDefinitionArns"]).should.equal(2) + response["taskDefinitionArns"][0].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["taskDefinitionArns"][1].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2" + ) + + +@mock_ecs +def test_list_task_definitions_with_family_prefix(): + client = boto3.client("ecs", region_name="us-east-1") + _ = client.register_task_definition( + family="test_ecs_task_a", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + _ = client.register_task_definition( + family="test_ecs_task_a", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + _ = client.register_task_definition( + family="test_ecs_task_b", + containerDefinitions=[ + { + "name": "hello_world2", + "image": "docker/hello-world2:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY2"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + empty_response = client.list_task_definitions(familyPrefix="test_ecs_task") + len(empty_response["taskDefinitionArns"]).should.equal(0) + filtered_response = client.list_task_definitions(familyPrefix="test_ecs_task_a") + len(filtered_response["taskDefinitionArns"]).should.equal(2) + filtered_response["taskDefinitionArns"][0].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task_a:1" + ) + filtered_response["taskDefinitionArns"][1].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task_a:2" + ) @mock_ecs def test_describe_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world2', - 'image': 'docker/hello-world2:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY2' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world2", + "image": "docker/hello-world2:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY2"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world3', - 'image': 'docker/hello-world3:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY3' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world3", + "image": "docker/hello-world3:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY3"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], + ) + response = client.describe_task_definition(taskDefinition="test_ecs_task") + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:3" ) - response = client.describe_task_definition(taskDefinition='test_ecs_task') - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:3') - response = client.describe_task_definition( - taskDefinition='test_ecs_task:2') - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2') + response = client.describe_task_definition(taskDefinition="test_ecs_task:2") + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2" + ) @mock_ecs def test_deregister_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) - response = client.deregister_task_definition( - taskDefinition='test_ecs_task:1' + response = client.deregister_task_definition(taskDefinition="test_ecs_task:1") + type(response["taskDefinition"]).should.be(dict) + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - type(response['taskDefinition']).should.be(dict) - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['taskDefinition']['containerDefinitions'][ - 0]['name'].should.equal('hello_world') - response['taskDefinition']['containerDefinitions'][0][ - 'image'].should.equal('docker/hello-world:latest') - response['taskDefinition']['containerDefinitions'][ - 0]['cpu'].should.equal(1024) - response['taskDefinition']['containerDefinitions'][ - 0]['memory'].should.equal(400) - response['taskDefinition']['containerDefinitions'][ - 0]['essential'].should.equal(True) - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['name'].should.equal('AWS_ACCESS_KEY_ID') - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['value'].should.equal('SOME_ACCESS_KEY') - response['taskDefinition']['containerDefinitions'][0][ - 'logConfiguration']['logDriver'].should.equal('json-file') + response["taskDefinition"]["containerDefinitions"][0]["name"].should.equal( + "hello_world" + ) + response["taskDefinition"]["containerDefinitions"][0]["image"].should.equal( + "docker/hello-world:latest" + ) + response["taskDefinition"]["containerDefinitions"][0]["cpu"].should.equal(1024) + response["taskDefinition"]["containerDefinitions"][0]["memory"].should.equal(400) + response["taskDefinition"]["containerDefinitions"][0]["essential"].should.equal( + True + ) + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "name" + ].should.equal("AWS_ACCESS_KEY_ID") + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "value" + ].should.equal("SOME_ACCESS_KEY") + response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][ + "logDriver" + ].should.equal("json-file") @mock_ecs def test_create_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(2) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(0) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['service']['schedulingStrategy'].should.equal('REPLICA') + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(2) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(0) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["service"]["schedulingStrategy"].should.equal("REPLICA") + @mock_ecs def test_create_service_scheduling_strategy(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, - schedulingStrategy='DAEMON', + schedulingStrategy="DAEMON", ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(2) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(0) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['service']['schedulingStrategy'].should.equal('DAEMON') + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(2) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(0) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["service"]["schedulingStrategy"].should.equal("DAEMON") @mock_ecs def test_list_services(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service1', - taskDefinition='test_ecs_task', - schedulingStrategy='REPLICA', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service1", + taskDefinition="test_ecs_task", + schedulingStrategy="REPLICA", + desiredCount=2, ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service2', - taskDefinition='test_ecs_task', - schedulingStrategy='DAEMON', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service2", + taskDefinition="test_ecs_task", + schedulingStrategy="DAEMON", + desiredCount=2, ) - unfiltered_response = client.list_services( - cluster='test_ecs_cluster' + unfiltered_response = client.list_services(cluster="test_ecs_cluster") + len(unfiltered_response["serviceArns"]).should.equal(2) + unfiltered_response["serviceArns"][0].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + unfiltered_response["serviceArns"][1].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2" ) - len(unfiltered_response['serviceArns']).should.equal(2) - unfiltered_response['serviceArns'][0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') - unfiltered_response['serviceArns'][1].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2') filtered_response = client.list_services( - cluster='test_ecs_cluster', - schedulingStrategy='REPLICA' + cluster="test_ecs_cluster", schedulingStrategy="REPLICA" ) - len(filtered_response['serviceArns']).should.equal(1) - filtered_response['serviceArns'][0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') + len(filtered_response["serviceArns"]).should.equal(1) + filtered_response["serviceArns"][0].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + @mock_ecs def test_describe_services(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service1', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service1", + taskDefinition="test_ecs_task", + desiredCount=2, ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service2', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service2", + taskDefinition="test_ecs_task", + desiredCount=2, ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service3', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service3", + taskDefinition="test_ecs_task", + desiredCount=2, ) response = client.describe_services( - cluster='test_ecs_cluster', - services=['test_ecs_service1', - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2'] + cluster="test_ecs_cluster", + services=[ + "test_ecs_service1", + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2", + ], ) - len(response['services']).should.equal(2) - response['services'][0]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') - response['services'][0]['serviceName'].should.equal('test_ecs_service1') - response['services'][1]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2') - response['services'][1]['serviceName'].should.equal('test_ecs_service2') + len(response["services"]).should.equal(2) + response["services"][0]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + response["services"][0]["serviceName"].should.equal("test_ecs_service1") + response["services"][1]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2" + ) + response["services"][1]["serviceName"].should.equal("test_ecs_service2") - response['services'][0]['deployments'][0]['desiredCount'].should.equal(2) - response['services'][0]['deployments'][0]['pendingCount'].should.equal(2) - response['services'][0]['deployments'][0]['runningCount'].should.equal(0) - response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY') - (datetime.now() - response['services'][0]['deployments'][0]["createdAt"].replace(tzinfo=None)).seconds.should.be.within(0, 10) - (datetime.now() - response['services'][0]['deployments'][0]["updatedAt"].replace(tzinfo=None)).seconds.should.be.within(0, 10) + response["services"][0]["deployments"][0]["desiredCount"].should.equal(2) + response["services"][0]["deployments"][0]["pendingCount"].should.equal(2) + response["services"][0]["deployments"][0]["runningCount"].should.equal(0) + response["services"][0]["deployments"][0]["status"].should.equal("PRIMARY") + ( + datetime.now() + - response["services"][0]["deployments"][0]["createdAt"].replace(tzinfo=None) + ).seconds.should.be.within(0, 10) + ( + datetime.now() + - response["services"][0]["deployments"][0]["updatedAt"].replace(tzinfo=None) + ).seconds.should.be.within(0, 10) @mock_ecs def test_describe_services_scheduling_strategy(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service1', - taskDefinition='test_ecs_task', - desiredCount=2 - ) - _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service2', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service1", + taskDefinition="test_ecs_task", desiredCount=2, - schedulingStrategy='DAEMON' ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service3', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service2", + taskDefinition="test_ecs_task", + desiredCount=2, + schedulingStrategy="DAEMON", + ) + _ = client.create_service( + cluster="test_ecs_cluster", + serviceName="test_ecs_service3", + taskDefinition="test_ecs_task", + desiredCount=2, ) response = client.describe_services( - cluster='test_ecs_cluster', - services=['test_ecs_service1', - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2', - 'test_ecs_service3'] + cluster="test_ecs_cluster", + services=[ + "test_ecs_service1", + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2", + "test_ecs_service3", + ], ) - len(response['services']).should.equal(3) - response['services'][0]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') - response['services'][0]['serviceName'].should.equal('test_ecs_service1') - response['services'][1]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2') - response['services'][1]['serviceName'].should.equal('test_ecs_service2') + len(response["services"]).should.equal(3) + response["services"][0]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + response["services"][0]["serviceName"].should.equal("test_ecs_service1") + response["services"][1]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2" + ) + response["services"][1]["serviceName"].should.equal("test_ecs_service2") - response['services'][0]['deployments'][0]['desiredCount'].should.equal(2) - response['services'][0]['deployments'][0]['pendingCount'].should.equal(2) - response['services'][0]['deployments'][0]['runningCount'].should.equal(0) - response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY') + response["services"][0]["deployments"][0]["desiredCount"].should.equal(2) + response["services"][0]["deployments"][0]["pendingCount"].should.equal(2) + response["services"][0]["deployments"][0]["runningCount"].should.equal(0) + response["services"][0]["deployments"][0]["status"].should.equal("PRIMARY") - response['services'][0]['schedulingStrategy'].should.equal('REPLICA') - response['services'][1]['schedulingStrategy'].should.equal('DAEMON') - response['services'][2]['schedulingStrategy'].should.equal('REPLICA') + response["services"][0]["schedulingStrategy"].should.equal("REPLICA") + response["services"][1]["schedulingStrategy"].should.equal("DAEMON") + response["services"][2]["schedulingStrategy"].should.equal("REPLICA") @mock_ecs def test_update_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, ) - response['service']['desiredCount'].should.equal(2) + response["service"]["desiredCount"].should.equal(2) response = client.update_service( - cluster='test_ecs_cluster', - service='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=0 + cluster="test_ecs_cluster", + service="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=0, ) - response['service']['desiredCount'].should.equal(0) - response['service']['schedulingStrategy'].should.equal('REPLICA') + response["service"]["desiredCount"].should.equal(0) + response["service"]["schedulingStrategy"].should.equal("REPLICA") @mock_ecs def test_update_missing_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") client.update_service.when.called_with( - cluster='test_ecs_cluster', - service='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=0 + cluster="test_ecs_cluster", + service="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=0, ).should.throw(ClientError) @mock_ecs def test_delete_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, ) _ = client.update_service( - cluster='test_ecs_cluster', - service='test_ecs_service', - desiredCount=0 + cluster="test_ecs_cluster", service="test_ecs_service", desiredCount=0 ) response = client.delete_service( - cluster='test_ecs_cluster', - service='test_ecs_service' + cluster="test_ecs_cluster", service="test_ecs_service" + ) + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(0) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(0) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["schedulingStrategy"].should.equal("REPLICA") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(0) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(0) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['schedulingStrategy'].should.equal('REPLICA') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') @mock_ecs def test_update_non_existant_service(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") try: client.update_service( - cluster="my-clustet", - service="my-service", - desiredCount=0, + cluster="my-clustet", service="my-service", desiredCount=0 ) except ClientError as exc: - error_code = exc.response['Error']['Code'] - error_code.should.equal('ServiceNotFoundException') + error_code = exc.response["Error"]["Code"] + error_code.should.equal("ServiceNotFoundException") else: raise Exception("Didn't raise ClientError") @@ -677,19 +740,15 @@ def test_update_non_existant_service(): @mock_ec2 @mock_ecs def test_register_container_instance(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -697,45 +756,37 @@ def test_register_container_instance(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn = response['containerInstance']['containerInstanceArn'] - arn_part = full_arn.split('/') - arn_part[0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:container-instance') + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn = response["containerInstance"]["containerInstanceArn"] + arn_part = full_arn.split("/") + arn_part[0].should.equal("arn:aws:ecs:us-east-1:012345678910:container-instance") arn_part[1].should.equal(str(UUID(arn_part[1]))) - response['containerInstance']['status'].should.equal('ACTIVE') - len(response['containerInstance']['registeredResources']).should.equal(4) - len(response['containerInstance']['remainingResources']).should.equal(4) - response['containerInstance']['agentConnected'].should.equal(True) - response['containerInstance']['versionInfo'][ - 'agentVersion'].should.equal('1.0.0') - response['containerInstance']['versionInfo'][ - 'agentHash'].should.equal('4023248') - response['containerInstance']['versionInfo'][ - 'dockerVersion'].should.equal('DockerVersion: 1.5.0') + response["containerInstance"]["status"].should.equal("ACTIVE") + len(response["containerInstance"]["registeredResources"]).should.equal(4) + len(response["containerInstance"]["remainingResources"]).should.equal(4) + response["containerInstance"]["agentConnected"].should.equal(True) + response["containerInstance"]["versionInfo"]["agentVersion"].should.equal("1.0.0") + response["containerInstance"]["versionInfo"]["agentHash"].should.equal("4023248") + response["containerInstance"]["versionInfo"]["dockerVersion"].should.equal( + "DockerVersion: 1.5.0" + ) @mock_ec2 @mock_ecs def test_deregister_container_instance(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -743,87 +794,76 @@ def test_deregister_container_instance(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instance_id = response['containerInstance']['containerInstanceArn'] + container_instance_id = response["containerInstance"]["containerInstanceArn"] response = ecs_client.deregister_container_instance( - cluster=test_cluster_name, - containerInstance=container_instance_id + cluster=test_cluster_name, containerInstance=container_instance_id ) container_instances_response = ecs_client.list_container_instances( cluster=test_cluster_name ) - len(container_instances_response['containerInstanceArns']).should.equal(0) + len(container_instances_response["containerInstanceArns"]).should.equal(0) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instance_id = response['containerInstance']['containerInstanceArn'] + container_instance_id = response["containerInstance"]["containerInstanceArn"] _ = ecs_client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = ecs_client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='moto' + startedBy="moto", ) with assert_raises(Exception) as e: ecs_client.deregister_container_instance( - cluster=test_cluster_name, - containerInstance=container_instance_id + cluster=test_cluster_name, containerInstance=container_instance_id ).should.have.raised(Exception) container_instances_response = ecs_client.list_container_instances( cluster=test_cluster_name ) - len(container_instances_response['containerInstanceArns']).should.equal(1) + len(container_instances_response["containerInstanceArns"]).should.equal(1) ecs_client.deregister_container_instance( - cluster=test_cluster_name, - containerInstance=container_instance_id, - force=True + cluster=test_cluster_name, containerInstance=container_instance_id, force=True ) container_instances_response = ecs_client.list_container_instances( cluster=test_cluster_name ) - len(container_instances_response['containerInstanceArns']).should.equal(0) + len(container_instances_response["containerInstanceArns"]).should.equal(0) @mock_ec2 @mock_ecs def test_list_container_instances(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -831,37 +871,32 @@ def test_list_container_instances(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance'][ - 'containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) response = ecs_client.list_container_instances(cluster=test_cluster_name) - len(response['containerInstanceArns']).should.equal(instance_to_create) + len(response["containerInstanceArns"]).should.equal(instance_to_create) for arn in test_instance_arns: - response['containerInstanceArns'].should.contain(arn) + response["containerInstanceArns"].should.contain(arn) @mock_ec2 @mock_ecs def test_describe_container_instances(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -869,45 +904,46 @@ def test_describe_container_instances(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance'][ - 'containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) - test_instance_ids = list( - map((lambda x: x.split('/')[1]), test_instance_arns)) + test_instance_ids = list(map((lambda x: x.split("/")[1]), test_instance_arns)) response = ecs_client.describe_container_instances( - cluster=test_cluster_name, containerInstances=test_instance_ids) - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_arns = [ci['containerInstanceArn'] - for ci in response['containerInstances']] + cluster=test_cluster_name, containerInstances=test_instance_ids + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_arns = [ + ci["containerInstanceArn"] for ci in response["containerInstances"] + ] for arn in test_instance_arns: response_arns.should.contain(arn) - for instance in response['containerInstances']: - instance.keys().should.contain('runningTasksCount') - instance.keys().should.contain('pendingTasksCount') + for instance in response["containerInstances"]: + instance.keys().should.contain("runningTasksCount") + instance.keys().should.contain("pendingTasksCount") + + with assert_raises(ClientError) as e: + ecs_client.describe_container_instances( + cluster=test_cluster_name, containerInstances=[] + ) @mock_ec2 @mock_ecs def test_update_container_instances_state(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -915,59 +951,61 @@ def test_update_container_instances_state(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance']['containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) - test_instance_ids = list(map((lambda x: x.split('/')[1]), test_instance_arns)) - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + test_instance_ids = list(map((lambda x: x.split("/")[1]), test_instance_arns)) + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_ids, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_ids, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='ACTIVE') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, containerInstances=test_instance_ids, status="ACTIVE" + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('ACTIVE') - ecs_client.update_container_instances_state.when.called_with(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='test_status').should.throw(Exception) + status.should.equal("ACTIVE") + ecs_client.update_container_instances_state.when.called_with( + cluster=test_cluster_name, + containerInstances=test_instance_ids, + status="test_status", + ).should.throw(Exception) @mock_ec2 @mock_ecs def test_update_container_instances_state_by_arn(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -975,56 +1013,60 @@ def test_update_container_instances_state_by_arn(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance']['containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='ACTIVE') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="ACTIVE", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('ACTIVE') - ecs_client.update_container_instances_state.when.called_with(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='test_status').should.throw(Exception) + status.should.equal("ACTIVE") + ecs_client.update_container_instances_state.when.called_with( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="test_status", + ).should.throw(Exception) @mock_ec2 @mock_ecs def test_run_task(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1032,66 +1074,64 @@ def test_run_task(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(2) - response['tasks'][0]['taskArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:task/') - response['tasks'][0]['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['tasks'][0]['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['tasks'][0]['containerInstanceArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:container-instance/') - response['tasks'][0]['overrides'].should.equal({}) - response['tasks'][0]['lastStatus'].should.equal("RUNNING") - response['tasks'][0]['desiredStatus'].should.equal("RUNNING") - response['tasks'][0]['startedBy'].should.equal("moto") - response['tasks'][0]['stoppedReason'].should.equal("") + len(response["tasks"]).should.equal(2) + response["tasks"][0]["taskArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:task/" + ) + response["tasks"][0]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["tasks"][0]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["tasks"][0]["containerInstanceArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:container-instance/" + ) + response["tasks"][0]["overrides"].should.equal({}) + response["tasks"][0]["lastStatus"].should.equal("RUNNING") + response["tasks"][0]["desiredStatus"].should.equal("RUNNING") + response["tasks"][0]["startedBy"].should.equal("moto") + response["tasks"][0]["stoppedReason"].should.equal("") @mock_ec2 @mock_ecs def test_start_task(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1099,73 +1139,73 @@ def test_start_task(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instances = client.list_container_instances( - cluster=test_cluster_name) - container_instance_id = container_instances[ - 'containerInstanceArns'][0].split('/')[-1] + container_instances = client.list_container_instances(cluster=test_cluster_name) + container_instance_id = container_instances["containerInstanceArns"][0].split("/")[ + -1 + ] _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(1) - response['tasks'][0]['taskArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:task/') - response['tasks'][0]['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['tasks'][0]['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['tasks'][0]['containerInstanceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:container-instance/{0}'.format(container_instance_id)) - response['tasks'][0]['overrides'].should.equal({}) - response['tasks'][0]['lastStatus'].should.equal("RUNNING") - response['tasks'][0]['desiredStatus'].should.equal("RUNNING") - response['tasks'][0]['startedBy'].should.equal("moto") - response['tasks'][0]['stoppedReason'].should.equal("") + len(response["tasks"]).should.equal(1) + response["tasks"][0]["taskArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:task/" + ) + response["tasks"][0]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["tasks"][0]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["tasks"][0]["containerInstanceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format( + container_instance_id + ) + ) + response["tasks"][0]["overrides"].should.equal({}) + response["tasks"][0]["lastStatus"].should.equal("RUNNING") + response["tasks"][0]["desiredStatus"].should.equal("RUNNING") + response["tasks"][0]["startedBy"].should.equal("moto") + response["tasks"][0]["stoppedReason"].should.equal("") @mock_ec2 @mock_ecs def test_list_tasks(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1173,71 +1213,66 @@ def test_list_tasks(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instances = client.list_container_instances( - cluster=test_cluster_name) - container_instance_id = container_instances[ - 'containerInstanceArns'][0].split('/')[-1] + container_instances = client.list_container_instances(cluster=test_cluster_name) + container_instance_id = container_instances["containerInstanceArns"][0].split("/")[ + -1 + ] _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='foo' + startedBy="foo", ) _ = client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='bar' + startedBy="bar", ) - assert len(client.list_tasks()['taskArns']).should.equal(2) - assert len(client.list_tasks(cluster='test_ecs_cluster') - ['taskArns']).should.equal(2) - assert len(client.list_tasks(startedBy='foo')['taskArns']).should.equal(1) + assert len(client.list_tasks()["taskArns"]).should.equal(2) + assert len(client.list_tasks(cluster="test_ecs_cluster")["taskArns"]).should.equal( + 2 + ) + assert len(client.list_tasks(startedBy="foo")["taskArns"]).should.equal(1) @mock_ec2 @mock_ecs def test_describe_tasks(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1245,96 +1280,85 @@ def test_describe_tasks(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) tasks_arns = [ - task['taskArn'] for task in client.run_task( - cluster='test_ecs_cluster', + task["taskArn"] + for task in client.run_task( + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' - )['tasks'] + startedBy="moto", + )["tasks"] ] - response = client.describe_tasks( - cluster='test_ecs_cluster', - tasks=tasks_arns - ) + response = client.describe_tasks(cluster="test_ecs_cluster", tasks=tasks_arns) - len(response['tasks']).should.equal(2) - set([response['tasks'][0]['taskArn'], response['tasks'] - [1]['taskArn']]).should.equal(set(tasks_arns)) + len(response["tasks"]).should.equal(2) + set( + [response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]] + ).should.equal(set(tasks_arns)) # Test we can pass task ids instead of ARNs response = client.describe_tasks( - cluster='test_ecs_cluster', - tasks=[tasks_arns[0].split("/")[-1]] + cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]] ) - len(response['tasks']).should.equal(1) + len(response["tasks"]).should.equal(1) @mock_ecs def describe_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") container_definition = { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [{"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"}], + "logConfiguration": {"logDriver": "json-file"}, } task_definition = client.register_task_definition( - family='test_ecs_task', - containerDefinitions=[container_definition] + family="test_ecs_task", containerDefinitions=[container_definition] ) - family = task_definition['family'] + family = task_definition["family"] task = client.describe_task_definition(taskDefinition=family) - task['containerDefinitions'][0].should.equal(container_definition) - task['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task2:1') - task['volumes'].should.equal([]) + task["containerDefinitions"][0].should.equal(container_definition) + task["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task2:1" + ) + task["volumes"].should.equal([]) @mock_ec2 @mock_ecs def test_stop_task(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1342,63 +1366,58 @@ def test_stop_task(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) run_response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=1, - startedBy='moto' + startedBy="moto", ) stop_response = client.stop_task( - cluster='test_ecs_cluster', - task=run_response['tasks'][0].get('taskArn'), - reason='moto testing' + cluster="test_ecs_cluster", + task=run_response["tasks"][0].get("taskArn"), + reason="moto testing", ) - stop_response['task']['taskArn'].should.equal( - run_response['tasks'][0].get('taskArn')) - stop_response['task']['lastStatus'].should.equal('STOPPED') - stop_response['task']['desiredStatus'].should.equal('STOPPED') - stop_response['task']['stoppedReason'].should.equal('moto testing') + stop_response["task"]["taskArn"].should.equal( + run_response["tasks"][0].get("taskArn") + ) + stop_response["task"]["lastStatus"].should.equal("STOPPED") + stop_response["task"]["desiredStatus"].should.equal("STOPPED") + stop_response["task"]["stoppedReason"].should.equal("moto testing") @mock_ec2 @mock_ecs def test_resource_reservation_and_release(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1406,84 +1425,74 @@ def test_resource_reservation_and_release(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'}, - 'portMappings': [ - { - 'hostPort': 80, - 'containerPort': 8080 - } - ] + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + "portMappings": [{"hostPort": 80, "containerPort": 8080}], } - ] + ], ) run_response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=1, - startedBy='moto' + startedBy="moto", ) - container_instance_arn = run_response['tasks'][0].get('containerInstanceArn') + container_instance_arn = run_response["tasks"][0].get("containerInstanceArn") container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] remaining_resources, registered_resources = _fetch_container_instance_resources( - container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU'] - 1024) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY'] - 400) - registered_resources['PORTS'].append('80') - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(1) + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"] - 1024) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"] - 400) + registered_resources["PORTS"].append("80") + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(1) client.stop_task( - cluster='test_ecs_cluster', - task=run_response['tasks'][0].get('taskArn'), - reason='moto testing' + cluster="test_ecs_cluster", + task=run_response["tasks"][0].get("taskArn"), + reason="moto testing", ) container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] remaining_resources, registered_resources = _fetch_container_instance_resources( - container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU']) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY']) - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(0) + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"]) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"]) + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(0) + @mock_ec2 @mock_ecs def test_resource_reservation_and_release_memory_reservation(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1491,63 +1500,58 @@ def test_resource_reservation_and_release_memory_reservation(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'memoryReservation': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'}, - 'portMappings': [ - { - 'containerPort': 8080 - } - ] + "name": "hello_world", + "image": "docker/hello-world:latest", + "memoryReservation": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + "portMappings": [{"containerPort": 8080}], } - ] + ], ) run_response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=1, - startedBy='moto' + startedBy="moto", ) - container_instance_arn = run_response['tasks'][0].get('containerInstanceArn') + container_instance_arn = run_response["tasks"][0].get("containerInstanceArn") container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] - remaining_resources, registered_resources = _fetch_container_instance_resources(container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU']) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY'] - 400) - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(1) + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] + remaining_resources, registered_resources = _fetch_container_instance_resources( + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"]) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"] - 400) + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(1) client.stop_task( - cluster='test_ecs_cluster', - task=run_response['tasks'][0].get('taskArn'), - reason='moto testing' + cluster="test_ecs_cluster", + task=run_response["tasks"][0].get("taskArn"), + reason="moto testing", ) container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] - remaining_resources, registered_resources = _fetch_container_instance_resources(container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU']) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY']) - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(0) - + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] + remaining_resources, registered_resources = _fetch_container_instance_resources( + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"]) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"]) + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(0) @mock_ecs @@ -1559,26 +1563,21 @@ def test_create_cluster_through_cloudformation(): "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster" - } + "Properties": {"ClusterName": "testcluster"}, } - } + }, } template_json = json.dumps(template) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(0) + len(resp["clusterArns"]).should.equal(0) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json) resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(1) + len(resp["clusterArns"]).should.equal(1) @mock_ecs @@ -1589,22 +1588,15 @@ def test_create_cluster_through_cloudformation_no_name(): template = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "ECS Cluster Test CloudFormation", - "Resources": { - "testCluster": { - "Type": "AWS::ECS::Cluster", - } - } + "Resources": {"testCluster": {"Type": "AWS::ECS::Cluster"}}, } template_json = json.dumps(template) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(1) + len(resp["clusterArns"]).should.equal(1) @mock_ecs @@ -1616,31 +1608,24 @@ def test_update_cluster_name_through_cloudformation_should_trigger_a_replacement "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster1" - } + "Properties": {"ClusterName": "testcluster1"}, } - } + }, } template2 = deepcopy(template1) - template2['Resources']['testCluster'][ - 'Properties']['ClusterName'] = 'testcluster2' + template2["Resources"]["testCluster"]["Properties"]["ClusterName"] = "testcluster2" template1_json = json.dumps(template1) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") stack_resp = cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template1_json, + StackName="test_stack", TemplateBody=template1_json ) template2_json = json.dumps(template2) - cfn_conn.update_stack( - StackName=stack_resp['StackId'], - TemplateBody=template2_json - ) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + cfn_conn.update_stack(StackName=stack_resp["StackId"], TemplateBody=template2_json) + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(1) - resp['clusterArns'][0].endswith('testcluster2').should.be.true + len(resp["clusterArns"]).should.equal(1) + resp["clusterArns"][0].endswith("testcluster2").should.be.true @mock_ecs @@ -1659,47 +1644,42 @@ def test_create_task_definition_through_cloudformation(): "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, } - } + }, } template_json = json.dumps(template) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - stack_name = 'test_stack' - cfn_conn.create_stack( - StackName=stack_name, - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + stack_name = "test_stack" + cfn_conn.create_stack(StackName=stack_name, TemplateBody=template_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_task_definitions() - len(resp['taskDefinitionArns']).should.equal(1) - task_definition_arn = resp['taskDefinitionArns'][0] + len(resp["taskDefinitionArns"]).should.equal(1) + task_definition_arn = resp["taskDefinitionArns"][0] task_definition_details = cfn_conn.describe_stack_resource( - StackName=stack_name,LogicalResourceId='testTaskDefinition')['StackResourceDetail'] - task_definition_details['PhysicalResourceId'].should.equal(task_definition_arn) + StackName=stack_name, LogicalResourceId="testTaskDefinition" + )["StackResourceDetail"] + task_definition_details["PhysicalResourceId"].should.equal(task_definition_arn) + @mock_ec2 @mock_ecs def test_task_definitions_unable_to_be_placed(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1707,53 +1687,47 @@ def test_task_definitions_unable_to_be_placed(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 5000, - 'memory': 40000, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 5000, + "memory": 40000, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(0) + len(response["tasks"]).should.equal(0) @mock_ec2 @mock_ecs def test_task_definitions_with_port_clash(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1761,54 +1735,51 @@ def test_task_definitions_with_port_clash(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 256, - 'memory': 512, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'}, - 'portMappings': [ - { - 'hostPort': 80, - 'containerPort': 8080 - } - ] + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 256, + "memory": 512, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + "portMappings": [{"hostPort": 80, "containerPort": 8080}], } - ] + ], ) response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(1) - response['tasks'][0]['taskArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:task/') - response['tasks'][0]['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['tasks'][0]['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['tasks'][0]['containerInstanceArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:container-instance/') - response['tasks'][0]['overrides'].should.equal({}) - response['tasks'][0]['lastStatus'].should.equal("RUNNING") - response['tasks'][0]['desiredStatus'].should.equal("RUNNING") - response['tasks'][0]['startedBy'].should.equal("moto") - response['tasks'][0]['stoppedReason'].should.equal("") + len(response["tasks"]).should.equal(1) + response["tasks"][0]["taskArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:task/" + ) + response["tasks"][0]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["tasks"][0]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["tasks"][0]["containerInstanceArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:container-instance/" + ) + response["tasks"][0]["overrides"].should.equal({}) + response["tasks"][0]["lastStatus"].should.equal("RUNNING") + response["tasks"][0]["desiredStatus"].should.equal("RUNNING") + response["tasks"][0]["startedBy"].should.equal("moto") + response["tasks"][0]["stoppedReason"].should.equal("") @mock_ecs @@ -1828,35 +1799,29 @@ def test_update_task_definition_family_through_cloudformation_should_trigger_a_r "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, } - } + }, } template1_json = json.dumps(template1) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template1_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template1_json) template2 = deepcopy(template1) - template2['Resources']['testTaskDefinition'][ - 'Properties']['Family'] = 'testTaskDefinition2' + template2["Resources"]["testTaskDefinition"]["Properties"][ + "Family" + ] = "testTaskDefinition2" template2_json = json.dumps(template2) - cfn_conn.update_stack( - StackName="test_stack", - TemplateBody=template2_json, - ) + cfn_conn.update_stack(StackName="test_stack", TemplateBody=template2_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') - resp = ecs_conn.list_task_definitions(familyPrefix='testTaskDefinition') - len(resp['taskDefinitionArns']).should.equal(1) - resp['taskDefinitionArns'][0].endswith( - 'testTaskDefinition2:1').should.be.true + ecs_conn = boto3.client("ecs", region_name="us-west-1") + resp = ecs_conn.list_task_definitions(familyPrefix="testTaskDefinition2") + len(resp["taskDefinitionArns"]).should.equal(1) + resp["taskDefinitionArns"][0].endswith("testTaskDefinition2:1").should.be.true @mock_ecs @@ -1868,9 +1833,7 @@ def test_create_service_through_cloudformation(): "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster" - } + "Properties": {"ClusterName": "testcluster"}, }, "testTaskDefinition": { "Type": "AWS::ECS::TaskDefinition", @@ -1881,11 +1844,11 @@ def test_create_service_through_cloudformation(): "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, }, "testService": { "Type": "AWS::ECS::Service", @@ -1893,20 +1856,17 @@ def test_create_service_through_cloudformation(): "Cluster": {"Ref": "testCluster"}, "DesiredCount": 10, "TaskDefinition": {"Ref": "testTaskDefinition"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') - resp = ecs_conn.list_services(cluster='testcluster') - len(resp['serviceArns']).should.equal(1) + ecs_conn = boto3.client("ecs", region_name="us-west-1") + resp = ecs_conn.list_services(cluster="testcluster") + len(resp["serviceArns"]).should.equal(1) @mock_ecs @@ -1918,9 +1878,7 @@ def test_update_service_through_cloudformation_should_trigger_replacement(): "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster" - } + "Properties": {"ClusterName": "testcluster"}, }, "testTaskDefinition": { "Type": "AWS::ECS::TaskDefinition", @@ -1931,11 +1889,11 @@ def test_update_service_through_cloudformation_should_trigger_replacement(): "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, }, "testService": { "Type": "AWS::ECS::Service", @@ -1943,47 +1901,37 @@ def test_update_service_through_cloudformation_should_trigger_replacement(): "Cluster": {"Ref": "testCluster"}, "TaskDefinition": {"Ref": "testTaskDefinition"}, "DesiredCount": 10, - } - } - } + }, + }, + }, } template_json1 = json.dumps(template1) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json1, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json1) template2 = deepcopy(template1) - template2['Resources']['testService']['Properties']['DesiredCount'] = 5 + template2["Resources"]["testService"]["Properties"]["DesiredCount"] = 5 template2_json = json.dumps(template2) - cfn_conn.update_stack( - StackName="test_stack", - TemplateBody=template2_json, - ) + cfn_conn.update_stack(StackName="test_stack", TemplateBody=template2_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') - resp = ecs_conn.list_services(cluster='testcluster') - len(resp['serviceArns']).should.equal(1) + ecs_conn = boto3.client("ecs", region_name="us-west-1") + resp = ecs_conn.list_services(cluster="testcluster") + len(resp["serviceArns"]).should.equal(1) @mock_ec2 @mock_ecs def test_attributes(): # Combined put, list delete attributes into the same test due to the amount of setup - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instances = [] test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instances.append(test_instance) @@ -1992,18 +1940,14 @@ def test_attributes(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn1 = response['containerInstance']['containerInstanceArn'] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn1 = response["containerInstance"]["containerInstanceArn"] test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instances.append(test_instance) @@ -2012,133 +1956,143 @@ def test_attributes(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn2 = response['containerInstance']['containerInstanceArn'] - partial_arn2 = full_arn2.rsplit('/', 1)[-1] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn2 = response["containerInstance"]["containerInstanceArn"] + partial_arn2 = full_arn2.rsplit("/", 1)[-1] - full_arn2.should_not.equal(full_arn1) # uuid1 isnt unique enough when the pc is fast ;-) + full_arn2.should_not.equal( + full_arn1 + ) # uuid1 isnt unique enough when the pc is fast ;-) # Ok set instance 1 with 1 attribute, instance 2 with another, and all of them with a 3rd. ecs_client.put_attributes( cluster=test_cluster_name, attributes=[ - {'name': 'env', 'value': 'prod'}, - {'name': 'attr1', 'value': 'instance1', 'targetId': full_arn1}, - {'name': 'attr1', 'value': 'instance2', 'targetId': partial_arn2, - 'targetType': 'container-instance'} - ] + {"name": "env", "value": "prod"}, + {"name": "attr1", "value": "instance1", "targetId": full_arn1}, + { + "name": "attr1", + "value": "instance2", + "targetId": partial_arn2, + "targetType": "container-instance", + }, + ], ) resp = ecs_client.list_attributes( - cluster=test_cluster_name, - targetType='container-instance' + cluster=test_cluster_name, targetType="container-instance" ) - attrs = resp['attributes'] + attrs = resp["attributes"] NUM_CUSTOM_ATTRIBUTES = 4 # 2 specific to individual machines and 1 global, going to both machines (2 + 1*2) NUM_DEFAULT_ATTRIBUTES = 4 - len(attrs).should.equal(NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances))) + len(attrs).should.equal( + NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances)) + ) # Tests that the attrs have been set properly - len(list(filter(lambda item: item['name'] == 'env', attrs))).should.equal(2) - len(list( - filter(lambda item: item['name'] == 'attr1' and item['value'] == 'instance1', attrs))).should.equal(1) + len(list(filter(lambda item: item["name"] == "env", attrs))).should.equal(2) + len( + list( + filter( + lambda item: item["name"] == "attr1" and item["value"] == "instance1", + attrs, + ) + ) + ).should.equal(1) ecs_client.delete_attributes( cluster=test_cluster_name, attributes=[ - {'name': 'attr1', 'value': 'instance2', 'targetId': partial_arn2, - 'targetType': 'container-instance'} - ] + { + "name": "attr1", + "value": "instance2", + "targetId": partial_arn2, + "targetType": "container-instance", + } + ], ) NUM_CUSTOM_ATTRIBUTES -= 1 resp = ecs_client.list_attributes( - cluster=test_cluster_name, - targetType='container-instance' + cluster=test_cluster_name, targetType="container-instance" + ) + attrs = resp["attributes"] + len(attrs).should.equal( + NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances)) ) - attrs = resp['attributes'] - len(attrs).should.equal(NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances))) @mock_ecs def test_poll_endpoint(): # Combined put, list delete attributes into the same test due to the amount of setup - ecs_client = boto3.client('ecs', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") # Just a placeholder until someone actually wants useless data, just testing it doesnt raise an exception - resp = ecs_client.discover_poll_endpoint(cluster='blah', containerInstance='blah') - resp.should.contain('endpoint') - resp.should.contain('telemetryEndpoint') + resp = ecs_client.discover_poll_endpoint(cluster="blah", containerInstance="blah") + resp.should.contain("endpoint") + resp.should.contain("telemetryEndpoint") @mock_ecs def test_list_task_definition_families(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) client.register_task_definition( - family='alt_test_ecs_task', + family="alt_test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) resp1 = client.list_task_definition_families() - resp2 = client.list_task_definition_families(familyPrefix='alt') + resp2 = client.list_task_definition_families(familyPrefix="alt") - len(resp1['families']).should.equal(2) - len(resp2['families']).should.equal(1) + len(resp1["families"]).should.equal(2) + len(resp2["families"]).should.equal(1) @mock_ec2 @mock_ecs def test_default_container_instance_attributes(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" # Create cluster and EC2 instance - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -2147,44 +2101,42 @@ def test_default_container_instance_attributes(): # Register container instance response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn = response['containerInstance']['containerInstanceArn'] - container_instance_id = full_arn.rsplit('/', 1)[-1] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn = response["containerInstance"]["containerInstanceArn"] + container_instance_id = full_arn.rsplit("/", 1)[-1] - default_attributes = response['containerInstance']['attributes'] + default_attributes = response["containerInstance"]["attributes"] assert len(default_attributes) == 4 expected_result = [ - {'name': 'ecs.availability-zone', 'value': test_instance.placement['AvailabilityZone']}, - {'name': 'ecs.ami-id', 'value': test_instance.image_id}, - {'name': 'ecs.instance-type', 'value': test_instance.instance_type}, - {'name': 'ecs.os-type', 'value': test_instance.platform or 'linux'} + { + "name": "ecs.availability-zone", + "value": test_instance.placement["AvailabilityZone"], + }, + {"name": "ecs.ami-id", "value": test_instance.image_id}, + {"name": "ecs.instance-type", "value": test_instance.instance_type}, + {"name": "ecs.os-type", "value": test_instance.platform or "linux"}, ] - assert sorted(default_attributes, key=lambda item: item['name']) == sorted(expected_result, - key=lambda item: item['name']) + assert sorted(default_attributes, key=lambda item: item["name"]) == sorted( + expected_result, key=lambda item: item["name"] + ) @mock_ec2 @mock_ecs def test_describe_container_instances_with_attributes(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" # Create cluster and EC2 instance - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -2193,170 +2145,395 @@ def test_describe_container_instances_with_attributes(): # Register container instance response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn = response['containerInstance']['containerInstanceArn'] - container_instance_id = full_arn.rsplit('/', 1)[-1] - default_attributes = response['containerInstance']['attributes'] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn = response["containerInstance"]["containerInstanceArn"] + container_instance_id = full_arn.rsplit("/", 1)[-1] + default_attributes = response["containerInstance"]["attributes"] # Set attributes on container instance, one without a value attributes = [ - {'name': 'env', 'value': 'prod'}, - {'name': 'attr1', 'value': 'instance1', 'targetId': container_instance_id, - 'targetType': 'container-instance'}, - {'name': 'attr_without_value'} + {"name": "env", "value": "prod"}, + { + "name": "attr1", + "value": "instance1", + "targetId": container_instance_id, + "targetType": "container-instance", + }, + {"name": "attr_without_value"}, ] - ecs_client.put_attributes( - cluster=test_cluster_name, - attributes=attributes - ) + ecs_client.put_attributes(cluster=test_cluster_name, attributes=attributes) # Describe container instance, should have attributes previously set - described_instance = ecs_client.describe_container_instances(cluster=test_cluster_name, - containerInstances=[container_instance_id]) + described_instance = ecs_client.describe_container_instances( + cluster=test_cluster_name, containerInstances=[container_instance_id] + ) - assert len(described_instance['containerInstances']) == 1 - assert isinstance(described_instance['containerInstances'][0]['attributes'], list) + assert len(described_instance["containerInstances"]) == 1 + assert isinstance(described_instance["containerInstances"][0]["attributes"], list) # Remove additional info passed to put_attributes cleaned_attributes = [] for attribute in attributes: - attribute.pop('targetId', None) - attribute.pop('targetType', None) + attribute.pop("targetId", None) + attribute.pop("targetType", None) cleaned_attributes.append(attribute) - described_attributes = sorted(described_instance['containerInstances'][0]['attributes'], - key=lambda item: item['name']) - expected_attributes = sorted(default_attributes + cleaned_attributes, key=lambda item: item['name']) + described_attributes = sorted( + described_instance["containerInstances"][0]["attributes"], + key=lambda item: item["name"], + ) + expected_attributes = sorted( + default_attributes + cleaned_attributes, key=lambda item: item["name"] + ) assert described_attributes == expected_attributes def _fetch_container_instance_resources(container_instance_description): remaining_resources = {} registered_resources = {} - remaining_resources_list = container_instance_description['remainingResources'] - registered_resources_list = container_instance_description['registeredResources'] - remaining_resources['CPU'] = [x['integerValue'] for x in remaining_resources_list if x['name'] == 'CPU'][ - 0] - remaining_resources['MEMORY'] = \ - [x['integerValue'] for x in remaining_resources_list if x['name'] == 'MEMORY'][0] - remaining_resources['PORTS'] = \ - [x['stringSetValue'] for x in remaining_resources_list if x['name'] == 'PORTS'][0] - registered_resources['CPU'] = \ - [x['integerValue'] for x in registered_resources_list if x['name'] == 'CPU'][0] - registered_resources['MEMORY'] = \ - [x['integerValue'] for x in registered_resources_list if x['name'] == 'MEMORY'][0] - registered_resources['PORTS'] = \ - [x['stringSetValue'] for x in registered_resources_list if x['name'] == 'PORTS'][0] + remaining_resources_list = container_instance_description["remainingResources"] + registered_resources_list = container_instance_description["registeredResources"] + remaining_resources["CPU"] = [ + x["integerValue"] for x in remaining_resources_list if x["name"] == "CPU" + ][0] + remaining_resources["MEMORY"] = [ + x["integerValue"] for x in remaining_resources_list if x["name"] == "MEMORY" + ][0] + remaining_resources["PORTS"] = [ + x["stringSetValue"] for x in remaining_resources_list if x["name"] == "PORTS" + ][0] + registered_resources["CPU"] = [ + x["integerValue"] for x in registered_resources_list if x["name"] == "CPU" + ][0] + registered_resources["MEMORY"] = [ + x["integerValue"] for x in registered_resources_list if x["name"] == "MEMORY" + ][0] + registered_resources["PORTS"] = [ + x["stringSetValue"] for x in registered_resources_list if x["name"] == "PORTS" + ][0] return remaining_resources, registered_resources @mock_ecs def test_create_service_load_balancing(): - client = boto3.client('ecs', region_name='us-east-1') - client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + client.create_cluster(clusterName="test_ecs_cluster") client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, loadBalancers=[ { - 'targetGroupArn': 'test_target_group_arn', - 'loadBalancerName': 'test_load_balancer_name', - 'containerName': 'test_container_name', - 'containerPort': 123 + "targetGroupArn": "test_target_group_arn", + "loadBalancerName": "test_load_balancer_name", + "containerName": "test_container_name", + "containerPort": 123, } - ] + ], + ) + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(2) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(1) + response["service"]["loadBalancers"][0]["targetGroupArn"].should.equal( + "test_target_group_arn" + ) + response["service"]["loadBalancers"][0]["loadBalancerName"].should.equal( + "test_load_balancer_name" + ) + response["service"]["loadBalancers"][0]["containerName"].should.equal( + "test_container_name" + ) + response["service"]["loadBalancers"][0]["containerPort"].should.equal(123) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(2) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(1) - response['service']['loadBalancers'][0]['targetGroupArn'].should.equal( - 'test_target_group_arn') - response['service']['loadBalancers'][0]['loadBalancerName'].should.equal( - 'test_load_balancer_name') - response['service']['loadBalancers'][0]['containerName'].should.equal( - 'test_container_name') - response['service']['loadBalancers'][0]['containerPort'].should.equal(123) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') @mock_ecs def test_list_tags_for_resource(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") response = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } ], tags=[ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ] + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], + ) + type(response["taskDefinition"]).should.be(dict) + response["taskDefinition"]["revision"].should.equal(1) + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - type(response['taskDefinition']).should.be(dict) - response['taskDefinition']['revision'].should.equal(1) - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - task_definition_arn = response['taskDefinition']['taskDefinitionArn'] + task_definition_arn = response["taskDefinition"]["taskDefinitionArn"] response = client.list_tags_for_resource(resourceArn=task_definition_arn) - type(response['tags']).should.be(list) - response['tags'].should.equal([ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ]) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [{"key": "createdBy", "value": "moto-unittest"}, {"key": "foo", "value": "bar"}] + ) @mock_ecs def test_list_tags_for_resource_unknown(): - client = boto3.client('ecs', region_name='us-east-1') - task_definition_arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/unknown:1' + client = boto3.client("ecs", region_name="us-east-1") + task_definition_arn = "arn:aws:ecs:us-east-1:012345678910:task-definition/unknown:1" try: client.list_tags_for_resource(resourceArn=task_definition_arn) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") + + +@mock_ecs +def test_list_tags_for_resource_ecs_service(): + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + response = client.create_service( + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, + tags=[ + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [{"key": "createdBy", "value": "moto-unittest"}, {"key": "foo", "value": "bar"}] + ) + + +@mock_ecs +def test_list_tags_for_resource_unknown_service(): + client = boto3.client("ecs", region_name="us-east-1") + service_arn = "arn:aws:ecs:us-east-1:012345678910:service/unknown:1" + try: + client.list_tags_for_resource(resourceArn=service_arn) + except ClientError as err: + err.response["Error"]["Code"].should.equal("ServiceNotFoundException") + + +@mock_ecs +def test_ecs_service_tag_resource(): + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + response = client.create_service( + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, + ) + client.tag_resource( + resourceArn=response["service"]["serviceArn"], + tags=[ + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [{"key": "createdBy", "value": "moto-unittest"}, {"key": "foo", "value": "bar"}] + ) + + +@mock_ecs +def test_ecs_service_tag_resource_overwrites_tag(): + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + response = client.create_service( + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, + tags=[{"key": "foo", "value": "bar"}], + ) + client.tag_resource( + resourceArn=response["service"]["serviceArn"], + tags=[ + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "hello world"}, + ], + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [ + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "hello world"}, + ] + ) + + +@mock_ecs +def test_ecs_service_untag_resource(): + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + response = client.create_service( + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, + tags=[{"key": "foo", "value": "bar"}], + ) + client.untag_resource( + resourceArn=response["service"]["serviceArn"], tagKeys=["foo"] + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + response["tags"].should.equal([]) + + +@mock_ecs +def test_ecs_service_untag_resource_multiple_tags(): + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") + _ = client.register_task_definition( + family="test_ecs_task", + containerDefinitions=[ + { + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + } + ], + ) + response = client.create_service( + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, + tags=[ + {"key": "foo", "value": "bar"}, + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "hello", "value": "world"}, + ], + ) + client.untag_resource( + resourceArn=response["service"]["serviceArn"], tagKeys=["foo", "createdBy"] + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + response["tags"].should.equal([{"key": "hello", "value": "world"}]) diff --git a/tests/test_elb/test_elb.py b/tests/test_elb/test_elb.py index 447896f15..1583ea544 100644 --- a/tests/test_elb/test_elb.py +++ b/tests/test_elb/test_elb.py @@ -15,6 +15,7 @@ from nose.tools import assert_raises import sure # noqa from moto import mock_elb, mock_ec2, mock_elb_deprecated, mock_ec2_deprecated +from moto.core import ACCOUNT_ID @mock_elb_deprecated @@ -23,19 +24,20 @@ def test_create_load_balancer(): conn = boto.connect_elb() ec2 = boto.ec2.connect_to_region("us-east-1") - security_group = ec2.create_security_group('sg-abc987', 'description') + security_group = ec2.create_security_group("sg-abc987", "description") - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports, scheme='internal', security_groups=[security_group.id]) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer( + "my-lb", zones, ports, scheme="internal", security_groups=[security_group.id] + ) balancers = conn.get_all_load_balancers() balancer = balancers[0] balancer.name.should.equal("my-lb") balancer.scheme.should.equal("internal") list(balancer.security_groups).should.equal([security_group.id]) - set(balancer.availability_zones).should.equal( - set(['us-east-1a', 'us-east-1b'])) + set(balancer.availability_zones).should.equal(set(["us-east-1a", "us-east-1b"])) listener1 = balancer.listeners[0] listener1.load_balancer_port.should.equal(80) listener1.instance_port.should.equal(8080) @@ -50,19 +52,20 @@ def test_create_load_balancer(): def test_getting_missing_elb(): conn = boto.connect_elb() conn.get_all_load_balancers.when.called_with( - load_balancer_names='aaa').should.throw(BotoServerError) + load_balancer_names="aaa" + ).should.throw(BotoServerError) @mock_elb_deprecated def test_create_elb_in_multiple_region(): - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] west1_conn = boto.ec2.elb.connect_to_region("us-west-1") - west1_conn.create_load_balancer('my-lb', zones, ports) + west1_conn.create_load_balancer("my-lb", zones, ports) west2_conn = boto.ec2.elb.connect_to_region("us-west-2") - west2_conn.create_load_balancer('my-lb', zones, ports) + west2_conn.create_load_balancer("my-lb", zones, ports) list(west1_conn.get_all_load_balancers()).should.have.length_of(1) list(west2_conn.get_all_load_balancers()).should.have.length_of(1) @@ -72,117 +75,128 @@ def test_create_elb_in_multiple_region(): def test_create_load_balancer_with_certificate(): conn = boto.connect_elb() - zones = ['us-east-1a'] + zones = ["us-east-1a"] ports = [ - (443, 8443, 'https', 'arn:aws:iam:123456789012:server-certificate/test-cert')] - conn.create_load_balancer('my-lb', zones, ports) + ( + 443, + 8443, + "https", + "arn:aws:iam:{}:server-certificate/test-cert".format(ACCOUNT_ID), + ) + ] + conn.create_load_balancer("my-lb", zones, ports) balancers = conn.get_all_load_balancers() balancer = balancers[0] balancer.name.should.equal("my-lb") balancer.scheme.should.equal("internet-facing") - set(balancer.availability_zones).should.equal(set(['us-east-1a'])) + set(balancer.availability_zones).should.equal(set(["us-east-1a"])) listener = balancer.listeners[0] listener.load_balancer_port.should.equal(443) listener.instance_port.should.equal(8443) listener.protocol.should.equal("HTTPS") listener.ssl_certificate_id.should.equal( - 'arn:aws:iam:123456789012:server-certificate/test-cert') + "arn:aws:iam:{}:server-certificate/test-cert".format(ACCOUNT_ID) + ) @mock_elb def test_create_and_delete_boto3_support(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(1) + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(1) - client.delete_load_balancer( - LoadBalancerName='my-lb' - ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(0) + client.delete_load_balancer(LoadBalancerName="my-lb") + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(0) @mock_elb def test_create_load_balancer_with_no_listeners_defined(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") with assert_raises(ClientError): client.create_load_balancer( - LoadBalancerName='my-lb', + LoadBalancerName="my-lb", Listeners=[], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + AvailabilityZones=["us-east-1a", "us-east-1b"], ) @mock_elb def test_describe_paginated_balancers(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") for i in range(51): client.create_load_balancer( - LoadBalancerName='my-lb%d' % i, + LoadBalancerName="my-lb%d" % i, Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + {"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080} + ], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) resp = client.describe_load_balancers() - resp['LoadBalancerDescriptions'].should.have.length_of(50) - resp['NextMarker'].should.equal(resp['LoadBalancerDescriptions'][-1]['LoadBalancerName']) - resp2 = client.describe_load_balancers(Marker=resp['NextMarker']) - resp2['LoadBalancerDescriptions'].should.have.length_of(1) - assert 'NextToken' not in resp2.keys() + resp["LoadBalancerDescriptions"].should.have.length_of(50) + resp["NextMarker"].should.equal( + resp["LoadBalancerDescriptions"][-1]["LoadBalancerName"] + ) + resp2 = client.describe_load_balancers(Marker=resp["NextMarker"]) + resp2["LoadBalancerDescriptions"].should.have.length_of(1) + assert "NextToken" not in resp2.keys() @mock_elb @mock_ec2 def test_apply_security_groups_to_load_balancer(): - client = boto3.client('elb', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") security_group = ec2.create_security_group( - GroupName='sg01', Description='Test security group sg01', VpcId=vpc.id) + GroupName="sg01", Description="Test security group sg01", VpcId=vpc.id + ) client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) response = client.apply_security_groups_to_load_balancer( - LoadBalancerName='my-lb', - SecurityGroups=[security_group.id]) + LoadBalancerName="my-lb", SecurityGroups=[security_group.id] + ) - assert response['SecurityGroups'] == [security_group.id] - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - assert balancer['SecurityGroups'] == [security_group.id] + assert response["SecurityGroups"] == [security_group.id] + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + assert balancer["SecurityGroups"] == [security_group.id] # Using a not-real security group raises an error with assert_raises(ClientError) as error: response = client.apply_security_groups_to_load_balancer( - LoadBalancerName='my-lb', - SecurityGroups=['not-really-a-security-group']) - assert "One or more of the specified security groups do not exist." in str(error.exception) + LoadBalancerName="my-lb", SecurityGroups=["not-really-a-security-group"] + ) + assert "One or more of the specified security groups do not exist." in str( + error.exception + ) @mock_elb_deprecated def test_add_listener(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http')] - conn.create_load_balancer('my-lb', zones, ports) - new_listener = (443, 8443, 'tcp') - conn.create_load_balancer_listeners('my-lb', [new_listener]) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http")] + conn.create_load_balancer("my-lb", zones, ports) + new_listener = (443, 8443, "tcp") + conn.create_load_balancer_listeners("my-lb", [new_listener]) balancers = conn.get_all_load_balancers() balancer = balancers[0] listener1 = balancer.listeners[0] @@ -199,10 +213,10 @@ def test_add_listener(): def test_delete_listener(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports) - conn.delete_load_balancer_listeners('my-lb', [443]) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb", zones, ports) + conn.delete_load_balancer_listeners("my-lb", [443]) balancers = conn.get_all_load_balancers() balancer = balancers[0] listener1 = balancer.listeners[0] @@ -214,61 +228,57 @@ def test_delete_listener(): @mock_elb def test_create_and_delete_listener_boto3_support(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(1) + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(1) client.create_load_balancer_listeners( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 443, 'InstancePort': 8443}] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 443, "InstancePort": 8443}], ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - list(balancer['ListenerDescriptions']).should.have.length_of(2) - balancer['ListenerDescriptions'][0][ - 'Listener']['Protocol'].should.equal('HTTP') - balancer['ListenerDescriptions'][0]['Listener'][ - 'LoadBalancerPort'].should.equal(80) - balancer['ListenerDescriptions'][0]['Listener'][ - 'InstancePort'].should.equal(8080) - balancer['ListenerDescriptions'][1][ - 'Listener']['Protocol'].should.equal('TCP') - balancer['ListenerDescriptions'][1]['Listener'][ - 'LoadBalancerPort'].should.equal(443) - balancer['ListenerDescriptions'][1]['Listener'][ - 'InstancePort'].should.equal(8443) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + list(balancer["ListenerDescriptions"]).should.have.length_of(2) + balancer["ListenerDescriptions"][0]["Listener"]["Protocol"].should.equal("HTTP") + balancer["ListenerDescriptions"][0]["Listener"]["LoadBalancerPort"].should.equal(80) + balancer["ListenerDescriptions"][0]["Listener"]["InstancePort"].should.equal(8080) + balancer["ListenerDescriptions"][1]["Listener"]["Protocol"].should.equal("TCP") + balancer["ListenerDescriptions"][1]["Listener"]["LoadBalancerPort"].should.equal( + 443 + ) + balancer["ListenerDescriptions"][1]["Listener"]["InstancePort"].should.equal(8443) # Creating this listener with an conflicting definition throws error with assert_raises(ClientError): client.create_load_balancer_listeners( - LoadBalancerName='my-lb', + LoadBalancerName="my-lb", Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 443, 'InstancePort': 1234}] + {"Protocol": "tcp", "LoadBalancerPort": 443, "InstancePort": 1234} + ], ) client.delete_load_balancer_listeners( - LoadBalancerName='my-lb', - LoadBalancerPorts=[443]) + LoadBalancerName="my-lb", LoadBalancerPorts=[443] + ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - list(balancer['ListenerDescriptions']).should.have.length_of(1) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + list(balancer["ListenerDescriptions"]).should.have.length_of(1) @mock_elb_deprecated def test_set_sslcertificate(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports) - conn.set_lb_listener_SSL_certificate('my-lb', '443', 'arn:certificate') + zones = ["us-east-1a", "us-east-1b"] + ports = [(443, 8443, "tcp")] + conn.create_load_balancer("my-lb", zones, ports) + conn.set_lb_listener_SSL_certificate("my-lb", "443", "arn:certificate") balancers = conn.get_all_load_balancers() balancer = balancers[0] listener1 = balancer.listeners[0] @@ -282,26 +292,26 @@ def test_set_sslcertificate(): def test_get_load_balancers_by_name(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb1', zones, ports) - conn.create_load_balancer('my-lb2', zones, ports) - conn.create_load_balancer('my-lb3', zones, ports) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb1", zones, ports) + conn.create_load_balancer("my-lb2", zones, ports) + conn.create_load_balancer("my-lb3", zones, ports) conn.get_all_load_balancers().should.have.length_of(3) + conn.get_all_load_balancers(load_balancer_names=["my-lb1"]).should.have.length_of(1) conn.get_all_load_balancers( - load_balancer_names=['my-lb1']).should.have.length_of(1) - conn.get_all_load_balancers( - load_balancer_names=['my-lb1', 'my-lb2']).should.have.length_of(2) + load_balancer_names=["my-lb1", "my-lb2"] + ).should.have.length_of(2) @mock_elb_deprecated def test_delete_load_balancer(): conn = boto.connect_elb() - zones = ['us-east-1a'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports) + zones = ["us-east-1a"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb", zones, ports) balancers = conn.get_all_load_balancers() balancers.should.have.length_of(1) @@ -319,12 +329,12 @@ def test_create_health_check(): interval=20, healthy_threshold=3, unhealthy_threshold=5, - target='HTTP:8080/health', + target="HTTP:8080/health", timeout=23, ) - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) lb.configure_health_check(hc) balancer = conn.get_all_load_balancers()[0] @@ -332,50 +342,49 @@ def test_create_health_check(): health_check.interval.should.equal(20) health_check.healthy_threshold.should.equal(3) health_check.unhealthy_threshold.should.equal(5) - health_check.target.should.equal('HTTP:8080/health') + health_check.target.should.equal("HTTP:8080/health") health_check.timeout.should.equal(23) @mock_elb def test_create_health_check_boto3(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) client.configure_health_check( - LoadBalancerName='my-lb', + LoadBalancerName="my-lb", HealthCheck={ - 'Target': 'HTTP:8080/health', - 'Interval': 20, - 'Timeout': 23, - 'HealthyThreshold': 3, - 'UnhealthyThreshold': 5 - } + "Target": "HTTP:8080/health", + "Interval": 20, + "Timeout": 23, + "HealthyThreshold": 3, + "UnhealthyThreshold": 5, + }, ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - balancer['HealthCheck']['Target'].should.equal('HTTP:8080/health') - balancer['HealthCheck']['Interval'].should.equal(20) - balancer['HealthCheck']['Timeout'].should.equal(23) - balancer['HealthCheck']['HealthyThreshold'].should.equal(3) - balancer['HealthCheck']['UnhealthyThreshold'].should.equal(5) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + balancer["HealthCheck"]["Target"].should.equal("HTTP:8080/health") + balancer["HealthCheck"]["Interval"].should.equal(20) + balancer["HealthCheck"]["Timeout"].should.equal(23) + balancer["HealthCheck"]["HealthyThreshold"].should.equal(3) + balancer["HealthCheck"]["UnhealthyThreshold"].should.equal(5) @mock_ec2_deprecated @mock_elb_deprecated def test_register_instances(): ec2_conn = boto.connect_ec2() - reservation = ec2_conn.run_instances('ami-1234abcd', 2) + reservation = ec2_conn.run_instances("ami-1234abcd", 2) instance_id1 = reservation.instances[0].id instance_id2 = reservation.instances[1].id conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) lb.register_instances([instance_id1, instance_id2]) @@ -387,29 +396,23 @@ def test_register_instances(): @mock_ec2 @mock_elb def test_register_instances_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=2, MaxCount=2) + ec2 = boto3.resource("ec2", region_name="us-east-1") + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=2, MaxCount=2) instance_id1 = response[0].id instance_id2 = response[1].id - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) client.register_instances_with_load_balancer( - LoadBalancerName='my-lb', - Instances=[ - {'InstanceId': instance_id1}, - {'InstanceId': instance_id2} - ] + LoadBalancerName="my-lb", + Instances=[{"InstanceId": instance_id1}, {"InstanceId": instance_id2}], ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - instance_ids = [instance['InstanceId'] - for instance in balancer['Instances']] + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + instance_ids = [instance["InstanceId"] for instance in balancer["Instances"]] set(instance_ids).should.equal(set([instance_id1, instance_id2])) @@ -417,13 +420,13 @@ def test_register_instances_boto3(): @mock_elb_deprecated def test_deregister_instances(): ec2_conn = boto.connect_ec2() - reservation = ec2_conn.run_instances('ami-1234abcd', 2) + reservation = ec2_conn.run_instances("ami-1234abcd", 2) instance_id1 = reservation.instances[0].id instance_id2 = reservation.instances[1].id conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) lb.register_instances([instance_id1, instance_id2]) @@ -438,47 +441,39 @@ def test_deregister_instances(): @mock_ec2 @mock_elb def test_deregister_instances_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=2, MaxCount=2) + ec2 = boto3.resource("ec2", region_name="us-east-1") + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=2, MaxCount=2) instance_id1 = response[0].id instance_id2 = response[1].id - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) client.register_instances_with_load_balancer( - LoadBalancerName='my-lb', - Instances=[ - {'InstanceId': instance_id1}, - {'InstanceId': instance_id2} - ] + LoadBalancerName="my-lb", + Instances=[{"InstanceId": instance_id1}, {"InstanceId": instance_id2}], ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - balancer['Instances'].should.have.length_of(2) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + balancer["Instances"].should.have.length_of(2) client.deregister_instances_from_load_balancer( - LoadBalancerName='my-lb', - Instances=[ - {'InstanceId': instance_id1} - ] + LoadBalancerName="my-lb", Instances=[{"InstanceId": instance_id1}] ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - balancer['Instances'].should.have.length_of(1) - balancer['Instances'][0]['InstanceId'].should.equal(instance_id2) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + balancer["Instances"].should.have.length_of(1) + balancer["Instances"][0]["InstanceId"].should.equal(instance_id2) @mock_elb_deprecated def test_default_attributes(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) attributes = lb.get_attributes() attributes.cross_zone_load_balancing.enabled.should.be.false @@ -490,8 +485,8 @@ def test_default_attributes(): @mock_elb_deprecated def test_cross_zone_load_balancing_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) conn.modify_lb_attribute("my-lb", "CrossZoneLoadBalancing", True) attributes = lb.get_attributes(force=True) @@ -505,28 +500,25 @@ def test_cross_zone_load_balancing_attribute(): @mock_elb_deprecated def test_connection_draining_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) connection_draining = ConnectionDrainingAttribute() connection_draining.enabled = True connection_draining.timeout = 60 - conn.modify_lb_attribute( - "my-lb", "ConnectionDraining", connection_draining) + conn.modify_lb_attribute("my-lb", "ConnectionDraining", connection_draining) attributes = lb.get_attributes(force=True) attributes.connection_draining.enabled.should.be.true attributes.connection_draining.timeout.should.equal(60) connection_draining.timeout = 30 - conn.modify_lb_attribute( - "my-lb", "ConnectionDraining", connection_draining) + conn.modify_lb_attribute("my-lb", "ConnectionDraining", connection_draining) attributes = lb.get_attributes(force=True) attributes.connection_draining.timeout.should.equal(30) connection_draining.enabled = False - conn.modify_lb_attribute( - "my-lb", "ConnectionDraining", connection_draining) + conn.modify_lb_attribute("my-lb", "ConnectionDraining", connection_draining) attributes = lb.get_attributes(force=True) attributes.connection_draining.enabled.should.be.false @@ -534,13 +526,13 @@ def test_connection_draining_attribute(): @mock_elb_deprecated def test_access_log_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) access_log = AccessLogAttribute() access_log.enabled = True - access_log.s3_bucket_name = 'bucket' - access_log.s3_bucket_prefix = 'prefix' + access_log.s3_bucket_name = "bucket" + access_log.s3_bucket_prefix = "prefix" access_log.emit_interval = 60 conn.modify_lb_attribute("my-lb", "AccessLog", access_log) @@ -559,20 +551,18 @@ def test_access_log_attribute(): @mock_elb_deprecated def test_connection_settings_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) connection_settings = ConnectionSettingAttribute(conn) connection_settings.idle_timeout = 120 - conn.modify_lb_attribute( - "my-lb", "ConnectingSettings", connection_settings) + conn.modify_lb_attribute("my-lb", "ConnectingSettings", connection_settings) attributes = lb.get_attributes(force=True) attributes.connecting_settings.idle_timeout.should.equal(120) connection_settings.idle_timeout = 60 - conn.modify_lb_attribute( - "my-lb", "ConnectingSettings", connection_settings) + conn.modify_lb_attribute("my-lb", "ConnectingSettings", connection_settings) attributes = lb.get_attributes(force=True) attributes.connecting_settings.idle_timeout.should.equal(60) @@ -580,8 +570,8 @@ def test_connection_settings_attribute(): @mock_elb_deprecated def test_create_lb_cookie_stickiness_policy(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) cookie_expiration_period = 60 policy_name = "LBCookieStickinessPolicy" @@ -594,55 +584,49 @@ def test_create_lb_cookie_stickiness_policy(): # # To work around that, this value is converted to an int and checked. cookie_expiration_period_response_str = lb.policies.lb_cookie_stickiness_policies[ - 0].cookie_expiration_period - int(cookie_expiration_period_response_str).should.equal( - cookie_expiration_period) - lb.policies.lb_cookie_stickiness_policies[ - 0].policy_name.should.equal(policy_name) + 0 + ].cookie_expiration_period + int(cookie_expiration_period_response_str).should.equal(cookie_expiration_period) + lb.policies.lb_cookie_stickiness_policies[0].policy_name.should.equal(policy_name) @mock_elb_deprecated def test_create_lb_cookie_stickiness_policy_no_expiry(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) policy_name = "LBCookieStickinessPolicy" lb.create_cookie_stickiness_policy(None, policy_name) lb = conn.get_all_load_balancers()[0] - lb.policies.lb_cookie_stickiness_policies[ - 0].cookie_expiration_period.should.be.none - lb.policies.lb_cookie_stickiness_policies[ - 0].policy_name.should.equal(policy_name) + lb.policies.lb_cookie_stickiness_policies[0].cookie_expiration_period.should.be.none + lb.policies.lb_cookie_stickiness_policies[0].policy_name.should.equal(policy_name) @mock_elb_deprecated def test_create_app_cookie_stickiness_policy(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) cookie_name = "my-stickiness-policy" policy_name = "AppCookieStickinessPolicy" lb.create_app_cookie_stickiness_policy(cookie_name, policy_name) lb = conn.get_all_load_balancers()[0] - lb.policies.app_cookie_stickiness_policies[ - 0].cookie_name.should.equal(cookie_name) - lb.policies.app_cookie_stickiness_policies[ - 0].policy_name.should.equal(policy_name) + lb.policies.app_cookie_stickiness_policies[0].cookie_name.should.equal(cookie_name) + lb.policies.app_cookie_stickiness_policies[0].policy_name.should.equal(policy_name) @mock_elb_deprecated def test_create_lb_policy(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) policy_name = "ProxyPolicy" - lb.create_lb_policy(policy_name, 'ProxyProtocolPolicyType', { - 'ProxyProtocol': True}) + lb.create_lb_policy(policy_name, "ProxyProtocolPolicyType", {"ProxyProtocol": True}) lb = conn.get_all_load_balancers()[0] lb.policies.other_policies[0].policy_name.should.equal(policy_name) @@ -651,8 +635,8 @@ def test_create_lb_policy(): @mock_elb_deprecated def test_set_policies_of_listener(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) listener_port = 80 policy_name = "my-stickiness-policy" @@ -674,15 +658,14 @@ def test_set_policies_of_listener(): @mock_elb_deprecated def test_set_policies_of_backend_server(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) instance_port = 8080 policy_name = "ProxyPolicy" # in a real flow, it is necessary first to create a policy, # then to set that policy to the backend - lb.create_lb_policy(policy_name, 'ProxyProtocolPolicyType', { - 'ProxyProtocol': True}) + lb.create_lb_policy(policy_name, "ProxyProtocolPolicyType", {"ProxyProtocol": True}) lb.set_policies_of_backend_server(instance_port, [policy_name]) lb = conn.get_all_load_balancers()[0] @@ -696,287 +679,262 @@ def test_set_policies_of_backend_server(): @mock_elb_deprecated def test_describe_instance_health(): ec2_conn = boto.connect_ec2() - reservation = ec2_conn.run_instances('ami-1234abcd', 2) + reservation = ec2_conn.run_instances("ami-1234abcd", 2) instance_id1 = reservation.instances[0].id instance_id2 = reservation.instances[1].id conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', zones, ports) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", zones, ports) - instances_health = conn.describe_instance_health('my-lb') + instances_health = conn.describe_instance_health("my-lb") instances_health.should.be.empty lb.register_instances([instance_id1, instance_id2]) - instances_health = conn.describe_instance_health('my-lb') + instances_health = conn.describe_instance_health("my-lb") instances_health.should.have.length_of(2) for instance_health in instances_health: - instance_health.instance_id.should.be.within( - [instance_id1, instance_id2]) - instance_health.state.should.equal('InService') + instance_health.instance_id.should.be.within([instance_id1, instance_id2]) + instance_health.state.should.equal("InService") - instances_health = conn.describe_instance_health('my-lb', [instance_id1]) + instances_health = conn.describe_instance_health("my-lb", [instance_id1]) instances_health.should.have.length_of(1) instances_health[0].instance_id.should.equal(instance_id1) - instances_health[0].state.should.equal('InService') + instances_health[0].state.should.equal("InService") @mock_ec2 @mock_elb def test_describe_instance_health_boto3(): - elb = boto3.client('elb', region_name="us-east-1") - ec2 = boto3.client('ec2', region_name="us-east-1") - instances = ec2.run_instances(MinCount=2, MaxCount=2)['Instances'] + elb = boto3.client("elb", region_name="us-east-1") + ec2 = boto3.client("ec2", region_name="us-east-1") + instances = ec2.run_instances(MinCount=2, MaxCount=2)["Instances"] lb_name = "my_load_balancer" elb.create_load_balancer( - Listeners=[{ - 'InstancePort': 80, - 'LoadBalancerPort': 8080, - 'Protocol': 'HTTP' - }], + Listeners=[{"InstancePort": 80, "LoadBalancerPort": 8080, "Protocol": "HTTP"}], LoadBalancerName=lb_name, ) elb.register_instances_with_load_balancer( - LoadBalancerName=lb_name, - Instances=[{'InstanceId': instances[0]['InstanceId']}] + LoadBalancerName=lb_name, Instances=[{"InstanceId": instances[0]["InstanceId"]}] ) instances_health = elb.describe_instance_health( LoadBalancerName=lb_name, - Instances=[{'InstanceId': instance['InstanceId']} for instance in instances] + Instances=[{"InstanceId": instance["InstanceId"]} for instance in instances], ) - instances_health['InstanceStates'].should.have.length_of(2) - instances_health['InstanceStates'][0]['InstanceId'].\ - should.equal(instances[0]['InstanceId']) - instances_health['InstanceStates'][0]['State'].\ - should.equal('InService') - instances_health['InstanceStates'][1]['InstanceId'].\ - should.equal(instances[1]['InstanceId']) - instances_health['InstanceStates'][1]['State'].\ - should.equal('Unknown') + instances_health["InstanceStates"].should.have.length_of(2) + instances_health["InstanceStates"][0]["InstanceId"].should.equal( + instances[0]["InstanceId"] + ) + instances_health["InstanceStates"][0]["State"].should.equal("InService") + instances_health["InstanceStates"][1]["InstanceId"].should.equal( + instances[1]["InstanceId"] + ) + instances_health["InstanceStates"][1]["State"].should.equal("Unknown") @mock_elb def test_add_remove_tags(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") - client.add_tags.when.called_with(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]).should.throw(botocore.exceptions.ClientError) + client.add_tags.when.called_with( + LoadBalancerNames=["my-lb"], Tags=[{"Key": "a", "Value": "b"}] + ).should.throw(botocore.exceptions.ClientError) client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(1) + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(1) - client.add_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]) + client.add_tags(LoadBalancerNames=["my-lb"], Tags=[{"Key": "a", "Value": "b"}]) - tags = dict([(d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']]) - tags.should.have.key('a').which.should.equal('b') + tags = dict( + [ + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])[ + "TagDescriptions" + ][0]["Tags"] + ] + ) + tags.should.have.key("a").which.should.equal("b") - client.add_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }, { - 'Key': 'b', - 'Value': 'b' - }, { - 'Key': 'c', - 'Value': 'b' - }, { - 'Key': 'd', - 'Value': 'b' - }, { - 'Key': 'e', - 'Value': 'b' - }, { - 'Key': 'f', - 'Value': 'b' - }, { - 'Key': 'g', - 'Value': 'b' - }, { - 'Key': 'h', - 'Value': 'b' - }, { - 'Key': 'i', - 'Value': 'b' - }, { - 'Key': 'j', - 'Value': 'b' - }]) + client.add_tags( + LoadBalancerNames=["my-lb"], + Tags=[ + {"Key": "a", "Value": "b"}, + {"Key": "b", "Value": "b"}, + {"Key": "c", "Value": "b"}, + {"Key": "d", "Value": "b"}, + {"Key": "e", "Value": "b"}, + {"Key": "f", "Value": "b"}, + {"Key": "g", "Value": "b"}, + {"Key": "h", "Value": "b"}, + {"Key": "i", "Value": "b"}, + {"Key": "j", "Value": "b"}, + ], + ) - client.add_tags.when.called_with(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'k', - 'Value': 'b' - }]).should.throw(botocore.exceptions.ClientError) + client.add_tags.when.called_with( + LoadBalancerNames=["my-lb"], Tags=[{"Key": "k", "Value": "b"}] + ).should.throw(botocore.exceptions.ClientError) - client.add_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'j', - 'Value': 'c' - }]) + client.add_tags(LoadBalancerNames=["my-lb"], Tags=[{"Key": "j", "Value": "c"}]) - tags = dict([(d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']]) + tags = dict( + [ + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])[ + "TagDescriptions" + ][0]["Tags"] + ] + ) - tags.should.have.key('a').which.should.equal('b') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('i').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') - tags.shouldnt.have.key('k') + tags.should.have.key("a").which.should.equal("b") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("i").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") + tags.shouldnt.have.key("k") - client.remove_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a' - }]) + client.remove_tags(LoadBalancerNames=["my-lb"], Tags=[{"Key": "a"}]) - tags = dict([(d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']]) + tags = dict( + [ + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])[ + "TagDescriptions" + ][0]["Tags"] + ] + ) - tags.shouldnt.have.key('a') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('i').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') + tags.shouldnt.have.key("a") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("i").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") client.create_load_balancer( - LoadBalancerName='other-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 433, 'InstancePort': 8433}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="other-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 433, "InstancePort": 8433}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client.add_tags(LoadBalancerNames=['other-lb'], - Tags=[{ - 'Key': 'other', - 'Value': 'something' - }]) + client.add_tags( + LoadBalancerNames=["other-lb"], Tags=[{"Key": "other", "Value": "something"}] + ) - lb_tags = dict([(l['LoadBalancerName'], dict([(d['Key'], d['Value']) for d in l['Tags']])) - for l in client.describe_tags(LoadBalancerNames=['my-lb', 'other-lb'])['TagDescriptions']]) + lb_tags = dict( + [ + (l["LoadBalancerName"], dict([(d["Key"], d["Value"]) for d in l["Tags"]])) + for l in client.describe_tags(LoadBalancerNames=["my-lb", "other-lb"])[ + "TagDescriptions" + ] + ] + ) - lb_tags.should.have.key('my-lb') - lb_tags.should.have.key('other-lb') + lb_tags.should.have.key("my-lb") + lb_tags.should.have.key("other-lb") - lb_tags['my-lb'].shouldnt.have.key('other') - lb_tags[ - 'other-lb'].should.have.key('other').which.should.equal('something') + lb_tags["my-lb"].shouldnt.have.key("other") + lb_tags["other-lb"].should.have.key("other").which.should.equal("something") @mock_elb def test_create_with_tags(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'], - Tags=[{ - 'Key': 'k', - 'Value': 'v' - }] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], + Tags=[{"Key": "k", "Value": "v"}], ) - tags = dict((d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']) - tags.should.have.key('k').which.should.equal('v') + tags = dict( + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])["TagDescriptions"][ + 0 + ]["Tags"] + ) + tags.should.have.key("k").which.should.equal("v") @mock_elb def test_modify_attributes(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) # Default ConnectionDraining timeout of 300 seconds client.modify_load_balancer_attributes( - LoadBalancerName='my-lb', - LoadBalancerAttributes={ - 'ConnectionDraining': {'Enabled': True}, - } + LoadBalancerName="my-lb", + LoadBalancerAttributes={"ConnectionDraining": {"Enabled": True}}, + ) + lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName="my-lb") + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Enabled"].should.equal( + True + ) + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Timeout"].should.equal( + 300 ) - lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName='my-lb') - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Enabled'].should.equal(True) - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Timeout'].should.equal(300) # specify a custom ConnectionDraining timeout client.modify_load_balancer_attributes( - LoadBalancerName='my-lb', - LoadBalancerAttributes={ - 'ConnectionDraining': { - 'Enabled': True, - 'Timeout': 45, - }, - } + LoadBalancerName="my-lb", + LoadBalancerAttributes={"ConnectionDraining": {"Enabled": True, "Timeout": 45}}, ) - lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName='my-lb') - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Enabled'].should.equal(True) - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Timeout'].should.equal(45) + lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName="my-lb") + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Enabled"].should.equal( + True + ) + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Timeout"].should.equal(45) @mock_ec2 @mock_elb def test_subnets(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc( - CidrBlock='172.28.7.0/24', - InstanceTenancy='default' - ) - subnet = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26' - ) - client = boto3.client('elb', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="172.28.7.192/26") + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - Subnets=[subnet.id] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + Subnets=[subnet.id], ) - lb = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - lb.should.have.key('Subnets').which.should.have.length_of(1) - lb['Subnets'][0].should.equal(subnet.id) + lb = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + lb.should.have.key("Subnets").which.should.have.length_of(1) + lb["Subnets"][0].should.equal(subnet.id) - lb.should.have.key('VPCId').which.should.equal(vpc.id) + lb.should.have.key("VPCId").which.should.equal(vpc.id) @mock_elb_deprecated def test_create_load_balancer_duplicate(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', [], ports) - conn.create_load_balancer.when.called_with( - 'my-lb', [], ports).should.throw(BotoServerError) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb", [], ports) + conn.create_load_balancer.when.called_with("my-lb", [], ports).should.throw( + BotoServerError + ) diff --git a/tests/test_elb/test_server.py b/tests/test_elb/test_server.py index 159da970d..0f432cef4 100644 --- a/tests/test_elb/test_server.py +++ b/tests/test_elb/test_server.py @@ -1,17 +1,17 @@ -from __future__ import unicode_literals -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_elb_describe_instances(): - backend = server.create_backend_app("elb") - test_client = backend.test_client() - - res = test_client.get('/?Action=DescribeLoadBalancers&Version=2015-12-01') - - res.data.should.contain(b'DescribeLoadBalancersResponse') +from __future__ import unicode_literals +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_elb_describe_instances(): + backend = server.create_backend_app("elb") + test_client = backend.test_client() + + res = test_client.get("/?Action=DescribeLoadBalancers&Version=2015-12-01") + + res.data.should.contain(b"DescribeLoadBalancersResponse") diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index 36772c02e..eb5df14c3 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -4,667 +4,672 @@ import json import os import boto3 import botocore -from botocore.exceptions import ClientError +from botocore.exceptions import ClientError, ParamValidationError from nose.tools import assert_raises import sure # noqa from moto import mock_elbv2, mock_ec2, mock_acm, mock_cloudformation from moto.elbv2 import elbv2_backends +from moto.core import ACCOUNT_ID @mock_elbv2 @mock_ec2 def test_create_load_balancer(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - lb = response.get('LoadBalancers')[0] + lb = response.get("LoadBalancers")[0] - lb.get('DNSName').should.equal("my-lb-1.us-east-1.elb.amazonaws.com") - lb.get('LoadBalancerArn').should.equal( - 'arn:aws:elasticloadbalancing:us-east-1:1:loadbalancer/my-lb/50dc6c495c0c9188') - lb.get('SecurityGroups').should.equal([security_group.id]) - lb.get('AvailabilityZones').should.equal([ - {'SubnetId': subnet1.id, 'ZoneName': 'us-east-1a'}, - {'SubnetId': subnet2.id, 'ZoneName': 'us-east-1b'}]) + lb.get("DNSName").should.equal("my-lb-1.us-east-1.elb.amazonaws.com") + lb.get("LoadBalancerArn").should.equal( + "arn:aws:elasticloadbalancing:us-east-1:1:loadbalancer/my-lb/50dc6c495c0c9188" + ) + lb.get("SecurityGroups").should.equal([security_group.id]) + lb.get("AvailabilityZones").should.equal( + [ + {"SubnetId": subnet1.id, "ZoneName": "us-east-1a"}, + {"SubnetId": subnet2.id, "ZoneName": "us-east-1b"}, + ] + ) # Ensure the tags persisted - response = conn.describe_tags(ResourceArns=[lb.get('LoadBalancerArn')]) - tags = {d['Key']: d['Value'] - for d in response['TagDescriptions'][0]['Tags']} - tags.should.equal({'key_name': 'a_value'}) + response = conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")]) + tags = {d["Key"]: d["Value"] for d in response["TagDescriptions"][0]["Tags"]} + tags.should.equal({"key_name": "a_value"}) @mock_elbv2 @mock_ec2 def test_describe_load_balancers(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.describe_load_balancers() - response.get('LoadBalancers').should.have.length_of(1) - lb = response.get('LoadBalancers')[0] - lb.get('LoadBalancerName').should.equal('my-lb') + response.get("LoadBalancers").should.have.length_of(1) + lb = response.get("LoadBalancers")[0] + lb.get("LoadBalancerName").should.equal("my-lb") response = conn.describe_load_balancers( - LoadBalancerArns=[lb.get('LoadBalancerArn')]) - response.get('LoadBalancers')[0].get( - 'LoadBalancerName').should.equal('my-lb') + LoadBalancerArns=[lb.get("LoadBalancerArn")] + ) + response.get("LoadBalancers")[0].get("LoadBalancerName").should.equal("my-lb") - response = conn.describe_load_balancers(Names=['my-lb']) - response.get('LoadBalancers')[0].get( - 'LoadBalancerName').should.equal('my-lb') + response = conn.describe_load_balancers(Names=["my-lb"]) + response.get("LoadBalancers")[0].get("LoadBalancerName").should.equal("my-lb") with assert_raises(ClientError): - conn.describe_load_balancers(LoadBalancerArns=['not-a/real/arn']) + conn.describe_load_balancers(LoadBalancerArns=["not-a/real/arn"]) with assert_raises(ClientError): - conn.describe_load_balancers(Names=['nope']) + conn.describe_load_balancers(Names=["nope"]) @mock_elbv2 @mock_ec2 def test_add_remove_tags(): - conn = boto3.client('elbv2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - lbs = conn.describe_load_balancers()['LoadBalancers'] + lbs = conn.describe_load_balancers()["LoadBalancers"] lbs.should.have.length_of(1) lb = lbs[0] with assert_raises(ClientError): - conn.add_tags(ResourceArns=['missing-arn'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]) + conn.add_tags(ResourceArns=["missing-arn"], Tags=[{"Key": "a", "Value": "b"}]) - conn.add_tags(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]) + conn.add_tags( + ResourceArns=[lb.get("LoadBalancerArn")], Tags=[{"Key": "a", "Value": "b"}] + ) - tags = {d['Key']: d['Value'] for d in conn.describe_tags( - ResourceArns=[lb.get('LoadBalancerArn')])['TagDescriptions'][0]['Tags']} - tags.should.have.key('a').which.should.equal('b') + tags = { + d["Key"]: d["Value"] + for d in conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")])[ + "TagDescriptions" + ][0]["Tags"] + } + tags.should.have.key("a").which.should.equal("b") - conn.add_tags(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }, { - 'Key': 'b', - 'Value': 'b' - }, { - 'Key': 'c', - 'Value': 'b' - }, { - 'Key': 'd', - 'Value': 'b' - }, { - 'Key': 'e', - 'Value': 'b' - }, { - 'Key': 'f', - 'Value': 'b' - }, { - 'Key': 'g', - 'Value': 'b' - }, { - 'Key': 'h', - 'Value': 'b' - }, { - 'Key': 'j', - 'Value': 'b' - }]) + conn.add_tags( + ResourceArns=[lb.get("LoadBalancerArn")], + Tags=[ + {"Key": "a", "Value": "b"}, + {"Key": "b", "Value": "b"}, + {"Key": "c", "Value": "b"}, + {"Key": "d", "Value": "b"}, + {"Key": "e", "Value": "b"}, + {"Key": "f", "Value": "b"}, + {"Key": "g", "Value": "b"}, + {"Key": "h", "Value": "b"}, + {"Key": "j", "Value": "b"}, + ], + ) - conn.add_tags.when.called_with(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'k', - 'Value': 'b' - }]).should.throw(botocore.exceptions.ClientError) + conn.add_tags.when.called_with( + ResourceArns=[lb.get("LoadBalancerArn")], Tags=[{"Key": "k", "Value": "b"}] + ).should.throw(botocore.exceptions.ClientError) - conn.add_tags(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'j', - 'Value': 'c' - }]) + conn.add_tags( + ResourceArns=[lb.get("LoadBalancerArn")], Tags=[{"Key": "j", "Value": "c"}] + ) - tags = {d['Key']: d['Value'] for d in conn.describe_tags( - ResourceArns=[lb.get('LoadBalancerArn')])['TagDescriptions'][0]['Tags']} + tags = { + d["Key"]: d["Value"] + for d in conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")])[ + "TagDescriptions" + ][0]["Tags"] + } - tags.should.have.key('a').which.should.equal('b') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') - tags.shouldnt.have.key('k') + tags.should.have.key("a").which.should.equal("b") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") + tags.shouldnt.have.key("k") - conn.remove_tags(ResourceArns=[lb.get('LoadBalancerArn')], - TagKeys=['a']) + conn.remove_tags(ResourceArns=[lb.get("LoadBalancerArn")], TagKeys=["a"]) - tags = {d['Key']: d['Value'] for d in conn.describe_tags( - ResourceArns=[lb.get('LoadBalancerArn')])['TagDescriptions'][0]['Tags']} + tags = { + d["Key"]: d["Value"] + for d in conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")])[ + "TagDescriptions" + ][0]["Tags"] + } - tags.shouldnt.have.key('a') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') + tags.shouldnt.have.key("a") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") @mock_elbv2 @mock_ec2 def test_create_elb_in_multiple_region(): - for region in ['us-west-1', 'us-west-2']: - conn = boto3.client('elbv2', region_name=region) - ec2 = boto3.resource('ec2', region_name=region) + for region in ["us-west-1", "us-west-2"]: + conn = boto3.client("elbv2", region_name=region) + ec2 = boto3.resource("ec2", region_name=region) security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc( - CidrBlock='172.28.7.0/24', - InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone=region + 'a') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone=region + "a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone=region + 'b') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone=region + "b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) list( - boto3.client( - 'elbv2', - region_name='us-west-1').describe_load_balancers().get('LoadBalancers') + boto3.client("elbv2", region_name="us-west-1") + .describe_load_balancers() + .get("LoadBalancers") ).should.have.length_of(1) list( - boto3.client( - 'elbv2', - region_name='us-west-2').describe_load_balancers().get('LoadBalancers') + boto3.client("elbv2", region_name="us-west-2") + .describe_load_balancers() + .get("LoadBalancers") ).should.have.length_of(1) @mock_elbv2 @mock_ec2 def test_create_target_group_and_listeners(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") # Can't create a target group with an invalid protocol with assert_raises(ClientError): conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='/HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="/HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] - target_group_arn = target_group['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] + target_group_arn = target_group["TargetGroupArn"] # Add tags to the target group - conn.add_tags(ResourceArns=[target_group_arn], Tags=[ - {'Key': 'target', 'Value': 'group'}]) - conn.describe_tags(ResourceArns=[target_group_arn])['TagDescriptions'][0]['Tags'].should.equal( - [{'Key': 'target', 'Value': 'group'}]) + conn.add_tags( + ResourceArns=[target_group_arn], Tags=[{"Key": "target", "Value": "group"}] + ) + conn.describe_tags(ResourceArns=[target_group_arn])["TagDescriptions"][0][ + "Tags" + ].should.equal([{"Key": "target", "Value": "group"}]) # Check it's in the describe_target_groups response response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(1) + response.get("TargetGroups").should.have.length_of(1) # Plain HTTP listener response = conn.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', + Protocol="HTTP", Port=80, - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group.get('TargetGroupArn')}]) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(80) - listener.get('Protocol').should.equal('HTTP') - listener.get('DefaultActions').should.equal([{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward'}]) - http_listener_arn = listener.get('ListenerArn') + DefaultActions=[ + {"Type": "forward", "TargetGroupArn": target_group.get("TargetGroupArn")} + ], + ) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(80) + listener.get("Protocol").should.equal("HTTP") + listener.get("DefaultActions").should.equal( + [{"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"}] + ) + http_listener_arn = listener.get("ListenerArn") - response = conn.describe_target_groups(LoadBalancerArn=load_balancer_arn, - Names=['a-target']) - response.get('TargetGroups').should.have.length_of(1) + response = conn.describe_target_groups( + LoadBalancerArn=load_balancer_arn, Names=["a-target"] + ) + response.get("TargetGroups").should.have.length_of(1) # And another with SSL response = conn.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTPS', + Protocol="HTTPS", Port=443, Certificates=[ - {'CertificateArn': 'arn:aws:iam:123456789012:server-certificate/test-cert'}], - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group.get('TargetGroupArn')}]) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(443) - listener.get('Protocol').should.equal('HTTPS') - listener.get('Certificates').should.equal([{ - 'CertificateArn': 'arn:aws:iam:123456789012:server-certificate/test-cert', - }]) - listener.get('DefaultActions').should.equal([{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward'}]) + { + "CertificateArn": "arn:aws:iam:{}:server-certificate/test-cert".format( + ACCOUNT_ID + ) + } + ], + DefaultActions=[ + {"Type": "forward", "TargetGroupArn": target_group.get("TargetGroupArn")} + ], + ) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(443) + listener.get("Protocol").should.equal("HTTPS") + listener.get("Certificates").should.equal( + [ + { + "CertificateArn": "arn:aws:iam:{}:server-certificate/test-cert".format( + ACCOUNT_ID + ) + } + ] + ) + listener.get("DefaultActions").should.equal( + [{"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"}] + ) - https_listener_arn = listener.get('ListenerArn') + https_listener_arn = listener.get("ListenerArn") response = conn.describe_listeners(LoadBalancerArn=load_balancer_arn) - response.get('Listeners').should.have.length_of(2) + response.get("Listeners").should.have.length_of(2) response = conn.describe_listeners(ListenerArns=[https_listener_arn]) - response.get('Listeners').should.have.length_of(1) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(443) - listener.get('Protocol').should.equal('HTTPS') + response.get("Listeners").should.have.length_of(1) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(443) + listener.get("Protocol").should.equal("HTTPS") response = conn.describe_listeners( - ListenerArns=[ - http_listener_arn, - https_listener_arn]) - response.get('Listeners').should.have.length_of(2) + ListenerArns=[http_listener_arn, https_listener_arn] + ) + response.get("Listeners").should.have.length_of(2) # Try to delete the target group and it fails because there's a # listener referencing it with assert_raises(ClientError) as e: - conn.delete_target_group( - TargetGroupArn=target_group.get('TargetGroupArn')) - e.exception.operation_name.should.equal('DeleteTargetGroup') - e.exception.args.should.equal(("An error occurred (ResourceInUse) when calling the DeleteTargetGroup operation: The target group 'arn:aws:elasticloadbalancing:us-east-1:1:targetgroup/a-target/50dc6c495c0c9188' is currently in use by a listener or a rule", )) # NOQA + conn.delete_target_group(TargetGroupArn=target_group.get("TargetGroupArn")) + e.exception.operation_name.should.equal("DeleteTargetGroup") + e.exception.args.should.equal( + ( + "An error occurred (ResourceInUse) when calling the DeleteTargetGroup operation: The target group 'arn:aws:elasticloadbalancing:us-east-1:1:targetgroup/a-target/50dc6c495c0c9188' is currently in use by a listener or a rule", + ) + ) # NOQA # Delete one listener response = conn.describe_listeners(LoadBalancerArn=load_balancer_arn) - response.get('Listeners').should.have.length_of(2) + response.get("Listeners").should.have.length_of(2) conn.delete_listener(ListenerArn=http_listener_arn) response = conn.describe_listeners(LoadBalancerArn=load_balancer_arn) - response.get('Listeners').should.have.length_of(1) + response.get("Listeners").should.have.length_of(1) # Then delete the load balancer conn.delete_load_balancer(LoadBalancerArn=load_balancer_arn) # It's gone response = conn.describe_load_balancers() - response.get('LoadBalancers').should.have.length_of(0) + response.get("LoadBalancers").should.have.length_of(0) # And it deleted the remaining listener response = conn.describe_listeners( - ListenerArns=[ - http_listener_arn, - https_listener_arn]) - response.get('Listeners').should.have.length_of(0) + ListenerArns=[http_listener_arn, https_listener_arn] + ) + response.get("Listeners").should.have.length_of(0) # But not the target groups response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(1) + response.get("TargetGroups").should.have.length_of(1) # Which we'll now delete - conn.delete_target_group(TargetGroupArn=target_group.get('TargetGroupArn')) + conn.delete_target_group(TargetGroupArn=target_group.get("TargetGroupArn")) response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(0) + response.get("TargetGroups").should.have.length_of(0) @mock_elbv2 @mock_ec2 def test_create_target_group_without_non_required_parameters(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) # request without HealthCheckIntervalSeconds parameter # which is default to 30 seconds response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080' + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", ) - target_group = response.get('TargetGroups')[0] + target_group = response.get("TargetGroups")[0] target_group.should_not.be.none @mock_elbv2 @mock_ec2 def test_create_invalid_target_group(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") # Fail to create target group with name which length is 33 - long_name = 'A' * 33 + long_name = "A" * 33 with assert_raises(ClientError): conn.create_target_group( Name=long_name, - Protocol='HTTP', + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) - invalid_names = [ - '-name', - 'name-', - '-name-', - 'example.com', - 'test@test', - 'Na--me'] + invalid_names = ["-name", "name-", "-name-", "example.com", "test@test", "Na--me"] for name in invalid_names: with assert_raises(ClientError): conn.create_target_group( Name=name, - Protocol='HTTP', + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) - valid_names = ['name', 'Name', '000'] + valid_names = ["name", "Name", "000"] for name in valid_names: conn.create_target_group( Name=name, - Protocol='HTTP', + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) @mock_elbv2 @mock_ec2 def test_describe_paginated_balancers(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) for i in range(51): conn.create_load_balancer( - Name='my-lb%d' % i, + Name="my-lb%d" % i, Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) resp = conn.describe_load_balancers() - resp['LoadBalancers'].should.have.length_of(50) - resp['NextMarker'].should.equal( - resp['LoadBalancers'][-1]['LoadBalancerName']) - resp2 = conn.describe_load_balancers(Marker=resp['NextMarker']) - resp2['LoadBalancers'].should.have.length_of(1) - assert 'NextToken' not in resp2.keys() + resp["LoadBalancers"].should.have.length_of(50) + resp["NextMarker"].should.equal(resp["LoadBalancers"][-1]["LoadBalancerName"]) + resp2 = conn.describe_load_balancers(Marker=resp["NextMarker"]) + resp2["LoadBalancers"].should.have.length_of(1) + assert "NextToken" not in resp2.keys() @mock_elbv2 @mock_ec2 def test_delete_load_balancer(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - response.get('LoadBalancers').should.have.length_of(1) - lb = response.get('LoadBalancers')[0] + response.get("LoadBalancers").should.have.length_of(1) + lb = response.get("LoadBalancers")[0] - conn.delete_load_balancer(LoadBalancerArn=lb.get('LoadBalancerArn')) - balancers = conn.describe_load_balancers().get('LoadBalancers') + conn.delete_load_balancer(LoadBalancerArn=lb.get("LoadBalancerArn")) + balancers = conn.describe_load_balancers().get("LoadBalancers") balancers.should.have.length_of(0) @mock_ec2 @mock_elbv2 def test_register_targets(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # No targets registered yet response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(0) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=2, MaxCount=2) + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=2, MaxCount=2) instance_id1 = response[0].id instance_id2 = response[1].id response = conn.register_targets( - TargetGroupArn=target_group.get('TargetGroupArn'), + TargetGroupArn=target_group.get("TargetGroupArn"), Targets=[ - { - 'Id': instance_id1, - 'Port': 5060, - }, - { - 'Id': instance_id2, - 'Port': 4030, - }, - ]) + {"Id": instance_id1, "Port": 5060}, + {"Id": instance_id2, "Port": 4030}, + ], + ) response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(2) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(2) response = conn.deregister_targets( - TargetGroupArn=target_group.get('TargetGroupArn'), - Targets=[{'Id': instance_id2}]) + TargetGroupArn=target_group.get("TargetGroupArn"), + Targets=[{"Id": instance_id2}], + ) response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(1) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) @mock_ec2 @@ -672,289 +677,350 @@ def test_register_targets(): def test_stopped_instance_target(): target_group_port = 8080 - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=target_group_port, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # No targets registered yet response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(0) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=1, MaxCount=1) + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=1, MaxCount=1) instance = response[0] - target_dict = { - 'Id': instance.id, - 'Port': 500 - } + target_dict = {"Id": instance.id, "Port": 500} response = conn.register_targets( - TargetGroupArn=target_group.get('TargetGroupArn'), - Targets=[target_dict]) + TargetGroupArn=target_group.get("TargetGroupArn"), Targets=[target_dict] + ) response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(1) - target_health_description = response.get('TargetHealthDescriptions')[0] + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) + target_health_description = response.get("TargetHealthDescriptions")[0] - target_health_description['Target'].should.equal(target_dict) - target_health_description['HealthCheckPort'].should.equal(str(target_group_port)) - target_health_description['TargetHealth'].should.equal({ - 'State': 'healthy' - }) + target_health_description["Target"].should.equal(target_dict) + target_health_description["HealthCheckPort"].should.equal(str(target_group_port)) + target_health_description["TargetHealth"].should.equal({"State": "healthy"}) instance.stop() response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(1) - target_health_description = response.get('TargetHealthDescriptions')[0] - target_health_description['Target'].should.equal(target_dict) - target_health_description['HealthCheckPort'].should.equal(str(target_group_port)) - target_health_description['TargetHealth'].should.equal({ - 'State': 'unused', - 'Reason': 'Target.InvalidState', - 'Description': 'Target is in the stopped state' - }) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) + target_health_description = response.get("TargetHealthDescriptions")[0] + target_health_description["Target"].should.equal(target_dict) + target_health_description["HealthCheckPort"].should.equal(str(target_group_port)) + target_health_description["TargetHealth"].should.equal( + { + "State": "unused", + "Reason": "Target.InvalidState", + "Description": "Target is in the stopped state", + } + ) + + +@mock_ec2 +@mock_elbv2 +def test_terminated_instance_target(): + target_group_port = 8080 + + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + + security_group = ec2.create_security_group( + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet1 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) + subnet2 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) + + conn.create_load_balancer( + Name="my-lb", + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + + response = conn.create_target_group( + Name="a-target", + Protocol="HTTP", + Port=target_group_port, + VpcId=vpc.id, + HealthCheckProtocol="HTTP", + HealthCheckPath="/", + HealthCheckIntervalSeconds=5, + HealthCheckTimeoutSeconds=5, + HealthyThresholdCount=5, + UnhealthyThresholdCount=2, + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] + + # No targets registered yet + response = conn.describe_target_health( + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) + + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=1, MaxCount=1) + instance = response[0] + + target_dict = {"Id": instance.id, "Port": 500} + + response = conn.register_targets( + TargetGroupArn=target_group.get("TargetGroupArn"), Targets=[target_dict] + ) + + response = conn.describe_target_health( + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) + target_health_description = response.get("TargetHealthDescriptions")[0] + + target_health_description["Target"].should.equal(target_dict) + target_health_description["HealthCheckPort"].should.equal(str(target_group_port)) + target_health_description["TargetHealth"].should.equal({"State": "healthy"}) + + instance.terminate() + + response = conn.describe_target_health( + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) @mock_ec2 @mock_elbv2 def test_target_group_attributes(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # Check it's in the describe_target_groups response response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(1) - target_group_arn = target_group['TargetGroupArn'] + response.get("TargetGroups").should.have.length_of(1) + target_group_arn = target_group["TargetGroupArn"] # check if Names filter works response = conn.describe_target_groups(Names=[]) - response = conn.describe_target_groups(Names=['a-target']) - response.get('TargetGroups').should.have.length_of(1) - target_group_arn = target_group['TargetGroupArn'] + response = conn.describe_target_groups(Names=["a-target"]) + response.get("TargetGroups").should.have.length_of(1) + target_group_arn = target_group["TargetGroupArn"] # The attributes should start with the two defaults - response = conn.describe_target_group_attributes( - TargetGroupArn=target_group_arn) - response['Attributes'].should.have.length_of(2) - attributes = {attr['Key']: attr['Value'] - for attr in response['Attributes']} - attributes['deregistration_delay.timeout_seconds'].should.equal('300') - attributes['stickiness.enabled'].should.equal('false') + response = conn.describe_target_group_attributes(TargetGroupArn=target_group_arn) + response["Attributes"].should.have.length_of(2) + attributes = {attr["Key"]: attr["Value"] for attr in response["Attributes"]} + attributes["deregistration_delay.timeout_seconds"].should.equal("300") + attributes["stickiness.enabled"].should.equal("false") # Add cookie stickiness response = conn.modify_target_group_attributes( TargetGroupArn=target_group_arn, Attributes=[ - { - 'Key': 'stickiness.enabled', - 'Value': 'true', - }, - { - 'Key': 'stickiness.type', - 'Value': 'lb_cookie', - }, - ]) + {"Key": "stickiness.enabled", "Value": "true"}, + {"Key": "stickiness.type", "Value": "lb_cookie"}, + ], + ) # The response should have only the keys updated - response['Attributes'].should.have.length_of(2) - attributes = {attr['Key']: attr['Value'] - for attr in response['Attributes']} - attributes['stickiness.type'].should.equal('lb_cookie') - attributes['stickiness.enabled'].should.equal('true') + response["Attributes"].should.have.length_of(2) + attributes = {attr["Key"]: attr["Value"] for attr in response["Attributes"]} + attributes["stickiness.type"].should.equal("lb_cookie") + attributes["stickiness.enabled"].should.equal("true") # These new values should be in the full attribute list - response = conn.describe_target_group_attributes( - TargetGroupArn=target_group_arn) - response['Attributes'].should.have.length_of(3) - attributes = {attr['Key']: attr['Value'] - for attr in response['Attributes']} - attributes['stickiness.type'].should.equal('lb_cookie') - attributes['stickiness.enabled'].should.equal('true') + response = conn.describe_target_group_attributes(TargetGroupArn=target_group_arn) + response["Attributes"].should.have.length_of(3) + attributes = {attr["Key"]: attr["Value"] for attr in response["Attributes"]} + attributes["stickiness.type"].should.equal("lb_cookie") + attributes["stickiness.enabled"].should.equal("true") @mock_elbv2 @mock_ec2 def test_handle_listener_rules(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") # Can't create a target group with an invalid protocol with assert_raises(ClientError): conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='/HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="/HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # Plain HTTP listener response = conn.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', + Protocol="HTTP", Port=80, - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group.get('TargetGroupArn')}]) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(80) - listener.get('Protocol').should.equal('HTTP') - listener.get('DefaultActions').should.equal([{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward'}]) - http_listener_arn = listener.get('ListenerArn') + DefaultActions=[ + {"Type": "forward", "TargetGroupArn": target_group.get("TargetGroupArn")} + ], + ) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(80) + listener.get("Protocol").should.equal("HTTP") + listener.get("DefaultActions").should.equal( + [{"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"}] + ) + http_listener_arn = listener.get("ListenerArn") # create first rule priority = 100 - host = 'xxx.example.com' - path_pattern = 'foobar' + host = "xxx.example.com" + path_pattern = "foobar" created_rule = conn.create_rule( ListenerArn=http_listener_arn, Priority=priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, - { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] - )['Rules'][0] - created_rule['Priority'].should.equal('100') + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ + {"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"} + ], + )["Rules"][0] + created_rule["Priority"].should.equal("100") # check if rules is sorted by priority priority = 50 - host = 'yyy.example.com' - path_pattern = 'foobar' + host = "yyy.example.com" + path_pattern = "foobar" rules = conn.create_rule( ListenerArn=http_listener_arn, Priority=priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, - { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ + {"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"} + ], ) # test for PriorityInUse @@ -962,46 +1028,43 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) # test for describe listeners obtained_rules = conn.describe_rules(ListenerArn=http_listener_arn) - len(obtained_rules['Rules']).should.equal(3) - priorities = [rule['Priority'] for rule in obtained_rules['Rules']] - priorities.should.equal(['50', '100', 'default']) + len(obtained_rules["Rules"]).should.equal(3) + priorities = [rule["Priority"] for rule in obtained_rules["Rules"]] + priorities.should.equal(["50", "100", "default"]) - first_rule = obtained_rules['Rules'][0] - second_rule = obtained_rules['Rules'][1] - obtained_rules = conn.describe_rules(RuleArns=[first_rule['RuleArn']]) - obtained_rules['Rules'].should.equal([first_rule]) + first_rule = obtained_rules["Rules"][0] + second_rule = obtained_rules["Rules"][1] + obtained_rules = conn.describe_rules(RuleArns=[first_rule["RuleArn"]]) + obtained_rules["Rules"].should.equal([first_rule]) # test for pagination - obtained_rules = conn.describe_rules( - ListenerArn=http_listener_arn, PageSize=1) - len(obtained_rules['Rules']).should.equal(1) - obtained_rules.should.have.key('NextMarker') - next_marker = obtained_rules['NextMarker'] + obtained_rules = conn.describe_rules(ListenerArn=http_listener_arn, PageSize=1) + len(obtained_rules["Rules"]).should.equal(1) + obtained_rules.should.have.key("NextMarker") + next_marker = obtained_rules["NextMarker"] following_rules = conn.describe_rules( - ListenerArn=http_listener_arn, - PageSize=1, - Marker=next_marker) - len(following_rules['Rules']).should.equal(1) - following_rules.should.have.key('NextMarker') - following_rules['Rules'][0]['RuleArn'].should_not.equal( - obtained_rules['Rules'][0]['RuleArn']) + ListenerArn=http_listener_arn, PageSize=1, Marker=next_marker + ) + len(following_rules["Rules"]).should.equal(1) + following_rules.should.have.key("NextMarker") + following_rules["Rules"][0]["RuleArn"].should_not.equal( + obtained_rules["Rules"][0]["RuleArn"] + ) # test for invalid describe rule request with assert_raises(ClientError): @@ -1010,52 +1073,50 @@ def test_handle_listener_rules(): conn.describe_rules(RuleArns=[]) with assert_raises(ClientError): conn.describe_rules( - ListenerArn=http_listener_arn, - RuleArns=[first_rule['RuleArn']] + ListenerArn=http_listener_arn, RuleArns=[first_rule["RuleArn"]] ) # modify rule partially - new_host = 'new.example.com' - new_path_pattern = 'new_path' + new_host = "new.example.com" + new_path_pattern = "new_path" modified_rule = conn.modify_rule( - RuleArn=first_rule['RuleArn'], - Conditions=[{ - 'Field': 'host-header', - 'Values': [new_host] - }, - { - 'Field': 'path-pattern', - 'Values': [new_path_pattern] - }] - )['Rules'][0] + RuleArn=first_rule["RuleArn"], + Conditions=[ + {"Field": "host-header", "Values": [new_host]}, + {"Field": "path-pattern", "Values": [new_path_pattern]}, + ], + )["Rules"][0] rules = conn.describe_rules(ListenerArn=http_listener_arn) - obtained_rule = rules['Rules'][0] + obtained_rule = rules["Rules"][0] modified_rule.should.equal(obtained_rule) - obtained_rule['Conditions'][0]['Values'][0].should.equal(new_host) - obtained_rule['Conditions'][1]['Values'][0].should.equal(new_path_pattern) - obtained_rule['Actions'][0]['TargetGroupArn'].should.equal( - target_group.get('TargetGroupArn')) + obtained_rule["Conditions"][0]["Values"][0].should.equal(new_host) + obtained_rule["Conditions"][1]["Values"][0].should.equal(new_path_pattern) + obtained_rule["Actions"][0]["TargetGroupArn"].should.equal( + target_group.get("TargetGroupArn") + ) # modify priority conn.set_rule_priorities( RulePriorities=[ - {'RuleArn': first_rule['RuleArn'], - 'Priority': int(first_rule['Priority']) - 1} + { + "RuleArn": first_rule["RuleArn"], + "Priority": int(first_rule["Priority"]) - 1, + } ] ) with assert_raises(ClientError): conn.set_rule_priorities( RulePriorities=[ - {'RuleArn': first_rule['RuleArn'], 'Priority': 999}, - {'RuleArn': second_rule['RuleArn'], 'Priority': 999} + {"RuleArn": first_rule["RuleArn"], "Priority": 999}, + {"RuleArn": second_rule["RuleArn"], "Priority": 999}, ] ) # delete - arn = first_rule['RuleArn'] + arn = first_rule["RuleArn"] conn.delete_rule(RuleArn=arn) - rules = conn.describe_rules(ListenerArn=http_listener_arn)['Rules'] + rules = conn.describe_rules(ListenerArn=http_listener_arn)["Rules"] len(rules).should.equal(2) # test for invalid action type @@ -1064,39 +1125,30 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward2' - }] + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward2", + } + ], ) # test for invalid action type safe_priority = 2 - invalid_target_group_arn = target_group.get('TargetGroupArn') + 'x' + invalid_target_group_arn = target_group.get("TargetGroupArn") + "x" with assert_raises(ClientError): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, - { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': invalid_target_group_arn, - 'Type': 'forward' - }] + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[{"TargetGroupArn": invalid_target_group_arn, "Type": "forward"}], ) # test for invalid condition field_name @@ -1105,14 +1157,13 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'xxxxxxx', - 'Values': [host] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[{"Field": "xxxxxxx", "Values": [host]}], + Actions=[ + { + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) # test for emptry condition value @@ -1121,14 +1172,13 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[{"Field": "host-header", "Values": []}], + Actions=[ + { + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) # test for multiple condition value @@ -1137,444 +1187,440 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host, host] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[{"Field": "host-header", "Values": [host, host]}], + Actions=[ + { + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) @mock_elbv2 @mock_ec2 def test_describe_invalid_target_group(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - response.get('LoadBalancers')[0].get('LoadBalancerArn') + response.get("LoadBalancers")[0].get("LoadBalancerArn") response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) # Check error raises correctly with assert_raises(ClientError): - conn.describe_target_groups(Names=['invalid']) + conn.describe_target_groups(Names=["invalid"]) @mock_elbv2 @mock_ec2 def test_describe_target_groups_no_arguments(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - response.get('LoadBalancers')[0].get('LoadBalancerArn') + response.get("LoadBalancers")[0].get("LoadBalancerArn") conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) - assert len(conn.describe_target_groups()['TargetGroups']) == 1 + assert len(conn.describe_target_groups()["TargetGroups"]) == 1 @mock_elbv2 def test_describe_account_limits(): - client = boto3.client('elbv2', region_name='eu-central-1') + client = boto3.client("elbv2", region_name="eu-central-1") resp = client.describe_account_limits() - resp['Limits'][0].should.contain('Name') - resp['Limits'][0].should.contain('Max') + resp["Limits"][0].should.contain("Name") + resp["Limits"][0].should.contain("Max") @mock_elbv2 def test_describe_ssl_policies(): - client = boto3.client('elbv2', region_name='eu-central-1') + client = boto3.client("elbv2", region_name="eu-central-1") resp = client.describe_ssl_policies() - len(resp['SslPolicies']).should.equal(5) + len(resp["SslPolicies"]).should.equal(5) - resp = client.describe_ssl_policies(Names=['ELBSecurityPolicy-TLS-1-2-2017-01', 'ELBSecurityPolicy-2016-08']) - len(resp['SslPolicies']).should.equal(2) + resp = client.describe_ssl_policies( + Names=["ELBSecurityPolicy-TLS-1-2-2017-01", "ELBSecurityPolicy-2016-08"] + ) + len(resp["SslPolicies"]).should.equal(2) @mock_elbv2 @mock_ec2 def test_set_ip_address_type(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] # Internal LBs cant be dualstack yet with assert_raises(ClientError): - client.set_ip_address_type( - LoadBalancerArn=arn, - IpAddressType='dualstack' - ) + client.set_ip_address_type(LoadBalancerArn=arn, IpAddressType="dualstack") # Create internet facing one response = client.create_load_balancer( - Name='my-lb2', + Name="my-lb2", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internet-facing', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] - - client.set_ip_address_type( - LoadBalancerArn=arn, - IpAddressType='dualstack' + Scheme="internet-facing", + Tags=[{"Key": "key_name", "Value": "a_value"}], ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] + + client.set_ip_address_type(LoadBalancerArn=arn, IpAddressType="dualstack") @mock_elbv2 @mock_ec2 def test_set_security_groups(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') + GroupName="a-security-group", Description="First One" + ) security_group2 = ec2.create_security_group( - GroupName='b-security-group', Description='Second One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="b-security-group", Description="Second One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] client.set_security_groups( - LoadBalancerArn=arn, - SecurityGroups=[security_group.id, security_group2.id] + LoadBalancerArn=arn, SecurityGroups=[security_group.id, security_group2.id] ) resp = client.describe_load_balancers(LoadBalancerArns=[arn]) - len(resp['LoadBalancers'][0]['SecurityGroups']).should.equal(2) + len(resp["LoadBalancers"][0]["SecurityGroups"]).should.equal(2) with assert_raises(ClientError): - client.set_security_groups( - LoadBalancerArn=arn, - SecurityGroups=['non_existant'] - ) + client.set_security_groups(LoadBalancerArn=arn, SecurityGroups=["non_existant"]) @mock_elbv2 @mock_ec2 def test_set_subnets(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.64/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.64/26", AvailabilityZone="us-east-1b" + ) subnet3 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1c') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1c" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] client.set_subnets( - LoadBalancerArn=arn, - Subnets=[subnet1.id, subnet2.id, subnet3.id] + LoadBalancerArn=arn, Subnets=[subnet1.id, subnet2.id, subnet3.id] ) resp = client.describe_load_balancers(LoadBalancerArns=[arn]) - len(resp['LoadBalancers'][0]['AvailabilityZones']).should.equal(3) + len(resp["LoadBalancers"][0]["AvailabilityZones"]).should.equal(3) # Only 1 AZ with assert_raises(ClientError): - client.set_subnets( - LoadBalancerArn=arn, - Subnets=[subnet1.id] - ) + client.set_subnets(LoadBalancerArn=arn, Subnets=[subnet1.id]) # Multiple subnets in same AZ with assert_raises(ClientError): client.set_subnets( - LoadBalancerArn=arn, - Subnets=[subnet1.id, subnet2.id, subnet2.id] + LoadBalancerArn=arn, Subnets=[subnet1.id, subnet2.id, subnet2.id] ) @mock_elbv2 @mock_ec2 def test_set_subnets(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] client.modify_load_balancer_attributes( LoadBalancerArn=arn, - Attributes=[{'Key': 'idle_timeout.timeout_seconds', 'Value': '600'}] + Attributes=[{"Key": "idle_timeout.timeout_seconds", "Value": "600"}], ) # Check its 600 not 60 - response = client.describe_load_balancer_attributes( - LoadBalancerArn=arn - ) - idle_timeout = list(filter(lambda item: item['Key'] == 'idle_timeout.timeout_seconds', response['Attributes']))[0] - idle_timeout['Value'].should.equal('600') + response = client.describe_load_balancer_attributes(LoadBalancerArn=arn) + idle_timeout = list( + filter( + lambda item: item["Key"] == "idle_timeout.timeout_seconds", + response["Attributes"], + ) + )[0] + idle_timeout["Value"].should.equal("600") @mock_elbv2 @mock_ec2 def test_modify_target_group(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") response = client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - arn = response.get('TargetGroups')[0]['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + arn = response.get("TargetGroups")[0]["TargetGroupArn"] client.modify_target_group( TargetGroupArn=arn, - HealthCheckProtocol='HTTPS', - HealthCheckPort='8081', - HealthCheckPath='/status', + HealthCheckProtocol="HTTPS", + HealthCheckPort="8081", + HealthCheckPath="/status", HealthCheckIntervalSeconds=10, HealthCheckTimeoutSeconds=10, HealthyThresholdCount=10, UnhealthyThresholdCount=4, - Matcher={'HttpCode': '200-399'} + Matcher={"HttpCode": "200-399"}, ) - response = client.describe_target_groups( - TargetGroupArns=[arn] - ) - response['TargetGroups'][0]['Matcher']['HttpCode'].should.equal('200-399') - response['TargetGroups'][0]['HealthCheckIntervalSeconds'].should.equal(10) - response['TargetGroups'][0]['HealthCheckPath'].should.equal('/status') - response['TargetGroups'][0]['HealthCheckPort'].should.equal('8081') - response['TargetGroups'][0]['HealthCheckProtocol'].should.equal('HTTPS') - response['TargetGroups'][0]['HealthCheckTimeoutSeconds'].should.equal(10) - response['TargetGroups'][0]['HealthyThresholdCount'].should.equal(10) - response['TargetGroups'][0]['UnhealthyThresholdCount'].should.equal(4) + response = client.describe_target_groups(TargetGroupArns=[arn]) + response["TargetGroups"][0]["Matcher"]["HttpCode"].should.equal("200-399") + response["TargetGroups"][0]["HealthCheckIntervalSeconds"].should.equal(10) + response["TargetGroups"][0]["HealthCheckPath"].should.equal("/status") + response["TargetGroups"][0]["HealthCheckPort"].should.equal("8081") + response["TargetGroups"][0]["HealthCheckProtocol"].should.equal("HTTPS") + response["TargetGroups"][0]["HealthCheckTimeoutSeconds"].should.equal(10) + response["TargetGroups"][0]["HealthyThresholdCount"].should.equal(10) + response["TargetGroups"][0]["UnhealthyThresholdCount"].should.equal(4) @mock_elbv2 @mock_ec2 @mock_acm def test_modify_listener_http_to_https(): - client = boto3.client('elbv2', region_name='eu-central-1') - acm = boto3.client('acm', region_name='eu-central-1') - ec2 = boto3.resource('ec2', region_name='eu-central-1') + client = boto3.client("elbv2", region_name="eu-central-1") + acm = boto3.client("acm", region_name="eu-central-1") + ec2 = boto3.resource("ec2", region_name="eu-central-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='eu-central-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="eu-central-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='eu-central-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="eu-central-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") response = client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] - target_group_arn = target_group['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] + target_group_arn = target_group["TargetGroupArn"] # Plain HTTP listener response = client.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', + Protocol="HTTP", Port=80, - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group_arn}] + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) - listener_arn = response['Listeners'][0]['ListenerArn'] + listener_arn = response["Listeners"][0]["ListenerArn"] response = acm.request_certificate( - DomainName='google.com', - SubjectAlternativeNames=['google.com', 'www.google.com', 'mail.google.com'], + DomainName="google.com", + SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"], ) - google_arn = response['CertificateArn'] + google_arn = response["CertificateArn"] response = acm.request_certificate( - DomainName='yahoo.com', - SubjectAlternativeNames=['yahoo.com', 'www.yahoo.com', 'mail.yahoo.com'], + DomainName="yahoo.com", + SubjectAlternativeNames=["yahoo.com", "www.yahoo.com", "mail.yahoo.com"], ) - yahoo_arn = response['CertificateArn'] + yahoo_arn = response["CertificateArn"] response = client.modify_listener( ListenerArn=listener_arn, Port=443, - Protocol='HTTPS', - SslPolicy='ELBSecurityPolicy-TLS-1-2-2017-01', + Protocol="HTTPS", + SslPolicy="ELBSecurityPolicy-TLS-1-2-2017-01", Certificates=[ - {'CertificateArn': google_arn, 'IsDefault': False}, - {'CertificateArn': yahoo_arn, 'IsDefault': True} + {"CertificateArn": google_arn, "IsDefault": False}, + {"CertificateArn": yahoo_arn, "IsDefault": True}, ], - DefaultActions=[ - {'Type': 'forward', 'TargetGroupArn': target_group_arn} - ] + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) - response['Listeners'][0]['Port'].should.equal(443) - response['Listeners'][0]['Protocol'].should.equal('HTTPS') - response['Listeners'][0]['SslPolicy'].should.equal('ELBSecurityPolicy-TLS-1-2-2017-01') - len(response['Listeners'][0]['Certificates']).should.equal(2) + response["Listeners"][0]["Port"].should.equal(443) + response["Listeners"][0]["Protocol"].should.equal("HTTPS") + response["Listeners"][0]["SslPolicy"].should.equal( + "ELBSecurityPolicy-TLS-1-2-2017-01" + ) + len(response["Listeners"][0]["Certificates"]).should.equal(2) # Check default cert, can't do this in server mode - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': - listener = elbv2_backends['eu-central-1'].load_balancers[load_balancer_arn].listeners[listener_arn] + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": + listener = ( + elbv2_backends["eu-central-1"] + .load_balancers[load_balancer_arn] + .listeners[listener_arn] + ) listener.certificate.should.equal(yahoo_arn) # No default cert @@ -1582,14 +1628,10 @@ def test_modify_listener_http_to_https(): client.modify_listener( ListenerArn=listener_arn, Port=443, - Protocol='HTTPS', - SslPolicy='ELBSecurityPolicy-TLS-1-2-2017-01', - Certificates=[ - {'CertificateArn': google_arn, 'IsDefault': False} - ], - DefaultActions=[ - {'Type': 'forward', 'TargetGroupArn': target_group_arn} - ] + Protocol="HTTPS", + SslPolicy="ELBSecurityPolicy-TLS-1-2-2017-01", + Certificates=[{"CertificateArn": google_arn, "IsDefault": False}], + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) # Bad cert @@ -1597,14 +1639,10 @@ def test_modify_listener_http_to_https(): client.modify_listener( ListenerArn=listener_arn, Port=443, - Protocol='HTTPS', - SslPolicy='ELBSecurityPolicy-TLS-1-2-2017-01', - Certificates=[ - {'CertificateArn': 'lalala', 'IsDefault': True} - ], - DefaultActions=[ - {'Type': 'forward', 'TargetGroupArn': target_group_arn} - ] + Protocol="HTTPS", + SslPolicy="ELBSecurityPolicy-TLS-1-2-2017-01", + Certificates=[{"CertificateArn": "lalala", "IsDefault": True}], + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) @@ -1612,8 +1650,8 @@ def test_modify_listener_http_to_https(): @mock_elbv2 @mock_cloudformation def test_create_target_groups_through_cloudformation(): - cfn_conn = boto3.client('cloudformation', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + cfn_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") # test that setting a name manually as well as letting cloudformation create a name both work # this is a special case because test groups have a name length limit of 22 characters, and must be unique @@ -1624,9 +1662,7 @@ def test_create_target_groups_through_cloudformation(): "Resources": { "testVPC": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - }, + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "testGroup1": { "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", @@ -1653,93 +1689,117 @@ def test_create_target_groups_through_cloudformation(): "VpcId": {"Ref": "testVPC"}, }, }, - } + }, } template_json = json.dumps(template) - cfn_conn.create_stack( - StackName="test-stack", - TemplateBody=template_json, - ) + cfn_conn.create_stack(StackName="test-stack", TemplateBody=template_json) describe_target_groups_response = elbv2_client.describe_target_groups() - target_group_dicts = describe_target_groups_response['TargetGroups'] + target_group_dicts = describe_target_groups_response["TargetGroups"] assert len(target_group_dicts) == 3 # there should be 2 target groups with the same prefix of 10 characters (since the random suffix is 12) # and one named MyTargetGroup - assert len([tg for tg in target_group_dicts if tg['TargetGroupName'] == 'MyTargetGroup']) == 1 - assert len( - [tg for tg in target_group_dicts if tg['TargetGroupName'].startswith('test-stack')] - ) == 2 + assert ( + len( + [ + tg + for tg in target_group_dicts + if tg["TargetGroupName"] == "MyTargetGroup" + ] + ) + == 1 + ) + assert ( + len( + [ + tg + for tg in target_group_dicts + if tg["TargetGroupName"].startswith("test-stack") + ] + ) + == 2 + ) @mock_elbv2 @mock_ec2 def test_redirect_action_listener_rule(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") - response = conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[ - {'Type': 'redirect', - 'RedirectConfig': { - 'Protocol': 'HTTPS', - 'Port': '443', - 'StatusCode': 'HTTP_301' - }}]) + response = conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[ + { + "Type": "redirect", + "RedirectConfig": { + "Protocol": "HTTPS", + "Port": "443", + "StatusCode": "HTTP_301", + }, + } + ], + ) - listener = response.get('Listeners')[0] - expected_default_actions = [{ - 'Type': 'redirect', - 'RedirectConfig': { - 'Protocol': 'HTTPS', - 'Port': '443', - 'StatusCode': 'HTTP_301' + listener = response.get("Listeners")[0] + expected_default_actions = [ + { + "Type": "redirect", + "RedirectConfig": { + "Protocol": "HTTPS", + "Port": "443", + "StatusCode": "HTTP_301", + }, } - }] - listener.get('DefaultActions').should.equal(expected_default_actions) - listener_arn = listener.get('ListenerArn') + ] + listener.get("DefaultActions").should.equal(expected_default_actions) + listener_arn = listener.get("ListenerArn") describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) - describe_rules_response['Rules'][0]['Actions'].should.equal(expected_default_actions) + describe_rules_response["Rules"][0]["Actions"].should.equal( + expected_default_actions + ) - describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) - describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'] + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn]) + describe_listener_actions = describe_listener_response["Listeners"][0][ + "DefaultActions" + ] describe_listener_actions.should.equal(expected_default_actions) modify_listener_response = conn.modify_listener(ListenerArn=listener_arn, Port=81) - modify_listener_actions = modify_listener_response['Listeners'][0]['DefaultActions'] + modify_listener_actions = modify_listener_response["Listeners"][0]["DefaultActions"] modify_listener_actions.should.equal(expected_default_actions) @mock_elbv2 @mock_cloudformation def test_redirect_action_listener_rule_cloudformation(): - cnf_conn = boto3.client('cloudformation', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + cnf_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -1747,9 +1807,7 @@ def test_redirect_action_listener_rule_cloudformation(): "Resources": { "testVPC": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - }, + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "subnet1": { "Type": "AWS::EC2::Subnet", @@ -1774,7 +1832,7 @@ def test_redirect_action_listener_rule_cloudformation(): "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], "Type": "application", "SecurityGroups": [], - } + }, }, "testListener": { "Type": "AWS::ElasticLoadBalancingV2::Listener", @@ -1782,93 +1840,112 @@ def test_redirect_action_listener_rule_cloudformation(): "LoadBalancerArn": {"Ref": "testLb"}, "Port": 80, "Protocol": "HTTP", - "DefaultActions": [{ - "Type": "redirect", - "RedirectConfig": { - "Port": "443", - "Protocol": "HTTPS", - "StatusCode": "HTTP_301", + "DefaultActions": [ + { + "Type": "redirect", + "RedirectConfig": { + "Port": "443", + "Protocol": "HTTPS", + "StatusCode": "HTTP_301", + }, } - }] - } - - } - } + ], + }, + }, + }, } template_json = json.dumps(template) cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) - describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) - describe_load_balancers_response['LoadBalancers'].should.have.length_of(1) - load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] + describe_load_balancers_response = elbv2_client.describe_load_balancers( + Names=["my-lb"] + ) + describe_load_balancers_response["LoadBalancers"].should.have.length_of(1) + load_balancer_arn = describe_load_balancers_response["LoadBalancers"][0][ + "LoadBalancerArn" + ] - describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + describe_listeners_response = elbv2_client.describe_listeners( + LoadBalancerArn=load_balancer_arn + ) - describe_listeners_response['Listeners'].should.have.length_of(1) - describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ - 'Type': 'redirect', - 'RedirectConfig': { - 'Port': '443', 'Protocol': 'HTTPS', 'StatusCode': 'HTTP_301', - } - },]) + describe_listeners_response["Listeners"].should.have.length_of(1) + describe_listeners_response["Listeners"][0]["DefaultActions"].should.equal( + [ + { + "Type": "redirect", + "RedirectConfig": { + "Port": "443", + "Protocol": "HTTPS", + "StatusCode": "HTTP_301", + }, + } + ] + ) @mock_elbv2 @mock_ec2 def test_cognito_action_listener_rule(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") action = { - 'Type': 'authenticate-cognito', - 'AuthenticateCognitoConfig': { - 'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234', - 'UserPoolClientId': 'abcd1234abcd', - 'UserPoolDomain': 'testpool', - } + "Type": "authenticate-cognito", + "AuthenticateCognitoConfig": { + "UserPoolArn": "arn:aws:cognito-idp:us-east-1:{}:userpool/us-east-1_ABCD1234".format( + ACCOUNT_ID + ), + "UserPoolClientId": "abcd1234abcd", + "UserPoolDomain": "testpool", + }, } - response = conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[action]) + response = conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[action], + ) - listener = response.get('Listeners')[0] - listener.get('DefaultActions')[0].should.equal(action) - listener_arn = listener.get('ListenerArn') + listener = response.get("Listeners")[0] + listener.get("DefaultActions")[0].should.equal(action) + listener_arn = listener.get("ListenerArn") describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) - describe_rules_response['Rules'][0]['Actions'][0].should.equal(action) + describe_rules_response["Rules"][0]["Actions"][0].should.equal(action) - describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) - describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0] + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn]) + describe_listener_actions = describe_listener_response["Listeners"][0][ + "DefaultActions" + ][0] describe_listener_actions.should.equal(action) @mock_elbv2 @mock_cloudformation def test_cognito_action_listener_rule_cloudformation(): - cnf_conn = boto3.client('cloudformation', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + cnf_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -1876,9 +1953,7 @@ def test_cognito_action_listener_rule_cloudformation(): "Resources": { "testVPC": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - }, + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "subnet1": { "Type": "AWS::EC2::Subnet", @@ -1903,7 +1978,7 @@ def test_cognito_action_listener_rule_cloudformation(): "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], "Type": "application", "SecurityGroups": [], - } + }, }, "testListener": { "Type": "AWS::ElasticLoadBalancingV2::Listener", @@ -1911,32 +1986,348 @@ def test_cognito_action_listener_rule_cloudformation(): "LoadBalancerArn": {"Ref": "testLb"}, "Port": 80, "Protocol": "HTTP", - "DefaultActions": [{ - "Type": "authenticate-cognito", - "AuthenticateCognitoConfig": { - 'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234', - 'UserPoolClientId': 'abcd1234abcd', - 'UserPoolDomain': 'testpool', + "DefaultActions": [ + { + "Type": "authenticate-cognito", + "AuthenticateCognitoConfig": { + "UserPoolArn": "arn:aws:cognito-idp:us-east-1:{}:userpool/us-east-1_ABCD1234".format( + ACCOUNT_ID + ), + "UserPoolClientId": "abcd1234abcd", + "UserPoolDomain": "testpool", + }, } - }] - } - - } - } + ], + }, + }, + }, } template_json = json.dumps(template) cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) - describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) - load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] - describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + describe_load_balancers_response = elbv2_client.describe_load_balancers( + Names=["my-lb"] + ) + load_balancer_arn = describe_load_balancers_response["LoadBalancers"][0][ + "LoadBalancerArn" + ] + describe_listeners_response = elbv2_client.describe_listeners( + LoadBalancerArn=load_balancer_arn + ) - describe_listeners_response['Listeners'].should.have.length_of(1) - describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ - 'Type': 'authenticate-cognito', - "AuthenticateCognitoConfig": { - 'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234', - 'UserPoolClientId': 'abcd1234abcd', - 'UserPoolDomain': 'testpool', + describe_listeners_response["Listeners"].should.have.length_of(1) + describe_listeners_response["Listeners"][0]["DefaultActions"].should.equal( + [ + { + "Type": "authenticate-cognito", + "AuthenticateCognitoConfig": { + "UserPoolArn": "arn:aws:cognito-idp:us-east-1:{}:userpool/us-east-1_ABCD1234".format( + ACCOUNT_ID + ), + "UserPoolClientId": "abcd1234abcd", + "UserPoolDomain": "testpool", + }, + } + ] + ) + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule(): + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + + security_group = ec2.create_security_group( + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet1 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) + subnet2 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) + + response = conn.create_load_balancer( + Name="my-lb", + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") + + action = { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "404", + }, + } + response = conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[action], + ) + + listener = response.get("Listeners")[0] + listener.get("DefaultActions")[0].should.equal(action) + listener_arn = listener.get("ListenerArn") + + describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) + describe_rules_response["Rules"][0]["Actions"][0].should.equal(action) + + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn]) + describe_listener_actions = describe_listener_response["Listeners"][0][ + "DefaultActions" + ][0] + describe_listener_actions.should.equal(action) + + +@mock_elbv2 +@mock_cloudformation +def test_fixed_response_action_listener_rule_cloudformation(): + cnf_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") + + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Description": "ECS Cluster Test CloudFormation", + "Resources": { + "testVPC": { + "Type": "AWS::EC2::VPC", + "Properties": {"CidrBlock": "10.0.0.0/16"}, + }, + "subnet1": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.0.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "subnet2": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.1.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "testLb": { + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Name": "my-lb", + "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], + "Type": "application", + "SecurityGroups": [], + }, + }, + "testListener": { + "Type": "AWS::ElasticLoadBalancingV2::Listener", + "Properties": { + "LoadBalancerArn": {"Ref": "testLb"}, + "Port": 80, + "Protocol": "HTTP", + "DefaultActions": [ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "404", + }, + } + ], + }, + }, + }, + } + template_json = json.dumps(template) + cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) + + describe_load_balancers_response = elbv2_client.describe_load_balancers( + Names=["my-lb"] + ) + load_balancer_arn = describe_load_balancers_response["LoadBalancers"][0][ + "LoadBalancerArn" + ] + describe_listeners_response = elbv2_client.describe_listeners( + LoadBalancerArn=load_balancer_arn + ) + + describe_listeners_response["Listeners"].should.have.length_of(1) + describe_listeners_response["Listeners"][0]["DefaultActions"].should.equal( + [ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "404", + }, + } + ] + ) + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule_validates_status_code(): + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + + security_group = ec2.create_security_group( + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet1 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) + subnet2 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) + + response = conn.create_load_balancer( + Name="my-lb", + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") + + missing_status_code_action = { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + }, + } + with assert_raises(ParamValidationError): + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[missing_status_code_action], + ) + + invalid_status_code_action = { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "100", + }, + } + + @mock_elbv2 + @mock_ec2 + def test_fixed_response_action_listener_rule_validates_status_code(): + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + + security_group = ec2.create_security_group( + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet1 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) + subnet2 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) + + response = conn.create_load_balancer( + Name="my-lb", + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") + + missing_status_code_action = { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + }, } - },]) + with assert_raises(ParamValidationError): + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[missing_status_code_action], + ) + + invalid_status_code_action = { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "100", + }, + } + + with assert_raises(ClientError) as invalid_status_code_exception: + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[invalid_status_code_action], + ) + + invalid_status_code_exception.exception.response["Error"]["Code"].should.equal( + "ValidationError" + ) + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule_validates_content_type(): + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + + security_group = ec2.create_security_group( + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet1 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) + subnet2 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) + + response = conn.create_load_balancer( + Name="my-lb", + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") + + invalid_content_type_action = { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "Fake content type", + "MessageBody": "This page does not exist", + "StatusCode": "200", + }, + } + with assert_raises(ClientError) as invalid_content_type_exception: + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[invalid_content_type_action], + ) + invalid_content_type_exception.exception.response["Error"]["Code"].should.equal( + "InvalidLoadBalancerAction" + ) diff --git a/tests/test_elbv2/test_server.py b/tests/test_elbv2/test_server.py index 7d47d23ad..7d2ce4b01 100644 --- a/tests/test_elbv2/test_server.py +++ b/tests/test_elbv2/test_server.py @@ -1,17 +1,17 @@ -from __future__ import unicode_literals -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_elbv2_describe_load_balancers(): - backend = server.create_backend_app("elbv2") - test_client = backend.test_client() - - res = test_client.get('/?Action=DescribeLoadBalancers&Version=2015-12-01') - - res.data.should.contain(b'DescribeLoadBalancersResponse') +from __future__ import unicode_literals +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_elbv2_describe_load_balancers(): + backend = server.create_backend_app("elbv2") + test_client = backend.test_client() + + res = test_client.get("/?Action=DescribeLoadBalancers&Version=2015-12-01") + + res.data.should.contain(b"DescribeLoadBalancersResponse") diff --git a/tests/test_emr/test_emr.py b/tests/test_emr/test_emr.py index a1918ac30..0dea23066 100644 --- a/tests/test_emr/test_emr.py +++ b/tests/test_emr/test_emr.py @@ -1,658 +1,664 @@ -from __future__ import unicode_literals -import time -from datetime import datetime - -import boto -import pytz -from boto.emr.bootstrap_action import BootstrapAction -from boto.emr.instance_group import InstanceGroup -from boto.emr.step import StreamingStep - -import six -import sure # noqa - -from moto import mock_emr_deprecated -from tests.helpers import requires_boto_gte - - -run_jobflow_args = dict( - job_flow_role='EMR_EC2_DefaultRole', - keep_alive=True, - log_uri='s3://some_bucket/jobflow_logs', - master_instance_type='c1.medium', - name='My jobflow', - num_instances=2, - service_role='EMR_DefaultRole', - slave_instance_type='c1.medium', -) - - -input_instance_groups = [ - InstanceGroup(1, 'MASTER', 'c1.medium', 'ON_DEMAND', 'master'), - InstanceGroup(3, 'CORE', 'c1.medium', 'ON_DEMAND', 'core'), - InstanceGroup(6, 'TASK', 'c1.large', 'SPOT', 'task-1', '0.07'), - InstanceGroup(10, 'TASK', 'c1.xlarge', 'SPOT', 'task-2', '0.05'), -] - - -@mock_emr_deprecated -def test_describe_cluster(): - conn = boto.connect_emr() - args = run_jobflow_args.copy() - args.update(dict( - api_params={ - 'Applications.member.1.Name': 'Spark', - 'Applications.member.1.Version': '2.4.2', - 'Configurations.member.1.Classification': 'yarn-site', - 'Configurations.member.1.Properties.entry.1.key': 'someproperty', - 'Configurations.member.1.Properties.entry.1.value': 'somevalue', - 'Configurations.member.1.Properties.entry.2.key': 'someotherproperty', - 'Configurations.member.1.Properties.entry.2.value': 'someothervalue', - 'Instances.EmrManagedMasterSecurityGroup': 'master-security-group', - 'Instances.Ec2SubnetId': 'subnet-8be41cec', - }, - availability_zone='us-east-2b', - ec2_keyname='mykey', - job_flow_role='EMR_EC2_DefaultRole', - keep_alive=False, - log_uri='s3://some_bucket/jobflow_logs', - name='My jobflow', - service_role='EMR_DefaultRole', - visible_to_all_users=True, - )) - cluster_id = conn.run_jobflow(**args) - input_tags = {'tag1': 'val1', 'tag2': 'val2'} - conn.add_tags(cluster_id, input_tags) - - cluster = conn.describe_cluster(cluster_id) - cluster.applications[0].name.should.equal('Spark') - cluster.applications[0].version.should.equal('2.4.2') - cluster.autoterminate.should.equal('true') - - # configurations appear not be supplied as attributes? - - attrs = cluster.ec2instanceattributes - # AdditionalMasterSecurityGroups - # AdditionalSlaveSecurityGroups - attrs.ec2availabilityzone.should.equal(args['availability_zone']) - attrs.ec2keyname.should.equal(args['ec2_keyname']) - attrs.ec2subnetid.should.equal(args['api_params']['Instances.Ec2SubnetId']) - # EmrManagedMasterSecurityGroups - # EmrManagedSlaveSecurityGroups - attrs.iaminstanceprofile.should.equal(args['job_flow_role']) - # ServiceAccessSecurityGroup - - cluster.id.should.equal(cluster_id) - cluster.loguri.should.equal(args['log_uri']) - cluster.masterpublicdnsname.should.be.a(six.string_types) - cluster.name.should.equal(args['name']) - int(cluster.normalizedinstancehours).should.equal(0) - # cluster.release_label - cluster.shouldnt.have.property('requestedamiversion') - cluster.runningamiversion.should.equal('1.0.0') - # cluster.securityconfiguration - cluster.servicerole.should.equal(args['service_role']) - - cluster.status.state.should.equal('TERMINATED') - cluster.status.statechangereason.message.should.be.a(six.string_types) - cluster.status.statechangereason.code.should.be.a(six.string_types) - cluster.status.timeline.creationdatetime.should.be.a(six.string_types) - # cluster.status.timeline.enddatetime.should.be.a(six.string_types) - # cluster.status.timeline.readydatetime.should.be.a(six.string_types) - - dict((item.key, item.value) - for item in cluster.tags).should.equal(input_tags) - - cluster.terminationprotected.should.equal('false') - cluster.visibletoallusers.should.equal('true') - - -@mock_emr_deprecated -def test_describe_jobflows(): - conn = boto.connect_emr() - args = run_jobflow_args.copy() - expected = {} - - for idx in range(4): - cluster_name = 'cluster' + str(idx) - args['name'] = cluster_name - cluster_id = conn.run_jobflow(**args) - expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'state': 'WAITING' - } - - # need sleep since it appears the timestamp is always rounded to - # the nearest second internally - time.sleep(1) - timestamp = datetime.now(pytz.utc) - time.sleep(1) - - for idx in range(4, 6): - cluster_name = 'cluster' + str(idx) - args['name'] = cluster_name - cluster_id = conn.run_jobflow(**args) - conn.terminate_jobflow(cluster_id) - expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'state': 'TERMINATED' - } - jobs = conn.describe_jobflows() - jobs.should.have.length_of(6) - - for cluster_id, y in expected.items(): - resp = conn.describe_jobflows(jobflow_ids=[cluster_id]) - resp.should.have.length_of(1) - resp[0].jobflowid.should.equal(cluster_id) - - resp = conn.describe_jobflows(states=['WAITING']) - resp.should.have.length_of(4) - for x in resp: - x.state.should.equal('WAITING') - - resp = conn.describe_jobflows(created_before=timestamp) - resp.should.have.length_of(4) - - resp = conn.describe_jobflows(created_after=timestamp) - resp.should.have.length_of(2) - - -@mock_emr_deprecated -def test_describe_jobflow(): - conn = boto.connect_emr() - args = run_jobflow_args.copy() - args.update(dict( - ami_version='3.8.1', - api_params={ - #'Applications.member.1.Name': 'Spark', - #'Applications.member.1.Version': '2.4.2', - #'Configurations.member.1.Classification': 'yarn-site', - #'Configurations.member.1.Properties.entry.1.key': 'someproperty', - #'Configurations.member.1.Properties.entry.1.value': 'somevalue', - #'Instances.EmrManagedMasterSecurityGroup': 'master-security-group', - 'Instances.Ec2SubnetId': 'subnet-8be41cec', - }, - ec2_keyname='mykey', - hadoop_version='2.4.0', - - name='My jobflow', - log_uri='s3://some_bucket/jobflow_logs', - keep_alive=True, - master_instance_type='c1.medium', - slave_instance_type='c1.medium', - num_instances=2, - - availability_zone='us-west-2b', - - job_flow_role='EMR_EC2_DefaultRole', - service_role='EMR_DefaultRole', - visible_to_all_users=True, - )) - - cluster_id = conn.run_jobflow(**args) - jf = conn.describe_jobflow(cluster_id) - jf.amiversion.should.equal(args['ami_version']) - jf.bootstrapactions.should.equal(None) - jf.creationdatetime.should.be.a(six.string_types) - jf.should.have.property('laststatechangereason') - jf.readydatetime.should.be.a(six.string_types) - jf.startdatetime.should.be.a(six.string_types) - jf.state.should.equal('WAITING') - - jf.ec2keyname.should.equal(args['ec2_keyname']) - # Ec2SubnetId - jf.hadoopversion.should.equal(args['hadoop_version']) - int(jf.instancecount).should.equal(2) - - for ig in jf.instancegroups: - ig.creationdatetime.should.be.a(six.string_types) - # ig.enddatetime.should.be.a(six.string_types) - ig.should.have.property('instancegroupid').being.a(six.string_types) - int(ig.instancerequestcount).should.equal(1) - ig.instancerole.should.be.within(['MASTER', 'CORE']) - int(ig.instancerunningcount).should.equal(1) - ig.instancetype.should.equal('c1.medium') - ig.laststatechangereason.should.be.a(six.string_types) - ig.market.should.equal('ON_DEMAND') - ig.name.should.be.a(six.string_types) - ig.readydatetime.should.be.a(six.string_types) - ig.startdatetime.should.be.a(six.string_types) - ig.state.should.equal('RUNNING') - - jf.keepjobflowalivewhennosteps.should.equal('true') - jf.masterinstanceid.should.be.a(six.string_types) - jf.masterinstancetype.should.equal(args['master_instance_type']) - jf.masterpublicdnsname.should.be.a(six.string_types) - int(jf.normalizedinstancehours).should.equal(0) - jf.availabilityzone.should.equal(args['availability_zone']) - jf.slaveinstancetype.should.equal(args['slave_instance_type']) - jf.terminationprotected.should.equal('false') - - jf.jobflowid.should.equal(cluster_id) - # jf.jobflowrole.should.equal(args['job_flow_role']) - jf.loguri.should.equal(args['log_uri']) - jf.name.should.equal(args['name']) - # jf.servicerole.should.equal(args['service_role']) - - jf.steps.should.have.length_of(0) - - list(i.value for i in jf.supported_products).should.equal([]) - jf.visibletoallusers.should.equal('true') - - -@mock_emr_deprecated -def test_list_clusters(): - conn = boto.connect_emr() - args = run_jobflow_args.copy() - expected = {} - - for idx in range(40): - cluster_name = 'jobflow' + str(idx) - args['name'] = cluster_name - cluster_id = conn.run_jobflow(**args) - expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'normalizedinstancehours': '0', - 'state': 'WAITING' - } - - # need sleep since it appears the timestamp is always rounded to - # the nearest second internally - time.sleep(1) - timestamp = datetime.now(pytz.utc) - time.sleep(1) - - for idx in range(40, 70): - cluster_name = 'jobflow' + str(idx) - args['name'] = cluster_name - cluster_id = conn.run_jobflow(**args) - conn.terminate_jobflow(cluster_id) - expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'normalizedinstancehours': '0', - 'state': 'TERMINATED' - } - - args = {} - while 1: - resp = conn.list_clusters(**args) - clusters = resp.clusters - len(clusters).should.be.lower_than_or_equal_to(50) - for x in clusters: - y = expected[x.id] - x.id.should.equal(y['id']) - x.name.should.equal(y['name']) - x.normalizedinstancehours.should.equal( - y['normalizedinstancehours']) - x.status.state.should.equal(y['state']) - x.status.timeline.creationdatetime.should.be.a(six.string_types) - if y['state'] == 'TERMINATED': - x.status.timeline.enddatetime.should.be.a(six.string_types) - else: - x.status.timeline.shouldnt.have.property('enddatetime') - x.status.timeline.readydatetime.should.be.a(six.string_types) - if not hasattr(resp, 'marker'): - break - args = {'marker': resp.marker} - - resp = conn.list_clusters(cluster_states=['TERMINATED']) - resp.clusters.should.have.length_of(30) - for x in resp.clusters: - x.status.state.should.equal('TERMINATED') - - resp = conn.list_clusters(created_before=timestamp) - resp.clusters.should.have.length_of(40) - - resp = conn.list_clusters(created_after=timestamp) - resp.clusters.should.have.length_of(30) - - -@mock_emr_deprecated -def test_run_jobflow(): - conn = boto.connect_emr() - args = run_jobflow_args.copy() - job_id = conn.run_jobflow(**args) - job_flow = conn.describe_jobflow(job_id) - job_flow.state.should.equal('WAITING') - job_flow.jobflowid.should.equal(job_id) - job_flow.name.should.equal(args['name']) - job_flow.masterinstancetype.should.equal(args['master_instance_type']) - job_flow.slaveinstancetype.should.equal(args['slave_instance_type']) - job_flow.loguri.should.equal(args['log_uri']) - job_flow.visibletoallusers.should.equal('false') - int(job_flow.normalizedinstancehours).should.equal(0) - job_flow.steps.should.have.length_of(0) - - -@mock_emr_deprecated -def test_run_jobflow_in_multiple_regions(): - regions = {} - for region in ['us-east-1', 'eu-west-1']: - conn = boto.emr.connect_to_region(region) - args = run_jobflow_args.copy() - args['name'] = region - cluster_id = conn.run_jobflow(**args) - regions[region] = {'conn': conn, 'cluster_id': cluster_id} - - for region in regions.keys(): - conn = regions[region]['conn'] - jf = conn.describe_jobflow(regions[region]['cluster_id']) - jf.name.should.equal(region) - - -@requires_boto_gte("2.8") -@mock_emr_deprecated -def test_run_jobflow_with_new_params(): - # Test that run_jobflow works with newer params - conn = boto.connect_emr() - conn.run_jobflow(**run_jobflow_args) - - -@requires_boto_gte("2.8") -@mock_emr_deprecated -def test_run_jobflow_with_visible_to_all_users(): - conn = boto.connect_emr() - for expected in (True, False): - job_id = conn.run_jobflow( - visible_to_all_users=expected, - **run_jobflow_args - ) - job_flow = conn.describe_jobflow(job_id) - job_flow.visibletoallusers.should.equal(str(expected).lower()) - - -@requires_boto_gte("2.8") -@mock_emr_deprecated -def test_run_jobflow_with_instance_groups(): - input_groups = dict((g.name, g) for g in input_instance_groups) - conn = boto.connect_emr() - job_id = conn.run_jobflow(instance_groups=input_instance_groups, - **run_jobflow_args) - job_flow = conn.describe_jobflow(job_id) - int(job_flow.instancecount).should.equal( - sum(g.num_instances for g in input_instance_groups)) - for instance_group in job_flow.instancegroups: - expected = input_groups[instance_group.name] - instance_group.should.have.property('instancegroupid') - int(instance_group.instancerunningcount).should.equal( - expected.num_instances) - instance_group.instancerole.should.equal(expected.role) - instance_group.instancetype.should.equal(expected.type) - instance_group.market.should.equal(expected.market) - if hasattr(expected, 'bidprice'): - instance_group.bidprice.should.equal(expected.bidprice) - - -@requires_boto_gte("2.8") -@mock_emr_deprecated -def test_set_termination_protection(): - conn = boto.connect_emr() - job_id = conn.run_jobflow(**run_jobflow_args) - job_flow = conn.describe_jobflow(job_id) - job_flow.terminationprotected.should.equal('false') - - conn.set_termination_protection(job_id, True) - job_flow = conn.describe_jobflow(job_id) - job_flow.terminationprotected.should.equal('true') - - conn.set_termination_protection(job_id, False) - job_flow = conn.describe_jobflow(job_id) - job_flow.terminationprotected.should.equal('false') - - -@requires_boto_gte("2.8") -@mock_emr_deprecated -def test_set_visible_to_all_users(): - conn = boto.connect_emr() - args = run_jobflow_args.copy() - args['visible_to_all_users'] = False - job_id = conn.run_jobflow(**args) - job_flow = conn.describe_jobflow(job_id) - job_flow.visibletoallusers.should.equal('false') - - conn.set_visible_to_all_users(job_id, True) - job_flow = conn.describe_jobflow(job_id) - job_flow.visibletoallusers.should.equal('true') - - conn.set_visible_to_all_users(job_id, False) - job_flow = conn.describe_jobflow(job_id) - job_flow.visibletoallusers.should.equal('false') - - -@mock_emr_deprecated -def test_terminate_jobflow(): - conn = boto.connect_emr() - job_id = conn.run_jobflow(**run_jobflow_args) - flow = conn.describe_jobflows()[0] - flow.state.should.equal('WAITING') - - conn.terminate_jobflow(job_id) - flow = conn.describe_jobflows()[0] - flow.state.should.equal('TERMINATED') - - -# testing multiple end points for each feature - -@mock_emr_deprecated -def test_bootstrap_actions(): - bootstrap_actions = [ - BootstrapAction( - name='bs1', - path='path/to/script', - bootstrap_action_args=['arg1', 'arg2&arg3']), - BootstrapAction( - name='bs2', - path='path/to/anotherscript', - bootstrap_action_args=[]) - ] - - conn = boto.connect_emr() - cluster_id = conn.run_jobflow( - bootstrap_actions=bootstrap_actions, - **run_jobflow_args - ) - - jf = conn.describe_jobflow(cluster_id) - for x, y in zip(jf.bootstrapactions, bootstrap_actions): - x.name.should.equal(y.name) - x.path.should.equal(y.path) - list(o.value for o in x.args).should.equal(y.args()) - - resp = conn.list_bootstrap_actions(cluster_id) - for i, y in enumerate(bootstrap_actions): - x = resp.actions[i] - x.name.should.equal(y.name) - x.scriptpath.should.equal(y.path) - list(arg.value for arg in x.args).should.equal(y.args()) - - -@mock_emr_deprecated -def test_instance_groups(): - input_groups = dict((g.name, g) for g in input_instance_groups) - - conn = boto.connect_emr() - args = run_jobflow_args.copy() - for key in ['master_instance_type', 'slave_instance_type', 'num_instances']: - del args[key] - args['instance_groups'] = input_instance_groups[:2] - job_id = conn.run_jobflow(**args) - - jf = conn.describe_jobflow(job_id) - base_instance_count = int(jf.instancecount) - - conn.add_instance_groups(job_id, input_instance_groups[2:]) - - jf = conn.describe_jobflow(job_id) - int(jf.instancecount).should.equal( - sum(g.num_instances for g in input_instance_groups)) - for x in jf.instancegroups: - y = input_groups[x.name] - if hasattr(y, 'bidprice'): - x.bidprice.should.equal(y.bidprice) - x.creationdatetime.should.be.a(six.string_types) - # x.enddatetime.should.be.a(six.string_types) - x.should.have.property('instancegroupid') - int(x.instancerequestcount).should.equal(y.num_instances) - x.instancerole.should.equal(y.role) - int(x.instancerunningcount).should.equal(y.num_instances) - x.instancetype.should.equal(y.type) - x.laststatechangereason.should.be.a(six.string_types) - x.market.should.equal(y.market) - x.name.should.be.a(six.string_types) - x.readydatetime.should.be.a(six.string_types) - x.startdatetime.should.be.a(six.string_types) - x.state.should.equal('RUNNING') - - for x in conn.list_instance_groups(job_id).instancegroups: - y = input_groups[x.name] - if hasattr(y, 'bidprice'): - x.bidprice.should.equal(y.bidprice) - # Configurations - # EbsBlockDevices - # EbsOptimized - x.should.have.property('id') - x.instancegrouptype.should.equal(y.role) - x.instancetype.should.equal(y.type) - x.market.should.equal(y.market) - x.name.should.equal(y.name) - int(x.requestedinstancecount).should.equal(y.num_instances) - int(x.runninginstancecount).should.equal(y.num_instances) - # ShrinkPolicy - x.status.state.should.equal('RUNNING') - x.status.statechangereason.code.should.be.a(six.string_types) - x.status.statechangereason.message.should.be.a(six.string_types) - x.status.timeline.creationdatetime.should.be.a(six.string_types) - # x.status.timeline.enddatetime.should.be.a(six.string_types) - x.status.timeline.readydatetime.should.be.a(six.string_types) - - igs = dict((g.name, g) for g in jf.instancegroups) - - conn.modify_instance_groups( - [igs['task-1'].instancegroupid, igs['task-2'].instancegroupid], - [2, 3]) - jf = conn.describe_jobflow(job_id) - int(jf.instancecount).should.equal(base_instance_count + 5) - igs = dict((g.name, g) for g in jf.instancegroups) - int(igs['task-1'].instancerunningcount).should.equal(2) - int(igs['task-2'].instancerunningcount).should.equal(3) - - -@mock_emr_deprecated -def test_steps(): - input_steps = [ - StreamingStep( - name='My wordcount example', - mapper='s3n://elasticmapreduce/samples/wordcount/wordSplitter.py', - reducer='aggregate', - input='s3n://elasticmapreduce/samples/wordcount/input', - output='s3n://output_bucket/output/wordcount_output'), - StreamingStep( - name='My wordcount example & co.', - mapper='s3n://elasticmapreduce/samples/wordcount/wordSplitter2.py', - reducer='aggregate', - input='s3n://elasticmapreduce/samples/wordcount/input2', - output='s3n://output_bucket/output/wordcount_output2') - ] - - # TODO: implementation and test for cancel_steps - - conn = boto.connect_emr() - cluster_id = conn.run_jobflow( - steps=[input_steps[0]], - **run_jobflow_args) - - jf = conn.describe_jobflow(cluster_id) - jf.steps.should.have.length_of(1) - - conn.add_jobflow_steps(cluster_id, [input_steps[1]]) - - jf = conn.describe_jobflow(cluster_id) - jf.steps.should.have.length_of(2) - for step in jf.steps: - step.actiononfailure.should.equal('TERMINATE_JOB_FLOW') - list(arg.value for arg in step.args).should.have.length_of(8) - step.creationdatetime.should.be.a(six.string_types) - # step.enddatetime.should.be.a(six.string_types) - step.jar.should.equal( - '/home/hadoop/contrib/streaming/hadoop-streaming.jar') - step.laststatechangereason.should.be.a(six.string_types) - step.mainclass.should.equal('') - step.name.should.be.a(six.string_types) - # step.readydatetime.should.be.a(six.string_types) - # step.startdatetime.should.be.a(six.string_types) - step.state.should.be.within(['STARTING', 'PENDING']) - - expected = dict((s.name, s) for s in input_steps) - - steps = conn.list_steps(cluster_id).steps - for x in steps: - y = expected[x.name] - # actiononfailure - list(arg.value for arg in x.config.args).should.equal([ - '-mapper', y.mapper, - '-reducer', y.reducer, - '-input', y.input, - '-output', y.output, - ]) - x.config.jar.should.equal( - '/home/hadoop/contrib/streaming/hadoop-streaming.jar') - x.config.mainclass.should.equal('') - # properties - x.should.have.property('id').should.be.a(six.string_types) - x.name.should.equal(y.name) - x.status.state.should.be.within(['STARTING', 'PENDING']) - # x.status.statechangereason - x.status.timeline.creationdatetime.should.be.a(six.string_types) - # x.status.timeline.enddatetime.should.be.a(six.string_types) - # x.status.timeline.startdatetime.should.be.a(six.string_types) - - x = conn.describe_step(cluster_id, x.id) - list(arg.value for arg in x.config.args).should.equal([ - '-mapper', y.mapper, - '-reducer', y.reducer, - '-input', y.input, - '-output', y.output, - ]) - x.config.jar.should.equal( - '/home/hadoop/contrib/streaming/hadoop-streaming.jar') - x.config.mainclass.should.equal('') - # properties - x.should.have.property('id').should.be.a(six.string_types) - x.name.should.equal(y.name) - x.status.state.should.be.within(['STARTING', 'PENDING']) - # x.status.statechangereason - x.status.timeline.creationdatetime.should.be.a(six.string_types) - # x.status.timeline.enddatetime.should.be.a(six.string_types) - # x.status.timeline.startdatetime.should.be.a(six.string_types) - - @requires_boto_gte('2.39') - def test_list_steps_with_states(): - # boto's list_steps prior to 2.39 has a bug that ignores - # step_states argument. - steps = conn.list_steps(cluster_id).steps - step_id = steps[0].id - steps = conn.list_steps(cluster_id, step_states=['STARTING']).steps - steps.should.have.length_of(1) - steps[0].id.should.equal(step_id) - test_list_steps_with_states() - - -@mock_emr_deprecated -def test_tags(): - input_tags = {"tag1": "val1", "tag2": "val2"} - - conn = boto.connect_emr() - cluster_id = conn.run_jobflow(**run_jobflow_args) - - conn.add_tags(cluster_id, input_tags) - cluster = conn.describe_cluster(cluster_id) - cluster.tags.should.have.length_of(2) - dict((t.key, t.value) for t in cluster.tags).should.equal(input_tags) - - conn.remove_tags(cluster_id, list(input_tags.keys())) - cluster = conn.describe_cluster(cluster_id) - cluster.tags.should.have.length_of(0) +from __future__ import unicode_literals +import time +from datetime import datetime + +import boto +import pytz +from boto.emr.bootstrap_action import BootstrapAction +from boto.emr.instance_group import InstanceGroup +from boto.emr.step import StreamingStep + +import six +import sure # noqa + +from moto import mock_emr_deprecated +from tests.helpers import requires_boto_gte + + +run_jobflow_args = dict( + job_flow_role="EMR_EC2_DefaultRole", + keep_alive=True, + log_uri="s3://some_bucket/jobflow_logs", + master_instance_type="c1.medium", + name="My jobflow", + num_instances=2, + service_role="EMR_DefaultRole", + slave_instance_type="c1.medium", +) + + +input_instance_groups = [ + InstanceGroup(1, "MASTER", "c1.medium", "ON_DEMAND", "master"), + InstanceGroup(3, "CORE", "c1.medium", "ON_DEMAND", "core"), + InstanceGroup(6, "TASK", "c1.large", "SPOT", "task-1", "0.07"), + InstanceGroup(10, "TASK", "c1.xlarge", "SPOT", "task-2", "0.05"), +] + + +@mock_emr_deprecated +def test_describe_cluster(): + conn = boto.connect_emr() + args = run_jobflow_args.copy() + args.update( + dict( + api_params={ + "Applications.member.1.Name": "Spark", + "Applications.member.1.Version": "2.4.2", + "Configurations.member.1.Classification": "yarn-site", + "Configurations.member.1.Properties.entry.1.key": "someproperty", + "Configurations.member.1.Properties.entry.1.value": "somevalue", + "Configurations.member.1.Properties.entry.2.key": "someotherproperty", + "Configurations.member.1.Properties.entry.2.value": "someothervalue", + "Instances.EmrManagedMasterSecurityGroup": "master-security-group", + "Instances.Ec2SubnetId": "subnet-8be41cec", + }, + availability_zone="us-east-2b", + ec2_keyname="mykey", + job_flow_role="EMR_EC2_DefaultRole", + keep_alive=False, + log_uri="s3://some_bucket/jobflow_logs", + name="My jobflow", + service_role="EMR_DefaultRole", + visible_to_all_users=True, + ) + ) + cluster_id = conn.run_jobflow(**args) + input_tags = {"tag1": "val1", "tag2": "val2"} + conn.add_tags(cluster_id, input_tags) + + cluster = conn.describe_cluster(cluster_id) + cluster.applications[0].name.should.equal("Spark") + cluster.applications[0].version.should.equal("2.4.2") + cluster.autoterminate.should.equal("true") + + # configurations appear not be supplied as attributes? + + attrs = cluster.ec2instanceattributes + # AdditionalMasterSecurityGroups + # AdditionalSlaveSecurityGroups + attrs.ec2availabilityzone.should.equal(args["availability_zone"]) + attrs.ec2keyname.should.equal(args["ec2_keyname"]) + attrs.ec2subnetid.should.equal(args["api_params"]["Instances.Ec2SubnetId"]) + # EmrManagedMasterSecurityGroups + # EmrManagedSlaveSecurityGroups + attrs.iaminstanceprofile.should.equal(args["job_flow_role"]) + # ServiceAccessSecurityGroup + + cluster.id.should.equal(cluster_id) + cluster.loguri.should.equal(args["log_uri"]) + cluster.masterpublicdnsname.should.be.a(six.string_types) + cluster.name.should.equal(args["name"]) + int(cluster.normalizedinstancehours).should.equal(0) + # cluster.release_label + cluster.shouldnt.have.property("requestedamiversion") + cluster.runningamiversion.should.equal("1.0.0") + # cluster.securityconfiguration + cluster.servicerole.should.equal(args["service_role"]) + + cluster.status.state.should.equal("TERMINATED") + cluster.status.statechangereason.message.should.be.a(six.string_types) + cluster.status.statechangereason.code.should.be.a(six.string_types) + cluster.status.timeline.creationdatetime.should.be.a(six.string_types) + # cluster.status.timeline.enddatetime.should.be.a(six.string_types) + # cluster.status.timeline.readydatetime.should.be.a(six.string_types) + + dict((item.key, item.value) for item in cluster.tags).should.equal(input_tags) + + cluster.terminationprotected.should.equal("false") + cluster.visibletoallusers.should.equal("true") + + +@mock_emr_deprecated +def test_describe_jobflows(): + conn = boto.connect_emr() + args = run_jobflow_args.copy() + expected = {} + + for idx in range(4): + cluster_name = "cluster" + str(idx) + args["name"] = cluster_name + cluster_id = conn.run_jobflow(**args) + expected[cluster_id] = { + "id": cluster_id, + "name": cluster_name, + "state": "WAITING", + } + + # need sleep since it appears the timestamp is always rounded to + # the nearest second internally + time.sleep(1) + timestamp = datetime.now(pytz.utc) + time.sleep(1) + + for idx in range(4, 6): + cluster_name = "cluster" + str(idx) + args["name"] = cluster_name + cluster_id = conn.run_jobflow(**args) + conn.terminate_jobflow(cluster_id) + expected[cluster_id] = { + "id": cluster_id, + "name": cluster_name, + "state": "TERMINATED", + } + jobs = conn.describe_jobflows() + jobs.should.have.length_of(6) + + for cluster_id, y in expected.items(): + resp = conn.describe_jobflows(jobflow_ids=[cluster_id]) + resp.should.have.length_of(1) + resp[0].jobflowid.should.equal(cluster_id) + + resp = conn.describe_jobflows(states=["WAITING"]) + resp.should.have.length_of(4) + for x in resp: + x.state.should.equal("WAITING") + + resp = conn.describe_jobflows(created_before=timestamp) + resp.should.have.length_of(4) + + resp = conn.describe_jobflows(created_after=timestamp) + resp.should.have.length_of(2) + + +@mock_emr_deprecated +def test_describe_jobflow(): + conn = boto.connect_emr() + args = run_jobflow_args.copy() + args.update( + dict( + ami_version="3.8.1", + api_params={ + #'Applications.member.1.Name': 'Spark', + #'Applications.member.1.Version': '2.4.2', + #'Configurations.member.1.Classification': 'yarn-site', + #'Configurations.member.1.Properties.entry.1.key': 'someproperty', + #'Configurations.member.1.Properties.entry.1.value': 'somevalue', + #'Instances.EmrManagedMasterSecurityGroup': 'master-security-group', + "Instances.Ec2SubnetId": "subnet-8be41cec" + }, + ec2_keyname="mykey", + hadoop_version="2.4.0", + name="My jobflow", + log_uri="s3://some_bucket/jobflow_logs", + keep_alive=True, + master_instance_type="c1.medium", + slave_instance_type="c1.medium", + num_instances=2, + availability_zone="us-west-2b", + job_flow_role="EMR_EC2_DefaultRole", + service_role="EMR_DefaultRole", + visible_to_all_users=True, + ) + ) + + cluster_id = conn.run_jobflow(**args) + jf = conn.describe_jobflow(cluster_id) + jf.amiversion.should.equal(args["ami_version"]) + jf.bootstrapactions.should.equal(None) + jf.creationdatetime.should.be.a(six.string_types) + jf.should.have.property("laststatechangereason") + jf.readydatetime.should.be.a(six.string_types) + jf.startdatetime.should.be.a(six.string_types) + jf.state.should.equal("WAITING") + + jf.ec2keyname.should.equal(args["ec2_keyname"]) + # Ec2SubnetId + jf.hadoopversion.should.equal(args["hadoop_version"]) + int(jf.instancecount).should.equal(2) + + for ig in jf.instancegroups: + ig.creationdatetime.should.be.a(six.string_types) + # ig.enddatetime.should.be.a(six.string_types) + ig.should.have.property("instancegroupid").being.a(six.string_types) + int(ig.instancerequestcount).should.equal(1) + ig.instancerole.should.be.within(["MASTER", "CORE"]) + int(ig.instancerunningcount).should.equal(1) + ig.instancetype.should.equal("c1.medium") + ig.laststatechangereason.should.be.a(six.string_types) + ig.market.should.equal("ON_DEMAND") + ig.name.should.be.a(six.string_types) + ig.readydatetime.should.be.a(six.string_types) + ig.startdatetime.should.be.a(six.string_types) + ig.state.should.equal("RUNNING") + + jf.keepjobflowalivewhennosteps.should.equal("true") + jf.masterinstanceid.should.be.a(six.string_types) + jf.masterinstancetype.should.equal(args["master_instance_type"]) + jf.masterpublicdnsname.should.be.a(six.string_types) + int(jf.normalizedinstancehours).should.equal(0) + jf.availabilityzone.should.equal(args["availability_zone"]) + jf.slaveinstancetype.should.equal(args["slave_instance_type"]) + jf.terminationprotected.should.equal("false") + + jf.jobflowid.should.equal(cluster_id) + # jf.jobflowrole.should.equal(args['job_flow_role']) + jf.loguri.should.equal(args["log_uri"]) + jf.name.should.equal(args["name"]) + # jf.servicerole.should.equal(args['service_role']) + + jf.steps.should.have.length_of(0) + + list(i.value for i in jf.supported_products).should.equal([]) + jf.visibletoallusers.should.equal("true") + + +@mock_emr_deprecated +def test_list_clusters(): + conn = boto.connect_emr() + args = run_jobflow_args.copy() + expected = {} + + for idx in range(40): + cluster_name = "jobflow" + str(idx) + args["name"] = cluster_name + cluster_id = conn.run_jobflow(**args) + expected[cluster_id] = { + "id": cluster_id, + "name": cluster_name, + "normalizedinstancehours": "0", + "state": "WAITING", + } + + # need sleep since it appears the timestamp is always rounded to + # the nearest second internally + time.sleep(1) + timestamp = datetime.now(pytz.utc) + time.sleep(1) + + for idx in range(40, 70): + cluster_name = "jobflow" + str(idx) + args["name"] = cluster_name + cluster_id = conn.run_jobflow(**args) + conn.terminate_jobflow(cluster_id) + expected[cluster_id] = { + "id": cluster_id, + "name": cluster_name, + "normalizedinstancehours": "0", + "state": "TERMINATED", + } + + args = {} + while 1: + resp = conn.list_clusters(**args) + clusters = resp.clusters + len(clusters).should.be.lower_than_or_equal_to(50) + for x in clusters: + y = expected[x.id] + x.id.should.equal(y["id"]) + x.name.should.equal(y["name"]) + x.normalizedinstancehours.should.equal(y["normalizedinstancehours"]) + x.status.state.should.equal(y["state"]) + x.status.timeline.creationdatetime.should.be.a(six.string_types) + if y["state"] == "TERMINATED": + x.status.timeline.enddatetime.should.be.a(six.string_types) + else: + x.status.timeline.shouldnt.have.property("enddatetime") + x.status.timeline.readydatetime.should.be.a(six.string_types) + if not hasattr(resp, "marker"): + break + args = {"marker": resp.marker} + + resp = conn.list_clusters(cluster_states=["TERMINATED"]) + resp.clusters.should.have.length_of(30) + for x in resp.clusters: + x.status.state.should.equal("TERMINATED") + + resp = conn.list_clusters(created_before=timestamp) + resp.clusters.should.have.length_of(40) + + resp = conn.list_clusters(created_after=timestamp) + resp.clusters.should.have.length_of(30) + + +@mock_emr_deprecated +def test_run_jobflow(): + conn = boto.connect_emr() + args = run_jobflow_args.copy() + job_id = conn.run_jobflow(**args) + job_flow = conn.describe_jobflow(job_id) + job_flow.state.should.equal("WAITING") + job_flow.jobflowid.should.equal(job_id) + job_flow.name.should.equal(args["name"]) + job_flow.masterinstancetype.should.equal(args["master_instance_type"]) + job_flow.slaveinstancetype.should.equal(args["slave_instance_type"]) + job_flow.loguri.should.equal(args["log_uri"]) + job_flow.visibletoallusers.should.equal("false") + int(job_flow.normalizedinstancehours).should.equal(0) + job_flow.steps.should.have.length_of(0) + + +@mock_emr_deprecated +def test_run_jobflow_in_multiple_regions(): + regions = {} + for region in ["us-east-1", "eu-west-1"]: + conn = boto.emr.connect_to_region(region) + args = run_jobflow_args.copy() + args["name"] = region + cluster_id = conn.run_jobflow(**args) + regions[region] = {"conn": conn, "cluster_id": cluster_id} + + for region in regions.keys(): + conn = regions[region]["conn"] + jf = conn.describe_jobflow(regions[region]["cluster_id"]) + jf.name.should.equal(region) + + +@requires_boto_gte("2.8") +@mock_emr_deprecated +def test_run_jobflow_with_new_params(): + # Test that run_jobflow works with newer params + conn = boto.connect_emr() + conn.run_jobflow(**run_jobflow_args) + + +@requires_boto_gte("2.8") +@mock_emr_deprecated +def test_run_jobflow_with_visible_to_all_users(): + conn = boto.connect_emr() + for expected in (True, False): + job_id = conn.run_jobflow(visible_to_all_users=expected, **run_jobflow_args) + job_flow = conn.describe_jobflow(job_id) + job_flow.visibletoallusers.should.equal(str(expected).lower()) + + +@requires_boto_gte("2.8") +@mock_emr_deprecated +def test_run_jobflow_with_instance_groups(): + input_groups = dict((g.name, g) for g in input_instance_groups) + conn = boto.connect_emr() + job_id = conn.run_jobflow(instance_groups=input_instance_groups, **run_jobflow_args) + job_flow = conn.describe_jobflow(job_id) + int(job_flow.instancecount).should.equal( + sum(g.num_instances for g in input_instance_groups) + ) + for instance_group in job_flow.instancegroups: + expected = input_groups[instance_group.name] + instance_group.should.have.property("instancegroupid") + int(instance_group.instancerunningcount).should.equal(expected.num_instances) + instance_group.instancerole.should.equal(expected.role) + instance_group.instancetype.should.equal(expected.type) + instance_group.market.should.equal(expected.market) + if hasattr(expected, "bidprice"): + instance_group.bidprice.should.equal(expected.bidprice) + + +@requires_boto_gte("2.8") +@mock_emr_deprecated +def test_set_termination_protection(): + conn = boto.connect_emr() + job_id = conn.run_jobflow(**run_jobflow_args) + job_flow = conn.describe_jobflow(job_id) + job_flow.terminationprotected.should.equal("false") + + conn.set_termination_protection(job_id, True) + job_flow = conn.describe_jobflow(job_id) + job_flow.terminationprotected.should.equal("true") + + conn.set_termination_protection(job_id, False) + job_flow = conn.describe_jobflow(job_id) + job_flow.terminationprotected.should.equal("false") + + +@requires_boto_gte("2.8") +@mock_emr_deprecated +def test_set_visible_to_all_users(): + conn = boto.connect_emr() + args = run_jobflow_args.copy() + args["visible_to_all_users"] = False + job_id = conn.run_jobflow(**args) + job_flow = conn.describe_jobflow(job_id) + job_flow.visibletoallusers.should.equal("false") + + conn.set_visible_to_all_users(job_id, True) + job_flow = conn.describe_jobflow(job_id) + job_flow.visibletoallusers.should.equal("true") + + conn.set_visible_to_all_users(job_id, False) + job_flow = conn.describe_jobflow(job_id) + job_flow.visibletoallusers.should.equal("false") + + +@mock_emr_deprecated +def test_terminate_jobflow(): + conn = boto.connect_emr() + job_id = conn.run_jobflow(**run_jobflow_args) + flow = conn.describe_jobflows()[0] + flow.state.should.equal("WAITING") + + conn.terminate_jobflow(job_id) + flow = conn.describe_jobflows()[0] + flow.state.should.equal("TERMINATED") + + +# testing multiple end points for each feature + + +@mock_emr_deprecated +def test_bootstrap_actions(): + bootstrap_actions = [ + BootstrapAction( + name="bs1", + path="path/to/script", + bootstrap_action_args=["arg1", "arg2&arg3"], + ), + BootstrapAction( + name="bs2", path="path/to/anotherscript", bootstrap_action_args=[] + ), + ] + + conn = boto.connect_emr() + cluster_id = conn.run_jobflow( + bootstrap_actions=bootstrap_actions, **run_jobflow_args + ) + + jf = conn.describe_jobflow(cluster_id) + for x, y in zip(jf.bootstrapactions, bootstrap_actions): + x.name.should.equal(y.name) + x.path.should.equal(y.path) + list(o.value for o in x.args).should.equal(y.args()) + + resp = conn.list_bootstrap_actions(cluster_id) + for i, y in enumerate(bootstrap_actions): + x = resp.actions[i] + x.name.should.equal(y.name) + x.scriptpath.should.equal(y.path) + list(arg.value for arg in x.args).should.equal(y.args()) + + +@mock_emr_deprecated +def test_instance_groups(): + input_groups = dict((g.name, g) for g in input_instance_groups) + + conn = boto.connect_emr() + args = run_jobflow_args.copy() + for key in ["master_instance_type", "slave_instance_type", "num_instances"]: + del args[key] + args["instance_groups"] = input_instance_groups[:2] + job_id = conn.run_jobflow(**args) + + jf = conn.describe_jobflow(job_id) + base_instance_count = int(jf.instancecount) + + conn.add_instance_groups(job_id, input_instance_groups[2:]) + + jf = conn.describe_jobflow(job_id) + int(jf.instancecount).should.equal( + sum(g.num_instances for g in input_instance_groups) + ) + for x in jf.instancegroups: + y = input_groups[x.name] + if hasattr(y, "bidprice"): + x.bidprice.should.equal(y.bidprice) + x.creationdatetime.should.be.a(six.string_types) + # x.enddatetime.should.be.a(six.string_types) + x.should.have.property("instancegroupid") + int(x.instancerequestcount).should.equal(y.num_instances) + x.instancerole.should.equal(y.role) + int(x.instancerunningcount).should.equal(y.num_instances) + x.instancetype.should.equal(y.type) + x.laststatechangereason.should.be.a(six.string_types) + x.market.should.equal(y.market) + x.name.should.be.a(six.string_types) + x.readydatetime.should.be.a(six.string_types) + x.startdatetime.should.be.a(six.string_types) + x.state.should.equal("RUNNING") + + for x in conn.list_instance_groups(job_id).instancegroups: + y = input_groups[x.name] + if hasattr(y, "bidprice"): + x.bidprice.should.equal(y.bidprice) + # Configurations + # EbsBlockDevices + # EbsOptimized + x.should.have.property("id") + x.instancegrouptype.should.equal(y.role) + x.instancetype.should.equal(y.type) + x.market.should.equal(y.market) + x.name.should.equal(y.name) + int(x.requestedinstancecount).should.equal(y.num_instances) + int(x.runninginstancecount).should.equal(y.num_instances) + # ShrinkPolicy + x.status.state.should.equal("RUNNING") + x.status.statechangereason.code.should.be.a(six.string_types) + x.status.statechangereason.message.should.be.a(six.string_types) + x.status.timeline.creationdatetime.should.be.a(six.string_types) + # x.status.timeline.enddatetime.should.be.a(six.string_types) + x.status.timeline.readydatetime.should.be.a(six.string_types) + + igs = dict((g.name, g) for g in jf.instancegroups) + + conn.modify_instance_groups( + [igs["task-1"].instancegroupid, igs["task-2"].instancegroupid], [2, 3] + ) + jf = conn.describe_jobflow(job_id) + int(jf.instancecount).should.equal(base_instance_count + 5) + igs = dict((g.name, g) for g in jf.instancegroups) + int(igs["task-1"].instancerunningcount).should.equal(2) + int(igs["task-2"].instancerunningcount).should.equal(3) + + +@mock_emr_deprecated +def test_steps(): + input_steps = [ + StreamingStep( + name="My wordcount example", + mapper="s3n://elasticmapreduce/samples/wordcount/wordSplitter.py", + reducer="aggregate", + input="s3n://elasticmapreduce/samples/wordcount/input", + output="s3n://output_bucket/output/wordcount_output", + ), + StreamingStep( + name="My wordcount example & co.", + mapper="s3n://elasticmapreduce/samples/wordcount/wordSplitter2.py", + reducer="aggregate", + input="s3n://elasticmapreduce/samples/wordcount/input2", + output="s3n://output_bucket/output/wordcount_output2", + ), + ] + + # TODO: implementation and test for cancel_steps + + conn = boto.connect_emr() + cluster_id = conn.run_jobflow(steps=[input_steps[0]], **run_jobflow_args) + + jf = conn.describe_jobflow(cluster_id) + jf.steps.should.have.length_of(1) + + conn.add_jobflow_steps(cluster_id, [input_steps[1]]) + + jf = conn.describe_jobflow(cluster_id) + jf.steps.should.have.length_of(2) + for step in jf.steps: + step.actiononfailure.should.equal("TERMINATE_JOB_FLOW") + list(arg.value for arg in step.args).should.have.length_of(8) + step.creationdatetime.should.be.a(six.string_types) + # step.enddatetime.should.be.a(six.string_types) + step.jar.should.equal("/home/hadoop/contrib/streaming/hadoop-streaming.jar") + step.laststatechangereason.should.be.a(six.string_types) + step.mainclass.should.equal("") + step.name.should.be.a(six.string_types) + # step.readydatetime.should.be.a(six.string_types) + # step.startdatetime.should.be.a(six.string_types) + step.state.should.be.within(["STARTING", "PENDING"]) + + expected = dict((s.name, s) for s in input_steps) + + steps = conn.list_steps(cluster_id).steps + for x in steps: + y = expected[x.name] + # actiononfailure + list(arg.value for arg in x.config.args).should.equal( + [ + "-mapper", + y.mapper, + "-reducer", + y.reducer, + "-input", + y.input, + "-output", + y.output, + ] + ) + x.config.jar.should.equal("/home/hadoop/contrib/streaming/hadoop-streaming.jar") + x.config.mainclass.should.equal("") + # properties + x.should.have.property("id").should.be.a(six.string_types) + x.name.should.equal(y.name) + x.status.state.should.be.within(["STARTING", "PENDING"]) + # x.status.statechangereason + x.status.timeline.creationdatetime.should.be.a(six.string_types) + # x.status.timeline.enddatetime.should.be.a(six.string_types) + # x.status.timeline.startdatetime.should.be.a(six.string_types) + + x = conn.describe_step(cluster_id, x.id) + list(arg.value for arg in x.config.args).should.equal( + [ + "-mapper", + y.mapper, + "-reducer", + y.reducer, + "-input", + y.input, + "-output", + y.output, + ] + ) + x.config.jar.should.equal("/home/hadoop/contrib/streaming/hadoop-streaming.jar") + x.config.mainclass.should.equal("") + # properties + x.should.have.property("id").should.be.a(six.string_types) + x.name.should.equal(y.name) + x.status.state.should.be.within(["STARTING", "PENDING"]) + # x.status.statechangereason + x.status.timeline.creationdatetime.should.be.a(six.string_types) + # x.status.timeline.enddatetime.should.be.a(six.string_types) + # x.status.timeline.startdatetime.should.be.a(six.string_types) + + @requires_boto_gte("2.39") + def test_list_steps_with_states(): + # boto's list_steps prior to 2.39 has a bug that ignores + # step_states argument. + steps = conn.list_steps(cluster_id).steps + step_id = steps[0].id + steps = conn.list_steps(cluster_id, step_states=["STARTING"]).steps + steps.should.have.length_of(1) + steps[0].id.should.equal(step_id) + + test_list_steps_with_states() + + +@mock_emr_deprecated +def test_tags(): + input_tags = {"tag1": "val1", "tag2": "val2"} + + conn = boto.connect_emr() + cluster_id = conn.run_jobflow(**run_jobflow_args) + + conn.add_tags(cluster_id, input_tags) + cluster = conn.describe_cluster(cluster_id) + cluster.tags.should.have.length_of(2) + dict((t.key, t.value) for t in cluster.tags).should.equal(input_tags) + + conn.remove_tags(cluster_id, list(input_tags.keys())) + cluster = conn.describe_cluster(cluster_id) + cluster.tags.should.have.length_of(0) diff --git a/tests/test_emr/test_emr_boto3.py b/tests/test_emr/test_emr_boto3.py index b9a5025d9..212444abf 100644 --- a/tests/test_emr/test_emr_boto3.py +++ b/tests/test_emr/test_emr_boto3.py @@ -16,158 +16,176 @@ from moto import mock_emr run_job_flow_args = dict( Instances={ - 'InstanceCount': 3, - 'KeepJobFlowAliveWhenNoSteps': True, - 'MasterInstanceType': 'c3.medium', - 'Placement': {'AvailabilityZone': 'us-east-1a'}, - 'SlaveInstanceType': 'c3.xlarge', + "InstanceCount": 3, + "KeepJobFlowAliveWhenNoSteps": True, + "MasterInstanceType": "c3.medium", + "Placement": {"AvailabilityZone": "us-east-1a"}, + "SlaveInstanceType": "c3.xlarge", }, - JobFlowRole='EMR_EC2_DefaultRole', - LogUri='s3://mybucket/log', - Name='cluster', - ServiceRole='EMR_DefaultRole', - VisibleToAllUsers=True) + JobFlowRole="EMR_EC2_DefaultRole", + LogUri="s3://mybucket/log", + Name="cluster", + ServiceRole="EMR_DefaultRole", + VisibleToAllUsers=True, +) input_instance_groups = [ - {'InstanceCount': 1, - 'InstanceRole': 'MASTER', - 'InstanceType': 'c1.medium', - 'Market': 'ON_DEMAND', - 'Name': 'master'}, - {'InstanceCount': 3, - 'InstanceRole': 'CORE', - 'InstanceType': 'c1.medium', - 'Market': 'ON_DEMAND', - 'Name': 'core'}, - {'InstanceCount': 6, - 'InstanceRole': 'TASK', - 'InstanceType': 'c1.large', - 'Market': 'SPOT', - 'Name': 'task-1', - 'BidPrice': '0.07'}, - {'InstanceCount': 10, - 'InstanceRole': 'TASK', - 'InstanceType': 'c1.xlarge', - 'Market': 'SPOT', - 'Name': 'task-2', - 'BidPrice': '0.05'}, + { + "InstanceCount": 1, + "InstanceRole": "MASTER", + "InstanceType": "c1.medium", + "Market": "ON_DEMAND", + "Name": "master", + }, + { + "InstanceCount": 3, + "InstanceRole": "CORE", + "InstanceType": "c1.medium", + "Market": "ON_DEMAND", + "Name": "core", + }, + { + "InstanceCount": 6, + "InstanceRole": "TASK", + "InstanceType": "c1.large", + "Market": "SPOT", + "Name": "task-1", + "BidPrice": "0.07", + }, + { + "InstanceCount": 10, + "InstanceRole": "TASK", + "InstanceType": "c1.xlarge", + "Market": "SPOT", + "Name": "task-2", + "BidPrice": "0.05", + }, ] @mock_emr def test_describe_cluster(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Applications'] = [{'Name': 'Spark', 'Version': '2.4.2'}] - args['Configurations'] = [ - {'Classification': 'yarn-site', - 'Properties': {'someproperty': 'somevalue', - 'someotherproperty': 'someothervalue'}}, - {'Classification': 'nested-configs', - 'Properties': {}, - 'Configurations': [ - { - 'Classification': 'nested-config', - 'Properties': { - 'nested-property': 'nested-value' - } - } - ]} + args["Applications"] = [{"Name": "Spark", "Version": "2.4.2"}] + args["Configurations"] = [ + { + "Classification": "yarn-site", + "Properties": { + "someproperty": "somevalue", + "someotherproperty": "someothervalue", + }, + }, + { + "Classification": "nested-configs", + "Properties": {}, + "Configurations": [ + { + "Classification": "nested-config", + "Properties": {"nested-property": "nested-value"}, + } + ], + }, ] - args['Instances']['AdditionalMasterSecurityGroups'] = ['additional-master'] - args['Instances']['AdditionalSlaveSecurityGroups'] = ['additional-slave'] - args['Instances']['Ec2KeyName'] = 'mykey' - args['Instances']['Ec2SubnetId'] = 'subnet-8be41cec' - args['Instances']['EmrManagedMasterSecurityGroup'] = 'master-security-group' - args['Instances']['EmrManagedSlaveSecurityGroup'] = 'slave-security-group' - args['Instances']['KeepJobFlowAliveWhenNoSteps'] = False - args['Instances']['ServiceAccessSecurityGroup'] = 'service-access-security-group' - args['Tags'] = [{'Key': 'tag1', 'Value': 'val1'}, - {'Key': 'tag2', 'Value': 'val2'}] + args["Instances"]["AdditionalMasterSecurityGroups"] = ["additional-master"] + args["Instances"]["AdditionalSlaveSecurityGroups"] = ["additional-slave"] + args["Instances"]["Ec2KeyName"] = "mykey" + args["Instances"]["Ec2SubnetId"] = "subnet-8be41cec" + args["Instances"]["EmrManagedMasterSecurityGroup"] = "master-security-group" + args["Instances"]["EmrManagedSlaveSecurityGroup"] = "slave-security-group" + args["Instances"]["KeepJobFlowAliveWhenNoSteps"] = False + args["Instances"]["ServiceAccessSecurityGroup"] = "service-access-security-group" + args["Tags"] = [{"Key": "tag1", "Value": "val1"}, {"Key": "tag2", "Value": "val2"}] - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - cl = client.describe_cluster(ClusterId=cluster_id)['Cluster'] - cl['Applications'][0]['Name'].should.equal('Spark') - cl['Applications'][0]['Version'].should.equal('2.4.2') - cl['AutoTerminate'].should.equal(True) + cl = client.describe_cluster(ClusterId=cluster_id)["Cluster"] + cl["Applications"][0]["Name"].should.equal("Spark") + cl["Applications"][0]["Version"].should.equal("2.4.2") + cl["AutoTerminate"].should.equal(True) - config = cl['Configurations'][0] - config['Classification'].should.equal('yarn-site') - config['Properties'].should.equal(args['Configurations'][0]['Properties']) + config = cl["Configurations"][0] + config["Classification"].should.equal("yarn-site") + config["Properties"].should.equal(args["Configurations"][0]["Properties"]) - nested_config = cl['Configurations'][1] - nested_config['Classification'].should.equal('nested-configs') - nested_config['Properties'].should.equal(args['Configurations'][1]['Properties']) + nested_config = cl["Configurations"][1] + nested_config["Classification"].should.equal("nested-configs") + nested_config["Properties"].should.equal(args["Configurations"][1]["Properties"]) - attrs = cl['Ec2InstanceAttributes'] - attrs['AdditionalMasterSecurityGroups'].should.equal( - args['Instances']['AdditionalMasterSecurityGroups']) - attrs['AdditionalSlaveSecurityGroups'].should.equal( - args['Instances']['AdditionalSlaveSecurityGroups']) - attrs['Ec2AvailabilityZone'].should.equal('us-east-1a') - attrs['Ec2KeyName'].should.equal(args['Instances']['Ec2KeyName']) - attrs['Ec2SubnetId'].should.equal(args['Instances']['Ec2SubnetId']) - attrs['EmrManagedMasterSecurityGroup'].should.equal( - args['Instances']['EmrManagedMasterSecurityGroup']) - attrs['EmrManagedSlaveSecurityGroup'].should.equal( - args['Instances']['EmrManagedSlaveSecurityGroup']) - attrs['IamInstanceProfile'].should.equal(args['JobFlowRole']) - attrs['ServiceAccessSecurityGroup'].should.equal( - args['Instances']['ServiceAccessSecurityGroup']) - cl['Id'].should.equal(cluster_id) - cl['LogUri'].should.equal(args['LogUri']) - cl['MasterPublicDnsName'].should.be.a(six.string_types) - cl['Name'].should.equal(args['Name']) - cl['NormalizedInstanceHours'].should.equal(0) + attrs = cl["Ec2InstanceAttributes"] + attrs["AdditionalMasterSecurityGroups"].should.equal( + args["Instances"]["AdditionalMasterSecurityGroups"] + ) + attrs["AdditionalSlaveSecurityGroups"].should.equal( + args["Instances"]["AdditionalSlaveSecurityGroups"] + ) + attrs["Ec2AvailabilityZone"].should.equal("us-east-1a") + attrs["Ec2KeyName"].should.equal(args["Instances"]["Ec2KeyName"]) + attrs["Ec2SubnetId"].should.equal(args["Instances"]["Ec2SubnetId"]) + attrs["EmrManagedMasterSecurityGroup"].should.equal( + args["Instances"]["EmrManagedMasterSecurityGroup"] + ) + attrs["EmrManagedSlaveSecurityGroup"].should.equal( + args["Instances"]["EmrManagedSlaveSecurityGroup"] + ) + attrs["IamInstanceProfile"].should.equal(args["JobFlowRole"]) + attrs["ServiceAccessSecurityGroup"].should.equal( + args["Instances"]["ServiceAccessSecurityGroup"] + ) + cl["Id"].should.equal(cluster_id) + cl["LogUri"].should.equal(args["LogUri"]) + cl["MasterPublicDnsName"].should.be.a(six.string_types) + cl["Name"].should.equal(args["Name"]) + cl["NormalizedInstanceHours"].should.equal(0) # cl['ReleaseLabel'].should.equal('emr-5.0.0') - cl.shouldnt.have.key('RequestedAmiVersion') - cl['RunningAmiVersion'].should.equal('1.0.0') + cl.shouldnt.have.key("RequestedAmiVersion") + cl["RunningAmiVersion"].should.equal("1.0.0") # cl['SecurityConfiguration'].should.be.a(six.string_types) - cl['ServiceRole'].should.equal(args['ServiceRole']) + cl["ServiceRole"].should.equal(args["ServiceRole"]) - status = cl['Status'] - status['State'].should.equal('TERMINATED') + status = cl["Status"] + status["State"].should.equal("TERMINATED") # cluster['Status']['StateChangeReason'] - status['Timeline']['CreationDateTime'].should.be.a('datetime.datetime') + status["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # status['Timeline']['EndDateTime'].should.equal(datetime(2014, 1, 24, 2, 19, 46, tzinfo=pytz.utc)) - status['Timeline']['ReadyDateTime'].should.be.a('datetime.datetime') + status["Timeline"]["ReadyDateTime"].should.be.a("datetime.datetime") - dict((t['Key'], t['Value']) for t in cl['Tags']).should.equal( - dict((t['Key'], t['Value']) for t in args['Tags'])) + dict((t["Key"], t["Value"]) for t in cl["Tags"]).should.equal( + dict((t["Key"], t["Value"]) for t in args["Tags"]) + ) - cl['TerminationProtected'].should.equal(False) - cl['VisibleToAllUsers'].should.equal(True) + cl["TerminationProtected"].should.equal(False) + cl["VisibleToAllUsers"].should.equal(True) @mock_emr def test_describe_cluster_not_found(): - conn = boto3.client('emr', region_name='us-east-1') + conn = boto3.client("emr", region_name="us-east-1") raised = False try: - cluster = conn.describe_cluster(ClusterId='DummyId') + cluster = conn.describe_cluster(ClusterId="DummyId") except ClientError as e: - if e.response['Error']['Code'] == "ResourceNotFoundException": + if e.response["Error"]["Code"] == "ResourceNotFoundException": raised = True raised.should.equal(True) @mock_emr def test_describe_job_flows(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) expected = {} for idx in range(4): - cluster_name = 'cluster' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "cluster" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'State': 'WAITING' + "Id": cluster_id, + "Name": cluster_name, + "State": "WAITING", } # need sleep since it appears the timestamp is always rounded to @@ -177,117 +195,119 @@ def test_describe_job_flows(): time.sleep(1) for idx in range(4, 6): - cluster_name = 'cluster' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "cluster" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] client.terminate_job_flows(JobFlowIds=[cluster_id]) expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'State': 'TERMINATED' + "Id": cluster_id, + "Name": cluster_name, + "State": "TERMINATED", } resp = client.describe_job_flows() - resp['JobFlows'].should.have.length_of(6) + resp["JobFlows"].should.have.length_of(6) for cluster_id, y in expected.items(): resp = client.describe_job_flows(JobFlowIds=[cluster_id]) - resp['JobFlows'].should.have.length_of(1) - resp['JobFlows'][0]['JobFlowId'].should.equal(cluster_id) + resp["JobFlows"].should.have.length_of(1) + resp["JobFlows"][0]["JobFlowId"].should.equal(cluster_id) - resp = client.describe_job_flows(JobFlowStates=['WAITING']) - resp['JobFlows'].should.have.length_of(4) - for x in resp['JobFlows']: - x['ExecutionStatusDetail']['State'].should.equal('WAITING') + resp = client.describe_job_flows(JobFlowStates=["WAITING"]) + resp["JobFlows"].should.have.length_of(4) + for x in resp["JobFlows"]: + x["ExecutionStatusDetail"]["State"].should.equal("WAITING") resp = client.describe_job_flows(CreatedBefore=timestamp) - resp['JobFlows'].should.have.length_of(4) + resp["JobFlows"].should.have.length_of(4) resp = client.describe_job_flows(CreatedAfter=timestamp) - resp['JobFlows'].should.have.length_of(2) + resp["JobFlows"].should.have.length_of(2) @mock_emr def test_describe_job_flow(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['AmiVersion'] = '3.8.1' - args['Instances'].update( - {'Ec2KeyName': 'ec2keyname', - 'Ec2SubnetId': 'subnet-8be41cec', - 'HadoopVersion': '2.4.0'}) - args['VisibleToAllUsers'] = True + args["AmiVersion"] = "3.8.1" + args["Instances"].update( + { + "Ec2KeyName": "ec2keyname", + "Ec2SubnetId": "subnet-8be41cec", + "HadoopVersion": "2.4.0", + } + ) + args["VisibleToAllUsers"] = True - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] - jf['AmiVersion'].should.equal(args['AmiVersion']) - jf.shouldnt.have.key('BootstrapActions') - esd = jf['ExecutionStatusDetail'] - esd['CreationDateTime'].should.be.a('datetime.datetime') + jf["AmiVersion"].should.equal(args["AmiVersion"]) + jf.shouldnt.have.key("BootstrapActions") + esd = jf["ExecutionStatusDetail"] + esd["CreationDateTime"].should.be.a("datetime.datetime") # esd['EndDateTime'].should.be.a('datetime.datetime') # esd['LastStateChangeReason'].should.be.a(six.string_types) - esd['ReadyDateTime'].should.be.a('datetime.datetime') - esd['StartDateTime'].should.be.a('datetime.datetime') - esd['State'].should.equal('WAITING') - attrs = jf['Instances'] - attrs['Ec2KeyName'].should.equal(args['Instances']['Ec2KeyName']) - attrs['Ec2SubnetId'].should.equal(args['Instances']['Ec2SubnetId']) - attrs['HadoopVersion'].should.equal(args['Instances']['HadoopVersion']) - attrs['InstanceCount'].should.equal(args['Instances']['InstanceCount']) - for ig in attrs['InstanceGroups']: + esd["ReadyDateTime"].should.be.a("datetime.datetime") + esd["StartDateTime"].should.be.a("datetime.datetime") + esd["State"].should.equal("WAITING") + attrs = jf["Instances"] + attrs["Ec2KeyName"].should.equal(args["Instances"]["Ec2KeyName"]) + attrs["Ec2SubnetId"].should.equal(args["Instances"]["Ec2SubnetId"]) + attrs["HadoopVersion"].should.equal(args["Instances"]["HadoopVersion"]) + attrs["InstanceCount"].should.equal(args["Instances"]["InstanceCount"]) + for ig in attrs["InstanceGroups"]: # ig['BidPrice'] - ig['CreationDateTime'].should.be.a('datetime.datetime') + ig["CreationDateTime"].should.be.a("datetime.datetime") # ig['EndDateTime'].should.be.a('datetime.datetime') - ig['InstanceGroupId'].should.be.a(six.string_types) - ig['InstanceRequestCount'].should.be.a(int) - ig['InstanceRole'].should.be.within(['MASTER', 'CORE']) - ig['InstanceRunningCount'].should.be.a(int) - ig['InstanceType'].should.be.within(['c3.medium', 'c3.xlarge']) + ig["InstanceGroupId"].should.be.a(six.string_types) + ig["InstanceRequestCount"].should.be.a(int) + ig["InstanceRole"].should.be.within(["MASTER", "CORE"]) + ig["InstanceRunningCount"].should.be.a(int) + ig["InstanceType"].should.be.within(["c3.medium", "c3.xlarge"]) # ig['LastStateChangeReason'].should.be.a(six.string_types) - ig['Market'].should.equal('ON_DEMAND') - ig['Name'].should.be.a(six.string_types) - ig['ReadyDateTime'].should.be.a('datetime.datetime') - ig['StartDateTime'].should.be.a('datetime.datetime') - ig['State'].should.equal('RUNNING') - attrs['KeepJobFlowAliveWhenNoSteps'].should.equal(True) + ig["Market"].should.equal("ON_DEMAND") + ig["Name"].should.be.a(six.string_types) + ig["ReadyDateTime"].should.be.a("datetime.datetime") + ig["StartDateTime"].should.be.a("datetime.datetime") + ig["State"].should.equal("RUNNING") + attrs["KeepJobFlowAliveWhenNoSteps"].should.equal(True) # attrs['MasterInstanceId'].should.be.a(six.string_types) - attrs['MasterInstanceType'].should.equal( - args['Instances']['MasterInstanceType']) - attrs['MasterPublicDnsName'].should.be.a(six.string_types) - attrs['NormalizedInstanceHours'].should.equal(0) - attrs['Placement']['AvailabilityZone'].should.equal( - args['Instances']['Placement']['AvailabilityZone']) - attrs['SlaveInstanceType'].should.equal( - args['Instances']['SlaveInstanceType']) - attrs['TerminationProtected'].should.equal(False) - jf['JobFlowId'].should.equal(cluster_id) - jf['JobFlowRole'].should.equal(args['JobFlowRole']) - jf['LogUri'].should.equal(args['LogUri']) - jf['Name'].should.equal(args['Name']) - jf['ServiceRole'].should.equal(args['ServiceRole']) - jf['Steps'].should.equal([]) - jf['SupportedProducts'].should.equal([]) - jf['VisibleToAllUsers'].should.equal(True) + attrs["MasterInstanceType"].should.equal(args["Instances"]["MasterInstanceType"]) + attrs["MasterPublicDnsName"].should.be.a(six.string_types) + attrs["NormalizedInstanceHours"].should.equal(0) + attrs["Placement"]["AvailabilityZone"].should.equal( + args["Instances"]["Placement"]["AvailabilityZone"] + ) + attrs["SlaveInstanceType"].should.equal(args["Instances"]["SlaveInstanceType"]) + attrs["TerminationProtected"].should.equal(False) + jf["JobFlowId"].should.equal(cluster_id) + jf["JobFlowRole"].should.equal(args["JobFlowRole"]) + jf["LogUri"].should.equal(args["LogUri"]) + jf["Name"].should.equal(args["Name"]) + jf["ServiceRole"].should.equal(args["ServiceRole"]) + jf["Steps"].should.equal([]) + jf["SupportedProducts"].should.equal([]) + jf["VisibleToAllUsers"].should.equal(True) @mock_emr def test_list_clusters(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) expected = {} for idx in range(40): - cluster_name = 'jobflow' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "jobflow" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'NormalizedInstanceHours': 0, - 'State': 'WAITING' + "Id": cluster_id, + "Name": cluster_name, + "NormalizedInstanceHours": 0, + "State": "WAITING", } # need sleep since it appears the timestamp is always rounded to @@ -297,465 +317,484 @@ def test_list_clusters(): time.sleep(1) for idx in range(40, 70): - cluster_name = 'jobflow' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "jobflow" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] client.terminate_job_flows(JobFlowIds=[cluster_id]) expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'NormalizedInstanceHours': 0, - 'State': 'TERMINATED' + "Id": cluster_id, + "Name": cluster_name, + "NormalizedInstanceHours": 0, + "State": "TERMINATED", } args = {} while 1: resp = client.list_clusters(**args) - clusters = resp['Clusters'] + clusters = resp["Clusters"] len(clusters).should.be.lower_than_or_equal_to(50) for x in clusters: - y = expected[x['Id']] - x['Id'].should.equal(y['Id']) - x['Name'].should.equal(y['Name']) - x['NormalizedInstanceHours'].should.equal( - y['NormalizedInstanceHours']) - x['Status']['State'].should.equal(y['State']) - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') - if y['State'] == 'TERMINATED': - x['Status']['Timeline'][ - 'EndDateTime'].should.be.a('datetime.datetime') + y = expected[x["Id"]] + x["Id"].should.equal(y["Id"]) + x["Name"].should.equal(y["Name"]) + x["NormalizedInstanceHours"].should.equal(y["NormalizedInstanceHours"]) + x["Status"]["State"].should.equal(y["State"]) + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") + if y["State"] == "TERMINATED": + x["Status"]["Timeline"]["EndDateTime"].should.be.a("datetime.datetime") else: - x['Status']['Timeline'].shouldnt.have.key('EndDateTime') - x['Status']['Timeline'][ - 'ReadyDateTime'].should.be.a('datetime.datetime') - marker = resp.get('Marker') + x["Status"]["Timeline"].shouldnt.have.key("EndDateTime") + x["Status"]["Timeline"]["ReadyDateTime"].should.be.a("datetime.datetime") + marker = resp.get("Marker") if marker is None: break - args = {'Marker': marker} + args = {"Marker": marker} - resp = client.list_clusters(ClusterStates=['TERMINATED']) - resp['Clusters'].should.have.length_of(30) - for x in resp['Clusters']: - x['Status']['State'].should.equal('TERMINATED') + resp = client.list_clusters(ClusterStates=["TERMINATED"]) + resp["Clusters"].should.have.length_of(30) + for x in resp["Clusters"]: + x["Status"]["State"].should.equal("TERMINATED") resp = client.list_clusters(CreatedBefore=timestamp) - resp['Clusters'].should.have.length_of(40) + resp["Clusters"].should.have.length_of(40) resp = client.list_clusters(CreatedAfter=timestamp) - resp['Clusters'].should.have.length_of(30) + resp["Clusters"].should.have.length_of(30) @mock_emr def test_run_job_flow(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - cluster_id = client.run_job_flow(**args)['JobFlowId'] - resp = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - resp['ExecutionStatusDetail']['State'].should.equal('WAITING') - resp['JobFlowId'].should.equal(cluster_id) - resp['Name'].should.equal(args['Name']) - resp['Instances']['MasterInstanceType'].should.equal( - args['Instances']['MasterInstanceType']) - resp['Instances']['SlaveInstanceType'].should.equal( - args['Instances']['SlaveInstanceType']) - resp['LogUri'].should.equal(args['LogUri']) - resp['VisibleToAllUsers'].should.equal(args['VisibleToAllUsers']) - resp['Instances']['NormalizedInstanceHours'].should.equal(0) - resp['Steps'].should.equal([]) + cluster_id = client.run_job_flow(**args)["JobFlowId"] + resp = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + resp["ExecutionStatusDetail"]["State"].should.equal("WAITING") + resp["JobFlowId"].should.equal(cluster_id) + resp["Name"].should.equal(args["Name"]) + resp["Instances"]["MasterInstanceType"].should.equal( + args["Instances"]["MasterInstanceType"] + ) + resp["Instances"]["SlaveInstanceType"].should.equal( + args["Instances"]["SlaveInstanceType"] + ) + resp["LogUri"].should.equal(args["LogUri"]) + resp["VisibleToAllUsers"].should.equal(args["VisibleToAllUsers"]) + resp["Instances"]["NormalizedInstanceHours"].should.equal(0) + resp["Steps"].should.equal([]) @mock_emr def test_run_job_flow_with_invalid_params(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") with assert_raises(ClientError) as ex: # cannot set both AmiVersion and ReleaseLabel args = deepcopy(run_job_flow_args) - args['AmiVersion'] = '2.4' - args['ReleaseLabel'] = 'emr-5.0.0' + args["AmiVersion"] = "2.4" + args["ReleaseLabel"] = "emr-5.0.0" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") @mock_emr def test_run_job_flow_in_multiple_regions(): regions = {} - for region in ['us-east-1', 'eu-west-1']: - client = boto3.client('emr', region_name=region) + for region in ["us-east-1", "eu-west-1"]: + client = boto3.client("emr", region_name=region) args = deepcopy(run_job_flow_args) - args['Name'] = region - cluster_id = client.run_job_flow(**args)['JobFlowId'] - regions[region] = {'client': client, 'cluster_id': cluster_id} + args["Name"] = region + cluster_id = client.run_job_flow(**args)["JobFlowId"] + regions[region] = {"client": client, "cluster_id": cluster_id} for region in regions.keys(): - client = regions[region]['client'] - resp = client.describe_cluster(ClusterId=regions[region]['cluster_id']) - resp['Cluster']['Name'].should.equal(region) + client = regions[region]["client"] + resp = client.describe_cluster(ClusterId=regions[region]["cluster_id"]) + resp["Cluster"]["Name"].should.equal(region) @mock_emr def test_run_job_flow_with_new_params(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") resp = client.run_job_flow(**run_job_flow_args) - resp.should.have.key('JobFlowId') + resp.should.have.key("JobFlowId") @mock_emr def test_run_job_flow_with_visible_to_all_users(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") for expected in (True, False): args = deepcopy(run_job_flow_args) - args['VisibleToAllUsers'] = expected + args["VisibleToAllUsers"] = expected resp = client.run_job_flow(**args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['VisibleToAllUsers'].should.equal(expected) + resp["Cluster"]["VisibleToAllUsers"].should.equal(expected) @mock_emr def test_run_job_flow_with_instance_groups(): - input_groups = dict((g['Name'], g) for g in input_instance_groups) - client = boto3.client('emr', region_name='us-east-1') + input_groups = dict((g["Name"], g) for g in input_instance_groups) + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Instances'] = {'InstanceGroups': input_instance_groups} - cluster_id = client.run_job_flow(**args)['JobFlowId'] - groups = client.list_instance_groups(ClusterId=cluster_id)[ - 'InstanceGroups'] + args["Instances"] = {"InstanceGroups": input_instance_groups} + cluster_id = client.run_job_flow(**args)["JobFlowId"] + groups = client.list_instance_groups(ClusterId=cluster_id)["InstanceGroups"] for x in groups: - y = input_groups[x['Name']] - x.should.have.key('Id') - x['RequestedInstanceCount'].should.equal(y['InstanceCount']) - x['InstanceGroupType'].should.equal(y['InstanceRole']) - x['InstanceType'].should.equal(y['InstanceType']) - x['Market'].should.equal(y['Market']) - if 'BidPrice' in y: - x['BidPrice'].should.equal(y['BidPrice']) + y = input_groups[x["Name"]] + x.should.have.key("Id") + x["RequestedInstanceCount"].should.equal(y["InstanceCount"]) + x["InstanceGroupType"].should.equal(y["InstanceRole"]) + x["InstanceType"].should.equal(y["InstanceType"]) + x["Market"].should.equal(y["Market"]) + if "BidPrice" in y: + x["BidPrice"].should.equal(y["BidPrice"]) @mock_emr def test_run_job_flow_with_custom_ami(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") with assert_raises(ClientError) as ex: # CustomAmiId available in Amazon EMR 5.7.0 and later args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomId' - args['ReleaseLabel'] = 'emr-5.6.0' + args["CustomAmiId"] = "MyEmrCustomId" + args["ReleaseLabel"] = "emr-5.6.0" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal('Custom AMI is not allowed') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal("Custom AMI is not allowed") with assert_raises(ClientError) as ex: args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomId' - args['AmiVersion'] = '3.8.1' + args["CustomAmiId"] = "MyEmrCustomId" + args["AmiVersion"] = "3.8.1" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal( - 'Custom AMI is not supported in this version of EMR') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Custom AMI is not supported in this version of EMR" + ) with assert_raises(ClientError) as ex: # AMI version and release label exception raises before CustomAmi exception args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomId' - args['ReleaseLabel'] = 'emr-5.6.0' - args['AmiVersion'] = '3.8.1' + args["CustomAmiId"] = "MyEmrCustomId" + args["ReleaseLabel"] = "emr-5.6.0" + args["AmiVersion"] = "3.8.1" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.contain( - 'Only one AMI version and release label may be specified.') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.contain( + "Only one AMI version and release label may be specified." + ) args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomAmi' - args['ReleaseLabel'] = 'emr-5.7.0' - cluster_id = client.run_job_flow(**args)['JobFlowId'] + args["CustomAmiId"] = "MyEmrCustomAmi" + args["ReleaseLabel"] = "emr-5.7.0" + cluster_id = client.run_job_flow(**args)["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['CustomAmiId'].should.equal('MyEmrCustomAmi') + resp["Cluster"]["CustomAmiId"].should.equal("MyEmrCustomAmi") @mock_emr def test_set_termination_protection(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Instances']['TerminationProtected'] = False + args["Instances"]["TerminationProtected"] = False resp = client.run_job_flow(**args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['TerminationProtected'].should.equal(False) + resp["Cluster"]["TerminationProtected"].should.equal(False) for expected in (True, False): - resp = client.set_termination_protection(JobFlowIds=[cluster_id], - TerminationProtected=expected) + resp = client.set_termination_protection( + JobFlowIds=[cluster_id], TerminationProtected=expected + ) resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['TerminationProtected'].should.equal(expected) + resp["Cluster"]["TerminationProtected"].should.equal(expected) @mock_emr def test_set_visible_to_all_users(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['VisibleToAllUsers'] = False + args["VisibleToAllUsers"] = False resp = client.run_job_flow(**args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['VisibleToAllUsers'].should.equal(False) + resp["Cluster"]["VisibleToAllUsers"].should.equal(False) for expected in (True, False): - resp = client.set_visible_to_all_users(JobFlowIds=[cluster_id], - VisibleToAllUsers=expected) + resp = client.set_visible_to_all_users( + JobFlowIds=[cluster_id], VisibleToAllUsers=expected + ) resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['VisibleToAllUsers'].should.equal(expected) + resp["Cluster"]["VisibleToAllUsers"].should.equal(expected) @mock_emr def test_terminate_job_flows(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") resp = client.run_job_flow(**run_job_flow_args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['Status']['State'].should.equal('WAITING') + resp["Cluster"]["Status"]["State"].should.equal("WAITING") resp = client.terminate_job_flows(JobFlowIds=[cluster_id]) resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['Status']['State'].should.equal('TERMINATED') + resp["Cluster"]["Status"]["State"].should.equal("TERMINATED") # testing multiple end points for each feature + @mock_emr def test_bootstrap_actions(): bootstrap_actions = [ - {'Name': 'bs1', - 'ScriptBootstrapAction': { - 'Args': ['arg1', 'arg2'], - 'Path': 's3://path/to/script'}}, - {'Name': 'bs2', - 'ScriptBootstrapAction': { - 'Args': [], - 'Path': 's3://path/to/anotherscript'}} + { + "Name": "bs1", + "ScriptBootstrapAction": { + "Args": ["arg1", "arg2"], + "Path": "s3://path/to/script", + }, + }, + { + "Name": "bs2", + "ScriptBootstrapAction": {"Args": [], "Path": "s3://path/to/anotherscript"}, + }, ] - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['BootstrapActions'] = bootstrap_actions - cluster_id = client.run_job_flow(**args)['JobFlowId'] + args["BootstrapActions"] = bootstrap_actions + cluster_id = client.run_job_flow(**args)["JobFlowId"] - cl = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - for x, y in zip(cl['BootstrapActions'], bootstrap_actions): - x['BootstrapActionConfig'].should.equal(y) + cl = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + for x, y in zip(cl["BootstrapActions"], bootstrap_actions): + x["BootstrapActionConfig"].should.equal(y) resp = client.list_bootstrap_actions(ClusterId=cluster_id) - for x, y in zip(resp['BootstrapActions'], bootstrap_actions): - x['Name'].should.equal(y['Name']) - if 'Args' in y['ScriptBootstrapAction']: - x['Args'].should.equal(y['ScriptBootstrapAction']['Args']) - x['ScriptPath'].should.equal(y['ScriptBootstrapAction']['Path']) + for x, y in zip(resp["BootstrapActions"], bootstrap_actions): + x["Name"].should.equal(y["Name"]) + if "Args" in y["ScriptBootstrapAction"]: + x["Args"].should.equal(y["ScriptBootstrapAction"]["Args"]) + x["ScriptPath"].should.equal(y["ScriptBootstrapAction"]["Path"]) @mock_emr def test_instance_groups(): - input_groups = dict((g['Name'], g) for g in input_instance_groups) + input_groups = dict((g["Name"], g) for g in input_instance_groups) - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - for key in ['MasterInstanceType', 'SlaveInstanceType', 'InstanceCount']: - del args['Instances'][key] - args['Instances']['InstanceGroups'] = input_instance_groups[:2] - cluster_id = client.run_job_flow(**args)['JobFlowId'] + for key in ["MasterInstanceType", "SlaveInstanceType", "InstanceCount"]: + del args["Instances"][key] + args["Instances"]["InstanceGroups"] = input_instance_groups[:2] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - base_instance_count = jf['Instances']['InstanceCount'] + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + base_instance_count = jf["Instances"]["InstanceCount"] client.add_instance_groups( - JobFlowId=cluster_id, InstanceGroups=input_instance_groups[2:]) + JobFlowId=cluster_id, InstanceGroups=input_instance_groups[2:] + ) - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Instances']['InstanceCount'].should.equal( - sum(g['InstanceCount'] for g in input_instance_groups)) - for x in jf['Instances']['InstanceGroups']: - y = input_groups[x['Name']] - if hasattr(y, 'BidPrice'): - x['BidPrice'].should.equal('BidPrice') - x['CreationDateTime'].should.be.a('datetime.datetime') + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Instances"]["InstanceCount"].should.equal( + sum(g["InstanceCount"] for g in input_instance_groups) + ) + for x in jf["Instances"]["InstanceGroups"]: + y = input_groups[x["Name"]] + if hasattr(y, "BidPrice"): + x["BidPrice"].should.equal("BidPrice") + x["CreationDateTime"].should.be.a("datetime.datetime") # x['EndDateTime'].should.be.a('datetime.datetime') - x.should.have.key('InstanceGroupId') - x['InstanceRequestCount'].should.equal(y['InstanceCount']) - x['InstanceRole'].should.equal(y['InstanceRole']) - x['InstanceRunningCount'].should.equal(y['InstanceCount']) - x['InstanceType'].should.equal(y['InstanceType']) + x.should.have.key("InstanceGroupId") + x["InstanceRequestCount"].should.equal(y["InstanceCount"]) + x["InstanceRole"].should.equal(y["InstanceRole"]) + x["InstanceRunningCount"].should.equal(y["InstanceCount"]) + x["InstanceType"].should.equal(y["InstanceType"]) # x['LastStateChangeReason'].should.equal(y['LastStateChangeReason']) - x['Market'].should.equal(y['Market']) - x['Name'].should.equal(y['Name']) - x['ReadyDateTime'].should.be.a('datetime.datetime') - x['StartDateTime'].should.be.a('datetime.datetime') - x['State'].should.equal('RUNNING') + x["Market"].should.equal(y["Market"]) + x["Name"].should.equal(y["Name"]) + x["ReadyDateTime"].should.be.a("datetime.datetime") + x["StartDateTime"].should.be.a("datetime.datetime") + x["State"].should.equal("RUNNING") - groups = client.list_instance_groups(ClusterId=cluster_id)[ - 'InstanceGroups'] + groups = client.list_instance_groups(ClusterId=cluster_id)["InstanceGroups"] for x in groups: - y = input_groups[x['Name']] - if hasattr(y, 'BidPrice'): - x['BidPrice'].should.equal('BidPrice') + y = input_groups[x["Name"]] + if hasattr(y, "BidPrice"): + x["BidPrice"].should.equal("BidPrice") # Configurations # EbsBlockDevices # EbsOptimized - x.should.have.key('Id') - x['InstanceGroupType'].should.equal(y['InstanceRole']) - x['InstanceType'].should.equal(y['InstanceType']) - x['Market'].should.equal(y['Market']) - x['Name'].should.equal(y['Name']) - x['RequestedInstanceCount'].should.equal(y['InstanceCount']) - x['RunningInstanceCount'].should.equal(y['InstanceCount']) + x.should.have.key("Id") + x["InstanceGroupType"].should.equal(y["InstanceRole"]) + x["InstanceType"].should.equal(y["InstanceType"]) + x["Market"].should.equal(y["Market"]) + x["Name"].should.equal(y["Name"]) + x["RequestedInstanceCount"].should.equal(y["InstanceCount"]) + x["RunningInstanceCount"].should.equal(y["InstanceCount"]) # ShrinkPolicy - x['Status']['State'].should.equal('RUNNING') - x['Status']['StateChangeReason']['Code'].should.be.a(six.string_types) + x["Status"]["State"].should.equal("RUNNING") + x["Status"]["StateChangeReason"]["Code"].should.be.a(six.string_types) # x['Status']['StateChangeReason']['Message'].should.be.a(six.string_types) - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # x['Status']['Timeline']['EndDateTime'].should.be.a('datetime.datetime') - x['Status']['Timeline'][ - 'ReadyDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["ReadyDateTime"].should.be.a("datetime.datetime") - igs = dict((g['Name'], g) for g in groups) + igs = dict((g["Name"], g) for g in groups) client.modify_instance_groups( InstanceGroups=[ - {'InstanceGroupId': igs['task-1']['Id'], - 'InstanceCount': 2}, - {'InstanceGroupId': igs['task-2']['Id'], - 'InstanceCount': 3}]) - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Instances']['InstanceCount'].should.equal(base_instance_count + 5) - igs = dict((g['Name'], g) for g in jf['Instances']['InstanceGroups']) - igs['task-1']['InstanceRunningCount'].should.equal(2) - igs['task-2']['InstanceRunningCount'].should.equal(3) + {"InstanceGroupId": igs["task-1"]["Id"], "InstanceCount": 2}, + {"InstanceGroupId": igs["task-2"]["Id"], "InstanceCount": 3}, + ] + ) + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Instances"]["InstanceCount"].should.equal(base_instance_count + 5) + igs = dict((g["Name"], g) for g in jf["Instances"]["InstanceGroups"]) + igs["task-1"]["InstanceRunningCount"].should.equal(2) + igs["task-2"]["InstanceRunningCount"].should.equal(3) @mock_emr def test_steps(): - input_steps = [{ - 'HadoopJarStep': { - 'Args': [ - 'hadoop-streaming', - '-files', 's3://elasticmapreduce/samples/wordcount/wordSplitter.py#wordSplitter.py', - '-mapper', 'python wordSplitter.py', - '-input', 's3://elasticmapreduce/samples/wordcount/input', - '-output', 's3://output_bucket/output/wordcount_output', - '-reducer', 'aggregate' - ], - 'Jar': 'command-runner.jar', + input_steps = [ + { + "HadoopJarStep": { + "Args": [ + "hadoop-streaming", + "-files", + "s3://elasticmapreduce/samples/wordcount/wordSplitter.py#wordSplitter.py", + "-mapper", + "python wordSplitter.py", + "-input", + "s3://elasticmapreduce/samples/wordcount/input", + "-output", + "s3://output_bucket/output/wordcount_output", + "-reducer", + "aggregate", + ], + "Jar": "command-runner.jar", + }, + "Name": "My wordcount example", }, - 'Name': 'My wordcount example', - }, { - 'HadoopJarStep': { - 'Args': [ - 'hadoop-streaming', - '-files', 's3://elasticmapreduce/samples/wordcount/wordSplitter2.py#wordSplitter2.py', - '-mapper', 'python wordSplitter2.py', - '-input', 's3://elasticmapreduce/samples/wordcount/input2', - '-output', 's3://output_bucket/output/wordcount_output2', - '-reducer', 'aggregate' - ], - 'Jar': 'command-runner.jar', + { + "HadoopJarStep": { + "Args": [ + "hadoop-streaming", + "-files", + "s3://elasticmapreduce/samples/wordcount/wordSplitter2.py#wordSplitter2.py", + "-mapper", + "python wordSplitter2.py", + "-input", + "s3://elasticmapreduce/samples/wordcount/input2", + "-output", + "s3://output_bucket/output/wordcount_output2", + "-reducer", + "aggregate", + ], + "Jar": "command-runner.jar", + }, + "Name": "My wordcount example2", }, - 'Name': 'My wordcount example2', - }] + ] # TODO: implementation and test for cancel_steps - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Steps'] = [input_steps[0]] - cluster_id = client.run_job_flow(**args)['JobFlowId'] + args["Steps"] = [input_steps[0]] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Steps'].should.have.length_of(1) + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Steps"].should.have.length_of(1) client.add_job_flow_steps(JobFlowId=cluster_id, Steps=[input_steps[1]]) - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Steps'].should.have.length_of(2) - for idx, (x, y) in enumerate(zip(jf['Steps'], input_steps)): - x['ExecutionStatusDetail'].should.have.key('CreationDateTime') + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Steps"].should.have.length_of(2) + for idx, (x, y) in enumerate(zip(jf["Steps"], input_steps)): + x["ExecutionStatusDetail"].should.have.key("CreationDateTime") # x['ExecutionStatusDetail'].should.have.key('EndDateTime') # x['ExecutionStatusDetail'].should.have.key('LastStateChangeReason') # x['ExecutionStatusDetail'].should.have.key('StartDateTime') - x['ExecutionStatusDetail']['State'].should.equal( - 'STARTING' if idx == 0 else 'PENDING') - x['StepConfig']['ActionOnFailure'].should.equal('TERMINATE_CLUSTER') - x['StepConfig']['HadoopJarStep'][ - 'Args'].should.equal(y['HadoopJarStep']['Args']) - x['StepConfig']['HadoopJarStep'][ - 'Jar'].should.equal(y['HadoopJarStep']['Jar']) - if 'MainClass' in y['HadoopJarStep']: - x['StepConfig']['HadoopJarStep']['MainClass'].should.equal( - y['HadoopJarStep']['MainClass']) - if 'Properties' in y['HadoopJarStep']: - x['StepConfig']['HadoopJarStep']['Properties'].should.equal( - y['HadoopJarStep']['Properties']) - x['StepConfig']['Name'].should.equal(y['Name']) + x["ExecutionStatusDetail"]["State"].should.equal( + "STARTING" if idx == 0 else "PENDING" + ) + x["StepConfig"]["ActionOnFailure"].should.equal("TERMINATE_CLUSTER") + x["StepConfig"]["HadoopJarStep"]["Args"].should.equal( + y["HadoopJarStep"]["Args"] + ) + x["StepConfig"]["HadoopJarStep"]["Jar"].should.equal(y["HadoopJarStep"]["Jar"]) + if "MainClass" in y["HadoopJarStep"]: + x["StepConfig"]["HadoopJarStep"]["MainClass"].should.equal( + y["HadoopJarStep"]["MainClass"] + ) + if "Properties" in y["HadoopJarStep"]: + x["StepConfig"]["HadoopJarStep"]["Properties"].should.equal( + y["HadoopJarStep"]["Properties"] + ) + x["StepConfig"]["Name"].should.equal(y["Name"]) - expected = dict((s['Name'], s) for s in input_steps) + expected = dict((s["Name"], s) for s in input_steps) - steps = client.list_steps(ClusterId=cluster_id)['Steps'] + steps = client.list_steps(ClusterId=cluster_id)["Steps"] steps.should.have.length_of(2) for x in steps: - y = expected[x['Name']] - x['ActionOnFailure'].should.equal('TERMINATE_CLUSTER') - x['Config']['Args'].should.equal(y['HadoopJarStep']['Args']) - x['Config']['Jar'].should.equal(y['HadoopJarStep']['Jar']) + y = expected[x["Name"]] + x["ActionOnFailure"].should.equal("TERMINATE_CLUSTER") + x["Config"]["Args"].should.equal(y["HadoopJarStep"]["Args"]) + x["Config"]["Jar"].should.equal(y["HadoopJarStep"]["Jar"]) # x['Config']['MainClass'].should.equal(y['HadoopJarStep']['MainClass']) # Properties - x['Id'].should.be.a(six.string_types) - x['Name'].should.equal(y['Name']) - x['Status']['State'].should.be.within(['STARTING', 'PENDING']) + x["Id"].should.be.a(six.string_types) + x["Name"].should.equal(y["Name"]) + x["Status"]["State"].should.be.within(["STARTING", "PENDING"]) # StateChangeReason - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # x['Status']['Timeline']['EndDateTime'].should.be.a('datetime.datetime') # x['Status']['Timeline']['StartDateTime'].should.be.a('datetime.datetime') - x = client.describe_step(ClusterId=cluster_id, StepId=x['Id'])['Step'] - x['ActionOnFailure'].should.equal('TERMINATE_CLUSTER') - x['Config']['Args'].should.equal(y['HadoopJarStep']['Args']) - x['Config']['Jar'].should.equal(y['HadoopJarStep']['Jar']) + x = client.describe_step(ClusterId=cluster_id, StepId=x["Id"])["Step"] + x["ActionOnFailure"].should.equal("TERMINATE_CLUSTER") + x["Config"]["Args"].should.equal(y["HadoopJarStep"]["Args"]) + x["Config"]["Jar"].should.equal(y["HadoopJarStep"]["Jar"]) # x['Config']['MainClass'].should.equal(y['HadoopJarStep']['MainClass']) # Properties - x['Id'].should.be.a(six.string_types) - x['Name'].should.equal(y['Name']) - x['Status']['State'].should.be.within(['STARTING', 'PENDING']) + x["Id"].should.be.a(six.string_types) + x["Name"].should.equal(y["Name"]) + x["Status"]["State"].should.be.within(["STARTING", "PENDING"]) # StateChangeReason - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # x['Status']['Timeline']['EndDateTime'].should.be.a('datetime.datetime') # x['Status']['Timeline']['StartDateTime'].should.be.a('datetime.datetime') - step_id = steps[0]['Id'] - steps = client.list_steps(ClusterId=cluster_id, StepIds=[step_id])['Steps'] + step_id = steps[0]["Id"] + steps = client.list_steps(ClusterId=cluster_id, StepIds=[step_id])["Steps"] steps.should.have.length_of(1) - steps[0]['Id'].should.equal(step_id) + steps[0]["Id"].should.equal(step_id) - steps = client.list_steps(ClusterId=cluster_id, - StepStates=['STARTING'])['Steps'] + steps = client.list_steps(ClusterId=cluster_id, StepStates=["STARTING"])["Steps"] steps.should.have.length_of(1) - steps[0]['Id'].should.equal(step_id) + steps[0]["Id"].should.equal(step_id) @mock_emr def test_tags(): - input_tags = [{'Key': 'newkey1', 'Value': 'newval1'}, - {'Key': 'newkey2', 'Value': 'newval2'}] + input_tags = [ + {"Key": "newkey1", "Value": "newval1"}, + {"Key": "newkey2", "Value": "newval2"}, + ] - client = boto3.client('emr', region_name='us-east-1') - cluster_id = client.run_job_flow(**run_job_flow_args)['JobFlowId'] + client = boto3.client("emr", region_name="us-east-1") + cluster_id = client.run_job_flow(**run_job_flow_args)["JobFlowId"] client.add_tags(ResourceId=cluster_id, Tags=input_tags) - resp = client.describe_cluster(ClusterId=cluster_id)['Cluster'] - resp['Tags'].should.have.length_of(2) - dict((t['Key'], t['Value']) for t in resp['Tags']).should.equal( - dict((t['Key'], t['Value']) for t in input_tags)) + resp = client.describe_cluster(ClusterId=cluster_id)["Cluster"] + resp["Tags"].should.have.length_of(2) + dict((t["Key"], t["Value"]) for t in resp["Tags"]).should.equal( + dict((t["Key"], t["Value"]) for t in input_tags) + ) - client.remove_tags(ResourceId=cluster_id, TagKeys=[ - t['Key'] for t in input_tags]) - resp = client.describe_cluster(ClusterId=cluster_id)['Cluster'] - resp['Tags'].should.equal([]) + client.remove_tags(ResourceId=cluster_id, TagKeys=[t["Key"] for t in input_tags]) + resp = client.describe_cluster(ClusterId=cluster_id)["Cluster"] + resp["Tags"].should.equal([]) diff --git a/tests/test_emr/test_server.py b/tests/test_emr/test_server.py index f2b215ec7..4dbd02553 100644 --- a/tests/test_emr/test_server.py +++ b/tests/test_emr/test_server.py @@ -1,18 +1,18 @@ -from __future__ import unicode_literals -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_describe_jobflows(): - backend = server.create_backend_app("emr") - test_client = backend.test_client() - - res = test_client.get('/?Action=DescribeJobFlows') - - res.data.should.contain(b'') - res.data.should.contain(b'') +from __future__ import unicode_literals +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_describe_jobflows(): + backend = server.create_backend_app("emr") + test_client = backend.test_client() + + res = test_client.get("/?Action=DescribeJobFlows") + + res.data.should.contain(b"") + res.data.should.contain(b"") diff --git a/tests/test_events/test_events.py b/tests/test_events/test_events.py index a9d90ec32..14d872806 100644 --- a/tests/test_events/test_events.py +++ b/tests/test_events/test_events.py @@ -1,48 +1,50 @@ import random import boto3 import json +import sure # noqa from moto.events import mock_events from botocore.exceptions import ClientError from nose.tools import assert_raises +from moto.core import ACCOUNT_ID RULES = [ - {'Name': 'test1', 'ScheduleExpression': 'rate(5 minutes)'}, - {'Name': 'test2', 'ScheduleExpression': 'rate(1 minute)'}, - {'Name': 'test3', 'EventPattern': '{"source": ["test-source"]}'} + {"Name": "test1", "ScheduleExpression": "rate(5 minutes)"}, + {"Name": "test2", "ScheduleExpression": "rate(1 minute)"}, + {"Name": "test3", "EventPattern": '{"source": ["test-source"]}'}, ] TARGETS = { - 'test-target-1': { - 'Id': 'test-target-1', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-1', - 'Rules': ['test1', 'test2'] + "test-target-1": { + "Id": "test-target-1", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-1", + "Rules": ["test1", "test2"], }, - 'test-target-2': { - 'Id': 'test-target-2', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-2', - 'Rules': ['test1', 'test3'] + "test-target-2": { + "Id": "test-target-2", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-2", + "Rules": ["test1", "test3"], }, - 'test-target-3': { - 'Id': 'test-target-3', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-3', - 'Rules': ['test1', 'test2'] + "test-target-3": { + "Id": "test-target-3", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-3", + "Rules": ["test1", "test2"], }, - 'test-target-4': { - 'Id': 'test-target-4', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-4', - 'Rules': ['test1', 'test3'] + "test-target-4": { + "Id": "test-target-4", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-4", + "Rules": ["test1", "test3"], }, - 'test-target-5': { - 'Id': 'test-target-5', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-5', - 'Rules': ['test1', 'test2'] + "test-target-5": { + "Id": "test-target-5", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-5", + "Rules": ["test1", "test2"], + }, + "test-target-6": { + "Id": "test-target-6", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-6", + "Rules": ["test1", "test3"], }, - 'test-target-6': { - 'Id': 'test-target-6', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-6', - 'Rules': ['test1', 'test3'] - } } @@ -51,21 +53,21 @@ def get_random_rule(): def generate_environment(): - client = boto3.client('events', 'us-west-2') + client = boto3.client("events", "us-west-2") for rule in RULES: client.put_rule( - Name=rule['Name'], - ScheduleExpression=rule.get('ScheduleExpression', ''), - EventPattern=rule.get('EventPattern', '') + Name=rule["Name"], + ScheduleExpression=rule.get("ScheduleExpression", ""), + EventPattern=rule.get("EventPattern", ""), ) targets = [] for target in TARGETS: - if rule['Name'] in TARGETS[target].get('Rules'): - targets.append({'Id': target, 'Arn': TARGETS[target]['Arn']}) + if rule["Name"] in TARGETS[target].get("Rules"): + targets.append({"Id": target, "Arn": TARGETS[target]["Arn"]}) - client.put_targets(Rule=rule['Name'], Targets=targets) + client.put_targets(Rule=rule["Name"], Targets=targets) return client @@ -75,61 +77,63 @@ def test_list_rules(): client = generate_environment() response = client.list_rules() - assert(response is not None) - assert(len(response['Rules']) > 0) + assert response is not None + assert len(response["Rules"]) > 0 @mock_events def test_describe_rule(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() response = client.describe_rule(Name=rule_name) - assert(response is not None) - assert(response.get('Name') == rule_name) - assert(response.get('Arn') is not None) + assert response is not None + assert response.get("Name") == rule_name + assert response.get( + "Arn" + ) == "arn:aws:events:us-west-2:111111111111:rule/{0}".format(rule_name) @mock_events def test_enable_disable_rule(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() # Rules should start out enabled in these tests. rule = client.describe_rule(Name=rule_name) - assert(rule['State'] == 'ENABLED') + assert rule["State"] == "ENABLED" client.disable_rule(Name=rule_name) rule = client.describe_rule(Name=rule_name) - assert(rule['State'] == 'DISABLED') + assert rule["State"] == "DISABLED" client.enable_rule(Name=rule_name) rule = client.describe_rule(Name=rule_name) - assert(rule['State'] == 'ENABLED') + assert rule["State"] == "ENABLED" # Test invalid name try: - client.enable_rule(Name='junk') + client.enable_rule(Name="junk") except ClientError as ce: - assert ce.response['Error']['Code'] == 'ResourceNotFoundException' + assert ce.response["Error"]["Code"] == "ResourceNotFoundException" @mock_events def test_list_rule_names_by_target(): - test_1_target = TARGETS['test-target-1'] - test_2_target = TARGETS['test-target-2'] + test_1_target = TARGETS["test-target-1"] + test_2_target = TARGETS["test-target-2"] client = generate_environment() - rules = client.list_rule_names_by_target(TargetArn=test_1_target['Arn']) - assert(len(rules['RuleNames']) == len(test_1_target['Rules'])) - for rule in rules['RuleNames']: - assert(rule in test_1_target['Rules']) + rules = client.list_rule_names_by_target(TargetArn=test_1_target["Arn"]) + assert len(rules["RuleNames"]) == len(test_1_target["Rules"]) + for rule in rules["RuleNames"]: + assert rule in test_1_target["Rules"] - rules = client.list_rule_names_by_target(TargetArn=test_2_target['Arn']) - assert(len(rules['RuleNames']) == len(test_2_target['Rules'])) - for rule in rules['RuleNames']: - assert(rule in test_2_target['Rules']) + rules = client.list_rule_names_by_target(TargetArn=test_2_target["Arn"]) + assert len(rules["RuleNames"]) == len(test_2_target["Rules"]) + for rule in rules["RuleNames"]: + assert rule in test_2_target["Rules"] @mock_events @@ -137,80 +141,323 @@ def test_list_rules(): client = generate_environment() rules = client.list_rules() - assert(len(rules['Rules']) == len(RULES)) + assert len(rules["Rules"]) == len(RULES) @mock_events def test_delete_rule(): client = generate_environment() - client.delete_rule(Name=RULES[0]['Name']) + client.delete_rule(Name=RULES[0]["Name"]) rules = client.list_rules() - assert(len(rules['Rules']) == len(RULES) - 1) + assert len(rules["Rules"]) == len(RULES) - 1 @mock_events def test_list_targets_by_rule(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() targets = client.list_targets_by_rule(Rule=rule_name) expected_targets = [] for target in TARGETS: - if rule_name in TARGETS[target].get('Rules'): + if rule_name in TARGETS[target].get("Rules"): expected_targets.append(target) - assert(len(targets['Targets']) == len(expected_targets)) + assert len(targets["Targets"]) == len(expected_targets) @mock_events def test_remove_targets(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() - targets = client.list_targets_by_rule(Rule=rule_name)['Targets'] + targets = client.list_targets_by_rule(Rule=rule_name)["Targets"] targets_before = len(targets) - assert(targets_before > 0) + assert targets_before > 0 - client.remove_targets(Rule=rule_name, Ids=[targets[0]['Id']]) + client.remove_targets(Rule=rule_name, Ids=[targets[0]["Id"]]) - targets = client.list_targets_by_rule(Rule=rule_name)['Targets'] + targets = client.list_targets_by_rule(Rule=rule_name)["Targets"] targets_after = len(targets) - assert(targets_before - 1 == targets_after) + assert targets_before - 1 == targets_after @mock_events def test_permissions(): - client = boto3.client('events', 'eu-central-1') + client = boto3.client("events", "eu-central-1") - client.put_permission(Action='events:PutEvents', Principal='111111111111', StatementId='Account1') - client.put_permission(Action='events:PutEvents', Principal='222222222222', StatementId='Account2') + client.put_permission( + Action="events:PutEvents", Principal="111111111111", StatementId="Account1" + ) + client.put_permission( + Action="events:PutEvents", Principal="222222222222", StatementId="Account2" + ) resp = client.describe_event_bus() - resp_policy = json.loads(resp['Policy']) - assert len(resp_policy['Statement']) == 2 + resp_policy = json.loads(resp["Policy"]) + assert len(resp_policy["Statement"]) == 2 - client.remove_permission(StatementId='Account2') + client.remove_permission(StatementId="Account2") resp = client.describe_event_bus() - resp_policy = json.loads(resp['Policy']) - assert len(resp_policy['Statement']) == 1 - assert resp_policy['Statement'][0]['Sid'] == 'Account1' + resp_policy = json.loads(resp["Policy"]) + assert len(resp_policy["Statement"]) == 1 + assert resp_policy["Statement"][0]["Sid"] == "Account1" + + +@mock_events +def test_put_permission_errors(): + client = boto3.client("events", "us-east-1") + client.create_event_bus(Name="test-bus") + + client.put_permission.when.called_with( + EventBusName="non-existing", + Action="events:PutEvents", + Principal="111111111111", + StatementId="test", + ).should.throw(ClientError, "Event bus non-existing does not exist.") + + client.put_permission.when.called_with( + EventBusName="test-bus", + Action="events:PutPermission", + Principal="111111111111", + StatementId="test", + ).should.throw( + ClientError, "Provided value in parameter 'action' is not supported." + ) + + +@mock_events +def test_remove_permission_errors(): + client = boto3.client("events", "us-east-1") + client.create_event_bus(Name="test-bus") + + client.remove_permission.when.called_with( + EventBusName="non-existing", StatementId="test" + ).should.throw(ClientError, "Event bus non-existing does not exist.") + + client.remove_permission.when.called_with( + EventBusName="test-bus", StatementId="test" + ).should.throw(ClientError, "EventBus does not have a policy.") + + client.put_permission( + EventBusName="test-bus", + Action="events:PutEvents", + Principal="111111111111", + StatementId="test", + ) + + client.remove_permission.when.called_with( + EventBusName="test-bus", StatementId="non-existing" + ).should.throw(ClientError, "Statement with the provided id does not exist.") @mock_events def test_put_events(): - client = boto3.client('events', 'eu-central-1') + client = boto3.client("events", "eu-central-1") event = { "Source": "com.mycompany.myapp", "Detail": '{"key1": "value3", "key2": "value4"}', "Resources": ["resource1", "resource2"], - "DetailType": "myDetailType" + "DetailType": "myDetailType", } client.put_events(Entries=[event]) # Boto3 would error if it didn't return 200 OK with assert_raises(ClientError): - client.put_events(Entries=[event]*20) + client.put_events(Entries=[event] * 20) + + +@mock_events +def test_create_event_bus(): + client = boto3.client("events", "us-east-1") + response = client.create_event_bus(Name="test-bus") + + response["EventBusArn"].should.equal( + "arn:aws:events:us-east-1:{}:event-bus/test-bus".format(ACCOUNT_ID) + ) + + +@mock_events +def test_create_event_bus_errors(): + client = boto3.client("events", "us-east-1") + client.create_event_bus(Name="test-bus") + + client.create_event_bus.when.called_with(Name="test-bus").should.throw( + ClientError, "Event bus test-bus already exists." + ) + + # the 'default' name is already used for the account's default event bus. + client.create_event_bus.when.called_with(Name="default").should.throw( + ClientError, "Event bus default already exists." + ) + + # non partner event buses can't contain the '/' character + client.create_event_bus.when.called_with(Name="test/test-bus").should.throw( + ClientError, "Event bus name must not contain '/'." + ) + + client.create_event_bus.when.called_with( + Name="aws.partner/test/test-bus", EventSourceName="aws.partner/test/test-bus" + ).should.throw( + ClientError, "Event source aws.partner/test/test-bus does not exist." + ) + + +@mock_events +def test_describe_event_bus(): + client = boto3.client("events", "us-east-1") + + response = client.describe_event_bus() + + response["Name"].should.equal("default") + response["Arn"].should.equal( + "arn:aws:events:us-east-1:{}:event-bus/default".format(ACCOUNT_ID) + ) + response.should_not.have.key("Policy") + + client.create_event_bus(Name="test-bus") + client.put_permission( + EventBusName="test-bus", + Action="events:PutEvents", + Principal="111111111111", + StatementId="test", + ) + + response = client.describe_event_bus(Name="test-bus") + + response["Name"].should.equal("test-bus") + response["Arn"].should.equal( + "arn:aws:events:us-east-1:{}:event-bus/test-bus".format(ACCOUNT_ID) + ) + json.loads(response["Policy"]).should.equal( + { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "test", + "Effect": "Allow", + "Principal": {"AWS": "arn:aws:iam::111111111111:root"}, + "Action": "events:PutEvents", + "Resource": "arn:aws:events:us-east-1:{}:event-bus/test-bus".format( + ACCOUNT_ID + ), + } + ], + } + ) + + +@mock_events +def test_describe_event_bus_errors(): + client = boto3.client("events", "us-east-1") + + client.describe_event_bus.when.called_with(Name="non-existing").should.throw( + ClientError, "Event bus non-existing does not exist." + ) + + +@mock_events +def test_list_event_buses(): + client = boto3.client("events", "us-east-1") + client.create_event_bus(Name="test-bus-1") + client.create_event_bus(Name="test-bus-2") + client.create_event_bus(Name="other-bus-1") + client.create_event_bus(Name="other-bus-2") + + response = client.list_event_buses() + + response["EventBuses"].should.have.length_of(5) + sorted(response["EventBuses"], key=lambda i: i["Name"]).should.equal( + [ + { + "Name": "default", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/default".format( + ACCOUNT_ID + ), + }, + { + "Name": "other-bus-1", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/other-bus-1".format( + ACCOUNT_ID + ), + }, + { + "Name": "other-bus-2", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/other-bus-2".format( + ACCOUNT_ID + ), + }, + { + "Name": "test-bus-1", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/test-bus-1".format( + ACCOUNT_ID + ), + }, + { + "Name": "test-bus-2", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/test-bus-2".format( + ACCOUNT_ID + ), + }, + ] + ) + + response = client.list_event_buses(NamePrefix="other-bus") + + response["EventBuses"].should.have.length_of(2) + sorted(response["EventBuses"], key=lambda i: i["Name"]).should.equal( + [ + { + "Name": "other-bus-1", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/other-bus-1".format( + ACCOUNT_ID + ), + }, + { + "Name": "other-bus-2", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/other-bus-2".format( + ACCOUNT_ID + ), + }, + ] + ) + + +@mock_events +def test_delete_event_bus(): + client = boto3.client("events", "us-east-1") + client.create_event_bus(Name="test-bus") + + response = client.list_event_buses() + response["EventBuses"].should.have.length_of(2) + + client.delete_event_bus(Name="test-bus") + + response = client.list_event_buses() + response["EventBuses"].should.have.length_of(1) + response["EventBuses"].should.equal( + [ + { + "Name": "default", + "Arn": "arn:aws:events:us-east-1:{}:event-bus/default".format( + ACCOUNT_ID + ), + } + ] + ) + + # deleting non existing event bus should be successful + client.delete_event_bus(Name="non-existing") + + +@mock_events +def test_delete_event_bus_errors(): + client = boto3.client("events", "us-east-1") + + client.delete_event_bus.when.called_with(Name="default").should.throw( + ClientError, "Cannot delete event bus default." + ) diff --git a/tests/test_glacier/test_glacier_jobs.py b/tests/test_glacier/test_glacier_jobs.py index 761b47a66..11077d7f2 100644 --- a/tests/test_glacier/test_glacier_jobs.py +++ b/tests/test_glacier/test_glacier_jobs.py @@ -1,90 +1,91 @@ -from __future__ import unicode_literals - -import json -import time - -from boto.glacier.layer1 import Layer1 -import sure # noqa - -from moto import mock_glacier_deprecated - - -@mock_glacier_deprecated -def test_init_glacier_job(): - conn = Layer1(region_name="us-west-2") - vault_name = "my_vault" - conn.create_vault(vault_name) - archive_id = conn.upload_archive( - vault_name, "some stuff", "", "", "some description") - - job_response = conn.initiate_job(vault_name, { - "ArchiveId": archive_id, - "Type": "archive-retrieval", - }) - job_id = job_response['JobId'] - job_response['Location'].should.equal( - "//vaults/my_vault/jobs/{0}".format(job_id)) - - -@mock_glacier_deprecated -def test_describe_job(): - conn = Layer1(region_name="us-west-2") - vault_name = "my_vault" - conn.create_vault(vault_name) - archive_id = conn.upload_archive( - vault_name, "some stuff", "", "", "some description") - job_response = conn.initiate_job(vault_name, { - "ArchiveId": archive_id, - "Type": "archive-retrieval", - }) - job_id = job_response['JobId'] - - job = conn.describe_job(vault_name, job_id) - joboutput = json.loads(job.read().decode("utf-8")) - - joboutput.should.have.key('Tier').which.should.equal('Standard') - joboutput.should.have.key('StatusCode').which.should.equal('InProgress') - joboutput.should.have.key('VaultARN').which.should.equal('arn:aws:glacier:RegionInfo:us-west-2:012345678901:vaults/my_vault') - - -@mock_glacier_deprecated -def test_list_glacier_jobs(): - conn = Layer1(region_name="us-west-2") - vault_name = "my_vault" - conn.create_vault(vault_name) - archive_id1 = conn.upload_archive( - vault_name, "some stuff", "", "", "some description")['ArchiveId'] - archive_id2 = conn.upload_archive( - vault_name, "some other stuff", "", "", "some description")['ArchiveId'] - - conn.initiate_job(vault_name, { - "ArchiveId": archive_id1, - "Type": "archive-retrieval", - }) - conn.initiate_job(vault_name, { - "ArchiveId": archive_id2, - "Type": "archive-retrieval", - }) - - jobs = conn.list_jobs(vault_name) - len(jobs['JobList']).should.equal(2) - - -@mock_glacier_deprecated -def test_get_job_output(): - conn = Layer1(region_name="us-west-2") - vault_name = "my_vault" - conn.create_vault(vault_name) - archive_response = conn.upload_archive( - vault_name, "some stuff", "", "", "some description") - archive_id = archive_response['ArchiveId'] - job_response = conn.initiate_job(vault_name, { - "ArchiveId": archive_id, - "Type": "archive-retrieval", - }) - job_id = job_response['JobId'] - - time.sleep(6) - - output = conn.get_job_output(vault_name, job_id) - output.read().decode("utf-8").should.equal("some stuff") +from __future__ import unicode_literals + +import json +import time + +from boto.glacier.layer1 import Layer1 +import sure # noqa + +from moto import mock_glacier_deprecated + + +@mock_glacier_deprecated +def test_init_glacier_job(): + conn = Layer1(region_name="us-west-2") + vault_name = "my_vault" + conn.create_vault(vault_name) + archive_id = conn.upload_archive( + vault_name, "some stuff", "", "", "some description" + ) + + job_response = conn.initiate_job( + vault_name, {"ArchiveId": archive_id, "Type": "archive-retrieval"} + ) + job_id = job_response["JobId"] + job_response["Location"].should.equal("//vaults/my_vault/jobs/{0}".format(job_id)) + + +@mock_glacier_deprecated +def test_describe_job(): + conn = Layer1(region_name="us-west-2") + vault_name = "my_vault" + conn.create_vault(vault_name) + archive_id = conn.upload_archive( + vault_name, "some stuff", "", "", "some description" + ) + job_response = conn.initiate_job( + vault_name, {"ArchiveId": archive_id, "Type": "archive-retrieval"} + ) + job_id = job_response["JobId"] + + job = conn.describe_job(vault_name, job_id) + joboutput = json.loads(job.read().decode("utf-8")) + + joboutput.should.have.key("Tier").which.should.equal("Standard") + joboutput.should.have.key("StatusCode").which.should.equal("InProgress") + joboutput.should.have.key("VaultARN").which.should.equal( + "arn:aws:glacier:RegionInfo:us-west-2:012345678901:vaults/my_vault" + ) + + +@mock_glacier_deprecated +def test_list_glacier_jobs(): + conn = Layer1(region_name="us-west-2") + vault_name = "my_vault" + conn.create_vault(vault_name) + archive_id1 = conn.upload_archive( + vault_name, "some stuff", "", "", "some description" + )["ArchiveId"] + archive_id2 = conn.upload_archive( + vault_name, "some other stuff", "", "", "some description" + )["ArchiveId"] + + conn.initiate_job( + vault_name, {"ArchiveId": archive_id1, "Type": "archive-retrieval"} + ) + conn.initiate_job( + vault_name, {"ArchiveId": archive_id2, "Type": "archive-retrieval"} + ) + + jobs = conn.list_jobs(vault_name) + len(jobs["JobList"]).should.equal(2) + + +@mock_glacier_deprecated +def test_get_job_output(): + conn = Layer1(region_name="us-west-2") + vault_name = "my_vault" + conn.create_vault(vault_name) + archive_response = conn.upload_archive( + vault_name, "some stuff", "", "", "some description" + ) + archive_id = archive_response["ArchiveId"] + job_response = conn.initiate_job( + vault_name, {"ArchiveId": archive_id, "Type": "archive-retrieval"} + ) + job_id = job_response["JobId"] + + time.sleep(6) + + output = conn.get_job_output(vault_name, job_id) + output.read().decode("utf-8").should.equal("some stuff") diff --git a/tests/test_glacier/test_glacier_server.py b/tests/test_glacier/test_glacier_server.py index b6c03428e..d43dd4e8a 100644 --- a/tests/test_glacier/test_glacier_server.py +++ b/tests/test_glacier/test_glacier_server.py @@ -1,22 +1,21 @@ -from __future__ import unicode_literals - -import json -import sure # noqa - -import moto.server as server -from moto import mock_glacier - -''' -Test the different server responses -''' - - -@mock_glacier -def test_list_vaults(): - backend = server.create_backend_app("glacier") - test_client = backend.test_client() - - res = test_client.get('/1234bcd/vaults') - - json.loads(res.data.decode("utf-8") - ).should.equal({u'Marker': None, u'VaultList': []}) +from __future__ import unicode_literals + +import json +import sure # noqa + +import moto.server as server +from moto import mock_glacier + +""" +Test the different server responses +""" + + +@mock_glacier +def test_list_vaults(): + backend = server.create_backend_app("glacier") + test_client = backend.test_client() + + res = test_client.get("/1234bcd/vaults") + + json.loads(res.data.decode("utf-8")).should.equal({"Marker": None, "VaultList": []}) diff --git a/tests/test_glue/fixtures/datacatalog.py b/tests/test_glue/fixtures/datacatalog.py index 13136158b..11cb30ca9 100644 --- a/tests/test_glue/fixtures/datacatalog.py +++ b/tests/test_glue/fixtures/datacatalog.py @@ -1,56 +1,55 @@ -from __future__ import unicode_literals - -TABLE_INPUT = { - 'Owner': 'a_fake_owner', - 'Parameters': { - 'EXTERNAL': 'TRUE', - }, - 'Retention': 0, - 'StorageDescriptor': { - 'BucketColumns': [], - 'Compressed': False, - 'InputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat', - 'NumberOfBuckets': -1, - 'OutputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat', - 'Parameters': {}, - 'SerdeInfo': { - 'Parameters': { - 'serialization.format': '1' - }, - 'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' - }, - 'SkewedInfo': { - 'SkewedColumnNames': [], - 'SkewedColumnValueLocationMaps': {}, - 'SkewedColumnValues': [] - }, - 'SortColumns': [], - 'StoredAsSubDirectories': False - }, - 'TableType': 'EXTERNAL_TABLE', -} - - -PARTITION_INPUT = { - # 'DatabaseName': 'dbname', - 'StorageDescriptor': { - 'BucketColumns': [], - 'Columns': [], - 'Compressed': False, - 'InputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat', - 'Location': 's3://.../partition=value', - 'NumberOfBuckets': -1, - 'OutputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat', - 'Parameters': {}, - 'SerdeInfo': { - 'Parameters': {'path': 's3://...', 'serialization.format': '1'}, - 'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'}, - 'SkewedInfo': {'SkewedColumnNames': [], - 'SkewedColumnValueLocationMaps': {}, - 'SkewedColumnValues': []}, - 'SortColumns': [], - 'StoredAsSubDirectories': False, - }, - # 'TableName': 'source_table', - # 'Values': ['2018-06-26'], -} +from __future__ import unicode_literals + +TABLE_INPUT = { + "Owner": "a_fake_owner", + "Parameters": {"EXTERNAL": "TRUE"}, + "Retention": 0, + "StorageDescriptor": { + "BucketColumns": [], + "Compressed": False, + "InputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "NumberOfBuckets": -1, + "OutputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "Parameters": {}, + "SerdeInfo": { + "Parameters": {"serialization.format": "1"}, + "SerializationLibrary": "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + }, + "SkewedInfo": { + "SkewedColumnNames": [], + "SkewedColumnValueLocationMaps": {}, + "SkewedColumnValues": [], + }, + "SortColumns": [], + "StoredAsSubDirectories": False, + }, + "TableType": "EXTERNAL_TABLE", +} + + +PARTITION_INPUT = { + # 'DatabaseName': 'dbname', + "StorageDescriptor": { + "BucketColumns": [], + "Columns": [], + "Compressed": False, + "InputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "Location": "s3://.../partition=value", + "NumberOfBuckets": -1, + "OutputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "Parameters": {}, + "SerdeInfo": { + "Parameters": {"path": "s3://...", "serialization.format": "1"}, + "SerializationLibrary": "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + }, + "SkewedInfo": { + "SkewedColumnNames": [], + "SkewedColumnValueLocationMaps": {}, + "SkewedColumnValues": [], + }, + "SortColumns": [], + "StoredAsSubDirectories": False, + }, + # 'TableName': 'source_table', + # 'Values': ['2018-06-26'], +} diff --git a/tests/test_glue/helpers.py b/tests/test_glue/helpers.py index 48908532c..9003a1358 100644 --- a/tests/test_glue/helpers.py +++ b/tests/test_glue/helpers.py @@ -6,11 +6,7 @@ from .fixtures.datacatalog import TABLE_INPUT, PARTITION_INPUT def create_database(client, database_name): - return client.create_database( - DatabaseInput={ - 'Name': database_name - } - ) + return client.create_database(DatabaseInput={"Name": database_name}) def get_database(client, database_name): @@ -19,12 +15,13 @@ def get_database(client, database_name): def create_table_input(database_name, table_name, columns=[], partition_keys=[]): table_input = copy.deepcopy(TABLE_INPUT) - table_input['Name'] = table_name - table_input['PartitionKeys'] = partition_keys - table_input['StorageDescriptor']['Columns'] = columns - table_input['StorageDescriptor']['Location'] = 's3://my-bucket/{database_name}/{table_name}'.format( - database_name=database_name, - table_name=table_name + table_input["Name"] = table_name + table_input["PartitionKeys"] = partition_keys + table_input["StorageDescriptor"]["Columns"] = columns + table_input["StorageDescriptor"][ + "Location" + ] = "s3://my-bucket/{database_name}/{table_name}".format( + database_name=database_name, table_name=table_name ) return table_input @@ -33,60 +30,43 @@ def create_table(client, database_name, table_name, table_input=None, **kwargs): if table_input is None: table_input = create_table_input(database_name, table_name, **kwargs) - return client.create_table( - DatabaseName=database_name, - TableInput=table_input - ) + return client.create_table(DatabaseName=database_name, TableInput=table_input) def update_table(client, database_name, table_name, table_input=None, **kwargs): if table_input is None: table_input = create_table_input(database_name, table_name, **kwargs) - return client.update_table( - DatabaseName=database_name, - TableInput=table_input, - ) + return client.update_table(DatabaseName=database_name, TableInput=table_input) def get_table(client, database_name, table_name): - return client.get_table( - DatabaseName=database_name, - Name=table_name - ) + return client.get_table(DatabaseName=database_name, Name=table_name) def get_tables(client, database_name): - return client.get_tables( - DatabaseName=database_name - ) + return client.get_tables(DatabaseName=database_name) def get_table_versions(client, database_name, table_name): - return client.get_table_versions( - DatabaseName=database_name, - TableName=table_name - ) + return client.get_table_versions(DatabaseName=database_name, TableName=table_name) def get_table_version(client, database_name, table_name, version_id): return client.get_table_version( - DatabaseName=database_name, - TableName=table_name, - VersionId=version_id, + DatabaseName=database_name, TableName=table_name, VersionId=version_id ) def create_partition_input(database_name, table_name, values=[], columns=[]): - root_path = 's3://my-bucket/{database_name}/{table_name}'.format( - database_name=database_name, - table_name=table_name + root_path = "s3://my-bucket/{database_name}/{table_name}".format( + database_name=database_name, table_name=table_name ) part_input = copy.deepcopy(PARTITION_INPUT) - part_input['Values'] = values - part_input['StorageDescriptor']['Columns'] = columns - part_input['StorageDescriptor']['SerdeInfo']['Parameters']['path'] = root_path + part_input["Values"] = values + part_input["StorageDescriptor"]["Columns"] = columns + part_input["StorageDescriptor"]["SerdeInfo"]["Parameters"]["path"] = root_path return part_input @@ -94,13 +74,13 @@ def create_partition(client, database_name, table_name, partiton_input=None, **k if partiton_input is None: partiton_input = create_partition_input(database_name, table_name, **kwargs) return client.create_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionInput=partiton_input + DatabaseName=database_name, TableName=table_name, PartitionInput=partiton_input ) -def update_partition(client, database_name, table_name, old_values=[], partiton_input=None, **kwargs): +def update_partition( + client, database_name, table_name, old_values=[], partiton_input=None, **kwargs +): if partiton_input is None: partiton_input = create_partition_input(database_name, table_name, **kwargs) return client.update_partition( @@ -113,7 +93,5 @@ def update_partition(client, database_name, table_name, old_values=[], partiton_ def get_partition(client, database_name, table_name, values): return client.get_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionValues=values, + DatabaseName=database_name, TableName=table_name, PartitionValues=values ) diff --git a/tests/test_glue/test_datacatalog.py b/tests/test_glue/test_datacatalog.py index 9034feb55..28281b18f 100644 --- a/tests/test_glue/test_datacatalog.py +++ b/tests/test_glue/test_datacatalog.py @@ -16,80 +16,82 @@ from . import helpers @mock_glue def test_create_database(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) response = helpers.get_database(client, database_name) - database = response['Database'] + database = response["Database"] - database.should.equal({'Name': database_name}) + database.should.equal({"Name": database_name}) @mock_glue def test_create_database_already_exists(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'cantcreatethisdatabasetwice' + client = boto3.client("glue", region_name="us-east-1") + database_name = "cantcreatethisdatabasetwice" helpers.create_database(client, database_name) with assert_raises(ClientError) as exc: helpers.create_database(client, database_name) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_get_database_not_exits(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'nosuchdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "nosuchdatabase" with assert_raises(ClientError) as exc: helpers.get_database(client, database_name) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Database nosuchdatabase not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Database nosuchdatabase not found" + ) @mock_glue def test_create_table(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myspecialtable' + table_name = "myspecialtable" table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) response = helpers.get_table(client, database_name, table_name) - table = response['Table'] + table = response["Table"] - table['Name'].should.equal(table_input['Name']) - table['StorageDescriptor'].should.equal(table_input['StorageDescriptor']) - table['PartitionKeys'].should.equal(table_input['PartitionKeys']) + table["Name"].should.equal(table_input["Name"]) + table["StorageDescriptor"].should.equal(table_input["StorageDescriptor"]) + table["PartitionKeys"].should.equal(table_input["PartitionKeys"]) @mock_glue def test_create_table_already_exists(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'cantcreatethistabletwice' + table_name = "cantcreatethistabletwice" helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: helpers.create_table(client, database_name, table_name) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_get_tables(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_names = ['myfirsttable', 'mysecondtable', 'mythirdtable'] + table_names = ["myfirsttable", "mysecondtable", "mythirdtable"] table_inputs = {} for table_name in table_names: @@ -99,31 +101,33 @@ def test_get_tables(): response = helpers.get_tables(client, database_name) - tables = response['TableList'] + tables = response["TableList"] tables.should.have.length_of(3) for table in tables: - table_name = table['Name'] - table_name.should.equal(table_inputs[table_name]['Name']) - table['StorageDescriptor'].should.equal(table_inputs[table_name]['StorageDescriptor']) - table['PartitionKeys'].should.equal(table_inputs[table_name]['PartitionKeys']) + table_name = table["Name"] + table_name.should.equal(table_inputs[table_name]["Name"]) + table["StorageDescriptor"].should.equal( + table_inputs[table_name]["StorageDescriptor"] + ) + table["PartitionKeys"].should.equal(table_inputs[table_name]["PartitionKeys"]) @mock_glue def test_get_table_versions(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myfirsttable' + table_name = "myfirsttable" version_inputs = {} table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) version_inputs["1"] = table_input - columns = [{'Name': 'country', 'Type': 'string'}] + columns = [{"Name": "country", "Type": "string"}] table_input = helpers.create_table_input(database_name, table_name, columns=columns) helpers.update_table(client, database_name, table_name, table_input) version_inputs["2"] = table_input @@ -134,174 +138,189 @@ def test_get_table_versions(): response = helpers.get_table_versions(client, database_name, table_name) - vers = response['TableVersions'] + vers = response["TableVersions"] vers.should.have.length_of(3) - vers[0]['Table']['StorageDescriptor']['Columns'].should.equal([]) - vers[-1]['Table']['StorageDescriptor']['Columns'].should.equal(columns) + vers[0]["Table"]["StorageDescriptor"]["Columns"].should.equal([]) + vers[-1]["Table"]["StorageDescriptor"]["Columns"].should.equal(columns) for n, ver in enumerate(vers): n = str(n + 1) - ver['VersionId'].should.equal(n) - ver['Table']['Name'].should.equal(table_name) - ver['Table']['StorageDescriptor'].should.equal(version_inputs[n]['StorageDescriptor']) - ver['Table']['PartitionKeys'].should.equal(version_inputs[n]['PartitionKeys']) + ver["VersionId"].should.equal(n) + ver["Table"]["Name"].should.equal(table_name) + ver["Table"]["StorageDescriptor"].should.equal( + version_inputs[n]["StorageDescriptor"] + ) + ver["Table"]["PartitionKeys"].should.equal(version_inputs[n]["PartitionKeys"]) response = helpers.get_table_version(client, database_name, table_name, "3") - ver = response['TableVersion'] + ver = response["TableVersion"] - ver['VersionId'].should.equal("3") - ver['Table']['Name'].should.equal(table_name) - ver['Table']['StorageDescriptor']['Columns'].should.equal(columns) + ver["VersionId"].should.equal("3") + ver["Table"]["Name"].should.equal(table_name) + ver["Table"]["StorageDescriptor"]["Columns"].should.equal(columns) @mock_glue def test_get_table_version_not_found(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.get_table_version(client, database_name, 'myfirsttable', "20") + helpers.get_table_version(client, database_name, "myfirsttable", "20") - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('version', re.I) + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("version", re.I) @mock_glue def test_get_table_version_invalid_input(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.get_table_version(client, database_name, 'myfirsttable', "10not-an-int") + helpers.get_table_version(client, database_name, "myfirsttable", "10not-an-int") - exc.exception.response['Error']['Code'].should.equal('InvalidInputException') + exc.exception.response["Error"]["Code"].should.equal("InvalidInputException") @mock_glue def test_get_table_not_exits(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) with assert_raises(ClientError) as exc: - helpers.get_table(client, database_name, 'myfirsttable') + helpers.get_table(client, database_name, "myfirsttable") - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Table myfirsttable not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Table myfirsttable not found" + ) @mock_glue def test_get_table_when_database_not_exits(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'nosuchdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "nosuchdatabase" with assert_raises(ClientError) as exc: - helpers.get_table(client, database_name, 'myfirsttable') + helpers.get_table(client, database_name, "myfirsttable") - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Database nosuchdatabase not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Database nosuchdatabase not found" + ) @mock_glue def test_delete_table(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myspecialtable' + table_name = "myspecialtable" table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) result = client.delete_table(DatabaseName=database_name, Name=table_name) - result['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # confirm table is deleted with assert_raises(ClientError) as exc: helpers.get_table(client, database_name, table_name) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Table myspecialtable not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Table myspecialtable not found" + ) + @mock_glue def test_batch_delete_table(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myspecialtable' + table_name = "myspecialtable" table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) - result = client.batch_delete_table(DatabaseName=database_name, TablesToDelete=[table_name]) - result['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + result = client.batch_delete_table( + DatabaseName=database_name, TablesToDelete=[table_name] + ) + result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # confirm table is deleted with assert_raises(ClientError) as exc: helpers.get_table(client, database_name, table_name) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Table myspecialtable not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Table myspecialtable not found" + ) @mock_glue def test_get_partitions_empty(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - response['Partitions'].should.have.length_of(0) + response["Partitions"].should.have.length_of(0) @mock_glue def test_create_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) before = datetime.now(pytz.utc) - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) helpers.create_partition(client, database_name, table_name, part_input) after = datetime.now(pytz.utc) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(1) partition = partitions[0] - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor'].should.equal(part_input['StorageDescriptor']) - partition['Values'].should.equal(values) - partition['CreationTime'].should.be.greater_than(before) - partition['CreationTime'].should.be.lower_than(after) + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"].should.equal(part_input["StorageDescriptor"]) + partition["Values"].should.equal(values) + partition["CreationTime"].should.be.greater_than(before) + partition["CreationTime"].should.be.lower_than(after) @mock_glue def test_create_partition_already_exist(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -311,15 +330,15 @@ def test_create_partition_already_exist(): with assert_raises(ClientError) as exc: helpers.create_partition(client, database_name, table_name, values=values) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_get_partition_not_found(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -327,14 +346,15 @@ def test_get_partition_not_found(): with assert_raises(ClientError) as exc: helpers.get_partition(client, database_name, table_name, values) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('partition') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("partition") + @mock_glue def test_batch_create_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -344,197 +364,221 @@ def test_batch_create_partition(): partition_inputs = [] for i in range(0, 20): values = ["2018-10-{:2}".format(i)] - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) partition_inputs.append(part_input) client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=partition_inputs + PartitionInputList=partition_inputs, ) after = datetime.now(pytz.utc) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(20) for idx, partition in enumerate(partitions): partition_input = partition_inputs[idx] - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor'].should.equal(partition_input['StorageDescriptor']) - partition['Values'].should.equal(partition_input['Values']) - partition['CreationTime'].should.be.greater_than(before) - partition['CreationTime'].should.be.lower_than(after) + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"].should.equal( + partition_input["StorageDescriptor"] + ) + partition["Values"].should.equal(partition_input["Values"]) + partition["CreationTime"].should.be.greater_than(before) + partition["CreationTime"].should.be.lower_than(after) @mock_glue def test_batch_create_partition_already_exist(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) helpers.create_partition(client, database_name, table_name, values=values) - partition_input = helpers.create_partition_input(database_name, table_name, values=values) + partition_input = helpers.create_partition_input( + database_name, table_name, values=values + ) response = client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=[partition_input] + PartitionInputList=[partition_input], ) - response.should.have.key('Errors') - response['Errors'].should.have.length_of(1) - response['Errors'][0]['PartitionValues'].should.equal(values) - response['Errors'][0]['ErrorDetail']['ErrorCode'].should.equal('AlreadyExistsException') + response.should.have.key("Errors") + response["Errors"].should.have.length_of(1) + response["Errors"][0]["PartitionValues"].should.equal(values) + response["Errors"][0]["ErrorDetail"]["ErrorCode"].should.equal( + "AlreadyExistsException" + ) @mock_glue def test_get_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01']] + values = [["2018-10-01"], ["2018-09-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[1]) - response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=values[1]) + response = client.get_partition( + DatabaseName=database_name, TableName=table_name, PartitionValues=values[1] + ) - partition = response['Partition'] + partition = response["Partition"] - partition['TableName'].should.equal(table_name) - partition['Values'].should.equal(values[1]) + partition["TableName"].should.equal(table_name) + partition["Values"].should.equal(values[1]) @mock_glue def test_batch_get_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01']] + values = [["2018-10-01"], ["2018-09-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[1]) - partitions_to_get = [ - {'Values': values[0]}, - {'Values': values[1]}, - ] - response = client.batch_get_partition(DatabaseName=database_name, TableName=table_name, PartitionsToGet=partitions_to_get) + partitions_to_get = [{"Values": values[0]}, {"Values": values[1]}] + response = client.batch_get_partition( + DatabaseName=database_name, + TableName=table_name, + PartitionsToGet=partitions_to_get, + ) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(2) partition = partitions[1] - partition['TableName'].should.equal(table_name) - partition['Values'].should.equal(values[1]) + partition["TableName"].should.equal(table_name) + partition["Values"].should.equal(values[1]) @mock_glue def test_batch_get_partition_missing_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01'], ['2018-08-01']] + values = [["2018-10-01"], ["2018-09-01"], ["2018-08-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[2]) partitions_to_get = [ - {'Values': values[0]}, - {'Values': values[1]}, - {'Values': values[2]}, + {"Values": values[0]}, + {"Values": values[1]}, + {"Values": values[2]}, ] - response = client.batch_get_partition(DatabaseName=database_name, TableName=table_name, PartitionsToGet=partitions_to_get) + response = client.batch_get_partition( + DatabaseName=database_name, + TableName=table_name, + PartitionsToGet=partitions_to_get, + ) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(2) - partitions[0]['Values'].should.equal(values[0]) - partitions[1]['Values'].should.equal(values[2]) - + partitions[0]["Values"].should.equal(values[0]) + partitions[1]["Values"].should.equal(values[2]) @mock_glue def test_update_partition_not_found_moving(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.update_partition(client, database_name, table_name, old_values=['0000-00-00'], values=['2018-10-02']) + helpers.update_partition( + client, + database_name, + table_name, + old_values=["0000-00-00"], + values=["2018-10-02"], + ) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('partition') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("partition") @mock_glue def test_update_partition_not_found_change_in_place(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.update_partition(client, database_name, table_name, old_values=values, values=values) + helpers.update_partition( + client, database_name, table_name, old_values=values, values=values + ) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('partition') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("partition") @mock_glue def test_update_partition_cannot_overwrite(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01']] + values = [["2018-10-01"], ["2018-09-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[1]) with assert_raises(ClientError) as exc: - helpers.update_partition(client, database_name, table_name, old_values=values[0], values=values[1]) + helpers.update_partition( + client, database_name, table_name, old_values=values[0], values=values[1] + ) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_update_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -546,23 +590,27 @@ def test_update_partition(): table_name, old_values=values, values=values, - columns=[{'Name': 'country', 'Type': 'string'}], + columns=[{"Name": "country", "Type": "string"}], ) - response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=values) - partition = response['Partition'] + response = client.get_partition( + DatabaseName=database_name, TableName=table_name, PartitionValues=values + ) + partition = response["Partition"] - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor']['Columns'].should.equal([{'Name': 'country', 'Type': 'string'}]) + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"]["Columns"].should.equal( + [{"Name": "country", "Type": "string"}] + ) @mock_glue def test_update_partition_move(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] - new_values = ['2018-09-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] + new_values = ["2018-09-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -574,79 +622,86 @@ def test_update_partition_move(): table_name, old_values=values, values=new_values, - columns=[{'Name': 'country', 'Type': 'string'}], + columns=[{"Name": "country", "Type": "string"}], ) with assert_raises(ClientError) as exc: helpers.get_partition(client, database_name, table_name, values) # Old partition shouldn't exist anymore - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") - response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=new_values) - partition = response['Partition'] + response = client.get_partition( + DatabaseName=database_name, TableName=table_name, PartitionValues=new_values + ) + partition = response["Partition"] + + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"]["Columns"].should.equal( + [{"Name": "country", "Type": "string"}] + ) - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor']['Columns'].should.equal([{'Name': 'country', 'Type': 'string'}]) @mock_glue def test_delete_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) helpers.create_partition(client, database_name, table_name, part_input) client.delete_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionValues=values, + DatabaseName=database_name, TableName=table_name, PartitionValues=values ) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.be.empty + @mock_glue def test_delete_partition_bad_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: client.delete_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionValues=values, + DatabaseName=database_name, TableName=table_name, PartitionValues=values ) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + @mock_glue def test_batch_delete_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) partition_inputs = [] for i in range(0, 20): values = ["2018-10-{:2}".format(i)] - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) partition_inputs.append(part_input) client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=partition_inputs + PartitionInputList=partition_inputs, ) partition_values = [{"Values": p["Values"]} for p in partition_inputs] @@ -657,26 +712,29 @@ def test_batch_delete_partition(): PartitionsToDelete=partition_values, ) - response.should_not.have.key('Errors') + response.should_not.have.key("Errors") + @mock_glue def test_batch_delete_partition_with_bad_partitions(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) partition_inputs = [] for i in range(0, 20): values = ["2018-10-{:2}".format(i)] - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) partition_inputs.append(part_input) client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=partition_inputs + PartitionInputList=partition_inputs, ) partition_values = [{"Values": p["Values"]} for p in partition_inputs] @@ -691,9 +749,9 @@ def test_batch_delete_partition_with_bad_partitions(): PartitionsToDelete=partition_values, ) - response.should.have.key('Errors') - response['Errors'].should.have.length_of(3) - error_partitions = map(lambda x: x['PartitionValues'], response['Errors']) - ['2018-11-01'].should.be.within(error_partitions) - ['2018-11-02'].should.be.within(error_partitions) - ['2018-11-03'].should.be.within(error_partitions) + response.should.have.key("Errors") + response["Errors"].should.have.length_of(3) + error_partitions = map(lambda x: x["PartitionValues"], response["Errors"]) + ["2018-11-01"].should.be.within(error_partitions) + ["2018-11-02"].should.be.within(error_partitions) + ["2018-11-03"].should.be.within(error_partitions) diff --git a/tests/test_iam/test_iam.py b/tests/test_iam/test_iam.py index fe2117a3a..6311dce9c 100644 --- a/tests/test_iam/test_iam.py +++ b/tests/test_iam/test_iam.py @@ -9,13 +9,17 @@ import sure # noqa import sys from boto.exception import BotoServerError from botocore.exceptions import ClientError +from dateutil.tz import tzutc + from moto import mock_iam, mock_iam_deprecated from moto.iam.models import aws_managed_policies +from moto.core import ACCOUNT_ID from nose.tools import assert_raises, assert_equals from nose.tools import raises from datetime import datetime from tests.helpers import requires_boto_gte +from uuid import uuid4 MOCK_CERT = """-----BEGIN CERTIFICATE----- @@ -74,13 +78,15 @@ def test_get_all_server_certs(): conn = boto.connect_iam() conn.upload_server_cert("certname", "certbody", "privatekey") - certs = conn.get_all_server_certs()['list_server_certificates_response'][ - 'list_server_certificates_result']['server_certificate_metadata_list'] + certs = conn.get_all_server_certs()["list_server_certificates_response"][ + "list_server_certificates_result" + ]["server_certificate_metadata_list"] certs.should.have.length_of(1) cert1 = certs[0] cert1.server_certificate_name.should.equal("certname") cert1.arn.should.equal( - "arn:aws:iam::123456789012:server-certificate/certname") + "arn:aws:iam::{}:server-certificate/certname".format(ACCOUNT_ID) + ) @mock_iam_deprecated() @@ -99,7 +105,8 @@ def test_get_server_cert(): cert = conn.get_server_certificate("certname") cert.server_certificate_name.should.equal("certname") cert.arn.should.equal( - "arn:aws:iam::123456789012:server-certificate/certname") + "arn:aws:iam::{}:server-certificate/certname".format(ACCOUNT_ID) + ) @mock_iam_deprecated() @@ -110,7 +117,8 @@ def test_upload_server_cert(): cert = conn.get_server_certificate("certname") cert.server_certificate_name.should.equal("certname") cert.arn.should.equal( - "arn:aws:iam::123456789012:server-certificate/certname") + "arn:aws:iam::{}:server-certificate/certname".format(ACCOUNT_ID) + ) @mock_iam_deprecated() @@ -131,7 +139,7 @@ def test_delete_server_cert(): def test_get_role__should_throw__when_role_does_not_exist(): conn = boto.connect_iam() - conn.get_role('unexisting_role') + conn.get_role("unexisting_role") @mock_iam_deprecated() @@ -139,7 +147,7 @@ def test_get_role__should_throw__when_role_does_not_exist(): def test_get_instance_profile__should_throw__when_instance_profile_does_not_exist(): conn = boto.connect_iam() - conn.get_instance_profile('unexisting_instance_profile') + conn.get_instance_profile("unexisting_instance_profile") @mock_iam_deprecated() @@ -147,7 +155,8 @@ def test_create_role_and_instance_profile(): conn = boto.connect_iam() conn.create_instance_profile("my-profile", path="my-path") conn.create_role( - "my-role", assume_role_policy_document="some policy", path="my-path") + "my-role", assume_role_policy_document="some policy", path="my-path" + ) conn.add_role_to_instance_profile("my-profile", "my-role") @@ -158,26 +167,36 @@ def test_create_role_and_instance_profile(): profile = conn.get_instance_profile("my-profile") profile.path.should.equal("my-path") role_from_profile = list(profile.roles.values())[0] - role_from_profile['role_id'].should.equal(role.role_id) - role_from_profile['role_name'].should.equal("my-role") + role_from_profile["role_id"].should.equal(role.role_id) + role_from_profile["role_name"].should.equal("my-role") - conn.list_roles().roles[0].role_name.should.equal('my-role') + conn.list_roles().roles[0].role_name.should.equal("my-role") # Test with an empty path: - profile = conn.create_instance_profile('my-other-profile') - profile.path.should.equal('/') + profile = conn.create_instance_profile("my-other-profile") + profile.path.should.equal("/") + + +@mock_iam +def test_create_instance_profile_should_throw_when_name_is_not_unique(): + conn = boto3.client("iam", region_name="us-east-1") + conn.create_instance_profile(InstanceProfileName="unique-instance-profile") + with assert_raises(ClientError): + conn.create_instance_profile(InstanceProfileName="unique-instance-profile") + @mock_iam_deprecated() def test_remove_role_from_instance_profile(): conn = boto.connect_iam() conn.create_instance_profile("my-profile", path="my-path") conn.create_role( - "my-role", assume_role_policy_document="some policy", path="my-path") + "my-role", assume_role_policy_document="some policy", path="my-path" + ) conn.add_role_to_instance_profile("my-profile", "my-role") profile = conn.get_instance_profile("my-profile") role_from_profile = list(profile.roles.values())[0] - role_from_profile['role_name'].should.equal("my-role") + role_from_profile["role_name"].should.equal("my-role") conn.remove_role_from_instance_profile("my-profile", "my-role") @@ -187,41 +206,89 @@ def test_remove_role_from_instance_profile(): @mock_iam() def test_get_login_profile(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') - conn.create_login_profile(UserName='my-user', Password='my-pass') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") + conn.create_login_profile(UserName="my-user", Password="my-pass") - response = conn.get_login_profile(UserName='my-user') - response['LoginProfile']['UserName'].should.equal('my-user') + response = conn.get_login_profile(UserName="my-user") + response["LoginProfile"]["UserName"].should.equal("my-user") @mock_iam() def test_update_login_profile(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') - conn.create_login_profile(UserName='my-user', Password='my-pass') - response = conn.get_login_profile(UserName='my-user') - response['LoginProfile'].get('PasswordResetRequired').should.equal(None) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") + conn.create_login_profile(UserName="my-user", Password="my-pass") + response = conn.get_login_profile(UserName="my-user") + response["LoginProfile"].get("PasswordResetRequired").should.equal(None) - conn.update_login_profile(UserName='my-user', Password='new-pass', PasswordResetRequired=True) - response = conn.get_login_profile(UserName='my-user') - response['LoginProfile'].get('PasswordResetRequired').should.equal(True) + conn.update_login_profile( + UserName="my-user", Password="new-pass", PasswordResetRequired=True + ) + response = conn.get_login_profile(UserName="my-user") + response["LoginProfile"].get("PasswordResetRequired").should.equal(True) @mock_iam() def test_delete_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") - with assert_raises(ClientError): + with assert_raises(conn.exceptions.NoSuchEntityException): conn.delete_role(RoleName="my-role") - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") - role = conn.get_role(RoleName="my-role") - role.get('Role').get('Arn').should.equal('arn:aws:iam::123456789012:role/my-path/my-role') - + # Test deletion failure with a managed policy + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + response = conn.create_policy( + PolicyName="my-managed-policy", PolicyDocument=MOCK_POLICY + ) + conn.attach_role_policy(PolicyArn=response["Policy"]["Arn"], RoleName="my-role") + with assert_raises(conn.exceptions.DeleteConflictException): + conn.delete_role(RoleName="my-role") + conn.detach_role_policy(PolicyArn=response["Policy"]["Arn"], RoleName="my-role") + conn.delete_policy(PolicyArn=response["Policy"]["Arn"]) conn.delete_role(RoleName="my-role") + with assert_raises(conn.exceptions.NoSuchEntityException): + conn.get_role(RoleName="my-role") - with assert_raises(ClientError): + # Test deletion failure with an inline policy + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + conn.put_role_policy( + RoleName="my-role", PolicyName="my-role-policy", PolicyDocument=MOCK_POLICY + ) + with assert_raises(conn.exceptions.DeleteConflictException): + conn.delete_role(RoleName="my-role") + conn.delete_role_policy(RoleName="my-role", PolicyName="my-role-policy") + conn.delete_role(RoleName="my-role") + with assert_raises(conn.exceptions.NoSuchEntityException): + conn.get_role(RoleName="my-role") + + # Test deletion failure with attachment to an instance profile + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + conn.create_instance_profile(InstanceProfileName="my-profile") + conn.add_role_to_instance_profile( + InstanceProfileName="my-profile", RoleName="my-role" + ) + with assert_raises(conn.exceptions.DeleteConflictException): + conn.delete_role(RoleName="my-role") + conn.remove_role_from_instance_profile( + InstanceProfileName="my-profile", RoleName="my-role" + ) + conn.delete_role(RoleName="my-role") + with assert_raises(conn.exceptions.NoSuchEntityException): + conn.get_role(RoleName="my-role") + + # Test deletion with no conflicts + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + conn.delete_role(RoleName="my-role") + with assert_raises(conn.exceptions.NoSuchEntityException): conn.get_role(RoleName="my-role") @@ -244,37 +311,43 @@ def test_list_instance_profiles(): def test_list_instance_profiles_for_role(): conn = boto.connect_iam() - conn.create_role(role_name="my-role", - assume_role_policy_document="some policy", path="my-path") - conn.create_role(role_name="my-role2", - assume_role_policy_document="some policy2", path="my-path2") + conn.create_role( + role_name="my-role", assume_role_policy_document="some policy", path="my-path" + ) + conn.create_role( + role_name="my-role2", + assume_role_policy_document="some policy2", + path="my-path2", + ) - profile_name_list = ['my-profile', 'my-profile2'] - profile_path_list = ['my-path', 'my-path2'] + profile_name_list = ["my-profile", "my-profile2"] + profile_path_list = ["my-path", "my-path2"] for profile_count in range(0, 2): conn.create_instance_profile( - profile_name_list[profile_count], path=profile_path_list[profile_count]) + profile_name_list[profile_count], path=profile_path_list[profile_count] + ) for profile_count in range(0, 2): - conn.add_role_to_instance_profile( - profile_name_list[profile_count], "my-role") + conn.add_role_to_instance_profile(profile_name_list[profile_count], "my-role") profile_dump = conn.list_instance_profiles_for_role(role_name="my-role") - profile_list = profile_dump['list_instance_profiles_for_role_response'][ - 'list_instance_profiles_for_role_result']['instance_profiles'] + profile_list = profile_dump["list_instance_profiles_for_role_response"][ + "list_instance_profiles_for_role_result" + ]["instance_profiles"] for profile_count in range(0, len(profile_list)): - profile_name_list.remove(profile_list[profile_count][ - "instance_profile_name"]) + profile_name_list.remove(profile_list[profile_count]["instance_profile_name"]) profile_path_list.remove(profile_list[profile_count]["path"]) - profile_list[profile_count]["roles"]["member"][ - "role_name"].should.equal("my-role") + profile_list[profile_count]["roles"]["member"]["role_name"].should.equal( + "my-role" + ) len(profile_name_list).should.equal(0) len(profile_path_list).should.equal(0) profile_dump2 = conn.list_instance_profiles_for_role(role_name="my-role2") - profile_list = profile_dump2['list_instance_profiles_for_role_response'][ - 'list_instance_profiles_for_role_result']['instance_profiles'] + profile_list = profile_dump2["list_instance_profiles_for_role_response"][ + "list_instance_profiles_for_role_result" + ]["instance_profiles"] len(profile_list).should.equal(0) @@ -304,18 +377,21 @@ def test_list_role_policies(): def test_put_role_policy(): conn = boto.connect_iam() conn.create_role( - "my-role", assume_role_policy_document="some policy", path="my-path") + "my-role", assume_role_policy_document="some policy", path="my-path" + ) conn.put_role_policy("my-role", "test policy", MOCK_POLICY) - policy = conn.get_role_policy( - "my-role", "test policy")['get_role_policy_response']['get_role_policy_result']['policy_name'] + policy = conn.get_role_policy("my-role", "test policy")["get_role_policy_response"][ + "get_role_policy_result" + ]["policy_name"] policy.should.equal("test policy") @mock_iam def test_get_role_policy(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_role( - RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="my-path") + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="my-path" + ) with assert_raises(conn.exceptions.NoSuchEntityException): conn.get_role_policy(RoleName="my-role", PolicyName="does-not-exist") @@ -331,329 +407,394 @@ def test_update_assume_role_policy(): @mock_iam def test_create_policy(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") response = conn.create_policy( - PolicyName="TestCreatePolicy", - PolicyDocument=MOCK_POLICY) - response['Policy']['Arn'].should.equal("arn:aws:iam::123456789012:policy/TestCreatePolicy") + PolicyName="TestCreatePolicy", PolicyDocument=MOCK_POLICY + ) + response["Policy"]["Arn"].should.equal( + "arn:aws:iam::{}:policy/TestCreatePolicy".format(ACCOUNT_ID) + ) + + +@mock_iam +def test_create_policy_already_exists(): + conn = boto3.client("iam", region_name="us-east-1") + response = conn.create_policy( + PolicyName="TestCreatePolicy", PolicyDocument=MOCK_POLICY + ) + with assert_raises(conn.exceptions.EntityAlreadyExistsException) as ex: + response = conn.create_policy( + PolicyName="TestCreatePolicy", PolicyDocument=MOCK_POLICY + ) + ex.exception.response["Error"]["Code"].should.equal("EntityAlreadyExists") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(409) + ex.exception.response["Error"]["Message"].should.contain("TestCreatePolicy") + + +@mock_iam +def test_delete_policy(): + conn = boto3.client("iam", region_name="us-east-1") + response = conn.create_policy( + PolicyName="TestCreatePolicy", PolicyDocument=MOCK_POLICY + ) + [ + pol["PolicyName"] for pol in conn.list_policies(Scope="Local")["Policies"] + ].should.equal(["TestCreatePolicy"]) + conn.delete_policy(PolicyArn=response["Policy"]["Arn"]) + assert conn.list_policies(Scope="Local")["Policies"].should.be.empty @mock_iam def test_create_policy_versions(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", - PolicyDocument='{"some":"policy"}') - conn.create_policy( - PolicyName="TestCreatePolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyArn="arn:aws:iam::{}:policy/TestCreatePolicyVersion".format( + ACCOUNT_ID + ), + PolicyDocument='{"some":"policy"}', + ) + conn.create_policy(PolicyName="TestCreatePolicyVersion", PolicyDocument=MOCK_POLICY) version = conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", + PolicyArn="arn:aws:iam::{}:policy/TestCreatePolicyVersion".format(ACCOUNT_ID), PolicyDocument=MOCK_POLICY, - SetAsDefault=True) - version.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY)) - version.get('PolicyVersion').get('VersionId').should.equal("v2") - version.get('PolicyVersion').get('IsDefaultVersion').should.be.ok + SetAsDefault=True, + ) + version.get("PolicyVersion").get("Document").should.equal(json.loads(MOCK_POLICY)) + version.get("PolicyVersion").get("VersionId").should.equal("v2") + version.get("PolicyVersion").get("IsDefaultVersion").should.be.ok conn.delete_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", - VersionId="v1") + PolicyArn="arn:aws:iam::{}:policy/TestCreatePolicyVersion".format(ACCOUNT_ID), + VersionId="v1", + ) version = conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", - PolicyDocument=MOCK_POLICY) - version.get('PolicyVersion').get('VersionId').should.equal("v3") - version.get('PolicyVersion').get('IsDefaultVersion').shouldnt.be.ok + PolicyArn="arn:aws:iam::{}:policy/TestCreatePolicyVersion".format(ACCOUNT_ID), + PolicyDocument=MOCK_POLICY, + ) + version.get("PolicyVersion").get("VersionId").should.equal("v3") + version.get("PolicyVersion").get("IsDefaultVersion").shouldnt.be.ok @mock_iam def test_create_many_policy_versions(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_policy( - PolicyName="TestCreateManyPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyName="TestCreateManyPolicyVersions", PolicyDocument=MOCK_POLICY + ) for _ in range(0, 4): conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestCreateManyPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyArn="arn:aws:iam::{}:policy/TestCreateManyPolicyVersions".format( + ACCOUNT_ID + ), + PolicyDocument=MOCK_POLICY, + ) with assert_raises(ClientError): conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestCreateManyPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyArn="arn:aws:iam::{}:policy/TestCreateManyPolicyVersions".format( + ACCOUNT_ID + ), + PolicyDocument=MOCK_POLICY, + ) @mock_iam def test_set_default_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_policy( - PolicyName="TestSetDefaultPolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyName="TestSetDefaultPolicyVersion", PolicyDocument=MOCK_POLICY + ) conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion", + PolicyArn="arn:aws:iam::{}:policy/TestSetDefaultPolicyVersion".format( + ACCOUNT_ID + ), PolicyDocument=MOCK_POLICY_2, - SetAsDefault=True) + SetAsDefault=True, + ) conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion", + PolicyArn="arn:aws:iam::{}:policy/TestSetDefaultPolicyVersion".format( + ACCOUNT_ID + ), PolicyDocument=MOCK_POLICY_3, - SetAsDefault=True) + SetAsDefault=True, + ) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion") - versions.get('Versions')[0].get('Document').should.equal(json.loads(MOCK_POLICY)) - versions.get('Versions')[0].get('IsDefaultVersion').shouldnt.be.ok - versions.get('Versions')[1].get('Document').should.equal(json.loads(MOCK_POLICY_2)) - versions.get('Versions')[1].get('IsDefaultVersion').shouldnt.be.ok - versions.get('Versions')[2].get('Document').should.equal(json.loads(MOCK_POLICY_3)) - versions.get('Versions')[2].get('IsDefaultVersion').should.be.ok + PolicyArn="arn:aws:iam::{}:policy/TestSetDefaultPolicyVersion".format( + ACCOUNT_ID + ) + ) + versions.get("Versions")[0].get("Document").should.equal(json.loads(MOCK_POLICY)) + versions.get("Versions")[0].get("IsDefaultVersion").shouldnt.be.ok + versions.get("Versions")[1].get("Document").should.equal(json.loads(MOCK_POLICY_2)) + versions.get("Versions")[1].get("IsDefaultVersion").shouldnt.be.ok + versions.get("Versions")[2].get("Document").should.equal(json.loads(MOCK_POLICY_3)) + versions.get("Versions")[2].get("IsDefaultVersion").should.be.ok @mock_iam def test_get_policy(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") response = conn.create_policy( - PolicyName="TestGetPolicy", - PolicyDocument=MOCK_POLICY) + PolicyName="TestGetPolicy", PolicyDocument=MOCK_POLICY + ) policy = conn.get_policy( - PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicy") - policy['Policy']['Arn'].should.equal("arn:aws:iam::123456789012:policy/TestGetPolicy") + PolicyArn="arn:aws:iam::{}:policy/TestGetPolicy".format(ACCOUNT_ID) + ) + policy["Policy"]["Arn"].should.equal( + "arn:aws:iam::{}:policy/TestGetPolicy".format(ACCOUNT_ID) + ) @mock_iam def test_get_aws_managed_policy(): - conn = boto3.client('iam', region_name='us-east-1') - managed_policy_arn = 'arn:aws:iam::aws:policy/IAMUserChangePassword' - managed_policy_create_date = datetime.strptime("2016-11-15T00:25:16+00:00", "%Y-%m-%dT%H:%M:%S+00:00") - policy = conn.get_policy( - PolicyArn=managed_policy_arn) - policy['Policy']['Arn'].should.equal(managed_policy_arn) - policy['Policy']['CreateDate'].replace(tzinfo=None).should.equal(managed_policy_create_date) + conn = boto3.client("iam", region_name="us-east-1") + managed_policy_arn = "arn:aws:iam::aws:policy/IAMUserChangePassword" + managed_policy_create_date = datetime.strptime( + "2016-11-15T00:25:16+00:00", "%Y-%m-%dT%H:%M:%S+00:00" + ) + policy = conn.get_policy(PolicyArn=managed_policy_arn) + policy["Policy"]["Arn"].should.equal(managed_policy_arn) + policy["Policy"]["CreateDate"].replace(tzinfo=None).should.equal( + managed_policy_create_date + ) @mock_iam def test_get_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_policy( - PolicyName="TestGetPolicyVersion", - PolicyDocument=MOCK_POLICY) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_policy(PolicyName="TestGetPolicyVersion", PolicyDocument=MOCK_POLICY) version = conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyArn="arn:aws:iam::{}:policy/TestGetPolicyVersion".format(ACCOUNT_ID), + PolicyDocument=MOCK_POLICY, + ) with assert_raises(ClientError): conn.get_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion", - VersionId='v2-does-not-exist') + PolicyArn="arn:aws:iam::{}:policy/TestGetPolicyVersion".format(ACCOUNT_ID), + VersionId="v2-does-not-exist", + ) retrieved = conn.get_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion", - VersionId=version.get('PolicyVersion').get('VersionId')) - retrieved.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY)) - retrieved.get('PolicyVersion').get('IsDefaultVersion').shouldnt.be.ok + PolicyArn="arn:aws:iam::{}:policy/TestGetPolicyVersion".format(ACCOUNT_ID), + VersionId=version.get("PolicyVersion").get("VersionId"), + ) + retrieved.get("PolicyVersion").get("Document").should.equal(json.loads(MOCK_POLICY)) + retrieved.get("PolicyVersion").get("IsDefaultVersion").shouldnt.be.ok @mock_iam def test_get_aws_managed_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - managed_policy_arn = 'arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' - managed_policy_version_create_date = datetime.strptime("2015-04-09T15:03:43+00:00", "%Y-%m-%dT%H:%M:%S+00:00") + conn = boto3.client("iam", region_name="us-east-1") + managed_policy_arn = ( + "arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole" + ) + managed_policy_version_create_date = datetime.strptime( + "2015-04-09T15:03:43+00:00", "%Y-%m-%dT%H:%M:%S+00:00" + ) with assert_raises(ClientError): conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId='v2-does-not-exist') - retrieved = conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId="v1") - retrieved['PolicyVersion']['CreateDate'].replace(tzinfo=None).should.equal(managed_policy_version_create_date) - retrieved['PolicyVersion']['Document'].should.be.an(dict) + PolicyArn=managed_policy_arn, VersionId="v2-does-not-exist" + ) + retrieved = conn.get_policy_version(PolicyArn=managed_policy_arn, VersionId="v1") + retrieved["PolicyVersion"]["CreateDate"].replace(tzinfo=None).should.equal( + managed_policy_version_create_date + ) + retrieved["PolicyVersion"]["Document"].should.be.an(dict) @mock_iam def test_get_aws_managed_policy_v4_version(): - conn = boto3.client('iam', region_name='us-east-1') - managed_policy_arn = 'arn:aws:iam::aws:policy/job-function/SystemAdministrator' - managed_policy_version_create_date = datetime.strptime("2018-10-08T21:33:45+00:00", "%Y-%m-%dT%H:%M:%S+00:00") + conn = boto3.client("iam", region_name="us-east-1") + managed_policy_arn = "arn:aws:iam::aws:policy/job-function/SystemAdministrator" + managed_policy_version_create_date = datetime.strptime( + "2018-10-08T21:33:45+00:00", "%Y-%m-%dT%H:%M:%S+00:00" + ) with assert_raises(ClientError): conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId='v2-does-not-exist') - retrieved = conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId="v4") - retrieved['PolicyVersion']['CreateDate'].replace(tzinfo=None).should.equal(managed_policy_version_create_date) - retrieved['PolicyVersion']['Document'].should.be.an(dict) + PolicyArn=managed_policy_arn, VersionId="v2-does-not-exist" + ) + retrieved = conn.get_policy_version(PolicyArn=managed_policy_arn, VersionId="v4") + retrieved["PolicyVersion"]["CreateDate"].replace(tzinfo=None).should.equal( + managed_policy_version_create_date + ) + retrieved["PolicyVersion"]["Document"].should.be.an(dict) @mock_iam def test_list_policy_versions(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") - conn.create_policy( - PolicyName="TestListPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyArn="arn:aws:iam::{}:policy/TestListPolicyVersions".format(ACCOUNT_ID) + ) + conn.create_policy(PolicyName="TestListPolicyVersions", PolicyDocument=MOCK_POLICY) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") - versions.get('Versions')[0].get('VersionId').should.equal('v1') - versions.get('Versions')[0].get('IsDefaultVersion').should.be.ok + PolicyArn="arn:aws:iam::{}:policy/TestListPolicyVersions".format(ACCOUNT_ID) + ) + versions.get("Versions")[0].get("VersionId").should.equal("v1") + versions.get("Versions")[0].get("IsDefaultVersion").should.be.ok conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions", - PolicyDocument=MOCK_POLICY_2) + PolicyArn="arn:aws:iam::{}:policy/TestListPolicyVersions".format(ACCOUNT_ID), + PolicyDocument=MOCK_POLICY_2, + ) conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions", - PolicyDocument=MOCK_POLICY_3) + PolicyArn="arn:aws:iam::{}:policy/TestListPolicyVersions".format(ACCOUNT_ID), + PolicyDocument=MOCK_POLICY_3, + ) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") - versions.get('Versions')[1].get('Document').should.equal(json.loads(MOCK_POLICY_2)) - versions.get('Versions')[1].get('IsDefaultVersion').shouldnt.be.ok - versions.get('Versions')[2].get('Document').should.equal(json.loads(MOCK_POLICY_3)) - versions.get('Versions')[2].get('IsDefaultVersion').shouldnt.be.ok + PolicyArn="arn:aws:iam::{}:policy/TestListPolicyVersions".format(ACCOUNT_ID) + ) + versions.get("Versions")[1].get("Document").should.equal(json.loads(MOCK_POLICY_2)) + versions.get("Versions")[1].get("IsDefaultVersion").shouldnt.be.ok + versions.get("Versions")[2].get("Document").should.equal(json.loads(MOCK_POLICY_3)) + versions.get("Versions")[2].get("IsDefaultVersion").shouldnt.be.ok @mock_iam def test_delete_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_policy( - PolicyName="TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_policy(PolicyName="TestDeletePolicyVersion", PolicyDocument=MOCK_POLICY) conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyArn="arn:aws:iam::{}:policy/TestDeletePolicyVersion".format(ACCOUNT_ID), + PolicyDocument=MOCK_POLICY, + ) with assert_raises(ClientError): conn.delete_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - VersionId='v2-nope-this-does-not-exist') + PolicyArn="arn:aws:iam::{}:policy/TestDeletePolicyVersion".format( + ACCOUNT_ID + ), + VersionId="v2-nope-this-does-not-exist", + ) conn.delete_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - VersionId='v2') + PolicyArn="arn:aws:iam::{}:policy/TestDeletePolicyVersion".format(ACCOUNT_ID), + VersionId="v2", + ) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion") - len(versions.get('Versions')).should.equal(1) + PolicyArn="arn:aws:iam::{}:policy/TestDeletePolicyVersion".format(ACCOUNT_ID) + ) + len(versions.get("Versions")).should.equal(1) @mock_iam def test_delete_default_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_policy( - PolicyName="TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_policy(PolicyName="TestDeletePolicyVersion", PolicyDocument=MOCK_POLICY) conn.create_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY_2) + PolicyArn="arn:aws:iam::{}:policy/TestDeletePolicyVersion".format(ACCOUNT_ID), + PolicyDocument=MOCK_POLICY_2, + ) with assert_raises(ClientError): conn.delete_policy_version( - PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - VersionId='v1') + PolicyArn="arn:aws:iam::{}:policy/TestDeletePolicyVersion".format( + ACCOUNT_ID + ), + VersionId="v1", + ) @mock_iam_deprecated() def test_create_user(): conn = boto.connect_iam() - conn.create_user('my-user') + conn.create_user("my-user") with assert_raises(BotoServerError): - conn.create_user('my-user') + conn.create_user("my-user") @mock_iam_deprecated() def test_get_user(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.get_user('my-user') - conn.create_user('my-user') - conn.get_user('my-user') + conn.get_user("my-user") + conn.create_user("my-user") + conn.get_user("my-user") @mock_iam() def test_update_user(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.update_user(UserName='my-user') - conn.create_user(UserName='my-user') - conn.update_user(UserName='my-user', NewPath='/new-path/', NewUserName='new-user') - response = conn.get_user(UserName='new-user') - response['User'].get('Path').should.equal('/new-path/') + conn.update_user(UserName="my-user") + conn.create_user(UserName="my-user") + conn.update_user(UserName="my-user", NewPath="/new-path/", NewUserName="new-user") + response = conn.get_user(UserName="new-user") + response["User"].get("Path").should.equal("/new-path/") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.get_user(UserName='my-user') + conn.get_user(UserName="my-user") @mock_iam_deprecated() def test_get_current_user(): """If no user is specific, IAM returns the current user""" conn = boto.connect_iam() - user = conn.get_user()['get_user_response']['get_user_result']['user'] - user['user_name'].should.equal('default_user') + user = conn.get_user()["get_user_response"]["get_user_result"]["user"] + user["user_name"].should.equal("default_user") @mock_iam() def test_list_users(): - path_prefix = '/' + path_prefix = "/" max_items = 10 - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") response = conn.list_users(PathPrefix=path_prefix, MaxItems=max_items) - user = response['Users'][0] - user['UserName'].should.equal('my-user') - user['Path'].should.equal('/') - user['Arn'].should.equal('arn:aws:iam::123456789012:user/my-user') + user = response["Users"][0] + user["UserName"].should.equal("my-user") + user["Path"].should.equal("/") + user["Arn"].should.equal("arn:aws:iam::{}:user/my-user".format(ACCOUNT_ID)) @mock_iam() def test_user_policies(): - policy_name = 'UserManagedPolicy' - user_name = 'my-user' - conn = boto3.client('iam', region_name='us-east-1') + policy_name = "UserManagedPolicy" + user_name = "my-user" + conn = boto3.client("iam", region_name="us-east-1") conn.create_user(UserName=user_name) conn.put_user_policy( - UserName=user_name, - PolicyName=policy_name, - PolicyDocument=MOCK_POLICY + UserName=user_name, PolicyName=policy_name, PolicyDocument=MOCK_POLICY ) - policy_doc = conn.get_user_policy( - UserName=user_name, - PolicyName=policy_name - ) - policy_doc['PolicyDocument'].should.equal(json.loads(MOCK_POLICY)) + policy_doc = conn.get_user_policy(UserName=user_name, PolicyName=policy_name) + policy_doc["PolicyDocument"].should.equal(json.loads(MOCK_POLICY)) policies = conn.list_user_policies(UserName=user_name) - len(policies['PolicyNames']).should.equal(1) - policies['PolicyNames'][0].should.equal(policy_name) + len(policies["PolicyNames"]).should.equal(1) + policies["PolicyNames"][0].should.equal(policy_name) - conn.delete_user_policy( - UserName=user_name, - PolicyName=policy_name - ) + conn.delete_user_policy(UserName=user_name, PolicyName=policy_name) policies = conn.list_user_policies(UserName=user_name) - len(policies['PolicyNames']).should.equal(0) + len(policies["PolicyNames"]).should.equal(0) @mock_iam_deprecated() def test_create_login_profile(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.create_login_profile('my-user', 'my-pass') - conn.create_user('my-user') - conn.create_login_profile('my-user', 'my-pass') + conn.create_login_profile("my-user", "my-pass") + conn.create_user("my-user") + conn.create_login_profile("my-user", "my-pass") with assert_raises(BotoServerError): - conn.create_login_profile('my-user', 'my-pass') + conn.create_login_profile("my-user", "my-pass") @mock_iam_deprecated() def test_delete_login_profile(): conn = boto.connect_iam() - conn.create_user('my-user') + conn.create_user("my-user") with assert_raises(BotoServerError): - conn.delete_login_profile('my-user') - conn.create_login_profile('my-user', 'my-pass') - conn.delete_login_profile('my-user') + conn.delete_login_profile("my-user") + conn.create_login_profile("my-user", "my-pass") + conn.delete_login_profile("my-user") @mock_iam() def test_create_access_key(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): - conn.create_access_key(UserName='my-user') - conn.create_user(UserName='my-user') - access_key = conn.create_access_key(UserName='my-user')["AccessKey"] - (datetime.utcnow() - access_key["CreateDate"].replace(tzinfo=None)).seconds.should.be.within(0, 10) + conn.create_access_key(UserName="my-user") + conn.create_user(UserName="my-user") + access_key = conn.create_access_key(UserName="my-user")["AccessKey"] + ( + datetime.utcnow() - access_key["CreateDate"].replace(tzinfo=None) + ).seconds.should.be.within(0, 10) access_key["AccessKeyId"].should.have.length_of(20) access_key["SecretAccessKey"].should.have.length_of(40) assert access_key["AccessKeyId"].startswith("AKIA") @@ -664,884 +805,1870 @@ def test_get_all_access_keys(): """If no access keys exist there should be none in the response, if an access key is present it should have the correct fields present""" conn = boto.connect_iam() - conn.create_user('my-user') - response = conn.get_all_access_keys('my-user') + conn.create_user("my-user") + response = conn.get_all_access_keys("my-user") assert_equals( - response['list_access_keys_response'][ - 'list_access_keys_result']['access_key_metadata'], - [] + response["list_access_keys_response"]["list_access_keys_result"][ + "access_key_metadata" + ], + [], ) - conn.create_access_key('my-user') - response = conn.get_all_access_keys('my-user') + conn.create_access_key("my-user") + response = conn.get_all_access_keys("my-user") assert_equals( - sorted(response['list_access_keys_response'][ - 'list_access_keys_result']['access_key_metadata'][0].keys()), - sorted(['status', 'create_date', 'user_name', 'access_key_id']) + sorted( + response["list_access_keys_response"]["list_access_keys_result"][ + "access_key_metadata" + ][0].keys() + ), + sorted(["status", "create_date", "user_name", "access_key_id"]), ) @mock_iam_deprecated() def test_delete_access_key(): conn = boto.connect_iam() - conn.create_user('my-user') - access_key_id = conn.create_access_key('my-user')['create_access_key_response'][ - 'create_access_key_result']['access_key']['access_key_id'] - conn.delete_access_key(access_key_id, 'my-user') + conn.create_user("my-user") + access_key_id = conn.create_access_key("my-user")["create_access_key_response"][ + "create_access_key_result" + ]["access_key"]["access_key_id"] + conn.delete_access_key(access_key_id, "my-user") @mock_iam() def test_mfa_devices(): # Test enable device - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") conn.enable_mfa_device( - UserName='my-user', - SerialNumber='123456789', - AuthenticationCode1='234567', - AuthenticationCode2='987654' + UserName="my-user", + SerialNumber="123456789", + AuthenticationCode1="234567", + AuthenticationCode2="987654", ) # Test list mfa devices - response = conn.list_mfa_devices(UserName='my-user') - device = response['MFADevices'][0] - device['SerialNumber'].should.equal('123456789') + response = conn.list_mfa_devices(UserName="my-user") + device = response["MFADevices"][0] + device["SerialNumber"].should.equal("123456789") # Test deactivate mfa device - conn.deactivate_mfa_device(UserName='my-user', SerialNumber='123456789') - response = conn.list_mfa_devices(UserName='my-user') - len(response['MFADevices']).should.equal(0) + conn.deactivate_mfa_device(UserName="my-user", SerialNumber="123456789") + response = conn.list_mfa_devices(UserName="my-user") + len(response["MFADevices"]).should.equal(0) + + +@mock_iam +def test_create_virtual_mfa_device(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + device = response["VirtualMFADevice"] + + device["SerialNumber"].should.equal( + "arn:aws:iam::{}:mfa/test-device".format(ACCOUNT_ID) + ) + device["Base32StringSeed"].decode("ascii").should.match("[A-Z234567]") + device["QRCodePNG"].should_not.be.empty + + response = client.create_virtual_mfa_device( + Path="/", VirtualMFADeviceName="test-device-2" + ) + device = response["VirtualMFADevice"] + + device["SerialNumber"].should.equal( + "arn:aws:iam::{}:mfa/test-device-2".format(ACCOUNT_ID) + ) + device["Base32StringSeed"].decode("ascii").should.match("[A-Z234567]") + device["QRCodePNG"].should_not.be.empty + + response = client.create_virtual_mfa_device( + Path="/test/", VirtualMFADeviceName="test-device" + ) + device = response["VirtualMFADevice"] + + device["SerialNumber"].should.equal( + "arn:aws:iam::{}:mfa/test/test-device".format(ACCOUNT_ID) + ) + device["Base32StringSeed"].decode("ascii").should.match("[A-Z234567]") + device["QRCodePNG"].should_not.be.empty + + +@mock_iam +def test_create_virtual_mfa_device_errors(): + client = boto3.client("iam", region_name="us-east-1") + client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + + client.create_virtual_mfa_device.when.called_with( + VirtualMFADeviceName="test-device" + ).should.throw( + ClientError, "MFADevice entity at the same path and name already exists." + ) + + client.create_virtual_mfa_device.when.called_with( + Path="test", VirtualMFADeviceName="test-device" + ).should.throw( + ClientError, + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters.", + ) + + client.create_virtual_mfa_device.when.called_with( + Path="/test//test/", VirtualMFADeviceName="test-device" + ).should.throw( + ClientError, + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters.", + ) + + too_long_path = "/{}/".format("b" * 511) + client.create_virtual_mfa_device.when.called_with( + Path=too_long_path, VirtualMFADeviceName="test-device" + ).should.throw( + ClientError, + "1 validation error detected: " + 'Value "{}" at "path" failed to satisfy constraint: ' + "Member must have length less than or equal to 512", + ) + + +@mock_iam +def test_delete_virtual_mfa_device(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + serial_number = response["VirtualMFADevice"]["SerialNumber"] + + client.delete_virtual_mfa_device(SerialNumber=serial_number) + + response = client.list_virtual_mfa_devices() + + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok + + +@mock_iam +def test_delete_virtual_mfa_device_errors(): + client = boto3.client("iam", region_name="us-east-1") + + serial_number = "arn:aws:iam::{}:mfa/not-existing".format(ACCOUNT_ID) + client.delete_virtual_mfa_device.when.called_with( + SerialNumber=serial_number + ).should.throw( + ClientError, + "VirtualMFADevice with serial number {0} doesn't exist.".format(serial_number), + ) + + +@mock_iam +def test_list_virtual_mfa_devices(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + serial_number_1 = response["VirtualMFADevice"]["SerialNumber"] + + response = client.create_virtual_mfa_device( + Path="/test/", VirtualMFADeviceName="test-device" + ) + serial_number_2 = response["VirtualMFADevice"]["SerialNumber"] + + response = client.list_virtual_mfa_devices() + + response["VirtualMFADevices"].should.equal( + [{"SerialNumber": serial_number_1}, {"SerialNumber": serial_number_2}] + ) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Assigned") + + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Unassigned") + + response["VirtualMFADevices"].should.equal( + [{"SerialNumber": serial_number_1}, {"SerialNumber": serial_number_2}] + ) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Any", MaxItems=1) + + response["VirtualMFADevices"].should.equal([{"SerialNumber": serial_number_1}]) + response["IsTruncated"].should.be.ok + response["Marker"].should.equal("1") + + response = client.list_virtual_mfa_devices( + AssignmentStatus="Any", Marker=response["Marker"] + ) + + response["VirtualMFADevices"].should.equal([{"SerialNumber": serial_number_2}]) + response["IsTruncated"].should_not.be.ok + + +@mock_iam +def test_list_virtual_mfa_devices_errors(): + client = boto3.client("iam", region_name="us-east-1") + client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + + client.list_virtual_mfa_devices.when.called_with(Marker="100").should.throw( + ClientError, "Invalid Marker." + ) + + +@mock_iam +def test_enable_virtual_mfa_device(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + serial_number = response["VirtualMFADevice"]["SerialNumber"] + + client.create_user(UserName="test-user") + client.enable_mfa_device( + UserName="test-user", + SerialNumber=serial_number, + AuthenticationCode1="234567", + AuthenticationCode2="987654", + ) + + response = client.list_virtual_mfa_devices(AssignmentStatus="Unassigned") + + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Assigned") + + device = response["VirtualMFADevices"][0] + device["SerialNumber"].should.equal(serial_number) + device["User"]["Path"].should.equal("/") + device["User"]["UserName"].should.equal("test-user") + device["User"]["UserId"].should_not.be.empty + device["User"]["Arn"].should.equal( + "arn:aws:iam::{}:user/test-user".format(ACCOUNT_ID) + ) + device["User"]["CreateDate"].should.be.a(datetime) + device["EnableDate"].should.be.a(datetime) + response["IsTruncated"].should_not.be.ok + + client.deactivate_mfa_device(UserName="test-user", SerialNumber=serial_number) + + response = client.list_virtual_mfa_devices(AssignmentStatus="Assigned") + + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Unassigned") + + response["VirtualMFADevices"].should.equal([{"SerialNumber": serial_number}]) + response["IsTruncated"].should_not.be.ok @mock_iam_deprecated() -def test_delete_user(): +def test_delete_user_deprecated(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.delete_user('my-user') - conn.create_user('my-user') - conn.delete_user('my-user') + conn.delete_user("my-user") + conn.create_user("my-user") + conn.delete_user("my-user") + + +@mock_iam() +def test_delete_user(): + conn = boto3.client("iam", region_name="us-east-1") + with assert_raises(conn.exceptions.NoSuchEntityException): + conn.delete_user(UserName="my-user") + + # Test deletion failure with a managed policy + conn.create_user(UserName="my-user") + response = conn.create_policy( + PolicyName="my-managed-policy", PolicyDocument=MOCK_POLICY + ) + conn.attach_user_policy(PolicyArn=response["Policy"]["Arn"], UserName="my-user") + with assert_raises(conn.exceptions.DeleteConflictException): + conn.delete_user(UserName="my-user") + conn.detach_user_policy(PolicyArn=response["Policy"]["Arn"], UserName="my-user") + conn.delete_policy(PolicyArn=response["Policy"]["Arn"]) + conn.delete_user(UserName="my-user") + with assert_raises(conn.exceptions.NoSuchEntityException): + conn.get_user(UserName="my-user") + + # Test deletion failure with an inline policy + conn.create_user(UserName="my-user") + conn.put_user_policy( + UserName="my-user", PolicyName="my-user-policy", PolicyDocument=MOCK_POLICY + ) + with assert_raises(conn.exceptions.DeleteConflictException): + conn.delete_user(UserName="my-user") + conn.delete_user_policy(UserName="my-user", PolicyName="my-user-policy") + conn.delete_user(UserName="my-user") + with assert_raises(conn.exceptions.NoSuchEntityException): + conn.get_user(UserName="my-user") + + # Test deletion with no conflicts + conn.create_user(UserName="my-user") + conn.delete_user(UserName="my-user") + with assert_raises(conn.exceptions.NoSuchEntityException): + conn.get_user(UserName="my-user") @mock_iam_deprecated() def test_generate_credential_report(): conn = boto.connect_iam() result = conn.generate_credential_report() - result['generate_credential_report_response'][ - 'generate_credential_report_result']['state'].should.equal('STARTED') + result["generate_credential_report_response"]["generate_credential_report_result"][ + "state" + ].should.equal("STARTED") result = conn.generate_credential_report() - result['generate_credential_report_response'][ - 'generate_credential_report_result']['state'].should.equal('COMPLETE') + result["generate_credential_report_response"]["generate_credential_report_result"][ + "state" + ].should.equal("COMPLETE") + @mock_iam def test_boto3_generate_credential_report(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") result = conn.generate_credential_report() - result['State'].should.equal('STARTED') + result["State"].should.equal("STARTED") result = conn.generate_credential_report() - result['State'].should.equal('COMPLETE') + result["State"].should.equal("COMPLETE") @mock_iam_deprecated() def test_get_credential_report(): conn = boto.connect_iam() - conn.create_user('my-user') + conn.create_user("my-user") with assert_raises(BotoServerError): conn.get_credential_report() result = conn.generate_credential_report() - while result['generate_credential_report_response']['generate_credential_report_result']['state'] != 'COMPLETE': + while ( + result["generate_credential_report_response"][ + "generate_credential_report_result" + ]["state"] + != "COMPLETE" + ): result = conn.generate_credential_report() result = conn.get_credential_report() - report = base64.b64decode(result['get_credential_report_response'][ - 'get_credential_report_result']['content'].encode('ascii')).decode('ascii') - report.should.match(r'.*my-user.*') + report = base64.b64decode( + result["get_credential_report_response"]["get_credential_report_result"][ + "content" + ].encode("ascii") + ).decode("ascii") + report.should.match(r".*my-user.*") @mock_iam def test_boto3_get_credential_report(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") with assert_raises(ClientError): conn.get_credential_report() result = conn.generate_credential_report() - while result['State'] != 'COMPLETE': + while result["State"] != "COMPLETE": result = conn.generate_credential_report() result = conn.get_credential_report() - report = result['Content'].decode('utf-8') - report.should.match(r'.*my-user.*') + report = result["Content"].decode("utf-8") + report.should.match(r".*my-user.*") -@requires_boto_gte('2.39') +@requires_boto_gte("2.39") @mock_iam_deprecated() def test_managed_policy(): conn = boto.connect_iam() - conn.create_policy(policy_name='UserManagedPolicy', - policy_document=MOCK_POLICY, - path='/mypolicy/', - description='my user managed policy') + conn.create_policy( + policy_name="UserManagedPolicy", + policy_document=MOCK_POLICY, + path="/mypolicy/", + description="my user managed policy", + ) marker = 0 aws_policies = [] while marker is not None: - response = conn.list_policies(scope='AWS', marker=marker)[ - 'list_policies_response']['list_policies_result'] - for policy in response['policies']: + response = conn.list_policies(scope="AWS", marker=marker)[ + "list_policies_response" + ]["list_policies_result"] + for policy in response["policies"]: aws_policies.append(policy) - marker = response.get('marker') + marker = response.get("marker") set(p.name for p in aws_managed_policies).should.equal( - set(p['policy_name'] for p in aws_policies)) + set(p["policy_name"] for p in aws_policies) + ) - user_policies = conn.list_policies(scope='Local')['list_policies_response'][ - 'list_policies_result']['policies'] - set(['UserManagedPolicy']).should.equal( - set(p['policy_name'] for p in user_policies)) + user_policies = conn.list_policies(scope="Local")["list_policies_response"][ + "list_policies_result" + ]["policies"] + set(["UserManagedPolicy"]).should.equal( + set(p["policy_name"] for p in user_policies) + ) marker = 0 all_policies = [] while marker is not None: - response = conn.list_policies(marker=marker)[ - 'list_policies_response']['list_policies_result'] - for policy in response['policies']: + response = conn.list_policies(marker=marker)["list_policies_response"][ + "list_policies_result" + ] + for policy in response["policies"]: all_policies.append(policy) - marker = response.get('marker') - set(p['policy_name'] for p in aws_policies + - user_policies).should.equal(set(p['policy_name'] for p in all_policies)) + marker = response.get("marker") + set(p["policy_name"] for p in aws_policies + user_policies).should.equal( + set(p["policy_name"] for p in all_policies) + ) - role_name = 'my-role' - conn.create_role(role_name, assume_role_policy_document={ - 'policy': 'test'}, path="my-path") - for policy_name in ['AmazonElasticMapReduceRole', - 'AmazonElasticMapReduceforEC2Role']: - policy_arn = 'arn:aws:iam::aws:policy/service-role/' + policy_name + role_name = "my-role" + conn.create_role( + role_name, assume_role_policy_document={"policy": "test"}, path="my-path" + ) + for policy_name in [ + "AmazonElasticMapReduceRole", + "AmazonElasticMapReduceforEC2Role", + ]: + policy_arn = "arn:aws:iam::aws:policy/service-role/" + policy_name conn.attach_role_policy(policy_arn, role_name) - rows = conn.list_policies(only_attached=True)['list_policies_response'][ - 'list_policies_result']['policies'] + rows = conn.list_policies(only_attached=True)["list_policies_response"][ + "list_policies_result" + ]["policies"] rows.should.have.length_of(2) for x in rows: - int(x['attachment_count']).should.be.greater_than(0) + int(x["attachment_count"]).should.be.greater_than(0) # boto has not implemented this end point but accessible this way - resp = conn.get_response('ListAttachedRolePolicies', - {'RoleName': role_name}, - list_marker='AttachedPolicies') - resp['list_attached_role_policies_response']['list_attached_role_policies_result'][ - 'attached_policies'].should.have.length_of(2) + resp = conn.get_response( + "ListAttachedRolePolicies", + {"RoleName": role_name}, + list_marker="AttachedPolicies", + ) + resp["list_attached_role_policies_response"]["list_attached_role_policies_result"][ + "attached_policies" + ].should.have.length_of(2) conn.detach_role_policy( - "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", - role_name) - rows = conn.list_policies(only_attached=True)['list_policies_response'][ - 'list_policies_result']['policies'] + "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", role_name + ) + rows = conn.list_policies(only_attached=True)["list_policies_response"][ + "list_policies_result" + ]["policies"] rows.should.have.length_of(1) for x in rows: - int(x['attachment_count']).should.be.greater_than(0) + int(x["attachment_count"]).should.be.greater_than(0) # boto has not implemented this end point but accessible this way - resp = conn.get_response('ListAttachedRolePolicies', - {'RoleName': role_name}, - list_marker='AttachedPolicies') - resp['list_attached_role_policies_response']['list_attached_role_policies_result'][ - 'attached_policies'].should.have.length_of(1) + resp = conn.get_response( + "ListAttachedRolePolicies", + {"RoleName": role_name}, + list_marker="AttachedPolicies", + ) + resp["list_attached_role_policies_response"]["list_attached_role_policies_result"][ + "attached_policies" + ].should.have.length_of(1) with assert_raises(BotoServerError): conn.detach_role_policy( - "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", - role_name) + "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", role_name + ) with assert_raises(BotoServerError): - conn.detach_role_policy( - "arn:aws:iam::aws:policy/Nonexistent", role_name) + conn.detach_role_policy("arn:aws:iam::aws:policy/Nonexistent", role_name) @mock_iam def test_boto3_create_login_profile(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): - conn.create_login_profile(UserName='my-user', Password='Password') + conn.create_login_profile(UserName="my-user", Password="Password") - conn.create_user(UserName='my-user') - conn.create_login_profile(UserName='my-user', Password='Password') + conn.create_user(UserName="my-user") + conn.create_login_profile(UserName="my-user", Password="Password") with assert_raises(ClientError): - conn.create_login_profile(UserName='my-user', Password='Password') + conn.create_login_profile(UserName="my-user", Password="Password") @mock_iam() def test_attach_detach_user_policy(): - iam = boto3.resource('iam', region_name='us-east-1') - client = boto3.client('iam', region_name='us-east-1') + iam = boto3.resource("iam", region_name="us-east-1") + client = boto3.client("iam", region_name="us-east-1") - user = iam.create_user(UserName='test-user') + user = iam.create_user(UserName="test-user") - policy_name = 'UserAttachedPolicy' - policy = iam.create_policy(PolicyName=policy_name, - PolicyDocument=MOCK_POLICY, - Path='/mypolicy/', - Description='my user attached policy') + policy_name = "UserAttachedPolicy" + policy = iam.create_policy( + PolicyName=policy_name, + PolicyDocument=MOCK_POLICY, + Path="/mypolicy/", + Description="my user attached policy", + ) client.attach_user_policy(UserName=user.name, PolicyArn=policy.arn) resp = client.list_attached_user_policies(UserName=user.name) - resp['AttachedPolicies'].should.have.length_of(1) - attached_policy = resp['AttachedPolicies'][0] - attached_policy['PolicyArn'].should.equal(policy.arn) - attached_policy['PolicyName'].should.equal(policy_name) + resp["AttachedPolicies"].should.have.length_of(1) + attached_policy = resp["AttachedPolicies"][0] + attached_policy["PolicyArn"].should.equal(policy.arn) + attached_policy["PolicyName"].should.equal(policy_name) client.detach_user_policy(UserName=user.name, PolicyArn=policy.arn) resp = client.list_attached_user_policies(UserName=user.name) - resp['AttachedPolicies'].should.have.length_of(0) + resp["AttachedPolicies"].should.have.length_of(0) @mock_iam def test_update_access_key(): - iam = boto3.resource('iam', region_name='us-east-1') + iam = boto3.resource("iam", region_name="us-east-1") client = iam.meta.client - username = 'test-user' + username = "test-user" iam.create_user(UserName=username) with assert_raises(ClientError): - client.update_access_key(UserName=username, - AccessKeyId='non-existent-key', - Status='Inactive') - key = client.create_access_key(UserName=username)['AccessKey'] - client.update_access_key(UserName=username, - AccessKeyId=key['AccessKeyId'], - Status='Inactive') + client.update_access_key( + UserName=username, AccessKeyId="non-existent-key", Status="Inactive" + ) + key = client.create_access_key(UserName=username)["AccessKey"] + client.update_access_key( + UserName=username, AccessKeyId=key["AccessKeyId"], Status="Inactive" + ) resp = client.list_access_keys(UserName=username) - resp['AccessKeyMetadata'][0]['Status'].should.equal('Inactive') + resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive") @mock_iam def test_get_access_key_last_used(): - iam = boto3.resource('iam', region_name='us-east-1') + iam = boto3.resource("iam", region_name="us-east-1") client = iam.meta.client - username = 'test-user' + username = "test-user" iam.create_user(UserName=username) with assert_raises(ClientError): - client.get_access_key_last_used(AccessKeyId='non-existent-key-id') - create_key_response = client.create_access_key(UserName=username)['AccessKey'] - resp = client.get_access_key_last_used(AccessKeyId=create_key_response['AccessKeyId']) + client.get_access_key_last_used(AccessKeyId="non-existent-key-id") + create_key_response = client.create_access_key(UserName=username)["AccessKey"] + resp = client.get_access_key_last_used( + AccessKeyId=create_key_response["AccessKeyId"] + ) - datetime.strftime(resp["AccessKeyLastUsed"]["LastUsedDate"], "%Y-%m-%d").should.equal(datetime.strftime( - datetime.utcnow(), - "%Y-%m-%d" - )) + datetime.strftime( + resp["AccessKeyLastUsed"]["LastUsedDate"], "%Y-%m-%d" + ).should.equal(datetime.strftime(datetime.utcnow(), "%Y-%m-%d")) resp["UserName"].should.equal(create_key_response["UserName"]) @mock_iam -def test_get_account_authorization_details(): - test_policy = json.dumps({ - "Version": "2012-10-17", - "Statement": [ - { - "Action": "s3:ListBucket", - "Resource": "*", - "Effect": "Allow", - } - ] - }) +def test_upload_ssh_public_key(): + iam = boto3.resource("iam", region_name="us-east-1") + client = iam.meta.client + username = "test-user" + iam.create_user(UserName=username) + public_key = MOCK_CERT - conn = boto3.client('iam', region_name='us-east-1') - boundary = 'arn:aws:iam::123456789012:policy/boundary' - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/", Description='testing', PermissionsBoundary=boundary) - conn.create_user(Path='/', UserName='testUser') - conn.create_group(Path='/', GroupName='testGroup') + resp = client.upload_ssh_public_key(UserName=username, SSHPublicKeyBody=public_key) + pubkey = resp["SSHPublicKey"] + pubkey["SSHPublicKeyBody"].should.equal(public_key) + pubkey["UserName"].should.equal(username) + pubkey["SSHPublicKeyId"].should.have.length_of(20) + assert pubkey["SSHPublicKeyId"].startswith("APKA") + pubkey.should.have.key("Fingerprint") + pubkey["Status"].should.equal("Active") + ( + datetime.utcnow() - pubkey["UploadDate"].replace(tzinfo=None) + ).seconds.should.be.within(0, 10) + + +@mock_iam +def test_get_ssh_public_key(): + iam = boto3.resource("iam", region_name="us-east-1") + client = iam.meta.client + username = "test-user" + iam.create_user(UserName=username) + public_key = MOCK_CERT + + with assert_raises(ClientError): + client.get_ssh_public_key( + UserName=username, SSHPublicKeyId="xxnon-existent-keyxx", Encoding="SSH" + ) + + resp = client.upload_ssh_public_key(UserName=username, SSHPublicKeyBody=public_key) + ssh_public_key_id = resp["SSHPublicKey"]["SSHPublicKeyId"] + + resp = client.get_ssh_public_key( + UserName=username, SSHPublicKeyId=ssh_public_key_id, Encoding="SSH" + ) + resp["SSHPublicKey"]["SSHPublicKeyBody"].should.equal(public_key) + + +@mock_iam +def test_list_ssh_public_keys(): + iam = boto3.resource("iam", region_name="us-east-1") + client = iam.meta.client + username = "test-user" + iam.create_user(UserName=username) + public_key = MOCK_CERT + + resp = client.list_ssh_public_keys(UserName=username) + resp["SSHPublicKeys"].should.have.length_of(0) + + resp = client.upload_ssh_public_key(UserName=username, SSHPublicKeyBody=public_key) + ssh_public_key_id = resp["SSHPublicKey"]["SSHPublicKeyId"] + + resp = client.list_ssh_public_keys(UserName=username) + resp["SSHPublicKeys"].should.have.length_of(1) + resp["SSHPublicKeys"][0]["SSHPublicKeyId"].should.equal(ssh_public_key_id) + + +@mock_iam +def test_update_ssh_public_key(): + iam = boto3.resource("iam", region_name="us-east-1") + client = iam.meta.client + username = "test-user" + iam.create_user(UserName=username) + public_key = MOCK_CERT + + with assert_raises(ClientError): + client.update_ssh_public_key( + UserName=username, SSHPublicKeyId="xxnon-existent-keyxx", Status="Inactive" + ) + + resp = client.upload_ssh_public_key(UserName=username, SSHPublicKeyBody=public_key) + ssh_public_key_id = resp["SSHPublicKey"]["SSHPublicKeyId"] + resp["SSHPublicKey"]["Status"].should.equal("Active") + + resp = client.update_ssh_public_key( + UserName=username, SSHPublicKeyId=ssh_public_key_id, Status="Inactive" + ) + + resp = client.get_ssh_public_key( + UserName=username, SSHPublicKeyId=ssh_public_key_id, Encoding="SSH" + ) + resp["SSHPublicKey"]["Status"].should.equal("Inactive") + + +@mock_iam +def test_delete_ssh_public_key(): + iam = boto3.resource("iam", region_name="us-east-1") + client = iam.meta.client + username = "test-user" + iam.create_user(UserName=username) + public_key = MOCK_CERT + + with assert_raises(ClientError): + client.delete_ssh_public_key( + UserName=username, SSHPublicKeyId="xxnon-existent-keyxx" + ) + + resp = client.upload_ssh_public_key(UserName=username, SSHPublicKeyBody=public_key) + ssh_public_key_id = resp["SSHPublicKey"]["SSHPublicKeyId"] + + resp = client.list_ssh_public_keys(UserName=username) + resp["SSHPublicKeys"].should.have.length_of(1) + + resp = client.delete_ssh_public_key( + UserName=username, SSHPublicKeyId=ssh_public_key_id + ) + + resp = client.list_ssh_public_keys(UserName=username) + resp["SSHPublicKeys"].should.have.length_of(0) + + +@mock_iam +def test_get_account_authorization_details(): + test_policy = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + {"Action": "s3:ListBucket", "Resource": "*", "Effect": "Allow"} + ], + } + ) + + conn = boto3.client("iam", region_name="us-east-1") + boundary = "arn:aws:iam::{}:policy/boundary".format(ACCOUNT_ID) + conn.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="some policy", + Path="/my-path/", + Description="testing", + PermissionsBoundary=boundary, + ) + conn.create_user(Path="/", UserName="testUser") + conn.create_group(Path="/", GroupName="testGroup") conn.create_policy( - PolicyName='testPolicy', - Path='/', + PolicyName="testPolicy", + Path="/", PolicyDocument=test_policy, - Description='Test Policy' + Description="Test Policy", ) # Attach things to the user and group: - conn.put_user_policy(UserName='testUser', PolicyName='testPolicy', PolicyDocument=test_policy) - conn.put_group_policy(GroupName='testGroup', PolicyName='testPolicy', PolicyDocument=test_policy) + conn.put_user_policy( + UserName="testUser", PolicyName="testPolicy", PolicyDocument=test_policy + ) + conn.put_group_policy( + GroupName="testGroup", PolicyName="testPolicy", PolicyDocument=test_policy + ) - conn.attach_user_policy(UserName='testUser', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') - conn.attach_group_policy(GroupName='testGroup', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.attach_user_policy( + UserName="testUser", + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + ) + conn.attach_group_policy( + GroupName="testGroup", + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + ) - conn.add_user_to_group(UserName='testUser', GroupName='testGroup') + conn.add_user_to_group(UserName="testUser", GroupName="testGroup") # Add things to the role: - conn.create_instance_profile(InstanceProfileName='ipn') - conn.add_role_to_instance_profile(InstanceProfileName='ipn', RoleName='my-role') - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) - conn.put_role_policy(RoleName='my-role', PolicyName='test-policy', PolicyDocument=test_policy) - conn.attach_role_policy(RoleName='my-role', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.create_instance_profile(InstanceProfileName="ipn") + conn.add_role_to_instance_profile(InstanceProfileName="ipn", RoleName="my-role") + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) + conn.put_role_policy( + RoleName="my-role", PolicyName="test-policy", PolicyDocument=test_policy + ) + conn.attach_role_policy( + RoleName="my-role", + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + ) - result = conn.get_account_authorization_details(Filter=['Role']) - assert len(result['RoleDetailList']) == 1 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) == 0 - assert len(result['RoleDetailList'][0]['InstanceProfileList']) == 1 - assert result['RoleDetailList'][0]['InstanceProfileList'][0]['Roles'][0]['Description'] == 'testing' - assert result['RoleDetailList'][0]['InstanceProfileList'][0]['Roles'][0]['PermissionsBoundary'] == { - 'PermissionsBoundaryType': 'PermissionsBoundaryPolicy', - 'PermissionsBoundaryArn': 'arn:aws:iam::123456789012:policy/boundary' + result = conn.get_account_authorization_details(Filter=["Role"]) + assert len(result["RoleDetailList"]) == 1 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) == 0 + assert len(result["RoleDetailList"][0]["InstanceProfileList"]) == 1 + assert ( + result["RoleDetailList"][0]["InstanceProfileList"][0]["Roles"][0]["Description"] + == "testing" + ) + assert result["RoleDetailList"][0]["InstanceProfileList"][0]["Roles"][0][ + "PermissionsBoundary" + ] == { + "PermissionsBoundaryType": "PermissionsBoundaryPolicy", + "PermissionsBoundaryArn": "arn:aws:iam::{}:policy/boundary".format(ACCOUNT_ID), } - assert len(result['RoleDetailList'][0]['Tags']) == 2 - assert len(result['RoleDetailList'][0]['RolePolicyList']) == 1 - assert len(result['RoleDetailList'][0]['AttachedManagedPolicies']) == 1 - assert result['RoleDetailList'][0]['AttachedManagedPolicies'][0]['PolicyName'] == 'testPolicy' - assert result['RoleDetailList'][0]['AttachedManagedPolicies'][0]['PolicyArn'] == \ - 'arn:aws:iam::123456789012:policy/testPolicy' + assert len(result["RoleDetailList"][0]["Tags"]) == 2 + assert len(result["RoleDetailList"][0]["RolePolicyList"]) == 1 + assert len(result["RoleDetailList"][0]["AttachedManagedPolicies"]) == 1 + assert ( + result["RoleDetailList"][0]["AttachedManagedPolicies"][0]["PolicyName"] + == "testPolicy" + ) + assert result["RoleDetailList"][0]["AttachedManagedPolicies"][0][ + "PolicyArn" + ] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID) - result = conn.get_account_authorization_details(Filter=['User']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 1 - assert len(result['UserDetailList'][0]['GroupList']) == 1 - assert len(result['UserDetailList'][0]['AttachedManagedPolicies']) == 1 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) == 0 - assert result['UserDetailList'][0]['AttachedManagedPolicies'][0]['PolicyName'] == 'testPolicy' - assert result['UserDetailList'][0]['AttachedManagedPolicies'][0]['PolicyArn'] == \ - 'arn:aws:iam::123456789012:policy/testPolicy' + result = conn.get_account_authorization_details(Filter=["User"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 1 + assert len(result["UserDetailList"][0]["GroupList"]) == 1 + assert len(result["UserDetailList"][0]["AttachedManagedPolicies"]) == 1 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) == 0 + assert ( + result["UserDetailList"][0]["AttachedManagedPolicies"][0]["PolicyName"] + == "testPolicy" + ) + assert result["UserDetailList"][0]["AttachedManagedPolicies"][0][ + "PolicyArn" + ] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID) - result = conn.get_account_authorization_details(Filter=['Group']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 1 - assert len(result['GroupDetailList'][0]['GroupPolicyList']) == 1 - assert len(result['GroupDetailList'][0]['AttachedManagedPolicies']) == 1 - assert len(result['Policies']) == 0 - assert result['GroupDetailList'][0]['AttachedManagedPolicies'][0]['PolicyName'] == 'testPolicy' - assert result['GroupDetailList'][0]['AttachedManagedPolicies'][0]['PolicyArn'] == \ - 'arn:aws:iam::123456789012:policy/testPolicy' + result = conn.get_account_authorization_details(Filter=["Group"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 1 + assert len(result["GroupDetailList"][0]["GroupPolicyList"]) == 1 + assert len(result["GroupDetailList"][0]["AttachedManagedPolicies"]) == 1 + assert len(result["Policies"]) == 0 + assert ( + result["GroupDetailList"][0]["AttachedManagedPolicies"][0]["PolicyName"] + == "testPolicy" + ) + assert result["GroupDetailList"][0]["AttachedManagedPolicies"][0][ + "PolicyArn" + ] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID) - result = conn.get_account_authorization_details(Filter=['LocalManagedPolicy']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) == 1 - assert len(result['Policies'][0]['PolicyVersionList']) == 1 + result = conn.get_account_authorization_details(Filter=["LocalManagedPolicy"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) == 1 + assert len(result["Policies"][0]["PolicyVersionList"]) == 1 # Check for greater than 1 since this should always be greater than one but might change. # See iam/aws_managed_policies.py - result = conn.get_account_authorization_details(Filter=['AWSManagedPolicy']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) > 1 + result = conn.get_account_authorization_details(Filter=["AWSManagedPolicy"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) > 1 result = conn.get_account_authorization_details() - assert len(result['RoleDetailList']) == 1 - assert len(result['UserDetailList']) == 1 - assert len(result['GroupDetailList']) == 1 - assert len(result['Policies']) > 1 + assert len(result["RoleDetailList"]) == 1 + assert len(result["UserDetailList"]) == 1 + assert len(result["GroupDetailList"]) == 1 + assert len(result["Policies"]) > 1 @mock_iam def test_signing_certs(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") # Create the IAM user first: - client.create_user(UserName='testing') + client.create_user(UserName="testing") # Upload the cert: - resp = client.upload_signing_certificate(UserName='testing', CertificateBody=MOCK_CERT)['Certificate'] - cert_id = resp['CertificateId'] + resp = client.upload_signing_certificate( + UserName="testing", CertificateBody=MOCK_CERT + )["Certificate"] + cert_id = resp["CertificateId"] - assert resp['UserName'] == 'testing' - assert resp['Status'] == 'Active' - assert resp['CertificateBody'] == MOCK_CERT - assert resp['CertificateId'] + assert resp["UserName"] == "testing" + assert resp["Status"] == "Active" + assert resp["CertificateBody"] == MOCK_CERT + assert resp["CertificateId"] # Upload a the cert with an invalid body: with assert_raises(ClientError) as ce: - client.upload_signing_certificate(UserName='testing', CertificateBody='notacert') - assert ce.exception.response['Error']['Code'] == 'MalformedCertificate' + client.upload_signing_certificate( + UserName="testing", CertificateBody="notacert" + ) + assert ce.exception.response["Error"]["Code"] == "MalformedCertificate" # Upload with an invalid user: with assert_raises(ClientError): - client.upload_signing_certificate(UserName='notauser', CertificateBody=MOCK_CERT) + client.upload_signing_certificate( + UserName="notauser", CertificateBody=MOCK_CERT + ) # Update: - client.update_signing_certificate(UserName='testing', CertificateId=cert_id, Status='Inactive') + client.update_signing_certificate( + UserName="testing", CertificateId=cert_id, Status="Inactive" + ) with assert_raises(ClientError): - client.update_signing_certificate(UserName='notauser', CertificateId=cert_id, Status='Inactive') + client.update_signing_certificate( + UserName="notauser", CertificateId=cert_id, Status="Inactive" + ) with assert_raises(ClientError) as ce: - client.update_signing_certificate(UserName='testing', CertificateId='x' * 32, Status='Inactive') + client.update_signing_certificate( + UserName="testing", CertificateId="x" * 32, Status="Inactive" + ) - assert ce.exception.response['Error']['Message'] == 'The Certificate with id {id} cannot be found.'.format( - id='x' * 32) + assert ce.exception.response["Error"][ + "Message" + ] == "The Certificate with id {id} cannot be found.".format(id="x" * 32) # List the certs: - resp = client.list_signing_certificates(UserName='testing')['Certificates'] + resp = client.list_signing_certificates(UserName="testing")["Certificates"] assert len(resp) == 1 - assert resp[0]['CertificateBody'] == MOCK_CERT - assert resp[0]['Status'] == 'Inactive' # Changed with the update call above. + assert resp[0]["CertificateBody"] == MOCK_CERT + assert resp[0]["Status"] == "Inactive" # Changed with the update call above. with assert_raises(ClientError): - client.list_signing_certificates(UserName='notauser') + client.list_signing_certificates(UserName="notauser") # Delete: - client.delete_signing_certificate(UserName='testing', CertificateId=cert_id) + client.delete_signing_certificate(UserName="testing", CertificateId=cert_id) with assert_raises(ClientError): - client.delete_signing_certificate(UserName='notauser', CertificateId=cert_id) + client.delete_signing_certificate(UserName="notauser", CertificateId=cert_id) @mock_iam() def test_create_saml_provider(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") response = conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 + Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024 + ) + response["SAMLProviderArn"].should.equal( + "arn:aws:iam::{}:saml-provider/TestSAMLProvider".format(ACCOUNT_ID) ) - response['SAMLProviderArn'].should.equal("arn:aws:iam::123456789012:saml-provider/TestSAMLProvider") @mock_iam() def test_get_saml_provider(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") saml_provider_create = conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 + Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024 ) response = conn.get_saml_provider( - SAMLProviderArn=saml_provider_create['SAMLProviderArn'] + SAMLProviderArn=saml_provider_create["SAMLProviderArn"] ) - response['SAMLMetadataDocument'].should.equal('a' * 1024) + response["SAMLMetadataDocument"].should.equal("a" * 1024) @mock_iam() def test_list_saml_providers(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 - ) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_saml_provider(Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024) response = conn.list_saml_providers() - response['SAMLProviderList'][0]['Arn'].should.equal("arn:aws:iam::123456789012:saml-provider/TestSAMLProvider") + response["SAMLProviderList"][0]["Arn"].should.equal( + "arn:aws:iam::{}:saml-provider/TestSAMLProvider".format(ACCOUNT_ID) + ) @mock_iam() def test_delete_saml_provider(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") saml_provider_create = conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 + Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024 ) response = conn.list_saml_providers() - len(response['SAMLProviderList']).should.equal(1) - conn.delete_saml_provider( - SAMLProviderArn=saml_provider_create['SAMLProviderArn'] - ) + len(response["SAMLProviderList"]).should.equal(1) + conn.delete_saml_provider(SAMLProviderArn=saml_provider_create["SAMLProviderArn"]) response = conn.list_saml_providers() - len(response['SAMLProviderList']).should.equal(0) - conn.create_user(UserName='testing') + len(response["SAMLProviderList"]).should.equal(0) + conn.create_user(UserName="testing") - cert_id = '123456789012345678901234' + cert_id = "123456789012345678901234" with assert_raises(ClientError) as ce: - conn.delete_signing_certificate(UserName='testing', CertificateId=cert_id) + conn.delete_signing_certificate(UserName="testing", CertificateId=cert_id) - assert ce.exception.response['Error']['Message'] == 'The Certificate with id {id} cannot be found.'.format( - id=cert_id) + assert ce.exception.response["Error"][ + "Message" + ] == "The Certificate with id {id} cannot be found.".format(id=cert_id) # Verify that it's not in the list: - resp = conn.list_signing_certificates(UserName='testing') - assert not resp['Certificates'] + resp = conn.list_signing_certificates(UserName="testing") + assert not resp["Certificates"] + + +@mock_iam() +def test_create_role_defaults(): + """Tests default values""" + conn = boto3.client("iam", region_name="us-east-1") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="{}", + ) + + # Get role: + role = conn.get_role(RoleName="my-role")["Role"] + + assert role["MaxSessionDuration"] == 3600 + assert role.get("Description") is None @mock_iam() def test_create_role_with_tags(): """Tests both the tag_role and get_role_tags capability""" - conn = boto3.client('iam', region_name='us-east-1') - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="{}", Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ], Description='testing') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="{}", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + Description="testing", + ) # Get role: - role = conn.get_role(RoleName='my-role')['Role'] - assert len(role['Tags']) == 2 - assert role['Tags'][0]['Key'] == 'somekey' - assert role['Tags'][0]['Value'] == 'somevalue' - assert role['Tags'][1]['Key'] == 'someotherkey' - assert role['Tags'][1]['Value'] == 'someothervalue' - assert role['Description'] == 'testing' + role = conn.get_role(RoleName="my-role")["Role"] + assert len(role["Tags"]) == 2 + assert role["Tags"][0]["Key"] == "somekey" + assert role["Tags"][0]["Value"] == "somevalue" + assert role["Tags"][1]["Key"] == "someotherkey" + assert role["Tags"][1]["Value"] == "someothervalue" + assert role["Description"] == "testing" # Empty is good: - conn.create_role(RoleName="my-role2", AssumeRolePolicyDocument="{}", Tags=[ - { - 'Key': 'somekey', - 'Value': '' - } - ]) - tags = conn.list_role_tags(RoleName='my-role2') - assert len(tags['Tags']) == 1 - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == '' + conn.create_role( + RoleName="my-role2", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "somekey", "Value": ""}], + ) + tags = conn.list_role_tags(RoleName="my-role2") + assert len(tags["Tags"]) == 1 + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "" # Test creating tags with invalid values: # With more than 50 tags: with assert_raises(ClientError) as ce: - too_many_tags = list(map(lambda x: {'Key': str(x), 'Value': str(x)}, range(0, 51))) - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=too_many_tags) - assert 'failed to satisfy constraint: Member must have length less than or equal to 50.' \ - in ce.exception.response['Error']['Message'] + too_many_tags = list( + map(lambda x: {"Key": str(x), "Value": str(x)}, range(0, 51)) + ) + conn.create_role( + RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=too_many_tags + ) + assert ( + "failed to satisfy constraint: Member must have length less than or equal to 50." + in ce.exception.response["Error"]["Message"] + ) # With a duplicate tag: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': '0', 'Value': ''}, {'Key': '0', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "0", "Value": ""}, {"Key": "0", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # Duplicate tag with different casing: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': 'a', 'Value': ''}, {'Key': 'A', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "a", "Value": ""}, {"Key": "A", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # With a really big key: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': '0' * 129, 'Value': ''}]) - assert 'Member must have length less than or equal to 128.' in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "0" * 129, "Value": ""}], + ) + assert ( + "Member must have length less than or equal to 128." + in ce.exception.response["Error"]["Message"] + ) # With a really big value: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': '0', 'Value': '0' * 257}]) - assert 'Member must have length less than or equal to 256.' in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "0", "Value": "0" * 257}], + ) + assert ( + "Member must have length less than or equal to 256." + in ce.exception.response["Error"]["Message"] + ) # With an invalid character: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': 'NOWAY!', 'Value': ''}]) - assert 'Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+' \ - in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "NOWAY!", "Value": ""}], + ) + assert ( + "Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" + in ce.exception.response["Error"]["Message"] + ) @mock_iam() def test_tag_role(): """Tests both the tag_role and get_role_tags capability""" - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="{}") # Get without tags: - role = conn.get_role(RoleName='my-role')['Role'] - assert not role.get('Tags') + role = conn.get_role(RoleName="my-role")["Role"] + assert not role.get("Tags") # With proper tag values: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) # Get role: - role = conn.get_role(RoleName='my-role')['Role'] - assert len(role['Tags']) == 2 - assert role['Tags'][0]['Key'] == 'somekey' - assert role['Tags'][0]['Value'] == 'somevalue' - assert role['Tags'][1]['Key'] == 'someotherkey' - assert role['Tags'][1]['Value'] == 'someothervalue' + role = conn.get_role(RoleName="my-role")["Role"] + assert len(role["Tags"]) == 2 + assert role["Tags"][0]["Key"] == "somekey" + assert role["Tags"][0]["Value"] == "somevalue" + assert role["Tags"][1]["Key"] == "someotherkey" + assert role["Tags"][1]["Value"] == "someothervalue" # Same -- but for list_role_tags: - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 2 - assert role['Tags'][0]['Key'] == 'somekey' - assert role['Tags'][0]['Value'] == 'somevalue' - assert role['Tags'][1]['Key'] == 'someotherkey' - assert role['Tags'][1]['Value'] == 'someothervalue' - assert not tags['IsTruncated'] - assert not tags.get('Marker') + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 2 + assert role["Tags"][0]["Key"] == "somekey" + assert role["Tags"][0]["Value"] == "somevalue" + assert role["Tags"][1]["Key"] == "someotherkey" + assert role["Tags"][1]["Value"] == "someothervalue" + assert not tags["IsTruncated"] + assert not tags.get("Marker") # Test pagination: - tags = conn.list_role_tags(RoleName='my-role', MaxItems=1) - assert len(tags['Tags']) == 1 - assert tags['IsTruncated'] - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == 'somevalue' - assert tags['Marker'] == '1' + tags = conn.list_role_tags(RoleName="my-role", MaxItems=1) + assert len(tags["Tags"]) == 1 + assert tags["IsTruncated"] + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "somevalue" + assert tags["Marker"] == "1" - tags = conn.list_role_tags(RoleName='my-role', Marker=tags['Marker']) - assert len(tags['Tags']) == 1 - assert tags['Tags'][0]['Key'] == 'someotherkey' - assert tags['Tags'][0]['Value'] == 'someothervalue' - assert not tags['IsTruncated'] - assert not tags.get('Marker') + tags = conn.list_role_tags(RoleName="my-role", Marker=tags["Marker"]) + assert len(tags["Tags"]) == 1 + assert tags["Tags"][0]["Key"] == "someotherkey" + assert tags["Tags"][0]["Value"] == "someothervalue" + assert not tags["IsTruncated"] + assert not tags.get("Marker") # Test updating an existing tag: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somenewvalue' - } - ]) - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 2 - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == 'somenewvalue' + conn.tag_role( + RoleName="my-role", Tags=[{"Key": "somekey", "Value": "somenewvalue"}] + ) + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 2 + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "somenewvalue" # Empty is good: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': '' - } - ]) - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 2 - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == '' + conn.tag_role(RoleName="my-role", Tags=[{"Key": "somekey", "Value": ""}]) + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 2 + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "" # Test creating tags with invalid values: # With more than 50 tags: with assert_raises(ClientError) as ce: - too_many_tags = list(map(lambda x: {'Key': str(x), 'Value': str(x)}, range(0, 51))) - conn.tag_role(RoleName='my-role', Tags=too_many_tags) - assert 'failed to satisfy constraint: Member must have length less than or equal to 50.' \ - in ce.exception.response['Error']['Message'] + too_many_tags = list( + map(lambda x: {"Key": str(x), "Value": str(x)}, range(0, 51)) + ) + conn.tag_role(RoleName="my-role", Tags=too_many_tags) + assert ( + "failed to satisfy constraint: Member must have length less than or equal to 50." + in ce.exception.response["Error"]["Message"] + ) # With a duplicate tag: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': '0', 'Value': ''}, {'Key': '0', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.tag_role( + RoleName="my-role", + Tags=[{"Key": "0", "Value": ""}, {"Key": "0", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # Duplicate tag with different casing: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': 'a', 'Value': ''}, {'Key': 'A', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.tag_role( + RoleName="my-role", + Tags=[{"Key": "a", "Value": ""}, {"Key": "A", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # With a really big key: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': '0' * 129, 'Value': ''}]) - assert 'Member must have length less than or equal to 128.' in ce.exception.response['Error']['Message'] + conn.tag_role(RoleName="my-role", Tags=[{"Key": "0" * 129, "Value": ""}]) + assert ( + "Member must have length less than or equal to 128." + in ce.exception.response["Error"]["Message"] + ) # With a really big value: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': '0', 'Value': '0' * 257}]) - assert 'Member must have length less than or equal to 256.' in ce.exception.response['Error']['Message'] + conn.tag_role(RoleName="my-role", Tags=[{"Key": "0", "Value": "0" * 257}]) + assert ( + "Member must have length less than or equal to 256." + in ce.exception.response["Error"]["Message"] + ) # With an invalid character: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': 'NOWAY!', 'Value': ''}]) - assert 'Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+' \ - in ce.exception.response['Error']['Message'] + conn.tag_role(RoleName="my-role", Tags=[{"Key": "NOWAY!", "Value": ""}]) + assert ( + "Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" + in ce.exception.response["Error"]["Message"] + ) # With a role that doesn't exist: with assert_raises(ClientError): - conn.tag_role(RoleName='notarole', Tags=[{'Key': 'some', 'Value': 'value'}]) + conn.tag_role(RoleName="notarole", Tags=[{"Key": "some", "Value": "value"}]) @mock_iam def test_untag_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="{}") # With proper tag values: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) # Remove them: - conn.untag_role(RoleName='my-role', TagKeys=['somekey']) - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 1 - assert tags['Tags'][0]['Key'] == 'someotherkey' - assert tags['Tags'][0]['Value'] == 'someothervalue' + conn.untag_role(RoleName="my-role", TagKeys=["somekey"]) + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 1 + assert tags["Tags"][0]["Key"] == "someotherkey" + assert tags["Tags"][0]["Value"] == "someothervalue" # And again: - conn.untag_role(RoleName='my-role', TagKeys=['someotherkey']) - tags = conn.list_role_tags(RoleName='my-role') - assert not tags['Tags'] + conn.untag_role(RoleName="my-role", TagKeys=["someotherkey"]) + tags = conn.list_role_tags(RoleName="my-role") + assert not tags["Tags"] # Test removing tags with invalid values: # With more than 50 tags: with assert_raises(ClientError) as ce: - conn.untag_role(RoleName='my-role', TagKeys=[str(x) for x in range(0, 51)]) - assert 'failed to satisfy constraint: Member must have length less than or equal to 50.' \ - in ce.exception.response['Error']['Message'] - assert 'tagKeys' in ce.exception.response['Error']['Message'] + conn.untag_role(RoleName="my-role", TagKeys=[str(x) for x in range(0, 51)]) + assert ( + "failed to satisfy constraint: Member must have length less than or equal to 50." + in ce.exception.response["Error"]["Message"] + ) + assert "tagKeys" in ce.exception.response["Error"]["Message"] # With a really big key: with assert_raises(ClientError) as ce: - conn.untag_role(RoleName='my-role', TagKeys=['0' * 129]) - assert 'Member must have length less than or equal to 128.' in ce.exception.response['Error']['Message'] - assert 'tagKeys' in ce.exception.response['Error']['Message'] + conn.untag_role(RoleName="my-role", TagKeys=["0" * 129]) + assert ( + "Member must have length less than or equal to 128." + in ce.exception.response["Error"]["Message"] + ) + assert "tagKeys" in ce.exception.response["Error"]["Message"] # With an invalid character: with assert_raises(ClientError) as ce: - conn.untag_role(RoleName='my-role', TagKeys=['NOWAY!']) - assert 'Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+' \ - in ce.exception.response['Error']['Message'] - assert 'tagKeys' in ce.exception.response['Error']['Message'] + conn.untag_role(RoleName="my-role", TagKeys=["NOWAY!"]) + assert ( + "Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" + in ce.exception.response["Error"]["Message"] + ) + assert "tagKeys" in ce.exception.response["Error"]["Message"] # With a role that doesn't exist: with assert_raises(ClientError): - conn.untag_role(RoleName='notarole', TagKeys=['somevalue']) + conn.untag_role(RoleName="notarole", TagKeys=["somevalue"]) @mock_iam() def test_update_role_description(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.delete_role(RoleName="my-role") - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) response = conn.update_role_description(RoleName="my-role", Description="test") - assert response['Role']['RoleName'] == 'my-role' + assert response["Role"]["RoleName"] == "my-role" @mock_iam() def test_update_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.delete_role(RoleName="my-role") - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) response = conn.update_role_description(RoleName="my-role", Description="test") - assert response['Role']['RoleName'] == 'my-role' + assert response["Role"]["RoleName"] == "my-role" @mock_iam() def test_update_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.delete_role(RoleName="my-role") - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) response = conn.update_role(RoleName="my-role", Description="test") assert len(response.keys()) == 1 @mock_iam() -def test_list_entities_for_policy(): - test_policy = json.dumps({ - "Version": "2012-10-17", - "Statement": [ - { - "Action": "s3:ListBucket", - "Resource": "*", - "Effect": "Allow", - } - ] - }) +def test_update_role_defaults(): + conn = boto3.client("iam", region_name="us-east-1") - conn = boto3.client('iam', region_name='us-east-1') - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") - conn.create_user(Path='/', UserName='testUser') - conn.create_group(Path='/', GroupName='testGroup') + with assert_raises(ClientError): + conn.delete_role(RoleName="my-role") + + conn.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="some policy", + Description="test", + Path="/my-path/", + ) + response = conn.update_role(RoleName="my-role") + assert len(response.keys()) == 1 + + role = conn.get_role(RoleName="my-role")["Role"] + + assert role["MaxSessionDuration"] == 3600 + assert role.get("Description") is None + + +@mock_iam() +def test_list_entities_for_policy(): + test_policy = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + {"Action": "s3:ListBucket", "Resource": "*", "Effect": "Allow"} + ], + } + ) + + conn = boto3.client("iam", region_name="us-east-1") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + conn.create_user(Path="/", UserName="testUser") + conn.create_group(Path="/", GroupName="testGroup") conn.create_policy( - PolicyName='testPolicy', - Path='/', + PolicyName="testPolicy", + Path="/", PolicyDocument=test_policy, - Description='Test Policy' + Description="Test Policy", ) # Attach things to the user and group: - conn.put_user_policy(UserName='testUser', PolicyName='testPolicy', PolicyDocument=test_policy) - conn.put_group_policy(GroupName='testGroup', PolicyName='testPolicy', PolicyDocument=test_policy) + conn.put_user_policy( + UserName="testUser", PolicyName="testPolicy", PolicyDocument=test_policy + ) + conn.put_group_policy( + GroupName="testGroup", PolicyName="testPolicy", PolicyDocument=test_policy + ) - conn.attach_user_policy(UserName='testUser', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') - conn.attach_group_policy(GroupName='testGroup', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.attach_user_policy( + UserName="testUser", + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + ) + conn.attach_group_policy( + GroupName="testGroup", + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + ) - conn.add_user_to_group(UserName='testUser', GroupName='testGroup') + conn.add_user_to_group(UserName="testUser", GroupName="testGroup") # Add things to the role: - conn.create_instance_profile(InstanceProfileName='ipn') - conn.add_role_to_instance_profile(InstanceProfileName='ipn', RoleName='my-role') - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) - conn.put_role_policy(RoleName='my-role', PolicyName='test-policy', PolicyDocument=test_policy) - conn.attach_role_policy(RoleName='my-role', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.create_instance_profile(InstanceProfileName="ipn") + conn.add_role_to_instance_profile(InstanceProfileName="ipn", RoleName="my-role") + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) + conn.put_role_policy( + RoleName="my-role", PolicyName="test-policy", PolicyDocument=test_policy + ) + conn.attach_role_policy( + RoleName="my-role", + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + ) response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='Role' + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + EntityFilter="Role", ) - assert response['PolicyRoles'] == [{'RoleName': 'my-role'}] + assert response["PolicyRoles"] == [{"RoleName": "my-role"}] response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='User', + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + EntityFilter="User", ) - assert response['PolicyUsers'] == [{'UserName': 'testUser'}] + assert response["PolicyUsers"] == [{"UserName": "testUser"}] response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='Group', + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + EntityFilter="Group", ) - assert response['PolicyGroups'] == [{'GroupName': 'testGroup'}] + assert response["PolicyGroups"] == [{"GroupName": "testGroup"}] response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='LocalManagedPolicy', + PolicyArn="arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID), + EntityFilter="LocalManagedPolicy", ) - assert response['PolicyGroups'] == [{'GroupName': 'testGroup'}] - assert response['PolicyUsers'] == [{'UserName': 'testUser'}] - assert response['PolicyRoles'] == [{'RoleName': 'my-role'}] + assert response["PolicyGroups"] == [{"GroupName": "testGroup"}] + assert response["PolicyUsers"] == [{"UserName": "testUser"}] + assert response["PolicyRoles"] == [{"RoleName": "my-role"}] @mock_iam() def test_create_role_no_path(): - conn = boto3.client('iam', region_name='us-east-1') - resp = conn.create_role(RoleName='my-role', AssumeRolePolicyDocument='some policy', Description='test') - resp.get('Role').get('Arn').should.equal('arn:aws:iam::123456789012:role/my-role') - resp.get('Role').should_not.have.key('PermissionsBoundary') - resp.get('Role').get('Description').should.equal('test') + conn = boto3.client("iam", region_name="us-east-1") + resp = conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Description="test" + ) + resp.get("Role").get("Arn").should.equal( + "arn:aws:iam::{}:role/my-role".format(ACCOUNT_ID) + ) + resp.get("Role").should_not.have.key("PermissionsBoundary") + resp.get("Role").get("Description").should.equal("test") @mock_iam() def test_create_role_with_permissions_boundary(): - conn = boto3.client('iam', region_name='us-east-1') - boundary = 'arn:aws:iam::123456789012:policy/boundary' - resp = conn.create_role(RoleName='my-role', AssumeRolePolicyDocument='some policy', Description='test', PermissionsBoundary=boundary) + conn = boto3.client("iam", region_name="us-east-1") + boundary = "arn:aws:iam::{}:policy/boundary".format(ACCOUNT_ID) + resp = conn.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="some policy", + Description="test", + PermissionsBoundary=boundary, + ) expected = { - 'PermissionsBoundaryType': 'PermissionsBoundaryPolicy', - 'PermissionsBoundaryArn': boundary + "PermissionsBoundaryType": "PermissionsBoundaryPolicy", + "PermissionsBoundaryArn": boundary, } - resp.get('Role').get('PermissionsBoundary').should.equal(expected) - resp.get('Role').get('Description').should.equal('test') + resp.get("Role").get("PermissionsBoundary").should.equal(expected) + resp.get("Role").get("Description").should.equal("test") - invalid_boundary_arn = 'arn:aws:iam::123456789:not_a_boundary' + invalid_boundary_arn = "arn:aws:iam::123456789:not_a_boundary" with assert_raises(ClientError): - conn.create_role(RoleName='bad-boundary', AssumeRolePolicyDocument='some policy', Description='test', PermissionsBoundary=invalid_boundary_arn) + conn.create_role( + RoleName="bad-boundary", + AssumeRolePolicyDocument="some policy", + Description="test", + PermissionsBoundary=invalid_boundary_arn, + ) # Ensure the PermissionsBoundary is included in role listing as well - conn.list_roles().get('Roles')[0].get('PermissionsBoundary').should.equal(expected) + conn.list_roles().get("Roles")[0].get("PermissionsBoundary").should.equal(expected) + + +@mock_iam +def test_create_role_with_same_name_should_fail(): + iam = boto3.client("iam", region_name="us-east-1") + test_role_name = str(uuid4()) + iam.create_role( + RoleName=test_role_name, AssumeRolePolicyDocument="policy", Description="test" + ) + # Create the role again, and verify that it fails + with assert_raises(ClientError) as err: + iam.create_role( + RoleName=test_role_name, + AssumeRolePolicyDocument="policy", + Description="test", + ) + err.exception.response["Error"]["Code"].should.equal("EntityAlreadyExists") + err.exception.response["Error"]["Message"].should.equal( + "Role with name {0} already exists.".format(test_role_name) + ) + + +@mock_iam +def test_create_policy_with_same_name_should_fail(): + iam = boto3.client("iam", region_name="us-east-1") + test_policy_name = str(uuid4()) + policy = iam.create_policy(PolicyName=test_policy_name, PolicyDocument=MOCK_POLICY) + # Create the role again, and verify that it fails + with assert_raises(ClientError) as err: + iam.create_policy(PolicyName=test_policy_name, PolicyDocument=MOCK_POLICY) + err.exception.response["Error"]["Code"].should.equal("EntityAlreadyExists") + err.exception.response["Error"]["Message"].should.equal( + "A policy called {0} already exists. Duplicate names are not allowed.".format( + test_policy_name + ) + ) + + +@mock_iam +def test_create_open_id_connect_provider(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_open_id_connect_provider( + Url="https://example.com", + ThumbprintList=[], # even it is required to provide at least one thumbprint, AWS accepts an empty list + ) + + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::{}:oidc-provider/example.com".format(ACCOUNT_ID) + ) + + response = client.create_open_id_connect_provider( + Url="http://example.org", ThumbprintList=["b" * 40], ClientIDList=["b"] + ) + + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::{}:oidc-provider/example.org".format(ACCOUNT_ID) + ) + + response = client.create_open_id_connect_provider( + Url="http://example.org/oidc", ThumbprintList=[] + ) + + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::{}:oidc-provider/example.org/oidc".format(ACCOUNT_ID) + ) + + response = client.create_open_id_connect_provider( + Url="http://example.org/oidc-query?test=true", ThumbprintList=[] + ) + + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::{}:oidc-provider/example.org/oidc-query".format(ACCOUNT_ID) + ) + + +@mock_iam +def test_create_open_id_connect_provider_errors(): + client = boto3.client("iam", region_name="us-east-1") + client.create_open_id_connect_provider(Url="https://example.com", ThumbprintList=[]) + + client.create_open_id_connect_provider.when.called_with( + Url="https://example.com", ThumbprintList=[] + ).should.throw(ClientError, "Unknown") + + client.create_open_id_connect_provider.when.called_with( + Url="example.org", ThumbprintList=[] + ).should.throw(ClientError, "Invalid Open ID Connect Provider URL") + + client.create_open_id_connect_provider.when.called_with( + Url="example", ThumbprintList=[] + ).should.throw(ClientError, "Invalid Open ID Connect Provider URL") + + client.create_open_id_connect_provider.when.called_with( + Url="http://example.org", + ThumbprintList=["a" * 40, "b" * 40, "c" * 40, "d" * 40, "e" * 40, "f" * 40], + ).should.throw(ClientError, "Thumbprint list must contain fewer than 5 entries.") + + too_many_client_ids = ["{}".format(i) for i in range(101)] + client.create_open_id_connect_provider.when.called_with( + Url="http://example.org", ThumbprintList=[], ClientIDList=too_many_client_ids + ).should.throw( + ClientError, "Cannot exceed quota for ClientIdsPerOpenIdConnectProvider: 100" + ) + + too_long_url = "b" * 256 + too_long_thumbprint = "b" * 41 + too_long_client_id = "b" * 256 + client.create_open_id_connect_provider.when.called_with( + Url=too_long_url, + ThumbprintList=[too_long_thumbprint], + ClientIDList=[too_long_client_id], + ).should.throw( + ClientError, + "3 validation errors detected: " + 'Value "{0}" at "clientIDList" failed to satisfy constraint: ' + "Member must satisfy constraint: " + "[Member must have length less than or equal to 255, " + "Member must have length greater than or equal to 1]; " + 'Value "{1}" at "thumbprintList" failed to satisfy constraint: ' + "Member must satisfy constraint: " + "[Member must have length less than or equal to 40, " + "Member must have length greater than or equal to 40]; " + 'Value "{2}" at "url" failed to satisfy constraint: ' + "Member must have length less than or equal to 255".format( + [too_long_client_id], [too_long_thumbprint], too_long_url + ), + ) + + +@mock_iam +def test_delete_open_id_connect_provider(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_open_id_connect_provider( + Url="https://example.com", ThumbprintList=[] + ) + open_id_arn = response["OpenIDConnectProviderArn"] + + client.delete_open_id_connect_provider(OpenIDConnectProviderArn=open_id_arn) + + client.get_open_id_connect_provider.when.called_with( + OpenIDConnectProviderArn=open_id_arn + ).should.throw( + ClientError, "OpenIDConnect Provider not found for arn {}".format(open_id_arn) + ) + + # deleting a non existing provider should be successful + client.delete_open_id_connect_provider(OpenIDConnectProviderArn=open_id_arn) + + +@mock_iam +def test_get_open_id_connect_provider(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_open_id_connect_provider( + Url="https://example.com", ThumbprintList=["b" * 40], ClientIDList=["b"] + ) + open_id_arn = response["OpenIDConnectProviderArn"] + + response = client.get_open_id_connect_provider(OpenIDConnectProviderArn=open_id_arn) + + response["Url"].should.equal("example.com") + response["ThumbprintList"].should.equal(["b" * 40]) + response["ClientIDList"].should.equal(["b"]) + response.should.have.key("CreateDate").should.be.a(datetime) + + +@mock_iam +def test_get_open_id_connect_provider_errors(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_open_id_connect_provider( + Url="https://example.com", ThumbprintList=["b" * 40], ClientIDList=["b"] + ) + open_id_arn = response["OpenIDConnectProviderArn"] + + client.get_open_id_connect_provider.when.called_with( + OpenIDConnectProviderArn=open_id_arn + "-not-existing" + ).should.throw( + ClientError, + "OpenIDConnect Provider not found for arn {}".format( + open_id_arn + "-not-existing" + ), + ) + + +@mock_iam +def test_list_open_id_connect_providers(): + client = boto3.client("iam", region_name="us-east-1") + response = client.create_open_id_connect_provider( + Url="https://example.com", ThumbprintList=[] + ) + open_id_arn_1 = response["OpenIDConnectProviderArn"] + + response = client.create_open_id_connect_provider( + Url="http://example.org", ThumbprintList=["b" * 40], ClientIDList=["b"] + ) + open_id_arn_2 = response["OpenIDConnectProviderArn"] + + response = client.create_open_id_connect_provider( + Url="http://example.org/oidc", ThumbprintList=[] + ) + open_id_arn_3 = response["OpenIDConnectProviderArn"] + + response = client.list_open_id_connect_providers() + + sorted(response["OpenIDConnectProviderList"], key=lambda i: i["Arn"]).should.equal( + [{"Arn": open_id_arn_1}, {"Arn": open_id_arn_2}, {"Arn": open_id_arn_3}] + ) + + +@mock_iam +def test_update_account_password_policy(): + client = boto3.client("iam", region_name="us-east-1") + + client.update_account_password_policy() + + response = client.get_account_password_policy() + response["PasswordPolicy"].should.equal( + { + "AllowUsersToChangePassword": False, + "ExpirePasswords": False, + "MinimumPasswordLength": 6, + "RequireLowercaseCharacters": False, + "RequireNumbers": False, + "RequireSymbols": False, + "RequireUppercaseCharacters": False, + } + ) + + +@mock_iam +def test_update_account_password_policy_errors(): + client = boto3.client("iam", region_name="us-east-1") + + client.update_account_password_policy.when.called_with( + MaxPasswordAge=1096, MinimumPasswordLength=129, PasswordReusePrevention=25 + ).should.throw( + ClientError, + "3 validation errors detected: " + 'Value "129" at "minimumPasswordLength" failed to satisfy constraint: ' + "Member must have value less than or equal to 128; " + 'Value "25" at "passwordReusePrevention" failed to satisfy constraint: ' + "Member must have value less than or equal to 24; " + 'Value "1096" at "maxPasswordAge" failed to satisfy constraint: ' + "Member must have value less than or equal to 1095", + ) + + +@mock_iam +def test_get_account_password_policy(): + client = boto3.client("iam", region_name="us-east-1") + client.update_account_password_policy( + AllowUsersToChangePassword=True, + HardExpiry=True, + MaxPasswordAge=60, + MinimumPasswordLength=10, + PasswordReusePrevention=3, + RequireLowercaseCharacters=True, + RequireNumbers=True, + RequireSymbols=True, + RequireUppercaseCharacters=True, + ) + + response = client.get_account_password_policy() + + response["PasswordPolicy"].should.equal( + { + "AllowUsersToChangePassword": True, + "ExpirePasswords": True, + "HardExpiry": True, + "MaxPasswordAge": 60, + "MinimumPasswordLength": 10, + "PasswordReusePrevention": 3, + "RequireLowercaseCharacters": True, + "RequireNumbers": True, + "RequireSymbols": True, + "RequireUppercaseCharacters": True, + } + ) + + +@mock_iam +def test_get_account_password_policy_errors(): + client = boto3.client("iam", region_name="us-east-1") + + client.get_account_password_policy.when.called_with().should.throw( + ClientError, + "The Password Policy with domain name {} cannot be found.".format(ACCOUNT_ID), + ) + + +@mock_iam +def test_delete_account_password_policy(): + client = boto3.client("iam", region_name="us-east-1") + client.update_account_password_policy() + + response = client.get_account_password_policy() + + response.should.have.key("PasswordPolicy").which.should.be.a(dict) + + client.delete_account_password_policy() + + client.get_account_password_policy.when.called_with().should.throw( + ClientError, + "The Password Policy with domain name {} cannot be found.".format(ACCOUNT_ID), + ) + + +@mock_iam +def test_delete_account_password_policy_errors(): + client = boto3.client("iam", region_name="us-east-1") + + client.delete_account_password_policy.when.called_with().should.throw( + ClientError, "The account policy with name PasswordPolicy cannot be found." + ) + + +@mock_iam +def test_get_account_summary(): + client = boto3.client("iam", region_name="us-east-1") + iam = boto3.resource("iam", region_name="us-east-1") + + account_summary = iam.AccountSummary() + + account_summary.summary_map.should.equal( + { + "GroupPolicySizeQuota": 5120, + "InstanceProfilesQuota": 1000, + "Policies": 0, + "GroupsPerUserQuota": 10, + "InstanceProfiles": 0, + "AttachedPoliciesPerUserQuota": 10, + "Users": 0, + "PoliciesQuota": 1500, + "Providers": 0, + "AccountMFAEnabled": 0, + "AccessKeysPerUserQuota": 2, + "AssumeRolePolicySizeQuota": 2048, + "PolicyVersionsInUseQuota": 10000, + "GlobalEndpointTokenVersion": 1, + "VersionsPerPolicyQuota": 5, + "AttachedPoliciesPerGroupQuota": 10, + "PolicySizeQuota": 6144, + "Groups": 0, + "AccountSigningCertificatesPresent": 0, + "UsersQuota": 5000, + "ServerCertificatesQuota": 20, + "MFADevices": 0, + "UserPolicySizeQuota": 2048, + "PolicyVersionsInUse": 0, + "ServerCertificates": 0, + "Roles": 0, + "RolesQuota": 1000, + "SigningCertificatesPerUserQuota": 2, + "MFADevicesInUse": 0, + "RolePolicySizeQuota": 10240, + "AttachedPoliciesPerRoleQuota": 10, + "AccountAccessKeysPresent": 0, + "GroupsQuota": 300, + } + ) + + client.create_instance_profile(InstanceProfileName="test-profile") + client.create_open_id_connect_provider(Url="https://example.com", ThumbprintList=[]) + response_policy = client.create_policy( + PolicyName="test-policy", PolicyDocument=MOCK_POLICY + ) + client.create_role(RoleName="test-role", AssumeRolePolicyDocument="test policy") + client.attach_role_policy( + RoleName="test-role", PolicyArn=response_policy["Policy"]["Arn"] + ) + client.create_saml_provider( + Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024 + ) + client.create_group(GroupName="test-group") + client.attach_group_policy( + GroupName="test-group", PolicyArn=response_policy["Policy"]["Arn"] + ) + client.create_user(UserName="test-user") + client.attach_user_policy( + UserName="test-user", PolicyArn=response_policy["Policy"]["Arn"] + ) + client.enable_mfa_device( + UserName="test-user", + SerialNumber="123456789", + AuthenticationCode1="234567", + AuthenticationCode2="987654", + ) + client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + client.upload_server_certificate( + ServerCertificateName="test-cert", + CertificateBody="cert-body", + PrivateKey="private-key", + ) + account_summary.load() + + account_summary.summary_map.should.equal( + { + "GroupPolicySizeQuota": 5120, + "InstanceProfilesQuota": 1000, + "Policies": 1, + "GroupsPerUserQuota": 10, + "InstanceProfiles": 1, + "AttachedPoliciesPerUserQuota": 10, + "Users": 1, + "PoliciesQuota": 1500, + "Providers": 2, + "AccountMFAEnabled": 0, + "AccessKeysPerUserQuota": 2, + "AssumeRolePolicySizeQuota": 2048, + "PolicyVersionsInUseQuota": 10000, + "GlobalEndpointTokenVersion": 1, + "VersionsPerPolicyQuota": 5, + "AttachedPoliciesPerGroupQuota": 10, + "PolicySizeQuota": 6144, + "Groups": 1, + "AccountSigningCertificatesPresent": 0, + "UsersQuota": 5000, + "ServerCertificatesQuota": 20, + "MFADevices": 1, + "UserPolicySizeQuota": 2048, + "PolicyVersionsInUse": 3, + "ServerCertificates": 1, + "Roles": 1, + "RolesQuota": 1000, + "SigningCertificatesPerUserQuota": 2, + "MFADevicesInUse": 1, + "RolePolicySizeQuota": 10240, + "AttachedPoliciesPerRoleQuota": 10, + "AccountAccessKeysPresent": 0, + "GroupsQuota": 300, + } + ) diff --git a/tests/test_iam/test_iam_account_aliases.py b/tests/test_iam/test_iam_account_aliases.py index 5d7dec408..d01a72106 100644 --- a/tests/test_iam/test_iam_account_aliases.py +++ b/tests/test_iam/test_iam_account_aliases.py @@ -1,20 +1,20 @@ -import boto3 -import sure # noqa -from moto import mock_iam - - -@mock_iam() -def test_account_aliases(): - client = boto3.client('iam', region_name='us-east-1') - - alias = 'my-account-name' - aliases = client.list_account_aliases() - aliases.should.have.key('AccountAliases').which.should.equal([]) - - client.create_account_alias(AccountAlias=alias) - aliases = client.list_account_aliases() - aliases.should.have.key('AccountAliases').which.should.equal([alias]) - - client.delete_account_alias(AccountAlias=alias) - aliases = client.list_account_aliases() - aliases.should.have.key('AccountAliases').which.should.equal([]) +import boto3 +import sure # noqa +from moto import mock_iam + + +@mock_iam() +def test_account_aliases(): + client = boto3.client("iam", region_name="us-east-1") + + alias = "my-account-name" + aliases = client.list_account_aliases() + aliases.should.have.key("AccountAliases").which.should.equal([]) + + client.create_account_alias(AccountAlias=alias) + aliases = client.list_account_aliases() + aliases.should.have.key("AccountAliases").which.should.equal([alias]) + + client.delete_account_alias(AccountAlias=alias) + aliases = client.list_account_aliases() + aliases.should.have.key("AccountAliases").which.should.equal([]) diff --git a/tests/test_iam/test_iam_groups.py b/tests/test_iam/test_iam_groups.py index 1ca9f2512..64d838e2b 100644 --- a/tests/test_iam/test_iam_groups.py +++ b/tests/test_iam/test_iam_groups.py @@ -8,7 +8,9 @@ import sure # noqa from nose.tools import assert_raises from boto.exception import BotoServerError +from botocore.exceptions import ClientError from moto import mock_iam, mock_iam_deprecated +from moto.core import ACCOUNT_ID MOCK_POLICY = """ { @@ -26,46 +28,49 @@ MOCK_POLICY = """ @mock_iam_deprecated() def test_create_group(): conn = boto.connect_iam() - conn.create_group('my-group') + conn.create_group("my-group") with assert_raises(BotoServerError): - conn.create_group('my-group') + conn.create_group("my-group") @mock_iam_deprecated() def test_get_group(): conn = boto.connect_iam() - conn.create_group('my-group') - conn.get_group('my-group') + conn.create_group("my-group") + conn.get_group("my-group") with assert_raises(BotoServerError): - conn.get_group('not-group') + conn.get_group("not-group") @mock_iam() def test_get_group_current(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_group(GroupName='my-group') - result = conn.get_group(GroupName='my-group') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_group(GroupName="my-group") + result = conn.get_group(GroupName="my-group") - assert result['Group']['Path'] == '/' - assert result['Group']['GroupName'] == 'my-group' - assert isinstance(result['Group']['CreateDate'], datetime) - assert result['Group']['GroupId'] - assert result['Group']['Arn'] == 'arn:aws:iam::123456789012:group/my-group' - assert not result['Users'] + assert result["Group"]["Path"] == "/" + assert result["Group"]["GroupName"] == "my-group" + assert isinstance(result["Group"]["CreateDate"], datetime) + assert result["Group"]["GroupId"] + assert result["Group"]["Arn"] == "arn:aws:iam::{}:group/my-group".format(ACCOUNT_ID) + assert not result["Users"] # Make a group with a different path: - other_group = conn.create_group(GroupName='my-other-group', Path='some/location') - assert other_group['Group']['Path'] == 'some/location' - assert other_group['Group']['Arn'] == 'arn:aws:iam::123456789012:group/some/location/my-other-group' + other_group = conn.create_group(GroupName="my-other-group", Path="some/location") + assert other_group["Group"]["Path"] == "some/location" + assert other_group["Group"][ + "Arn" + ] == "arn:aws:iam::{}:group/some/location/my-other-group".format(ACCOUNT_ID) @mock_iam_deprecated() def test_get_all_groups(): conn = boto.connect_iam() - conn.create_group('my-group1') - conn.create_group('my-group2') - groups = conn.get_all_groups()['list_groups_response'][ - 'list_groups_result']['groups'] + conn.create_group("my-group1") + conn.create_group("my-group2") + groups = conn.get_all_groups()["list_groups_response"]["list_groups_result"][ + "groups" + ] groups.should.have.length_of(2) @@ -73,95 +78,130 @@ def test_get_all_groups(): def test_add_user_to_group(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.add_user_to_group('my-group', 'my-user') - conn.create_group('my-group') + conn.add_user_to_group("my-group", "my-user") + conn.create_group("my-group") with assert_raises(BotoServerError): - conn.add_user_to_group('my-group', 'my-user') - conn.create_user('my-user') - conn.add_user_to_group('my-group', 'my-user') + conn.add_user_to_group("my-group", "my-user") + conn.create_user("my-user") + conn.add_user_to_group("my-group", "my-user") @mock_iam_deprecated() def test_remove_user_from_group(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.remove_user_from_group('my-group', 'my-user') - conn.create_group('my-group') - conn.create_user('my-user') + conn.remove_user_from_group("my-group", "my-user") + conn.create_group("my-group") + conn.create_user("my-user") with assert_raises(BotoServerError): - conn.remove_user_from_group('my-group', 'my-user') - conn.add_user_to_group('my-group', 'my-user') - conn.remove_user_from_group('my-group', 'my-user') + conn.remove_user_from_group("my-group", "my-user") + conn.add_user_to_group("my-group", "my-user") + conn.remove_user_from_group("my-group", "my-user") @mock_iam_deprecated() def test_get_groups_for_user(): conn = boto.connect_iam() - conn.create_group('my-group1') - conn.create_group('my-group2') - conn.create_group('other-group') - conn.create_user('my-user') - conn.add_user_to_group('my-group1', 'my-user') - conn.add_user_to_group('my-group2', 'my-user') + conn.create_group("my-group1") + conn.create_group("my-group2") + conn.create_group("other-group") + conn.create_user("my-user") + conn.add_user_to_group("my-group1", "my-user") + conn.add_user_to_group("my-group2", "my-user") - groups = conn.get_groups_for_user( - 'my-user')['list_groups_for_user_response']['list_groups_for_user_result']['groups'] + groups = conn.get_groups_for_user("my-user")["list_groups_for_user_response"][ + "list_groups_for_user_result" + ]["groups"] groups.should.have.length_of(2) @mock_iam_deprecated() def test_put_group_policy(): conn = boto.connect_iam() - conn.create_group('my-group') - conn.put_group_policy('my-group', 'my-policy', MOCK_POLICY) + conn.create_group("my-group") + conn.put_group_policy("my-group", "my-policy", MOCK_POLICY) @mock_iam def test_attach_group_policies(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_group(GroupName='my-group') - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.be.empty - policy_arn = 'arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceforEC2Role' - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.be.empty - conn.attach_group_policy(GroupName='my-group', PolicyArn=policy_arn) - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.equal( - [ - { - 'PolicyName': 'AmazonElasticMapReduceforEC2Role', - 'PolicyArn': policy_arn, - } - ]) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_group(GroupName="my-group") + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.be.empty + policy_arn = "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceforEC2Role" + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.be.empty + conn.attach_group_policy(GroupName="my-group", PolicyArn=policy_arn) + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.equal( + [{"PolicyName": "AmazonElasticMapReduceforEC2Role", "PolicyArn": policy_arn}] + ) - conn.detach_group_policy(GroupName='my-group', PolicyArn=policy_arn) - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.be.empty + conn.detach_group_policy(GroupName="my-group", PolicyArn=policy_arn) + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.be.empty @mock_iam_deprecated() def test_get_group_policy(): conn = boto.connect_iam() - conn.create_group('my-group') + conn.create_group("my-group") with assert_raises(BotoServerError): - conn.get_group_policy('my-group', 'my-policy') + conn.get_group_policy("my-group", "my-policy") - conn.put_group_policy('my-group', 'my-policy', MOCK_POLICY) - conn.get_group_policy('my-group', 'my-policy') + conn.put_group_policy("my-group", "my-policy", MOCK_POLICY) + conn.get_group_policy("my-group", "my-policy") @mock_iam_deprecated() def test_get_all_group_policies(): conn = boto.connect_iam() - conn.create_group('my-group') - policies = conn.get_all_group_policies('my-group')['list_group_policies_response']['list_group_policies_result']['policy_names'] + conn.create_group("my-group") + policies = conn.get_all_group_policies("my-group")["list_group_policies_response"][ + "list_group_policies_result" + ]["policy_names"] assert policies == [] - conn.put_group_policy('my-group', 'my-policy', MOCK_POLICY) - policies = conn.get_all_group_policies('my-group')['list_group_policies_response']['list_group_policies_result']['policy_names'] - assert policies == ['my-policy'] + conn.put_group_policy("my-group", "my-policy", MOCK_POLICY) + policies = conn.get_all_group_policies("my-group")["list_group_policies_response"][ + "list_group_policies_result" + ]["policy_names"] + assert policies == ["my-policy"] @mock_iam() def test_list_group_policies(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_group(GroupName='my-group') - conn.list_group_policies(GroupName='my-group')['PolicyNames'].should.be.empty - conn.put_group_policy(GroupName='my-group', PolicyName='my-policy', PolicyDocument=MOCK_POLICY) - conn.list_group_policies(GroupName='my-group')['PolicyNames'].should.equal(['my-policy']) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_group(GroupName="my-group") + conn.list_group_policies(GroupName="my-group")["PolicyNames"].should.be.empty + conn.put_group_policy( + GroupName="my-group", PolicyName="my-policy", PolicyDocument=MOCK_POLICY + ) + conn.list_group_policies(GroupName="my-group")["PolicyNames"].should.equal( + ["my-policy"] + ) + + +@mock_iam +def test_delete_group(): + conn = boto3.client("iam", region_name="us-east-1") + conn.create_group(GroupName="my-group") + groups = conn.list_groups() + assert groups["Groups"][0]["GroupName"] == "my-group" + assert len(groups["Groups"]) == 1 + conn.delete_group(GroupName="my-group") + conn.list_groups()["Groups"].should.be.empty + + +@mock_iam +def test_delete_unknown_group(): + conn = boto3.client("iam", region_name="us-east-1") + with assert_raises(ClientError) as err: + conn.delete_group(GroupName="unknown-group") + err.exception.response["Error"]["Code"].should.equal("NoSuchEntity") + err.exception.response["Error"]["Message"].should.equal( + "The group with name unknown-group cannot be found." + ) diff --git a/tests/test_iam/test_iam_policies.py b/tests/test_iam/test_iam_policies.py index e1924a559..6348b0cba 100644 --- a/tests/test_iam/test_iam_policies.py +++ b/tests/test_iam/test_iam_policies.py @@ -9,17 +9,17 @@ from moto import mock_iam invalid_policy_document_test_cases = [ { "document": "This is not a json document", - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", } }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { @@ -27,10 +27,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { @@ -38,35 +38,18 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17" - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": ["afd"] - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - "Extra field": "value" }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": {"Version": "2012-10-17"}, + "error_message": "Syntax errors in policy.", + }, + { + "document": {"Version": "2012-10-17", "Statement": ["afd"]}, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -75,10 +58,22 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Extra field": "value" - } + }, + "Extra field": "value", }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Extra field": "value", + }, + }, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -87,10 +82,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -99,10 +94,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -110,10 +105,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "invalid", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -121,46 +116,43 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "invalid", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.' + "error_message": "Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "NotAction": "", + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.' + "error_message": "Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "a a:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "a a:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Vendor a a is not valid' + "error_message": "Vendor a a is not valid", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:List:Bucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:List:Bucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Actions/Condition can contain only one colon.' + "error_message": "Actions/Condition can contain only one colon.", }, { "document": { @@ -169,16 +161,16 @@ invalid_policy_document_test_cases = [ { "Effect": "Allow", "Action": "s3s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, { "Effect": "Allow", "Action": "s:3s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - ] + "Resource": "arn:aws:s3:::example_bucket", + }, + ], }, - "error_message": 'Actions/Condition can contain only one colon.' + "error_message": "Actions/Condition can contain only one colon.", }, { "document": { @@ -186,10 +178,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "invalid resource" - } + "Resource": "invalid resource", + }, }, - "error_message": 'Resource invalid resource must be in ARN format or "*".' + "error_message": 'Resource invalid resource must be in ARN format or "*".', }, { "document": { @@ -198,39 +190,32 @@ invalid_policy_document_test_cases = [ { "Sid": "EnableDisableHongKong", "Effect": "Allow", - "Action": [ - "account:EnableRegion", - "account:DisableRegion" - ], + "Action": ["account:EnableRegion", "account:DisableRegion"], "Resource": "", "Condition": { "StringEquals": {"account:TargetRegion": "ap-east-1"} - } + }, }, { "Sid": "ViewConsole", "Effect": "Allow", - "Action": [ - "aws-portal:ViewAccount", - "account:ListRegions" - ], - "Resource": "" - } - ] + "Action": ["aws-portal:ViewAccount", "account:ListRegions"], + "Resource": "", + }, + ], }, - "error_message": 'Resource must be in ARN format or "*".' + "error_message": 'Resource must be in ARN format or "*".', }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s:3:ListBucket", - "Resource": "sdfsadf" - } + "Statement": { + "Effect": "Allow", + "Action": "s:3:ListBucket", + "Resource": "sdfsadf", + }, }, - "error_message": 'Resource sdfsadf must be in ARN format or "*".' + "error_message": 'Resource sdfsadf must be in ARN format or "*".', }, { "document": { @@ -238,10 +223,50 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": ["adf"] - } + "Resource": ["adf"], + }, }, - "error_message": 'Resource adf must be in ARN format or "*".' + "error_message": 'Resource adf must be in ARN format or "*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": {"Effect": "Allow", "Action": "s3:ListBucket", "Resource": ""}, + }, + "error_message": 'Resource must be in ARN format or "*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Resource": "a:bsdfdsafsad", + }, + }, + "error_message": 'Partition "bsdfdsafsad" is not valid for resource "arn:bsdfdsafsad:*:*:*:*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Resource": "a:b:cadfsdf", + }, + }, + "error_message": 'Partition "b" is not valid for resource "arn:b:cadfsdf:*:*:*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Resource": "a:b:c:d:e:f:g:h", + }, + }, + "error_message": 'Partition "b" is not valid for resource "arn:b:c:d:e:f:g:h".', }, { "document": { @@ -249,57 +274,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "" - } + "Resource": "aws:s3:::example_bucket", + }, }, - "error_message": 'Resource must be in ARN format or "*".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Resource": "a:bsdfdsafsad" - } - }, - "error_message": 'Partition "bsdfdsafsad" is not valid for resource "arn:bsdfdsafsad:*:*:*:*".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Resource": "a:b:cadfsdf" - } - }, - "error_message": 'Partition "b" is not valid for resource "arn:b:cadfsdf:*:*:*".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Resource": "a:b:c:d:e:f:g:h" - } - }, - "error_message": 'Partition "b" is not valid for resource "arn:b:c:d:e:f:g:h".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "aws:s3:::example_bucket" - } - }, - "error_message": 'Partition "s3" is not valid for resource "arn:s3:::example_bucket:*".' + "error_message": 'Partition "s3" is not valid for resource "arn:s3:::example_bucket:*".', }, { "document": { @@ -309,166 +287,133 @@ invalid_policy_document_test_cases = [ "Action": "s3:ListBucket", "Resource": [ "arn:error:s3:::example_bucket", - "arn:error:s3::example_bucket" - ] - } + "arn:error:s3::example_bucket", + ], + }, }, - "error_message": 'Partition "error" is not valid for resource "arn:error:s3:::example_bucket".' + "error_message": 'Partition "error" is not valid for resource "arn:error:s3:::example_bucket".', + }, + { + "document": {"Version": "2012-10-17", "Statement": []}, + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": [] + "Statement": {"Effect": "Allow", "Action": "s3:ListBucket"}, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Policy statement must contain resources.", }, { "document": { "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket" - } + "Statement": {"Effect": "Allow", "Action": "s3:ListBucket", "Resource": []}, }, - "error_message": 'Policy statement must contain resources.' + "error_message": "Policy statement must contain resources.", }, { "document": { "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": [] - } + "Statement": {"Effect": "Allow", "Action": "invalid"}, }, - "error_message": 'Policy statement must contain resources.' + "error_message": "Policy statement must contain resources.", }, { "document": { "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "invalid" - } + "Statement": {"Effect": "Allow", "Resource": "arn:aws:s3:::example_bucket"}, }, - "error_message": 'Policy statement must contain resources.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Resource": "arn:aws:s3:::example_bucket" - } - }, - "error_message": 'Policy statement must contain actions.' + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": { "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": {"Version": "2012-10-17", "Statement": {"Effect": "Allow"}}, + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": { - "Effect": "Allow" - } + "Effect": "Allow", + "Action": [], + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Policy statement must contain actions.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": [], - "Resource": "arn:aws:s3:::example_bucket" - } - }, - "error_message": 'Policy statement must contain actions.' + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": [ + {"Effect": "Deny"}, { - "Effect": "Deny" + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - ] + ], }, - "error_message": 'Policy statement must contain actions.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:iam:::example_bucket" - } - }, - "error_message": 'IAM resource path must either be "*" or start with user/, federated-user/, role/, group/, instance-profile/, mfa/, server-certificate/, policy/, sms-mfa/, saml-provider/, oidc-provider/, report/, access-report/.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3::example_bucket" - } - }, - "error_message": 'The policy failed legacy parsing' + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Resource": "arn:aws:s3::example_bucket" - } + "Action": "s3:ListBucket", + "Resource": "arn:aws:iam:::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": 'IAM resource path must either be "*" or start with user/, federated-user/, role/, group/, instance-profile/, mfa/, server-certificate/, policy/, sms-mfa/, saml-provider/, oidc-provider/, report/, access-report/.', }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3::example_bucket", + }, }, - "error_message": 'Resource vendor must be fully qualified and cannot contain regexes.' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": { - "a": "arn:aws:s3:::example_bucket" - } - } + "Statement": {"Effect": "Allow", "Resource": "arn:aws:s3::example_bucket"}, }, - "error_message": 'Syntax errors in policy.' + "error_message": "The policy failed legacy parsing", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws", + }, + }, + "error_message": "Resource vendor must be fully qualified and cannot contain regexes.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": {"a": "arn:aws:s3:::example_bucket"}, + }, + }, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -476,23 +421,22 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Deny", "Action": "s3:ListBucket", - "Resource": ["adfdf", {}] - } + "Resource": ["adfdf", {}], + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "NotResource": [] - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "NotResource": [], + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -500,135 +444,33 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Deny", "Action": [[]], - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Action": [], - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Action": [], + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": {}, - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": {}, + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": [] - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": "a" - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "a": "b" - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": "b" - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": [] - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": {}} - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": {}} - } - } - }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -637,14 +479,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "x": { - "a": "1" - } - } - } + "Condition": [], + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -653,79 +491,153 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "ForAnyValue::StringEqualsIfExists": { - "a": "asf" - } - } - } + "Condition": "a", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": [ - {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}} - ] - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"a": "b"}, + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:iam:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": "b"}, + }, }, - "error_message": 'IAM resource arn:aws:iam:us-east-1::example_bucket cannot contain region information.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": []}, + }, }, - "error_message": 'Resource arn:aws:s3:us-east-1::example_bucket can not contain region information.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Sid": {}, - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": {}}}, + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Sid": [], - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": {}}}, + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"x": {"a": "1"}}, + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"ForAnyValue::StringEqualsIfExists": {"a": "asf"}}, + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": [ + {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}} + ], + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:iam:us-east-1::example_bucket", + }, + }, + "error_message": "IAM resource arn:aws:iam:us-east-1::example_bucket cannot contain region information.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:us-east-1::example_bucket", + }, + }, + "error_message": "Resource arn:aws:s3:us-east-1::example_bucket can not contain region information.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Sid": {}, + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Sid": [], + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, + }, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -735,15 +647,12 @@ invalid_policy_document_test_cases = [ "Sid": "sdf", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Sid": "sdf", - "Effect": "Allow" - } - ] + {"Sid": "sdf", "Effect": "Allow"}, + ], }, - "error_message": 'Statement IDs (SID) in a single policy must be unique.' + "error_message": "Statement IDs (SID) in a single policy must be unique.", }, { "document": { @@ -752,15 +661,12 @@ invalid_policy_document_test_cases = [ "Sid": "sdf", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Sid": "sdf", - "Effect": "Allow" - } + {"Sid": "sdf", "Effect": "Allow"}, ] }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { @@ -769,10 +675,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "NotAction": "s3:ListBucket", "Action": "iam:dsf", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -781,10 +687,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "NotResource": "*" - } + "NotResource": "*", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -792,85 +698,74 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "denY", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": "sdfdsf"} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": "sdfdsf"} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + } }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { "Statement": { "Effect": "denY", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", } }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Condition": { - "DateGreaterThan": {"a": "sdfdsf"} - } - } + "Statement": { + "Effect": "Allow", + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3:ListBucket", - "Resource": "arn:aws::::example_bucket" - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3:ListBucket", + "Resource": "arn:aws::::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "allow", - "Resource": "arn:aws:s3:us-east-1::example_bucket" - } + "Statement": { + "Effect": "allow", + "Resource": "arn:aws:s3:us-east-1::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -880,15 +775,12 @@ invalid_policy_document_test_cases = [ "Sid": "sdf", "Effect": "aLLow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Sid": "sdf", - "Effect": "Allow" - } - ] + {"Sid": "sdf", "Effect": "Allow"}, + ], }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -896,10 +788,22 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "NotResource": "arn:aws:s3::example_bucket" - } + "NotResource": "arn:aws:s3::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateLessThanEquals": {"a": "234-13"}}, + }, + }, + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -909,13 +813,11 @@ invalid_policy_document_test_cases = [ "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", "Condition": { - "DateLessThanEquals": { - "a": "234-13" - } - } - } + "DateLessThanEquals": {"a": "2016-12-13t2:00:00.593194+1"} + }, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -925,13 +827,11 @@ invalid_policy_document_test_cases = [ "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13t2:00:00.593194+1" - } - } - } + "DateLessThanEquals": {"a": "2016-12-13t2:00:00.1999999999+10:59"} + }, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -940,30 +840,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13t2:00:00.1999999999+10:59" - } - } - } + "Condition": {"DateLessThan": {"a": "9223372036854775808"}}, + }, }, - "error_message": 'The policy failed legacy parsing' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThan": { - "a": "9223372036854775808" - } - } - } - }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -972,14 +852,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:error:s3:::example_bucket", - "Condition": { - "DateGreaterThan": { - "a": "sdfdsf" - } - } - } + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -987,11 +863,11 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws::fdsasf" - } + "Resource": "arn:aws::fdsasf", + }, }, - "error_message": 'The policy failed legacy parsing' - } + "error_message": "The policy failed legacy parsing", + }, ] valid_policy_documents = [ @@ -1000,37 +876,32 @@ valid_policy_documents = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": [ - "arn:aws:s3:::example_bucket" - ] - } + "Resource": ["arn:aws:s3:::example_bucket"], + }, }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "Action": "iam: asdf safdsf af ", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": [ - "arn:aws:s3:::example_bucket", - "*" - ] - } + "Resource": ["arn:aws:s3:::example_bucket", "*"], + }, }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "Action": "*", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", @@ -1038,9 +909,9 @@ valid_policy_documents = [ { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", } - ] + ], }, { "Version": "2012-10-17", @@ -1050,160 +921,139 @@ valid_policy_documents = [ "Resource": "*", "Condition": { "DateGreaterThan": {"aws:CurrentTime": "2017-07-01T00:00:00Z"}, - "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"} - } - } + "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"}, + }, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "fsx:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "fsx:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:iam:::user/example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:iam:::user/example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s33:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s33:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:fdsasf" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:fdsasf", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": {} - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}} - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:cloudwatch:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:cloudwatch:us-east-1::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:ec2:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:ec2:us-east-1::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:invalid-service:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:invalid-service:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:invalid-service:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:invalid-service:us-east-1::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"aws:CurrentTime": "2017-07-01T00:00:00Z"}, - "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": { + "DateGreaterThan": {"aws:CurrentTime": "2017-07-01T00:00:00Z"}, + "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"}, + }, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": []} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": []}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "a": {} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"a": {}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Sid": "dsfsdfsdfsdfsdfsadfsd", - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Sid": "dsfsdfsdfsdfsdfsadfsd", + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", @@ -1217,37 +1067,29 @@ valid_policy_documents = [ "iam:ListRoles", "iam:ListRoleTags", "iam:ListUsers", - "iam:ListUserTags" + "iam:ListUserTags", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "AddTag", "Effect": "Allow", - "Action": [ - "iam:TagUser", - "iam:TagRole" - ], + "Action": ["iam:TagUser", "iam:TagRole"], "Resource": "*", "Condition": { - "StringEquals": { - "aws:RequestTag/CostCenter": [ - "A-123", - "B-456" - ] - }, - "ForAllValues:StringEquals": {"aws:TagKeys": "CostCenter"} - } - } - ] + "StringEquals": {"aws:RequestTag/CostCenter": ["A-123", "B-456"]}, + "ForAllValues:StringEquals": {"aws:TagKeys": "CostCenter"}, + }, + }, + ], }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "NotAction": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", @@ -1256,9 +1098,9 @@ valid_policy_documents = [ "Action": "s3:*", "NotResource": [ "arn:aws:s3:::HRBucket/Payroll", - "arn:aws:s3:::HRBucket/Payroll/*" - ] - } + "arn:aws:s3:::HRBucket/Payroll/*", + ], + }, }, { "Version": "2012-10-17", @@ -1266,44 +1108,40 @@ valid_policy_documents = [ "Statement": { "Effect": "Allow", "NotAction": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "aaaaaadsfdsafsadfsadfaaaaa:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "aaaaaadsfdsafsadfsadfaaaaa:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3-s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3-s:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3.s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3.s:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3:ListBucket", - "NotResource": "*" - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3:ListBucket", + "NotResource": "*", + }, }, { "Version": "2012-10-17", @@ -1312,14 +1150,59 @@ valid_policy_documents = [ "Sid": "sdf", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - ] + "Resource": "arn:aws:s3:::example_bucket", + }, + ], + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": "01T"}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"x": {}, "y": {}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"StringEqualsIfExists": {"a": "asf"}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"ForAnyValue:StringEqualsIfExists": {"a": "asf"}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateLessThanEquals": {"a": "2019-07-01T13:20:15Z"}}, + }, }, { "Version": "2012-10-17", @@ -1328,11 +1211,9 @@ valid_policy_documents = [ "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", "Condition": { - "DateGreaterThan": { - "a": "01T" - } - } - } + "DateLessThanEquals": {"a": "2016-12-13T21:20:37.593194+00:00"} + }, + }, }, { "Version": "2012-10-17", @@ -1340,12 +1221,8 @@ valid_policy_documents = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "x": { - }, - "y": {} - } - } + "Condition": {"DateLessThanEquals": {"a": "2016-12-13t2:00:00.593194+23"}}, + }, }, { "Version": "2012-10-17", @@ -1353,77 +1230,8 @@ valid_policy_documents = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "StringEqualsIfExists": { - "a": "asf" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "ForAnyValue:StringEqualsIfExists": { - "a": "asf" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2019-07-01T13:20:15Z" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13T21:20:37.593194+00:00" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13t2:00:00.593194+23" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThan": { - "a": "-292275054" - } - } - } + "Condition": {"DateLessThan": {"a": "-292275054"}}, + }, }, { "Version": "2012-10-17", @@ -1434,18 +1242,15 @@ valid_policy_documents = [ "Action": [ "iam:GetAccountPasswordPolicy", "iam:GetAccountSummary", - "iam:ListVirtualMFADevices" + "iam:ListVirtualMFADevices", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "AllowManageOwnPasswords", "Effect": "Allow", - "Action": [ - "iam:ChangePassword", - "iam:GetUser" - ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Action": ["iam:ChangePassword", "iam:GetUser"], + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnAccessKeys", @@ -1454,9 +1259,9 @@ valid_policy_documents = [ "iam:CreateAccessKey", "iam:DeleteAccessKey", "iam:ListAccessKeys", - "iam:UpdateAccessKey" + "iam:UpdateAccessKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSigningCertificates", @@ -1465,9 +1270,9 @@ valid_policy_documents = [ "iam:DeleteSigningCertificate", "iam:ListSigningCertificates", "iam:UpdateSigningCertificate", - "iam:UploadSigningCertificate" + "iam:UploadSigningCertificate", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSSHPublicKeys", @@ -1477,9 +1282,9 @@ valid_policy_documents = [ "iam:GetSSHPublicKey", "iam:ListSSHPublicKeys", "iam:UpdateSSHPublicKey", - "iam:UploadSSHPublicKey" + "iam:UploadSSHPublicKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnGitCredentials", @@ -1489,18 +1294,15 @@ valid_policy_documents = [ "iam:DeleteServiceSpecificCredential", "iam:ListServiceSpecificCredentials", "iam:ResetServiceSpecificCredential", - "iam:UpdateServiceSpecificCredential" + "iam:UpdateServiceSpecificCredential", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnVirtualMFADevice", "Effect": "Allow", - "Action": [ - "iam:CreateVirtualMFADevice", - "iam:DeleteVirtualMFADevice" - ], - "Resource": "arn:aws:iam::*:mfa/${aws:username}" + "Action": ["iam:CreateVirtualMFADevice", "iam:DeleteVirtualMFADevice"], + "Resource": "arn:aws:iam::*:mfa/${aws:username}", }, { "Sid": "AllowManageOwnUserMFA", @@ -1509,9 +1311,9 @@ valid_policy_documents = [ "iam:DeactivateMFADevice", "iam:EnableMFADevice", "iam:ListMFADevices", - "iam:ResyncMFADevice" + "iam:ResyncMFADevice", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "DenyAllExceptListedIfNoMFA", @@ -1523,16 +1325,12 @@ valid_policy_documents = [ "iam:ListMFADevices", "iam:ListVirtualMFADevices", "iam:ResyncMFADevice", - "sts:GetSessionToken" + "sts:GetSessionToken", ], "Resource": "*", - "Condition": { - "BoolIfExists": { - "aws:MultiFactorAuthPresent": "false" - } - } - } - ] + "Condition": {"BoolIfExists": {"aws:MultiFactorAuthPresent": "false"}}, + }, + ], }, { "Version": "2012-10-17", @@ -1544,9 +1342,9 @@ valid_policy_documents = [ "dynamodb:List*", "dynamodb:DescribeReservedCapacity*", "dynamodb:DescribeLimits", - "dynamodb:DescribeTimeToLive" + "dynamodb:DescribeTimeToLive", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "SpecificTable", @@ -1562,57 +1360,47 @@ valid_policy_documents = [ "dynamodb:CreateTable", "dynamodb:Delete*", "dynamodb:Update*", - "dynamodb:PutItem" + "dynamodb:PutItem", ], - "Resource": "arn:aws:dynamodb:*:*:table/MyTable" - } - ] + "Resource": "arn:aws:dynamodb:*:*:table/MyTable", + }, + ], }, { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Action": [ - "ec2:AttachVolume", - "ec2:DetachVolume" - ], - "Resource": [ - "arn:aws:ec2:*:*:volume/*", - "arn:aws:ec2:*:*:instance/*" - ], + "Action": ["ec2:AttachVolume", "ec2:DetachVolume"], + "Resource": ["arn:aws:ec2:*:*:volume/*", "arn:aws:ec2:*:*:instance/*"], "Condition": { - "ArnEquals": {"ec2:SourceInstanceARN": "arn:aws:ec2:*:*:instance/instance-id"} - } + "ArnEquals": { + "ec2:SourceInstanceARN": "arn:aws:ec2:*:*:instance/instance-id" + } + }, } - ] + ], }, { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Action": [ - "ec2:AttachVolume", - "ec2:DetachVolume" - ], + "Action": ["ec2:AttachVolume", "ec2:DetachVolume"], "Resource": "arn:aws:ec2:*:*:instance/*", "Condition": { "StringEquals": {"ec2:ResourceTag/Department": "Development"} - } + }, }, { "Effect": "Allow", - "Action": [ - "ec2:AttachVolume", - "ec2:DetachVolume" - ], + "Action": ["ec2:AttachVolume", "ec2:DetachVolume"], "Resource": "arn:aws:ec2:*:*:volume/*", "Condition": { "StringEquals": {"ec2:ResourceTag/VolumeUser": "${aws:username}"} - } - } - ] + }, + }, + ], }, { "Version": "2012-10-17", @@ -1623,17 +1411,17 @@ valid_policy_documents = [ "Action": [ "ec2:StartInstances", "ec2:StopInstances", - "ec2:DescribeTags" + "ec2:DescribeTags", ], "Resource": "arn:aws:ec2:region:account-id:instance/*", "Condition": { "StringEquals": { "ec2:ResourceTag/Project": "DataAnalytics", - "aws:PrincipalTag/Department": "Data" + "aws:PrincipalTag/Department": "Data", } - } + }, } - ] + ], }, { "Version": "2012-10-17", @@ -1645,59 +1433,48 @@ valid_policy_documents = [ "Resource": ["arn:aws:s3:::bucket-name"], "Condition": { "StringLike": { - "s3:prefix": ["cognito/application-name/${cognito-identity.amazonaws.com:sub}"] + "s3:prefix": [ + "cognito/application-name/${cognito-identity.amazonaws.com:sub}" + ] } - } + }, }, { "Sid": "ReadWriteDeleteYourObjects", "Effect": "Allow", - "Action": [ - "s3:GetObject", - "s3:PutObject", - "s3:DeleteObject" - ], + "Action": ["s3:GetObject", "s3:PutObject", "s3:DeleteObject"], "Resource": [ "arn:aws:s3:::bucket-name/cognito/application-name/${cognito-identity.amazonaws.com:sub}", - "arn:aws:s3:::bucket-name/cognito/application-name/${cognito-identity.amazonaws.com:sub}/*" - ] - } - ] + "arn:aws:s3:::bucket-name/cognito/application-name/${cognito-identity.amazonaws.com:sub}/*", + ], + }, + ], }, { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Action": [ - "s3:ListAllMyBuckets", - "s3:GetBucketLocation" - ], - "Resource": "*" + "Action": ["s3:ListAllMyBuckets", "s3:GetBucketLocation"], + "Resource": "*", }, { "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::bucket-name", "Condition": { - "StringLike": { - "s3:prefix": [ - "", - "home/", - "home/${aws:userid}/*" - ] - } - } + "StringLike": {"s3:prefix": ["", "home/", "home/${aws:userid}/*"]} + }, }, { "Effect": "Allow", "Action": "s3:*", "Resource": [ "arn:aws:s3:::bucket-name/home/${aws:userid}", - "arn:aws:s3:::bucket-name/home/${aws:userid}/*" - ] - } - ] + "arn:aws:s3:::bucket-name/home/${aws:userid}/*", + ], + }, + ], }, { "Version": "2012-10-17", @@ -1711,23 +1488,23 @@ valid_policy_documents = [ "s3:GetBucketLocation", "s3:GetBucketPolicyStatus", "s3:GetBucketPublicAccessBlock", - "s3:ListAllMyBuckets" + "s3:ListAllMyBuckets", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "ListObjectsInBucket", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": ["arn:aws:s3:::bucket-name"] + "Resource": ["arn:aws:s3:::bucket-name"], }, { "Sid": "AllObjectActions", "Effect": "Allow", "Action": "s3:*Object", - "Resource": ["arn:aws:s3:::bucket-name/*"] - } - ] + "Resource": ["arn:aws:s3:::bucket-name/*"], + }, + ], }, { "Version": "2012-10-17", @@ -1735,20 +1512,14 @@ valid_policy_documents = [ { "Sid": "AllowViewAccountInfo", "Effect": "Allow", - "Action": [ - "iam:GetAccountPasswordPolicy", - "iam:GetAccountSummary" - ], - "Resource": "*" + "Action": ["iam:GetAccountPasswordPolicy", "iam:GetAccountSummary"], + "Resource": "*", }, { "Sid": "AllowManageOwnPasswords", "Effect": "Allow", - "Action": [ - "iam:ChangePassword", - "iam:GetUser" - ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Action": ["iam:ChangePassword", "iam:GetUser"], + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnAccessKeys", @@ -1757,9 +1528,9 @@ valid_policy_documents = [ "iam:CreateAccessKey", "iam:DeleteAccessKey", "iam:ListAccessKeys", - "iam:UpdateAccessKey" + "iam:UpdateAccessKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSigningCertificates", @@ -1768,9 +1539,9 @@ valid_policy_documents = [ "iam:DeleteSigningCertificate", "iam:ListSigningCertificates", "iam:UpdateSigningCertificate", - "iam:UploadSigningCertificate" + "iam:UploadSigningCertificate", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSSHPublicKeys", @@ -1780,9 +1551,9 @@ valid_policy_documents = [ "iam:GetSSHPublicKey", "iam:ListSSHPublicKeys", "iam:UpdateSSHPublicKey", - "iam:UploadSSHPublicKey" + "iam:UploadSSHPublicKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnGitCredentials", @@ -1792,11 +1563,11 @@ valid_policy_documents = [ "iam:DeleteServiceSpecificCredential", "iam:ListServiceSpecificCredentials", "iam:ResetServiceSpecificCredential", - "iam:UpdateServiceSpecificCredential" + "iam:UpdateServiceSpecificCredential", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" - } - ] + "Resource": "arn:aws:iam::*:user/${aws:username}", + }, + ], }, { "Version": "2012-10-17", @@ -1805,13 +1576,9 @@ valid_policy_documents = [ "Action": "ec2:*", "Resource": "*", "Effect": "Allow", - "Condition": { - "StringEquals": { - "ec2:Region": "region" - } - } + "Condition": {"StringEquals": {"ec2:Region": "region"}}, } - ] + ], }, { "Version": "2012-10-17", @@ -1819,15 +1586,28 @@ valid_policy_documents = [ { "Effect": "Allow", "Action": "rds:*", - "Resource": ["arn:aws:rds:region:*:*"] + "Resource": ["arn:aws:rds:region:*:*"], + }, + {"Effect": "Allow", "Action": ["rds:Describe*"], "Resource": ["*"]}, + ], + }, + { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "", + "Effect": "Allow", + "Action": "rds:*", + "Resource": ["arn:aws:rds:region:*:*"], }, { + "Sid": "", "Effect": "Allow", "Action": ["rds:Describe*"], - "Resource": ["*"] - } - ] - } + "Resource": ["*"], + }, + ], + }, ] @@ -1843,19 +1623,20 @@ def test_create_policy_with_valid_policy_documents(): @mock_iam def check_create_policy_with_invalid_policy_document(test_case): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError) as ex: conn.create_policy( PolicyName="TestCreatePolicy", - PolicyDocument=json.dumps(test_case["document"])) - ex.exception.response['Error']['Code'].should.equal('MalformedPolicyDocument') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal(test_case["error_message"]) + PolicyDocument=json.dumps(test_case["document"]), + ) + ex.exception.response["Error"]["Code"].should.equal("MalformedPolicyDocument") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal(test_case["error_message"]) @mock_iam def check_create_policy_with_valid_policy_document(valid_policy_document): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_policy( - PolicyName="TestCreatePolicy", - PolicyDocument=json.dumps(valid_policy_document)) + PolicyName="TestCreatePolicy", PolicyDocument=json.dumps(valid_policy_document) + ) diff --git a/tests/test_iam/test_server.py b/tests/test_iam/test_server.py index 80c15b59d..4d1698424 100644 --- a/tests/test_iam/test_server.py +++ b/tests/test_iam/test_server.py @@ -1,26 +1,27 @@ -from __future__ import unicode_literals - -import json - -import re -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_iam_server_get(): - backend = server.create_backend_app("iam") - test_client = backend.test_client() - - group_data = test_client.action_data( - "CreateGroup", GroupName="test group", Path="/") - group_id = re.search("(.*)", group_data).groups()[0] - - groups_data = test_client.action_data("ListGroups") - groups_ids = re.findall("(.*)", groups_data) - - assert group_id in groups_ids +from __future__ import unicode_literals + +import json + +import re +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_iam_server_get(): + backend = server.create_backend_app("iam") + test_client = backend.test_client() + + group_data = test_client.action_data( + "CreateGroup", GroupName="test group", Path="/" + ) + group_id = re.search("(.*)", group_data).groups()[0] + + groups_data = test_client.action_data("ListGroups") + groups_ids = re.findall("(.*)", groups_data) + + assert group_id in groups_ids diff --git a/tests/test_iot/test_iot.py b/tests/test_iot/test_iot.py index 4a142b292..49a0af974 100644 --- a/tests/test_iot/test_iot.py +++ b/tests/test_iot/test_iot.py @@ -10,403 +10,441 @@ from nose.tools import assert_raises @mock_iot def test_attach_policy(): - client = boto3.client('iot', region_name='ap-northeast-1') - policy_name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + policy_name = "my-policy" + doc = "{}" cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.create_policy(policyName=policy_name, policyDocument=doc) client.attach_policy(policyName=policy_name, target=cert_arn) res = client.list_attached_policies(target=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - res['policies'][0]['policyName'].should.equal('my-policy') + res.should.have.key("policies").which.should.have.length_of(1) + res["policies"][0]["policyName"].should.equal("my-policy") @mock_iot def test_detach_policy(): - client = boto3.client('iot', region_name='ap-northeast-1') - policy_name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + policy_name = "my-policy" + doc = "{}" cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.create_policy(policyName=policy_name, policyDocument=doc) client.attach_policy(policyName=policy_name, target=cert_arn) res = client.list_attached_policies(target=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - res['policies'][0]['policyName'].should.equal('my-policy') + res.should.have.key("policies").which.should.have.length_of(1) + res["policies"][0]["policyName"].should.equal("my-policy") client.detach_policy(policyName=policy_name, target=cert_arn) res = client.list_attached_policies(target=cert_arn) - res.should.have.key('policies').which.should.be.empty + res.should.have.key("policies").which.should.be.empty @mock_iot def test_list_attached_policies(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") cert = client.create_keys_and_certificate(setAsActive=True) - policies = client.list_attached_policies(target=cert['certificateArn']) - policies['policies'].should.be.empty + policies = client.list_attached_policies(target=cert["certificateArn"]) + policies["policies"].should.be.empty @mock_iot def test_policy_versions(): - client = boto3.client('iot', region_name='ap-northeast-1') - policy_name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + policy_name = "my-policy" + doc = "{}" policy = client.create_policy(policyName=policy_name, policyDocument=doc) - policy.should.have.key('policyName').which.should.equal(policy_name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(json.dumps({})) - policy.should.have.key('policyVersionId').which.should.equal('1') + policy.should.have.key("policyName").which.should.equal(policy_name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(json.dumps({})) + policy.should.have.key("policyVersionId").which.should.equal("1") policy = client.get_policy(policyName=policy_name) - policy.should.have.key('policyName').which.should.equal(policy_name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(json.dumps({})) - policy.should.have.key('defaultVersionId').which.should.equal(policy['defaultVersionId']) + policy.should.have.key("policyName").which.should.equal(policy_name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(json.dumps({})) + policy.should.have.key("defaultVersionId").which.should.equal(policy["defaultVersionId"]) - policy1 = client.create_policy_version(policyName=policy_name, policyDocument=json.dumps({'version': 'version_1'}), + policy1 = client.create_policy_version(policyName=policy_name, policyDocument=json.dumps({"version": "version_1"}), setAsDefault=True) - policy1.should.have.key('policyArn').which.should_not.be.none - policy1.should.have.key('policyDocument').which.should.equal(json.dumps({'version': 'version_1'})) - policy1.should.have.key('policyVersionId').which.should.equal('2') - policy1.should.have.key('isDefaultVersion').which.should.equal(True) + policy1.should.have.key("policyArn").which.should_not.be.none + policy1.should.have.key("policyDocument").which.should.equal(json.dumps({"version": "version_1"})) + policy1.should.have.key("policyVersionId").which.should.equal("2") + policy1.should.have.key("isDefaultVersion").which.should.equal(True) - policy2 = client.create_policy_version(policyName=policy_name, policyDocument=json.dumps({'version': 'version_2'}), + policy2 = client.create_policy_version(policyName=policy_name, policyDocument=json.dumps({"version": "version_2"}), setAsDefault=False) - policy2.should.have.key('policyArn').which.should_not.be.none - policy2.should.have.key('policyDocument').which.should.equal(json.dumps({'version': 'version_2'})) - policy2.should.have.key('policyVersionId').which.should.equal('3') - policy2.should.have.key('isDefaultVersion').which.should.equal(False) + policy2.should.have.key("policyArn").which.should_not.be.none + policy2.should.have.key("policyDocument").which.should.equal(json.dumps({"version": "version_2"})) + policy2.should.have.key("policyVersionId").which.should.equal("3") + policy2.should.have.key("isDefaultVersion").which.should.equal(False) policy = client.get_policy(policyName=policy_name) - policy.should.have.key('policyName').which.should.equal(policy_name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(json.dumps({'version': 'version_1'})) - policy.should.have.key('defaultVersionId').which.should.equal(policy1['policyVersionId']) + policy.should.have.key("policyName").which.should.equal(policy_name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(json.dumps({"version": "version_1"})) + policy.should.have.key("defaultVersionId").which.should.equal(policy1["policyVersionId"]) policy_versions = client.list_policy_versions(policyName=policy_name) - policy_versions.should.have.key('policyVersions').which.should.have.length_of(3) - list(map(lambda item: item['isDefaultVersion'], policy_versions['policyVersions'])).count(True).should.equal(1) - default_policy = list(filter(lambda item: item['isDefaultVersion'], policy_versions['policyVersions'])) - default_policy[0].should.have.key('versionId').should.equal(policy1['policyVersionId']) + policy_versions.should.have.key("policyVersions").which.should.have.length_of(3) + list(map(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])).count(True).should.equal(1) + default_policy = list(filter(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])) + default_policy[0].should.have.key("versionId").should.equal(policy1["policyVersionId"]) policy = client.get_policy(policyName=policy_name) - policy.should.have.key('policyName').which.should.equal(policy_name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(json.dumps({'version': 'version_1'})) - policy.should.have.key('defaultVersionId').which.should.equal(policy1['policyVersionId']) + policy.should.have.key("policyName").which.should.equal(policy_name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(json.dumps({"version": "version_1"})) + policy.should.have.key("defaultVersionId").which.should.equal(policy1["policyVersionId"]) - client.set_default_policy_version(policyName=policy_name, policyVersionId=policy2['policyVersionId']) + client.set_default_policy_version(policyName=policy_name, policyVersionId=policy2["policyVersionId"]) policy_versions = client.list_policy_versions(policyName=policy_name) - policy_versions.should.have.key('policyVersions').which.should.have.length_of(3) - list(map(lambda item: item['isDefaultVersion'], policy_versions['policyVersions'])).count(True).should.equal(1) - default_policy = list(filter(lambda item: item['isDefaultVersion'], policy_versions['policyVersions'])) - default_policy[0].should.have.key('versionId').should.equal(policy2['policyVersionId']) + policy_versions.should.have.key("policyVersions").which.should.have.length_of(3) + list(map(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])).count(True).should.equal(1) + default_policy = list(filter(lambda item: item["isDefaultVersion"], policy_versions["policyVersions"])) + default_policy[0].should.have.key("versionId").should.equal(policy2["policyVersionId"]) policy = client.get_policy(policyName=policy_name) - policy.should.have.key('policyName').which.should.equal(policy_name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(json.dumps({'version': 'version_2'})) - policy.should.have.key('defaultVersionId').which.should.equal(policy2['policyVersionId']) + policy.should.have.key("policyName").which.should.equal(policy_name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(json.dumps({"version": "version_2"})) + policy.should.have.key("defaultVersionId").which.should.equal(policy2["policyVersionId"]) - client.delete_policy_version(policyName=policy_name, policyVersionId='1') + client.delete_policy_version(policyName=policy_name, policyVersionId="1") policy_versions = client.list_policy_versions(policyName=policy_name) - policy_versions.should.have.key('policyVersions').which.should.have.length_of(2) + policy_versions.should.have.key("policyVersions").which.should.have.length_of(2) - client.delete_policy_version(policyName=policy_name, policyVersionId=policy1['policyVersionId']) + client.delete_policy_version(policyName=policy_name, policyVersionId=policy1["policyVersionId"]) policy_versions = client.list_policy_versions(policyName=policy_name) - policy_versions.should.have.key('policyVersions').which.should.have.length_of(1) + policy_versions.should.have.key("policyVersions").which.should.have.length_of(1) - # should fail as it's the default policy. Should use delete_policy instead + # should fail as it"s the default policy. Should use delete_policy instead try: - client.delete_policy_version(policyName=policy_name, policyVersionId=policy2['policyVersionId']) - assert False, 'Should have failed in previous call' + client.delete_policy_version(policyName=policy_name, policyVersionId=policy2["policyVersionId"]) + assert False, "Should have failed in previous call" except Exception as exception: - exception.response['Error']['Message'].should.equal('Cannot delete the default version of a policy') + exception.response["Error"]["Message"].should.equal("Cannot delete the default version of a policy") @mock_iot def test_things(): - client = boto3.client('iot', region_name='ap-northeast-1') - name = 'my-thing' - type_name = 'my-type-name' + client = boto3.client("iot", region_name="ap-northeast-1") + name = "my-thing" + type_name = "my-type-name" # thing type thing_type = client.create_thing_type(thingTypeName=type_name) - thing_type.should.have.key('thingTypeName').which.should.equal(type_name) - thing_type.should.have.key('thingTypeArn') + thing_type.should.have.key("thingTypeName").which.should.equal(type_name) + thing_type.should.have.key("thingTypeArn") res = client.list_thing_types() - res.should.have.key('thingTypes').which.should.have.length_of(1) - for thing_type in res['thingTypes']: - thing_type.should.have.key('thingTypeName').which.should_not.be.none + res.should.have.key("thingTypes").which.should.have.length_of(1) + for thing_type in res["thingTypes"]: + thing_type.should.have.key("thingTypeName").which.should_not.be.none thing_type = client.describe_thing_type(thingTypeName=type_name) - thing_type.should.have.key('thingTypeName').which.should.equal(type_name) - thing_type.should.have.key('thingTypeProperties') - thing_type.should.have.key('thingTypeMetadata') + thing_type.should.have.key("thingTypeName").which.should.equal(type_name) + thing_type.should.have.key("thingTypeProperties") + thing_type.should.have.key("thingTypeMetadata") # thing thing = client.create_thing(thingName=name, thingTypeName=type_name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") res = client.list_things() - res.should.have.key('things').which.should.have.length_of(1) - for thing in res['things']: - thing.should.have.key('thingName').which.should_not.be.none - thing.should.have.key('thingArn').which.should_not.be.none + res.should.have.key("things").which.should.have.length_of(1) + for thing in res["things"]: + thing.should.have.key("thingName").which.should_not.be.none + thing.should.have.key("thingArn").which.should_not.be.none - thing = client.update_thing(thingName=name, attributePayload={'attributes': {'k1': 'v1'}}) + thing = client.update_thing( + thingName=name, attributePayload={"attributes": {"k1": "v1"}} + ) res = client.list_things() - res.should.have.key('things').which.should.have.length_of(1) - for thing in res['things']: - thing.should.have.key('thingName').which.should_not.be.none - thing.should.have.key('thingArn').which.should_not.be.none - res['things'][0]['attributes'].should.have.key('k1').which.should.equal('v1') + res.should.have.key("things").which.should.have.length_of(1) + for thing in res["things"]: + thing.should.have.key("thingName").which.should_not.be.none + thing.should.have.key("thingArn").which.should_not.be.none + res["things"][0]["attributes"].should.have.key("k1").which.should.equal("v1") thing = client.describe_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('defaultClientId') - thing.should.have.key('thingTypeName') - thing.should.have.key('attributes') - thing.should.have.key('version') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("defaultClientId") + thing.should.have.key("thingTypeName") + thing.should.have.key("attributes") + thing.should.have.key("version") # delete thing client.delete_thing(thingName=name) res = client.list_things() - res.should.have.key('things').which.should.have.length_of(0) + res.should.have.key("things").which.should.have.length_of(0) # delete thing type client.delete_thing_type(thingTypeName=type_name) res = client.list_thing_types() - res.should.have.key('thingTypes').which.should.have.length_of(0) + res.should.have.key("thingTypes").which.should.have.length_of(0) @mock_iot def test_list_thing_types(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") for i in range(0, 100): client.create_thing_type(thingTypeName=str(i + 1)) thing_types = client.list_thing_types() - thing_types.should.have.key('nextToken') - thing_types.should.have.key('thingTypes').which.should.have.length_of(50) - thing_types['thingTypes'][0]['thingTypeName'].should.equal('1') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('50') + thing_types.should.have.key("nextToken") + thing_types.should.have.key("thingTypes").which.should.have.length_of(50) + thing_types["thingTypes"][0]["thingTypeName"].should.equal("1") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("50") - thing_types = client.list_thing_types(nextToken=thing_types['nextToken']) - thing_types.should.have.key('thingTypes').which.should.have.length_of(50) - thing_types.should_not.have.key('nextToken') - thing_types['thingTypes'][0]['thingTypeName'].should.equal('51') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('100') + thing_types = client.list_thing_types(nextToken=thing_types["nextToken"]) + thing_types.should.have.key("thingTypes").which.should.have.length_of(50) + thing_types.should_not.have.key("nextToken") + thing_types["thingTypes"][0]["thingTypeName"].should.equal("51") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("100") @mock_iot def test_list_thing_types_with_typename_filter(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") - client.create_thing_type(thingTypeName='thing') - client.create_thing_type(thingTypeName='thingType') - client.create_thing_type(thingTypeName='thingTypeName') - client.create_thing_type(thingTypeName='thingTypeNameGroup') - client.create_thing_type(thingTypeName='shouldNotFind') - client.create_thing_type(thingTypeName='find me it shall not') + client.create_thing_type(thingTypeName="thing") + client.create_thing_type(thingTypeName="thingType") + client.create_thing_type(thingTypeName="thingTypeName") + client.create_thing_type(thingTypeName="thingTypeNameGroup") + client.create_thing_type(thingTypeName="shouldNotFind") + client.create_thing_type(thingTypeName="find me it shall not") - thing_types = client.list_thing_types(thingTypeName='thing') - thing_types.should_not.have.key('nextToken') - thing_types.should.have.key('thingTypes').which.should.have.length_of(4) - thing_types['thingTypes'][0]['thingTypeName'].should.equal('thing') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('thingTypeNameGroup') + thing_types = client.list_thing_types(thingTypeName="thing") + thing_types.should_not.have.key("nextToken") + thing_types.should.have.key("thingTypes").which.should.have.length_of(4) + thing_types["thingTypes"][0]["thingTypeName"].should.equal("thing") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("thingTypeNameGroup") - thing_types = client.list_thing_types(thingTypeName='thingTypeName') - thing_types.should_not.have.key('nextToken') - thing_types.should.have.key('thingTypes').which.should.have.length_of(2) - thing_types['thingTypes'][0]['thingTypeName'].should.equal('thingTypeName') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('thingTypeNameGroup') + thing_types = client.list_thing_types(thingTypeName="thingTypeName") + thing_types.should_not.have.key("nextToken") + thing_types.should.have.key("thingTypes").which.should.have.length_of(2) + thing_types["thingTypes"][0]["thingTypeName"].should.equal("thingTypeName") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("thingTypeNameGroup") @mock_iot def test_list_things_with_next_token(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") for i in range(0, 200): client.create_thing(thingName=str(i + 1)) things = client.list_things() - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('1') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/1') - things['things'][-1]['thingName'].should.equal('50') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/50') + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("1") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/1") + things["things"][-1]["thingName"].should.equal("50") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/50" + ) - things = client.list_things(nextToken=things['nextToken']) - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('51') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/51') - things['things'][-1]['thingName'].should.equal('100') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/100') + things = client.list_things(nextToken=things["nextToken"]) + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("51") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/51" + ) + things["things"][-1]["thingName"].should.equal("100") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/100" + ) - things = client.list_things(nextToken=things['nextToken']) - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('101') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/101') - things['things'][-1]['thingName'].should.equal('150') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/150') + things = client.list_things(nextToken=things["nextToken"]) + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("101") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/101" + ) + things["things"][-1]["thingName"].should.equal("150") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/150" + ) - things = client.list_things(nextToken=things['nextToken']) - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('151') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/151') - things['things'][-1]['thingName'].should.equal('200') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/200') + things = client.list_things(nextToken=things["nextToken"]) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("151") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/151" + ) + things["things"][-1]["thingName"].should.equal("200") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/200" + ) @mock_iot def test_list_things_with_attribute_and_thing_type_filter_and_next_token(): - client = boto3.client('iot', region_name='ap-northeast-1') - client.create_thing_type(thingTypeName='my-thing-type') + client = boto3.client("iot", region_name="ap-northeast-1") + client.create_thing_type(thingTypeName="my-thing-type") for i in range(0, 200): if not (i + 1) % 3: - attribute_payload = { - 'attributes': { - 'foo': 'bar' - } - } + attribute_payload = {"attributes": {"foo": "bar"}} elif not (i + 1) % 5: - attribute_payload = { - 'attributes': { - 'bar': 'foo' - } - } + attribute_payload = {"attributes": {"bar": "foo"}} else: attribute_payload = {} if not (i + 1) % 2: - thing_type_name = 'my-thing-type' - client.create_thing(thingName=str(i + 1), thingTypeName=thing_type_name, attributePayload=attribute_payload) + thing_type_name = "my-thing-type" + client.create_thing( + thingName=str(i + 1), + thingTypeName=thing_type_name, + attributePayload=attribute_payload, + ) else: - client.create_thing(thingName=str(i + 1), attributePayload=attribute_payload) + client.create_thing( + thingName=str(i + 1), attributePayload=attribute_payload + ) # Test filter for thingTypeName things = client.list_things(thingTypeName=thing_type_name) - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('2') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/2') - things['things'][-1]['thingName'].should.equal('100') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/100') - all(item['thingTypeName'] == thing_type_name for item in things['things']) + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("2") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/2") + things["things"][-1]["thingName"].should.equal("100") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/100" + ) + all(item["thingTypeName"] == thing_type_name for item in things["things"]) - things = client.list_things(nextToken=things['nextToken'], thingTypeName=thing_type_name) - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('102') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/102') - things['things'][-1]['thingName'].should.equal('200') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/200') - all(item['thingTypeName'] == thing_type_name for item in things['things']) + things = client.list_things( + nextToken=things["nextToken"], thingTypeName=thing_type_name + ) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("102") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/102" + ) + things["things"][-1]["thingName"].should.equal("200") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/200" + ) + all(item["thingTypeName"] == thing_type_name for item in things["things"]) # Test filter for attributes - things = client.list_things(attributeName='foo', attributeValue='bar') - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('3') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/3') - things['things'][-1]['thingName'].should.equal('150') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/150') - all(item['attributes'] == {'foo': 'bar'} for item in things['things']) + things = client.list_things(attributeName="foo", attributeValue="bar") + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("3") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/3") + things["things"][-1]["thingName"].should.equal("150") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/150" + ) + all(item["attributes"] == {"foo": "bar"} for item in things["things"]) - things = client.list_things(nextToken=things['nextToken'], attributeName='foo', attributeValue='bar') - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(16) - things['things'][0]['thingName'].should.equal('153') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/153') - things['things'][-1]['thingName'].should.equal('198') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/198') - all(item['attributes'] == {'foo': 'bar'} for item in things['things']) + things = client.list_things( + nextToken=things["nextToken"], attributeName="foo", attributeValue="bar" + ) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(16) + things["things"][0]["thingName"].should.equal("153") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/153" + ) + things["things"][-1]["thingName"].should.equal("198") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/198" + ) + all(item["attributes"] == {"foo": "bar"} for item in things["things"]) # Test filter for attributes and thingTypeName - things = client.list_things(thingTypeName=thing_type_name, attributeName='foo', attributeValue='bar') - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(33) - things['things'][0]['thingName'].should.equal('6') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/6') - things['things'][-1]['thingName'].should.equal('198') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/198') - all(item['attributes'] == {'foo': 'bar'} and item['thingTypeName'] == thing_type_name for item in things['things']) + things = client.list_things( + thingTypeName=thing_type_name, attributeName="foo", attributeValue="bar" + ) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(33) + things["things"][0]["thingName"].should.equal("6") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/6") + things["things"][-1]["thingName"].should.equal("198") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/198" + ) + all( + item["attributes"] == {"foo": "bar"} + and item["thingTypeName"] == thing_type_name + for item in things["things"] + ) @mock_iot def test_certs(): - client = boto3.client('iot', region_name='us-east-1') + client = boto3.client("iot", region_name="us-east-1") cert = client.create_keys_and_certificate(setAsActive=True) - cert.should.have.key('certificateArn').which.should_not.be.none - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('certificatePem').which.should_not.be.none - cert.should.have.key('keyPair') - cert['keyPair'].should.have.key('PublicKey').which.should_not.be.none - cert['keyPair'].should.have.key('PrivateKey').which.should_not.be.none - cert_id = cert['certificateId'] + cert.should.have.key("certificateArn").which.should_not.be.none + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("certificatePem").which.should_not.be.none + cert.should.have.key("keyPair") + cert["keyPair"].should.have.key("PublicKey").which.should_not.be.none + cert["keyPair"].should.have.key("PrivateKey").which.should_not.be.none + cert_id = cert["certificateId"] cert = client.describe_certificate(certificateId=cert_id) - cert.should.have.key('certificateDescription') - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('certificateArn').which.should_not.be.none - cert_desc.should.have.key('certificateId').which.should_not.be.none - cert_desc.should.have.key('certificatePem').which.should_not.be.none - cert_desc.should.have.key('status').which.should.equal('ACTIVE') - cert_pem = cert_desc['certificatePem'] + cert.should.have.key("certificateDescription") + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("certificateArn").which.should_not.be.none + cert_desc.should.have.key("certificateId").which.should_not.be.none + cert_desc.should.have.key("certificatePem").which.should_not.be.none + cert_desc.should.have.key("status").which.should.equal("ACTIVE") + cert_pem = cert_desc["certificatePem"] res = client.list_certificates() - for cert in res['certificates']: - cert.should.have.key('certificateArn').which.should_not.be.none - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('status').which.should_not.be.none - cert.should.have.key('creationDate').which.should_not.be.none + for cert in res["certificates"]: + cert.should.have.key("certificateArn").which.should_not.be.none + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("status").which.should_not.be.none + cert.should.have.key("creationDate").which.should_not.be.none - client.update_certificate(certificateId=cert_id, newStatus='REVOKED') + client.update_certificate(certificateId=cert_id, newStatus="REVOKED") cert = client.describe_certificate(certificateId=cert_id) - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('REVOKED') + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("REVOKED") client.delete_certificate(certificateId=cert_id) res = client.list_certificates() - res.should.have.key('certificates') + res.should.have.key("certificates") # Test register_certificate flow cert = client.register_certificate(certificatePem=cert_pem, setAsActive=True) - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('certificateArn').which.should_not.be.none - cert_id = cert['certificateId'] + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("certificateArn").which.should_not.be.none + cert_id = cert["certificateId"] res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) - for cert in res['certificates']: - cert.should.have.key('certificateArn').which.should_not.be.none - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('status').which.should_not.be.none - cert.should.have.key('creationDate').which.should_not.be.none + res.should.have.key("certificates").which.should.have.length_of(1) + for cert in res["certificates"]: + cert.should.have.key("certificateArn").which.should_not.be.none + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("status").which.should_not.be.none + cert.should.have.key("creationDate").which.should_not.be.none - client.update_certificate(certificateId=cert_id, newStatus='REVOKED') + client.update_certificate(certificateId=cert_id, newStatus="REVOKED") cert = client.describe_certificate(certificateId=cert_id) - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('REVOKED') + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("REVOKED") client.delete_certificate(certificateId=cert_id) res = client.list_certificates() - res.should.have.key('certificates') + res.should.have.key("certificates") @mock_iot @@ -424,24 +462,26 @@ def test_delete_policy_validation(): ] } """ - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] - policy_name = 'my-policy' + cert_arn = cert["certificateArn"] + policy_name = "my-policy" client.create_policy(policyName=policy_name, policyDocument=doc) client.attach_principal_policy(policyName=policy_name, principal=cert_arn) with assert_raises(ClientError) as e: client.delete_policy(policyName=policy_name) - e.exception.response['Error']['Message'].should.contain( - 'The policy cannot be deleted as the policy is attached to one or more principals (name=%s)' % policy_name) + e.exception.response["Error"]["Message"].should.contain( + "The policy cannot be deleted as the policy is attached to one or more principals (name=%s)" + % policy_name + ) res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(1) + res.should.have.key("policies").which.should.have.length_of(1) client.detach_principal_policy(policyName=policy_name, principal=cert_arn) client.delete_policy(policyName=policy_name) res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) @mock_iot @@ -459,12 +499,12 @@ def test_delete_certificate_validation(): ] } """ - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") cert = client.create_keys_and_certificate(setAsActive=True) - cert_id = cert['certificateId'] - cert_arn = cert['certificateArn'] - policy_name = 'my-policy' - thing_name = 'thing-1' + cert_id = cert["certificateId"] + cert_arn = cert["certificateArn"] + policy_name = "my-policy" + thing_name = "thing-1" client.create_policy(policyName=policy_name, policyDocument=doc) client.attach_principal_policy(policyName=policy_name, principal=cert_arn) client.create_thing(thingName=thing_name) @@ -472,366 +512,608 @@ def test_delete_certificate_validation(): with assert_raises(ClientError) as e: client.delete_certificate(certificateId=cert_id) - e.exception.response['Error']['Message'].should.contain( - 'Certificate must be deactivated (not ACTIVE) before deletion.') + e.exception.response["Error"]["Message"].should.contain( + "Certificate must be deactivated (not ACTIVE) before deletion." + ) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) + res.should.have.key("certificates").which.should.have.length_of(1) - client.update_certificate(certificateId=cert_id, newStatus='REVOKED') + client.update_certificate(certificateId=cert_id, newStatus="REVOKED") with assert_raises(ClientError) as e: client.delete_certificate(certificateId=cert_id) - e.exception.response['Error']['Message'].should.contain( - 'Things must be detached before deletion (arn: %s)' % cert_arn) + e.exception.response["Error"]["Message"].should.contain( + "Things must be detached before deletion (arn: %s)" % cert_arn + ) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) + res.should.have.key("certificates").which.should.have.length_of(1) client.detach_thing_principal(thingName=thing_name, principal=cert_arn) with assert_raises(ClientError) as e: client.delete_certificate(certificateId=cert_id) - e.exception.response['Error']['Message'].should.contain( - 'Certificate policies must be detached before deletion (arn: %s)' % cert_arn) + e.exception.response["Error"]["Message"].should.contain( + "Certificate policies must be detached before deletion (arn: %s)" % cert_arn + ) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) + res.should.have.key("certificates").which.should.have.length_of(1) client.detach_principal_policy(policyName=policy_name, principal=cert_arn) client.delete_certificate(certificateId=cert_id) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(0) + res.should.have.key("certificates").which.should.have.length_of(0) @mock_iot def test_certs_create_inactive(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") cert = client.create_keys_and_certificate(setAsActive=False) - cert_id = cert['certificateId'] + cert_id = cert["certificateId"] cert = client.describe_certificate(certificateId=cert_id) - cert.should.have.key('certificateDescription') - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('INACTIVE') + cert.should.have.key("certificateDescription") + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("INACTIVE") - client.update_certificate(certificateId=cert_id, newStatus='ACTIVE') + client.update_certificate(certificateId=cert_id, newStatus="ACTIVE") cert = client.describe_certificate(certificateId=cert_id) - cert.should.have.key('certificateDescription') - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('ACTIVE') + cert.should.have.key("certificateDescription") + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("ACTIVE") @mock_iot def test_policy(): - client = boto3.client('iot', region_name='ap-northeast-1') - name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + name = "my-policy" + doc = "{}" policy = client.create_policy(policyName=name, policyDocument=doc) - policy.should.have.key('policyName').which.should.equal(name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(doc) - policy.should.have.key('policyVersionId').which.should.equal('1') + policy.should.have.key("policyName").which.should.equal(name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(doc) + policy.should.have.key("policyVersionId").which.should.equal("1") policy = client.get_policy(policyName=name) - policy.should.have.key('policyName').which.should.equal(name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(doc) - policy.should.have.key('defaultVersionId').which.should.equal('1') + policy.should.have.key("policyName").which.should.equal(name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(doc) + policy.should.have.key("defaultVersionId").which.should.equal("1") res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none client.delete_policy(policyName=name) res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) @mock_iot def test_principal_policy(): - client = boto3.client('iot', region_name='ap-northeast-1') - policy_name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + policy_name = "my-policy" + doc = "{}" client.create_policy(policyName=policy_name, policyDocument=doc) cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.attach_policy(policyName=policy_name, target=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none # do nothing if policy have already attached to certificate client.attach_policy(policyName=policy_name, target=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(1) - for principal in res['principals']: + res.should.have.key("principals").which.should.have.length_of(1) + for principal in res["principals"]: principal.should_not.be.none client.detach_policy(policyName=policy_name, target=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(0) + res.should.have.key("principals").which.should.have.length_of(0) with assert_raises(ClientError) as e: client.detach_policy(policyName=policy_name, target=cert_arn) - e.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') + e.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") @mock_iot def test_principal_policy_deprecated(): - client = boto3.client('iot', region_name='ap-northeast-1') - policy_name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + policy_name = "my-policy" + doc = "{}" policy = client.create_policy(policyName=policy_name, policyDocument=doc) cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.attach_principal_policy(policyName=policy_name, principal=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(1) - for principal in res['principals']: + res.should.have.key("principals").which.should.have.length_of(1) + for principal in res["principals"]: principal.should_not.be.none client.detach_principal_policy(policyName=policy_name, principal=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(0) + res.should.have.key("principals").which.should.have.length_of(0) @mock_iot def test_principal_thing(): - client = boto3.client('iot', region_name='ap-northeast-1') - thing_name = 'my-thing' + client = boto3.client("iot", region_name="ap-northeast-1") + thing_name = "my-thing" thing = client.create_thing(thingName=thing_name) cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.attach_thing_principal(thingName=thing_name, principal=cert_arn) res = client.list_principal_things(principal=cert_arn) - res.should.have.key('things').which.should.have.length_of(1) - for thing in res['things']: + res.should.have.key("things").which.should.have.length_of(1) + for thing in res["things"]: thing.should_not.be.none res = client.list_thing_principals(thingName=thing_name) - res.should.have.key('principals').which.should.have.length_of(1) - for principal in res['principals']: + res.should.have.key("principals").which.should.have.length_of(1) + for principal in res["principals"]: principal.should_not.be.none client.detach_thing_principal(thingName=thing_name, principal=cert_arn) res = client.list_principal_things(principal=cert_arn) - res.should.have.key('things').which.should.have.length_of(0) + res.should.have.key("things").which.should.have.length_of(0) res = client.list_thing_principals(thingName=thing_name) - res.should.have.key('principals').which.should.have.length_of(0) + res.should.have.key("principals").which.should.have.length_of(0) + + +@mock_iot +def test_delete_principal_thing(): + client = boto3.client("iot", region_name="ap-northeast-1") + thing_name = "my-thing" + thing = client.create_thing(thingName=thing_name) + cert = client.create_keys_and_certificate(setAsActive=True) + cert_arn = cert["certificateArn"] + cert_id = cert["certificateId"] + + client.attach_thing_principal(thingName=thing_name, principal=cert_arn) + + client.delete_thing(thingName=thing_name) + res = client.list_principal_things(principal=cert_arn) + res.should.have.key("things").which.should.have.length_of(0) + + client.update_certificate(certificateId=cert_id, newStatus="INACTIVE") + client.delete_certificate(certificateId=cert_id) + + +@mock_iot +def test_describe_thing_group_metadata_hierarchy(): + client = boto3.client("iot", region_name="ap-northeast-1") + group_name_1a = "my-group-name-1a" + group_name_1b = "my-group-name-1b" + group_name_2a = "my-group-name-2a" + group_name_2b = "my-group-name-2b" + group_name_3a = "my-group-name-3a" + group_name_3b = "my-group-name-3b" + group_name_3c = "my-group-name-3c" + group_name_3d = "my-group-name-3d" + + # --1a + # |--2a + # | |--3a + # | |--3b + # | + # |--2b + # |--3c + # |--3d + # --1b + + # create thing groups tree + # 1 + thing_group1a = client.create_thing_group(thingGroupName=group_name_1a) + thing_group1a.should.have.key("thingGroupName").which.should.equal(group_name_1a) + thing_group1a.should.have.key("thingGroupArn") + thing_group1b = client.create_thing_group(thingGroupName=group_name_1b) + thing_group1b.should.have.key("thingGroupName").which.should.equal(group_name_1b) + thing_group1b.should.have.key("thingGroupArn") + # 2 + thing_group2a = client.create_thing_group( + thingGroupName=group_name_2a, parentGroupName=group_name_1a + ) + thing_group2a.should.have.key("thingGroupName").which.should.equal(group_name_2a) + thing_group2a.should.have.key("thingGroupArn") + thing_group2b = client.create_thing_group( + thingGroupName=group_name_2b, parentGroupName=group_name_1a + ) + thing_group2b.should.have.key("thingGroupName").which.should.equal(group_name_2b) + thing_group2b.should.have.key("thingGroupArn") + # 3 + thing_group3a = client.create_thing_group( + thingGroupName=group_name_3a, parentGroupName=group_name_2a + ) + thing_group3a.should.have.key("thingGroupName").which.should.equal(group_name_3a) + thing_group3a.should.have.key("thingGroupArn") + thing_group3b = client.create_thing_group( + thingGroupName=group_name_3b, parentGroupName=group_name_2a + ) + thing_group3b.should.have.key("thingGroupName").which.should.equal(group_name_3b) + thing_group3b.should.have.key("thingGroupArn") + thing_group3c = client.create_thing_group( + thingGroupName=group_name_3c, parentGroupName=group_name_2b + ) + thing_group3c.should.have.key("thingGroupName").which.should.equal(group_name_3c) + thing_group3c.should.have.key("thingGroupArn") + thing_group3d = client.create_thing_group( + thingGroupName=group_name_3d, parentGroupName=group_name_2b + ) + thing_group3d.should.have.key("thingGroupName").which.should.equal(group_name_3d) + thing_group3d.should.have.key("thingGroupArn") + + # describe groups + # groups level 1 + # 1a + thing_group_description1a = client.describe_thing_group( + thingGroupName=group_name_1a + ) + thing_group_description1a.should.have.key("thingGroupName").which.should.equal( + group_name_1a + ) + thing_group_description1a.should.have.key("thingGroupProperties") + thing_group_description1a.should.have.key("thingGroupMetadata") + thing_group_description1a["thingGroupMetadata"].should.have.key("creationDate") + thing_group_description1a.should.have.key("version") + # 1b + thing_group_description1b = client.describe_thing_group( + thingGroupName=group_name_1b + ) + thing_group_description1b.should.have.key("thingGroupName").which.should.equal( + group_name_1b + ) + thing_group_description1b.should.have.key("thingGroupProperties") + thing_group_description1b.should.have.key("thingGroupMetadata") + thing_group_description1b["thingGroupMetadata"].should.have.length_of(1) + thing_group_description1b["thingGroupMetadata"].should.have.key("creationDate") + thing_group_description1b.should.have.key("version") + # groups level 2 + # 2a + thing_group_description2a = client.describe_thing_group( + thingGroupName=group_name_2a + ) + thing_group_description2a.should.have.key("thingGroupName").which.should.equal( + group_name_2a + ) + thing_group_description2a.should.have.key("thingGroupProperties") + thing_group_description2a.should.have.key("thingGroupMetadata") + thing_group_description2a["thingGroupMetadata"].should.have.length_of(3) + thing_group_description2a["thingGroupMetadata"].should.have.key( + "parentGroupName" + ).being.equal(group_name_1a) + thing_group_description2a["thingGroupMetadata"].should.have.key( + "rootToParentThingGroups" + ) + thing_group_description2a["thingGroupMetadata"][ + "rootToParentThingGroups" + ].should.have.length_of(1) + thing_group_description2a["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupName" + ].should.match(group_name_1a) + thing_group_description2a["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupArn" + ].should.match(thing_group1a["thingGroupArn"]) + thing_group_description2a.should.have.key("version") + # 2b + thing_group_description2b = client.describe_thing_group( + thingGroupName=group_name_2b + ) + thing_group_description2b.should.have.key("thingGroupName").which.should.equal( + group_name_2b + ) + thing_group_description2b.should.have.key("thingGroupProperties") + thing_group_description2b.should.have.key("thingGroupMetadata") + thing_group_description2b["thingGroupMetadata"].should.have.length_of(3) + thing_group_description2b["thingGroupMetadata"].should.have.key( + "parentGroupName" + ).being.equal(group_name_1a) + thing_group_description2b["thingGroupMetadata"].should.have.key( + "rootToParentThingGroups" + ) + thing_group_description2b["thingGroupMetadata"][ + "rootToParentThingGroups" + ].should.have.length_of(1) + thing_group_description2b["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupName" + ].should.match(group_name_1a) + thing_group_description2b["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupArn" + ].should.match(thing_group1a["thingGroupArn"]) + thing_group_description2b.should.have.key("version") + # groups level 3 + # 3a + thing_group_description3a = client.describe_thing_group( + thingGroupName=group_name_3a + ) + thing_group_description3a.should.have.key("thingGroupName").which.should.equal( + group_name_3a + ) + thing_group_description3a.should.have.key("thingGroupProperties") + thing_group_description3a.should.have.key("thingGroupMetadata") + thing_group_description3a["thingGroupMetadata"].should.have.length_of(3) + thing_group_description3a["thingGroupMetadata"].should.have.key( + "parentGroupName" + ).being.equal(group_name_2a) + thing_group_description3a["thingGroupMetadata"].should.have.key( + "rootToParentThingGroups" + ) + thing_group_description3a["thingGroupMetadata"][ + "rootToParentThingGroups" + ].should.have.length_of(2) + thing_group_description3a["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupName" + ].should.match(group_name_1a) + thing_group_description3a["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupArn" + ].should.match(thing_group1a["thingGroupArn"]) + thing_group_description3a["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupName" + ].should.match(group_name_2a) + thing_group_description3a["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupArn" + ].should.match(thing_group2a["thingGroupArn"]) + thing_group_description3a.should.have.key("version") + # 3b + thing_group_description3b = client.describe_thing_group( + thingGroupName=group_name_3b + ) + thing_group_description3b.should.have.key("thingGroupName").which.should.equal( + group_name_3b + ) + thing_group_description3b.should.have.key("thingGroupProperties") + thing_group_description3b.should.have.key("thingGroupMetadata") + thing_group_description3b["thingGroupMetadata"].should.have.length_of(3) + thing_group_description3b["thingGroupMetadata"].should.have.key( + "parentGroupName" + ).being.equal(group_name_2a) + thing_group_description3b["thingGroupMetadata"].should.have.key( + "rootToParentThingGroups" + ) + thing_group_description3b["thingGroupMetadata"][ + "rootToParentThingGroups" + ].should.have.length_of(2) + thing_group_description3b["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupName" + ].should.match(group_name_1a) + thing_group_description3b["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupArn" + ].should.match(thing_group1a["thingGroupArn"]) + thing_group_description3b["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupName" + ].should.match(group_name_2a) + thing_group_description3b["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupArn" + ].should.match(thing_group2a["thingGroupArn"]) + thing_group_description3b.should.have.key("version") + # 3c + thing_group_description3c = client.describe_thing_group( + thingGroupName=group_name_3c + ) + thing_group_description3c.should.have.key("thingGroupName").which.should.equal( + group_name_3c + ) + thing_group_description3c.should.have.key("thingGroupProperties") + thing_group_description3c.should.have.key("thingGroupMetadata") + thing_group_description3c["thingGroupMetadata"].should.have.length_of(3) + thing_group_description3c["thingGroupMetadata"].should.have.key( + "parentGroupName" + ).being.equal(group_name_2b) + thing_group_description3c["thingGroupMetadata"].should.have.key( + "rootToParentThingGroups" + ) + thing_group_description3c["thingGroupMetadata"][ + "rootToParentThingGroups" + ].should.have.length_of(2) + thing_group_description3c["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupName" + ].should.match(group_name_1a) + thing_group_description3c["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupArn" + ].should.match(thing_group1a["thingGroupArn"]) + thing_group_description3c["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupName" + ].should.match(group_name_2b) + thing_group_description3c["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupArn" + ].should.match(thing_group2b["thingGroupArn"]) + thing_group_description3c.should.have.key("version") + # 3d + thing_group_description3d = client.describe_thing_group( + thingGroupName=group_name_3d + ) + thing_group_description3d.should.have.key("thingGroupName").which.should.equal( + group_name_3d + ) + thing_group_description3d.should.have.key("thingGroupProperties") + thing_group_description3d.should.have.key("thingGroupMetadata") + thing_group_description3d["thingGroupMetadata"].should.have.length_of(3) + thing_group_description3d["thingGroupMetadata"].should.have.key( + "parentGroupName" + ).being.equal(group_name_2b) + thing_group_description3d["thingGroupMetadata"].should.have.key( + "rootToParentThingGroups" + ) + thing_group_description3d["thingGroupMetadata"][ + "rootToParentThingGroups" + ].should.have.length_of(2) + thing_group_description3d["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupName" + ].should.match(group_name_1a) + thing_group_description3d["thingGroupMetadata"]["rootToParentThingGroups"][0][ + "groupArn" + ].should.match(thing_group1a["thingGroupArn"]) + thing_group_description3d["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupName" + ].should.match(group_name_2b) + thing_group_description3d["thingGroupMetadata"]["rootToParentThingGroups"][1][ + "groupArn" + ].should.match(thing_group2b["thingGroupArn"]) + thing_group_description3d.should.have.key("version") @mock_iot def test_thing_groups(): - client = boto3.client('iot', region_name='ap-northeast-1') - group_name = 'my-group-name' + client = boto3.client("iot", region_name="ap-northeast-1") + group_name = "my-group-name" # thing group thing_group = client.create_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupArn') + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupArn") res = client.list_thing_groups() - res.should.have.key('thingGroups').which.should.have.length_of(1) - for thing_group in res['thingGroups']: - thing_group.should.have.key('groupName').which.should_not.be.none - thing_group.should.have.key('groupArn').which.should_not.be.none + res.should.have.key("thingGroups").which.should.have.length_of(1) + for thing_group in res["thingGroups"]: + thing_group.should.have.key("groupName").which.should_not.be.none + thing_group.should.have.key("groupArn").which.should_not.be.none thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupProperties') - thing_group.should.have.key('thingGroupMetadata') - thing_group.should.have.key('version') + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupProperties") + thing_group.should.have.key("thingGroupMetadata") + thing_group.should.have.key("version") # delete thing group client.delete_thing_group(thingGroupName=group_name) res = client.list_thing_groups() - res.should.have.key('thingGroups').which.should.have.length_of(0) + res.should.have.key("thingGroups").which.should.have.length_of(0) # props create test props = { - 'thingGroupDescription': 'my first thing group', - 'attributePayload': { - 'attributes': { - 'key1': 'val01', - 'Key02': 'VAL2' - } - } + "thingGroupDescription": "my first thing group", + "attributePayload": {"attributes": {"key1": "val01", "Key02": "VAL2"}}, } - thing_group = client.create_thing_group(thingGroupName=group_name, thingGroupProperties=props) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupArn') + thing_group = client.create_thing_group( + thingGroupName=group_name, thingGroupProperties=props + ) + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupArn") thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupProperties') \ - .which.should.have.key('attributePayload') \ - .which.should.have.key('attributes') - res_props = thing_group['thingGroupProperties']['attributePayload']['attributes'] - res_props.should.have.key('key1').which.should.equal('val01') - res_props.should.have.key('Key02').which.should.equal('VAL2') + thing_group.should.have.key("thingGroupProperties").which.should.have.key( + "attributePayload" + ).which.should.have.key("attributes") + res_props = thing_group["thingGroupProperties"]["attributePayload"]["attributes"] + res_props.should.have.key("key1").which.should.equal("val01") + res_props.should.have.key("Key02").which.should.equal("VAL2") # props update test with merge - new_props = { - 'attributePayload': { - 'attributes': { - 'k3': 'v3' - }, - 'merge': True - } - } - client.update_thing_group( - thingGroupName=group_name, - thingGroupProperties=new_props - ) + new_props = {"attributePayload": {"attributes": {"k3": "v3"}, "merge": True}} + client.update_thing_group(thingGroupName=group_name, thingGroupProperties=new_props) thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupProperties') \ - .which.should.have.key('attributePayload') \ - .which.should.have.key('attributes') - res_props = thing_group['thingGroupProperties']['attributePayload']['attributes'] - res_props.should.have.key('key1').which.should.equal('val01') - res_props.should.have.key('Key02').which.should.equal('VAL2') + thing_group.should.have.key("thingGroupProperties").which.should.have.key( + "attributePayload" + ).which.should.have.key("attributes") + res_props = thing_group["thingGroupProperties"]["attributePayload"]["attributes"] + res_props.should.have.key("key1").which.should.equal("val01") + res_props.should.have.key("Key02").which.should.equal("VAL2") - res_props.should.have.key('k3').which.should.equal('v3') + res_props.should.have.key("k3").which.should.equal("v3") # props update test - new_props = { - 'attributePayload': { - 'attributes': { - 'k4': 'v4' - } - } - } - client.update_thing_group( - thingGroupName=group_name, - thingGroupProperties=new_props - ) + new_props = {"attributePayload": {"attributes": {"k4": "v4"}}} + client.update_thing_group(thingGroupName=group_name, thingGroupProperties=new_props) thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupProperties') \ - .which.should.have.key('attributePayload') \ - .which.should.have.key('attributes') - res_props = thing_group['thingGroupProperties']['attributePayload']['attributes'] - res_props.should.have.key('k4').which.should.equal('v4') - res_props.should_not.have.key('key1') + thing_group.should.have.key("thingGroupProperties").which.should.have.key( + "attributePayload" + ).which.should.have.key("attributes") + res_props = thing_group["thingGroupProperties"]["attributePayload"]["attributes"] + res_props.should.have.key("k4").which.should.equal("v4") + res_props.should_not.have.key("key1") @mock_iot def test_thing_group_relations(): - client = boto3.client('iot', region_name='ap-northeast-1') - name = 'my-thing' - group_name = 'my-group-name' + client = boto3.client("iot", region_name="ap-northeast-1") + name = "my-thing" + group_name = "my-group-name" # thing group thing_group = client.create_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupArn') + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupArn") # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # add in 4 way + client.add_thing_to_thing_group(thingGroupName=group_name, thingName=name) client.add_thing_to_thing_group( - thingGroupName=group_name, - thingName=name + thingGroupArn=thing_group["thingGroupArn"], thingArn=thing["thingArn"] ) client.add_thing_to_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingArn=thing['thingArn'] + thingGroupName=group_name, thingArn=thing["thingArn"] ) client.add_thing_to_thing_group( - thingGroupName=group_name, - thingArn=thing['thingArn'] - ) - client.add_thing_to_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingName=name + thingGroupArn=thing_group["thingGroupArn"], thingName=name ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(1) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(1) - thing_groups = client.list_thing_groups_for_thing( - thingName=name - ) - thing_groups.should.have.key('thingGroups') - thing_groups['thingGroups'].should.have.length_of(1) + thing_groups = client.list_thing_groups_for_thing(thingName=name) + thing_groups.should.have.key("thingGroups") + thing_groups["thingGroups"].should.have.length_of(1) # remove in 4 way + client.remove_thing_from_thing_group(thingGroupName=group_name, thingName=name) client.remove_thing_from_thing_group( - thingGroupName=group_name, - thingName=name + thingGroupArn=thing_group["thingGroupArn"], thingArn=thing["thingArn"] ) client.remove_thing_from_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingArn=thing['thingArn'] + thingGroupName=group_name, thingArn=thing["thingArn"] ) client.remove_thing_from_thing_group( - thingGroupName=group_name, - thingArn=thing['thingArn'] + thingGroupArn=thing_group["thingGroupArn"], thingName=name ) - client.remove_thing_from_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingName=name - ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(0) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(0) # update thing group for thing - client.update_thing_groups_for_thing( - thingName=name, - thingGroupsToAdd=[ - group_name - ] - ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(1) + client.update_thing_groups_for_thing(thingName=name, thingGroupsToAdd=[group_name]) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(1) client.update_thing_groups_for_thing( - thingName=name, - thingGroupsToRemove=[ - group_name - ] + thingName=name, thingGroupsToRemove=[group_name] ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(0) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(0) @mock_iot def test_create_job(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing# job document @@ -839,13 +1121,11 @@ def test_create_job(): # "field": "value" # } thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document - job_document = { - "field": "value" - } + job_document = {"field": "value"} job = client.create_job( jobId=job_id, @@ -853,23 +1133,21 @@ def test_create_job(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123, }, targetSelection="CONTINUOUS", - jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 - } + jobExecutionsRolloutConfig={"maximumPerMinute": 10}, ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') - job.should.have.key('description') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") + job.should.have.key("description") @mock_iot def test_list_jobs(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing# job document @@ -877,8 +1155,8 @@ def test_list_jobs(): # "field": "value" # } thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document job_document = { @@ -891,18 +1169,18 @@ def test_list_jobs(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job1.should.have.key('jobId').which.should.equal(job_id) - job1.should.have.key('jobArn') - job1.should.have.key('description') + job1.should.have.key("jobId").which.should.equal(job_id) + job1.should.have.key("jobArn") + job1.should.have.key("description") job2 = client.create_job( jobId=job_id+"1", @@ -910,244 +1188,245 @@ def test_list_jobs(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job2.should.have.key('jobId').which.should.equal(job_id+"1") - job2.should.have.key('jobArn') - job2.should.have.key('description') + job2.should.have.key("jobId").which.should.equal(job_id+"1") + job2.should.have.key("jobArn") + job2.should.have.key("description") jobs = client.list_jobs() - jobs.should.have.key('jobs') - jobs.should_not.have.key('nextToken') - jobs['jobs'][0].should.have.key('jobId').which.should.equal(job_id) - jobs['jobs'][1].should.have.key('jobId').which.should.equal(job_id+"1") + jobs.should.have.key("jobs") + jobs.should_not.have.key("nextToken") + jobs["jobs"][0].should.have.key("jobId").which.should.equal(job_id) + jobs["jobs"][1].should.have.key("jobId").which.should.equal(job_id+"1") @mock_iot def test_describe_job(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") job = client.create_job( jobId=job_id, targets=[thing["thingArn"]], documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123, }, targetSelection="CONTINUOUS", - jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 - } + jobExecutionsRolloutConfig={"maximumPerMinute": 10}, ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job = client.describe_job(jobId=job_id) - job.should.have.key('documentSource') - job.should.have.key('job') - job.should.have.key('job').which.should.have.key("jobArn") - job.should.have.key('job').which.should.have.key("jobId").which.should.equal(job_id) - job.should.have.key('job').which.should.have.key("targets") - job.should.have.key('job').which.should.have.key("jobProcessDetails") - job.should.have.key('job').which.should.have.key("lastUpdatedAt") - job.should.have.key('job').which.should.have.key("createdAt") - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig") - job.should.have.key('job').which.should.have.key("targetSelection").which.should.equal("CONTINUOUS") - job.should.have.key('job').which.should.have.key("presignedUrlConfig") - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( - "roleArn").which.should.equal('arn:aws:iam::1:role/service-role/iot_job_role') - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( - "expiresInSec").which.should.equal(123) - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig").which.should.have.key( - "maximumPerMinute").which.should.equal(10) + job.should.have.key("documentSource") + job.should.have.key("job") + job.should.have.key("job").which.should.have.key("jobArn") + job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("job").which.should.have.key("targets") + job.should.have.key("job").which.should.have.key("jobProcessDetails") + job.should.have.key("job").which.should.have.key("lastUpdatedAt") + job.should.have.key("job").which.should.have.key("createdAt") + job.should.have.key("job").which.should.have.key("jobExecutionsRolloutConfig") + job.should.have.key("job").which.should.have.key( + "targetSelection" + ).which.should.equal("CONTINUOUS") + job.should.have.key("job").which.should.have.key("presignedUrlConfig") + job.should.have.key("job").which.should.have.key( + "presignedUrlConfig" + ).which.should.have.key("roleArn").which.should.equal( + "arn:aws:iam::1:role/service-role/iot_job_role" + ) + job.should.have.key("job").which.should.have.key( + "presignedUrlConfig" + ).which.should.have.key("expiresInSec").which.should.equal(123) + job.should.have.key("job").which.should.have.key( + "jobExecutionsRolloutConfig" + ).which.should.have.key("maximumPerMinute").which.should.equal(10) @mock_iot def test_describe_job_1(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document - job_document = { - "field": "value" - } + job_document = {"field": "value"} job = client.create_job( jobId=job_id, targets=[thing["thingArn"]], document=json.dumps(job_document), presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123, }, targetSelection="CONTINUOUS", - jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 - } + jobExecutionsRolloutConfig={"maximumPerMinute": 10}, ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job = client.describe_job(jobId=job_id) - job.should.have.key('job') - job.should.have.key('job').which.should.have.key("jobArn") - job.should.have.key('job').which.should.have.key("jobId").which.should.equal(job_id) - job.should.have.key('job').which.should.have.key("targets") - job.should.have.key('job').which.should.have.key("jobProcessDetails") - job.should.have.key('job').which.should.have.key("lastUpdatedAt") - job.should.have.key('job').which.should.have.key("createdAt") - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig") - job.should.have.key('job').which.should.have.key("targetSelection").which.should.equal("CONTINUOUS") - job.should.have.key('job').which.should.have.key("presignedUrlConfig") - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( - "roleArn").which.should.equal('arn:aws:iam::1:role/service-role/iot_job_role') - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( + job.should.have.key("job") + job.should.have.key("job").which.should.have.key("jobArn") + job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("job").which.should.have.key("targets") + job.should.have.key("job").which.should.have.key("jobProcessDetails") + job.should.have.key("job").which.should.have.key("lastUpdatedAt") + job.should.have.key("job").which.should.have.key("createdAt") + job.should.have.key("job").which.should.have.key("jobExecutionsRolloutConfig") + job.should.have.key("job").which.should.have.key("targetSelection").which.should.equal("CONTINUOUS") + job.should.have.key("job").which.should.have.key("presignedUrlConfig") + job.should.have.key("job").which.should.have.key("presignedUrlConfig").which.should.have.key( + "roleArn").which.should.equal("arn:aws:iam::1:role/service-role/iot_job_role") + job.should.have.key("job").which.should.have.key("presignedUrlConfig").which.should.have.key( "expiresInSec").which.should.equal(123) - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig").which.should.have.key( + job.should.have.key("job").which.should.have.key("jobExecutionsRolloutConfig").which.should.have.key( "maximumPerMinute").which.should.equal(10) @mock_iot def test_delete_job(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") job = client.create_job( jobId=job_id, targets=[thing["thingArn"]], documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job = client.describe_job(jobId=job_id) - job.should.have.key('job') - job.should.have.key('job').which.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("job") + job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id) client.delete_job(jobId=job_id) - client.list_jobs()['jobs'].should.have.length_of(0) + client.list_jobs()["jobs"].should.have.length_of(0) @mock_iot def test_cancel_job(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") job = client.create_job( jobId=job_id, targets=[thing["thingArn"]], documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job = client.describe_job(jobId=job_id) - job.should.have.key('job') - job.should.have.key('job').which.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("job") + job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id) - job = client.cancel_job(jobId=job_id, reasonCode='Because', comment='You are') - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job = client.cancel_job(jobId=job_id, reasonCode="Because", comment="You are") + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job = client.describe_job(jobId=job_id) - job.should.have.key('job') - job.should.have.key('job').which.should.have.key("jobId").which.should.equal(job_id) - job.should.have.key('job').which.should.have.key("status").which.should.equal('CANCELED') - job.should.have.key('job').which.should.have.key("forceCanceled").which.should.equal(False) - job.should.have.key('job').which.should.have.key("reasonCode").which.should.equal('Because') - job.should.have.key('job').which.should.have.key("comment").which.should.equal('You are') + job.should.have.key("job") + job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("job").which.should.have.key("status").which.should.equal("CANCELED") + job.should.have.key("job").which.should.have.key("forceCanceled").which.should.equal(False) + job.should.have.key("job").which.should.have.key("reasonCode").which.should.equal("Because") + job.should.have.key("job").which.should.have.key("comment").which.should.equal("You are") @mock_iot def test_get_job_document_with_document_source(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") job = client.create_job( jobId=job_id, targets=[thing["thingArn"]], documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job_document = client.get_job_document(jobId=job_id) - job_document.should.have.key('document').which.should.equal('') + job_document.should.have.key("document").which.should.equal("") @mock_iot def test_get_job_document_with_document(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document job_document = { @@ -1159,31 +1438,31 @@ def test_get_job_document_with_document(): targets=[thing["thingArn"]], document=json.dumps(job_document), presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job_document = client.get_job_document(jobId=job_id) - job_document.should.have.key('document').which.should.equal("{\"field\": \"value\"}") + job_document.should.have.key("document").which.should.equal("{\"field\": \"value\"}") @mock_iot def test_describe_job_execution(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document job_document = { @@ -1196,65 +1475,65 @@ def test_describe_job_execution(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') - job.should.have.key('description') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") + job.should.have.key("description") job_execution = client.describe_job_execution(jobId=job_id, thingName=name) - job_execution.should.have.key('execution') - job_execution['execution'].should.have.key('jobId').which.should.equal(job_id) - job_execution['execution'].should.have.key('status').which.should.equal('QUEUED') - job_execution['execution'].should.have.key('forceCanceled').which.should.equal(False) - job_execution['execution'].should.have.key('statusDetails').which.should.equal({'detailsMap': {}}) - job_execution['execution'].should.have.key('thingArn').which.should.equal(thing["thingArn"]) - job_execution['execution'].should.have.key('queuedAt') - job_execution['execution'].should.have.key('startedAt') - job_execution['execution'].should.have.key('lastUpdatedAt') - job_execution['execution'].should.have.key('executionNumber').which.should.equal(123) - job_execution['execution'].should.have.key('versionNumber').which.should.equal(123) - job_execution['execution'].should.have.key('approximateSecondsBeforeTimedOut').which.should.equal(123) + job_execution.should.have.key("execution") + job_execution["execution"].should.have.key("jobId").which.should.equal(job_id) + job_execution["execution"].should.have.key("status").which.should.equal("QUEUED") + job_execution["execution"].should.have.key("forceCanceled").which.should.equal(False) + job_execution["execution"].should.have.key("statusDetails").which.should.equal({"detailsMap": {}}) + job_execution["execution"].should.have.key("thingArn").which.should.equal(thing["thingArn"]) + job_execution["execution"].should.have.key("queuedAt") + job_execution["execution"].should.have.key("startedAt") + job_execution["execution"].should.have.key("lastUpdatedAt") + job_execution["execution"].should.have.key("executionNumber").which.should.equal(123) + job_execution["execution"].should.have.key("versionNumber").which.should.equal(123) + job_execution["execution"].should.have.key("approximateSecondsBeforeTimedOut").which.should.equal(123) job_execution = client.describe_job_execution(jobId=job_id, thingName=name, executionNumber=123) - job_execution.should.have.key('execution') - job_execution['execution'].should.have.key('jobId').which.should.equal(job_id) - job_execution['execution'].should.have.key('status').which.should.equal('QUEUED') - job_execution['execution'].should.have.key('forceCanceled').which.should.equal(False) - job_execution['execution'].should.have.key('statusDetails').which.should.equal({'detailsMap': {}}) - job_execution['execution'].should.have.key('thingArn').which.should.equal(thing["thingArn"]) - job_execution['execution'].should.have.key('queuedAt') - job_execution['execution'].should.have.key('startedAt') - job_execution['execution'].should.have.key('lastUpdatedAt') - job_execution['execution'].should.have.key('executionNumber').which.should.equal(123) - job_execution['execution'].should.have.key('versionNumber').which.should.equal(123) - job_execution['execution'].should.have.key('approximateSecondsBeforeTimedOut').which.should.equal(123) + job_execution.should.have.key("execution") + job_execution["execution"].should.have.key("jobId").which.should.equal(job_id) + job_execution["execution"].should.have.key("status").which.should.equal("QUEUED") + job_execution["execution"].should.have.key("forceCanceled").which.should.equal(False) + job_execution["execution"].should.have.key("statusDetails").which.should.equal({"detailsMap": {}}) + job_execution["execution"].should.have.key("thingArn").which.should.equal(thing["thingArn"]) + job_execution["execution"].should.have.key("queuedAt") + job_execution["execution"].should.have.key("startedAt") + job_execution["execution"].should.have.key("lastUpdatedAt") + job_execution["execution"].should.have.key("executionNumber").which.should.equal(123) + job_execution["execution"].should.have.key("versionNumber").which.should.equal(123) + job_execution["execution"].should.have.key("approximateSecondsBeforeTimedOut").which.should.equal(123) try: client.describe_job_execution(jobId=job_id, thingName=name, executionNumber=456) except ClientError as exc: - error_code = exc.response['Error']['Code'] - error_code.should.equal('ResourceNotFoundException') + error_code = exc.response["Error"]["Code"] + error_code.should.equal("ResourceNotFoundException") else: raise Exception("Should have raised error") @mock_iot def test_cancel_job_execution(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document job_document = { @@ -1267,34 +1546,34 @@ def test_cancel_job_execution(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') - job.should.have.key('description') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") + job.should.have.key("description") client.cancel_job_execution(jobId=job_id, thingName=name) job_execution = client.describe_job_execution(jobId=job_id, thingName=name) - job_execution.should.have.key('execution') - job_execution['execution'].should.have.key('status').which.should.equal('CANCELED') + job_execution.should.have.key("execution") + job_execution["execution"].should.have.key("status").which.should.equal("CANCELED") @mock_iot def test_delete_job_execution(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document job_document = { @@ -1307,38 +1586,38 @@ def test_delete_job_execution(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') - job.should.have.key('description') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") + job.should.have.key("description") client.delete_job_execution(jobId=job_id, thingName=name, executionNumber=123) try: client.describe_job_execution(jobId=job_id, thingName=name, executionNumber=123) except ClientError as exc: - error_code = exc.response['Error']['Code'] - error_code.should.equal('ResourceNotFoundException') + error_code = exc.response["Error"]["Code"] + error_code.should.equal("ResourceNotFoundException") else: raise Exception("Should have raised error") @mock_iot def test_list_job_executions_for_job(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document job_document = { @@ -1351,33 +1630,33 @@ def test_list_job_executions_for_job(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') - job.should.have.key('description') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") + job.should.have.key("description") job_execution = client.list_job_executions_for_job(jobId=job_id) - job_execution.should.have.key('executionSummaries') - job_execution['executionSummaries'][0].should.have.key('thingArn').which.should.equal(thing["thingArn"]) + job_execution.should.have.key("executionSummaries") + job_execution["executionSummaries"][0].should.have.key("thingArn").which.should.equal(thing["thingArn"]) @mock_iot def test_list_job_executions_for_thing(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document job_document = { @@ -1390,20 +1669,20 @@ def test_list_job_executions_for_thing(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123 }, targetSelection="CONTINUOUS", jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 + "maximumPerMinute": 10 } ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') - job.should.have.key('description') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") + job.should.have.key("description") job_execution = client.list_job_executions_for_thing(thingName=name) - job_execution.should.have.key('executionSummaries') - job_execution['executionSummaries'][0].should.have.key('jobId').which.should.equal(job_id) + job_execution.should.have.key("executionSummaries") + job_execution["executionSummaries"][0].should.have.key("jobId").which.should.equal(job_id) diff --git a/tests/test_iot/test_server.py b/tests/test_iot/test_server.py index 60e81435a..b04f4d8ea 100644 --- a/tests/test_iot/test_server.py +++ b/tests/test_iot/test_server.py @@ -1,19 +1,20 @@ -from __future__ import unicode_literals - -import sure # noqa - -import moto.server as server -from moto import mock_iot - -''' -Test the different server responses -''' - -@mock_iot -def test_iot_list(): - backend = server.create_backend_app("iot") - test_client = backend.test_client() - - # just making sure that server is up - res = test_client.get('/things') - res.status_code.should.equal(404) +from __future__ import unicode_literals + +import sure # noqa + +import moto.server as server +from moto import mock_iot + +""" +Test the different server responses +""" + + +@mock_iot +def test_iot_list(): + backend = server.create_backend_app("iot") + test_client = backend.test_client() + + # just making sure that server is up + res = test_client.get("/things") + res.status_code.should.equal(404) diff --git a/tests/test_iotdata/test_iotdata.py b/tests/test_iotdata/test_iotdata.py index 8c03521f1..ac0a04244 100644 --- a/tests/test_iotdata/test_iotdata.py +++ b/tests/test_iotdata/test_iotdata.py @@ -1,93 +1,111 @@ -from __future__ import unicode_literals - -import json -import boto3 -import sure # noqa -from nose.tools import assert_raises -from botocore.exceptions import ClientError -from moto import mock_iotdata, mock_iot - - -@mock_iot -@mock_iotdata -def test_basic(): - iot_client = boto3.client('iot', region_name='ap-northeast-1') - client = boto3.client('iot-data', region_name='ap-northeast-1') - name = 'my-thing' - raw_payload = b'{"state": {"desired": {"led": "on"}}}' - iot_client.create_thing(thingName=name) - - with assert_raises(ClientError): - client.get_thing_shadow(thingName=name) - - res = client.update_thing_shadow(thingName=name, payload=raw_payload) - - payload = json.loads(res['payload'].read()) - expected_state = '{"desired": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') - - res = client.get_thing_shadow(thingName=name) - payload = json.loads(res['payload'].read()) - expected_state = b'{"desired": {"led": "on"}, "delta": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') - - client.delete_thing_shadow(thingName=name) - with assert_raises(ClientError): - client.get_thing_shadow(thingName=name) - - -@mock_iot -@mock_iotdata -def test_update(): - iot_client = boto3.client('iot', region_name='ap-northeast-1') - client = boto3.client('iot-data', region_name='ap-northeast-1') - name = 'my-thing' - raw_payload = b'{"state": {"desired": {"led": "on"}}}' - iot_client.create_thing(thingName=name) - - # first update - res = client.update_thing_shadow(thingName=name, payload=raw_payload) - payload = json.loads(res['payload'].read()) - expected_state = '{"desired": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') - - res = client.get_thing_shadow(thingName=name) - payload = json.loads(res['payload'].read()) - expected_state = b'{"desired": {"led": "on"}, "delta": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') - - # reporting new state - new_payload = b'{"state": {"reported": {"led": "on"}}}' - res = client.update_thing_shadow(thingName=name, payload=new_payload) - payload = json.loads(res['payload'].read()) - expected_state = '{"reported": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('reported').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(2) - payload.should.have.key('timestamp') - - res = client.get_thing_shadow(thingName=name) - payload = json.loads(res['payload'].read()) - expected_state = b'{"desired": {"led": "on"}, "reported": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(2) - payload.should.have.key('timestamp') - - -@mock_iotdata -def test_publish(): - client = boto3.client('iot-data', region_name='ap-northeast-1') - client.publish(topic='test/topic', qos=1, payload=b'') +from __future__ import unicode_literals + +import json +import boto3 +import sure # noqa +from nose.tools import assert_raises +from botocore.exceptions import ClientError +from moto import mock_iotdata, mock_iot + + +@mock_iot +@mock_iotdata +def test_basic(): + iot_client = boto3.client("iot", region_name="ap-northeast-1") + client = boto3.client("iot-data", region_name="ap-northeast-1") + name = "my-thing" + raw_payload = b'{"state": {"desired": {"led": "on"}}}' + iot_client.create_thing(thingName=name) + + with assert_raises(ClientError): + client.get_thing_shadow(thingName=name) + + res = client.update_thing_shadow(thingName=name, payload=raw_payload) + + payload = json.loads(res["payload"].read()) + expected_state = '{"desired": {"led": "on"}}' + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") + + res = client.get_thing_shadow(thingName=name) + payload = json.loads(res["payload"].read()) + expected_state = b'{"desired": {"led": "on"}, "delta": {"led": "on"}}' + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") + + client.delete_thing_shadow(thingName=name) + with assert_raises(ClientError): + client.get_thing_shadow(thingName=name) + + +@mock_iot +@mock_iotdata +def test_update(): + iot_client = boto3.client("iot", region_name="ap-northeast-1") + client = boto3.client("iot-data", region_name="ap-northeast-1") + name = "my-thing" + raw_payload = b'{"state": {"desired": {"led": "on"}}}' + iot_client.create_thing(thingName=name) + + # first update + res = client.update_thing_shadow(thingName=name, payload=raw_payload) + payload = json.loads(res["payload"].read()) + expected_state = '{"desired": {"led": "on"}}' + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") + + res = client.get_thing_shadow(thingName=name) + payload = json.loads(res["payload"].read()) + expected_state = b'{"desired": {"led": "on"}, "delta": {"led": "on"}}' + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") + + # reporting new state + new_payload = b'{"state": {"reported": {"led": "on"}}}' + res = client.update_thing_shadow(thingName=name, payload=new_payload) + payload = json.loads(res["payload"].read()) + expected_state = '{"reported": {"led": "on"}}' + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "reported" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(2) + payload.should.have.key("timestamp") + + res = client.get_thing_shadow(thingName=name) + payload = json.loads(res["payload"].read()) + expected_state = b'{"desired": {"led": "on"}, "reported": {"led": "on"}}' + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(2) + payload.should.have.key("timestamp") + + raw_payload = b'{"state": {"desired": {"led": "on"}}, "version": 1}' + with assert_raises(ClientError) as ex: + client.update_thing_shadow(thingName=name, payload=raw_payload) + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(409) + ex.exception.response["Error"]["Message"].should.equal("Version conflict") + + +@mock_iotdata +def test_publish(): + client = boto3.client("iot-data", region_name="ap-northeast-1") + client.publish(topic="test/topic", qos=1, payload=b"") diff --git a/tests/test_iotdata/test_server.py b/tests/test_iotdata/test_server.py index edcd92a33..bbced67b6 100644 --- a/tests/test_iotdata/test_server.py +++ b/tests/test_iotdata/test_server.py @@ -1,20 +1,21 @@ -from __future__ import unicode_literals - -import sure # noqa - -import moto.server as server -from moto import mock_iotdata - -''' -Test the different server responses -''' - -@mock_iotdata -def test_iotdata_list(): - backend = server.create_backend_app("iot-data") - test_client = backend.test_client() - - # just making sure that server is up - thing_name = 'nothing' - res = test_client.get('/things/{}/shadow'.format(thing_name)) - res.status_code.should.equal(404) +from __future__ import unicode_literals + +import sure # noqa + +import moto.server as server +from moto import mock_iotdata + +""" +Test the different server responses +""" + + +@mock_iotdata +def test_iotdata_list(): + backend = server.create_backend_app("iot-data") + test_client = backend.test_client() + + # just making sure that server is up + thing_name = "nothing" + res = test_client.get("/things/{}/shadow".format(thing_name)) + res.status_code.should.equal(404) diff --git a/tests/test_kinesis/test_firehose.py b/tests/test_kinesis/test_firehose.py index b13672e26..5e8c4aa08 100644 --- a/tests/test_kinesis/test_firehose.py +++ b/tests/test_kinesis/test_firehose.py @@ -1,188 +1,268 @@ -from __future__ import unicode_literals - -import datetime - -from botocore.exceptions import ClientError -import boto3 -import sure # noqa - -from moto import mock_kinesis - - -def create_stream(client, stream_name): - return client.create_delivery_stream( - DeliveryStreamName=stream_name, - RedshiftDestinationConfiguration={ - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'ClusterJDBCURL': 'jdbc:redshift://host.amazonaws.com:5439/database', - 'CopyCommand': { - 'DataTableName': 'outputTable', - 'CopyOptions': "CSV DELIMITER ',' NULL '\\0'" - }, - 'Username': 'username', - 'Password': 'password', - 'S3Configuration': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 - }, - 'CompressionFormat': 'UNCOMPRESSED', - } - } - ) - - -@mock_kinesis -def test_create_stream(): - client = boto3.client('firehose', region_name='us-east-1') - - response = create_stream(client, 'stream1') - stream_arn = response['DeliveryStreamARN'] - - response = client.describe_delivery_stream(DeliveryStreamName='stream1') - stream_description = response['DeliveryStreamDescription'] - - # Sure and Freezegun don't play nicely together - _ = stream_description.pop('CreateTimestamp') - _ = stream_description.pop('LastUpdateTimestamp') - - stream_description.should.equal({ - 'DeliveryStreamName': 'stream1', - 'DeliveryStreamARN': stream_arn, - 'DeliveryStreamStatus': 'ACTIVE', - 'VersionId': 'string', - 'Destinations': [ - { - 'DestinationId': 'string', - 'RedshiftDestinationDescription': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'ClusterJDBCURL': 'jdbc:redshift://host.amazonaws.com:5439/database', - 'CopyCommand': { - 'DataTableName': 'outputTable', - 'CopyOptions': "CSV DELIMITER ',' NULL '\\0'" - }, - 'Username': 'username', - 'S3DestinationDescription': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 - }, - 'CompressionFormat': 'UNCOMPRESSED', - } - } - }, - ], - "HasMoreDestinations": False, - }) - - -@mock_kinesis -def test_create_stream_without_redshift(): - client = boto3.client('firehose', region_name='us-east-1') - - response = client.create_delivery_stream( - DeliveryStreamName="stream1", - S3DestinationConfiguration={ - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 - }, - 'CompressionFormat': 'UNCOMPRESSED', - } - ) - stream_arn = response['DeliveryStreamARN'] - - response = client.describe_delivery_stream(DeliveryStreamName='stream1') - stream_description = response['DeliveryStreamDescription'] - - # Sure and Freezegun don't play nicely together - _ = stream_description.pop('CreateTimestamp') - _ = stream_description.pop('LastUpdateTimestamp') - - stream_description.should.equal({ - 'DeliveryStreamName': 'stream1', - 'DeliveryStreamARN': stream_arn, - 'DeliveryStreamStatus': 'ACTIVE', - 'VersionId': 'string', - 'Destinations': [ - { - 'DestinationId': 'string', - 'S3DestinationDescription': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 - }, - 'CompressionFormat': 'UNCOMPRESSED', - } - }, - ], - "HasMoreDestinations": False, - }) - - -@mock_kinesis -def test_deescribe_non_existant_stream(): - client = boto3.client('firehose', region_name='us-east-1') - - client.describe_delivery_stream.when.called_with( - DeliveryStreamName='not-a-stream').should.throw(ClientError) - - -@mock_kinesis -def test_list_and_delete_stream(): - client = boto3.client('firehose', region_name='us-east-1') - - create_stream(client, 'stream1') - create_stream(client, 'stream2') - - set(client.list_delivery_streams()['DeliveryStreamNames']).should.equal( - set(['stream1', 'stream2'])) - - client.delete_delivery_stream(DeliveryStreamName='stream1') - - set(client.list_delivery_streams()[ - 'DeliveryStreamNames']).should.equal(set(['stream2'])) - - -@mock_kinesis -def test_put_record(): - client = boto3.client('firehose', region_name='us-east-1') - - create_stream(client, 'stream1') - client.put_record( - DeliveryStreamName='stream1', - Record={ - 'Data': 'some data' - } - ) - - -@mock_kinesis -def test_put_record_batch(): - client = boto3.client('firehose', region_name='us-east-1') - - create_stream(client, 'stream1') - client.put_record_batch( - DeliveryStreamName='stream1', - Records=[ - { - 'Data': 'some data1' - }, - { - 'Data': 'some data2' - }, - ] - ) +from __future__ import unicode_literals + +import datetime + +from botocore.exceptions import ClientError +import boto3 +import sure # noqa + +from moto import mock_kinesis +from moto.core import ACCOUNT_ID + + +def create_s3_delivery_stream(client, stream_name): + return client.create_delivery_stream( + DeliveryStreamName=stream_name, + DeliveryStreamType="DirectPut", + ExtendedS3DestinationConfiguration={ + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format(ACCOUNT_ID), + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "CompressionFormat": "UNCOMPRESSED", + "DataFormatConversionConfiguration": { + "Enabled": True, + "InputFormatConfiguration": {"Deserializer": {"HiveJsonSerDe": {}}}, + "OutputFormatConfiguration": { + "Serializer": {"ParquetSerDe": {"Compression": "SNAPPY"}} + }, + "SchemaConfiguration": { + "DatabaseName": stream_name, + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "TableName": "outputTable", + }, + }, + }, + ) + + +def create_redshift_delivery_stream(client, stream_name): + return client.create_delivery_stream( + DeliveryStreamName=stream_name, + RedshiftDestinationConfiguration={ + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format(ACCOUNT_ID), + "ClusterJDBCURL": "jdbc:redshift://host.amazonaws.com:5439/database", + "CopyCommand": { + "DataTableName": "outputTable", + "CopyOptions": "CSV DELIMITER ',' NULL '\\0'", + }, + "Username": "username", + "Password": "password", + "S3Configuration": { + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": {"SizeInMBs": 123, "IntervalInSeconds": 124}, + "CompressionFormat": "UNCOMPRESSED", + }, + }, + ) + + +@mock_kinesis +def test_create_redshift_delivery_stream(): + client = boto3.client("firehose", region_name="us-east-1") + + response = create_redshift_delivery_stream(client, "stream1") + stream_arn = response["DeliveryStreamARN"] + + response = client.describe_delivery_stream(DeliveryStreamName="stream1") + stream_description = response["DeliveryStreamDescription"] + + # Sure and Freezegun don't play nicely together + _ = stream_description.pop("CreateTimestamp") + _ = stream_description.pop("LastUpdateTimestamp") + + stream_description.should.equal( + { + "DeliveryStreamName": "stream1", + "DeliveryStreamARN": stream_arn, + "DeliveryStreamStatus": "ACTIVE", + "VersionId": "string", + "Destinations": [ + { + "DestinationId": "string", + "RedshiftDestinationDescription": { + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "ClusterJDBCURL": "jdbc:redshift://host.amazonaws.com:5439/database", + "CopyCommand": { + "DataTableName": "outputTable", + "CopyOptions": "CSV DELIMITER ',' NULL '\\0'", + }, + "Username": "username", + "S3DestinationDescription": { + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": { + "SizeInMBs": 123, + "IntervalInSeconds": 124, + }, + "CompressionFormat": "UNCOMPRESSED", + }, + }, + } + ], + "HasMoreDestinations": False, + } + ) + + +@mock_kinesis +def test_create_s3_delivery_stream(): + client = boto3.client("firehose", region_name="us-east-1") + + response = create_s3_delivery_stream(client, "stream1") + stream_arn = response["DeliveryStreamARN"] + + response = client.describe_delivery_stream(DeliveryStreamName="stream1") + stream_description = response["DeliveryStreamDescription"] + + # Sure and Freezegun don't play nicely together + _ = stream_description.pop("CreateTimestamp") + _ = stream_description.pop("LastUpdateTimestamp") + + stream_description.should.equal( + { + "DeliveryStreamName": "stream1", + "DeliveryStreamARN": stream_arn, + "DeliveryStreamStatus": "ACTIVE", + "VersionId": "string", + "Destinations": [ + { + "DestinationId": "string", + "ExtendedS3DestinationDescription": { + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "CompressionFormat": "UNCOMPRESSED", + "DataFormatConversionConfiguration": { + "Enabled": True, + "InputFormatConfiguration": { + "Deserializer": {"HiveJsonSerDe": {}} + }, + "OutputFormatConfiguration": { + "Serializer": { + "ParquetSerDe": {"Compression": "SNAPPY"} + } + }, + "SchemaConfiguration": { + "DatabaseName": "stream1", + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "TableName": "outputTable", + }, + }, + }, + } + ], + "HasMoreDestinations": False, + } + ) + + +@mock_kinesis +def test_create_stream_without_redshift(): + client = boto3.client("firehose", region_name="us-east-1") + + response = client.create_delivery_stream( + DeliveryStreamName="stream1", + S3DestinationConfiguration={ + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format(ACCOUNT_ID), + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": {"SizeInMBs": 123, "IntervalInSeconds": 124}, + "CompressionFormat": "UNCOMPRESSED", + }, + ) + stream_arn = response["DeliveryStreamARN"] + + response = client.describe_delivery_stream(DeliveryStreamName="stream1") + stream_description = response["DeliveryStreamDescription"] + + # Sure and Freezegun don't play nicely together + _ = stream_description.pop("CreateTimestamp") + _ = stream_description.pop("LastUpdateTimestamp") + + stream_description.should.equal( + { + "DeliveryStreamName": "stream1", + "DeliveryStreamARN": stream_arn, + "DeliveryStreamStatus": "ACTIVE", + "VersionId": "string", + "Destinations": [ + { + "DestinationId": "string", + "S3DestinationDescription": { + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "RoleARN": "arn:aws:iam::{}:role/firehose_delivery_role".format( + ACCOUNT_ID + ), + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": {"SizeInMBs": 123, "IntervalInSeconds": 124}, + "CompressionFormat": "UNCOMPRESSED", + }, + } + ], + "HasMoreDestinations": False, + } + ) + + +@mock_kinesis +def test_deescribe_non_existant_stream(): + client = boto3.client("firehose", region_name="us-east-1") + + client.describe_delivery_stream.when.called_with( + DeliveryStreamName="not-a-stream" + ).should.throw(ClientError) + + +@mock_kinesis +def test_list_and_delete_stream(): + client = boto3.client("firehose", region_name="us-east-1") + + create_redshift_delivery_stream(client, "stream1") + create_redshift_delivery_stream(client, "stream2") + + set(client.list_delivery_streams()["DeliveryStreamNames"]).should.equal( + set(["stream1", "stream2"]) + ) + + client.delete_delivery_stream(DeliveryStreamName="stream1") + + set(client.list_delivery_streams()["DeliveryStreamNames"]).should.equal( + set(["stream2"]) + ) + + +@mock_kinesis +def test_put_record(): + client = boto3.client("firehose", region_name="us-east-1") + + create_redshift_delivery_stream(client, "stream1") + client.put_record(DeliveryStreamName="stream1", Record={"Data": "some data"}) + + +@mock_kinesis +def test_put_record_batch(): + client = boto3.client("firehose", region_name="us-east-1") + + create_redshift_delivery_stream(client, "stream1") + client.put_record_batch( + DeliveryStreamName="stream1", + Records=[{"Data": "some data1"}, {"Data": "some data2"}], + ) diff --git a/tests/test_kinesis/test_kinesis.py b/tests/test_kinesis/test_kinesis.py index e2de866fc..de1764892 100644 --- a/tests/test_kinesis/test_kinesis.py +++ b/tests/test_kinesis/test_kinesis.py @@ -5,10 +5,10 @@ import time import boto.kinesis import boto3 -from boto.kinesis.exceptions import ResourceNotFoundException, \ - InvalidArgumentException +from boto.kinesis.exceptions import ResourceNotFoundException, InvalidArgumentException from moto import mock_kinesis, mock_kinesis_deprecated +from moto.core import ACCOUNT_ID @mock_kinesis_deprecated @@ -23,18 +23,20 @@ def test_create_cluster(): stream["StreamName"].should.equal("my_stream") stream["HasMoreShards"].should.equal(False) stream["StreamARN"].should.equal( - "arn:aws:kinesis:us-west-2:123456789012:my_stream") + "arn:aws:kinesis:us-west-2:{}:my_stream".format(ACCOUNT_ID) + ) stream["StreamStatus"].should.equal("ACTIVE") - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(3) @mock_kinesis_deprecated def test_describe_non_existant_stream(): conn = boto.kinesis.connect_to_region("us-east-1") - conn.describe_stream.when.called_with( - "not-a-stream").should.throw(ResourceNotFoundException) + conn.describe_stream.when.called_with("not-a-stream").should.throw( + ResourceNotFoundException + ) @mock_kinesis_deprecated @@ -44,20 +46,21 @@ def test_list_and_delete_stream(): conn.create_stream("stream1", 1) conn.create_stream("stream2", 1) - conn.list_streams()['StreamNames'].should.have.length_of(2) + conn.list_streams()["StreamNames"].should.have.length_of(2) conn.delete_stream("stream2") - conn.list_streams()['StreamNames'].should.have.length_of(1) + conn.list_streams()["StreamNames"].should.have.length_of(1) # Delete invalid id - conn.delete_stream.when.called_with( - "not-a-stream").should.throw(ResourceNotFoundException) + conn.delete_stream.when.called_with("not-a-stream").should.throw( + ResourceNotFoundException + ) @mock_kinesis def test_list_many_streams(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") for i in range(11): conn.create_stream(StreamName="stream%d" % i, ShardCount=1) @@ -76,8 +79,8 @@ def test_list_many_streams(): @mock_kinesis def test_describe_stream_summary(): - conn = boto3.client('kinesis', region_name="us-west-2") - stream_name = 'my_stream_summary' + conn = boto3.client("kinesis", region_name="us-west-2") + stream_name = "my_stream_summary" shard_count = 5 conn.create_stream(StreamName=stream_name, ShardCount=shard_count) @@ -87,7 +90,8 @@ def test_describe_stream_summary(): stream["StreamName"].should.equal(stream_name) stream["OpenShardCount"].should.equal(shard_count) stream["StreamARN"].should.equal( - "arn:aws:kinesis:us-west-2:123456789012:{}".format(stream_name)) + "arn:aws:kinesis:us-west-2:{}:{}".format(ACCOUNT_ID, stream_name) + ) stream["StreamStatus"].should.equal("ACTIVE") @@ -99,15 +103,15 @@ def test_basic_shard_iterator(): conn.create_stream(stream_name, 1) response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) - shard_iterator = response['NextShardIterator'] - response['Records'].should.equal([]) - response['MillisBehindLatest'].should.equal(0) + shard_iterator = response["NextShardIterator"] + response["Records"].should.equal([]) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis_deprecated @@ -118,8 +122,8 @@ def test_get_invalid_shard_iterator(): conn.create_stream(stream_name, 1) conn.get_shard_iterator.when.called_with( - stream_name, "123", 'TRIM_HORIZON').should.throw( - ResourceNotFoundException) + stream_name, "123", "TRIM_HORIZON" + ).should.throw(ResourceNotFoundException) @mock_kinesis_deprecated @@ -132,21 +136,22 @@ def test_put_records(): data = "hello world" partition_key = "1234" - conn.put_record.when.called_with( - stream_name, data, 1234).should.throw(InvalidArgumentException) + conn.put_record.when.called_with(stream_name, data, 1234).should.throw( + InvalidArgumentException + ) conn.put_record(stream_name, data, partition_key) response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) - shard_iterator = response['NextShardIterator'] - response['Records'].should.have.length_of(1) - record = response['Records'][0] + shard_iterator = response["NextShardIterator"] + response["Records"].should.have.length_of(1) + record = response["Records"][0] record["Data"].should.equal("hello world") record["PartitionKey"].should.equal("1234") @@ -168,18 +173,18 @@ def test_get_records_limit(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Retrieve only 3 records response = conn.get_records(shard_iterator, limit=3) - response['Records'].should.have.length_of(3) + response["Records"].should.have.length_of(3) # Then get the rest of the results - next_shard_iterator = response['NextShardIterator'] + next_shard_iterator = response["NextShardIterator"] response = conn.get_records(next_shard_iterator) - response['Records'].should.have.length_of(2) + response["Records"].should.have.length_of(2) @mock_kinesis_deprecated @@ -196,23 +201,24 @@ def test_get_records_at_sequence_number(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Get the second record response = conn.get_records(shard_iterator, limit=2) - second_sequence_id = response['Records'][1]['SequenceNumber'] + second_sequence_id = response["Records"][1]["SequenceNumber"] # Then get a new iterator starting at that id response = conn.get_shard_iterator( - stream_name, shard_id, 'AT_SEQUENCE_NUMBER', second_sequence_id) - shard_iterator = response['ShardIterator'] + stream_name, shard_id, "AT_SEQUENCE_NUMBER", second_sequence_id + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) # And the first result returned should be the second item - response['Records'][0]['SequenceNumber'].should.equal(second_sequence_id) - response['Records'][0]['Data'].should.equal('2') + response["Records"][0]["SequenceNumber"].should.equal(second_sequence_id) + response["Records"][0]["Data"].should.equal("2") @mock_kinesis_deprecated @@ -229,23 +235,24 @@ def test_get_records_after_sequence_number(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Get the second record response = conn.get_records(shard_iterator, limit=2) - second_sequence_id = response['Records'][1]['SequenceNumber'] + second_sequence_id = response["Records"][1]["SequenceNumber"] # Then get a new iterator starting after that id response = conn.get_shard_iterator( - stream_name, shard_id, 'AFTER_SEQUENCE_NUMBER', second_sequence_id) - shard_iterator = response['ShardIterator'] + stream_name, shard_id, "AFTER_SEQUENCE_NUMBER", second_sequence_id + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) # And the first result returned should be the third item - response['Records'][0]['Data'].should.equal('3') - response['MillisBehindLatest'].should.equal(0) + response["Records"][0]["Data"].should.equal("3") + response["MillisBehindLatest"].should.equal(0) @mock_kinesis_deprecated @@ -262,42 +269,43 @@ def test_get_records_latest(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Get the second record response = conn.get_records(shard_iterator, limit=2) - second_sequence_id = response['Records'][1]['SequenceNumber'] + second_sequence_id = response["Records"][1]["SequenceNumber"] # Then get a new iterator starting after that id response = conn.get_shard_iterator( - stream_name, shard_id, 'LATEST', second_sequence_id) - shard_iterator = response['ShardIterator'] + stream_name, shard_id, "LATEST", second_sequence_id + ) + shard_iterator = response["ShardIterator"] # Write some more data conn.put_record(stream_name, "last_record", "last_record") response = conn.get_records(shard_iterator) # And the only result returned should be the new item - response['Records'].should.have.length_of(1) - response['Records'][0]['PartitionKey'].should.equal('last_record') - response['Records'][0]['Data'].should.equal('last_record') - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(1) + response["Records"][0]["PartitionKey"].should.equal("last_record") + response["Records"][0]["Data"].should.equal("last_record") + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_at_timestamp(): # AT_TIMESTAMP - Read the first record at or after the specified timestamp - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) # Create some data for index in range(1, 5): - conn.put_record(StreamName=stream_name, - Data=str(index), - PartitionKey=str(index)) + conn.put_record( + StreamName=stream_name, Data=str(index), PartitionKey=str(index) + ) # When boto3 floors the timestamp that we pass to get_shard_iterator to # second precision even though AWS supports ms precision: @@ -309,148 +317,143 @@ def test_get_records_at_timestamp(): keys = [str(i) for i in range(5, 10)] for k in keys: - conn.put_record(StreamName=stream_name, - Data=k, - PartitionKey=k) + conn.put_record(StreamName=stream_name, Data=k, PartitionKey=k) # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(len(keys)) - partition_keys = [r['PartitionKey'] for r in response['Records']] + response["Records"].should.have.length_of(len(keys)) + partition_keys = [r["PartitionKey"] for r in response["Records"]] partition_keys.should.equal(keys) - response['MillisBehindLatest'].should.equal(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_at_very_old_timestamp(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) # Create some data keys = [str(i) for i in range(1, 5)] for k in keys: - conn.put_record(StreamName=stream_name, - Data=k, - PartitionKey=k) + conn.put_record(StreamName=stream_name, Data=k, PartitionKey=k) # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=1) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=1, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(len(keys)) - partition_keys = [r['PartitionKey'] for r in response['Records']] + response["Records"].should.have.length_of(len(keys)) + partition_keys = [r["PartitionKey"] for r in response["Records"]] partition_keys.should.equal(keys) - response['MillisBehindLatest'].should.equal(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_timestamp_filtering(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) - conn.put_record(StreamName=stream_name, - Data='0', - PartitionKey='0') + conn.put_record(StreamName=stream_name, Data="0", PartitionKey="0") time.sleep(1.0) timestamp = datetime.datetime.utcnow() - conn.put_record(StreamName=stream_name, - Data='1', - PartitionKey='1') + conn.put_record(StreamName=stream_name, Data="1", PartitionKey="1") response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(1) - response['Records'][0]['PartitionKey'].should.equal('1') - response['Records'][0]['ApproximateArrivalTimestamp'].should.be. \ - greater_than(timestamp) - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(1) + response["Records"][0]["PartitionKey"].should.equal("1") + response["Records"][0]["ApproximateArrivalTimestamp"].should.be.greater_than( + timestamp + ) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_millis_behind_latest(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) - conn.put_record(StreamName=stream_name, - Data='0', - PartitionKey='0') + conn.put_record(StreamName=stream_name, Data="0", PartitionKey="0") time.sleep(1.0) - conn.put_record(StreamName=stream_name, - Data='1', - PartitionKey='1') + conn.put_record(StreamName=stream_name, Data="1", PartitionKey="1") response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON" + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator, Limit=1) - response['Records'].should.have.length_of(1) - response['MillisBehindLatest'].should.be.greater_than(0) + response["Records"].should.have.length_of(1) + response["MillisBehindLatest"].should.be.greater_than(0) @mock_kinesis def test_get_records_at_very_new_timestamp(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) # Create some data keys = [str(i) for i in range(1, 5)] for k in keys: - conn.put_record(StreamName=stream_name, - Data=k, - PartitionKey=k) + conn.put_record(StreamName=stream_name, Data=k, PartitionKey=k) timestamp = datetime.datetime.utcnow() + datetime.timedelta(seconds=1) # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(0) - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_from_empty_stream_at_timestamp(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) @@ -458,17 +461,19 @@ def test_get_records_from_empty_stream_at_timestamp(): # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(0) - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis_deprecated @@ -478,10 +483,10 @@ def test_invalid_shard_iterator_type(): conn.create_stream(stream_name, 1) response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] response = conn.get_shard_iterator.when.called_with( - stream_name, shard_id, 'invalid-type').should.throw( - InvalidArgumentException) + stream_name, shard_id, "invalid-type" + ).should.throw(InvalidArgumentException) @mock_kinesis_deprecated @@ -491,10 +496,10 @@ def test_add_tags(): conn.create_stream(stream_name, 1) conn.describe_stream(stream_name) - conn.add_tags_to_stream(stream_name, {'tag1': 'val1'}) - conn.add_tags_to_stream(stream_name, {'tag2': 'val2'}) - conn.add_tags_to_stream(stream_name, {'tag1': 'val3'}) - conn.add_tags_to_stream(stream_name, {'tag2': 'val4'}) + conn.add_tags_to_stream(stream_name, {"tag1": "val1"}) + conn.add_tags_to_stream(stream_name, {"tag2": "val2"}) + conn.add_tags_to_stream(stream_name, {"tag1": "val3"}) + conn.add_tags_to_stream(stream_name, {"tag2": "val4"}) @mock_kinesis_deprecated @@ -504,22 +509,38 @@ def test_list_tags(): conn.create_stream(stream_name, 1) conn.describe_stream(stream_name) - conn.add_tags_to_stream(stream_name, {'tag1': 'val1'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal('val1') - conn.add_tags_to_stream(stream_name, {'tag2': 'val2'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal('val2') - conn.add_tags_to_stream(stream_name, {'tag1': 'val3'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal('val3') - conn.add_tags_to_stream(stream_name, {'tag2': 'val4'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal('val4') + conn.add_tags_to_stream(stream_name, {"tag1": "val1"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal("val1") + conn.add_tags_to_stream(stream_name, {"tag2": "val2"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal("val2") + conn.add_tags_to_stream(stream_name, {"tag1": "val3"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal("val3") + conn.add_tags_to_stream(stream_name, {"tag2": "val4"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal("val4") @mock_kinesis_deprecated @@ -529,29 +550,45 @@ def test_remove_tags(): conn.create_stream(stream_name, 1) conn.describe_stream(stream_name) - conn.add_tags_to_stream(stream_name, {'tag1': 'val1'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal('val1') - conn.remove_tags_from_stream(stream_name, ['tag1']) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal(None) + conn.add_tags_to_stream(stream_name, {"tag1": "val1"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal("val1") + conn.remove_tags_from_stream(stream_name, ["tag1"]) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal(None) - conn.add_tags_to_stream(stream_name, {'tag2': 'val2'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal('val2') - conn.remove_tags_from_stream(stream_name, ['tag2']) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal(None) + conn.add_tags_to_stream(stream_name, {"tag2": "val2"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal("val2") + conn.remove_tags_from_stream(stream_name, ["tag2"]) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal(None) @mock_kinesis_deprecated def test_split_shard(): conn = boto.kinesis.connect_to_region("us-west-2") - stream_name = 'my_stream' + stream_name = "my_stream" conn.create_stream(stream_name, 2) @@ -562,44 +599,47 @@ def test_split_shard(): stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(2) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) - shard_range = shards[0]['HashKeyRange'] + shard_range = shards[0]["HashKeyRange"] new_starting_hash = ( - int(shard_range['EndingHashKey']) + int( - shard_range['StartingHashKey'])) // 2 - conn.split_shard("my_stream", shards[0]['ShardId'], str(new_starting_hash)) + int(shard_range["EndingHashKey"]) + int(shard_range["StartingHashKey"]) + ) // 2 + conn.split_shard("my_stream", shards[0]["ShardId"], str(new_starting_hash)) stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(3) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) - shard_range = shards[2]['HashKeyRange'] + shard_range = shards[2]["HashKeyRange"] new_starting_hash = ( - int(shard_range['EndingHashKey']) + int( - shard_range['StartingHashKey'])) // 2 - conn.split_shard("my_stream", shards[2]['ShardId'], str(new_starting_hash)) + int(shard_range["EndingHashKey"]) + int(shard_range["StartingHashKey"]) + ) // 2 + conn.split_shard("my_stream", shards[2]["ShardId"], str(new_starting_hash)) stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(4) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) @mock_kinesis_deprecated def test_merge_shards(): conn = boto.kinesis.connect_to_region("us-west-2") - stream_name = 'my_stream' + stream_name = "my_stream" conn.create_stream(stream_name, 4) @@ -610,38 +650,39 @@ def test_merge_shards(): stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(4) conn.merge_shards.when.called_with( - stream_name, 'shardId-000000000000', - 'shardId-000000000002').should.throw(InvalidArgumentException) + stream_name, "shardId-000000000000", "shardId-000000000002" + ).should.throw(InvalidArgumentException) stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(4) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) - conn.merge_shards(stream_name, 'shardId-000000000000', - 'shardId-000000000001') + conn.merge_shards(stream_name, "shardId-000000000000", "shardId-000000000001") stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(3) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) - conn.merge_shards(stream_name, 'shardId-000000000002', - 'shardId-000000000000') + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) + conn.merge_shards(stream_name, "shardId-000000000002", "shardId-000000000000") stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(2) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) diff --git a/tests/test_kinesis/test_server.py b/tests/test_kinesis/test_server.py index b88ab1bb2..3d7fdeee4 100644 --- a/tests/test_kinesis/test_server.py +++ b/tests/test_kinesis/test_server.py @@ -1,25 +1,22 @@ -from __future__ import unicode_literals - -import json -import sure # noqa - -import moto.server as server -from moto import mock_kinesis - -''' -Test the different server responses -''' - - -@mock_kinesis -def test_list_streams(): - backend = server.create_backend_app("kinesis") - test_client = backend.test_client() - - res = test_client.get('/?Action=ListStreams') - - json_data = json.loads(res.data.decode("utf-8")) - json_data.should.equal({ - "HasMoreStreams": False, - "StreamNames": [], - }) +from __future__ import unicode_literals + +import json +import sure # noqa + +import moto.server as server +from moto import mock_kinesis + +""" +Test the different server responses +""" + + +@mock_kinesis +def test_list_streams(): + backend = server.create_backend_app("kinesis") + test_client = backend.test_client() + + res = test_client.get("/?Action=ListStreams") + + json_data = json.loads(res.data.decode("utf-8")) + json_data.should.equal({"HasMoreStreams": False, "StreamNames": []}) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index f189fbe41..70fa68787 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -1,333 +1,454 @@ +# -*- coding: utf-8 -*- from __future__ import unicode_literals -import os, re -import boto3 -import boto.kms -import botocore.exceptions -from boto.exception import JSONResponseError -from boto.kms.exceptions import AlreadyExistsException, NotFoundException - -from moto.kms.exceptions import NotFoundException as MotoNotFoundException -import sure # noqa -from moto import mock_kms, mock_kms_deprecated -from nose.tools import assert_raises -from freezegun import freeze_time from datetime import date from datetime import datetime from dateutil.tz import tzutc +import base64 +import os +import re + +import boto3 +import boto.kms +import botocore.exceptions +import six +import sure # noqa +from boto.exception import JSONResponseError +from boto.kms.exceptions import AlreadyExistsException, NotFoundException +from freezegun import freeze_time +from nose.tools import assert_raises +from parameterized import parameterized + +from moto.kms.exceptions import NotFoundException as MotoNotFoundException +from moto import mock_kms, mock_kms_deprecated + +PLAINTEXT_VECTORS = ( + (b"some encodeable plaintext",), + (b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",), + ("some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",), +) + + +def _get_encoded_value(plaintext): + if isinstance(plaintext, six.binary_type): + return plaintext + + return plaintext.encode("utf-8") @mock_kms def test_create_key(): - conn = boto3.client('kms', region_name='us-east-1') + conn = boto3.client("kms", region_name="us-east-1") with freeze_time("2015-01-01 00:00:00"): - key = conn.create_key(Policy="my policy", - Description="my key", - KeyUsage='ENCRYPT_DECRYPT', - Tags=[ - { - 'TagKey': 'project', - 'TagValue': 'moto', - }, - ]) + key = conn.create_key( + Policy="my policy", + Description="my key", + KeyUsage="ENCRYPT_DECRYPT", + Tags=[{"TagKey": "project", "TagValue": "moto"}], + ) - key['KeyMetadata']['Description'].should.equal("my key") - key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - key['KeyMetadata']['Enabled'].should.equal(True) - key['KeyMetadata']['CreationDate'].should.be.a(date) + key["KeyMetadata"]["Description"].should.equal("my key") + key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + key["KeyMetadata"]["Enabled"].should.equal(True) + key["KeyMetadata"]["CreationDate"].should.be.a(date) @mock_kms_deprecated def test_describe_key(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] key = conn.describe_key(key_id) - key['KeyMetadata']['Description'].should.equal("my key") - key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") + key["KeyMetadata"]["Description"].should.equal("my key") + key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") @mock_kms_deprecated def test_describe_key_via_alias(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) - alias_key = conn.describe_key('alias/my-key-alias') - alias_key['KeyMetadata']['Description'].should.equal("my key") - alias_key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - alias_key['KeyMetadata']['Arn'].should.equal(key['KeyMetadata']['Arn']) + alias_key = conn.describe_key("alias/my-key-alias") + alias_key["KeyMetadata"]["Description"].should.equal("my key") + alias_key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + alias_key["KeyMetadata"]["Arn"].should.equal(key["KeyMetadata"]["Arn"]) @mock_kms_deprecated def test_describe_key_via_alias_not_found(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) - conn.describe_key.when.called_with( - 'alias/not-found-alias').should.throw(JSONResponseError) + conn.describe_key.when.called_with("alias/not-found-alias").should.throw( + NotFoundException + ) + + +@parameterized( + ( + ("alias/does-not-exist",), + ("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",), + ("invalid",), + ) +) +@mock_kms +def test_describe_key_via_alias_invalid_alias(key_id): + client = boto3.client("kms", region_name="us-east-1") + client.create_key(Description="key") + + with assert_raises(client.exceptions.NotFoundException): + client.describe_key(KeyId=key_id) @mock_kms_deprecated def test_describe_key_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - arn = key['KeyMetadata']['Arn'] + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + arn = key["KeyMetadata"]["Arn"] the_key = conn.describe_key(arn) - the_key['KeyMetadata']['Description'].should.equal("my key") - the_key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - the_key['KeyMetadata']['KeyId'].should.equal(key['KeyMetadata']['KeyId']) + the_key["KeyMetadata"]["Description"].should.equal("my key") + the_key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + the_key["KeyMetadata"]["KeyId"].should.equal(key["KeyMetadata"]["KeyId"]) @mock_kms_deprecated def test_describe_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.describe_key.when.called_with( - "not-a-key").should.throw(JSONResponseError) + conn.describe_key.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_list_keys(): conn = boto.kms.connect_to_region("us-west-2") - conn.create_key(policy="my policy", description="my key1", - key_usage='ENCRYPT_DECRYPT') - conn.create_key(policy="my policy", description="my key2", - key_usage='ENCRYPT_DECRYPT') + conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_key( + policy="my policy", description="my key2", key_usage="ENCRYPT_DECRYPT" + ) keys = conn.list_keys() - keys['Keys'].should.have.length_of(2) + keys["Keys"].should.have.length_of(2) @mock_kms_deprecated def test_enable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) @mock_kms_deprecated def test_enable_key_rotation_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['Arn'] + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["Arn"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) @mock_kms_deprecated def test_enable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.enable_key_rotation.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("not-a-key").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_enable_key_rotation_with_alias_name_should_fail(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) - alias_key = conn.describe_key('alias/my-key-alias') - alias_key['KeyMetadata']['Arn'].should.equal(key['KeyMetadata']['Arn']) + alias_key = conn.describe_key("alias/my-key-alias") + alias_key["KeyMetadata"]["Arn"].should.equal(key["KeyMetadata"]["Arn"]) - conn.enable_key_rotation.when.called_with( - 'alias/my-alias').should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("alias/my-alias").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_disable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) conn.disable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated -def test_encrypt(): - """ - test_encrypt - Using base64 encoding to merely test that the endpoint was called - """ +def test_generate_data_key(): conn = boto.kms.connect_to_region("us-west-2") - response = conn.encrypt('key_id', 'encryptme'.encode('utf-8')) - response['CiphertextBlob'].should.equal(b'ZW5jcnlwdG1l') - response['KeyId'].should.equal('key_id') + + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = conn.generate_data_key(key_id=key_id, number_of_bytes=32) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["KeyId"].should.equal(key_arn) -@mock_kms_deprecated -def test_decrypt(): - conn = boto.kms.connect_to_region('us-west-2') - response = conn.decrypt('ZW5jcnlwdG1l'.encode('utf-8')) - response['Plaintext'].should.equal(b'encryptme') - response['KeyId'].should.equal('key_id') +@mock_kms +def test_boto3_generate_data_key(): + kms = boto3.client("kms", region_name="us-west-2") + + key = kms.create_key() + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = kms.generate_data_key(KeyId=key_id, NumberOfBytes=32) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["KeyId"].should.equal(key_arn) + + +@parameterized(PLAINTEXT_VECTORS) +@mock_kms +def test_encrypt(plaintext): + client = boto3.client("kms", region_name="us-west-2") + + key = client.create_key(Description="key") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = client.encrypt(KeyId=key_id, Plaintext=plaintext) + response["CiphertextBlob"].should_not.equal(plaintext) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + + response["KeyId"].should.equal(key_arn) + + +@parameterized(PLAINTEXT_VECTORS) +@mock_kms +def test_decrypt(plaintext): + client = boto3.client("kms", region_name="us-west-2") + + key = client.create_key(Description="key") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext) + + client.create_key(Description="key") + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(encrypt_response["CiphertextBlob"], validate=True) + + decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"]) + + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(decrypt_response["Plaintext"], validate=True) + + decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext)) + decrypt_response["KeyId"].should.equal(key_arn) @mock_kms_deprecated def test_disable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.disable_key_rotation.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.disable_key_rotation.when.called_with("not-a-key").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_get_key_rotation_status_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.get_key_rotation_status.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.get_key_rotation_status.when.called_with("not-a-key").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_get_key_rotation_status(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated def test_create_key_defaults_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated def test_get_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('my policy') + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_get_key_policy_via_arn(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - policy = conn.get_key_policy(key['KeyMetadata']['Arn'], 'default') + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + policy = conn.get_key_policy(key["KeyMetadata"]["Arn"], "default") - policy['Policy'].should.equal('my policy') + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_put_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] - conn.put_key_policy(key_id, 'default', 'new policy') - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('new policy') + conn.put_key_policy(key_id, "default", "new policy") + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_put_key_policy_via_arn(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['Arn'] + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["Arn"] - conn.put_key_policy(key_id, 'default', 'new policy') - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('new policy') + conn.put_key_policy(key_id, "default", "new policy") + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_put_key_policy_via_alias_should_not_update(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) conn.put_key_policy.when.called_with( - 'alias/my-key-alias', 'default', 'new policy').should.throw(NotFoundException) + "alias/my-key-alias", "default", "new policy" + ).should.throw(NotFoundException) - policy = conn.get_key_policy(key['KeyMetadata']['KeyId'], 'default') - policy['Policy'].should.equal('my policy') + policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_put_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - conn.put_key_policy(key['KeyMetadata']['Arn'], 'default', 'new policy') + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + conn.put_key_policy(key["KeyMetadata"]["Arn"], "default", "new policy") - policy = conn.get_key_policy(key['KeyMetadata']['KeyId'], 'default') - policy['Policy'].should.equal('new policy') + policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_list_key_policies(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + key_id = key["KeyMetadata"]["KeyId"] policies = conn.list_key_policies(key_id) - policies['PolicyNames'].should.equal(['default']) + policies["PolicyNames"].should.equal(["default"]) @mock_kms_deprecated def test__create_alias__returns_none_if_correct(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - resp = kms.create_alias('alias/my-alias', key_id) + resp = kms.create_alias("alias/my-alias", key_id) resp.should.be.none @@ -336,13 +457,13 @@ def test__create_alias__returns_none_if_correct(): def test__create_alias__raises_if_reserved_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] reserved_aliases = [ - 'alias/aws/ebs', - 'alias/aws/s3', - 'alias/aws/redshift', - 'alias/aws/rds', + "alias/aws/ebs", + "alias/aws/s3", + "alias/aws/redshift", + "alias/aws/rds", ] for alias_name in reserved_aliases: @@ -351,9 +472,9 @@ def test__create_alias__raises_if_reserved_alias(): ex = err.exception ex.error_message.should.be.none - ex.error_code.should.equal('NotAuthorizedException') - ex.body.should.equal({'__type': 'NotAuthorizedException'}) - ex.reason.should.equal('Bad Request') + ex.error_code.should.equal("NotAuthorizedException") + ex.body.should.equal({"__type": "NotAuthorizedException"}) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -361,38 +482,39 @@ def test__create_alias__raises_if_reserved_alias(): def test__create_alias__can_create_multiple_aliases_for_same_key_id(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - kms.create_alias('alias/my-alias3', key_id).should.be.none - kms.create_alias('alias/my-alias4', key_id).should.be.none - kms.create_alias('alias/my-alias5', key_id).should.be.none + kms.create_alias("alias/my-alias3", key_id).should.be.none + kms.create_alias("alias/my-alias4", key_id).should.be.none + kms.create_alias("alias/my-alias5", key_id).should.be.none @mock_kms_deprecated def test__create_alias__raises_if_wrong_prefix(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] with assert_raises(JSONResponseError) as err: - kms.create_alias('wrongprefix/my-alias', key_id) + kms.create_alias("wrongprefix/my-alias", key_id) ex = err.exception - ex.error_message.should.equal('Invalid identifier') - ex.error_code.should.equal('ValidationException') - ex.body.should.equal({'message': 'Invalid identifier', - '__type': 'ValidationException'}) - ex.reason.should.equal('Bad Request') + ex.error_message.should.equal("Invalid identifier") + ex.error_code.should.equal("ValidationException") + ex.body.should.equal( + {"message": "Invalid identifier", "__type": "ValidationException"} + ) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_kms_deprecated def test__create_alias__raises_if_duplicate(): - region = 'us-west-2' + region = "us-west-2" kms = boto.kms.connect_to_region(region) create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" kms.create_alias(alias, key_id) @@ -400,15 +522,21 @@ def test__create_alias__raises_if_duplicate(): kms.create_alias(alias, key_id) ex = err.exception - ex.error_message.should.match(r'An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists' - .format(**locals())) + ex.error_message.should.match( + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format( + **locals() + ) + ) ex.error_code.should.be.none ex.box_usage.should.be.none ex.request_id.should.be.none - ex.body['message'].should.match(r'An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists' - .format(**locals())) - ex.body['__type'].should.equal('AlreadyExistsException') - ex.reason.should.equal('Bad Request') + ex.body["message"].should.match( + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format( + **locals() + ) + ) + ex.body["__type"].should.equal("AlreadyExistsException") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -416,25 +544,31 @@ def test__create_alias__raises_if_duplicate(): def test__create_alias__raises_if_alias_has_restricted_characters(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] alias_names_with_restricted_characters = [ - 'alias/my-alias!', - 'alias/my-alias$', - 'alias/my-alias@', + "alias/my-alias!", + "alias/my-alias$", + "alias/my-alias@", ] for alias_name in alias_names_with_restricted_characters: with assert_raises(JSONResponseError) as err: kms.create_alias(alias_name, key_id) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal( - "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format(**locals())) - ex.error_code.should.equal('ValidationException') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal( + "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format( + **locals() + ) + ) + ex.error_code.should.equal("ValidationException") ex.message.should.equal( - "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format(**locals())) - ex.reason.should.equal('Bad Request') + "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format( + **locals() + ) + ) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -444,47 +578,42 @@ def test__create_alias__raises_if_alias_has_colon_character(): # are accepted by regex ^[a-zA-Z0-9:/_-]+$ kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_restricted_characters = [ - 'alias/my:alias', - ] + alias_names_with_restricted_characters = ["alias/my:alias"] for alias_name in alias_names_with_restricted_characters: with assert_raises(JSONResponseError) as err: kms.create_alias(alias_name, key_id) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal( - "{alias_name} contains invalid characters for an alias".format(**locals())) - ex.error_code.should.equal('ValidationException') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal( + "{alias_name} contains invalid characters for an alias".format(**locals()) + ) + ex.error_code.should.equal("ValidationException") ex.message.should.equal( - "{alias_name} contains invalid characters for an alias".format(**locals())) - ex.reason.should.equal('Bad Request') + "{alias_name} contains invalid characters for an alias".format(**locals()) + ) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) +@parameterized((("alias/my-alias_/",), ("alias/my_alias-/",))) @mock_kms_deprecated -def test__create_alias__accepted_characters(): +def test__create_alias__accepted_characters(alias_name): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_accepted_characters = [ - 'alias/my-alias_/', - 'alias/my_alias-/', - ] - - for alias_name in alias_names_with_accepted_characters: - kms.create_alias(alias_name, key_id) + kms.create_alias(alias_name, key_id) @mock_kms_deprecated def test__create_alias__raises_if_target_key_id_is_existing_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" kms.create_alias(alias, key_id) @@ -492,11 +621,11 @@ def test__create_alias__raises_if_target_key_id_is_existing_alias(): kms.create_alias(alias, alias) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal('Aliases must refer to keys. Not aliases') - ex.error_code.should.equal('ValidationException') - ex.message.should.equal('Aliases must refer to keys. Not aliases') - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("Aliases must refer to keys. Not aliases") + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("Aliases must refer to keys. Not aliases") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -504,14 +633,14 @@ def test__create_alias__raises_if_target_key_id_is_existing_alias(): def test__delete_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" # added another alias here to make sure that the deletion of the alias can # be done when there are multiple existing aliases. another_create_resp = kms.create_key() - another_key_id = create_resp['KeyMetadata']['KeyId'] - another_alias = 'alias/another-alias' + another_key_id = create_resp["KeyMetadata"]["KeyId"] + another_alias = "alias/another-alias" kms.create_alias(alias, key_id) kms.create_alias(another_alias, another_key_id) @@ -529,35 +658,36 @@ def test__delete_alias__raises_if_wrong_prefix(): kms = boto.connect_kms() with assert_raises(JSONResponseError) as err: - kms.delete_alias('wrongprefix/my-alias') + kms.delete_alias("wrongprefix/my-alias") ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal('Invalid identifier') - ex.error_code.should.equal('ValidationException') - ex.message.should.equal('Invalid identifier') - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("Invalid identifier") + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("Invalid identifier") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_kms_deprecated def test__delete_alias__raises_if_alias_is_not_found(): - region = 'us-west-2' + region = "us-west-2" kms = boto.kms.connect_to_region(region) - alias_name = 'alias/unexisting-alias' + alias_name = "alias/unexisting-alias" with assert_raises(NotFoundException) as err: kms.delete_alias(alias_name) + expected_message_match = r"Alias arn:aws:kms:{region}:[0-9]{{12}}:{alias_name} is not found.".format( + region=region, alias_name=alias_name + ) ex = err.exception - ex.body['__type'].should.equal('NotFoundException') - ex.body['message'].should.match( - r'Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.'.format(**locals())) + ex.body["__type"].should.equal("NotFoundException") + ex.body["message"].should.match(expected_message_match) ex.box_usage.should.be.none ex.error_code.should.be.none - ex.message.should.match( - r'Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.'.format(**locals())) - ex.reason.should.equal('Bad Request') + ex.message.should.match(expected_message_match) + ex.reason.should.equal("Bad Request") ex.request_id.should.be.none ex.status.should.equal(400) @@ -568,197 +698,228 @@ def test__list_aliases(): kms = boto.kms.connect_to_region(region) create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - kms.create_alias('alias/my-alias1', key_id) - kms.create_alias('alias/my-alias2', key_id) - kms.create_alias('alias/my-alias3', key_id) + key_id = create_resp["KeyMetadata"]["KeyId"] + kms.create_alias("alias/my-alias1", key_id) + kms.create_alias("alias/my-alias2", key_id) + kms.create_alias("alias/my-alias3", key_id) resp = kms.list_aliases() - resp['Truncated'].should.be.false + resp["Truncated"].should.be.false - aliases = resp['Aliases'] + aliases = resp["Aliases"] def has_correct_arn(alias_obj): - alias_name = alias_obj['AliasName'] - alias_arn = alias_obj['AliasArn'] - return re.match(r'arn:aws:kms:{region}:\d{{12}}:{alias_name}'.format(region=region, alias_name=alias_name), - alias_arn) + alias_name = alias_obj["AliasName"] + alias_arn = alias_obj["AliasArn"] + return re.match( + r"arn:aws:kms:{region}:\d{{12}}:{alias_name}".format( + region=region, alias_name=alias_name + ), + alias_arn, + ) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/ebs' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/rds' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/redshift' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/s3' == alias['AliasName']]).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/ebs" == alias["AliasName"] + ] + ).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/rds" == alias["AliasName"] + ] + ).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/redshift" == alias["AliasName"] + ] + ).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/s3" == alias["AliasName"] + ] + ).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/my-alias1' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/my-alias2' == alias['AliasName']]).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/my-alias1" == alias["AliasName"] + ] + ).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/my-alias2" == alias["AliasName"] + ] + ).should.equal(1) - len([alias for alias in aliases if 'TargetKeyId' in alias and key_id == - alias['TargetKeyId']]).should.equal(3) + len( + [ + alias + for alias in aliases + if "TargetKeyId" in alias and key_id == alias["TargetKeyId"] + ] + ).should.equal(3) len(aliases).should.equal(7) -@mock_kms_deprecated -def test__assert_valid_key_id(): - from moto.kms.responses import _assert_valid_key_id - import uuid +@parameterized( + ( + ("not-a-uuid",), + ("alias/DoesNotExist",), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), + ("d25652e4-d2d2-49f7-929a-671ccda580c6",), + ( + "arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6", + ), + ) +) +@mock_kms +def test_invalid_key_ids(key_id): + client = boto3.client("kms", region_name="us-east-1") - _assert_valid_key_id.when.called_with( - "not-a-key").should.throw(MotoNotFoundException) - _assert_valid_key_id.when.called_with( - str(uuid.uuid4())).should_not.throw(MotoNotFoundException) + with assert_raises(client.exceptions.NotFoundException): + client.generate_data_key(KeyId=key_id, NumberOfBytes=5) @mock_kms_deprecated def test__assert_default_policy(): from moto.kms.responses import _assert_default_policy - _assert_default_policy.when.called_with( - "not-default").should.throw(MotoNotFoundException) - _assert_default_policy.when.called_with( - "default").should_not.throw(MotoNotFoundException) + _assert_default_policy.when.called_with("not-default").should.throw( + MotoNotFoundException + ) + _assert_default_policy.when.called_with("default").should_not.throw( + MotoNotFoundException + ) +@parameterized(PLAINTEXT_VECTORS) @mock_kms -def test_kms_encrypt_boto3(): - client = boto3.client('kms', region_name='us-east-1') - response = client.encrypt(KeyId='foo', Plaintext=b'bar') +def test_kms_encrypt_boto3(plaintext): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="key") + response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext) - response = client.decrypt(CiphertextBlob=response['CiphertextBlob']) - response['Plaintext'].should.equal(b'bar') + response = client.decrypt(CiphertextBlob=response["CiphertextBlob"]) + response["Plaintext"].should.equal(_get_encoded_value(plaintext)) @mock_kms def test_disable_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='disable-key') - client.disable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="disable-key") + client.disable_key(KeyId=key["KeyMetadata"]["KeyId"]) - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'Disabled' + assert result["KeyMetadata"]["KeyState"] == "Disabled" @mock_kms def test_enable_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='enable-key') - client.disable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) - client.enable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="enable-key") + client.disable_key(KeyId=key["KeyMetadata"]["KeyId"]) + client.enable_key(KeyId=key["KeyMetadata"]["KeyId"]) - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == True - assert result["KeyMetadata"]["KeyState"] == 'Enabled' + assert result["KeyMetadata"]["KeyState"] == "Enabled" @mock_kms def test_schedule_key_deletion(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='schedule-key-deletion') - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="schedule-key-deletion") + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": with freeze_time("2015-01-01 12:00:00"): - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] + assert response["DeletionDate"] == datetime( + 2015, 1, 31, 12, 0, tzinfo=tzutc() ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] - assert response['DeletionDate'] == datetime(2015, 1, 31, 12, 0, tzinfo=tzutc()) else: # Can't manipulate time in server mode - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion' - assert 'DeletionDate' in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "PendingDeletion" + assert "DeletionDate" in result["KeyMetadata"] @mock_kms def test_schedule_key_deletion_custom(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='schedule-key-deletion') - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="schedule-key-deletion") + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": with freeze_time("2015-01-01 12:00:00"): response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'], - PendingWindowInDays=7 + KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7 + ) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] + assert response["DeletionDate"] == datetime( + 2015, 1, 8, 12, 0, tzinfo=tzutc() ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] - assert response['DeletionDate'] == datetime(2015, 1, 8, 12, 0, tzinfo=tzutc()) else: # Can't manipulate time in server mode response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'], - PendingWindowInDays=7 + KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7 ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion' - assert 'DeletionDate' in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "PendingDeletion" + assert "DeletionDate" in result["KeyMetadata"] @mock_kms def test_cancel_key_deletion(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - response = client.cancel_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + response = client.cancel_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'Disabled' - assert 'DeletionDate' not in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "Disabled" + assert "DeletionDate" not in result["KeyMetadata"] @mock_kms def test_update_key_description(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='old_description') - key_id = key['KeyMetadata']['KeyId'] + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="old_description") + key_id = key["KeyMetadata"]["KeyId"] - result = client.update_key_description(KeyId=key_id, Description='new_description') - assert 'ResponseMetadata' in result + result = client.update_key_description(KeyId=key_id, Description="new_description") + assert "ResponseMetadata" in result @mock_kms def test_tag_resource(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) - keyid = response['KeyId'] + keyid = response["KeyId"] response = client.tag_resource( - KeyId=keyid, - Tags=[ - { - 'TagKey': 'string', - 'TagValue': 'string' - }, - ] + KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}] ) # Shouldn't have any data, just header @@ -767,226 +928,296 @@ def test_tag_resource(): @mock_kms def test_list_resource_tags(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) - keyid = response['KeyId'] + keyid = response["KeyId"] response = client.tag_resource( - KeyId=keyid, - Tags=[ - { - 'TagKey': 'string', - 'TagValue': 'string' - }, - ] + KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}] ) response = client.list_resource_tags(KeyId=keyid) - assert response['Tags'][0]['TagKey'] == 'string' - assert response['Tags'][0]['TagValue'] == 'string' + assert response["Tags"][0]["TagKey"] == "string" + assert response["Tags"][0]["TagValue"] == "string" +@parameterized( + ( + (dict(KeySpec="AES_256"), 32), + (dict(KeySpec="AES_128"), 16), + (dict(NumberOfBytes=64), 64), + (dict(NumberOfBytes=1), 1), + (dict(NumberOfBytes=1024), 1024), + ) +) @mock_kms -def test_generate_data_key_sizes(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') +def test_generate_data_key_sizes(kwargs, expected_key_length): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-size") - resp1 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' - ) - resp2 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_128' - ) - resp3 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - NumberOfBytes=64 - ) + response = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) - assert len(resp1['Plaintext']) == 32 - assert len(resp2['Plaintext']) == 16 - assert len(resp3['Plaintext']) == 64 + assert len(response["Plaintext"]) == expected_key_length @mock_kms def test_generate_data_key_decrypt(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-decrypt') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-decrypt") resp1 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' + KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256" ) - resp2 = client.decrypt( - CiphertextBlob=resp1['CiphertextBlob'] + resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"]) + + assert resp1["Plaintext"] == resp2["Plaintext"] + + +@parameterized( + ( + (dict(KeySpec="AES_257"),), + (dict(KeySpec="AES_128", NumberOfBytes=16),), + (dict(NumberOfBytes=2048),), + (dict(NumberOfBytes=0),), + (dict(),), ) - - assert resp1['Plaintext'] == resp2['Plaintext'] - - +) @mock_kms -def test_generate_data_key_invalid_size_params(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') +def test_generate_data_key_invalid_size_params(kwargs): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-size") - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_257' - ) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_128', - NumberOfBytes=16 - ) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - NumberOfBytes=2048 - ) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + with assert_raises( + (botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError) + ) as err: + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) +@parameterized( + ( + ("alias/DoesNotExist",), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), + ("d25652e4-d2d2-49f7-929a-671ccda580c6",), + ( + "arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6", + ), + ) +) @mock_kms -def test_generate_data_key_invalid_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') +def test_generate_data_key_invalid_key(key_id): + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key( - KeyId='alias/randomnonexistantkey', - KeySpec='AES_256' - ) + client.generate_data_key(KeyId=key_id, KeySpec="AES_256") - with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'] + '4', - KeySpec='AES_256' - ) + +@parameterized( + ( + ("alias/DoesExist", False), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False), + ("", True), + ("arn:aws:kms:us-east-1:012345678912:key/", True), + ) +) +@mock_kms +def test_generate_data_key_all_valid_key_ids(prefix, append_key_id): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key() + key_id = key["KeyMetadata"]["KeyId"] + client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id) + + target_id = prefix + if append_key_id: + target_id += key_id + + client.generate_data_key(KeyId=key_id, NumberOfBytes=32) @mock_kms def test_generate_data_key_without_plaintext_decrypt(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-decrypt') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-decrypt") resp1 = client.generate_data_key_without_plaintext( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' + KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256" ) - assert 'Plaintext' not in resp1 + assert "Plaintext" not in resp1 + + +@parameterized(PLAINTEXT_VECTORS) +@mock_kms +def test_re_encrypt_decrypt(plaintext): + client = boto3.client("kms", region_name="us-west-2") + + key_1 = client.create_key(Description="key 1") + key_1_id = key_1["KeyMetadata"]["KeyId"] + key_1_arn = key_1["KeyMetadata"]["Arn"] + key_2 = client.create_key(Description="key 2") + key_2_id = key_2["KeyMetadata"]["KeyId"] + key_2_arn = key_2["KeyMetadata"]["Arn"] + + encrypt_response = client.encrypt( + KeyId=key_1_id, Plaintext=plaintext, EncryptionContext={"encryption": "context"} + ) + + re_encrypt_response = client.re_encrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + SourceEncryptionContext={"encryption": "context"}, + DestinationKeyId=key_2_id, + DestinationEncryptionContext={"another": "context"}, + ) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(re_encrypt_response["CiphertextBlob"], validate=True) + + re_encrypt_response["SourceKeyId"].should.equal(key_1_arn) + re_encrypt_response["KeyId"].should.equal(key_2_arn) + + decrypt_response_1 = client.decrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + EncryptionContext={"encryption": "context"}, + ) + decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext)) + decrypt_response_1["KeyId"].should.equal(key_1_arn) + + decrypt_response_2 = client.decrypt( + CiphertextBlob=re_encrypt_response["CiphertextBlob"], + EncryptionContext={"another": "context"}, + ) + decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext)) + decrypt_response_2["KeyId"].should.equal(key_2_arn) + + decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"]) + + +@mock_kms +def test_re_encrypt_to_invalid_destination(): + client = boto3.client("kms", region_name="us-west-2") + + key = client.create_key(Description="key 1") + key_id = key["KeyMetadata"]["KeyId"] + + encrypt_response = client.encrypt(KeyId=key_id, Plaintext=b"some plaintext") + + with assert_raises(client.exceptions.NotFoundException): + client.re_encrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + DestinationKeyId="alias/DoesNotExist", + ) + + +@parameterized(((12,), (44,), (91,), (1,), (1024,))) +@mock_kms +def test_generate_random(number_of_bytes): + client = boto3.client("kms", region_name="us-west-2") + + response = client.generate_random(NumberOfBytes=number_of_bytes) + + response["Plaintext"].should.be.a(bytes) + len(response["Plaintext"]).should.equal(number_of_bytes) + + +@parameterized( + ( + (2048, botocore.exceptions.ClientError), + (1025, botocore.exceptions.ClientError), + (0, botocore.exceptions.ParamValidationError), + (-1, botocore.exceptions.ParamValidationError), + (-1024, botocore.exceptions.ParamValidationError), + ) +) +@mock_kms +def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type): + client = boto3.client("kms", region_name="us-west-2") + + with assert_raises(error_type): + client.generate_random(NumberOfBytes=number_of_bytes) @mock_kms def test_enable_key_rotation_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.enable_key_rotation( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.enable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_disable_key_rotation_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.disable_key_rotation( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.disable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_enable_key_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.enable_key( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.enable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_disable_key_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.disable_key( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.disable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_cancel_key_deletion_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.cancel_key_deletion( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.cancel_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_schedule_key_deletion_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.schedule_key_deletion( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.schedule_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_get_key_rotation_status_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.get_key_rotation_status( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.get_key_rotation_status(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_get_key_policy_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): client.get_key_policy( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02', - PolicyName='default' + KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default" ) @mock_kms def test_list_key_policies_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.list_key_policies( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.list_key_policies(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_put_key_policy_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): client.put_key_policy( - KeyId='00000000-0000-0000-0000-000000000000', - PolicyName='default', - Policy='new policy' + KeyId="00000000-0000-0000-0000-000000000000", + PolicyName="default", + Policy="new policy", ) diff --git a/tests/test_kms/test_server.py b/tests/test_kms/test_server.py index a5aac7d94..083f9d18a 100644 --- a/tests/test_kms/test_server.py +++ b/tests/test_kms/test_server.py @@ -1,25 +1,23 @@ -from __future__ import unicode_literals - -import json -import sure # noqa - -import moto.server as server -from moto import mock_kms - -''' -Test the different server responses -''' - - -@mock_kms -def test_list_keys(): - backend = server.create_backend_app("kms") - test_client = backend.test_client() - - res = test_client.get('/?Action=ListKeys') - - json.loads(res.data.decode("utf-8")).should.equal({ - "Keys": [], - "NextMarker": None, - "Truncated": False, - }) +from __future__ import unicode_literals + +import json +import sure # noqa + +import moto.server as server +from moto import mock_kms + +""" +Test the different server responses +""" + + +@mock_kms +def test_list_keys(): + backend = server.create_backend_app("kms") + test_client = backend.test_client() + + res = test_client.get("/?Action=ListKeys") + + json.loads(res.data.decode("utf-8")).should.equal( + {"Keys": [], "NextMarker": None, "Truncated": False} + ) diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py new file mode 100644 index 000000000..f5478e0ef --- /dev/null +++ b/tests/test_kms/test_utils.py @@ -0,0 +1,189 @@ +from __future__ import unicode_literals + +import sure # noqa +from nose.tools import assert_raises +from parameterized import parameterized + +from moto.kms.exceptions import ( + AccessDeniedException, + InvalidCiphertextException, + NotFoundException, +) +from moto.kms.models import Key +from moto.kms.utils import ( + _deserialize_ciphertext_blob, + _serialize_ciphertext_blob, + _serialize_encryption_context, + generate_data_key, + generate_master_key, + MASTER_KEY_LEN, + encrypt, + decrypt, + Ciphertext, +) + +ENCRYPTION_CONTEXT_VECTORS = ( + ( + {"this": "is", "an": "encryption", "context": "example"}, + b"an" b"encryption" b"context" b"example" b"this" b"is", + ), + ( + {"a_this": "one", "b_is": "actually", "c_in": "order"}, + b"a_this" b"one" b"b_is" b"actually" b"c_in" b"order", + ), +) +CIPHERTEXT_BLOB_VECTORS = ( + ( + Ciphertext( + key_id="d25652e4-d2d2-49f7-929a-671ccda580c6", + iv=b"123456789012", + ciphertext=b"some ciphertext", + tag=b"1234567890123456", + ), + b"d25652e4-d2d2-49f7-929a-671ccda580c6" + b"123456789012" + b"1234567890123456" + b"some ciphertext", + ), + ( + Ciphertext( + key_id="d25652e4-d2d2-49f7-929a-671ccda580c6", + iv=b"123456789012", + ciphertext=b"some ciphertext that is much longer now", + tag=b"1234567890123456", + ), + b"d25652e4-d2d2-49f7-929a-671ccda580c6" + b"123456789012" + b"1234567890123456" + b"some ciphertext that is much longer now", + ), +) + + +def test_generate_data_key(): + test = generate_data_key(123) + + test.should.be.a(bytes) + len(test).should.equal(123) + + +def test_generate_master_key(): + test = generate_master_key() + + test.should.be.a(bytes) + len(test).should.equal(MASTER_KEY_LEN) + + +@parameterized(ENCRYPTION_CONTEXT_VECTORS) +def test_serialize_encryption_context(raw, serialized): + test = _serialize_encryption_context(raw) + test.should.equal(serialized) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_cycle_ciphertext_blob(raw, _serialized): + test_serialized = _serialize_ciphertext_blob(raw) + test_deserialized = _deserialize_ciphertext_blob(test_serialized) + test_deserialized.should.equal(raw) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_serialize_ciphertext_blob(raw, serialized): + test = _serialize_ciphertext_blob(raw) + test.should.equal(serialized) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_deserialize_ciphertext_blob(raw, serialized): + test = _deserialize_ciphertext_blob(serialized) + test.should.equal(raw) + + +@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) +def test_encrypt_decrypt_cycle(encryption_context): + plaintext = b"some secret plaintext" + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + ciphertext_blob = encrypt( + master_keys=master_key_map, + key_id=master_key.id, + plaintext=plaintext, + encryption_context=encryption_context, + ) + ciphertext_blob.should_not.equal(plaintext) + + decrypted, decrypting_key_id = decrypt( + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context=encryption_context, + ) + decrypted.should.equal(plaintext) + decrypting_key_id.should.equal(master_key.id) + + +def test_encrypt_unknown_key_id(): + with assert_raises(NotFoundException): + encrypt( + master_keys={}, + key_id="anything", + plaintext=b"secrets", + encryption_context={}, + ) + + +def test_decrypt_invalid_ciphertext_format(): + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + with assert_raises(InvalidCiphertextException): + decrypt(master_keys=master_key_map, ciphertext_blob=b"", encryption_context={}) + + +def test_decrypt_unknwown_key_id(): + ciphertext_blob = ( + b"d25652e4-d2d2-49f7-929a-671ccda580c6" + b"123456789012" + b"1234567890123456" + b"some ciphertext" + ) + + with assert_raises(AccessDeniedException): + decrypt(master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={}) + + +def test_decrypt_invalid_ciphertext(): + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + ciphertext_blob = ( + master_key.id.encode("utf-8") + b"123456789012" + b"1234567890123456" + b"some ciphertext" + ) + + with assert_raises(InvalidCiphertextException): + decrypt( + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) + + +def test_decrypt_invalid_encryption_context(): + plaintext = b"some secret plaintext" + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + ciphertext_blob = encrypt( + master_keys=master_key_map, + key_id=master_key.id, + plaintext=plaintext, + encryption_context={"some": "encryption", "context": "here"}, + ) + + with assert_raises(InvalidCiphertextException): + decrypt( + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) diff --git a/tests/test_logs/test_logs.py b/tests/test_logs/test_logs.py index 49e593fdc..e8f60ff03 100644 --- a/tests/test_logs/test_logs.py +++ b/tests/test_logs/test_logs.py @@ -1,98 +1,80 @@ import boto3 +import os import sure # noqa import six from botocore.exceptions import ClientError from moto import mock_logs, settings from nose.tools import assert_raises +from nose import SkipTest -_logs_region = 'us-east-1' if settings.TEST_SERVER_MODE else 'us-west-2' +_logs_region = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2" @mock_logs def test_log_group_create(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" response = conn.create_log_group(logGroupName=log_group_name) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 + assert len(response["logGroups"]) == 1 # AWS defaults to Never Expire for log group retention - assert response['logGroups'][0].get('retentionInDays') == None + assert response["logGroups"][0].get("retentionInDays") == None response = conn.delete_log_group(logGroupName=log_group_name) @mock_logs def test_exceptions(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - log_stream_name = 'dummp-stream' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "dummp-stream" conn.create_log_group(logGroupName=log_group_name) with assert_raises(ClientError): conn.create_log_group(logGroupName=log_group_name) # descrine_log_groups is not implemented yet - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) with assert_raises(ClientError): conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name + logGroupName=log_group_name, logStreamName=log_stream_name ) conn.put_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, - logEvents=[ - { - 'timestamp': 0, - 'message': 'line' - }, - ], + logEvents=[{"timestamp": 0, "message": "line"}], ) with assert_raises(ClientError): conn.put_log_events( logGroupName=log_group_name, logStreamName="invalid-stream", - logEvents=[ - { - 'timestamp': 0, - 'message': 'line' - }, - ], + logEvents=[{"timestamp": 0, "message": "line"}], ) @mock_logs def test_put_logs(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - log_stream_name = 'stream' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "stream" conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) messages = [ - {'timestamp': 0, 'message': 'hello'}, - {'timestamp': 0, 'message': 'world'} + {"timestamp": 0, "message": "hello"}, + {"timestamp": 0, "message": "world"}, ] putRes = conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=messages + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=messages ) res = conn.get_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name + logGroupName=log_group_name, logStreamName=log_stream_name ) - events = res['events'] - nextSequenceToken = putRes['nextSequenceToken'] + events = res["events"] + nextSequenceToken = putRes["nextSequenceToken"] assert isinstance(nextSequenceToken, six.string_types) == True assert len(nextSequenceToken) == 56 events.should.have.length_of(2) @@ -100,125 +82,349 @@ def test_put_logs(): @mock_logs def test_filter_logs_interleaved(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - log_stream_name = 'stream' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "stream" conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) messages = [ - {'timestamp': 0, 'message': 'hello'}, - {'timestamp': 0, 'message': 'world'} + {"timestamp": 0, "message": "hello"}, + {"timestamp": 0, "message": "world"}, ] conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=messages + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=messages ) res = conn.filter_log_events( - logGroupName=log_group_name, - logStreamNames=[log_stream_name], - interleaved=True, + logGroupName=log_group_name, logStreamNames=[log_stream_name], interleaved=True ) - events = res['events'] + events = res["events"] for original_message, resulting_event in zip(messages, events): - resulting_event['eventId'].should.equal(str(resulting_event['eventId'])) - resulting_event['timestamp'].should.equal(original_message['timestamp']) - resulting_event['message'].should.equal(original_message['message']) + resulting_event["eventId"].should.equal(str(resulting_event["eventId"])) + resulting_event["timestamp"].should.equal(original_message["timestamp"]) + resulting_event["message"].should.equal(original_message["message"]) + + +@mock_logs +def test_filter_logs_raises_if_filter_pattern(): + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Does not work in server mode due to error in Workzeug") + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "stream" + conn.create_log_group(logGroupName=log_group_name) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) + messages = [ + {"timestamp": 0, "message": "hello"}, + {"timestamp": 0, "message": "world"}, + ] + conn.put_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=messages + ) + with assert_raises(NotImplementedError): + conn.filter_log_events( + logGroupName=log_group_name, + logStreamNames=[log_stream_name], + filterPattern='{$.message = "hello"}', + ) + @mock_logs def test_put_retention_policy(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" response = conn.create_log_group(logGroupName=log_group_name) response = conn.put_retention_policy(logGroupName=log_group_name, retentionInDays=7) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 - assert response['logGroups'][0].get('retentionInDays') == 7 + assert len(response["logGroups"]) == 1 + assert response["logGroups"][0].get("retentionInDays") == 7 response = conn.delete_log_group(logGroupName=log_group_name) + @mock_logs def test_delete_retention_policy(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" response = conn.create_log_group(logGroupName=log_group_name) response = conn.put_retention_policy(logGroupName=log_group_name, retentionInDays=7) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 - assert response['logGroups'][0].get('retentionInDays') == 7 + assert len(response["logGroups"]) == 1 + assert response["logGroups"][0].get("retentionInDays") == 7 response = conn.delete_retention_policy(logGroupName=log_group_name) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 - assert response['logGroups'][0].get('retentionInDays') == None + assert len(response["logGroups"]) == 1 + assert response["logGroups"][0].get("retentionInDays") == None response = conn.delete_log_group(logGroupName=log_group_name) @mock_logs def test_get_log_events(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'test' - log_stream_name = 'stream' - conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name + client = boto3.client("logs", "us-west-2") + log_group_name = "test" + log_stream_name = "stream" + client.create_log_group(logGroupName=log_group_name) + client.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) + + events = [{"timestamp": x, "message": str(x)} for x in range(20)] + + client.put_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=events ) - events = [{'timestamp': x, 'message': str(x)} for x in range(20)] - - conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=events + resp = client.get_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name, limit=10 ) - resp = conn.get_log_events( + resp["events"].should.have.length_of(10) + for i in range(10): + resp["events"][i]["timestamp"].should.equal(i + 10) + resp["events"][i]["message"].should.equal(str(i + 10)) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000019" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000010" + ) + + resp = client.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, - limit=10) + nextToken=resp["nextBackwardToken"], + limit=20, + ) - resp['events'].should.have.length_of(10) - resp.should.have.key('nextForwardToken') - resp.should.have.key('nextBackwardToken') + resp["events"].should.have.length_of(10) for i in range(10): - resp['events'][i]['timestamp'].should.equal(i) - resp['events'][i]['message'].should.equal(str(i)) + resp["events"][i]["timestamp"].should.equal(i) + resp["events"][i]["message"].should.equal(str(i)) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000009" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000000" + ) - next_token = resp['nextForwardToken'] - - resp = conn.get_log_events( + resp = client.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, - nextToken=next_token, - limit=10) + nextToken=resp["nextBackwardToken"], + limit=10, + ) - resp['events'].should.have.length_of(10) - resp.should.have.key('nextForwardToken') - resp.should.have.key('nextBackwardToken') - resp['nextForwardToken'].should.equal(next_token) - for i in range(10): - resp['events'][i]['timestamp'].should.equal(i+10) - resp['events'][i]['message'].should.equal(str(i+10)) + resp["events"].should.have.length_of(0) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000000" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000000" + ) - resp = conn.get_log_events( + resp = client.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, - nextToken=resp['nextBackwardToken'], - limit=10) + nextToken=resp["nextForwardToken"], + limit=1, + ) - resp['events'].should.have.length_of(10) - resp.should.have.key('nextForwardToken') - resp.should.have.key('nextBackwardToken') + resp["events"].should.have.length_of(1) + resp["events"][0]["timestamp"].should.equal(1) + resp["events"][0]["message"].should.equal(str(1)) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000001" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000001" + ) + + +@mock_logs +def test_get_log_events_with_start_from_head(): + client = boto3.client("logs", "us-west-2") + log_group_name = "test" + log_stream_name = "stream" + client.create_log_group(logGroupName=log_group_name) + client.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) + + events = [{"timestamp": x, "message": str(x)} for x in range(20)] + + client.put_log_events( + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=events + ) + + resp = client.get_log_events( + logGroupName=log_group_name, + logStreamName=log_stream_name, + limit=10, + startFromHead=True, # this parameter is only relevant without the usage of nextToken + ) + + resp["events"].should.have.length_of(10) for i in range(10): - resp['events'][i]['timestamp'].should.equal(i) - resp['events'][i]['message'].should.equal(str(i)) + resp["events"][i]["timestamp"].should.equal(i) + resp["events"][i]["message"].should.equal(str(i)) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000009" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000000" + ) + + resp = client.get_log_events( + logGroupName=log_group_name, + logStreamName=log_stream_name, + nextToken=resp["nextForwardToken"], + limit=20, + ) + + resp["events"].should.have.length_of(10) + for i in range(10): + resp["events"][i]["timestamp"].should.equal(i + 10) + resp["events"][i]["message"].should.equal(str(i + 10)) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000019" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000010" + ) + + resp = client.get_log_events( + logGroupName=log_group_name, + logStreamName=log_stream_name, + nextToken=resp["nextForwardToken"], + limit=10, + ) + + resp["events"].should.have.length_of(0) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000019" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000019" + ) + + resp = client.get_log_events( + logGroupName=log_group_name, + logStreamName=log_stream_name, + nextToken=resp["nextBackwardToken"], + limit=1, + ) + + resp["events"].should.have.length_of(1) + resp["events"][0]["timestamp"].should.equal(18) + resp["events"][0]["message"].should.equal(str(18)) + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000018" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000018" + ) + + +@mock_logs +def test_get_log_events_errors(): + client = boto3.client("logs", "us-west-2") + log_group_name = "test" + log_stream_name = "stream" + client.create_log_group(logGroupName=log_group_name) + client.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) + + with assert_raises(ClientError) as e: + client.get_log_events( + logGroupName=log_group_name, + logStreamName=log_stream_name, + nextToken="n/00000000000000000000000000000000000000000000000000000000", + ) + ex = e.exception + ex.operation_name.should.equal("GetLogEvents") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.equal("InvalidParameterException") + ex.response["Error"]["Message"].should.contain( + "The specified nextToken is invalid." + ) + + with assert_raises(ClientError) as e: + client.get_log_events( + logGroupName=log_group_name, + logStreamName=log_stream_name, + nextToken="not-existing-token", + ) + ex = e.exception + ex.operation_name.should.equal("GetLogEvents") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.equal("InvalidParameterException") + ex.response["Error"]["Message"].should.contain( + "The specified nextToken is invalid." + ) + + +@mock_logs +def test_list_tags_log_group(): + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + tags = {"tag_key_1": "tag_value_1", "tag_key_2": "tag_value_2"} + + response = conn.create_log_group(logGroupName=log_group_name) + response = conn.list_tags_log_group(logGroupName=log_group_name) + assert response["tags"] == {} + + response = conn.delete_log_group(logGroupName=log_group_name) + response = conn.create_log_group(logGroupName=log_group_name, tags=tags) + response = conn.list_tags_log_group(logGroupName=log_group_name) + assert response["tags"] == tags + + response = conn.delete_log_group(logGroupName=log_group_name) + + +@mock_logs +def test_tag_log_group(): + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + tags = {"tag_key_1": "tag_value_1"} + response = conn.create_log_group(logGroupName=log_group_name) + + response = conn.tag_log_group(logGroupName=log_group_name, tags=tags) + response = conn.list_tags_log_group(logGroupName=log_group_name) + assert response["tags"] == tags + + tags_with_added_value = {"tag_key_1": "tag_value_1", "tag_key_2": "tag_value_2"} + response = conn.tag_log_group( + logGroupName=log_group_name, tags={"tag_key_2": "tag_value_2"} + ) + response = conn.list_tags_log_group(logGroupName=log_group_name) + assert response["tags"] == tags_with_added_value + + tags_with_updated_value = {"tag_key_1": "tag_value_XX", "tag_key_2": "tag_value_2"} + response = conn.tag_log_group( + logGroupName=log_group_name, tags={"tag_key_1": "tag_value_XX"} + ) + response = conn.list_tags_log_group(logGroupName=log_group_name) + assert response["tags"] == tags_with_updated_value + + response = conn.delete_log_group(logGroupName=log_group_name) + + +@mock_logs +def test_untag_log_group(): + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + response = conn.create_log_group(logGroupName=log_group_name) + + tags = {"tag_key_1": "tag_value_1", "tag_key_2": "tag_value_2"} + response = conn.tag_log_group(logGroupName=log_group_name, tags=tags) + response = conn.list_tags_log_group(logGroupName=log_group_name) + assert response["tags"] == tags + + tags_to_remove = ["tag_key_1"] + remaining_tags = {"tag_key_2": "tag_value_2"} + response = conn.untag_log_group(logGroupName=log_group_name, tags=tags_to_remove) + response = conn.list_tags_log_group(logGroupName=log_group_name) + assert response["tags"] == remaining_tags + + response = conn.delete_log_group(logGroupName=log_group_name) diff --git a/tests/test_opsworks/test_apps.py b/tests/test_opsworks/test_apps.py index d13ce8eaf..1d3445c7d 100644 --- a/tests/test_opsworks/test_apps.py +++ b/tests/test_opsworks/test_apps.py @@ -10,19 +10,15 @@ from moto import mock_opsworks @freeze_time("2015-01-01") @mock_opsworks def test_create_app_response(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] - response = client.create_app( - StackId=stack_id, - Type="other", - Name="TestApp" - ) + response = client.create_app(StackId=stack_id, Type="other", Name="TestApp") response.should.contain("AppId") @@ -30,73 +26,51 @@ def test_create_app_response(): Name="test_stack_2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] - response = client.create_app( - StackId=second_stack_id, - Type="other", - Name="TestApp" - ) + response = client.create_app(StackId=second_stack_id, Type="other", Name="TestApp") response.should.contain("AppId") # ClientError client.create_app.when.called_with( - StackId=stack_id, - Type="other", - Name="TestApp" - ).should.throw( - Exception, re.compile(r'already an app named "TestApp"') - ) + StackId=stack_id, Type="other", Name="TestApp" + ).should.throw(Exception, re.compile(r'already an app named "TestApp"')) # ClientError client.create_app.when.called_with( - StackId="nothere", - Type="other", - Name="TestApp" - ).should.throw( - Exception, "nothere" - ) + StackId="nothere", Type="other", Name="TestApp" + ).should.throw(Exception, "nothere") + @freeze_time("2015-01-01") @mock_opsworks def test_describe_apps(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] - app_id = client.create_app( - StackId=stack_id, - Type="other", - Name="TestApp" - )['AppId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] + app_id = client.create_app(StackId=stack_id, Type="other", Name="TestApp")["AppId"] rv1 = client.describe_apps(StackId=stack_id) rv2 = client.describe_apps(AppIds=[app_id]) - rv1['Apps'].should.equal(rv2['Apps']) + rv1["Apps"].should.equal(rv2["Apps"]) - rv1['Apps'][0]['Name'].should.equal("TestApp") + rv1["Apps"][0]["Name"].should.equal("TestApp") # ClientError client.describe_apps.when.called_with( - StackId=stack_id, - AppIds=[app_id] - ).should.throw( - Exception, "Please provide one or more app IDs or a stack ID" - ) + StackId=stack_id, AppIds=[app_id] + ).should.throw(Exception, "Please provide one or more app IDs or a stack ID") # ClientError - client.describe_apps.when.called_with( - StackId="nothere" - ).should.throw( + client.describe_apps.when.called_with(StackId="nothere").should.throw( Exception, "Unable to find stack with ID nothere" ) # ClientError - client.describe_apps.when.called_with( - AppIds=["nothere"] - ).should.throw( + client.describe_apps.when.called_with(AppIds=["nothere"]).should.throw( Exception, "nothere" ) diff --git a/tests/test_opsworks/test_instances.py b/tests/test_opsworks/test_instances.py index 25260ad78..5f0dc2040 100644 --- a/tests/test_opsworks/test_instances.py +++ b/tests/test_opsworks/test_instances.py @@ -8,34 +8,34 @@ from moto import mock_ec2 @mock_opsworks def test_create_instance(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] layer_id = client.create_layer( StackId=stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" - )['LayerId'] + Shortname="TestLayerShortName", + )["LayerId"] second_stack_id = client.create_stack( Name="test_stack_2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] second_layer_id = client.create_layer( StackId=second_stack_id, Type="custom", Name="SecondTestLayer", - Shortname="SecondTestLayerShortName" - )['LayerId'] + Shortname="SecondTestLayerShortName", + )["LayerId"] response = client.create_instance( StackId=stack_id, LayerIds=[layer_id], InstanceType="t2.micro" @@ -55,9 +55,9 @@ def test_create_instance(): StackId=stack_id, LayerIds=[second_layer_id], InstanceType="t2.micro" ).should.throw(Exception, "Please only provide layer IDs from the same stack") # ClientError - client.start_instance.when.called_with( - InstanceId="nothere" - ).should.throw(Exception, "Unable to find instance with ID nothere") + client.start_instance.when.called_with(InstanceId="nothere").should.throw( + Exception, "Unable to find instance with ID nothere" + ) @mock_opsworks @@ -70,112 +70,95 @@ def test_describe_instances(): populate S2L2 with 3 instances (S2L2_i1..2) """ - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") S1 = client.create_stack( Name="S1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] S1L1 = client.create_layer( - StackId=S1, - Type="custom", - Name="S1L1", - Shortname="S1L1" - )['LayerId'] + StackId=S1, Type="custom", Name="S1L1", Shortname="S1L1" + )["LayerId"] S2 = client.create_stack( Name="S2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] S2L1 = client.create_layer( - StackId=S2, - Type="custom", - Name="S2L1", - Shortname="S2L1" - )['LayerId'] + StackId=S2, Type="custom", Name="S2L1", Shortname="S2L1" + )["LayerId"] S2L2 = client.create_layer( - StackId=S2, - Type="custom", - Name="S2L2", - Shortname="S2L2" - )['LayerId'] + StackId=S2, Type="custom", Name="S2L2", Shortname="S2L2" + )["LayerId"] S1L1_i1 = client.create_instance( StackId=S1, LayerIds=[S1L1], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S1L1_i2 = client.create_instance( StackId=S1, LayerIds=[S1L1], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S2L1_i1 = client.create_instance( StackId=S2, LayerIds=[S2L1], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S2L2_i1 = client.create_instance( StackId=S2, LayerIds=[S2L2], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S2L2_i2 = client.create_instance( StackId=S2, LayerIds=[S2L2], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] # instances in Stack 1 - response = client.describe_instances(StackId=S1)['Instances'] + response = client.describe_instances(StackId=S1)["Instances"] response.should.have.length_of(2) S1L1_i1.should.be.within([i["InstanceId"] for i in response]) S1L1_i2.should.be.within([i["InstanceId"] for i in response]) - response2 = client.describe_instances( - InstanceIds=[S1L1_i1, S1L1_i2])['Instances'] - sorted(response2, key=lambda d: d['InstanceId']).should.equal( - sorted(response, key=lambda d: d['InstanceId'])) + response2 = client.describe_instances(InstanceIds=[S1L1_i1, S1L1_i2])["Instances"] + sorted(response2, key=lambda d: d["InstanceId"]).should.equal( + sorted(response, key=lambda d: d["InstanceId"]) + ) - response3 = client.describe_instances(LayerId=S1L1)['Instances'] - sorted(response3, key=lambda d: d['InstanceId']).should.equal( - sorted(response, key=lambda d: d['InstanceId'])) + response3 = client.describe_instances(LayerId=S1L1)["Instances"] + sorted(response3, key=lambda d: d["InstanceId"]).should.equal( + sorted(response, key=lambda d: d["InstanceId"]) + ) - response = client.describe_instances(StackId=S1)['Instances'] + response = client.describe_instances(StackId=S1)["Instances"] response.should.have.length_of(2) S1L1_i1.should.be.within([i["InstanceId"] for i in response]) S1L1_i2.should.be.within([i["InstanceId"] for i in response]) # instances in Stack 2 - response = client.describe_instances(StackId=S2)['Instances'] + response = client.describe_instances(StackId=S2)["Instances"] response.should.have.length_of(3) S2L1_i1.should.be.within([i["InstanceId"] for i in response]) S2L2_i1.should.be.within([i["InstanceId"] for i in response]) S2L2_i2.should.be.within([i["InstanceId"] for i in response]) - response = client.describe_instances(LayerId=S2L1)['Instances'] + response = client.describe_instances(LayerId=S2L1)["Instances"] response.should.have.length_of(1) S2L1_i1.should.be.within([i["InstanceId"] for i in response]) - response = client.describe_instances(LayerId=S2L2)['Instances'] + response = client.describe_instances(LayerId=S2L2)["Instances"] response.should.have.length_of(2) S2L1_i1.should_not.be.within([i["InstanceId"] for i in response]) # ClientError - client.describe_instances.when.called_with( - StackId=S1, - LayerId=S1L1 - ).should.throw( + client.describe_instances.when.called_with(StackId=S1, LayerId=S1L1).should.throw( Exception, "Please provide either one or more" ) # ClientError - client.describe_instances.when.called_with( - StackId="nothere" - ).should.throw( + client.describe_instances.when.called_with(StackId="nothere").should.throw( Exception, "nothere" ) # ClientError - client.describe_instances.when.called_with( - LayerId="nothere" - ).should.throw( + client.describe_instances.when.called_with(LayerId="nothere").should.throw( Exception, "nothere" ) # ClientError - client.describe_instances.when.called_with( - InstanceIds=["nothere"] - ).should.throw( + client.describe_instances.when.called_with(InstanceIds=["nothere"]).should.throw( Exception, "nothere" ) @@ -187,38 +170,37 @@ def test_ec2_integration(): instances created via OpsWorks should be discoverable via ec2 """ - opsworks = boto3.client('opsworks', region_name='us-east-1') + opsworks = boto3.client("opsworks", region_name="us-east-1") stack_id = opsworks.create_stack( Name="S1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] layer_id = opsworks.create_layer( - StackId=stack_id, - Type="custom", - Name="S1L1", - Shortname="S1L1" - )['LayerId'] + StackId=stack_id, Type="custom", Name="S1L1", Shortname="S1L1" + )["LayerId"] instance_id = opsworks.create_instance( - StackId=stack_id, LayerIds=[layer_id], InstanceType="t2.micro", SshKeyName="testSSH" - )['InstanceId'] + StackId=stack_id, + LayerIds=[layer_id], + InstanceType="t2.micro", + SshKeyName="testSSH", + )["InstanceId"] - ec2 = boto3.client('ec2', region_name='us-east-1') + ec2 = boto3.client("ec2", region_name="us-east-1") # Before starting the instance, it shouldn't be discoverable via ec2 - reservations = ec2.describe_instances()['Reservations'] + reservations = ec2.describe_instances()["Reservations"] assert reservations.should.be.empty # After starting the instance, it should be discoverable via ec2 opsworks.start_instance(InstanceId=instance_id) - reservations = ec2.describe_instances()['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) - instance = reservations[0]['Instances'][0] - opsworks_instance = opsworks.describe_instances(StackId=stack_id)[ - 'Instances'][0] + reservations = ec2.describe_instances()["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) + instance = reservations[0]["Instances"][0] + opsworks_instance = opsworks.describe_instances(StackId=stack_id)["Instances"][0] - instance['InstanceId'].should.equal(opsworks_instance['Ec2InstanceId']) - instance['PrivateIpAddress'].should.equal(opsworks_instance['PrivateIp']) + instance["InstanceId"].should.equal(opsworks_instance["Ec2InstanceId"]) + instance["PrivateIpAddress"].should.equal(opsworks_instance["PrivateIp"]) diff --git a/tests/test_opsworks/test_layers.py b/tests/test_opsworks/test_layers.py index 035c246e2..08d5a1ce4 100644 --- a/tests/test_opsworks/test_layers.py +++ b/tests/test_opsworks/test_layers.py @@ -10,19 +10,19 @@ from moto import mock_opsworks @freeze_time("2015-01-01") @mock_opsworks def test_create_layer_response(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] response = client.create_layer( StackId=stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" + Shortname="TestLayerShortName", ) response.should.contain("LayerId") @@ -31,87 +31,66 @@ def test_create_layer_response(): Name="test_stack_2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] response = client.create_layer( StackId=second_stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" + Shortname="TestLayerShortName", ) response.should.contain("LayerId") # ClientError client.create_layer.when.called_with( - StackId=stack_id, - Type="custom", - Name="TestLayer", - Shortname="_" + StackId=stack_id, Type="custom", Name="TestLayer", Shortname="_" + ).should.throw(Exception, re.compile(r'already a layer named "TestLayer"')) + # ClientError + client.create_layer.when.called_with( + StackId=stack_id, Type="custom", Name="_", Shortname="TestLayerShortName" ).should.throw( - Exception, re.compile(r'already a layer named "TestLayer"') + Exception, re.compile(r'already a layer with shortname "TestLayerShortName"') ) # ClientError client.create_layer.when.called_with( - StackId=stack_id, - Type="custom", - Name="_", - Shortname="TestLayerShortName" - ).should.throw( - Exception, re.compile( - r'already a layer with shortname "TestLayerShortName"') - ) - # ClientError - client.create_layer.when.called_with( - StackId="nothere", - Type="custom", - Name="TestLayer", - Shortname="_" - ).should.throw( - Exception, "nothere" - ) + StackId="nothere", Type="custom", Name="TestLayer", Shortname="_" + ).should.throw(Exception, "nothere") @freeze_time("2015-01-01") @mock_opsworks def test_describe_layers(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] layer_id = client.create_layer( StackId=stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" - )['LayerId'] + Shortname="TestLayerShortName", + )["LayerId"] rv1 = client.describe_layers(StackId=stack_id) rv2 = client.describe_layers(LayerIds=[layer_id]) - rv1['Layers'].should.equal(rv2['Layers']) + rv1["Layers"].should.equal(rv2["Layers"]) - rv1['Layers'][0]['Name'].should.equal("TestLayer") + rv1["Layers"][0]["Name"].should.equal("TestLayer") # ClientError client.describe_layers.when.called_with( - StackId=stack_id, - LayerIds=[layer_id] - ).should.throw( - Exception, "Please provide one or more layer IDs or a stack ID" - ) + StackId=stack_id, LayerIds=[layer_id] + ).should.throw(Exception, "Please provide one or more layer IDs or a stack ID") # ClientError - client.describe_layers.when.called_with( - StackId="nothere" - ).should.throw( + client.describe_layers.when.called_with(StackId="nothere").should.throw( Exception, "Unable to find stack with ID nothere" ) # ClientError - client.describe_layers.when.called_with( - LayerIds=["nothere"] - ).should.throw( + client.describe_layers.when.called_with(LayerIds=["nothere"]).should.throw( Exception, "nothere" ) diff --git a/tests/test_opsworks/test_stack.py b/tests/test_opsworks/test_stack.py index 2a1b6cc67..277eda1ec 100644 --- a/tests/test_opsworks/test_stack.py +++ b/tests/test_opsworks/test_stack.py @@ -1,46 +1,46 @@ -from __future__ import unicode_literals -import boto3 -import sure # noqa -import re - -from moto import mock_opsworks - - -@mock_opsworks -def test_create_stack_response(): - client = boto3.client('opsworks', region_name='us-east-1') - response = client.create_stack( - Name="test_stack_1", - Region="us-east-1", - ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - ) - response.should.contain("StackId") - - -@mock_opsworks -def test_describe_stacks(): - client = boto3.client('opsworks', region_name='us-east-1') - for i in range(1, 4): - client.create_stack( - Name="test_stack_{0}".format(i), - Region="us-east-1", - ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - ) - - response = client.describe_stacks() - response['Stacks'].should.have.length_of(3) - for stack in response['Stacks']: - stack['ServiceRoleArn'].should.equal("service_arn") - stack['DefaultInstanceProfileArn'].should.equal("profile_arn") - - _id = response['Stacks'][0]['StackId'] - response = client.describe_stacks(StackIds=[_id]) - response['Stacks'].should.have.length_of(1) - response['Stacks'][0]['Arn'].should.contain(_id) - - # ClientError/ResourceNotFoundException - client.describe_stacks.when.called_with(StackIds=["foo"]).should.throw( - Exception, re.compile(r'foo') - ) +from __future__ import unicode_literals +import boto3 +import sure # noqa +import re + +from moto import mock_opsworks + + +@mock_opsworks +def test_create_stack_response(): + client = boto3.client("opsworks", region_name="us-east-1") + response = client.create_stack( + Name="test_stack_1", + Region="us-east-1", + ServiceRoleArn="service_arn", + DefaultInstanceProfileArn="profile_arn", + ) + response.should.contain("StackId") + + +@mock_opsworks +def test_describe_stacks(): + client = boto3.client("opsworks", region_name="us-east-1") + for i in range(1, 4): + client.create_stack( + Name="test_stack_{0}".format(i), + Region="us-east-1", + ServiceRoleArn="service_arn", + DefaultInstanceProfileArn="profile_arn", + ) + + response = client.describe_stacks() + response["Stacks"].should.have.length_of(3) + for stack in response["Stacks"]: + stack["ServiceRoleArn"].should.equal("service_arn") + stack["DefaultInstanceProfileArn"].should.equal("profile_arn") + + _id = response["Stacks"][0]["StackId"] + response = client.describe_stacks(StackIds=[_id]) + response["Stacks"].should.have.length_of(1) + response["Stacks"][0]["Arn"].should.contain(_id) + + # ClientError/ResourceNotFoundException + client.describe_stacks.when.called_with(StackIds=["foo"]).should.throw( + Exception, re.compile(r"foo") + ) diff --git a/tests/test_organizations/organizations_test_utils.py b/tests/test_organizations/organizations_test_utils.py index 83b60b877..12189c530 100644 --- a/tests/test_organizations/organizations_test_utils.py +++ b/tests/test_organizations/organizations_test_utils.py @@ -37,115 +37,108 @@ def test_make_random_service_control_policy_id(): def validate_organization(response): - org = response['Organization'] - sorted(org.keys()).should.equal([ - 'Arn', - 'AvailablePolicyTypes', - 'FeatureSet', - 'Id', - 'MasterAccountArn', - 'MasterAccountEmail', - 'MasterAccountId', - ]) - org['Id'].should.match(utils.ORG_ID_REGEX) - org['MasterAccountId'].should.equal(utils.MASTER_ACCOUNT_ID) - org['MasterAccountArn'].should.equal(utils.MASTER_ACCOUNT_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - )) - org['Arn'].should.equal(utils.ORGANIZATION_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - )) - org['MasterAccountEmail'].should.equal(utils.MASTER_ACCOUNT_EMAIL) - org['FeatureSet'].should.be.within(['ALL', 'CONSOLIDATED_BILLING']) - org['AvailablePolicyTypes'].should.equal([{ - 'Type': 'SERVICE_CONTROL_POLICY', - 'Status': 'ENABLED' - }]) + org = response["Organization"] + sorted(org.keys()).should.equal( + [ + "Arn", + "AvailablePolicyTypes", + "FeatureSet", + "Id", + "MasterAccountArn", + "MasterAccountEmail", + "MasterAccountId", + ] + ) + org["Id"].should.match(utils.ORG_ID_REGEX) + org["MasterAccountId"].should.equal(utils.MASTER_ACCOUNT_ID) + org["MasterAccountArn"].should.equal( + utils.MASTER_ACCOUNT_ARN_FORMAT.format(org["MasterAccountId"], org["Id"]) + ) + org["Arn"].should.equal( + utils.ORGANIZATION_ARN_FORMAT.format(org["MasterAccountId"], org["Id"]) + ) + org["MasterAccountEmail"].should.equal(utils.MASTER_ACCOUNT_EMAIL) + org["FeatureSet"].should.be.within(["ALL", "CONSOLIDATED_BILLING"]) + org["AvailablePolicyTypes"].should.equal( + [{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}] + ) def validate_roots(org, response): - response.should.have.key('Roots').should.be.a(list) - response['Roots'].should_not.be.empty - root = response['Roots'][0] - root.should.have.key('Id').should.match(utils.ROOT_ID_REGEX) - root.should.have.key('Arn').should.equal(utils.ROOT_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - root['Id'], - )) - root.should.have.key('Name').should.be.a(six.string_types) - root.should.have.key('PolicyTypes').should.be.a(list) - root['PolicyTypes'][0].should.have.key('Type').should.equal('SERVICE_CONTROL_POLICY') - root['PolicyTypes'][0].should.have.key('Status').should.equal('ENABLED') + response.should.have.key("Roots").should.be.a(list) + response["Roots"].should_not.be.empty + root = response["Roots"][0] + root.should.have.key("Id").should.match(utils.ROOT_ID_REGEX) + root.should.have.key("Arn").should.equal( + utils.ROOT_ARN_FORMAT.format(org["MasterAccountId"], org["Id"], root["Id"]) + ) + root.should.have.key("Name").should.be.a(six.string_types) + root.should.have.key("PolicyTypes").should.be.a(list) + root["PolicyTypes"][0].should.have.key("Type").should.equal( + "SERVICE_CONTROL_POLICY" + ) + root["PolicyTypes"][0].should.have.key("Status").should.equal("ENABLED") def validate_organizational_unit(org, response): - response.should.have.key('OrganizationalUnit').should.be.a(dict) - ou = response['OrganizationalUnit'] - ou.should.have.key('Id').should.match(utils.OU_ID_REGEX) - ou.should.have.key('Arn').should.equal(utils.OU_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - ou['Id'], - )) - ou.should.have.key('Name').should.be.a(six.string_types) + response.should.have.key("OrganizationalUnit").should.be.a(dict) + ou = response["OrganizationalUnit"] + ou.should.have.key("Id").should.match(utils.OU_ID_REGEX) + ou.should.have.key("Arn").should.equal( + utils.OU_ARN_FORMAT.format(org["MasterAccountId"], org["Id"], ou["Id"]) + ) + ou.should.have.key("Name").should.be.a(six.string_types) def validate_account(org, account): - sorted(account.keys()).should.equal([ - 'Arn', - 'Email', - 'Id', - 'JoinedMethod', - 'JoinedTimestamp', - 'Name', - 'Status', - ]) - account['Id'].should.match(utils.ACCOUNT_ID_REGEX) - account['Arn'].should.equal(utils.ACCOUNT_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - account['Id'], - )) - account['Email'].should.match(utils.EMAIL_REGEX) - account['JoinedMethod'].should.be.within(['INVITED', 'CREATED']) - account['Status'].should.be.within(['ACTIVE', 'SUSPENDED']) - account['Name'].should.be.a(six.string_types) - account['JoinedTimestamp'].should.be.a(datetime.datetime) + sorted(account.keys()).should.equal( + ["Arn", "Email", "Id", "JoinedMethod", "JoinedTimestamp", "Name", "Status"] + ) + account["Id"].should.match(utils.ACCOUNT_ID_REGEX) + account["Arn"].should.equal( + utils.ACCOUNT_ARN_FORMAT.format( + org["MasterAccountId"], org["Id"], account["Id"] + ) + ) + account["Email"].should.match(utils.EMAIL_REGEX) + account["JoinedMethod"].should.be.within(["INVITED", "CREATED"]) + account["Status"].should.be.within(["ACTIVE", "SUSPENDED"]) + account["Name"].should.be.a(six.string_types) + account["JoinedTimestamp"].should.be.a(datetime.datetime) def validate_create_account_status(create_status): - sorted(create_status.keys()).should.equal([ - 'AccountId', - 'AccountName', - 'CompletedTimestamp', - 'Id', - 'RequestedTimestamp', - 'State', - ]) - create_status['Id'].should.match(utils.CREATE_ACCOUNT_STATUS_ID_REGEX) - create_status['AccountId'].should.match(utils.ACCOUNT_ID_REGEX) - create_status['AccountName'].should.be.a(six.string_types) - create_status['State'].should.equal('SUCCEEDED') - create_status['RequestedTimestamp'].should.be.a(datetime.datetime) - create_status['CompletedTimestamp'].should.be.a(datetime.datetime) + sorted(create_status.keys()).should.equal( + [ + "AccountId", + "AccountName", + "CompletedTimestamp", + "Id", + "RequestedTimestamp", + "State", + ] + ) + create_status["Id"].should.match(utils.CREATE_ACCOUNT_STATUS_ID_REGEX) + create_status["AccountId"].should.match(utils.ACCOUNT_ID_REGEX) + create_status["AccountName"].should.be.a(six.string_types) + create_status["State"].should.equal("SUCCEEDED") + create_status["RequestedTimestamp"].should.be.a(datetime.datetime) + create_status["CompletedTimestamp"].should.be.a(datetime.datetime) + def validate_policy_summary(org, summary): summary.should.be.a(dict) - summary.should.have.key('Id').should.match(utils.SCP_ID_REGEX) - summary.should.have.key('Arn').should.equal(utils.SCP_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - summary['Id'], - )) - summary.should.have.key('Name').should.be.a(six.string_types) - summary.should.have.key('Description').should.be.a(six.string_types) - summary.should.have.key('Type').should.equal('SERVICE_CONTROL_POLICY') - summary.should.have.key('AwsManaged').should.be.a(bool) + summary.should.have.key("Id").should.match(utils.SCP_ID_REGEX) + summary.should.have.key("Arn").should.equal( + utils.SCP_ARN_FORMAT.format(org["MasterAccountId"], org["Id"], summary["Id"]) + ) + summary.should.have.key("Name").should.be.a(six.string_types) + summary.should.have.key("Description").should.be.a(six.string_types) + summary.should.have.key("Type").should.equal("SERVICE_CONTROL_POLICY") + summary.should.have.key("AwsManaged").should.be.a(bool) + def validate_service_control_policy(org, response): - response.should.have.key('PolicySummary').should.be.a(dict) - response.should.have.key('Content').should.be.a(six.string_types) - validate_policy_summary(org, response['PolicySummary']) + response.should.have.key("PolicySummary").should.be.a(dict) + response.should.have.key("Content").should.be.a(six.string_types) + validate_policy_summary(org, response["PolicySummary"]) diff --git a/tests/test_organizations/test_organizations_boto3.py b/tests/test_organizations/test_organizations_boto3.py index 28f8cca91..dd79ae787 100644 --- a/tests/test_organizations/test_organizations_boto3.py +++ b/tests/test_organizations/test_organizations_boto3.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals import boto3 import json import six +import sure # noqa from botocore.exceptions import ClientError from nose.tools import assert_raises @@ -21,593 +22,694 @@ from .organizations_test_utils import ( @mock_organizations def test_create_organization(): - client = boto3.client('organizations', region_name='us-east-1') - response = client.create_organization(FeatureSet='ALL') + client = boto3.client("organizations", region_name="us-east-1") + response = client.create_organization(FeatureSet="ALL") validate_organization(response) - response['Organization']['FeatureSet'].should.equal('ALL') + response["Organization"]["FeatureSet"].should.equal("ALL") response = client.list_accounts() - len(response['Accounts']).should.equal(1) - response['Accounts'][0]['Name'].should.equal('master') - response['Accounts'][0]['Id'].should.equal(utils.MASTER_ACCOUNT_ID) - response['Accounts'][0]['Email'].should.equal(utils.MASTER_ACCOUNT_EMAIL) + len(response["Accounts"]).should.equal(1) + response["Accounts"][0]["Name"].should.equal("master") + response["Accounts"][0]["Id"].should.equal(utils.MASTER_ACCOUNT_ID) + response["Accounts"][0]["Email"].should.equal(utils.MASTER_ACCOUNT_EMAIL) - response = client.list_policies(Filter='SERVICE_CONTROL_POLICY') - len(response['Policies']).should.equal(1) - response['Policies'][0]['Name'].should.equal('FullAWSAccess') - response['Policies'][0]['Id'].should.equal(utils.DEFAULT_POLICY_ID) - response['Policies'][0]['AwsManaged'].should.equal(True) + response = client.list_policies(Filter="SERVICE_CONTROL_POLICY") + len(response["Policies"]).should.equal(1) + response["Policies"][0]["Name"].should.equal("FullAWSAccess") + response["Policies"][0]["Id"].should.equal(utils.DEFAULT_POLICY_ID) + response["Policies"][0]["AwsManaged"].should.equal(True) response = client.list_targets_for_policy(PolicyId=utils.DEFAULT_POLICY_ID) - len(response['Targets']).should.equal(2) - root_ou = [t for t in response['Targets'] if t['Type'] == 'ROOT'][0] - root_ou['Name'].should.equal('Root') - master_account = [t for t in response['Targets'] if t['Type'] == 'ACCOUNT'][0] - master_account['Name'].should.equal('master') + len(response["Targets"]).should.equal(2) + root_ou = [t for t in response["Targets"] if t["Type"] == "ROOT"][0] + root_ou["Name"].should.equal("Root") + master_account = [t for t in response["Targets"] if t["Type"] == "ACCOUNT"][0] + master_account["Name"].should.equal("master") @mock_organizations def test_describe_organization(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL') + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") response = client.describe_organization() validate_organization(response) @mock_organizations def test_describe_organization_exception(): - client = boto3.client('organizations', region_name='us-east-1') + client = boto3.client("organizations", region_name="us-east-1") with assert_raises(ClientError) as e: response = client.describe_organization() ex = e.exception - ex.operation_name.should.equal('DescribeOrganization') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AWSOrganizationsNotInUseException') + ex.operation_name.should.equal("DescribeOrganization") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AWSOrganizationsNotInUseException") # Organizational Units + @mock_organizations def test_list_roots(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] response = client.list_roots() validate_roots(org, response) @mock_organizations def test_create_organizational_unit(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_name = 'ou01' - response = client.create_organizational_unit( - ParentId=root_id, - Name=ou_name, - ) + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_name = "ou01" + response = client.create_organizational_unit(ParentId=root_id, Name=ou_name) validate_organizational_unit(org, response) - response['OrganizationalUnit']['Name'].should.equal(ou_name) + response["OrganizationalUnit"]["Name"].should.equal(ou_name) @mock_organizations def test_describe_organizational_unit(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] response = client.describe_organizational_unit(OrganizationalUnitId=ou_id) validate_organizational_unit(org, response) @mock_organizations def test_describe_organizational_unit_exception(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] with assert_raises(ClientError) as e: response = client.describe_organizational_unit( OrganizationalUnitId=utils.make_random_root_id() ) ex = e.exception - ex.operation_name.should.equal('DescribeOrganizationalUnit') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("DescribeOrganizationalUnit") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) @mock_organizations def test_list_organizational_units_for_parent(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - client.create_organizational_unit(ParentId=root_id, Name='ou01') - client.create_organizational_unit(ParentId=root_id, Name='ou02') - client.create_organizational_unit(ParentId=root_id, Name='ou03') + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + client.create_organizational_unit(ParentId=root_id, Name="ou01") + client.create_organizational_unit(ParentId=root_id, Name="ou02") + client.create_organizational_unit(ParentId=root_id, Name="ou03") response = client.list_organizational_units_for_parent(ParentId=root_id) - response.should.have.key('OrganizationalUnits').should.be.a(list) - for ou in response['OrganizationalUnits']: + response.should.have.key("OrganizationalUnits").should.be.a(list) + for ou in response["OrganizationalUnits"]: validate_organizational_unit(org, dict(OrganizationalUnit=ou)) @mock_organizations def test_list_organizational_units_for_parent_exception(): - client = boto3.client('organizations', region_name='us-east-1') + client = boto3.client("organizations", region_name="us-east-1") with assert_raises(ClientError) as e: response = client.list_organizational_units_for_parent( ParentId=utils.make_random_root_id() ) ex = e.exception - ex.operation_name.should.equal('ListOrganizationalUnitsForParent') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('ParentNotFoundException') + ex.operation_name.should.equal("ListOrganizationalUnitsForParent") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("ParentNotFoundException") # Accounts -mockname = 'mock-account' -mockdomain = 'moto-example.org' -mockemail = '@'.join([mockname, mockdomain]) +mockname = "mock-account" +mockdomain = "moto-example.org" +mockemail = "@".join([mockname, mockdomain]) @mock_organizations def test_create_account(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL') - create_status = client.create_account( - AccountName=mockname, Email=mockemail - )['CreateAccountStatus'] + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + create_status = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ] validate_create_account_status(create_status) - create_status['AccountName'].should.equal(mockname) + create_status["AccountName"].should.equal(mockname) + + +@mock_organizations +def test_describe_create_account_status(): + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + request_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["Id"] + response = client.describe_create_account_status(CreateAccountRequestId=request_id) + validate_create_account_status(response["CreateAccountStatus"]) @mock_organizations def test_describe_account(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - account_id = client.create_account( - AccountName=mockname, Email=mockemail - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] response = client.describe_account(AccountId=account_id) - validate_account(org, response['Account']) - response['Account']['Name'].should.equal(mockname) - response['Account']['Email'].should.equal(mockemail) + validate_account(org, response["Account"]) + response["Account"]["Name"].should.equal(mockname) + response["Account"]["Email"].should.equal(mockemail) @mock_organizations def test_describe_account_exception(): - client = boto3.client('organizations', region_name='us-east-1') + client = boto3.client("organizations", region_name="us-east-1") with assert_raises(ClientError) as e: response = client.describe_account(AccountId=utils.make_random_account_id()) ex = e.exception - ex.operation_name.should.equal('DescribeAccount') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AccountNotFoundException') + ex.operation_name.should.equal("DescribeAccount") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AccountNotFoundException") @mock_organizations def test_list_accounts(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] for i in range(5): name = mockname + str(i) - email = name + '@' + mockdomain + email = name + "@" + mockdomain client.create_account(AccountName=name, Email=email) response = client.list_accounts() - response.should.have.key('Accounts') - accounts = response['Accounts'] + response.should.have.key("Accounts") + accounts = response["Accounts"] len(accounts).should.equal(6) for account in accounts: validate_account(org, account) - accounts[4]['Name'].should.equal(mockname + '3') - accounts[3]['Email'].should.equal(mockname + '2' + '@' + mockdomain) + accounts[4]["Name"].should.equal(mockname + "3") + accounts[3]["Email"].should.equal(mockname + "2" + "@" + mockdomain) @mock_organizations def test_list_accounts_for_parent(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] response = client.list_accounts_for_parent(ParentId=root_id) - account_id.should.be.within([account['Id'] for account in response['Accounts']]) + account_id.should.be.within([account["Id"] for account in response["Accounts"]]) @mock_organizations def test_move_account(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - account_id = client.create_account( - AccountName=mockname, Email=mockemail - )['CreateAccountStatus']['AccountId'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] client.move_account( - AccountId=account_id, - SourceParentId=root_id, - DestinationParentId=ou01_id, + AccountId=account_id, SourceParentId=root_id, DestinationParentId=ou01_id ) response = client.list_accounts_for_parent(ParentId=ou01_id) - account_id.should.be.within([account['Id'] for account in response['Accounts']]) + account_id.should.be.within([account["Id"] for account in response["Accounts"]]) @mock_organizations def test_list_parents_for_ou(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] response01 = client.list_parents(ChildId=ou01_id) - response01.should.have.key('Parents').should.be.a(list) - response01['Parents'][0].should.have.key('Id').should.equal(root_id) - response01['Parents'][0].should.have.key('Type').should.equal('ROOT') - ou02 = client.create_organizational_unit(ParentId=ou01_id, Name='ou02') - ou02_id = ou02['OrganizationalUnit']['Id'] + response01.should.have.key("Parents").should.be.a(list) + response01["Parents"][0].should.have.key("Id").should.equal(root_id) + response01["Parents"][0].should.have.key("Type").should.equal("ROOT") + ou02 = client.create_organizational_unit(ParentId=ou01_id, Name="ou02") + ou02_id = ou02["OrganizationalUnit"]["Id"] response02 = client.list_parents(ChildId=ou02_id) - response02.should.have.key('Parents').should.be.a(list) - response02['Parents'][0].should.have.key('Id').should.equal(ou01_id) - response02['Parents'][0].should.have.key('Type').should.equal('ORGANIZATIONAL_UNIT') + response02.should.have.key("Parents").should.be.a(list) + response02["Parents"][0].should.have.key("Id").should.equal(ou01_id) + response02["Parents"][0].should.have.key("Type").should.equal("ORGANIZATIONAL_UNIT") @mock_organizations def test_list_parents_for_accounts(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] account01_id = client.create_account( - AccountName='account01', - Email='account01@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account01", Email="account01@moto-example.org" + )["CreateAccountStatus"]["AccountId"] account02_id = client.create_account( - AccountName='account02', - Email='account02@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account02", Email="account02@moto-example.org" + )["CreateAccountStatus"]["AccountId"] client.move_account( - AccountId=account02_id, - SourceParentId=root_id, - DestinationParentId=ou01_id, + AccountId=account02_id, SourceParentId=root_id, DestinationParentId=ou01_id ) response01 = client.list_parents(ChildId=account01_id) - response01.should.have.key('Parents').should.be.a(list) - response01['Parents'][0].should.have.key('Id').should.equal(root_id) - response01['Parents'][0].should.have.key('Type').should.equal('ROOT') + response01.should.have.key("Parents").should.be.a(list) + response01["Parents"][0].should.have.key("Id").should.equal(root_id) + response01["Parents"][0].should.have.key("Type").should.equal("ROOT") response02 = client.list_parents(ChildId=account02_id) - response02.should.have.key('Parents').should.be.a(list) - response02['Parents'][0].should.have.key('Id').should.equal(ou01_id) - response02['Parents'][0].should.have.key('Type').should.equal('ORGANIZATIONAL_UNIT') + response02.should.have.key("Parents").should.be.a(list) + response02["Parents"][0].should.have.key("Id").should.equal(ou01_id) + response02["Parents"][0].should.have.key("Type").should.equal("ORGANIZATIONAL_UNIT") @mock_organizations def test_list_children(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] - ou02 = client.create_organizational_unit(ParentId=ou01_id, Name='ou02') - ou02_id = ou02['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] + ou02 = client.create_organizational_unit(ParentId=ou01_id, Name="ou02") + ou02_id = ou02["OrganizationalUnit"]["Id"] account01_id = client.create_account( - AccountName='account01', - Email='account01@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account01", Email="account01@moto-example.org" + )["CreateAccountStatus"]["AccountId"] account02_id = client.create_account( - AccountName='account02', - Email='account02@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account02", Email="account02@moto-example.org" + )["CreateAccountStatus"]["AccountId"] client.move_account( - AccountId=account02_id, - SourceParentId=root_id, - DestinationParentId=ou01_id, + AccountId=account02_id, SourceParentId=root_id, DestinationParentId=ou01_id ) - response01 = client.list_children(ParentId=root_id, ChildType='ACCOUNT') - response02 = client.list_children(ParentId=root_id, ChildType='ORGANIZATIONAL_UNIT') - response03 = client.list_children(ParentId=ou01_id, ChildType='ACCOUNT') - response04 = client.list_children(ParentId=ou01_id, ChildType='ORGANIZATIONAL_UNIT') - response01['Children'][0]['Id'].should.equal(utils.MASTER_ACCOUNT_ID) - response01['Children'][0]['Type'].should.equal('ACCOUNT') - response01['Children'][1]['Id'].should.equal(account01_id) - response01['Children'][1]['Type'].should.equal('ACCOUNT') - response02['Children'][0]['Id'].should.equal(ou01_id) - response02['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT') - response03['Children'][0]['Id'].should.equal(account02_id) - response03['Children'][0]['Type'].should.equal('ACCOUNT') - response04['Children'][0]['Id'].should.equal(ou02_id) - response04['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT') + response01 = client.list_children(ParentId=root_id, ChildType="ACCOUNT") + response02 = client.list_children(ParentId=root_id, ChildType="ORGANIZATIONAL_UNIT") + response03 = client.list_children(ParentId=ou01_id, ChildType="ACCOUNT") + response04 = client.list_children(ParentId=ou01_id, ChildType="ORGANIZATIONAL_UNIT") + response01["Children"][0]["Id"].should.equal(utils.MASTER_ACCOUNT_ID) + response01["Children"][0]["Type"].should.equal("ACCOUNT") + response01["Children"][1]["Id"].should.equal(account01_id) + response01["Children"][1]["Type"].should.equal("ACCOUNT") + response02["Children"][0]["Id"].should.equal(ou01_id) + response02["Children"][0]["Type"].should.equal("ORGANIZATIONAL_UNIT") + response03["Children"][0]["Id"].should.equal(account02_id) + response03["Children"][0]["Type"].should.equal("ACCOUNT") + response04["Children"][0]["Id"].should.equal(ou02_id) + response04["Children"][0]["Type"].should.equal("ORGANIZATIONAL_UNIT") @mock_organizations def test_list_children_exception(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] with assert_raises(ClientError) as e: response = client.list_children( - ParentId=utils.make_random_root_id(), - ChildType='ACCOUNT' + ParentId=utils.make_random_root_id(), ChildType="ACCOUNT" ) ex = e.exception - ex.operation_name.should.equal('ListChildren') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('ParentNotFoundException') + ex.operation_name.should.equal("ListChildren") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("ParentNotFoundException") with assert_raises(ClientError) as e: - response = client.list_children( - ParentId=root_id, - ChildType='BLEE' - ) + response = client.list_children(ParentId=root_id, ChildType="BLEE") ex = e.exception - ex.operation_name.should.equal('ListChildren') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("ListChildren") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") # Service Control Policies policy_doc01 = dict( - Version='2012-10-17', - Statement=[dict( - Sid='MockPolicyStatement', - Effect='Allow', - Action='s3:*', - Resource='*', - )] + Version="2012-10-17", + Statement=[ + dict(Sid="MockPolicyStatement", Effect="Allow", Action="s3:*", Resource="*") + ], ) + @mock_organizations def test_create_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] policy = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"] validate_service_control_policy(org, policy) - policy['PolicySummary']['Name'].should.equal('MockServiceControlPolicy') - policy['PolicySummary']['Description'].should.equal('A dummy service control policy') - policy['Content'].should.equal(json.dumps(policy_doc01)) + policy["PolicySummary"]["Name"].should.equal("MockServiceControlPolicy") + policy["PolicySummary"]["Description"].should.equal( + "A dummy service control policy" + ) + policy["Content"].should.equal(json.dumps(policy_doc01)) @mock_organizations def test_describe_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] - policy = client.describe_policy(PolicyId=policy_id)['Policy'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] + policy = client.describe_policy(PolicyId=policy_id)["Policy"] validate_service_control_policy(org, policy) - policy['PolicySummary']['Name'].should.equal('MockServiceControlPolicy') - policy['PolicySummary']['Description'].should.equal('A dummy service control policy') - policy['Content'].should.equal(json.dumps(policy_doc01)) + policy["PolicySummary"]["Name"].should.equal("MockServiceControlPolicy") + policy["PolicySummary"]["Description"].should.equal( + "A dummy service control policy" + ) + policy["Content"].should.equal(json.dumps(policy_doc01)) @mock_organizations def test_describe_policy_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - policy_id = 'p-47fhe9s3' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + policy_id = "p-47fhe9s3" with assert_raises(ClientError) as e: response = client.describe_policy(PolicyId=policy_id) ex = e.exception - ex.operation_name.should.equal('DescribePolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('PolicyNotFoundException') + ex.operation_name.should.equal("DescribePolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("PolicyNotFoundException") with assert_raises(ClientError) as e: - response = client.describe_policy(PolicyId='meaninglessstring') + response = client.describe_policy(PolicyId="meaninglessstring") ex = e.exception - ex.operation_name.should.equal('DescribePolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("DescribePolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") @mock_organizations def test_attach_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] response = client.attach_policy(PolicyId=policy_id, TargetId=root_id) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) response = client.attach_policy(PolicyId=policy_id, TargetId=ou_id) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) response = client.attach_policy(PolicyId=policy_id, TargetId=account_id) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @mock_organizations def test_attach_policy_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - root_id='r-dj873' - ou_id='ou-gi99-i7r8eh2i2' - account_id='126644886543' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + root_id = "r-dj873" + ou_id = "ou-gi99-i7r8eh2i2" + account_id = "126644886543" policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] with assert_raises(ClientError) as e: response = client.attach_policy(PolicyId=policy_id, TargetId=root_id) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) with assert_raises(ClientError) as e: response = client.attach_policy(PolicyId=policy_id, TargetId=ou_id) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) with assert_raises(ClientError) as e: response = client.attach_policy(PolicyId=policy_id, TargetId=account_id) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AccountNotFoundException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AccountNotFoundException") with assert_raises(ClientError) as e: - response = client.attach_policy(PolicyId=policy_id, TargetId='meaninglessstring') + response = client.attach_policy( + PolicyId=policy_id, TargetId="meaninglessstring" + ) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") @mock_organizations def test_list_polices(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - for i in range(0,4): + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + for i in range(0, 4): client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy' + str(i), - Type='SERVICE_CONTROL_POLICY' + Description="A dummy service control policy", + Name="MockServiceControlPolicy" + str(i), + Type="SERVICE_CONTROL_POLICY", ) - response = client.list_policies(Filter='SERVICE_CONTROL_POLICY') - for policy in response['Policies']: + response = client.list_policies(Filter="SERVICE_CONTROL_POLICY") + for policy in response["Policies"]: validate_policy_summary(org, policy) - + @mock_organizations def test_list_policies_for_target(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] client.attach_policy(PolicyId=policy_id, TargetId=ou_id) response = client.list_policies_for_target( - TargetId=ou_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=ou_id, Filter="SERVICE_CONTROL_POLICY" ) - for policy in response['Policies']: + for policy in response["Policies"]: validate_policy_summary(org, policy) client.attach_policy(PolicyId=policy_id, TargetId=account_id) response = client.list_policies_for_target( - TargetId=account_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=account_id, Filter="SERVICE_CONTROL_POLICY" ) - for policy in response['Policies']: + for policy in response["Policies"]: validate_policy_summary(org, policy) @mock_organizations def test_list_policies_for_target_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - ou_id='ou-gi99-i7r8eh2i2' - account_id='126644886543' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + ou_id = "ou-gi99-i7r8eh2i2" + account_id = "126644886543" with assert_raises(ClientError) as e: response = client.list_policies_for_target( - TargetId=ou_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=ou_id, Filter="SERVICE_CONTROL_POLICY" ) ex = e.exception - ex.operation_name.should.equal('ListPoliciesForTarget') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("ListPoliciesForTarget") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) with assert_raises(ClientError) as e: response = client.list_policies_for_target( - TargetId=account_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=account_id, Filter="SERVICE_CONTROL_POLICY" ) ex = e.exception - ex.operation_name.should.equal('ListPoliciesForTarget') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AccountNotFoundException') + ex.operation_name.should.equal("ListPoliciesForTarget") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AccountNotFoundException") with assert_raises(ClientError) as e: response = client.list_policies_for_target( - TargetId='meaninglessstring', - Filter='SERVICE_CONTROL_POLICY', + TargetId="meaninglessstring", Filter="SERVICE_CONTROL_POLICY" ) ex = e.exception - ex.operation_name.should.equal('ListPoliciesForTarget') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') - + ex.operation_name.should.equal("ListPoliciesForTarget") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") + @mock_organizations def test_list_targets_for_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] client.attach_policy(PolicyId=policy_id, TargetId=root_id) client.attach_policy(PolicyId=policy_id, TargetId=ou_id) client.attach_policy(PolicyId=policy_id, TargetId=account_id) response = client.list_targets_for_policy(PolicyId=policy_id) - for target in response['Targets']: + for target in response["Targets"]: target.should.be.a(dict) - target.should.have.key('Name').should.be.a(six.string_types) - target.should.have.key('Arn').should.be.a(six.string_types) - target.should.have.key('TargetId').should.be.a(six.string_types) - target.should.have.key('Type').should.be.within( - ['ROOT', 'ORGANIZATIONAL_UNIT', 'ACCOUNT'] + target.should.have.key("Name").should.be.a(six.string_types) + target.should.have.key("Arn").should.be.a(six.string_types) + target.should.have.key("TargetId").should.be.a(six.string_types) + target.should.have.key("Type").should.be.within( + ["ROOT", "ORGANIZATIONAL_UNIT", "ACCOUNT"] ) @mock_organizations def test_list_targets_for_policy_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - policy_id = 'p-47fhe9s3' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + policy_id = "p-47fhe9s3" with assert_raises(ClientError) as e: response = client.list_targets_for_policy(PolicyId=policy_id) ex = e.exception - ex.operation_name.should.equal('ListTargetsForPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('PolicyNotFoundException') + ex.operation_name.should.equal("ListTargetsForPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("PolicyNotFoundException") with assert_raises(ClientError) as e: - response = client.list_targets_for_policy(PolicyId='meaninglessstring') + response = client.list_targets_for_policy(PolicyId="meaninglessstring") ex = e.exception - ex.operation_name.should.equal('ListTargetsForPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("ListTargetsForPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") + + +@mock_organizations +def test_tag_resource(): + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] + + client.tag_resource(ResourceId=account_id, Tags=[{"Key": "key", "Value": "value"}]) + + response = client.list_tags_for_resource(ResourceId=account_id) + response["Tags"].should.equal([{"Key": "key", "Value": "value"}]) + + # adding a tag with an existing key, will update the value + client.tag_resource( + ResourceId=account_id, Tags=[{"Key": "key", "Value": "new-value"}] + ) + + response = client.list_tags_for_resource(ResourceId=account_id) + response["Tags"].should.equal([{"Key": "key", "Value": "new-value"}]) + + +@mock_organizations +def test_tag_resource_errors(): + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + + with assert_raises(ClientError) as e: + client.tag_resource( + ResourceId="000000000000", Tags=[{"Key": "key", "Value": "value"},] + ) + ex = e.exception + ex.operation_name.should.equal("TagResource") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidInputException") + ex.response["Error"]["Message"].should.equal( + "You provided a value that does not match the required pattern." + ) + + +@mock_organizations +def test_list_tags_for_resource(): + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] + client.tag_resource(ResourceId=account_id, Tags=[{"Key": "key", "Value": "value"}]) + + response = client.list_tags_for_resource(ResourceId=account_id) + + response["Tags"].should.equal([{"Key": "key", "Value": "value"}]) + + +@mock_organizations +def test_list_tags_for_resource_errors(): + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + + with assert_raises(ClientError) as e: + client.list_tags_for_resource(ResourceId="000000000000") + ex = e.exception + ex.operation_name.should.equal("ListTagsForResource") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidInputException") + ex.response["Error"]["Message"].should.equal( + "You provided a value that does not match the required pattern." + ) + + +@mock_organizations +def test_untag_resource(): + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] + client.tag_resource(ResourceId=account_id, Tags=[{"Key": "key", "Value": "value"}]) + response = client.list_tags_for_resource(ResourceId=account_id) + response["Tags"].should.equal([{"Key": "key", "Value": "value"}]) + + # removing a non existing tag should not raise any error + client.untag_resource(ResourceId=account_id, TagKeys=["not-existing"]) + response = client.list_tags_for_resource(ResourceId=account_id) + response["Tags"].should.equal([{"Key": "key", "Value": "value"}]) + + client.untag_resource(ResourceId=account_id, TagKeys=["key"]) + response = client.list_tags_for_resource(ResourceId=account_id) + response["Tags"].should.have.length_of(0) + + +@mock_organizations +def test_untag_resource_errors(): + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + + with assert_raises(ClientError) as e: + client.untag_resource(ResourceId="000000000000", TagKeys=["key"]) + ex = e.exception + ex.operation_name.should.equal("UntagResource") + ex.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.response["Error"]["Code"].should.contain("InvalidInputException") + ex.response["Error"]["Message"].should.equal( + "You provided a value that does not match the required pattern." + ) diff --git a/tests/test_packages/__init__.py b/tests/test_packages/__init__.py index bf582e0b3..05b1d476b 100644 --- a/tests/test_packages/__init__.py +++ b/tests/test_packages/__init__.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import logging + # Disable extra logging for tests -logging.getLogger('boto').setLevel(logging.CRITICAL) -logging.getLogger('boto3').setLevel(logging.CRITICAL) -logging.getLogger('botocore').setLevel(logging.CRITICAL) -logging.getLogger('nose').setLevel(logging.CRITICAL) +logging.getLogger("boto").setLevel(logging.CRITICAL) +logging.getLogger("boto3").setLevel(logging.CRITICAL) +logging.getLogger("botocore").setLevel(logging.CRITICAL) +logging.getLogger("nose").setLevel(logging.CRITICAL) diff --git a/tests/test_packages/test_httpretty.py b/tests/test_packages/test_httpretty.py index 48277a2de..ccf9b98ef 100644 --- a/tests/test_packages/test_httpretty.py +++ b/tests/test_packages/test_httpretty.py @@ -3,35 +3,42 @@ from __future__ import unicode_literals import mock -from moto.packages.httpretty.core import HTTPrettyRequest, fake_gethostname, fake_gethostbyname +from moto.packages.httpretty.core import ( + HTTPrettyRequest, + fake_gethostname, + fake_gethostbyname, +) def test_parse_querystring(): - core = HTTPrettyRequest(headers='test test HTTP/1.1') + core = HTTPrettyRequest(headers="test test HTTP/1.1") - qs = 'test test' + qs = "test test" response = core.parse_querystring(qs) assert response == {} -def test_parse_request_body(): - core = HTTPrettyRequest(headers='test test HTTP/1.1') - qs = 'test' +def test_parse_request_body(): + core = HTTPrettyRequest(headers="test test HTTP/1.1") + + qs = "test" response = core.parse_request_body(qs) - assert response == 'test' + assert response == "test" + def test_fake_gethostname(): - response = fake_gethostname() + response = fake_gethostname() + + assert response == "localhost" - assert response == 'localhost' def test_fake_gethostbyname(): - host = 'test' + host = "test" response = fake_gethostbyname(host=host) - assert response == '127.0.0.1' \ No newline at end of file + assert response == "127.0.0.1" diff --git a/tests/test_polly/test_polly.py b/tests/test_polly/test_polly.py index ec85142fa..5428cdeb7 100644 --- a/tests/test_polly/test_polly.py +++ b/tests/test_polly/test_polly.py @@ -7,7 +7,7 @@ from nose.tools import assert_raises from moto import mock_polly # Polly only available in a few regions -DEFAULT_REGION = 'eu-west-1' +DEFAULT_REGION = "eu-west-1" LEXICON_XML = """ @mock_polly def test_describe_voices(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) resp = client.describe_voices() - len(resp['Voices']).should.be.greater_than(1) + len(resp["Voices"]).should.be.greater_than(1) - resp = client.describe_voices(LanguageCode='en-GB') - len(resp['Voices']).should.equal(3) + resp = client.describe_voices(LanguageCode="en-GB") + len(resp["Voices"]).should.equal(3) try: - client.describe_voices(LanguageCode='SOME_LANGUAGE') + client.describe_voices(LanguageCode="SOME_LANGUAGE") except ClientError as err: - err.response['Error']['Code'].should.equal('400') + err.response["Error"]["Code"].should.equal("400") else: - raise RuntimeError('Should of raised an exception') + raise RuntimeError("Should of raised an exception") @mock_polly def test_put_list_lexicon(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) # Return nothing - client.put_lexicon( - Name='test', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test", Content=LEXICON_XML) resp = client.list_lexicons() - len(resp['Lexicons']).should.equal(1) + len(resp["Lexicons"]).should.equal(1) @mock_polly def test_put_get_lexicon(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) # Return nothing - client.put_lexicon( - Name='test', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test", Content=LEXICON_XML) - resp = client.get_lexicon(Name='test') - resp.should.contain('Lexicon') - resp.should.contain('LexiconAttributes') + resp = client.get_lexicon(Name="test") + resp.should.contain("Lexicon") + resp.should.contain("LexiconAttributes") @mock_polly def test_put_lexicon_bad_name(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) try: - client.put_lexicon( - Name='test-invalid', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test-invalid", Content=LEXICON_XML) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised an exception') + raise RuntimeError("Should of raised an exception") @mock_polly def test_synthesize_speech(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) # Return nothing - client.put_lexicon( - Name='test', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test", Content=LEXICON_XML) - tests = ( - ('pcm', 'audio/pcm'), - ('mp3', 'audio/mpeg'), - ('ogg_vorbis', 'audio/ogg'), - ) + tests = (("pcm", "audio/pcm"), ("mp3", "audio/mpeg"), ("ogg_vorbis", "audio/ogg")) for output_format, content_type in tests: resp = client.synthesize_speech( - LexiconNames=['test'], + LexiconNames=["test"], OutputFormat=output_format, - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) - resp['ContentType'].should.equal(content_type) + resp["ContentType"].should.equal(content_type) @mock_polly def test_synthesize_speech_bad_lexicon(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test2'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + LexiconNames=["test2"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('LexiconNotFoundException') + err.response["Error"]["Code"].should.equal("LexiconNotFoundException") else: - raise RuntimeError('Should of raised LexiconNotFoundException') + raise RuntimeError("Should of raised LexiconNotFoundException") @mock_polly def test_synthesize_speech_bad_output_format(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='invalid', - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="invalid", + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_sample_rate(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='18000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="18000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidSampleRateException') + err.response["Error"]["Code"].should.equal("InvalidSampleRateException") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_text_type(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='invalid', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="invalid", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_voice_id(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Luke' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Luke", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_text_too_long(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234'*376, # = 3008 characters - TextType='text', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234" * 376, # = 3008 characters + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('TextLengthExceededException') + err.response["Error"]["Code"].should.equal("TextLengthExceededException") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_speech_marks1(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='text', - SpeechMarkTypes=['word'], - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="text", + SpeechMarkTypes=["word"], + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('MarksNotSupportedForFormatException') + err.response["Error"]["Code"].should.equal( + "MarksNotSupportedForFormatException" + ) else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_speech_marks2(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='ssml', - SpeechMarkTypes=['word'], - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="ssml", + SpeechMarkTypes=["word"], + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('MarksNotSupportedForFormatException') + err.response["Error"]["Code"].should.equal( + "MarksNotSupportedForFormatException" + ) else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") diff --git a/tests/test_polly/test_server.py b/tests/test_polly/test_server.py index e080c7551..756c9d7e4 100644 --- a/tests/test_polly/test_server.py +++ b/tests/test_polly/test_server.py @@ -1,19 +1,19 @@ -from __future__ import unicode_literals - -import sure # noqa - -import moto.server as server -from moto import mock_polly - -''' -Test the different server responses -''' - - -@mock_polly -def test_polly_list(): - backend = server.create_backend_app("polly") - test_client = backend.test_client() - - res = test_client.get('/v1/lexicons') - res.status_code.should.equal(200) +from __future__ import unicode_literals + +import sure # noqa + +import moto.server as server +from moto import mock_polly + +""" +Test the different server responses +""" + + +@mock_polly +def test_polly_list(): + backend = server.create_backend_app("polly") + test_client = backend.test_client() + + res = test_client.get("/v1/lexicons") + res.status_code.should.equal(200) diff --git a/tests/test_rds/test_rds.py b/tests/test_rds/test_rds.py index af330e672..4ebea0cf3 100644 --- a/tests/test_rds/test_rds.py +++ b/tests/test_rds/test_rds.py @@ -14,17 +14,19 @@ from tests.helpers import disable_on_py3 def test_create_database(): conn = boto.rds.connect_to_region("us-west-2") - database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2', - security_groups=["my_sg"]) + database = conn.create_dbinstance( + "db-master-1", 10, "db.m1.small", "root", "hunter2", security_groups=["my_sg"] + ) - database.status.should.equal('available') + database.status.should.equal("available") database.id.should.equal("db-master-1") database.allocated_storage.should.equal(10) database.instance_class.should.equal("db.m1.small") database.master_username.should.equal("root") database.endpoint.should.equal( - ('db-master-1.aaaaaaaaaa.us-west-2.rds.amazonaws.com', 3306)) - database.security_groups[0].name.should.equal('my_sg') + ("db-master-1.aaaaaaaaaa.us-west-2.rds.amazonaws.com", 3306) + ) + database.security_groups[0].name.should.equal("my_sg") @mock_rds_deprecated @@ -33,8 +35,8 @@ def test_get_databases(): list(conn.get_all_dbinstances()).should.have.length_of(0) - conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2') - conn.create_dbinstance("db-master-2", 10, 'db.m1.small', 'root', 'hunter2') + conn.create_dbinstance("db-master-1", 10, "db.m1.small", "root", "hunter2") + conn.create_dbinstance("db-master-2", 10, "db.m1.small", "root", "hunter2") list(conn.get_all_dbinstances()).should.have.length_of(2) @@ -46,18 +48,20 @@ def test_get_databases(): @mock_rds def test_get_databases_paginated(): - conn = boto3.client('rds', region_name="us-west-2") + conn = boto3.client("rds", region_name="us-west-2") for i in range(51): - conn.create_db_instance(AllocatedStorage=5, - Port=5432, - DBInstanceIdentifier='rds%d' % i, - DBInstanceClass='db.t1.micro', - Engine='postgres') + conn.create_db_instance( + AllocatedStorage=5, + Port=5432, + DBInstanceIdentifier="rds%d" % i, + DBInstanceClass="db.t1.micro", + Engine="postgres", + ) resp = conn.describe_db_instances() resp["DBInstances"].should.have.length_of(50) - resp["Marker"].should.equal(resp["DBInstances"][-1]['DBInstanceIdentifier']) + resp["Marker"].should.equal(resp["DBInstances"][-1]["DBInstanceIdentifier"]) resp2 = conn.describe_db_instances(Marker=resp["Marker"]) resp2["DBInstances"].should.have.length_of(1) @@ -66,8 +70,7 @@ def test_get_databases_paginated(): @mock_rds_deprecated def test_describe_non_existant_database(): conn = boto.rds.connect_to_region("us-west-2") - conn.get_all_dbinstances.when.called_with( - "not-a-db").should.throw(BotoServerError) + conn.get_all_dbinstances.when.called_with("not-a-db").should.throw(BotoServerError) @mock_rds_deprecated @@ -75,7 +78,7 @@ def test_delete_database(): conn = boto.rds.connect_to_region("us-west-2") list(conn.get_all_dbinstances()).should.have.length_of(0) - conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + conn.create_dbinstance("db-master-1", 10, "db.m1.small", "root", "hunter2") list(conn.get_all_dbinstances()).should.have.length_of(1) conn.delete_dbinstance("db-master-1") @@ -85,16 +88,15 @@ def test_delete_database(): @mock_rds_deprecated def test_delete_non_existant_database(): conn = boto.rds.connect_to_region("us-west-2") - conn.delete_dbinstance.when.called_with( - "not-a-db").should.throw(BotoServerError) + conn.delete_dbinstance.when.called_with("not-a-db").should.throw(BotoServerError) @mock_rds_deprecated def test_create_database_security_group(): conn = boto.rds.connect_to_region("us-west-2") - security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') - security_group.name.should.equal('db_sg') + security_group = conn.create_dbsecurity_group("db_sg", "DB Security Group") + security_group.name.should.equal("db_sg") security_group.description.should.equal("DB Security Group") list(security_group.ip_ranges).should.equal([]) @@ -105,8 +107,8 @@ def test_get_security_groups(): list(conn.get_all_dbsecurity_groups()).should.have.length_of(0) - conn.create_dbsecurity_group('db_sg1', 'DB Security Group') - conn.create_dbsecurity_group('db_sg2', 'DB Security Group') + conn.create_dbsecurity_group("db_sg1", "DB Security Group") + conn.create_dbsecurity_group("db_sg2", "DB Security Group") list(conn.get_all_dbsecurity_groups()).should.have.length_of(2) @@ -119,14 +121,15 @@ def test_get_security_groups(): @mock_rds_deprecated def test_get_non_existant_security_group(): conn = boto.rds.connect_to_region("us-west-2") - conn.get_all_dbsecurity_groups.when.called_with( - "not-a-sg").should.throw(BotoServerError) + conn.get_all_dbsecurity_groups.when.called_with("not-a-sg").should.throw( + BotoServerError + ) @mock_rds_deprecated def test_delete_database_security_group(): conn = boto.rds.connect_to_region("us-west-2") - conn.create_dbsecurity_group('db_sg', 'DB Security Group') + conn.create_dbsecurity_group("db_sg", "DB Security Group") list(conn.get_all_dbsecurity_groups()).should.have.length_of(1) @@ -137,21 +140,22 @@ def test_delete_database_security_group(): @mock_rds_deprecated def test_delete_non_existant_security_group(): conn = boto.rds.connect_to_region("us-west-2") - conn.delete_dbsecurity_group.when.called_with( - "not-a-db").should.throw(BotoServerError) + conn.delete_dbsecurity_group.when.called_with("not-a-db").should.throw( + BotoServerError + ) @disable_on_py3() @mock_rds_deprecated def test_security_group_authorize(): conn = boto.rds.connect_to_region("us-west-2") - security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') + security_group = conn.create_dbsecurity_group("db_sg", "DB Security Group") list(security_group.ip_ranges).should.equal([]) - security_group.authorize(cidr_ip='10.3.2.45/32') + security_group.authorize(cidr_ip="10.3.2.45/32") security_group = conn.get_all_dbsecurity_groups()[0] list(security_group.ip_ranges).should.have.length_of(1) - security_group.ip_ranges[0].cidr_ip.should.equal('10.3.2.45/32') + security_group.ip_ranges[0].cidr_ip.should.equal("10.3.2.45/32") @mock_rds_deprecated @@ -159,8 +163,9 @@ def test_add_security_group_to_database(): conn = boto.rds.connect_to_region("us-west-2") database = conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2') - security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') + "db-master-1", 10, "db.m1.small", "root", "hunter2" + ) + security_group = conn.create_dbsecurity_group("db_sg", "DB Security Group") database.modify(security_groups=[security_group]) database = conn.get_all_dbinstances()[0] @@ -179,9 +184,8 @@ def test_add_database_subnet_group(): subnet_ids = [subnet1.id, subnet2.id] conn = boto.rds.connect_to_region("us-west-2") - subnet_group = conn.create_db_subnet_group( - "db_subnet", "my db subnet", subnet_ids) - subnet_group.name.should.equal('db_subnet') + subnet_group = conn.create_db_subnet_group("db_subnet", "my db subnet", subnet_ids) + subnet_group.name.should.equal("db_subnet") subnet_group.description.should.equal("my db subnet") list(subnet_group.subnet_ids).should.equal(subnet_ids) @@ -200,8 +204,9 @@ def test_describe_database_subnet_group(): list(conn.get_all_db_subnet_groups()).should.have.length_of(2) list(conn.get_all_db_subnet_groups("db_subnet1")).should.have.length_of(1) - conn.get_all_db_subnet_groups.when.called_with( - "not-a-subnet").should.throw(BotoServerError) + conn.get_all_db_subnet_groups.when.called_with("not-a-subnet").should.throw( + BotoServerError + ) @mock_ec2_deprecated @@ -218,8 +223,9 @@ def test_delete_database_subnet_group(): conn.delete_db_subnet_group("db_subnet1") list(conn.get_all_db_subnet_groups()).should.have.length_of(0) - conn.delete_db_subnet_group.when.called_with( - "db_subnet1").should.throw(BotoServerError) + conn.delete_db_subnet_group.when.called_with("db_subnet1").should.throw( + BotoServerError + ) @mock_ec2_deprecated @@ -232,8 +238,14 @@ def test_create_database_in_subnet_group(): conn = boto.rds.connect_to_region("us-west-2") conn.create_db_subnet_group("db_subnet1", "my db subnet", [subnet.id]) - database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', - 'root', 'hunter2', db_subnet_group_name="db_subnet1") + database = conn.create_dbinstance( + "db-master-1", + 10, + "db.m1.small", + "root", + "hunter2", + db_subnet_group_name="db_subnet1", + ) database = conn.get_all_dbinstances("db-master-1")[0] database.subnet_group.name.should.equal("db_subnet1") @@ -244,16 +256,18 @@ def test_create_database_replica(): conn = boto.rds.connect_to_region("us-west-2") primary = conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + "db-master-1", 10, "db.m1.small", "root", "hunter2" + ) replica = conn.create_dbinstance_read_replica( - "replica", "db-master-1", "db.m1.small") + "replica", "db-master-1", "db.m1.small" + ) replica.id.should.equal("replica") replica.instance_class.should.equal("db.m1.small") status_info = replica.status_infos[0] status_info.normal.should.equal(True) - status_info.status_type.should.equal('read replication') - status_info.status.should.equal('replicating') + status_info.status_type.should.equal("read replication") + status_info.status.should.equal("replicating") primary = conn.get_all_dbinstances("db-master-1")[0] primary.read_replica_dbinstance_identifiers[0].should.equal("replica") @@ -270,13 +284,12 @@ def test_create_cross_region_database_replica(): west_2_conn = boto.rds.connect_to_region("us-west-2") primary = west_1_conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + "db-master-1", 10, "db.m1.small", "root", "hunter2" + ) primary_arn = "arn:aws:rds:us-west-1:1234567890:db:db-master-1" replica = west_2_conn.create_dbinstance_read_replica( - "replica", - primary_arn, - "db.m1.small", + "replica", primary_arn, "db.m1.small" ) primary = west_1_conn.get_all_dbinstances("db-master-1")[0] @@ -298,17 +311,19 @@ def test_connecting_to_us_east_1(): # https://github.com/boto/boto/blob/e271ff09364ea18d9d8b6f4d63d6b0ac6cbc9b75/boto/endpoints.json#L285 conn = boto.rds.connect_to_region("us-east-1") - database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2', - security_groups=["my_sg"]) + database = conn.create_dbinstance( + "db-master-1", 10, "db.m1.small", "root", "hunter2", security_groups=["my_sg"] + ) - database.status.should.equal('available') + database.status.should.equal("available") database.id.should.equal("db-master-1") database.allocated_storage.should.equal(10) database.instance_class.should.equal("db.m1.small") database.master_username.should.equal("root") database.endpoint.should.equal( - ('db-master-1.aaaaaaaaaa.us-east-1.rds.amazonaws.com', 3306)) - database.security_groups[0].name.should.equal('my_sg') + ("db-master-1.aaaaaaaaaa.us-east-1.rds.amazonaws.com", 3306) + ) + database.security_groups[0].name.should.equal("my_sg") @mock_rds_deprecated @@ -316,9 +331,10 @@ def test_create_database_with_iops(): conn = boto.rds.connect_to_region("us-west-2") database = conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2', iops=6000) + "db-master-1", 10, "db.m1.small", "root", "hunter2", iops=6000 + ) - database.status.should.equal('available') + database.status.should.equal("available") database.iops.should.equal(6000) # boto>2.36.0 may change the following property name to `storage_type` - database.StorageType.should.equal('io1') + database.StorageType.should.equal("io1") diff --git a/tests/test_rds/test_server.py b/tests/test_rds/test_server.py index 814620331..ab53e83b4 100644 --- a/tests/test_rds/test_server.py +++ b/tests/test_rds/test_server.py @@ -1,20 +1,20 @@ -from __future__ import unicode_literals - -import sure # noqa - -import moto.server as server -from moto import mock_rds - -''' -Test the different server responses -''' - - -@mock_rds -def test_list_databases(): - backend = server.create_backend_app("rds") - test_client = backend.test_client() - - res = test_client.get('/?Action=DescribeDBInstances') - - res.data.decode("utf-8").should.contain("") +from __future__ import unicode_literals + +import sure # noqa + +import moto.server as server +from moto import mock_rds + +""" +Test the different server responses +""" + + +@mock_rds +def test_list_databases(): + backend = server.create_backend_app("rds") + test_client = backend.test_client() + + res = test_client.get("/?Action=DescribeDBInstances") + + res.data.decode("utf-8").should.contain("") diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py index aacaf04f1..47b45539d 100644 --- a/tests/test_rds2/test_rds2.py +++ b/tests/test_rds2/test_rds2.py @@ -8,230 +8,301 @@ from moto import mock_ec2, mock_kms, mock_rds2 @mock_rds2 def test_create_database(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - VpcSecurityGroupIds=['sg-123456']) - db_instance = database['DBInstance'] - db_instance['AllocatedStorage'].should.equal(10) - db_instance['DBInstanceClass'].should.equal("db.m1.small") - db_instance['LicenseModel'].should.equal("license-included") - db_instance['MasterUsername'].should.equal("root") - db_instance['DBSecurityGroups'][0][ - 'DBSecurityGroupName'].should.equal('my_sg') - db_instance['DBInstanceArn'].should.equal( - 'arn:aws:rds:us-west-2:1234567890:db:db-master-1') - db_instance['DBInstanceStatus'].should.equal('available') - db_instance['DBName'].should.equal('staging-postgres') - db_instance['DBInstanceIdentifier'].should.equal("db-master-1") - db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False) - db_instance['DbiResourceId'].should.contain("db-") - db_instance['CopyTagsToSnapshot'].should.equal(False) - db_instance['InstanceCreateTime'].should.be.a("datetime.datetime") - db_instance['VpcSecurityGroups'][0]['VpcSecurityGroupId'].should.equal('sg-123456') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + VpcSecurityGroupIds=["sg-123456"], + ) + db_instance = database["DBInstance"] + db_instance["AllocatedStorage"].should.equal(10) + db_instance["DBInstanceClass"].should.equal("db.m1.small") + db_instance["LicenseModel"].should.equal("license-included") + db_instance["MasterUsername"].should.equal("root") + db_instance["DBSecurityGroups"][0]["DBSecurityGroupName"].should.equal("my_sg") + db_instance["DBInstanceArn"].should.equal( + "arn:aws:rds:us-west-2:1234567890:db:db-master-1" + ) + db_instance["DBInstanceStatus"].should.equal("available") + db_instance["DBName"].should.equal("staging-postgres") + db_instance["DBInstanceIdentifier"].should.equal("db-master-1") + db_instance["IAMDatabaseAuthenticationEnabled"].should.equal(False) + db_instance["DbiResourceId"].should.contain("db-") + db_instance["CopyTagsToSnapshot"].should.equal(False) + db_instance["InstanceCreateTime"].should.be.a("datetime.datetime") + db_instance["VpcSecurityGroups"][0]["VpcSecurityGroupId"].should.equal("sg-123456") + + +@mock_rds2 +def test_create_database_no_allocated_storage(): + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + ) + db_instance = database["DBInstance"] + db_instance["Engine"].should.equal("postgres") + db_instance["StorageType"].should.equal("gp2") + db_instance["AllocatedStorage"].should.equal(20) @mock_rds2 def test_create_database_non_existing_option_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") database = conn.create_db_instance.when.called_with( - DBInstanceIdentifier='db-master-1', + DBInstanceIdentifier="db-master-1", AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - OptionGroupName='non-existing').should.throw(ClientError) + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + OptionGroupName="non-existing", + ).should.throw(ClientError) @mock_rds2 def test_create_database_with_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='my-og', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - OptionGroupName='my-og') - db_instance = database['DBInstance'] - db_instance['AllocatedStorage'].should.equal(10) - db_instance['DBInstanceClass'].should.equal('db.m1.small') - db_instance['DBName'].should.equal('staging-postgres') - db_instance['OptionGroupMemberships'][0]['OptionGroupName'].should.equal('my-og') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="my-og", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + OptionGroupName="my-og", + ) + db_instance = database["DBInstance"] + db_instance["AllocatedStorage"].should.equal(10) + db_instance["DBInstanceClass"].should.equal("db.m1.small") + db_instance["DBName"].should.equal("staging-postgres") + db_instance["OptionGroupMemberships"][0]["OptionGroupName"].should.equal("my-og") @mock_rds2 def test_stop_database(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - mydb = conn.describe_db_instances(DBInstanceIdentifier=database['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + mydb = conn.describe_db_instances( + DBInstanceIdentifier=database["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # test stopping database should shutdown - response = conn.stop_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('stopped') + response = conn.stop_db_instance(DBInstanceIdentifier=mydb["DBInstanceIdentifier"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("stopped") # test rdsclient error when trying to stop an already stopped database - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # test stopping a stopped database with snapshot should error and no snapshot should exist for that call - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier'], DBSnapshotIdentifier='rocky4570-rds-snap').should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"], + DBSnapshotIdentifier="rocky4570-rds-snap", + ).should.throw(ClientError) response = conn.describe_db_snapshots() - response['DBSnapshots'].should.equal([]) + response["DBSnapshots"].should.equal([]) @mock_rds2 def test_start_database(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - mydb = conn.describe_db_instances(DBInstanceIdentifier=database['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + mydb = conn.describe_db_instances( + DBInstanceIdentifier=database["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # test starting an already started database should error - conn.start_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.start_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # stop and test start - should go from stopped to available, create snapshot and check snapshot - response = conn.stop_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier'], DBSnapshotIdentifier='rocky4570-rds-snap') - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('stopped') + response = conn.stop_db_instance( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"], + DBSnapshotIdentifier="rocky4570-rds-snap", + ) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("stopped") response = conn.describe_db_snapshots() - response['DBSnapshots'][0]['DBSnapshotIdentifier'].should.equal('rocky4570-rds-snap') - response = conn.start_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('available') + response["DBSnapshots"][0]["DBSnapshotIdentifier"].should.equal( + "rocky4570-rds-snap" + ) + response = conn.start_db_instance(DBInstanceIdentifier=mydb["DBInstanceIdentifier"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("available") # starting database should not remove snapshot response = conn.describe_db_snapshots() - response['DBSnapshots'][0]['DBSnapshotIdentifier'].should.equal('rocky4570-rds-snap') + response["DBSnapshots"][0]["DBSnapshotIdentifier"].should.equal( + "rocky4570-rds-snap" + ) # test stopping database, create snapshot with existing snapshot already created should throw error - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier'], DBSnapshotIdentifier='rocky4570-rds-snap').should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"], + DBSnapshotIdentifier="rocky4570-rds-snap", + ).should.throw(ClientError) # test stopping database not invoking snapshot should succeed. - response = conn.stop_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('stopped') + response = conn.stop_db_instance(DBInstanceIdentifier=mydb["DBInstanceIdentifier"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("stopped") @mock_rds2 def test_fail_to_stop_multi_az(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - MultiAZ=True) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + MultiAZ=True, + ) - mydb = conn.describe_db_instances(DBInstanceIdentifier=database['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + mydb = conn.describe_db_instances( + DBInstanceIdentifier=database["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # multi-az databases arent allowed to be shutdown at this time. - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # multi-az databases arent allowed to be started up at this time. - conn.start_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.start_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) @mock_rds2 def test_fail_to_stop_readreplica(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - replica = conn.create_db_instance_read_replica(DBInstanceIdentifier="db-replica-1", - SourceDBInstanceIdentifier="db-master-1", - DBInstanceClass="db.m1.small") + replica = conn.create_db_instance_read_replica( + DBInstanceIdentifier="db-replica-1", + SourceDBInstanceIdentifier="db-master-1", + DBInstanceClass="db.m1.small", + ) - mydb = conn.describe_db_instances(DBInstanceIdentifier=replica['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + mydb = conn.describe_db_instances( + DBInstanceIdentifier=replica["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # read-replicas are not allowed to be stopped at this time. - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # read-replicas are not allowed to be started at this time. - conn.start_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.start_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) @mock_rds2 def test_get_databases(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(0) + list(instances["DBInstances"]).should.have.length_of(0) - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) - conn.create_db_instance(DBInstanceIdentifier='db-master-2', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + conn.create_db_instance( + DBInstanceIdentifier="db-master-2", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(2) + list(instances["DBInstances"]).should.have.length_of(2) instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - list(instances['DBInstances']).should.have.length_of(1) - instances['DBInstances'][0][ - 'DBInstanceIdentifier'].should.equal("db-master-1") - instances['DBInstances'][0]['DBInstanceArn'].should.equal( - 'arn:aws:rds:us-west-2:1234567890:db:db-master-1') + list(instances["DBInstances"]).should.have.length_of(1) + instances["DBInstances"][0]["DBInstanceIdentifier"].should.equal("db-master-1") + instances["DBInstances"][0]["DBInstanceArn"].should.equal( + "arn:aws:rds:us-west-2:1234567890:db:db-master-1" + ) @mock_rds2 def test_get_databases_paginated(): - conn = boto3.client('rds', region_name="us-west-2") + conn = boto3.client("rds", region_name="us-west-2") for i in range(51): - conn.create_db_instance(AllocatedStorage=5, - Port=5432, - DBInstanceIdentifier='rds%d' % i, - DBInstanceClass='db.t1.micro', - Engine='postgres') + conn.create_db_instance( + AllocatedStorage=5, + Port=5432, + DBInstanceIdentifier="rds%d" % i, + DBInstanceClass="db.t1.micro", + Engine="postgres", + ) resp = conn.describe_db_instances() resp["DBInstances"].should.have.length_of(50) - resp["Marker"].should.equal(resp["DBInstances"][-1]['DBInstanceIdentifier']) + resp["Marker"].should.equal(resp["DBInstances"][-1]["DBInstanceIdentifier"]) resp2 = conn.describe_db_instances(Marker=resp["Marker"]) resp2["DBInstances"].should.have.length_of(1) @@ -242,1269 +313,1379 @@ def test_get_databases_paginated(): @mock_rds2 def test_describe_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.describe_db_instances.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_modify_db_instance(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) - instances = conn.describe_db_instances(DBInstanceIdentifier='db-master-1') - instances['DBInstances'][0]['AllocatedStorage'].should.equal(10) - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=20, - ApplyImmediately=True, - VpcSecurityGroupIds=['sg-123456']) - instances = conn.describe_db_instances(DBInstanceIdentifier='db-master-1') - instances['DBInstances'][0]['AllocatedStorage'].should.equal(20) - instances['DBInstances'][0]['VpcSecurityGroups'][0]['VpcSecurityGroupId'].should.equal('sg-123456') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") + instances["DBInstances"][0]["AllocatedStorage"].should.equal(10) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=20, + ApplyImmediately=True, + VpcSecurityGroupIds=["sg-123456"], + ) + instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") + instances["DBInstances"][0]["AllocatedStorage"].should.equal(20) + instances["DBInstances"][0]["VpcSecurityGroups"][0][ + "VpcSecurityGroupId" + ].should.equal("sg-123456") @mock_rds2 def test_rename_db_instance(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - list(instances['DBInstances']).should.have.length_of(1) - conn.describe_db_instances.when.called_with(DBInstanceIdentifier="db-master-2").should.throw(ClientError) - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - NewDBInstanceIdentifier='db-master-2', - ApplyImmediately=True) - conn.describe_db_instances.when.called_with(DBInstanceIdentifier="db-master-1").should.throw(ClientError) + list(instances["DBInstances"]).should.have.length_of(1) + conn.describe_db_instances.when.called_with( + DBInstanceIdentifier="db-master-2" + ).should.throw(ClientError) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", + NewDBInstanceIdentifier="db-master-2", + ApplyImmediately=True, + ) + conn.describe_db_instances.when.called_with( + DBInstanceIdentifier="db-master-1" + ).should.throw(ClientError) instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-2") - list(instances['DBInstances']).should.have.length_of(1) + list(instances["DBInstances"]).should.have.length_of(1) @mock_rds2 def test_modify_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') - conn.modify_db_instance.when.called_with(DBInstanceIdentifier='not-a-db', - AllocatedStorage=20, - ApplyImmediately=True).should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.modify_db_instance.when.called_with( + DBInstanceIdentifier="not-a-db", AllocatedStorage=20, ApplyImmediately=True + ).should.throw(ClientError) @mock_rds2 def test_reboot_db_instance(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) - database = conn.reboot_db_instance(DBInstanceIdentifier='db-master-1') - database['DBInstance']['DBInstanceIdentifier'].should.equal("db-master-1") + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + database = conn.reboot_db_instance(DBInstanceIdentifier="db-master-1") + database["DBInstance"]["DBInstanceIdentifier"].should.equal("db-master-1") @mock_rds2 def test_reboot_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.reboot_db_instance.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_delete_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(0) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) + list(instances["DBInstances"]).should.have.length_of(0) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(1) + list(instances["DBInstances"]).should.have.length_of(1) - conn.delete_db_instance(DBInstanceIdentifier="db-primary-1", - FinalDBSnapshotIdentifier='primary-1-snapshot') + conn.delete_db_instance( + DBInstanceIdentifier="db-primary-1", + FinalDBSnapshotIdentifier="primary-1-snapshot", + ) instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(0) + list(instances["DBInstances"]).should.have.length_of(0) # Saved the snapshot - snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get('DBSnapshots') - snapshots[0].get('Engine').should.equal('postgres') + snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get( + "DBSnapshots" + ) + snapshots[0].get("Engine").should.equal("postgres") @mock_rds2 def test_delete_non_existant_database(): - conn = boto3.client('rds2', region_name="us-west-2") + conn = boto3.client("rds2", region_name="us-west-2") conn.delete_db_instance.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_create_db_snapshots(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.create_db_snapshot.when.called_with( - DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1').should.throw(ClientError) + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ).should.throw(ClientError) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='g-1').get('DBSnapshot') + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="g-1" + ).get("DBSnapshot") - snapshot.get('Engine').should.equal('postgres') - snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1') - snapshot.get('DBSnapshotIdentifier').should.equal('g-1') - result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn']) - result['TagList'].should.equal([]) + snapshot.get("Engine").should.equal("postgres") + snapshot.get("DBInstanceIdentifier").should.equal("db-primary-1") + snapshot.get("DBSnapshotIdentifier").should.equal("g-1") + result = conn.list_tags_for_resource(ResourceName=snapshot["DBSnapshotArn"]) + result["TagList"].should.equal([]) @mock_rds2 def test_create_db_snapshots_copy_tags(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.create_db_snapshot.when.called_with( - DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1').should.throw(ClientError) + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ).should.throw(ClientError) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - CopyTagsToSnapshot=True, - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + CopyTagsToSnapshot=True, + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='g-1').get('DBSnapshot') + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="g-1" + ).get("DBSnapshot") - snapshot.get('Engine').should.equal('postgres') - snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1') - snapshot.get('DBSnapshotIdentifier').should.equal('g-1') - result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn']) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + snapshot.get("Engine").should.equal("postgres") + snapshot.get("DBInstanceIdentifier").should.equal("db-primary-1") + snapshot.get("DBSnapshotIdentifier").should.equal("g-1") + result = conn.list_tags_for_resource(ResourceName=snapshot["DBSnapshotArn"]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_describe_db_snapshots(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - created = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1').get('DBSnapshot') + created = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ).get("DBSnapshot") - created.get('Engine').should.equal('postgres') + created.get("Engine").should.equal("postgres") - by_database_id = conn.describe_db_snapshots(DBInstanceIdentifier='db-primary-1').get('DBSnapshots') - by_snapshot_id = conn.describe_db_snapshots(DBSnapshotIdentifier='snapshot-1').get('DBSnapshots') + by_database_id = conn.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1" + ).get("DBSnapshots") + by_snapshot_id = conn.describe_db_snapshots(DBSnapshotIdentifier="snapshot-1").get( + "DBSnapshots" + ) by_snapshot_id.should.equal(by_database_id) snapshot = by_snapshot_id[0] snapshot.should.equal(created) - snapshot.get('Engine').should.equal('postgres') + snapshot.get("Engine").should.equal("postgres") - conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-2') - snapshots = conn.describe_db_snapshots(DBInstanceIdentifier='db-primary-1').get('DBSnapshots') + conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-2" + ) + snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get( + "DBSnapshots" + ) snapshots.should.have.length_of(2) @mock_rds2 def test_delete_db_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ) - conn.describe_db_snapshots(DBSnapshotIdentifier='snapshot-1').get('DBSnapshots')[0] - conn.delete_db_snapshot(DBSnapshotIdentifier='snapshot-1') + conn.describe_db_snapshots(DBSnapshotIdentifier="snapshot-1").get("DBSnapshots")[0] + conn.delete_db_snapshot(DBSnapshotIdentifier="snapshot-1") conn.describe_db_snapshots.when.called_with( - DBSnapshotIdentifier='snapshot-1').should.throw(ClientError) + DBSnapshotIdentifier="snapshot-1" + ).should.throw(ClientError) @mock_rds2 def test_create_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - option_group = conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - option_group['OptionGroup']['OptionGroupName'].should.equal('test') - option_group['OptionGroup']['EngineName'].should.equal('mysql') - option_group['OptionGroup'][ - 'OptionGroupDescription'].should.equal('test option group') - option_group['OptionGroup']['MajorEngineVersion'].should.equal('5.6') + conn = boto3.client("rds", region_name="us-west-2") + option_group = conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + option_group["OptionGroup"]["OptionGroupName"].should.equal("test") + option_group["OptionGroup"]["EngineName"].should.equal("mysql") + option_group["OptionGroup"]["OptionGroupDescription"].should.equal( + "test option group" + ) + option_group["OptionGroup"]["MajorEngineVersion"].should.equal("5.6") @mock_rds2 def test_create_option_group_bad_engine_name(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='invalid_engine', - MajorEngineVersion='5.6', - OptionGroupDescription='test invalid engine').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="invalid_engine", + MajorEngineVersion="5.6", + OptionGroupDescription="test invalid engine", + ).should.throw(ClientError) @mock_rds2 def test_create_option_group_bad_engine_major_version(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='6.6.6', - OptionGroupDescription='test invalid engine version').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="6.6.6", + OptionGroupDescription="test invalid engine version", + ).should.throw(ClientError) @mock_rds2 def test_create_option_group_empty_description(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="", + ).should.throw(ClientError) @mock_rds2 def test_create_option_group_duplicate(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ).should.throw(ClientError) @mock_rds2 def test_describe_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - option_groups = conn.describe_option_groups(OptionGroupName='test') - option_groups['OptionGroupsList'][0][ - 'OptionGroupName'].should.equal('test') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + option_groups = conn.describe_option_groups(OptionGroupName="test") + option_groups["OptionGroupsList"][0]["OptionGroupName"].should.equal("test") @mock_rds2 def test_describe_non_existant_option_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.describe_option_groups.when.called_with( - OptionGroupName="not-a-option-group").should.throw(ClientError) + OptionGroupName="not-a-option-group" + ).should.throw(ClientError) @mock_rds2 def test_delete_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - option_groups = conn.describe_option_groups(OptionGroupName='test') - option_groups['OptionGroupsList'][0][ - 'OptionGroupName'].should.equal('test') - conn.delete_option_group(OptionGroupName='test') - conn.describe_option_groups.when.called_with( - OptionGroupName='test').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + option_groups = conn.describe_option_groups(OptionGroupName="test") + option_groups["OptionGroupsList"][0]["OptionGroupName"].should.equal("test") + conn.delete_option_group(OptionGroupName="test") + conn.describe_option_groups.when.called_with(OptionGroupName="test").should.throw( + ClientError + ) @mock_rds2 def test_delete_non_existant_option_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_option_group.when.called_with( - OptionGroupName='non-existant').should.throw(ClientError) + OptionGroupName="non-existant" + ).should.throw(ClientError) @mock_rds2 def test_describe_option_group_options(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") + option_group_options = conn.describe_option_group_options(EngineName="sqlserver-ee") + len(option_group_options["OptionGroupOptions"]).should.equal(4) option_group_options = conn.describe_option_group_options( - EngineName='sqlserver-ee') - len(option_group_options['OptionGroupOptions']).should.equal(4) + EngineName="sqlserver-ee", MajorEngineVersion="11.00" + ) + len(option_group_options["OptionGroupOptions"]).should.equal(2) option_group_options = conn.describe_option_group_options( - EngineName='sqlserver-ee', MajorEngineVersion='11.00') - len(option_group_options['OptionGroupOptions']).should.equal(2) - option_group_options = conn.describe_option_group_options( - EngineName='mysql', MajorEngineVersion='5.6') - len(option_group_options['OptionGroupOptions']).should.equal(1) + EngineName="mysql", MajorEngineVersion="5.6" + ) + len(option_group_options["OptionGroupOptions"]).should.equal(1) conn.describe_option_group_options.when.called_with( - EngineName='non-existent').should.throw(ClientError) + EngineName="non-existent" + ).should.throw(ClientError) conn.describe_option_group_options.when.called_with( - EngineName='mysql', MajorEngineVersion='non-existent').should.throw(ClientError) + EngineName="mysql", MajorEngineVersion="non-existent" + ).should.throw(ClientError) @mock_rds2 def test_modify_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', EngineName='mysql', - MajorEngineVersion='5.6', OptionGroupDescription='test option group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) # TODO: create option and validate before deleting. # if Someone can tell me how the hell to use this function # to add options to an option_group, I can finish coding this. - result = conn.modify_option_group(OptionGroupName='test', OptionsToInclude=[ - ], OptionsToRemove=['MEMCACHED'], ApplyImmediately=True) - result['OptionGroup']['EngineName'].should.equal('mysql') - result['OptionGroup']['Options'].should.equal([]) - result['OptionGroup']['OptionGroupName'].should.equal('test') + result = conn.modify_option_group( + OptionGroupName="test", + OptionsToInclude=[], + OptionsToRemove=["MEMCACHED"], + ApplyImmediately=True, + ) + result["OptionGroup"]["EngineName"].should.equal("mysql") + result["OptionGroup"]["Options"].should.equal([]) + result["OptionGroup"]["OptionGroupName"].should.equal("test") @mock_rds2 def test_modify_option_group_no_options(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', EngineName='mysql', - MajorEngineVersion='5.6', OptionGroupDescription='test option group') - conn.modify_option_group.when.called_with( - OptionGroupName='test').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + conn.modify_option_group.when.called_with(OptionGroupName="test").should.throw( + ClientError + ) @mock_rds2 def test_modify_non_existant_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.modify_option_group.when.called_with(OptionGroupName='non-existant', OptionsToInclude=[( - 'OptionName', 'Port', 'DBSecurityGroupMemberships', 'VpcSecurityGroupMemberships', 'OptionSettings')]).should.throw(ParamValidationError) + conn = boto3.client("rds", region_name="us-west-2") + conn.modify_option_group.when.called_with( + OptionGroupName="non-existant", + OptionsToInclude=[ + ( + "OptionName", + "Port", + "DBSecurityGroupMemberships", + "VpcSecurityGroupMemberships", + "OptionSettings", + ) + ], + ).should.throw(ParamValidationError) @mock_rds2 def test_delete_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_db_instance.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_list_tags_invalid_arn(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.list_tags_for_resource.when.called_with( - ResourceName='arn:aws:rds:bad-arn').should.throw(ClientError) + ResourceName="arn:aws:rds:bad-arn" + ).should.throw(ClientError) @mock_rds2 def test_list_tags_db(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:foo') - result['TagList'].should.equal([]) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:foo" + ) + result["TagList"].should.equal([]) test_instance = conn.create_db_instance( - DBInstanceIdentifier='db-with-tags', + DBInstanceIdentifier="db-with-tags", AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", Port=1234, - DBSecurityGroups=['my_sg'], - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + DBSecurityGroups=["my_sg"], + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName=test_instance['DBInstance']['DBInstanceArn']) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + ResourceName=test_instance["DBInstance"]["DBInstanceArn"] + ) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_add_tags_db(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-without-tags', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg'], - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-without-tags", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-without-tags') - list(result['TagList']).should.have.length_of(2) - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-without-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }, - ]) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-without-tags" + ) + list(result["TagList"]).should.have.length_of(2) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-without-tags", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-without-tags') - list(result['TagList']).should.have.length_of(3) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-without-tags" + ) + list(result["TagList"]).should.have.length_of(3) @mock_rds2 def test_remove_tags_db(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-with-tags', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg'], - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-with-tags", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-with-tags') - list(result['TagList']).should.have.length_of(2) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-with-tags" + ) + list(result["TagList"]).should.have.length_of(2) conn.remove_tags_from_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-with-tags', TagKeys=['foo']) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-with-tags", TagKeys=["foo"] + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-with-tags') - len(result['TagList']).should.equal(1) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-with-tags" + ) + len(result["TagList"]).should.equal(1) @mock_rds2 def test_list_tags_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:foo') - result['TagList'].should.equal([]) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-with-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) - result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshot']['DBSnapshotArn']) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:foo" + ) + result["TagList"].should.equal([]) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", + DBSnapshotIdentifier="snapshot-with-tags", + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) + result = conn.list_tags_for_resource( + ResourceName=snapshot["DBSnapshot"]["DBSnapshotArn"] + ) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_add_tags_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-without-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", + DBSnapshotIdentifier="snapshot-without-tags", + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags') - list(result['TagList']).should.have.length_of(2) - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }, - ]) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags" + ) + list(result["TagList"]).should.have.length_of(2) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags') - list(result['TagList']).should.have.length_of(3) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags" + ) + list(result["TagList"]).should.have.length_of(3) @mock_rds2 def test_remove_tags_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-with-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", + DBSnapshotIdentifier="snapshot-with-tags", + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags') - list(result['TagList']).should.have.length_of(2) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags" + ) + list(result["TagList"]).should.have.length_of(2) conn.remove_tags_from_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags', TagKeys=['foo']) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags", + TagKeys=["foo"], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags') - len(result['TagList']).should.equal(1) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags" + ) + len(result["TagList"]).should.equal(1) @mock_rds2 def test_add_tags_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(0) - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:og:test', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }]) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(0) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(2) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(2) @mock_rds2 def test_remove_tags_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:og:test', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }]) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(2) - conn.remove_tags_from_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:og:test', - TagKeys=['foo']) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(2) + conn.remove_tags_from_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test", TagKeys=["foo"] + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(1) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(1) @mock_rds2 def test_create_database_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.create_db_security_group( - DBSecurityGroupName='db_sg', DBSecurityGroupDescription='DB Security Group') - result['DBSecurityGroup']['DBSecurityGroupName'].should.equal("db_sg") - result['DBSecurityGroup'][ - 'DBSecurityGroupDescription'].should.equal("DB Security Group") - result['DBSecurityGroup']['IPRanges'].should.equal([]) + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) + result["DBSecurityGroup"]["DBSecurityGroupName"].should.equal("db_sg") + result["DBSecurityGroup"]["DBSecurityGroupDescription"].should.equal( + "DB Security Group" + ) + result["DBSecurityGroup"]["IPRanges"].should.equal([]) @mock_rds2 def test_get_security_groups(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(0) + result["DBSecurityGroups"].should.have.length_of(0) conn.create_db_security_group( - DBSecurityGroupName='db_sg1', DBSecurityGroupDescription='DB Security Group') + DBSecurityGroupName="db_sg1", DBSecurityGroupDescription="DB Security Group" + ) conn.create_db_security_group( - DBSecurityGroupName='db_sg2', DBSecurityGroupDescription='DB Security Group') + DBSecurityGroupName="db_sg2", DBSecurityGroupDescription="DB Security Group" + ) result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(2) + result["DBSecurityGroups"].should.have.length_of(2) result = conn.describe_db_security_groups(DBSecurityGroupName="db_sg1") - result['DBSecurityGroups'].should.have.length_of(1) - result['DBSecurityGroups'][0]['DBSecurityGroupName'].should.equal("db_sg1") + result["DBSecurityGroups"].should.have.length_of(1) + result["DBSecurityGroups"][0]["DBSecurityGroupName"].should.equal("db_sg1") @mock_rds2 def test_get_non_existant_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.describe_db_security_groups.when.called_with( - DBSecurityGroupName="not-a-sg").should.throw(ClientError) + DBSecurityGroupName="not-a-sg" + ).should.throw(ClientError) @mock_rds2 def test_delete_database_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.create_db_security_group( - DBSecurityGroupName='db_sg', DBSecurityGroupDescription='DB Security Group') + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(1) + result["DBSecurityGroups"].should.have.length_of(1) conn.delete_db_security_group(DBSecurityGroupName="db_sg") result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(0) + result["DBSecurityGroups"].should.have.length_of(0) @mock_rds2 def test_delete_non_existant_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_db_security_group.when.called_with( - DBSecurityGroupName="not-a-db").should.throw(ClientError) + DBSecurityGroupName="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_security_group_authorize(): - conn = boto3.client('rds', region_name='us-west-2') - security_group = conn.create_db_security_group(DBSecurityGroupName='db_sg', - DBSecurityGroupDescription='DB Security Group') - security_group['DBSecurityGroup']['IPRanges'].should.equal([]) + conn = boto3.client("rds", region_name="us-west-2") + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) + security_group["DBSecurityGroup"]["IPRanges"].should.equal([]) - conn.authorize_db_security_group_ingress(DBSecurityGroupName='db_sg', - CIDRIP='10.3.2.45/32') + conn.authorize_db_security_group_ingress( + DBSecurityGroupName="db_sg", CIDRIP="10.3.2.45/32" + ) result = conn.describe_db_security_groups(DBSecurityGroupName="db_sg") - result['DBSecurityGroups'][0]['IPRanges'].should.have.length_of(1) - result['DBSecurityGroups'][0]['IPRanges'].should.equal( - [{'Status': 'authorized', 'CIDRIP': '10.3.2.45/32'}]) + result["DBSecurityGroups"][0]["IPRanges"].should.have.length_of(1) + result["DBSecurityGroups"][0]["IPRanges"].should.equal( + [{"Status": "authorized", "CIDRIP": "10.3.2.45/32"}] + ) - conn.authorize_db_security_group_ingress(DBSecurityGroupName='db_sg', - CIDRIP='10.3.2.46/32') + conn.authorize_db_security_group_ingress( + DBSecurityGroupName="db_sg", CIDRIP="10.3.2.46/32" + ) result = conn.describe_db_security_groups(DBSecurityGroupName="db_sg") - result['DBSecurityGroups'][0]['IPRanges'].should.have.length_of(2) - result['DBSecurityGroups'][0]['IPRanges'].should.equal([ - {'Status': 'authorized', 'CIDRIP': '10.3.2.45/32'}, - {'Status': 'authorized', 'CIDRIP': '10.3.2.46/32'}, - ]) + result["DBSecurityGroups"][0]["IPRanges"].should.have.length_of(2) + result["DBSecurityGroups"][0]["IPRanges"].should.equal( + [ + {"Status": "authorized", "CIDRIP": "10.3.2.45/32"}, + {"Status": "authorized", "CIDRIP": "10.3.2.46/32"}, + ] + ) @mock_rds2 def test_add_security_group_to_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234) + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + ) result = conn.describe_db_instances() - result['DBInstances'][0]['DBSecurityGroups'].should.equal([]) - conn.create_db_security_group(DBSecurityGroupName='db_sg', - DBSecurityGroupDescription='DB Security Group') - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - DBSecurityGroups=['db_sg']) + result["DBInstances"][0]["DBSecurityGroups"].should.equal([]) + conn.create_db_security_group( + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", DBSecurityGroups=["db_sg"] + ) result = conn.describe_db_instances() - result['DBInstances'][0]['DBSecurityGroups'][0][ - 'DBSecurityGroupName'].should.equal('db_sg') + result["DBInstances"][0]["DBSecurityGroups"][0]["DBSecurityGroupName"].should.equal( + "db_sg" + ) @mock_rds2 def test_list_tags_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - security_group = conn.create_db_security_group(DBSecurityGroupName="db_sg", - DBSecurityGroupDescription='DB Security Group', - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSecurityGroup']['DBSecurityGroupName'] - resource = 'arn:aws:rds:us-west-2:1234567890:secgrp:{0}'.format( - security_group) + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", + DBSecurityGroupDescription="DB Security Group", + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSecurityGroup"]["DBSecurityGroupName"] + resource = "arn:aws:rds:us-west-2:1234567890:secgrp:{0}".format(security_group) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_add_tags_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - security_group = conn.create_db_security_group(DBSecurityGroupName="db_sg", - DBSecurityGroupDescription='DB Security Group')['DBSecurityGroup']['DBSecurityGroupName'] + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + )["DBSecurityGroup"]["DBSecurityGroupName"] - resource = 'arn:aws:rds:us-west-2:1234567890:secgrp:{0}'.format( - security_group) - conn.add_tags_to_resource(ResourceName=resource, - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + resource = "arn:aws:rds:us-west-2:1234567890:secgrp:{0}".format(security_group) + conn.add_tags_to_resource( + ResourceName=resource, + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + ) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_remove_tags_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - security_group = conn.create_db_security_group(DBSecurityGroupName="db_sg", - DBSecurityGroupDescription='DB Security Group', - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSecurityGroup']['DBSecurityGroupName'] + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", + DBSecurityGroupDescription="DB Security Group", + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSecurityGroup"]["DBSecurityGroupName"] - resource = 'arn:aws:rds:us-west-2:1234567890:secgrp:{0}'.format( - security_group) - conn.remove_tags_from_resource(ResourceName=resource, TagKeys=['foo']) + resource = "arn:aws:rds:us-west-2:1234567890:secgrp:{0}".format(security_group) + conn.remove_tags_from_resource(ResourceName=resource, TagKeys=["foo"]) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar1', 'Key': 'foo1'}]) + result["TagList"].should.equal([{"Value": "bar1", "Key": "foo1"}]) @mock_ec2 @mock_rds2 def test_create_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet1 = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] - subnet2 = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.2.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet1 = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] + subnet2 = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.2.0/24")[ + "Subnet" + ] - subnet_ids = [subnet1['SubnetId'], subnet2['SubnetId']] - conn = boto3.client('rds', region_name='us-west-2') - result = conn.create_db_subnet_group(DBSubnetGroupName='db_subnet', - DBSubnetGroupDescription='my db subnet', - SubnetIds=subnet_ids) - result['DBSubnetGroup']['DBSubnetGroupName'].should.equal("db_subnet") - result['DBSubnetGroup'][ - 'DBSubnetGroupDescription'].should.equal("my db subnet") - subnets = result['DBSubnetGroup']['Subnets'] - subnet_group_ids = [subnets[0]['SubnetIdentifier'], - subnets[1]['SubnetIdentifier']] + subnet_ids = [subnet1["SubnetId"], subnet2["SubnetId"]] + conn = boto3.client("rds", region_name="us-west-2") + result = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet", + DBSubnetGroupDescription="my db subnet", + SubnetIds=subnet_ids, + ) + result["DBSubnetGroup"]["DBSubnetGroupName"].should.equal("db_subnet") + result["DBSubnetGroup"]["DBSubnetGroupDescription"].should.equal("my db subnet") + subnets = result["DBSubnetGroup"]["Subnets"] + subnet_group_ids = [subnets[0]["SubnetIdentifier"], subnets[1]["SubnetIdentifier"]] list(subnet_group_ids).should.equal(subnet_ids) @mock_ec2 @mock_rds2 def test_create_database_in_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_subnet_group(DBSubnetGroupName='db_subnet1', - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSubnetGroupName='db_subnet1') - result = conn.describe_db_instances(DBInstanceIdentifier='db-master-1') - result['DBInstances'][0]['DBSubnetGroup'][ - 'DBSubnetGroupName'].should.equal('db_subnet1') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSubnetGroupName="db_subnet1", + ) + result = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") + result["DBInstances"][0]["DBSubnetGroup"]["DBSubnetGroupName"].should.equal( + "db_subnet1" + ) @mock_ec2 @mock_rds2 def test_describe_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) - conn.create_db_subnet_group(DBSubnetGroupName='db_subnet2', - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet2", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) resp = conn.describe_db_subnet_groups() - resp['DBSubnetGroups'].should.have.length_of(2) + resp["DBSubnetGroups"].should.have.length_of(2) - subnets = resp['DBSubnetGroups'][0]['Subnets'] + subnets = resp["DBSubnetGroups"][0]["Subnets"] subnets.should.have.length_of(1) - list(conn.describe_db_subnet_groups(DBSubnetGroupName="db_subnet1") - ['DBSubnetGroups']).should.have.length_of(1) + list( + conn.describe_db_subnet_groups(DBSubnetGroupName="db_subnet1")["DBSubnetGroups"] + ).should.have.length_of(1) conn.describe_db_subnet_groups.when.called_with( - DBSubnetGroupName="not-a-subnet").should.throw(ClientError) + DBSubnetGroupName="not-a-subnet" + ).should.throw(ClientError) @mock_ec2 @mock_rds2 def test_delete_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(1) + result["DBSubnetGroups"].should.have.length_of(1) conn.delete_db_subnet_group(DBSubnetGroupName="db_subnet1") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) conn.delete_db_subnet_group.when.called_with( - DBSubnetGroupName="db_subnet1").should.throw(ClientError) + DBSubnetGroupName="db_subnet1" + ).should.throw(ClientError) @mock_ec2 @mock_rds2 def test_list_tags_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - subnet = conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']], - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSubnetGroup']['DBSubnetGroupName'] + subnet = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSubnetGroup"]["DBSubnetGroupName"] result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:subgrp:{0}'.format(subnet)) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + ResourceName="arn:aws:rds:us-west-2:1234567890:subgrp:{0}".format(subnet) + ) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_ec2 @mock_rds2 def test_add_tags_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - subnet = conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']], - Tags=[])['DBSubnetGroup']['DBSubnetGroupName'] - resource = 'arn:aws:rds:us-west-2:1234567890:subgrp:{0}'.format(subnet) + subnet = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + Tags=[], + )["DBSubnetGroup"]["DBSubnetGroupName"] + resource = "arn:aws:rds:us-west-2:1234567890:subgrp:{0}".format(subnet) - conn.add_tags_to_resource(ResourceName=resource, - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + conn.add_tags_to_resource( + ResourceName=resource, + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + ) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_ec2 @mock_rds2 def test_remove_tags_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - subnet = conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']], - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSubnetGroup']['DBSubnetGroupName'] - resource = 'arn:aws:rds:us-west-2:1234567890:subgrp:{0}'.format(subnet) + subnet = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSubnetGroup"]["DBSubnetGroupName"] + resource = "arn:aws:rds:us-west-2:1234567890:subgrp:{0}".format(subnet) - conn.remove_tags_from_resource(ResourceName=resource, TagKeys=['foo']) + conn.remove_tags_from_resource(ResourceName=resource, TagKeys=["foo"]) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar1', 'Key': 'foo1'}]) + result["TagList"].should.equal([{"Value": "bar1", "Key": "foo1"}]) @mock_rds2 def test_create_database_replica(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - replica = conn.create_db_instance_read_replica(DBInstanceIdentifier="db-replica-1", - SourceDBInstanceIdentifier="db-master-1", - DBInstanceClass="db.m1.small") - replica['DBInstance'][ - 'ReadReplicaSourceDBInstanceIdentifier'].should.equal('db-master-1') - replica['DBInstance']['DBInstanceClass'].should.equal('db.m1.small') - replica['DBInstance']['DBInstanceIdentifier'].should.equal('db-replica-1') + replica = conn.create_db_instance_read_replica( + DBInstanceIdentifier="db-replica-1", + SourceDBInstanceIdentifier="db-master-1", + DBInstanceClass="db.m1.small", + ) + replica["DBInstance"]["ReadReplicaSourceDBInstanceIdentifier"].should.equal( + "db-master-1" + ) + replica["DBInstance"]["DBInstanceClass"].should.equal("db.m1.small") + replica["DBInstance"]["DBInstanceIdentifier"].should.equal("db-replica-1") master = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - master['DBInstances'][0]['ReadReplicaDBInstanceIdentifiers'].should.equal([ - 'db-replica-1']) + master["DBInstances"][0]["ReadReplicaDBInstanceIdentifiers"].should.equal( + ["db-replica-1"] + ) - conn.delete_db_instance( - DBInstanceIdentifier="db-replica-1", SkipFinalSnapshot=True) + conn.delete_db_instance(DBInstanceIdentifier="db-replica-1", SkipFinalSnapshot=True) master = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - master['DBInstances'][0][ - 'ReadReplicaDBInstanceIdentifiers'].should.equal([]) + master["DBInstances"][0]["ReadReplicaDBInstanceIdentifiers"].should.equal([]) @mock_rds2 @mock_kms def test_create_database_with_encrypted_storage(): - kms_conn = boto3.client('kms', region_name='us-west-2') - key = kms_conn.create_key(Policy='my RDS encryption policy', - Description='RDS encryption key', - KeyUsage='ENCRYPT_DECRYPT') + kms_conn = boto3.client("kms", region_name="us-west-2") + key = kms_conn.create_key( + Policy="my RDS encryption policy", + Description="RDS encryption key", + KeyUsage="ENCRYPT_DECRYPT", + ) - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - StorageEncrypted=True, - KmsKeyId=key['KeyMetadata']['KeyId']) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + StorageEncrypted=True, + KmsKeyId=key["KeyMetadata"]["KeyId"], + ) - database['DBInstance']['StorageEncrypted'].should.equal(True) - database['DBInstance']['KmsKeyId'].should.equal( - key['KeyMetadata']['KeyId']) + database["DBInstance"]["StorageEncrypted"].should.equal(True) + database["DBInstance"]["KmsKeyId"].should.equal(key["KeyMetadata"]["KeyId"]) @mock_rds2 def test_create_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - db_parameter_group = conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') + conn = boto3.client("rds", region_name="us-west-2") + db_parameter_group = conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) - db_parameter_group['DBParameterGroup'][ - 'DBParameterGroupName'].should.equal('test') - db_parameter_group['DBParameterGroup'][ - 'DBParameterGroupFamily'].should.equal('mysql5.6') - db_parameter_group['DBParameterGroup'][ - 'Description'].should.equal('test parameter group') + db_parameter_group["DBParameterGroup"]["DBParameterGroupName"].should.equal("test") + db_parameter_group["DBParameterGroup"]["DBParameterGroupFamily"].should.equal( + "mysql5.6" + ) + db_parameter_group["DBParameterGroup"]["Description"].should.equal( + "test parameter group" + ) @mock_rds2 def test_create_db_instance_with_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - db_parameter_group = conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') + conn = boto3.client("rds", region_name="us-west-2") + db_parameter_group = conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='mysql', - DBInstanceClass='db.m1.small', - DBParameterGroupName='test', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234) + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="mysql", + DBInstanceClass="db.m1.small", + DBParameterGroupName="test", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + ) - len(database['DBInstance']['DBParameterGroups']).should.equal(1) - database['DBInstance']['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') - database['DBInstance']['DBParameterGroups'][0][ - 'ParameterApplyStatus'].should.equal('in-sync') + len(database["DBInstance"]["DBParameterGroups"]).should.equal(1) + database["DBInstance"]["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "test" + ) + database["DBInstance"]["DBParameterGroups"][0]["ParameterApplyStatus"].should.equal( + "in-sync" + ) @mock_rds2 def test_create_database_with_default_port(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - DBSecurityGroups=["my_sg"]) - database['DBInstance']['Endpoint']['Port'].should.equal(5432) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + DBSecurityGroups=["my_sg"], + ) + database["DBInstance"]["Endpoint"]["Port"].should.equal(5432) @mock_rds2 def test_modify_db_instance_with_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='mysql', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="mysql", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + ) - len(database['DBInstance']['DBParameterGroups']).should.equal(1) - database['DBInstance']['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('default.mysql5.6') - database['DBInstance']['DBParameterGroups'][0][ - 'ParameterApplyStatus'].should.equal('in-sync') + len(database["DBInstance"]["DBParameterGroups"]).should.equal(1) + database["DBInstance"]["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "default.mysql5.6" + ) + database["DBInstance"]["DBParameterGroups"][0]["ParameterApplyStatus"].should.equal( + "in-sync" + ) - db_parameter_group = conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - DBParameterGroupName='test', - ApplyImmediately=True) + db_parameter_group = conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", + DBParameterGroupName="test", + ApplyImmediately=True, + ) - database = conn.describe_db_instances( - DBInstanceIdentifier='db-master-1')['DBInstances'][0] - len(database['DBParameterGroups']).should.equal(1) - database['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') - database['DBParameterGroups'][0][ - 'ParameterApplyStatus'].should.equal('in-sync') + database = conn.describe_db_instances(DBInstanceIdentifier="db-master-1")[ + "DBInstances" + ][0] + len(database["DBParameterGroups"]).should.equal(1) + database["DBParameterGroups"][0]["DBParameterGroupName"].should.equal("test") + database["DBParameterGroups"][0]["ParameterApplyStatus"].should.equal("in-sync") @mock_rds2 def test_create_db_parameter_group_empty_description(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group.when.called_with(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group.when.called_with( + DBParameterGroupName="test", DBParameterGroupFamily="mysql5.6", Description="" + ).should.throw(ClientError) @mock_rds2 def test_create_db_parameter_group_duplicate(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - conn.create_db_parameter_group.when.called_with(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + conn.create_db_parameter_group.when.called_with( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ).should.throw(ClientError) @mock_rds2 def test_describe_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - db_parameter_groups['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + db_parameter_groups["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "test" + ) @mock_rds2 def test_describe_non_existant_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - len(db_parameter_groups['DBParameterGroups']).should.equal(0) + conn = boto3.client("rds", region_name="us-west-2") + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + len(db_parameter_groups["DBParameterGroups"]).should.equal(0) @mock_rds2 def test_delete_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - db_parameter_groups['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') - conn.delete_db_parameter_group(DBParameterGroupName='test') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - len(db_parameter_groups['DBParameterGroups']).should.equal(0) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + db_parameter_groups["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "test" + ) + conn.delete_db_parameter_group(DBParameterGroupName="test") + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + len(db_parameter_groups["DBParameterGroups"]).should.equal(0) @mock_rds2 def test_modify_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) - modify_result = conn.modify_db_parameter_group(DBParameterGroupName='test', - Parameters=[{ - 'ParameterName': 'foo', - 'ParameterValue': 'foo_val', - 'Description': 'test param', - 'ApplyMethod': 'immediate' - }] - ) + modify_result = conn.modify_db_parameter_group( + DBParameterGroupName="test", + Parameters=[ + { + "ParameterName": "foo", + "ParameterValue": "foo_val", + "Description": "test param", + "ApplyMethod": "immediate", + } + ], + ) - modify_result['DBParameterGroupName'].should.equal('test') + modify_result["DBParameterGroupName"].should.equal("test") - db_parameters = conn.describe_db_parameters(DBParameterGroupName='test') - db_parameters['Parameters'][0]['ParameterName'].should.equal('foo') - db_parameters['Parameters'][0]['ParameterValue'].should.equal('foo_val') - db_parameters['Parameters'][0]['Description'].should.equal('test param') - db_parameters['Parameters'][0]['ApplyMethod'].should.equal('immediate') + db_parameters = conn.describe_db_parameters(DBParameterGroupName="test") + db_parameters["Parameters"][0]["ParameterName"].should.equal("foo") + db_parameters["Parameters"][0]["ParameterValue"].should.equal("foo_val") + db_parameters["Parameters"][0]["Description"].should.equal("test param") + db_parameters["Parameters"][0]["ApplyMethod"].should.equal("immediate") @mock_rds2 def test_delete_non_existant_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_db_parameter_group.when.called_with( - DBParameterGroupName='non-existant').should.throw(ClientError) + DBParameterGroupName="non-existant" + ).should.throw(ClientError) @mock_rds2 def test_create_parameter_group_with_tags(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group', - Tags=[{ - 'Key': 'foo', - 'Value': 'bar', - }]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + Tags=[{"Key": "foo", "Value": "bar"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:pg:test') - result['TagList'].should.equal([{'Value': 'bar', 'Key': 'foo'}]) + ResourceName="arn:aws:rds:us-west-2:1234567890:pg:test" + ) + result["TagList"].should.equal([{"Value": "bar", "Key": "foo"}]) diff --git a/tests/test_rds2/test_server.py b/tests/test_rds2/test_server.py index 8ae44fb58..dade82c9c 100644 --- a/tests/test_rds2/test_server.py +++ b/tests/test_rds2/test_server.py @@ -1,20 +1,20 @@ -from __future__ import unicode_literals - -import sure # noqa - -import moto.server as server -from moto import mock_rds2 - -''' -Test the different server responses -''' - - -#@mock_rds2 -# def test_list_databases(): -# backend = server.create_backend_app("rds2") -# test_client = backend.test_client() -# -# res = test_client.get('/?Action=DescribeDBInstances') -# -# res.data.decode("utf-8").should.contain("") +from __future__ import unicode_literals + +import sure # noqa + +import moto.server as server +from moto import mock_rds2 + +""" +Test the different server responses +""" + + +# @mock_rds2 +# def test_list_databases(): +# backend = server.create_backend_app("rds2") +# test_client = backend.test_client() +# +# res = test_client.get('/?Action=DescribeDBInstances') +# +# res.data.decode("utf-8").should.contain("") diff --git a/tests/test_redshift/test_redshift.py b/tests/test_redshift/test_redshift.py index 2c9b42a1d..6bb3b1396 100644 --- a/tests/test_redshift/test_redshift.py +++ b/tests/test_redshift/test_redshift.py @@ -9,79 +9,98 @@ from boto.redshift.exceptions import ( ClusterParameterGroupNotFound, ClusterSecurityGroupNotFound, ClusterSubnetGroupNotFound, - InvalidSubnet -) -from botocore.exceptions import ( - ClientError + InvalidSubnet, ) +from botocore.exceptions import ClientError import sure # noqa from moto import mock_ec2 from moto import mock_ec2_deprecated from moto import mock_redshift from moto import mock_redshift_deprecated +from moto.core import ACCOUNT_ID @mock_redshift def test_create_cluster_boto3(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") response = client.create_cluster( - DBName='test', - ClusterIdentifier='test', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", ) - response['Cluster']['NodeType'].should.equal('ds2.xlarge') - create_time = response['Cluster']['ClusterCreateTime'] + response["Cluster"]["NodeType"].should.equal("ds2.xlarge") + create_time = response["Cluster"]["ClusterCreateTime"] create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) - create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) + create_time.should.be.greater_than( + datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1) + ) + response["Cluster"]["EnhancedVpcRouting"].should.equal(False) + + +@mock_redshift +def test_create_cluster_boto3(): + client = boto3.client("redshift", region_name="us-east-1") + response = client.create_cluster( + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", + EnhancedVpcRouting=True, + ) + response["Cluster"]["NodeType"].should.equal("ds2.xlarge") + create_time = response["Cluster"]["ClusterCreateTime"] + create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) + create_time.should.be.greater_than( + datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1) + ) + response["Cluster"]["EnhancedVpcRouting"].should.equal(True) @mock_redshift def test_create_snapshot_copy_grant(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") grants = client.create_snapshot_copy_grant( - SnapshotCopyGrantName='test-us-east-1', - KmsKeyId='fake', + SnapshotCopyGrantName="test-us-east-1", KmsKeyId="fake" ) - grants['SnapshotCopyGrant']['SnapshotCopyGrantName'].should.equal('test-us-east-1') - grants['SnapshotCopyGrant']['KmsKeyId'].should.equal('fake') + grants["SnapshotCopyGrant"]["SnapshotCopyGrantName"].should.equal("test-us-east-1") + grants["SnapshotCopyGrant"]["KmsKeyId"].should.equal("fake") - client.delete_snapshot_copy_grant( - SnapshotCopyGrantName='test-us-east-1', - ) + client.delete_snapshot_copy_grant(SnapshotCopyGrantName="test-us-east-1") client.describe_snapshot_copy_grants.when.called_with( - SnapshotCopyGrantName='test-us-east-1', + SnapshotCopyGrantName="test-us-east-1" ).should.throw(Exception) @mock_redshift def test_create_many_snapshot_copy_grants(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") for i in range(10): client.create_snapshot_copy_grant( - SnapshotCopyGrantName='test-us-east-1-{0}'.format(i), - KmsKeyId='fake', + SnapshotCopyGrantName="test-us-east-1-{0}".format(i), KmsKeyId="fake" ) response = client.describe_snapshot_copy_grants() - len(response['SnapshotCopyGrants']).should.equal(10) + len(response["SnapshotCopyGrants"]).should.equal(10) @mock_redshift def test_no_snapshot_copy_grants(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") response = client.describe_snapshot_copy_grants() - len(response['SnapshotCopyGrants']).should.equal(0) + len(response["SnapshotCopyGrants"]).should.equal(0) @mock_redshift_deprecated def test_create_cluster(): conn = boto.redshift.connect_to_region("us-east-1") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" cluster_response = conn.create_cluster( cluster_identifier, @@ -98,36 +117,40 @@ def test_create_cluster(): allow_version_upgrade=True, number_of_nodes=3, ) - cluster_response['CreateClusterResponse']['CreateClusterResult'][ - 'Cluster']['ClusterStatus'].should.equal('creating') + cluster_response["CreateClusterResponse"]["CreateClusterResult"]["Cluster"][ + "ClusterStatus" + ].should.equal("creating") cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] - cluster['ClusterIdentifier'].should.equal(cluster_identifier) - cluster['NodeType'].should.equal("dw.hs1.xlarge") - cluster['MasterUsername'].should.equal("username") - cluster['DBName'].should.equal("my_db") - cluster['ClusterSecurityGroups'][0][ - 'ClusterSecurityGroupName'].should.equal("Default") - cluster['VpcSecurityGroups'].should.equal([]) - cluster['ClusterSubnetGroupName'].should.equal(None) - cluster['AvailabilityZone'].should.equal("us-east-1d") - cluster['PreferredMaintenanceWindow'].should.equal("Mon:03:00-Mon:11:00") - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("default.redshift-1.0") - cluster['AutomatedSnapshotRetentionPeriod'].should.equal(10) - cluster['Port'].should.equal(1234) - cluster['ClusterVersion'].should.equal("1.0") - cluster['AllowVersionUpgrade'].should.equal(True) - cluster['NumberOfNodes'].should.equal(3) + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("dw.hs1.xlarge") + cluster["MasterUsername"].should.equal("username") + cluster["DBName"].should.equal("my_db") + cluster["ClusterSecurityGroups"][0]["ClusterSecurityGroupName"].should.equal( + "Default" + ) + cluster["VpcSecurityGroups"].should.equal([]) + cluster["ClusterSubnetGroupName"].should.equal(None) + cluster["AvailabilityZone"].should.equal("us-east-1d") + cluster["PreferredMaintenanceWindow"].should.equal("Mon:03:00-Mon:11:00") + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "default.redshift-1.0" + ) + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(10) + cluster["Port"].should.equal(1234) + cluster["ClusterVersion"].should.equal("1.0") + cluster["AllowVersionUpgrade"].should.equal(True) + cluster["NumberOfNodes"].should.equal(3) @mock_redshift_deprecated def test_create_single_node_cluster(): conn = boto.redshift.connect_to_region("us-east-1") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" conn.create_cluster( cluster_identifier, @@ -139,20 +162,21 @@ def test_create_single_node_cluster(): ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] - cluster['ClusterIdentifier'].should.equal(cluster_identifier) - cluster['NodeType'].should.equal("dw.hs1.xlarge") - cluster['MasterUsername'].should.equal("username") - cluster['DBName'].should.equal("my_db") - cluster['NumberOfNodes'].should.equal(1) + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("dw.hs1.xlarge") + cluster["MasterUsername"].should.equal("username") + cluster["DBName"].should.equal("my_db") + cluster["NumberOfNodes"].should.equal(1) @mock_redshift_deprecated def test_default_cluster_attributes(): conn = boto.redshift.connect_to_region("us-east-1") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" conn.create_cluster( cluster_identifier, @@ -162,29 +186,31 @@ def test_default_cluster_attributes(): ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] - cluster['DBName'].should.equal("dev") - cluster['ClusterSubnetGroupName'].should.equal(None) - assert "us-east-" in cluster['AvailabilityZone'] - cluster['PreferredMaintenanceWindow'].should.equal("Mon:03:00-Mon:03:30") - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("default.redshift-1.0") - cluster['AutomatedSnapshotRetentionPeriod'].should.equal(1) - cluster['Port'].should.equal(5439) - cluster['ClusterVersion'].should.equal("1.0") - cluster['AllowVersionUpgrade'].should.equal(True) - cluster['NumberOfNodes'].should.equal(1) + cluster["DBName"].should.equal("dev") + cluster["ClusterSubnetGroupName"].should.equal(None) + assert "us-east-" in cluster["AvailabilityZone"] + cluster["PreferredMaintenanceWindow"].should.equal("Mon:03:00-Mon:03:30") + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "default.redshift-1.0" + ) + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(1) + cluster["Port"].should.equal(5439) + cluster["ClusterVersion"].should.equal("1.0") + cluster["AllowVersionUpgrade"].should.equal(True) + cluster["NumberOfNodes"].should.equal(1) @mock_redshift @mock_ec2 def test_create_cluster_in_subnet_group(): - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( ClusterSubnetGroupName="my_subnet_group", Description="This is my subnet group", @@ -196,25 +222,25 @@ def test_create_cluster_in_subnet_group(): NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - ClusterSubnetGroupName='my_subnet_group', + ClusterSubnetGroupName="my_subnet_group", ) cluster_response = client.describe_clusters(ClusterIdentifier="my_cluster") - cluster = cluster_response['Clusters'][0] - cluster['ClusterSubnetGroupName'].should.equal('my_subnet_group') + cluster = cluster_response["Clusters"][0] + cluster["ClusterSubnetGroupName"].should.equal("my_subnet_group") @mock_redshift @mock_ec2 def test_create_cluster_in_subnet_group_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') - subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock='10.0.0.0/24') - client = boto3.client('redshift', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', - SubnetIds=[subnet.id] + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", + SubnetIds=[subnet.id], ) client.create_cluster( @@ -222,46 +248,42 @@ def test_create_cluster_in_subnet_group_boto3(): NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - ClusterSubnetGroupName='my_subnet_group', + ClusterSubnetGroupName="my_subnet_group", ) cluster_response = client.describe_clusters(ClusterIdentifier="my_cluster") - cluster = cluster_response['Clusters'][0] - cluster['ClusterSubnetGroupName'].should.equal('my_subnet_group') + cluster = cluster_response["Clusters"][0] + cluster["ClusterSubnetGroupName"].should.equal("my_subnet_group") @mock_redshift_deprecated def test_create_cluster_with_security_group(): conn = boto.redshift.connect_to_region("us-east-1") - conn.create_cluster_security_group( - "security_group1", - "This is my security group", - ) - conn.create_cluster_security_group( - "security_group2", - "This is my security group", - ) + conn.create_cluster_security_group("security_group1", "This is my security group") + conn.create_cluster_security_group("security_group2", "This is my security group") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" conn.create_cluster( cluster_identifier, node_type="dw.hs1.xlarge", master_username="username", master_user_password="password", - cluster_security_groups=["security_group1", "security_group2"] + cluster_security_groups=["security_group1", "security_group2"], ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - group_names = [group['ClusterSecurityGroupName'] - for group in cluster['ClusterSecurityGroups']] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + group_names = [ + group["ClusterSecurityGroupName"] for group in cluster["ClusterSecurityGroups"] + ] set(group_names).should.equal(set(["security_group1", "security_group2"])) @mock_redshift def test_create_cluster_with_security_group_boto3(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_security_group( ClusterSecurityGroupName="security_group1", Description="This is my security group", @@ -271,18 +293,19 @@ def test_create_cluster_with_security_group_boto3(): Description="This is my security group", ) - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" client.create_cluster( ClusterIdentifier=cluster_identifier, NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - ClusterSecurityGroups=["security_group1", "security_group2"] + ClusterSecurityGroups=["security_group1", "security_group2"], ) response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = response['Clusters'][0] - group_names = [group['ClusterSecurityGroupName'] - for group in cluster['ClusterSecurityGroups']] + cluster = response["Clusters"][0] + group_names = [ + group["ClusterSecurityGroupName"] for group in cluster["ClusterSecurityGroups"] + ] set(group_names).should.equal({"security_group1", "security_group2"}) @@ -294,7 +317,8 @@ def test_create_cluster_with_vpc_security_groups(): redshift_conn = boto.connect_redshift() vpc = vpc_conn.create_vpc("10.0.0.0/16") security_group = ec2_conn.create_security_group( - "vpc_security_group", "a group", vpc_id=vpc.id) + "vpc_security_group", "a group", vpc_id=vpc.id + ) redshift_conn.create_cluster( "my_cluster", @@ -305,24 +329,23 @@ def test_create_cluster_with_vpc_security_groups(): ) cluster_response = redshift_conn.describe_clusters("my_cluster") - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - group_ids = [group['VpcSecurityGroupId'] - for group in cluster['VpcSecurityGroups']] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + group_ids = [group["VpcSecurityGroupId"] for group in cluster["VpcSecurityGroups"]] list(group_ids).should.equal([security_group.id]) @mock_redshift @mock_ec2 def test_create_cluster_with_vpc_security_groups_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') - client = boto3.client('redshift', region_name='us-east-1') - cluster_id = 'my_cluster' + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + client = boto3.client("redshift", region_name="us-east-1") + cluster_id = "my_cluster" security_group = ec2.create_security_group( - Description="vpc_security_group", - GroupName="a group", - VpcId=vpc.id) + Description="vpc_security_group", GroupName="a group", VpcId=vpc.id + ) client.create_cluster( ClusterIdentifier=cluster_id, NodeType="dw.hs1.xlarge", @@ -331,27 +354,26 @@ def test_create_cluster_with_vpc_security_groups_boto3(): VpcSecurityGroupIds=[security_group.id], ) response = client.describe_clusters(ClusterIdentifier=cluster_id) - cluster = response['Clusters'][0] - group_ids = [group['VpcSecurityGroupId'] - for group in cluster['VpcSecurityGroups']] + cluster = response["Clusters"][0] + group_ids = [group["VpcSecurityGroupId"] for group in cluster["VpcSecurityGroups"]] list(group_ids).should.equal([security_group.id]) @mock_redshift def test_create_cluster_with_iam_roles(): - iam_roles_arn = ['arn:aws:iam:::role/my-iam-role', ] - client = boto3.client('redshift', region_name='us-east-1') - cluster_id = 'my_cluster' + iam_roles_arn = ["arn:aws:iam:::role/my-iam-role"] + client = boto3.client("redshift", region_name="us-east-1") + cluster_id = "my_cluster" client.create_cluster( ClusterIdentifier=cluster_id, NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - IamRoles=iam_roles_arn + IamRoles=iam_roles_arn, ) response = client.describe_clusters(ClusterIdentifier=cluster_id) - cluster = response['Clusters'][0] - iam_roles = [role['IamRoleArn'] for role in cluster['IamRoles']] + cluster = response["Clusters"][0] + iam_roles = [role["IamRoleArn"] for role in cluster["IamRoles"]] iam_roles_arn.should.equal(iam_roles) @@ -359,9 +381,7 @@ def test_create_cluster_with_iam_roles(): def test_create_cluster_with_parameter_group(): conn = boto.connect_redshift() conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) conn.create_cluster( @@ -369,21 +389,25 @@ def test_create_cluster_with_parameter_group(): node_type="dw.hs1.xlarge", master_username="username", master_user_password="password", - cluster_parameter_group_name='my_parameter_group', + cluster_parameter_group_name="my_parameter_group", ) cluster_response = conn.describe_clusters("my_cluster") - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("my_parameter_group") + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "my_parameter_group" + ) @mock_redshift_deprecated def test_describe_non_existent_cluster(): conn = boto.redshift.connect_to_region("us-east-1") - conn.describe_clusters.when.called_with( - "not-a-cluster").should.throw(ClusterNotFound) + conn.describe_clusters.when.called_with("not-a-cluster").should.throw( + ClusterNotFound + ) + @mock_redshift_deprecated def test_delete_cluster(): @@ -398,54 +422,114 @@ def test_delete_cluster(): master_user_password="password", ) - conn.delete_cluster.when.called_with(cluster_identifier, False).should.throw(AttributeError) + conn.delete_cluster.when.called_with(cluster_identifier, False).should.throw( + AttributeError + ) - clusters = conn.describe_clusters()['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'] + clusters = conn.describe_clusters()["DescribeClustersResponse"][ + "DescribeClustersResult" + ]["Clusters"] list(clusters).should.have.length_of(1) conn.delete_cluster( cluster_identifier=cluster_identifier, skip_final_cluster_snapshot=False, - final_cluster_snapshot_identifier=snapshot_identifier - ) + final_cluster_snapshot_identifier=snapshot_identifier, + ) - clusters = conn.describe_clusters()['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'] + clusters = conn.describe_clusters()["DescribeClustersResponse"][ + "DescribeClustersResult" + ]["Clusters"] list(clusters).should.have.length_of(0) snapshots = conn.describe_cluster_snapshots()["DescribeClusterSnapshotsResponse"][ - "DescribeClusterSnapshotsResult"]["Snapshots"] + "DescribeClusterSnapshotsResult" + ]["Snapshots"] list(snapshots).should.have.length_of(1) assert snapshot_identifier in snapshots[0]["SnapshotIdentifier"] # Delete invalid id - conn.delete_cluster.when.called_with( - "not-a-cluster").should.throw(ClusterNotFound) + conn.delete_cluster.when.called_with("not-a-cluster").should.throw(ClusterNotFound) + + +@mock_redshift +def test_modify_cluster_vpc_routing(): + iam_roles_arn = ["arn:aws:iam:::role/my-iam-role"] + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + + client.create_cluster( + ClusterIdentifier=cluster_identifier, + NodeType="single-node", + MasterUsername="username", + MasterUserPassword="password", + IamRoles=iam_roles_arn, + ) + + cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster = cluster_response["Clusters"][0] + cluster["EnhancedVpcRouting"].should.equal(False) + + client.create_cluster_security_group( + ClusterSecurityGroupName="security_group", Description="security_group" + ) + + client.create_cluster_parameter_group( + ParameterGroupName="my_parameter_group", + ParameterGroupFamily="redshift-1.0", + Description="my_parameter_group", + ) + + client.modify_cluster( + ClusterIdentifier=cluster_identifier, + ClusterType="multi-node", + NodeType="ds2.8xlarge", + NumberOfNodes=3, + ClusterSecurityGroups=["security_group"], + MasterUserPassword="new_password", + ClusterParameterGroupName="my_parameter_group", + AutomatedSnapshotRetentionPeriod=7, + PreferredMaintenanceWindow="Tue:03:00-Tue:11:00", + AllowVersionUpgrade=False, + NewClusterIdentifier=cluster_identifier, + EnhancedVpcRouting=True, + ) + + cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster = cluster_response["Clusters"][0] + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("ds2.8xlarge") + cluster["PreferredMaintenanceWindow"].should.equal("Tue:03:00-Tue:11:00") + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(7) + cluster["AllowVersionUpgrade"].should.equal(False) + # This one should remain unmodified. + cluster["NumberOfNodes"].should.equal(3) + cluster["EnhancedVpcRouting"].should.equal(True) @mock_redshift_deprecated def test_modify_cluster(): conn = boto.connect_redshift() - cluster_identifier = 'my_cluster' - conn.create_cluster_security_group( - "security_group", - "This is my security group", - ) + cluster_identifier = "my_cluster" + conn.create_cluster_security_group("security_group", "This is my security group") conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) conn.create_cluster( cluster_identifier, - node_type='single-node', + node_type="single-node", master_username="username", master_user_password="password", ) + cluster_response = conn.describe_clusters(cluster_identifier) + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + cluster["EnhancedVpcRouting"].should.equal(False) + conn.modify_cluster( cluster_identifier, cluster_type="multi-node", @@ -456,49 +540,51 @@ def test_modify_cluster(): automated_snapshot_retention_period=7, preferred_maintenance_window="Tue:03:00-Tue:11:00", allow_version_upgrade=False, - new_cluster_identifier="new_identifier", + new_cluster_identifier=cluster_identifier, ) - cluster_response = conn.describe_clusters("new_identifier") - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - - cluster['ClusterIdentifier'].should.equal("new_identifier") - cluster['NodeType'].should.equal("dw.hs1.xlarge") - cluster['ClusterSecurityGroups'][0][ - 'ClusterSecurityGroupName'].should.equal("security_group") - cluster['PreferredMaintenanceWindow'].should.equal("Tue:03:00-Tue:11:00") - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("my_parameter_group") - cluster['AutomatedSnapshotRetentionPeriod'].should.equal(7) - cluster['AllowVersionUpgrade'].should.equal(False) + cluster_response = conn.describe_clusters(cluster_identifier) + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("dw.hs1.xlarge") + cluster["ClusterSecurityGroups"][0]["ClusterSecurityGroupName"].should.equal( + "security_group" + ) + cluster["PreferredMaintenanceWindow"].should.equal("Tue:03:00-Tue:11:00") + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "my_parameter_group" + ) + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(7) + cluster["AllowVersionUpgrade"].should.equal(False) # This one should remain unmodified. - cluster['NumberOfNodes'].should.equal(1) + cluster["NumberOfNodes"].should.equal(1) @mock_redshift @mock_ec2 def test_create_cluster_subnet_group(): - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet1 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") subnet2 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.1.0/24") - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", SubnetIds=[subnet1.id, subnet2.id], ) subnets_response = client.describe_cluster_subnet_groups( - ClusterSubnetGroupName="my_subnet_group") - my_subnet = subnets_response['ClusterSubnetGroups'][0] + ClusterSubnetGroupName="my_subnet_group" + ) + my_subnet = subnets_response["ClusterSubnetGroups"][0] - my_subnet['ClusterSubnetGroupName'].should.equal("my_subnet_group") - my_subnet['Description'].should.equal("This is my subnet group") - subnet_ids = [subnet['SubnetIdentifier'] - for subnet in my_subnet['Subnets']] + my_subnet["ClusterSubnetGroupName"].should.equal("my_subnet_group") + my_subnet["Description"].should.equal("This is my subnet group") + subnet_ids = [subnet["SubnetIdentifier"] for subnet in my_subnet["Subnets"]] set(subnet_ids).should.equal(set([subnet1.id, subnet2.id])) @@ -507,9 +593,7 @@ def test_create_cluster_subnet_group(): def test_create_invalid_cluster_subnet_group(): redshift_conn = boto.connect_redshift() redshift_conn.create_cluster_subnet_group.when.called_with( - "my_subnet", - "This is my subnet group", - subnet_ids=["subnet-1234"], + "my_subnet", "This is my subnet group", subnet_ids=["subnet-1234"] ).should.throw(InvalidSubnet) @@ -517,748 +601,735 @@ def test_create_invalid_cluster_subnet_group(): def test_describe_non_existent_subnet_group(): conn = boto.redshift.connect_to_region("us-east-1") conn.describe_cluster_subnet_groups.when.called_with( - "not-a-subnet-group").should.throw(ClusterSubnetGroupNotFound) + "not-a-subnet-group" + ).should.throw(ClusterSubnetGroupNotFound) @mock_redshift @mock_ec2 def test_delete_cluster_subnet_group(): - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", SubnetIds=[subnet.id], ) subnets_response = client.describe_cluster_subnet_groups() - subnets = subnets_response['ClusterSubnetGroups'] + subnets = subnets_response["ClusterSubnetGroups"] subnets.should.have.length_of(1) client.delete_cluster_subnet_group(ClusterSubnetGroupName="my_subnet_group") subnets_response = client.describe_cluster_subnet_groups() - subnets = subnets_response['ClusterSubnetGroups'] + subnets = subnets_response["ClusterSubnetGroups"] subnets.should.have.length_of(0) # Delete invalid id client.delete_cluster_subnet_group.when.called_with( - ClusterSubnetGroupName="not-a-subnet-group").should.throw(ClientError) + ClusterSubnetGroupName="not-a-subnet-group" + ).should.throw(ClientError) @mock_redshift_deprecated def test_create_cluster_security_group(): conn = boto.connect_redshift() - conn.create_cluster_security_group( - "my_security_group", - "This is my security group", - ) + conn.create_cluster_security_group("my_security_group", "This is my security group") - groups_response = conn.describe_cluster_security_groups( - "my_security_group") - my_group = groups_response['DescribeClusterSecurityGroupsResponse'][ - 'DescribeClusterSecurityGroupsResult']['ClusterSecurityGroups'][0] + groups_response = conn.describe_cluster_security_groups("my_security_group") + my_group = groups_response["DescribeClusterSecurityGroupsResponse"][ + "DescribeClusterSecurityGroupsResult" + ]["ClusterSecurityGroups"][0] - my_group['ClusterSecurityGroupName'].should.equal("my_security_group") - my_group['Description'].should.equal("This is my security group") - list(my_group['IPRanges']).should.equal([]) + my_group["ClusterSecurityGroupName"].should.equal("my_security_group") + my_group["Description"].should.equal("This is my security group") + list(my_group["IPRanges"]).should.equal([]) @mock_redshift_deprecated def test_describe_non_existent_security_group(): conn = boto.redshift.connect_to_region("us-east-1") conn.describe_cluster_security_groups.when.called_with( - "not-a-security-group").should.throw(ClusterSecurityGroupNotFound) + "not-a-security-group" + ).should.throw(ClusterSecurityGroupNotFound) @mock_redshift_deprecated def test_delete_cluster_security_group(): conn = boto.connect_redshift() - conn.create_cluster_security_group( - "my_security_group", - "This is my security group", - ) + conn.create_cluster_security_group("my_security_group", "This is my security group") groups_response = conn.describe_cluster_security_groups() - groups = groups_response['DescribeClusterSecurityGroupsResponse'][ - 'DescribeClusterSecurityGroupsResult']['ClusterSecurityGroups'] + groups = groups_response["DescribeClusterSecurityGroupsResponse"][ + "DescribeClusterSecurityGroupsResult" + ]["ClusterSecurityGroups"] groups.should.have.length_of(2) # The default group already exists conn.delete_cluster_security_group("my_security_group") groups_response = conn.describe_cluster_security_groups() - groups = groups_response['DescribeClusterSecurityGroupsResponse'][ - 'DescribeClusterSecurityGroupsResult']['ClusterSecurityGroups'] + groups = groups_response["DescribeClusterSecurityGroupsResponse"][ + "DescribeClusterSecurityGroupsResult" + ]["ClusterSecurityGroups"] groups.should.have.length_of(1) # Delete invalid id conn.delete_cluster_security_group.when.called_with( - "not-a-security-group").should.throw(ClusterSecurityGroupNotFound) + "not-a-security-group" + ).should.throw(ClusterSecurityGroupNotFound) @mock_redshift_deprecated def test_create_cluster_parameter_group(): conn = boto.connect_redshift() conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) - groups_response = conn.describe_cluster_parameter_groups( - "my_parameter_group") - my_group = groups_response['DescribeClusterParameterGroupsResponse'][ - 'DescribeClusterParameterGroupsResult']['ParameterGroups'][0] + groups_response = conn.describe_cluster_parameter_groups("my_parameter_group") + my_group = groups_response["DescribeClusterParameterGroupsResponse"][ + "DescribeClusterParameterGroupsResult" + ]["ParameterGroups"][0] - my_group['ParameterGroupName'].should.equal("my_parameter_group") - my_group['ParameterGroupFamily'].should.equal("redshift-1.0") - my_group['Description'].should.equal("This is my parameter group") + my_group["ParameterGroupName"].should.equal("my_parameter_group") + my_group["ParameterGroupFamily"].should.equal("redshift-1.0") + my_group["Description"].should.equal("This is my parameter group") @mock_redshift_deprecated def test_describe_non_existent_parameter_group(): conn = boto.redshift.connect_to_region("us-east-1") conn.describe_cluster_parameter_groups.when.called_with( - "not-a-parameter-group").should.throw(ClusterParameterGroupNotFound) + "not-a-parameter-group" + ).should.throw(ClusterParameterGroupNotFound) @mock_redshift_deprecated def test_delete_cluster_parameter_group(): conn = boto.connect_redshift() conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) groups_response = conn.describe_cluster_parameter_groups() - groups = groups_response['DescribeClusterParameterGroupsResponse'][ - 'DescribeClusterParameterGroupsResult']['ParameterGroups'] + groups = groups_response["DescribeClusterParameterGroupsResponse"][ + "DescribeClusterParameterGroupsResult" + ]["ParameterGroups"] groups.should.have.length_of(2) # The default group already exists conn.delete_cluster_parameter_group("my_parameter_group") groups_response = conn.describe_cluster_parameter_groups() - groups = groups_response['DescribeClusterParameterGroupsResponse'][ - 'DescribeClusterParameterGroupsResult']['ParameterGroups'] + groups = groups_response["DescribeClusterParameterGroupsResponse"][ + "DescribeClusterParameterGroupsResult" + ]["ParameterGroups"] groups.should.have.length_of(1) # Delete invalid id conn.delete_cluster_parameter_group.when.called_with( - "not-a-parameter-group").should.throw(ClusterParameterGroupNotFound) + "not-a-parameter-group" + ).should.throw(ClusterParameterGroupNotFound) @mock_redshift def test_create_cluster_snapshot_of_non_existent_cluster(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'non-existent-cluster-id' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "non-existent-cluster-id" client.create_cluster_snapshot.when.called_with( - SnapshotIdentifier='snapshot-id', - ClusterIdentifier=cluster_identifier, - ).should.throw(ClientError, 'Cluster {} not found.'.format(cluster_identifier)) + SnapshotIdentifier="snapshot-id", ClusterIdentifier=cluster_identifier + ).should.throw(ClientError, "Cluster {} not found.".format(cluster_identifier)) @mock_redshift def test_create_cluster_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" cluster_response = client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + EnhancedVpcRouting=True, ) - cluster_response['Cluster']['NodeType'].should.equal('ds2.xlarge') + cluster_response["Cluster"]["NodeType"].should.equal("ds2.xlarge") snapshot_response = client.create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, - Tags=[{'Key': 'test-tag-key', - 'Value': 'test-tag-value'}] + Tags=[{"Key": "test-tag-key", "Value": "test-tag-value"}], ) - snapshot = snapshot_response['Snapshot'] - snapshot['SnapshotIdentifier'].should.equal(snapshot_identifier) - snapshot['ClusterIdentifier'].should.equal(cluster_identifier) - snapshot['NumberOfNodes'].should.equal(1) - snapshot['NodeType'].should.equal('ds2.xlarge') - snapshot['MasterUsername'].should.equal('username') + snapshot = snapshot_response["Snapshot"] + snapshot["SnapshotIdentifier"].should.equal(snapshot_identifier) + snapshot["ClusterIdentifier"].should.equal(cluster_identifier) + snapshot["NumberOfNodes"].should.equal(1) + snapshot["NodeType"].should.equal("ds2.xlarge") + snapshot["MasterUsername"].should.equal("username") @mock_redshift def test_describe_cluster_snapshots(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier_1 = 'my_snapshot_1' - snapshot_identifier_2 = 'my_snapshot_2' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier_1 = "my_snapshot_1" + snapshot_identifier_2 = "my_snapshot_2" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier_1, - ClusterIdentifier=cluster_identifier, + SnapshotIdentifier=snapshot_identifier_1, ClusterIdentifier=cluster_identifier ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier_2, - ClusterIdentifier=cluster_identifier, + SnapshotIdentifier=snapshot_identifier_2, ClusterIdentifier=cluster_identifier ) - resp_snap_1 = client.describe_cluster_snapshots(SnapshotIdentifier=snapshot_identifier_1) - snapshot_1 = resp_snap_1['Snapshots'][0] - snapshot_1['SnapshotIdentifier'].should.equal(snapshot_identifier_1) - snapshot_1['ClusterIdentifier'].should.equal(cluster_identifier) - snapshot_1['NumberOfNodes'].should.equal(1) - snapshot_1['NodeType'].should.equal('ds2.xlarge') - snapshot_1['MasterUsername'].should.equal('username') + resp_snap_1 = client.describe_cluster_snapshots( + SnapshotIdentifier=snapshot_identifier_1 + ) + snapshot_1 = resp_snap_1["Snapshots"][0] + snapshot_1["SnapshotIdentifier"].should.equal(snapshot_identifier_1) + snapshot_1["ClusterIdentifier"].should.equal(cluster_identifier) + snapshot_1["NumberOfNodes"].should.equal(1) + snapshot_1["NodeType"].should.equal("ds2.xlarge") + snapshot_1["MasterUsername"].should.equal("username") - resp_snap_2 = client.describe_cluster_snapshots(SnapshotIdentifier=snapshot_identifier_2) - snapshot_2 = resp_snap_2['Snapshots'][0] - snapshot_2['SnapshotIdentifier'].should.equal(snapshot_identifier_2) - snapshot_2['ClusterIdentifier'].should.equal(cluster_identifier) - snapshot_2['NumberOfNodes'].should.equal(1) - snapshot_2['NodeType'].should.equal('ds2.xlarge') - snapshot_2['MasterUsername'].should.equal('username') + resp_snap_2 = client.describe_cluster_snapshots( + SnapshotIdentifier=snapshot_identifier_2 + ) + snapshot_2 = resp_snap_2["Snapshots"][0] + snapshot_2["SnapshotIdentifier"].should.equal(snapshot_identifier_2) + snapshot_2["ClusterIdentifier"].should.equal(cluster_identifier) + snapshot_2["NumberOfNodes"].should.equal(1) + snapshot_2["NodeType"].should.equal("ds2.xlarge") + snapshot_2["MasterUsername"].should.equal("username") resp_clust = client.describe_cluster_snapshots(ClusterIdentifier=cluster_identifier) - resp_clust['Snapshots'][0].should.equal(resp_snap_1['Snapshots'][0]) - resp_clust['Snapshots'][1].should.equal(resp_snap_2['Snapshots'][0]) + resp_clust["Snapshots"][0].should.equal(resp_snap_1["Snapshots"][0]) + resp_clust["Snapshots"][1].should.equal(resp_snap_2["Snapshots"][0]) @mock_redshift def test_describe_cluster_snapshots_not_found_error(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" client.describe_cluster_snapshots.when.called_with( - ClusterIdentifier=cluster_identifier, - ).should.throw(ClientError, 'Cluster {} not found.'.format(cluster_identifier)) + ClusterIdentifier=cluster_identifier + ).should.throw(ClientError, "Cluster {} not found.".format(cluster_identifier)) client.describe_cluster_snapshots.when.called_with( SnapshotIdentifier=snapshot_identifier - ).should.throw(ClientError, 'Snapshot {} not found.'.format(snapshot_identifier)) + ).should.throw(ClientError, "Snapshot {} not found.".format(snapshot_identifier)) @mock_redshift def test_delete_cluster_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" client.create_cluster( ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier, - ClusterIdentifier=cluster_identifier + SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier ) - snapshots = client.describe_cluster_snapshots()['Snapshots'] + snapshots = client.describe_cluster_snapshots()["Snapshots"] list(snapshots).should.have.length_of(1) - client.delete_cluster_snapshot(SnapshotIdentifier=snapshot_identifier)[ - 'Snapshot']['Status'].should.equal('deleted') + client.delete_cluster_snapshot(SnapshotIdentifier=snapshot_identifier)["Snapshot"][ + "Status" + ].should.equal("deleted") - snapshots = client.describe_cluster_snapshots()['Snapshots'] + snapshots = client.describe_cluster_snapshots()["Snapshots"] list(snapshots).should.have.length_of(0) # Delete invalid id client.delete_cluster_snapshot.when.called_with( - SnapshotIdentifier="not-a-snapshot").should.throw(ClientError) + SnapshotIdentifier="not-a-snapshot" + ).should.throw(ClientError) @mock_redshift def test_cluster_snapshot_already_exists(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier, - ClusterIdentifier=cluster_identifier + SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier ) client.create_cluster_snapshot.when.called_with( - SnapshotIdentifier=snapshot_identifier, - ClusterIdentifier=cluster_identifier + SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier ).should.throw(ClientError) @mock_redshift def test_create_cluster_from_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') - original_cluster_identifier = 'original-cluster' - original_snapshot_identifier = 'original-snapshot' - new_cluster_identifier = 'new-cluster' + client = boto3.client("redshift", region_name="us-east-1") + original_cluster_identifier = "original-cluster" + original_snapshot_identifier = "original-snapshot" + new_cluster_identifier = "new-cluster" client.create_cluster( ClusterIdentifier=original_cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + EnhancedVpcRouting=True, ) + client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, - ClusterIdentifier=original_cluster_identifier + ClusterIdentifier=original_cluster_identifier, ) + response = client.restore_from_cluster_snapshot( ClusterIdentifier=new_cluster_identifier, SnapshotIdentifier=original_snapshot_identifier, - Port=1234 + Port=1234, ) - response['Cluster']['ClusterStatus'].should.equal('creating') + response["Cluster"]["ClusterStatus"].should.equal("creating") - response = client.describe_clusters( - ClusterIdentifier=new_cluster_identifier - ) - new_cluster = response['Clusters'][0] - new_cluster['NodeType'].should.equal('ds2.xlarge') - new_cluster['MasterUsername'].should.equal('username') - new_cluster['Endpoint']['Port'].should.equal(1234) + response = client.describe_clusters(ClusterIdentifier=new_cluster_identifier) + new_cluster = response["Clusters"][0] + new_cluster["NodeType"].should.equal("ds2.xlarge") + new_cluster["MasterUsername"].should.equal("username") + new_cluster["Endpoint"]["Port"].should.equal(1234) + new_cluster["EnhancedVpcRouting"].should.equal(True) @mock_redshift def test_create_cluster_from_snapshot_with_waiter(): - client = boto3.client('redshift', region_name='us-east-1') - original_cluster_identifier = 'original-cluster' - original_snapshot_identifier = 'original-snapshot' - new_cluster_identifier = 'new-cluster' + client = boto3.client("redshift", region_name="us-east-1") + original_cluster_identifier = "original-cluster" + original_snapshot_identifier = "original-snapshot" + new_cluster_identifier = "new-cluster" client.create_cluster( ClusterIdentifier=original_cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + EnhancedVpcRouting=True, ) client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, - ClusterIdentifier=original_cluster_identifier + ClusterIdentifier=original_cluster_identifier, ) response = client.restore_from_cluster_snapshot( ClusterIdentifier=new_cluster_identifier, SnapshotIdentifier=original_snapshot_identifier, - Port=1234 + Port=1234, ) - response['Cluster']['ClusterStatus'].should.equal('creating') + response["Cluster"]["ClusterStatus"].should.equal("creating") - client.get_waiter('cluster_restored').wait( + client.get_waiter("cluster_restored").wait( ClusterIdentifier=new_cluster_identifier, - WaiterConfig={ - 'Delay': 1, - 'MaxAttempts': 2, - } + WaiterConfig={"Delay": 1, "MaxAttempts": 2}, ) - response = client.describe_clusters( - ClusterIdentifier=new_cluster_identifier - ) - new_cluster = response['Clusters'][0] - new_cluster['NodeType'].should.equal('ds2.xlarge') - new_cluster['MasterUsername'].should.equal('username') - new_cluster['Endpoint']['Port'].should.equal(1234) + response = client.describe_clusters(ClusterIdentifier=new_cluster_identifier) + new_cluster = response["Clusters"][0] + new_cluster["NodeType"].should.equal("ds2.xlarge") + new_cluster["MasterUsername"].should.equal("username") + new_cluster["EnhancedVpcRouting"].should.equal(True) + new_cluster["Endpoint"]["Port"].should.equal(1234) @mock_redshift def test_create_cluster_from_non_existent_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.restore_from_cluster_snapshot.when.called_with( - ClusterIdentifier='cluster-id', - SnapshotIdentifier='non-existent-snapshot', - ).should.throw(ClientError, 'Snapshot non-existent-snapshot not found.') + ClusterIdentifier="cluster-id", SnapshotIdentifier="non-existent-snapshot" + ).should.throw(ClientError, "Snapshot non-existent-snapshot not found.") @mock_redshift def test_create_cluster_status_update(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'test-cluster' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "test-cluster" response = client.create_cluster( ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) - response['Cluster']['ClusterStatus'].should.equal('creating') + response["Cluster"]["ClusterStatus"].should.equal("creating") - response = client.describe_clusters( - ClusterIdentifier=cluster_identifier - ) - response['Clusters'][0]['ClusterStatus'].should.equal('available') + response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + response["Clusters"][0]["ClusterStatus"].should.equal("available") @mock_redshift def test_describe_tags_with_resource_type(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - snapshot_identifier = 'my_snapshot' - snapshot_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'snapshot:{}/{}'.format(cluster_identifier, - snapshot_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + cluster_arn = "arn:aws:redshift:us-east-1:{}:" "cluster:{}".format( + ACCOUNT_ID, cluster_identifier + ) + snapshot_identifier = "my_snapshot" + snapshot_arn = "arn:aws:redshift:us-east-1:{}:" "snapshot:{}/{}".format( + ACCOUNT_ID, cluster_identifier, snapshot_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - Tags=[{'Key': tag_key, - 'Value': tag_value}] + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + Tags=[{"Key": tag_key, "Value": tag_value}], ) - tags_response = client.describe_tags(ResourceType='cluster') - tagged_resources = tags_response['TaggedResources'] + tags_response = client.describe_tags(ResourceType="cluster") + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('cluster') - tagged_resources[0]['ResourceName'].should.equal(cluster_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("cluster") + tagged_resources[0]["ResourceName"].should.equal(cluster_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) client.create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, - Tags=[{'Key': tag_key, - 'Value': tag_value}] + Tags=[{"Key": tag_key, "Value": tag_value}], ) - tags_response = client.describe_tags(ResourceType='snapshot') - tagged_resources = tags_response['TaggedResources'] + tags_response = client.describe_tags(ResourceType="snapshot") + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('snapshot') - tagged_resources[0]['ResourceName'].should.equal(snapshot_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("snapshot") + tagged_resources[0]["ResourceName"].should.equal(snapshot_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) @mock_redshift def test_describe_tags_cannot_specify_resource_type_and_resource_name(): - client = boto3.client('redshift', region_name='us-east-1') - resource_name = 'arn:aws:redshift:us-east-1:123456789012:cluster:cluster-id' - resource_type = 'cluster' + client = boto3.client("redshift", region_name="us-east-1") + resource_name = "arn:aws:redshift:us-east-1:{}:cluster:cluster-id".format( + ACCOUNT_ID + ) + resource_type = "cluster" client.describe_tags.when.called_with( - ResourceName=resource_name, - ResourceType=resource_type - ).should.throw(ClientError, 'using either an ARN or a resource type') + ResourceName=resource_name, ResourceType=resource_type + ).should.throw(ClientError, "using either an ARN or a resource type") @mock_redshift def test_describe_tags_with_resource_name(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'cluster-id' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - snapshot_identifier = 'snapshot-id' - snapshot_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'snapshot:{}/{}'.format(cluster_identifier, - snapshot_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "cluster-id" + cluster_arn = "arn:aws:redshift:us-east-1:{}:" "cluster:{}".format( + ACCOUNT_ID, cluster_identifier + ) + snapshot_identifier = "snapshot-id" + snapshot_arn = "arn:aws:redshift:us-east-1:{}:" "snapshot:{}/{}".format( + ACCOUNT_ID, cluster_identifier, snapshot_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - Tags=[{'Key': tag_key, - 'Value': tag_value}] + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + Tags=[{"Key": tag_key, "Value": tag_value}], ) tags_response = client.describe_tags(ResourceName=cluster_arn) - tagged_resources = tags_response['TaggedResources'] + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('cluster') - tagged_resources[0]['ResourceName'].should.equal(cluster_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("cluster") + tagged_resources[0]["ResourceName"].should.equal(cluster_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) client.create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, - Tags=[{'Key': tag_key, - 'Value': tag_value}] + Tags=[{"Key": tag_key, "Value": tag_value}], ) tags_response = client.describe_tags(ResourceName=snapshot_arn) - tagged_resources = tags_response['TaggedResources'] + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('snapshot') - tagged_resources[0]['ResourceName'].should.equal(snapshot_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("snapshot") + tagged_resources[0]["ResourceName"].should.equal(snapshot_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) @mock_redshift def test_create_tags(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'cluster-id' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "cluster-id" + cluster_arn = "arn:aws:redshift:us-east-1:{}:" "cluster:{}".format( + ACCOUNT_ID, cluster_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" num_tags = 5 tags = [] for i in range(0, num_tags): - tag = {'Key': '{}-{}'.format(tag_key, i), - 'Value': '{}-{}'.format(tag_value, i)} + tag = {"Key": "{}-{}".format(tag_key, i), "Value": "{}-{}".format(tag_value, i)} tags.append(tag) client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - ) - client.create_tags( - ResourceName=cluster_arn, - Tags=tags + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) + client.create_tags(ResourceName=cluster_arn, Tags=tags) response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = response['Clusters'][0] - list(cluster['Tags']).should.have.length_of(num_tags) + cluster = response["Clusters"][0] + list(cluster["Tags"]).should.have.length_of(num_tags) response = client.describe_tags(ResourceName=cluster_arn) - list(response['TaggedResources']).should.have.length_of(num_tags) + list(response["TaggedResources"]).should.have.length_of(num_tags) @mock_redshift def test_delete_tags(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'cluster-id' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "cluster-id" + cluster_arn = "arn:aws:redshift:us-east-1:{}:" "cluster:{}".format( + ACCOUNT_ID, cluster_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" tags = [] for i in range(1, 2): - tag = {'Key': '{}-{}'.format(tag_key, i), - 'Value': '{}-{}'.format(tag_value, i)} + tag = {"Key": "{}-{}".format(tag_key, i), "Value": "{}-{}".format(tag_value, i)} tags.append(tag) client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - Tags=tags + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + Tags=tags, ) client.delete_tags( ResourceName=cluster_arn, - TagKeys=[tag['Key'] for tag in tags - if tag['Key'] != '{}-1'.format(tag_key)] + TagKeys=[tag["Key"] for tag in tags if tag["Key"] != "{}-1".format(tag_key)], ) response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = response['Clusters'][0] - list(cluster['Tags']).should.have.length_of(1) + cluster = response["Clusters"][0] + list(cluster["Tags"]).should.have.length_of(1) response = client.describe_tags(ResourceName=cluster_arn) - list(response['TaggedResources']).should.have.length_of(1) + list(response["TaggedResources"]).should.have.length_of(1) @mock_ec2 @mock_redshift def test_describe_tags_all_resource_types(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') - subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock='10.0.0.0/24') - client = boto3.client('redshift', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") + client = boto3.client("redshift", region_name="us-east-1") response = client.describe_tags() - list(response['TaggedResources']).should.have.length_of(0) + list(response["TaggedResources"]).should.have.length_of(0) client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", SubnetIds=[subnet.id], - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster_security_group( ClusterSecurityGroupName="security_group1", Description="This is my security group", - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster( - DBName='test', - ClusterIdentifier='my_cluster', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + DBName="test", + ClusterIdentifier="my_cluster", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster_snapshot( - SnapshotIdentifier='my_snapshot', - ClusterIdentifier='my_cluster', - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + SnapshotIdentifier="my_snapshot", + ClusterIdentifier="my_cluster", + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster_parameter_group( ParameterGroupName="my_parameter_group", ParameterGroupFamily="redshift-1.0", Description="This is my parameter group", - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) response = client.describe_tags() - expected_types = ['cluster', 'parametergroup', 'securitygroup', 'snapshot', 'subnetgroup'] - tagged_resources = response['TaggedResources'] - returned_types = [resource['ResourceType'] for resource in tagged_resources] + expected_types = [ + "cluster", + "parametergroup", + "securitygroup", + "snapshot", + "subnetgroup", + ] + tagged_resources = response["TaggedResources"] + returned_types = [resource["ResourceType"] for resource in tagged_resources] list(tagged_resources).should.have.length_of(len(expected_types)) set(returned_types).should.equal(set(expected_types)) @mock_redshift def test_tagged_resource_not_found_error(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") - cluster_arn = 'arn:aws:redshift:us-east-1::cluster:fake' - client.describe_tags.when.called_with( - ResourceName=cluster_arn - ).should.throw(ClientError, 'cluster (fake) not found.') + cluster_arn = "arn:aws:redshift:us-east-1::cluster:fake" + client.describe_tags.when.called_with(ResourceName=cluster_arn).should.throw( + ClientError, "cluster (fake) not found." + ) - snapshot_arn = 'arn:aws:redshift:us-east-1::snapshot:cluster-id/snap-id' + snapshot_arn = "arn:aws:redshift:us-east-1::snapshot:cluster-id/snap-id" client.delete_tags.when.called_with( - ResourceName=snapshot_arn, - TagKeys=['test'] - ).should.throw(ClientError, 'snapshot (snap-id) not found.') + ResourceName=snapshot_arn, TagKeys=["test"] + ).should.throw(ClientError, "snapshot (snap-id) not found.") - client.describe_tags.when.called_with( - ResourceType='cluster' - ).should.throw(ClientError, "resource of type 'cluster' not found.") + client.describe_tags.when.called_with(ResourceType="cluster").should.throw( + ClientError, "resource of type 'cluster' not found." + ) - client.describe_tags.when.called_with( - ResourceName='bad:arn' - ).should.throw(ClientError, "Tagging is not supported for this type of resource") + client.describe_tags.when.called_with(ResourceName="bad:arn").should.throw( + ClientError, "Tagging is not supported for this type of resource" + ) @mock_redshift def test_enable_snapshot_copy(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - ClusterIdentifier='test', - ClusterType='single-node', - DBName='test', + ClusterIdentifier="test", + ClusterType="single-node", + DBName="test", Encrypted=True, - MasterUsername='user', - MasterUserPassword='password', - NodeType='ds2.xlarge', + MasterUsername="user", + MasterUserPassword="password", + NodeType="ds2.xlarge", ) client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', + ClusterIdentifier="test", + DestinationRegion="us-west-2", RetentionPeriod=3, - SnapshotCopyGrantName='copy-us-east-1-to-us-west-2' + SnapshotCopyGrantName="copy-us-east-1-to-us-west-2", + ) + response = client.describe_clusters(ClusterIdentifier="test") + cluster_snapshot_copy_status = response["Clusters"][0]["ClusterSnapshotCopyStatus"] + cluster_snapshot_copy_status["RetentionPeriod"].should.equal(3) + cluster_snapshot_copy_status["DestinationRegion"].should.equal("us-west-2") + cluster_snapshot_copy_status["SnapshotCopyGrantName"].should.equal( + "copy-us-east-1-to-us-west-2" ) - response = client.describe_clusters(ClusterIdentifier='test') - cluster_snapshot_copy_status = response['Clusters'][0]['ClusterSnapshotCopyStatus'] - cluster_snapshot_copy_status['RetentionPeriod'].should.equal(3) - cluster_snapshot_copy_status['DestinationRegion'].should.equal('us-west-2') - cluster_snapshot_copy_status['SnapshotCopyGrantName'].should.equal('copy-us-east-1-to-us-west-2') @mock_redshift def test_enable_snapshot_copy_unencrypted(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - ClusterIdentifier='test', - ClusterType='single-node', - DBName='test', - MasterUsername='user', - MasterUserPassword='password', - NodeType='ds2.xlarge', + ClusterIdentifier="test", + ClusterType="single-node", + DBName="test", + MasterUsername="user", + MasterUserPassword="password", + NodeType="ds2.xlarge", ) - client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', - ) - response = client.describe_clusters(ClusterIdentifier='test') - cluster_snapshot_copy_status = response['Clusters'][0]['ClusterSnapshotCopyStatus'] - cluster_snapshot_copy_status['RetentionPeriod'].should.equal(7) - cluster_snapshot_copy_status['DestinationRegion'].should.equal('us-west-2') + client.enable_snapshot_copy(ClusterIdentifier="test", DestinationRegion="us-west-2") + response = client.describe_clusters(ClusterIdentifier="test") + cluster_snapshot_copy_status = response["Clusters"][0]["ClusterSnapshotCopyStatus"] + cluster_snapshot_copy_status["RetentionPeriod"].should.equal(7) + cluster_snapshot_copy_status["DestinationRegion"].should.equal("us-west-2") @mock_redshift def test_disable_snapshot_copy(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - DBName='test', - ClusterIdentifier='test', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", ) client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', + ClusterIdentifier="test", + DestinationRegion="us-west-2", RetentionPeriod=3, - SnapshotCopyGrantName='copy-us-east-1-to-us-west-2', + SnapshotCopyGrantName="copy-us-east-1-to-us-west-2", ) - client.disable_snapshot_copy( - ClusterIdentifier='test', - ) - response = client.describe_clusters(ClusterIdentifier='test') - response['Clusters'][0].shouldnt.contain('ClusterSnapshotCopyStatus') + client.disable_snapshot_copy(ClusterIdentifier="test") + response = client.describe_clusters(ClusterIdentifier="test") + response["Clusters"][0].shouldnt.contain("ClusterSnapshotCopyStatus") @mock_redshift def test_modify_snapshot_copy_retention_period(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - DBName='test', - ClusterIdentifier='test', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", ) client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', + ClusterIdentifier="test", + DestinationRegion="us-west-2", RetentionPeriod=3, - SnapshotCopyGrantName='copy-us-east-1-to-us-west-2', + SnapshotCopyGrantName="copy-us-east-1-to-us-west-2", ) client.modify_snapshot_copy_retention_period( - ClusterIdentifier='test', - RetentionPeriod=5, + ClusterIdentifier="test", RetentionPeriod=5 ) - response = client.describe_clusters(ClusterIdentifier='test') - cluster_snapshot_copy_status = response['Clusters'][0]['ClusterSnapshotCopyStatus'] - cluster_snapshot_copy_status['RetentionPeriod'].should.equal(5) + response = client.describe_clusters(ClusterIdentifier="test") + cluster_snapshot_copy_status = response["Clusters"][0]["ClusterSnapshotCopyStatus"] + cluster_snapshot_copy_status["RetentionPeriod"].should.equal(5) diff --git a/tests/test_redshift/test_server.py b/tests/test_redshift/test_server.py index 47ccdc5f3..f4eee85e8 100644 --- a/tests/test_redshift/test_server.py +++ b/tests/test_redshift/test_server.py @@ -1,22 +1,22 @@ -from __future__ import unicode_literals - -import json -import sure # noqa - -import moto.server as server -from moto import mock_redshift - -''' -Test the different server responses -''' - - -@mock_redshift -def test_describe_clusters(): - backend = server.create_backend_app("redshift") - test_client = backend.test_client() - - res = test_client.get('/?Action=DescribeClusters') - - result = res.data.decode("utf-8") - result.should.contain("") +from __future__ import unicode_literals + +import json +import sure # noqa + +import moto.server as server +from moto import mock_redshift + +""" +Test the different server responses +""" + + +@mock_redshift +def test_describe_clusters(): + backend = server.create_backend_app("redshift") + test_client = backend.test_client() + + res = test_client.get("/?Action=DescribeClusters") + + result = res.data.decode("utf-8") + result.should.contain("") diff --git a/tests/test_resourcegroups/test_resourcegroups.py b/tests/test_resourcegroups/test_resourcegroups.py index bb3624413..29af9aad7 100644 --- a/tests/test_resourcegroups/test_resourcegroups.py +++ b/tests/test_resourcegroups/test_resourcegroups.py @@ -25,11 +25,13 @@ def test_create_group(): } ), }, - Tags={"resource_group_tag_key": "resource_group_tag_value"} + Tags={"resource_group_tag_key": "resource_group_tag_value"}, ) response["Group"]["Name"].should.contain("test_resource_group") response["ResourceQuery"]["Type"].should.contain("TAG_FILTERS_1_0") - response["Tags"]["resource_group_tag_key"].should.contain("resource_group_tag_value") + response["Tags"]["resource_group_tag_key"].should.contain( + "resource_group_tag_value" + ) @mock_resourcegroups @@ -76,7 +78,9 @@ def test_get_tags(): response = resource_groups.get_tags(Arn=response["Group"]["GroupArn"]) response["Tags"].should.have.length_of(1) - response["Tags"]["resource_group_tag_key"].should.contain("resource_group_tag_value") + response["Tags"]["resource_group_tag_key"].should.contain( + "resource_group_tag_value" + ) return response @@ -100,13 +104,17 @@ def test_tag(): response = resource_groups.tag( Arn=response["Arn"], - Tags={"resource_group_tag_key_2": "resource_group_tag_value_2"} + Tags={"resource_group_tag_key_2": "resource_group_tag_value_2"}, + ) + response["Tags"]["resource_group_tag_key_2"].should.contain( + "resource_group_tag_value_2" ) - response["Tags"]["resource_group_tag_key_2"].should.contain("resource_group_tag_value_2") response = resource_groups.get_tags(Arn=response["Arn"]) response["Tags"].should.have.length_of(2) - response["Tags"]["resource_group_tag_key_2"].should.contain("resource_group_tag_value_2") + response["Tags"]["resource_group_tag_key_2"].should.contain( + "resource_group_tag_value_2" + ) @mock_resourcegroups @@ -115,7 +123,9 @@ def test_untag(): response = test_get_tags() - response = resource_groups.untag(Arn=response["Arn"], Keys=["resource_group_tag_key"]) + response = resource_groups.untag( + Arn=response["Arn"], Keys=["resource_group_tag_key"] + ) response["Keys"].should.contain("resource_group_tag_key") response = resource_groups.get_tags(Arn=response["Arn"]) @@ -129,8 +139,7 @@ def test_update_group(): test_get_group() response = resource_groups.update_group( - GroupName="test_resource_group", - Description="description_2", + GroupName="test_resource_group", Description="description_2" ) response["Group"]["Description"].should.contain("description_2") @@ -154,12 +163,16 @@ def test_update_group_query(): "StackIdentifier": ( "arn:aws:cloudformation:eu-west-1:012345678912:stack/" "test_stack/c223eca0-e744-11e8-8910-500c41f59083" - ) + ), } ), }, ) - response["GroupQuery"]["ResourceQuery"]["Type"].should.contain("CLOUDFORMATION_STACK_1_0") + response["GroupQuery"]["ResourceQuery"]["Type"].should.contain( + "CLOUDFORMATION_STACK_1_0" + ) response = resource_groups.get_group_query(GroupName="test_resource_group") - response["GroupQuery"]["ResourceQuery"]["Type"].should.contain("CLOUDFORMATION_STACK_1_0") + response["GroupQuery"]["ResourceQuery"]["Type"].should.contain( + "CLOUDFORMATION_STACK_1_0" + ) diff --git a/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py b/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py index 1e42dfe55..84f7a8b86 100644 --- a/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py +++ b/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py @@ -13,7 +13,7 @@ from moto import mock_s3 @mock_resourcegroupstaggingapi def test_get_resources_s3(): # Tests pagination - s3_client = boto3.client('s3', region_name='eu-central-1') + s3_client = boto3.client("s3", region_name="eu-central-1") # Will end up having key1,key2,key3,key4 response_keys = set() @@ -21,26 +21,25 @@ def test_get_resources_s3(): # Create 4 buckets for i in range(1, 5): i_str = str(i) - s3_client.create_bucket(Bucket='test_bucket' + i_str) + s3_client.create_bucket(Bucket="test_bucket" + i_str) s3_client.put_bucket_tagging( - Bucket='test_bucket' + i_str, - Tagging={'TagSet': [{'Key': 'key' + i_str, 'Value': 'value' + i_str}]} + Bucket="test_bucket" + i_str, + Tagging={"TagSet": [{"Key": "key" + i_str, "Value": "value" + i_str}]}, ) - response_keys.add('key' + i_str) + response_keys.add("key" + i_str) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") resp = rtapi.get_resources(ResourcesPerPage=2) - for resource in resp['ResourceTagMappingList']: - response_keys.remove(resource['Tags'][0]['Key']) + for resource in resp["ResourceTagMappingList"]: + response_keys.remove(resource["Tags"][0]["Key"]) response_keys.should.have.length_of(2) resp = rtapi.get_resources( - ResourcesPerPage=2, - PaginationToken=resp['PaginationToken'] + ResourcesPerPage=2, PaginationToken=resp["PaginationToken"] ) - for resource in resp['ResourceTagMappingList']: - response_keys.remove(resource['Tags'][0]['Key']) + for resource in resp["ResourceTagMappingList"]: + response_keys.remove(resource["Tags"][0]["Key"]) response_keys.should.have.length_of(0) @@ -48,109 +47,86 @@ def test_get_resources_s3(): @mock_ec2 @mock_resourcegroupstaggingapi def test_get_resources_ec2(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") instances = client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) - instance_id = instances['Instances'][0]['InstanceId'] - image_id = client.create_image(Name='testami', InstanceId=instance_id)['ImageId'] + instance_id = instances["Instances"][0]["InstanceId"] + image_id = client.create_image(Name="testami", InstanceId=instance_id)["ImageId"] - client.create_tags( - Resources=[image_id], - Tags=[{'Key': 'ami', 'Value': 'test'}] - ) + client.create_tags(Resources=[image_id], Tags=[{"Key": "ami", "Value": "test"}]) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") resp = rtapi.get_resources() # Check we have 1 entry for Instance, 1 Entry for AMI - resp['ResourceTagMappingList'].should.have.length_of(2) + resp["ResourceTagMappingList"].should.have.length_of(2) # 1 Entry for AMI - resp = rtapi.get_resources(ResourceTypeFilters=['ec2:image']) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('image/') + resp = rtapi.get_resources(ResourceTypeFilters=["ec2:image"]) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("image/") # As were iterating the same data, this rules out that the test above was a fluke - resp = rtapi.get_resources(ResourceTypeFilters=['ec2:instance']) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('instance/') + resp = rtapi.get_resources(ResourceTypeFilters=["ec2:instance"]) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("instance/") # Basic test of tag filters - resp = rtapi.get_resources(TagFilters=[{'Key': 'MY_TAG1', 'Values': ['MY_VALUE1', 'some_other_value']}]) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('instance/') + resp = rtapi.get_resources( + TagFilters=[{"Key": "MY_TAG1", "Values": ["MY_VALUE1", "some_other_value"]}] + ) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("instance/") @mock_ec2 @mock_resourcegroupstaggingapi def test_get_tag_keys_ec2(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") resp = rtapi.get_tag_keys() - resp['TagKeys'].should.contain('MY_TAG1') - resp['TagKeys'].should.contain('MY_TAG2') - resp['TagKeys'].should.contain('MY_TAG3') + resp["TagKeys"].should.contain("MY_TAG1") + resp["TagKeys"].should.contain("MY_TAG2") + resp["TagKeys"].should.contain("MY_TAG3") # TODO test pagenation @@ -158,148 +134,114 @@ def test_get_tag_keys_ec2(): @mock_ec2 @mock_resourcegroupstaggingapi def test_get_tag_values_ec2(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE4', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE5', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE4"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE5"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE6', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE6"}], }, ], ) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') - resp = rtapi.get_tag_values(Key='MY_TAG1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") + resp = rtapi.get_tag_values(Key="MY_TAG1") + + resp["TagValues"].should.contain("MY_VALUE1") + resp["TagValues"].should.contain("MY_VALUE4") - resp['TagValues'].should.contain('MY_VALUE1') - resp['TagValues'].should.contain('MY_VALUE4') @mock_ec2 @mock_elbv2 @mock_kms @mock_resourcegroupstaggingapi def test_get_many_resources(): - elbv2 = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') - kms = boto3.client('kms', region_name='us-east-1') + elbv2 = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + kms = boto3.client("kms", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) elbv2.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', + Scheme="internal", Tags=[ - { - 'Key': 'key_name', - 'Value': 'a_value' - }, - { - 'Key': 'key_2', - 'Value': 'val2' - } - ] - ) + {"Key": "key_name", "Value": "a_value"}, + {"Key": "key_2", "Value": "val2"}, + ], + ) elbv2.create_load_balancer( - Name='my-other-lb', + Name="my-other-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - ) + Scheme="internal", + ) kms.create_key( - KeyUsage='ENCRYPT_DECRYPT', + KeyUsage="ENCRYPT_DECRYPT", Tags=[ - { - 'TagKey': 'key_name', - 'TagValue': 'a_value' - }, - { - 'TagKey': 'key_2', - 'TagValue': 'val2' - } - ] - ) + {"TagKey": "key_name", "TagValue": "a_value"}, + {"TagKey": "key_2", "TagValue": "val2"}, + ], + ) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='us-east-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="us-east-1") - resp = rtapi.get_resources(ResourceTypeFilters=['elasticloadbalancer:loadbalancer']) + resp = rtapi.get_resources(ResourceTypeFilters=["elasticloadbalancer:loadbalancer"]) - resp['ResourceTagMappingList'].should.have.length_of(2) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('loadbalancer/') + resp["ResourceTagMappingList"].should.have.length_of(2) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("loadbalancer/") resp = rtapi.get_resources( - ResourceTypeFilters=['elasticloadbalancer:loadbalancer'], - TagFilters=[{ - 'Key': 'key_name' - }] - ) + ResourceTypeFilters=["elasticloadbalancer:loadbalancer"], + TagFilters=[{"Key": "key_name"}], + ) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['Tags'].should.contain({'Key': 'key_name', 'Value': 'a_value'}) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["Tags"].should.contain( + {"Key": "key_name", "Value": "a_value"} + ) # TODO test pagenation diff --git a/tests/test_resourcegroupstaggingapi/test_server.py b/tests/test_resourcegroupstaggingapi/test_server.py index 80a74b0b8..836fa5828 100644 --- a/tests/test_resourcegroupstaggingapi/test_server.py +++ b/tests/test_resourcegroupstaggingapi/test_server.py @@ -1,24 +1,24 @@ -from __future__ import unicode_literals - -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_resourcegroupstaggingapi_list(): - backend = server.create_backend_app("resourcegroupstaggingapi") - test_client = backend.test_client() - # do test - - headers = { - 'X-Amz-Target': 'ResourceGroupsTaggingAPI_20170126.GetResources', - 'X-Amz-Date': '20171114T234623Z' - } - resp = test_client.post('/', headers=headers, data='{}') - - assert resp.status_code == 200 - assert b'ResourceTagMappingList' in resp.data +from __future__ import unicode_literals + +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_resourcegroupstaggingapi_list(): + backend = server.create_backend_app("resourcegroupstaggingapi") + test_client = backend.test_client() + # do test + + headers = { + "X-Amz-Target": "ResourceGroupsTaggingAPI_20170126.GetResources", + "X-Amz-Date": "20171114T234623Z", + } + resp = test_client.post("/", headers=headers, data="{}") + + assert resp.status_code == 200 + assert b"ResourceTagMappingList" in resp.data diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py index de9465d6d..0e9a1e2c0 100644 --- a/tests/test_route53/test_route53.py +++ b/tests/test_route53/test_route53.py @@ -17,7 +17,7 @@ from moto import mock_route53, mock_route53_deprecated @mock_route53_deprecated def test_hosted_zone(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") firstzone = conn.create_hosted_zone("testdns.aws.com") zones = conn.get_all_hosted_zones() len(zones["ListHostedZonesResponse"]["HostedZones"]).should.equal(1) @@ -26,30 +26,29 @@ def test_hosted_zone(): zones = conn.get_all_hosted_zones() len(zones["ListHostedZonesResponse"]["HostedZones"]).should.equal(2) - id1 = firstzone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + id1 = firstzone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] zone = conn.get_hosted_zone(id1) - zone["GetHostedZoneResponse"]["HostedZone"][ - "Name"].should.equal("testdns.aws.com.") + zone["GetHostedZoneResponse"]["HostedZone"]["Name"].should.equal("testdns.aws.com.") conn.delete_hosted_zone(id1) zones = conn.get_all_hosted_zones() len(zones["ListHostedZonesResponse"]["HostedZones"]).should.equal(1) conn.get_hosted_zone.when.called_with("abcd").should.throw( - boto.route53.exception.DNSServerError, "404 Not Found") + boto.route53.exception.DNSServerError, "404 Not Found" + ) @mock_route53_deprecated def test_rrset(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") conn.get_all_rrsets.when.called_with("abcd", type="A").should.throw( - boto.route53.exception.DNSServerError, "404 Not Found") + boto.route53.exception.DNSServerError, "404 Not Found" + ) zone = conn.create_hosted_zone("testdns.aws.com") - zoneid = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zoneid = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("CREATE", "foo.bar.testdns.aws.com", "A") @@ -58,7 +57,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('1.2.3.4') + rrsets[0].resource_records[0].should.equal("1.2.3.4") rrsets = conn.get_all_rrsets(zoneid, type="CNAME") rrsets.should.have.length_of(0) @@ -71,7 +70,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('5.6.7.8') + rrsets[0].resource_records[0].should.equal("5.6.7.8") changes = ResourceRecordSets(conn, zoneid) changes.add_change("DELETE", "foo.bar.testdns.aws.com", "A") @@ -87,7 +86,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('1.2.3.4') + rrsets[0].resource_records[0].should.equal("1.2.3.4") changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("UPSERT", "foo.bar.testdns.aws.com", "A") @@ -96,7 +95,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('5.6.7.8') + rrsets[0].resource_records[0].should.equal("5.6.7.8") changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("UPSERT", "foo.bar.testdns.aws.com", "TXT") @@ -105,8 +104,8 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid) rrsets.should.have.length_of(2) - rrsets[0].resource_records[0].should.equal('5.6.7.8') - rrsets[1].resource_records[0].should.equal('foo') + rrsets[0].resource_records[0].should.equal("5.6.7.8") + rrsets[1].resource_records[0].should.equal("foo") changes = ResourceRecordSets(conn, zoneid) changes.add_change("DELETE", "foo.bar.testdns.aws.com", "A") @@ -123,29 +122,25 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(2) - rrsets = conn.get_all_rrsets( - zoneid, name="bar.foo.testdns.aws.com", type="A") + rrsets = conn.get_all_rrsets(zoneid, name="bar.foo.testdns.aws.com", type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('5.6.7.8') + rrsets[0].resource_records[0].should.equal("5.6.7.8") - rrsets = conn.get_all_rrsets( - zoneid, name="foo.bar.testdns.aws.com", type="A") + rrsets = conn.get_all_rrsets(zoneid, name="foo.bar.testdns.aws.com", type="A") rrsets.should.have.length_of(2) resource_records = [rr for rr_set in rrsets for rr in rr_set.resource_records] - resource_records.should.contain('1.2.3.4') - resource_records.should.contain('5.6.7.8') + resource_records.should.contain("1.2.3.4") + resource_records.should.contain("5.6.7.8") - rrsets = conn.get_all_rrsets( - zoneid, name="foo.foo.testdns.aws.com", type="A") + rrsets = conn.get_all_rrsets(zoneid, name="foo.foo.testdns.aws.com", type="A") rrsets.should.have.length_of(0) @mock_route53_deprecated def test_rrset_with_multiple_values(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") zone = conn.create_hosted_zone("testdns.aws.com") - zoneid = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zoneid = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("CREATE", "foo.bar.testdns.aws.com", "A") @@ -155,39 +150,48 @@ def test_rrset_with_multiple_values(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - set(rrsets[0].resource_records).should.equal(set(['1.2.3.4', '5.6.7.8'])) + set(rrsets[0].resource_records).should.equal(set(["1.2.3.4", "5.6.7.8"])) @mock_route53_deprecated def test_alias_rrset(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") zone = conn.create_hosted_zone("testdns.aws.com") - zoneid = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zoneid = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zoneid) - changes.add_change("CREATE", "foo.alias.testdns.aws.com", "A", - alias_hosted_zone_id="Z3DG6IL3SJCGPX", alias_dns_name="foo.testdns.aws.com") - changes.add_change("CREATE", "bar.alias.testdns.aws.com", "CNAME", - alias_hosted_zone_id="Z3DG6IL3SJCGPX", alias_dns_name="bar.testdns.aws.com") + changes.add_change( + "CREATE", + "foo.alias.testdns.aws.com", + "A", + alias_hosted_zone_id="Z3DG6IL3SJCGPX", + alias_dns_name="foo.testdns.aws.com", + ) + changes.add_change( + "CREATE", + "bar.alias.testdns.aws.com", + "CNAME", + alias_hosted_zone_id="Z3DG6IL3SJCGPX", + alias_dns_name="bar.testdns.aws.com", + ) changes.commit() rrsets = conn.get_all_rrsets(zoneid, type="A") alias_targets = [rr_set.alias_dns_name for rr_set in rrsets] alias_targets.should.have.length_of(2) - alias_targets.should.contain('foo.testdns.aws.com') - alias_targets.should.contain('bar.testdns.aws.com') - rrsets[0].alias_dns_name.should.equal('foo.testdns.aws.com') + alias_targets.should.contain("foo.testdns.aws.com") + alias_targets.should.contain("bar.testdns.aws.com") + rrsets[0].alias_dns_name.should.equal("foo.testdns.aws.com") rrsets[0].resource_records.should.have.length_of(0) rrsets = conn.get_all_rrsets(zoneid, type="CNAME") rrsets.should.have.length_of(1) - rrsets[0].alias_dns_name.should.equal('bar.testdns.aws.com') + rrsets[0].alias_dns_name.should.equal("bar.testdns.aws.com") rrsets[0].resource_records.should.have.length_of(0) @mock_route53_deprecated def test_create_health_check(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") check = HealthCheck( ip_addr="10.0.0.25", @@ -201,65 +205,51 @@ def test_create_health_check(): ) conn.create_health_check(check) - checks = conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = conn.get_list_health_checks()["ListHealthChecksResponse"]["HealthChecks"] list(checks).should.have.length_of(1) check = checks[0] - config = check['HealthCheckConfig'] - config['IPAddress'].should.equal("10.0.0.25") - config['Port'].should.equal("80") - config['Type'].should.equal("HTTP") - config['ResourcePath'].should.equal("/") - config['FullyQualifiedDomainName'].should.equal("example.com") - config['SearchString'].should.equal("a good response") - config['RequestInterval'].should.equal("10") - config['FailureThreshold'].should.equal("2") + config = check["HealthCheckConfig"] + config["IPAddress"].should.equal("10.0.0.25") + config["Port"].should.equal("80") + config["Type"].should.equal("HTTP") + config["ResourcePath"].should.equal("/") + config["FullyQualifiedDomainName"].should.equal("example.com") + config["SearchString"].should.equal("a good response") + config["RequestInterval"].should.equal("10") + config["FailureThreshold"].should.equal("2") @mock_route53_deprecated def test_delete_health_check(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") - check = HealthCheck( - ip_addr="10.0.0.25", - port=80, - hc_type="HTTP", - resource_path="/", - ) + check = HealthCheck(ip_addr="10.0.0.25", port=80, hc_type="HTTP", resource_path="/") conn.create_health_check(check) - checks = conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = conn.get_list_health_checks()["ListHealthChecksResponse"]["HealthChecks"] list(checks).should.have.length_of(1) - health_check_id = checks[0]['Id'] + health_check_id = checks[0]["Id"] conn.delete_health_check(health_check_id) - checks = conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = conn.get_list_health_checks()["ListHealthChecksResponse"]["HealthChecks"] list(checks).should.have.length_of(0) @mock_route53_deprecated def test_use_health_check_in_resource_record_set(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") - check = HealthCheck( - ip_addr="10.0.0.25", - port=80, - hc_type="HTTP", - resource_path="/", - ) - check = conn.create_health_check( - check)['CreateHealthCheckResponse']['HealthCheck'] - check_id = check['Id'] + check = HealthCheck(ip_addr="10.0.0.25", port=80, hc_type="HTTP", resource_path="/") + check = conn.create_health_check(check)["CreateHealthCheckResponse"]["HealthCheck"] + check_id = check["Id"] zone = conn.create_hosted_zone("testdns.aws.com") - zone_id = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zone_id = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zone_id) change = changes.add_change( - "CREATE", "foo.bar.testdns.aws.com", "A", health_check=check_id) + "CREATE", "foo.bar.testdns.aws.com", "A", health_check=check_id + ) change.add_value("1.2.3.4") changes.commit() @@ -269,20 +259,20 @@ def test_use_health_check_in_resource_record_set(): @mock_route53_deprecated def test_hosted_zone_comment_preserved(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") - firstzone = conn.create_hosted_zone( - "testdns.aws.com.", comment="test comment") - zone_id = firstzone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + firstzone = conn.create_hosted_zone("testdns.aws.com.", comment="test comment") + zone_id = firstzone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] hosted_zone = conn.get_hosted_zone(zone_id) - hosted_zone["GetHostedZoneResponse"]["HostedZone"][ - "Config"]["Comment"].should.equal("test comment") + hosted_zone["GetHostedZoneResponse"]["HostedZone"]["Config"][ + "Comment" + ].should.equal("test comment") hosted_zones = conn.get_all_hosted_zones() - hosted_zones["ListHostedZonesResponse"]["HostedZones"][ - 0]["Config"]["Comment"].should.equal("test comment") + hosted_zones["ListHostedZonesResponse"]["HostedZones"][0]["Config"][ + "Comment" + ].should.equal("test comment") zone = conn.get_zone("testdns.aws.com.") zone.config["Comment"].should.equal("test comment") @@ -295,21 +285,22 @@ def test_deleting_weighted_route(): conn.create_hosted_zone("testdns.aws.com.") zone = conn.get_zone("testdns.aws.com.") - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-foo', '50')) - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-bar', '50')) + zone.add_cname( + "cname.testdns.aws.com", "example.com", identifier=("success-test-foo", "50") + ) + zone.add_cname( + "cname.testdns.aws.com", "example.com", identifier=("success-test-bar", "50") + ) - cnames = zone.get_cname('cname.testdns.aws.com.', all=True) + cnames = zone.get_cname("cname.testdns.aws.com.", all=True) cnames.should.have.length_of(2) - foo_cname = [cname for cname in cnames if cname.identifier == - 'success-test-foo'][0] + foo_cname = [cname for cname in cnames if cname.identifier == "success-test-foo"][0] zone.delete_record(foo_cname) - cname = zone.get_cname('cname.testdns.aws.com.', all=True) + cname = zone.get_cname("cname.testdns.aws.com.", all=True) # When get_cname only had one result, it returns just that result instead # of a list. - cname.identifier.should.equal('success-test-bar') + cname.identifier.should.equal("success-test-bar") @mock_route53_deprecated @@ -319,59 +310,63 @@ def test_deleting_latency_route(): conn.create_hosted_zone("testdns.aws.com.") zone = conn.get_zone("testdns.aws.com.") - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-foo', 'us-west-2')) - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-bar', 'us-west-1')) + zone.add_cname( + "cname.testdns.aws.com", + "example.com", + identifier=("success-test-foo", "us-west-2"), + ) + zone.add_cname( + "cname.testdns.aws.com", + "example.com", + identifier=("success-test-bar", "us-west-1"), + ) - cnames = zone.get_cname('cname.testdns.aws.com.', all=True) + cnames = zone.get_cname("cname.testdns.aws.com.", all=True) cnames.should.have.length_of(2) - foo_cname = [cname for cname in cnames if cname.identifier == - 'success-test-foo'][0] - foo_cname.region.should.equal('us-west-2') + foo_cname = [cname for cname in cnames if cname.identifier == "success-test-foo"][0] + foo_cname.region.should.equal("us-west-2") zone.delete_record(foo_cname) - cname = zone.get_cname('cname.testdns.aws.com.', all=True) + cname = zone.get_cname("cname.testdns.aws.com.", all=True) # When get_cname only had one result, it returns just that result instead # of a list. - cname.identifier.should.equal('success-test-bar') - cname.region.should.equal('us-west-1') + cname.identifier.should.equal("success-test-bar") + cname.region.should.equal("us-west-1") @mock_route53_deprecated def test_hosted_zone_private_zone_preserved(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") firstzone = conn.create_hosted_zone( - "testdns.aws.com.", private_zone=True, vpc_id='vpc-fake', vpc_region='us-east-1') - zone_id = firstzone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + "testdns.aws.com.", private_zone=True, vpc_id="vpc-fake", vpc_region="us-east-1" + ) + zone_id = firstzone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] hosted_zone = conn.get_hosted_zone(zone_id) # in (original) boto, these bools returned as strings. - hosted_zone["GetHostedZoneResponse"]["HostedZone"][ - "Config"]["PrivateZone"].should.equal('True') + hosted_zone["GetHostedZoneResponse"]["HostedZone"]["Config"][ + "PrivateZone" + ].should.equal("True") hosted_zones = conn.get_all_hosted_zones() - hosted_zones["ListHostedZonesResponse"]["HostedZones"][ - 0]["Config"]["PrivateZone"].should.equal('True') + hosted_zones["ListHostedZonesResponse"]["HostedZones"][0]["Config"][ + "PrivateZone" + ].should.equal("True") zone = conn.get_zone("testdns.aws.com.") - zone.config["PrivateZone"].should.equal('True') + zone.config["PrivateZone"].should.equal("True") @mock_route53 def test_hosted_zone_private_zone_preserved_boto3(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") # TODO: actually create_hosted_zone statements with PrivateZone=True, but without # a _valid_ vpc-id should fail. firstzone = conn.create_hosted_zone( Name="testdns.aws.com.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="Test", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="Test"), ) zone_id = firstzone["HostedZone"]["Id"].split("/")[-1] @@ -389,112 +384,109 @@ def test_hosted_zone_private_zone_preserved_boto3(): @mock_route53 def test_list_or_change_tags_for_resource_request(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") health_check = conn.create_health_check( - CallerReference='foobar', + CallerReference="foobar", HealthCheckConfig={ - 'IPAddress': '192.0.2.44', - 'Port': 123, - 'Type': 'HTTP', - 'ResourcePath': '/', - 'RequestInterval': 30, - 'FailureThreshold': 123, - 'HealthThreshold': 123, - } + "IPAddress": "192.0.2.44", + "Port": 123, + "Type": "HTTP", + "ResourcePath": "/", + "RequestInterval": 30, + "FailureThreshold": 123, + "HealthThreshold": 123, + }, ) - healthcheck_id = health_check['HealthCheck']['Id'] + healthcheck_id = health_check["HealthCheck"]["Id"] + + # confirm this works for resources with zero tags + response = conn.list_tags_for_resource( + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response["ResourceTagSet"]["Tags"].should.be.empty tag1 = {"Key": "Deploy", "Value": "True"} tag2 = {"Key": "Name", "Value": "UnitTest"} # Test adding a tag for a resource id conn.change_tags_for_resource( - ResourceType='healthcheck', - ResourceId=healthcheck_id, - AddTags=[tag1, tag2] + ResourceType="healthcheck", ResourceId=healthcheck_id, AddTags=[tag1, tag2] ) # Check to make sure that the response has the 'ResourceTagSet' key response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response.should.contain('ResourceTagSet') + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response.should.contain("ResourceTagSet") # Validate that each key was added - response['ResourceTagSet']['Tags'].should.contain(tag1) - response['ResourceTagSet']['Tags'].should.contain(tag2) + response["ResourceTagSet"]["Tags"].should.contain(tag1) + response["ResourceTagSet"]["Tags"].should.contain(tag2) - len(response['ResourceTagSet']['Tags']).should.equal(2) + len(response["ResourceTagSet"]["Tags"]).should.equal(2) # Try to remove the tags conn.change_tags_for_resource( - ResourceType='healthcheck', + ResourceType="healthcheck", ResourceId=healthcheck_id, - RemoveTagKeys=[tag1['Key']] + RemoveTagKeys=[tag1["Key"]], ) # Check to make sure that the response has the 'ResourceTagSet' key response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response.should.contain('ResourceTagSet') - response['ResourceTagSet']['Tags'].should_not.contain(tag1) - response['ResourceTagSet']['Tags'].should.contain(tag2) + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response.should.contain("ResourceTagSet") + response["ResourceTagSet"]["Tags"].should_not.contain(tag1) + response["ResourceTagSet"]["Tags"].should.contain(tag2) # Remove the second tag conn.change_tags_for_resource( - ResourceType='healthcheck', + ResourceType="healthcheck", ResourceId=healthcheck_id, - RemoveTagKeys=[tag2['Key']] + RemoveTagKeys=[tag2["Key"]], ) response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response['ResourceTagSet']['Tags'].should_not.contain(tag2) + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response["ResourceTagSet"]["Tags"].should_not.contain(tag2) # Re-add the tags conn.change_tags_for_resource( - ResourceType='healthcheck', - ResourceId=healthcheck_id, - AddTags=[tag1, tag2] + ResourceType="healthcheck", ResourceId=healthcheck_id, AddTags=[tag1, tag2] ) # Remove both conn.change_tags_for_resource( - ResourceType='healthcheck', + ResourceType="healthcheck", ResourceId=healthcheck_id, - RemoveTagKeys=[tag1['Key'], tag2['Key']] + RemoveTagKeys=[tag1["Key"], tag2["Key"]], ) response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response['ResourceTagSet']['Tags'].should.be.empty + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response["ResourceTagSet"]["Tags"].should.be.empty @mock_route53 def test_list_hosted_zones_by_name(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") conn.create_hosted_zone( Name="test.b.com.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="test com", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="test com"), ) conn.create_hosted_zone( Name="test.a.org.", - CallerReference=str(hash('bar')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="test org", - ) + CallerReference=str(hash("bar")), + HostedZoneConfig=dict(PrivateZone=True, Comment="test org"), ) conn.create_hosted_zone( Name="test.a.org.", - CallerReference=str(hash('bar')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="test org 2", - ) + CallerReference=str(hash("bar")), + HostedZoneConfig=dict(PrivateZone=True, Comment="test org 2"), ) # test lookup @@ -516,14 +508,11 @@ def test_list_hosted_zones_by_name(): @mock_route53 def test_change_resource_record_sets_crud_valid(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") conn.create_hosted_zone( Name="db.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="db", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="db"), ) zones = conn.list_hosted_zones_by_name(DNSName="db.") @@ -533,244 +522,244 @@ def test_change_resource_record_sets_crud_valid(): # Create A Record. a_record_endpoint_payload = { - 'Comment': 'Create A record prod.redis.db', - 'Changes': [ + "Comment": "Create A record prod.redis.db", + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db.', - 'Type': 'A', - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "prod.redis.db.", + "Type": "A", + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=a_record_endpoint_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=a_record_endpoint_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(1) - a_record_detail = response['ResourceRecordSets'][0] - a_record_detail['Name'].should.equal('prod.redis.db.') - a_record_detail['Type'].should.equal('A') - a_record_detail['TTL'].should.equal(10) - a_record_detail['ResourceRecords'].should.equal([{'Value': '127.0.0.1'}]) + len(response["ResourceRecordSets"]).should.equal(1) + a_record_detail = response["ResourceRecordSets"][0] + a_record_detail["Name"].should.equal("prod.redis.db.") + a_record_detail["Type"].should.equal("A") + a_record_detail["TTL"].should.equal(10) + a_record_detail["ResourceRecords"].should.equal([{"Value": "127.0.0.1"}]) # Update A Record. cname_record_endpoint_payload = { - 'Comment': 'Update A record prod.redis.db', - 'Changes': [ + "Comment": "Update A record prod.redis.db", + "Changes": [ { - 'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db.', - 'Type': 'A', - 'TTL': 60, - 'ResourceRecords': [{ - 'Value': '192.168.1.1' - }] - } + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": "prod.redis.db.", + "Type": "A", + "TTL": 60, + "ResourceRecords": [{"Value": "192.168.1.1"}], + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_record_endpoint_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=cname_record_endpoint_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(1) - cname_record_detail = response['ResourceRecordSets'][0] - cname_record_detail['Name'].should.equal('prod.redis.db.') - cname_record_detail['Type'].should.equal('A') - cname_record_detail['TTL'].should.equal(60) - cname_record_detail['ResourceRecords'].should.equal([{'Value': '192.168.1.1'}]) + len(response["ResourceRecordSets"]).should.equal(1) + cname_record_detail = response["ResourceRecordSets"][0] + cname_record_detail["Name"].should.equal("prod.redis.db.") + cname_record_detail["Type"].should.equal("A") + cname_record_detail["TTL"].should.equal(60) + cname_record_detail["ResourceRecords"].should.equal([{"Value": "192.168.1.1"}]) # Update to add Alias. cname_alias_record_endpoint_payload = { - 'Comment': 'Update to Alias prod.redis.db', - 'Changes': [ + "Comment": "Update to Alias prod.redis.db", + "Changes": [ { - 'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db.', - 'Type': 'A', - 'TTL': 60, - 'AliasTarget': { - 'HostedZoneId': hosted_zone_id, - 'DNSName': 'prod.redis.alias.', - 'EvaluateTargetHealth': False, - } - } + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": "prod.redis.db.", + "Type": "A", + "TTL": 60, + "AliasTarget": { + "HostedZoneId": hosted_zone_id, + "DNSName": "prod.redis.alias.", + "EvaluateTargetHealth": False, + }, + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - cname_alias_record_detail = response['ResourceRecordSets'][0] - cname_alias_record_detail['Name'].should.equal('prod.redis.db.') - cname_alias_record_detail['Type'].should.equal('A') - cname_alias_record_detail['TTL'].should.equal(60) - cname_alias_record_detail['AliasTarget'].should.equal({ - 'HostedZoneId': hosted_zone_id, - 'DNSName': 'prod.redis.alias.', - 'EvaluateTargetHealth': False, - }) - cname_alias_record_detail.should_not.contain('ResourceRecords') + cname_alias_record_detail = response["ResourceRecordSets"][0] + cname_alias_record_detail["Name"].should.equal("prod.redis.db.") + cname_alias_record_detail["Type"].should.equal("A") + cname_alias_record_detail["TTL"].should.equal(60) + cname_alias_record_detail["AliasTarget"].should.equal( + { + "HostedZoneId": hosted_zone_id, + "DNSName": "prod.redis.alias.", + "EvaluateTargetHealth": False, + } + ) + cname_alias_record_detail.should_not.contain("ResourceRecords") # Delete record with wrong type. delete_payload = { - 'Comment': 'delete prod.redis.db', - 'Changes': [ + "Comment": "delete prod.redis.db", + "Changes": [ { - 'Action': 'DELETE', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db', - 'Type': 'CNAME', - } + "Action": "DELETE", + "ResourceRecordSet": {"Name": "prod.redis.db", "Type": "CNAME"}, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(1) + len(response["ResourceRecordSets"]).should.equal(1) # Delete record. delete_payload = { - 'Comment': 'delete prod.redis.db', - 'Changes': [ + "Comment": "delete prod.redis.db", + "Changes": [ { - 'Action': 'DELETE', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db', - 'Type': 'A', - } + "Action": "DELETE", + "ResourceRecordSet": {"Name": "prod.redis.db", "Type": "A"}, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(0) + len(response["ResourceRecordSets"]).should.equal(0) + @mock_route53 def test_change_weighted_resource_record_sets(): - conn = boto3.client('route53', region_name='us-east-2') + conn = boto3.client("route53", region_name="us-east-2") conn.create_hosted_zone( - Name='test.vpc.internal.', - CallerReference=str(hash('test')) + Name="test.vpc.internal.", CallerReference=str(hash("test")) ) - zones = conn.list_hosted_zones_by_name( - DNSName='test.vpc.internal.' - ) + zones = conn.list_hosted_zones_by_name(DNSName="test.vpc.internal.") - hosted_zone_id = zones['HostedZones'][0]['Id'] + hosted_zone_id = zones["HostedZones"][0]["Id"] - #Create 2 weighted records + # Create 2 weighted records conn.change_resource_record_sets( HostedZoneId=hosted_zone_id, ChangeBatch={ - 'Changes': [ + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'test.vpc.internal', - 'Type': 'A', - 'SetIdentifier': 'test1', - 'Weight': 50, - 'AliasTarget': { - 'HostedZoneId': 'Z3AADJGX6KTTL2', - 'DNSName': 'internal-test1lb-447688172.us-east-2.elb.amazonaws.com.', - 'EvaluateTargetHealth': True - } - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "test.vpc.internal", + "Type": "A", + "SetIdentifier": "test1", + "Weight": 50, + "AliasTarget": { + "HostedZoneId": "Z3AADJGX6KTTL2", + "DNSName": "internal-test1lb-447688172.us-east-2.elb.amazonaws.com.", + "EvaluateTargetHealth": True, + }, + }, }, - { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'test.vpc.internal', - 'Type': 'A', - 'SetIdentifier': 'test2', - 'Weight': 50, - 'AliasTarget': { - 'HostedZoneId': 'Z3AADJGX6KTTL2', - 'DNSName': 'internal-testlb2-1116641781.us-east-2.elb.amazonaws.com.', - 'EvaluateTargetHealth': True - } - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "test.vpc.internal", + "Type": "A", + "SetIdentifier": "test2", + "Weight": 50, + "AliasTarget": { + "HostedZoneId": "Z3AADJGX6KTTL2", + "DNSName": "internal-testlb2-1116641781.us-east-2.elb.amazonaws.com.", + "EvaluateTargetHealth": True, + }, + }, + }, + ] + }, + ) + + response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) + record = response["ResourceRecordSets"][0] + # Update the first record to have a weight of 90 + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, + ChangeBatch={ + "Changes": [ + { + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": record["Name"], + "Type": record["Type"], + "SetIdentifier": record["SetIdentifier"], + "Weight": 90, + "AliasTarget": { + "HostedZoneId": record["AliasTarget"]["HostedZoneId"], + "DNSName": record["AliasTarget"]["DNSName"], + "EvaluateTargetHealth": record["AliasTarget"][ + "EvaluateTargetHealth" + ], + }, + }, } ] - } + }, ) - response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - record = response['ResourceRecordSets'][0] - #Update the first record to have a weight of 90 + record = response["ResourceRecordSets"][1] + # Update the second record to have a weight of 10 conn.change_resource_record_sets( HostedZoneId=hosted_zone_id, ChangeBatch={ - 'Changes' : [ + "Changes": [ { - 'Action' : 'UPSERT', - 'ResourceRecordSet' : { - 'Name' : record['Name'], - 'Type' : record['Type'], - 'SetIdentifier' : record['SetIdentifier'], - 'Weight' : 90, - 'AliasTarget' : { - 'HostedZoneId' : record['AliasTarget']['HostedZoneId'], - 'DNSName' : record['AliasTarget']['DNSName'], - 'EvaluateTargetHealth' : record['AliasTarget']['EvaluateTargetHealth'] - } - } - }, + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": record["Name"], + "Type": record["Type"], + "SetIdentifier": record["SetIdentifier"], + "Weight": 10, + "AliasTarget": { + "HostedZoneId": record["AliasTarget"]["HostedZoneId"], + "DNSName": record["AliasTarget"]["DNSName"], + "EvaluateTargetHealth": record["AliasTarget"][ + "EvaluateTargetHealth" + ], + }, + }, + } ] - } - ) - - record = response['ResourceRecordSets'][1] - #Update the second record to have a weight of 10 - conn.change_resource_record_sets( - HostedZoneId=hosted_zone_id, - ChangeBatch={ - 'Changes' : [ - { - 'Action' : 'UPSERT', - 'ResourceRecordSet' : { - 'Name' : record['Name'], - 'Type' : record['Type'], - 'SetIdentifier' : record['SetIdentifier'], - 'Weight' : 10, - 'AliasTarget' : { - 'HostedZoneId' : record['AliasTarget']['HostedZoneId'], - 'DNSName' : record['AliasTarget']['DNSName'], - 'EvaluateTargetHealth' : record['AliasTarget']['EvaluateTargetHealth'] - } - } - }, - ] - } + }, ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - for record in response['ResourceRecordSets']: - if record['SetIdentifier'] == 'test1': - record['Weight'].should.equal(90) - if record['SetIdentifier'] == 'test2': - record['Weight'].should.equal(10) + for record in response["ResourceRecordSets"]: + if record["SetIdentifier"] == "test1": + record["Weight"].should.equal(90) + if record["SetIdentifier"] == "test2": + record["Weight"].should.equal(10) @mock_route53 def test_change_resource_record_invalid(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") conn.create_hosted_zone( Name="db.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="db", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="db"), ) zones = conn.list_hosted_zones_by_name(DNSName="db.") @@ -779,92 +768,89 @@ def test_change_resource_record_invalid(): hosted_zone_id = zones["HostedZones"][0]["Id"] invalid_a_record_payload = { - 'Comment': 'this should fail', - 'Changes': [ + "Comment": "this should fail", + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'prod.scooby.doo', - 'Type': 'A', - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "prod.scooby.doo", + "Type": "A", + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } with assert_raises(botocore.exceptions.ClientError): - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=invalid_a_record_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=invalid_a_record_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(0) + len(response["ResourceRecordSets"]).should.equal(0) invalid_cname_record_payload = { - 'Comment': 'this should also fail', - 'Changes': [ + "Comment": "this should also fail", + "Changes": [ { - 'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': 'prod.scooby.doo', - 'Type': 'CNAME', - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": "prod.scooby.doo", + "Type": "CNAME", + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } with assert_raises(botocore.exceptions.ClientError): - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=invalid_cname_record_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=invalid_cname_record_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(0) + len(response["ResourceRecordSets"]).should.equal(0) @mock_route53 def test_list_resource_record_sets_name_type_filters(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") create_hosted_zone_response = conn.create_hosted_zone( Name="db.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="db", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="db"), ) - hosted_zone_id = create_hosted_zone_response['HostedZone']['Id'] + hosted_zone_id = create_hosted_zone_response["HostedZone"]["Id"] def create_resource_record_set(rec_type, rec_name): payload = { - 'Comment': 'create {} record {}'.format(rec_type, rec_name), - 'Changes': [ + "Comment": "create {} record {}".format(rec_type, rec_name), + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': rec_name, - 'Type': rec_type, - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": rec_name, + "Type": rec_type, + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=payload + ) # record_type, record_name all_records = [ - ('A', 'a.a.db.'), - ('A', 'a.b.db.'), - ('A', 'b.b.db.'), - ('CNAME', 'b.b.db.'), - ('CNAME', 'b.c.db.'), - ('CNAME', 'c.c.db.') + ("A", "a.a.db."), + ("A", "a.b.db."), + ("A", "b.b.db."), + ("CNAME", "b.b.db."), + ("CNAME", "b.c.db."), + ("CNAME", "c.c.db."), ] for record_type, record_name in all_records: create_resource_record_set(record_type, record_name) @@ -873,10 +859,12 @@ def test_list_resource_record_sets_name_type_filters(): response = conn.list_resource_record_sets( HostedZoneId=hosted_zone_id, StartRecordType=all_records[start_with][0], - StartRecordName=all_records[start_with][1] + StartRecordName=all_records[start_with][1], ) - returned_records = [(record['Type'], record['Name']) for record in response['ResourceRecordSets']] + returned_records = [ + (record["Type"], record["Name"]) for record in response["ResourceRecordSets"] + ] len(returned_records).should.equal(len(all_records) - start_with) for desired_record in all_records[start_with:]: returned_records.should.contain(desired_record) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index cd57fc92b..3cf3bc6f1 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -2,6 +2,9 @@ from __future__ import unicode_literals import datetime +import os +import sys + from six.moves.urllib.request import urlopen from six.moves.urllib.error import HTTPError from functools import wraps @@ -20,15 +23,19 @@ from botocore.handlers import disable_signing from boto.s3.connection import S3Connection from boto.s3.key import Key from freezegun import freeze_time +from parameterized import parameterized import six import requests import tests.backport_assert_raises # noqa +from nose import SkipTest from nose.tools import assert_raises import sure # noqa -from moto import settings, mock_s3, mock_s3_deprecated +from moto import settings, mock_s3, mock_s3_deprecated, mock_config import moto.s3.models as s3model +from moto.core.exceptions import InvalidNextTokenException +from moto.core.utils import py2_strip_unicode_keys if settings.TEST_SERVER_MODE: REDUCED_PART_SIZE = s3model.UPLOAD_PART_MIN_SIZE @@ -56,21 +63,20 @@ def reduced_min_part_size(f): class MyModel(object): - def __init__(self, name, value): self.name = name self.value = value def save(self): - s3 = boto3.client('s3', region_name='us-east-1') - s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value) + s3 = boto3.client("s3", region_name="us-east-1") + s3.put_object(Bucket="mybucket", Key=self.name, Body=self.value) @mock_s3 def test_keys_are_pickleable(): """Keys must be pickleable due to boto3 implementation details.""" - key = s3model.FakeKey('name', b'data!') - assert key.value == b'data!' + key = s3model.FakeKey("name", b"data!") + assert key.value == b"data!" pickled = pickle.dumps(key) loaded = pickle.loads(pickled) @@ -79,72 +85,73 @@ def test_keys_are_pickleable(): @mock_s3 def test_append_to_value__basic(): - key = s3model.FakeKey('name', b'data!') - assert key.value == b'data!' + key = s3model.FakeKey("name", b"data!") + assert key.value == b"data!" assert key.size == 5 - key.append_to_value(b' And even more data') - assert key.value == b'data! And even more data' + key.append_to_value(b" And even more data") + assert key.value == b"data! And even more data" assert key.size == 24 @mock_s3 def test_append_to_value__nothing_added(): - key = s3model.FakeKey('name', b'data!') - assert key.value == b'data!' + key = s3model.FakeKey("name", b"data!") + assert key.value == b"data!" assert key.size == 5 - key.append_to_value(b'') - assert key.value == b'data!' + key.append_to_value(b"") + assert key.value == b"data!" assert key.size == 5 @mock_s3 def test_append_to_value__empty_key(): - key = s3model.FakeKey('name', b'') - assert key.value == b'' + key = s3model.FakeKey("name", b"") + assert key.value == b"" assert key.size == 0 - key.append_to_value(b'stuff') - assert key.value == b'stuff' + key.append_to_value(b"stuff") + assert key.value == b"stuff" assert key.size == 5 @mock_s3 def test_my_model_save(): # Create Bucket so that test can run - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket='mybucket') + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket="mybucket") #################################### - model_instance = MyModel('steve', 'is awesome') + model_instance = MyModel("steve", "is awesome") model_instance.save() - body = conn.Object('mybucket', 'steve').get()['Body'].read().decode() + body = conn.Object("mybucket", "steve").get()["Body"].read().decode() - assert body == 'is awesome' + assert body == "is awesome" @mock_s3 def test_key_etag(): - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket='mybucket') + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket="mybucket") - model_instance = MyModel('steve', 'is awesome') + model_instance = MyModel("steve", "is awesome") model_instance.save() - conn.Bucket('mybucket').Object('steve').e_tag.should.equal( - '"d32bda93738f7e03adb22e66c90fbc04"') + conn.Bucket("mybucket").Object("steve").e_tag.should.equal( + '"d32bda93738f7e03adb22e66c90fbc04"' + ) @mock_s3_deprecated def test_multipart_upload_too_small(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - multipart.upload_part_from_file(BytesIO(b'hello'), 1) - multipart.upload_part_from_file(BytesIO(b'world'), 2) + multipart.upload_part_from_file(BytesIO(b"hello"), 1) + multipart.upload_part_from_file(BytesIO(b"world"), 2) # Multipart with total size under 5MB is refused multipart.complete_upload.should.throw(S3ResponseError) @@ -152,48 +159,45 @@ def test_multipart_upload_too_small(): @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" multipart.upload_part_from_file(BytesIO(part2), 2) multipart.complete_upload() # we should get both parts as the key contents - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + part2) + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_out_of_order(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" multipart.upload_part_from_file(BytesIO(part2), 4) - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 2) multipart.complete_upload() # we should get both parts as the key contents - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + part2) + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_with_headers(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") - multipart = bucket.initiate_multipart_upload( - "the-key", metadata={"foo": "bar"}) - part1 = b'0' * 10 + multipart = bucket.initiate_multipart_upload("the-key", metadata={"foo": "bar"}) + part1 = b"0" * 10 multipart.upload_part_from_file(BytesIO(part1), 1) multipart.complete_upload() @@ -204,29 +208,28 @@ def test_multipart_upload_with_headers(): @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_with_copy_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "original-key" key.set_contents_from_string("key_value") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) multipart.copy_part_from_key("foobar", "original-key", 2, 0, 3) multipart.complete_upload() - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + b"key_") + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + b"key_") @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_cancel(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) multipart.cancel_upload() # TODO we really need some sort of assertion here, but we don't currently @@ -237,14 +240,14 @@ def test_multipart_upload_cancel(): @reduced_min_part_size def test_multipart_etag(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" multipart.upload_part_from_file(BytesIO(part2), 2) multipart.complete_upload() # we should get both parts as the key contents @@ -255,76 +258,80 @@ def test_multipart_etag(): @reduced_min_part_size def test_multipart_invalid_order(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * 5242880 + part1 = b"0" * 5242880 etag1 = multipart.upload_part_from_file(BytesIO(part1), 1).etag # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" etag2 = multipart.upload_part_from_file(BytesIO(part2), 2).etag xml = "{0}{1}" xml = xml.format(2, etag2) + xml.format(1, etag1) xml = "{0}".format(xml) bucket.complete_multipart_upload.when.called_with( - multipart.key_name, multipart.id, xml).should.throw(S3ResponseError) + multipart.key_name, multipart.id, xml + ).should.throw(S3ResponseError) + @mock_s3_deprecated @reduced_min_part_size def test_multipart_etag_quotes_stripped(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE etag1 = multipart.upload_part_from_file(BytesIO(part1), 1).etag # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" etag2 = multipart.upload_part_from_file(BytesIO(part2), 2).etag # Strip quotes from etags - etag1 = etag1.replace('"','') - etag2 = etag2.replace('"','') + etag1 = etag1.replace('"', "") + etag2 = etag2.replace('"', "") xml = "{0}{1}" xml = xml.format(1, etag1) + xml.format(2, etag2) xml = "{0}".format(xml) bucket.complete_multipart_upload.when.called_with( - multipart.key_name, multipart.id, xml).should_not.throw(S3ResponseError) + multipart.key_name, multipart.id, xml + ).should_not.throw(S3ResponseError) # we should get both parts as the key contents bucket.get_key("the-key").etag.should.equal(EXPECTED_ETAG) + @mock_s3_deprecated @reduced_min_part_size def test_multipart_duplicate_upload(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) # same part again multipart.upload_part_from_file(BytesIO(part1), 1) - part2 = b'1' * 1024 + part2 = b"1" * 1024 multipart.upload_part_from_file(BytesIO(part2), 2) multipart.complete_upload() # We should get only one copy of part 1. - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + part2) + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) @mock_s3_deprecated def test_list_multiparts(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart1 = bucket.initiate_multipart_upload("one-key") multipart2 = bucket.initiate_multipart_upload("two-key") uploads = bucket.get_all_multipart_uploads() uploads.should.have.length_of(2) dict([(u.key_name, u.id) for u in uploads]).should.equal( - {'one-key': multipart1.id, 'two-key': multipart2.id}) + {"one-key": multipart1.id, "two-key": multipart2.id} + ) multipart2.cancel_upload() uploads = bucket.get_all_multipart_uploads() uploads.should.have.length_of(1) @@ -336,34 +343,36 @@ def test_list_multiparts(): @mock_s3_deprecated def test_key_save_to_missing_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.get_bucket('mybucket', validate=False) + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.get_bucket("mybucket", validate=False) key = Key(bucket) key.key = "the-key" - key.set_contents_from_string.when.called_with( - "foobar").should.throw(S3ResponseError) + key.set_contents_from_string.when.called_with("foobar").should.throw( + S3ResponseError + ) @mock_s3_deprecated def test_missing_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") bucket.get_key("the-key").should.equal(None) @mock_s3_deprecated def test_missing_key_urllib2(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") conn.create_bucket("foobar") - urlopen.when.called_with( - "http://foobar.s3.amazonaws.com/the-key").should.throw(HTTPError) + urlopen.when.called_with("http://foobar.s3.amazonaws.com/the-key").should.throw( + HTTPError + ) @mock_s3_deprecated def test_empty_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" @@ -371,12 +380,12 @@ def test_empty_key(): key = bucket.get_key("the-key") key.size.should.equal(0) - key.get_contents_as_string().should.equal(b'') + key.get_contents_as_string().should.equal(b"") @mock_s3_deprecated def test_empty_key_set_on_existing_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" @@ -384,59 +393,55 @@ def test_empty_key_set_on_existing_key(): key = bucket.get_key("the-key") key.size.should.equal(6) - key.get_contents_as_string().should.equal(b'foobar') + key.get_contents_as_string().should.equal(b"foobar") key.set_contents_from_string("") - bucket.get_key("the-key").get_contents_as_string().should.equal(b'') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"") @mock_s3_deprecated def test_large_key_save(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("foobar" * 100000) - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b'foobar' * 100000) + bucket.get_key("the-key").get_contents_as_string().should.equal(b"foobar" * 100000) @mock_s3_deprecated def test_copy_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-key') + bucket.copy_key("new-key", "foobar", "the-key") - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b"some value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("the-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") +@parameterized([("the-unicode-💩-key",), ("key-with?question-mark",)]) @mock_s3_deprecated -def test_copy_key_with_unicode(): - conn = boto.connect_s3('the_key', 'the_secret') +def test_copy_key_with_special_chars(key_name): + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) - key.key = "the-unicode-💩-key" + key.key = key_name key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-unicode-💩-key') + bucket.copy_key("new-key", "foobar", key_name) - bucket.get_key( - "the-unicode-💩-key").get_contents_as_string().should.equal(b"some value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key(key_name).get_contents_as_string().should.equal(b"some value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") @mock_s3_deprecated def test_copy_key_with_version(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") bucket.configure_versioning(versioning=True) key = Key(bucket) @@ -444,46 +449,40 @@ def test_copy_key_with_version(): key.set_contents_from_string("some value") key.set_contents_from_string("another value") - key = [ - key.version_id - for key in bucket.get_all_versions() - if not key.is_latest - ][0] - bucket.copy_key('new-key', 'foobar', 'the-key', src_version_id=key) + key = [key.version_id for key in bucket.get_all_versions() if not key.is_latest][0] + bucket.copy_key("new-key", "foobar", "the-key", src_version_id=key) - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b"another value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("the-key").get_contents_as_string().should.equal(b"another value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") @mock_s3_deprecated def test_set_metadata(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) - key.key = 'the-key' - key.set_metadata('md', 'Metadatastring') + key.key = "the-key" + key.set_metadata("md", "Metadatastring") key.set_contents_from_string("Testval") - bucket.get_key('the-key').get_metadata('md').should.equal('Metadatastring') + bucket.get_key("the-key").get_metadata("md").should.equal("Metadatastring") @mock_s3_deprecated def test_copy_key_replace_metadata(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" - key.set_metadata('md', 'Metadatastring') + key.set_metadata("md", "Metadatastring") key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-key', - metadata={'momd': 'Mometadatastring'}) + bucket.copy_key( + "new-key", "foobar", "the-key", metadata={"momd": "Mometadatastring"} + ) - bucket.get_key("new-key").get_metadata('md').should.be.none - bucket.get_key( - "new-key").get_metadata('momd').should.equal('Mometadatastring') + bucket.get_key("new-key").get_metadata("md").should.be.none + bucket.get_key("new-key").get_metadata("momd").should.equal("Mometadatastring") @freeze_time("2012-01-01 12:00:00") @@ -497,23 +496,23 @@ def test_last_modified(): key.set_contents_from_string("some value") rs = bucket.get_all_keys() - rs[0].last_modified.should.equal('2012-01-01T12:00:00.000Z') + rs[0].last_modified.should.equal("2012-01-01T12:00:00.000Z") - bucket.get_key( - "the-key").last_modified.should.equal('Sun, 01 Jan 2012 12:00:00 GMT') + bucket.get_key("the-key").last_modified.should.equal( + "Sun, 01 Jan 2012 12:00:00 GMT" + ) @mock_s3_deprecated def test_missing_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') - conn.get_bucket.when.called_with('mybucket').should.throw(S3ResponseError) + conn = boto.connect_s3("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket").should.throw(S3ResponseError) @mock_s3_deprecated def test_bucket_with_dash(): - conn = boto.connect_s3('the_key', 'the_secret') - conn.get_bucket.when.called_with( - 'mybucket-test').should.throw(S3ResponseError) + conn = boto.connect_s3("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket-test").should.throw(S3ResponseError) @mock_s3_deprecated @@ -522,7 +521,7 @@ def test_create_existing_bucket(): conn = boto.s3.connect_to_region("us-west-2") conn.create_bucket("foobar") with assert_raises(S3CreateError): - conn.create_bucket('foobar') + conn.create_bucket("foobar") @mock_s3_deprecated @@ -544,15 +543,14 @@ def test_create_existing_bucket_in_us_east_1(): @mock_s3_deprecated def test_other_region(): - conn = S3Connection( - 'key', 'secret', host='s3-website-ap-southeast-2.amazonaws.com') + conn = S3Connection("key", "secret", host="s3-website-ap-southeast-2.amazonaws.com") conn.create_bucket("foobar") list(conn.get_bucket("foobar").get_all_keys()).should.equal([]) @mock_s3_deprecated def test_bucket_deletion(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) @@ -574,7 +572,7 @@ def test_bucket_deletion(): @mock_s3_deprecated def test_get_all_buckets(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") conn.create_bucket("foobar") conn.create_bucket("foobar2") buckets = conn.get_all_buckets() @@ -585,36 +583,34 @@ def test_get_all_buckets(): @mock_s3 @mock_s3_deprecated def test_post_to_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") - requests.post("https://foobar.s3.amazonaws.com/", { - 'key': 'the-key', - 'file': 'nothing' - }) + requests.post( + "https://foobar.s3.amazonaws.com/", {"key": "the-key", "file": "nothing"} + ) - bucket.get_key('the-key').get_contents_as_string().should.equal(b'nothing') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"nothing") @mock_s3 @mock_s3_deprecated def test_post_with_metadata_to_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") - requests.post("https://foobar.s3.amazonaws.com/", { - 'key': 'the-key', - 'file': 'nothing', - 'x-amz-meta-test': 'metadata' - }) + requests.post( + "https://foobar.s3.amazonaws.com/", + {"key": "the-key", "file": "nothing", "x-amz-meta-test": "metadata"}, + ) - bucket.get_key('the-key').get_metadata('test').should.equal('metadata') + bucket.get_key("the-key").get_metadata("test").should.equal("metadata") @mock_s3_deprecated def test_delete_missing_key(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") deleted_key = bucket.delete_key("foobar") deleted_key.key.should.equal("foobar") @@ -622,40 +618,40 @@ def test_delete_missing_key(): @mock_s3_deprecated def test_delete_keys(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") - result = bucket.delete_keys(['file2', 'file3']) + result = bucket.delete_keys(["file2", "file3"]) result.deleted.should.have.length_of(2) result.errors.should.have.length_of(0) keys = bucket.get_all_keys() keys.should.have.length_of(2) - keys[0].name.should.equal('file1') + keys[0].name.should.equal("file1") @mock_s3_deprecated def test_delete_keys_invalid(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") # non-existing key case - result = bucket.delete_keys(['abc', 'file3']) + result = bucket.delete_keys(["abc", "file3"]) result.deleted.should.have.length_of(1) result.errors.should.have.length_of(1) keys = bucket.get_all_keys() keys.should.have.length_of(3) - keys[0].name.should.equal('file1') + keys[0].name.should.equal("file1") # empty keys result = bucket.delete_keys([]) @@ -663,136 +659,141 @@ def test_delete_keys_invalid(): result.deleted.should.have.length_of(0) result.errors.should.have.length_of(0) + @mock_s3 def test_boto3_delete_empty_keys_list(): with assert_raises(ClientError) as err: - boto3.client('s3').delete_objects(Bucket='foobar', Delete={'Objects': []}) + boto3.client("s3").delete_objects(Bucket="foobar", Delete={"Objects": []}) assert err.exception.response["Error"]["Code"] == "MalformedXML" @mock_s3_deprecated def test_bucket_name_with_dot(): conn = boto.connect_s3() - bucket = conn.create_bucket('firstname.lastname') + bucket = conn.create_bucket("firstname.lastname") - k = Key(bucket, 'somekey') - k.set_contents_from_string('somedata') + k = Key(bucket, "somekey") + k.set_contents_from_string("somedata") @mock_s3_deprecated def test_key_with_special_characters(): conn = boto.connect_s3() - bucket = conn.create_bucket('test_bucket_name') + bucket = conn.create_bucket("test_bucket_name") - key = Key(bucket, 'test_list_keys_2/x?y') - key.set_contents_from_string('value1') + key = Key(bucket, "test_list_keys_2/x?y") + key.set_contents_from_string("value1") - key_list = bucket.list('test_list_keys_2/', '/') + key_list = bucket.list("test_list_keys_2/", "/") keys = [x for x in key_list] keys[0].name.should.equal("test_list_keys_2/x?y") @mock_s3_deprecated def test_unicode_key_with_slash(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "/the-key-unîcode/test" key.set_contents_from_string("value") key = bucket.get_key("/the-key-unîcode/test") - key.get_contents_as_string().should.equal(b'value') + key.get_contents_as_string().should.equal(b"value") @mock_s3_deprecated def test_bucket_key_listing_order(): conn = boto.connect_s3() - bucket = conn.create_bucket('test_bucket') - prefix = 'toplevel/' + bucket = conn.create_bucket("test_bucket") + prefix = "toplevel/" def store(name): k = Key(bucket, prefix + name) - k.set_contents_from_string('somedata') + k.set_contents_from_string("somedata") - names = ['x/key', 'y.key1', 'y.key2', 'y.key3', 'x/y/key', 'x/y/z/key'] + names = ["x/key", "y.key1", "y.key2", "y.key3", "x/y/key", "x/y/z/key"] for name in names: store(name) delimiter = None keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/x/key', 'toplevel/x/y/key', 'toplevel/x/y/z/key', - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3' - ]) + keys.should.equal( + [ + "toplevel/x/key", + "toplevel/x/y/key", + "toplevel/x/y/z/key", + "toplevel/y.key1", + "toplevel/y.key2", + "toplevel/y.key3", + ] + ) - delimiter = '/' + delimiter = "/" keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3', 'toplevel/x/' - ]) + keys.should.equal( + ["toplevel/y.key1", "toplevel/y.key2", "toplevel/y.key3", "toplevel/x/"] + ) # Test delimiter with no prefix - delimiter = '/' + delimiter = "/" keys = [x.name for x in bucket.list(prefix=None, delimiter=delimiter)] - keys.should.equal(['toplevel/']) + keys.should.equal(["toplevel/"]) delimiter = None - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal( - [u'toplevel/x/key', u'toplevel/x/y/key', u'toplevel/x/y/z/key']) + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/key", "toplevel/x/y/key", "toplevel/x/y/z/key"]) - delimiter = '/' - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal([u'toplevel/x/']) + delimiter = "/" + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/"]) @mock_s3_deprecated def test_key_with_reduced_redundancy(): conn = boto.connect_s3() - bucket = conn.create_bucket('test_bucket_name') + bucket = conn.create_bucket("test_bucket_name") - key = Key(bucket, 'test_rr_key') - key.set_contents_from_string('value1', reduced_redundancy=True) + key = Key(bucket, "test_rr_key") + key.set_contents_from_string("value1", reduced_redundancy=True) # we use the bucket iterator because of: # https:/github.com/boto/boto/issues/1173 - list(bucket)[0].storage_class.should.equal('REDUCED_REDUNDANCY') + list(bucket)[0].storage_class.should.equal("REDUCED_REDUNDANCY") @mock_s3_deprecated def test_copy_key_reduced_redundancy(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-key', - storage_class='REDUCED_REDUNDANCY') + bucket.copy_key("new-key", "foobar", "the-key", storage_class="REDUCED_REDUNDANCY") # we use the bucket iterator because of: # https:/github.com/boto/boto/issues/1173 keys = dict([(k.name, k) for k in bucket]) - keys['new-key'].storage_class.should.equal("REDUCED_REDUNDANCY") - keys['the-key'].storage_class.should.equal("STANDARD") + keys["new-key"].storage_class.should.equal("REDUCED_REDUNDANCY") + keys["the-key"].storage_class.should.equal("STANDARD") @freeze_time("2012-01-01 12:00:00") @mock_s3_deprecated def test_restore_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") list(bucket)[0].ongoing_restore.should.be.none key.restore(1) - key = bucket.get_key('the-key') + key = bucket.get_key("the-key") key.ongoing_restore.should_not.be.none key.ongoing_restore.should.be.false key.expiry_date.should.equal("Mon, 02 Jan 2012 12:00:00 GMT") key.restore(2) - key = bucket.get_key('the-key') + key = bucket.get_key("the-key") key.ongoing_restore.should_not.be.none key.ongoing_restore.should.be.false key.expiry_date.should.equal("Tue, 03 Jan 2012 12:00:00 GMT") @@ -801,13 +802,13 @@ def test_restore_key(): @freeze_time("2012-01-01 12:00:00") @mock_s3_deprecated def test_restore_key_headers(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") - key.restore(1, headers={'foo': 'bar'}) - key = bucket.get_key('the-key') + key.restore(1, headers={"foo": "bar"}) + key = bucket.get_key("the-key") key.ongoing_restore.should_not.be.none key.ongoing_restore.should.be.false key.expiry_date.should.equal("Mon, 02 Jan 2012 12:00:00 GMT") @@ -815,51 +816,51 @@ def test_restore_key_headers(): @mock_s3_deprecated def test_get_versioning_status(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") d = bucket.get_versioning_status() d.should.be.empty bucket.configure_versioning(versioning=True) d = bucket.get_versioning_status() d.shouldnt.be.empty - d.should.have.key('Versioning').being.equal('Enabled') + d.should.have.key("Versioning").being.equal("Enabled") bucket.configure_versioning(versioning=False) d = bucket.get_versioning_status() - d.should.have.key('Versioning').being.equal('Suspended') + d.should.have.key("Versioning").being.equal("Suspended") @mock_s3_deprecated def test_key_version(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") bucket.configure_versioning(versioning=True) versions = [] key = Key(bucket) - key.key = 'the-key' + key.key = "the-key" key.version_id.should.be.none - key.set_contents_from_string('some string') + key.set_contents_from_string("some string") versions.append(key.version_id) - key.set_contents_from_string('some string') + key.set_contents_from_string("some string") versions.append(key.version_id) set(versions).should.have.length_of(2) - key = bucket.get_key('the-key') + key = bucket.get_key("the-key") key.version_id.should.equal(versions[-1]) @mock_s3_deprecated def test_list_versions(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") bucket.configure_versioning(versioning=True) key_versions = [] - key = Key(bucket, 'the-key') + key = Key(bucket, "the-key") key.version_id.should.be.none key.set_contents_from_string("Version 1") key_versions.append(key.version_id) @@ -870,32 +871,32 @@ def test_list_versions(): versions = list(bucket.list_versions()) versions.should.have.length_of(2) - versions[0].name.should.equal('the-key') + versions[0].name.should.equal("the-key") versions[0].version_id.should.equal(key_versions[0]) versions[0].get_contents_as_string().should.equal(b"Version 1") - versions[1].name.should.equal('the-key') + versions[1].name.should.equal("the-key") versions[1].version_id.should.equal(key_versions[1]) versions[1].get_contents_as_string().should.equal(b"Version 2") - key = Key(bucket, 'the2-key') + key = Key(bucket, "the2-key") key.set_contents_from_string("Version 1") keys = list(bucket.list()) keys.should.have.length_of(2) - versions = list(bucket.list_versions(prefix='the2-')) + versions = list(bucket.list_versions(prefix="the2-")) versions.should.have.length_of(1) @mock_s3_deprecated def test_acl_setting(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') - content = b'imafile' - keyname = 'test.txt' + bucket = conn.create_bucket("foobar") + content = b"imafile" + keyname = "test.txt" key = Key(bucket, name=keyname) - key.content_type = 'text/plain' + key.content_type = "text/plain" key.set_contents_from_string(content) key.make_public() @@ -904,147 +905,175 @@ def test_acl_setting(): assert key.get_contents_as_string() == content grants = key.get_acl().acl.grants - assert any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3_deprecated def test_acl_setting_via_headers(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') - content = b'imafile' - keyname = 'test.txt' + bucket = conn.create_bucket("foobar") + content = b"imafile" + keyname = "test.txt" key = Key(bucket, name=keyname) - key.content_type = 'text/plain' - key.set_contents_from_string(content, headers={ - 'x-amz-grant-full-control': 'uri="http://acs.amazonaws.com/groups/global/AllUsers"' - }) + key.content_type = "text/plain" + key.set_contents_from_string( + content, + headers={ + "x-amz-grant-full-control": 'uri="http://acs.amazonaws.com/groups/global/AllUsers"' + }, + ) key = bucket.get_key(keyname) assert key.get_contents_as_string() == content grants = key.get_acl().acl.grants - assert any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'FULL_CONTROL' for g in grants), grants + assert any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "FULL_CONTROL" + for g in grants + ), grants @mock_s3_deprecated def test_acl_switching(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') - content = b'imafile' - keyname = 'test.txt' + bucket = conn.create_bucket("foobar") + content = b"imafile" + keyname = "test.txt" key = Key(bucket, name=keyname) - key.content_type = 'text/plain' - key.set_contents_from_string(content, policy='public-read') - key.set_acl('private') + key.content_type = "text/plain" + key.set_contents_from_string(content, policy="public-read") + key.set_acl("private") grants = key.get_acl().acl.grants - assert not any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert not any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3_deprecated def test_bucket_acl_setting(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') + bucket = conn.create_bucket("foobar") bucket.make_public() grants = bucket.get_acl().acl.grants - assert any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3_deprecated def test_bucket_acl_switching(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') + bucket = conn.create_bucket("foobar") bucket.make_public() - bucket.set_acl('private') + bucket.set_acl("private") grants = bucket.get_acl().acl.grants - assert not any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert not any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3 def test_s3_object_in_public_bucket(): - s3 = boto3.resource('s3') - bucket = s3.Bucket('test-bucket') - bucket.create(ACL='public-read') - bucket.put_object(Body=b'ABCD', Key='file.txt') + s3 = boto3.resource("s3") + bucket = s3.Bucket("test-bucket") + bucket.create(ACL="public-read") + bucket.put_object(Body=b"ABCD", Key="file.txt") - s3_anonymous = boto3.resource('s3') - s3_anonymous.meta.client.meta.events.register('choose-signer.s3.*', disable_signing) + s3_anonymous = boto3.resource("s3") + s3_anonymous.meta.client.meta.events.register("choose-signer.s3.*", disable_signing) - contents = s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get()['Body'].read() - contents.should.equal(b'ABCD') + contents = ( + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket") + .get()["Body"] + .read() + ) + contents.should.equal(b"ABCD") - bucket.put_object(ACL='private', Body=b'ABCD', Key='file.txt') + bucket.put_object(ACL="private", Body=b"ABCD", Key="file.txt") with assert_raises(ClientError) as exc: - s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get() - exc.exception.response['Error']['Code'].should.equal('403') + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket").get() + exc.exception.response["Error"]["Code"].should.equal("403") - params = {'Bucket': 'test-bucket', 'Key': 'file.txt'} - presigned_url = boto3.client('s3').generate_presigned_url('get_object', params, ExpiresIn=900) + params = {"Bucket": "test-bucket", "Key": "file.txt"} + presigned_url = boto3.client("s3").generate_presigned_url( + "get_object", params, ExpiresIn=900 + ) response = requests.get(presigned_url) assert response.status_code == 200 @mock_s3 def test_s3_object_in_private_bucket(): - s3 = boto3.resource('s3') - bucket = s3.Bucket('test-bucket') - bucket.create(ACL='private') - bucket.put_object(ACL='private', Body=b'ABCD', Key='file.txt') + s3 = boto3.resource("s3") + bucket = s3.Bucket("test-bucket") + bucket.create(ACL="private") + bucket.put_object(ACL="private", Body=b"ABCD", Key="file.txt") - s3_anonymous = boto3.resource('s3') - s3_anonymous.meta.client.meta.events.register('choose-signer.s3.*', disable_signing) + s3_anonymous = boto3.resource("s3") + s3_anonymous.meta.client.meta.events.register("choose-signer.s3.*", disable_signing) with assert_raises(ClientError) as exc: - s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get() - exc.exception.response['Error']['Code'].should.equal('403') + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket").get() + exc.exception.response["Error"]["Code"].should.equal("403") - bucket.put_object(ACL='public-read', Body=b'ABCD', Key='file.txt') - contents = s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get()['Body'].read() - contents.should.equal(b'ABCD') + bucket.put_object(ACL="public-read", Body=b"ABCD", Key="file.txt") + contents = ( + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket") + .get()["Body"] + .read() + ) + contents.should.equal(b"ABCD") @mock_s3_deprecated def test_unicode_key(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = Key(bucket) - key.key = u'こんにちは.jpg' - key.set_contents_from_string('Hello world!') + key.key = "こんにちは.jpg" + key.set_contents_from_string("Hello world!") assert [listed_key.key for listed_key in bucket.list()] == [key.key] fetched_key = bucket.get_key(key.key) assert fetched_key.key == key.key - assert fetched_key.get_contents_as_string().decode("utf-8") == 'Hello world!' + assert fetched_key.get_contents_as_string().decode("utf-8") == "Hello world!" @mock_s3_deprecated def test_unicode_value(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = Key(bucket) - key.key = 'some_key' - key.set_contents_from_string(u'こんにちは.jpg') + key.key = "some_key" + key.set_contents_from_string("こんにちは.jpg") list(bucket.list()) key = bucket.get_key(key.key) - assert key.get_contents_as_string().decode("utf-8") == u'こんにちは.jpg' + assert key.get_contents_as_string().decode("utf-8") == "こんにちは.jpg" @mock_s3_deprecated def test_setting_content_encoding(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = bucket.new_key("keyname") key.set_metadata("Content-Encoding", "gzip") compressed_data = "abcdef" @@ -1057,77 +1086,57 @@ def test_setting_content_encoding(): @mock_s3_deprecated def test_bucket_location(): conn = boto.s3.connect_to_region("us-west-2") - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") bucket.get_location().should.equal("us-west-2") @mock_s3 def test_bucket_location_us_east_1(): - cli = boto3.client('s3') - bucket_name = 'mybucket' + cli = boto3.client("s3") + bucket_name = "mybucket" # No LocationConstraint ==> us-east-1 cli.create_bucket(Bucket=bucket_name) - cli.get_bucket_location(Bucket=bucket_name)['LocationConstraint'].should.equal(None) + cli.get_bucket_location(Bucket=bucket_name)["LocationConstraint"].should.equal(None) @mock_s3_deprecated def test_ranged_get(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = Key(bucket) - key.key = 'bigkey' + key.key = "bigkey" rep = b"0123456789" key.set_contents_from_string(rep * 10) # Implicitly bounded range requests. - key.get_contents_as_string( - headers={'Range': 'bytes=0-'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=50-'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=99-'}).should.equal(b'9') + key.get_contents_as_string(headers={"Range": "bytes=0-"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=50-"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=99-"}).should.equal(b"9") # Explicitly bounded range requests starting from the first byte. - key.get_contents_as_string( - headers={'Range': 'bytes=0-0'}).should.equal(b'0') - key.get_contents_as_string( - headers={'Range': 'bytes=0-49'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=0-99'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=0-100'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=0-700'}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=0-0"}).should.equal(b"0") + key.get_contents_as_string(headers={"Range": "bytes=0-49"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=0-99"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=0-100"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=0-700"}).should.equal(rep * 10) # Explicitly bounded range requests starting from the / a middle byte. - key.get_contents_as_string( - headers={'Range': 'bytes=50-54'}).should.equal(rep[:5]) - key.get_contents_as_string( - headers={'Range': 'bytes=50-99'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=50-100'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=50-700'}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=50-54"}).should.equal(rep[:5]) + key.get_contents_as_string(headers={"Range": "bytes=50-99"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=50-100"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=50-700"}).should.equal(rep * 5) # Explicitly bounded range requests starting from the last byte. - key.get_contents_as_string( - headers={'Range': 'bytes=99-99'}).should.equal(b'9') - key.get_contents_as_string( - headers={'Range': 'bytes=99-100'}).should.equal(b'9') - key.get_contents_as_string( - headers={'Range': 'bytes=99-700'}).should.equal(b'9') + key.get_contents_as_string(headers={"Range": "bytes=99-99"}).should.equal(b"9") + key.get_contents_as_string(headers={"Range": "bytes=99-100"}).should.equal(b"9") + key.get_contents_as_string(headers={"Range": "bytes=99-700"}).should.equal(b"9") # Suffix range requests. - key.get_contents_as_string( - headers={'Range': 'bytes=-1'}).should.equal(b'9') - key.get_contents_as_string( - headers={'Range': 'bytes=-60'}).should.equal(rep * 6) - key.get_contents_as_string( - headers={'Range': 'bytes=-100'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=-101'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=-700'}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=-1"}).should.equal(b"9") + key.get_contents_as_string(headers={"Range": "bytes=-60"}).should.equal(rep * 6) + key.get_contents_as_string(headers={"Range": "bytes=-100"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=-101"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=-700"}).should.equal(rep * 10) key.size.should.equal(100) @@ -1135,36 +1144,40 @@ def test_ranged_get(): @mock_s3_deprecated def test_policy(): conn = boto.connect_s3() - bucket_name = 'mybucket' + bucket_name = "mybucket" bucket = conn.create_bucket(bucket_name) - policy = json.dumps({ - "Version": "2012-10-17", - "Id": "PutObjPolicy", - "Statement": [ - { - "Sid": "DenyUnEncryptedObjectUploads", - "Effect": "Deny", - "Principal": "*", - "Action": "s3:PutObject", - "Resource": "arn:aws:s3:::{bucket_name}/*".format(bucket_name=bucket_name), - "Condition": { - "StringNotEquals": { - "s3:x-amz-server-side-encryption": "aws:kms" - } + policy = json.dumps( + { + "Version": "2012-10-17", + "Id": "PutObjPolicy", + "Statement": [ + { + "Sid": "DenyUnEncryptedObjectUploads", + "Effect": "Deny", + "Principal": "*", + "Action": "s3:PutObject", + "Resource": "arn:aws:s3:::{bucket_name}/*".format( + bucket_name=bucket_name + ), + "Condition": { + "StringNotEquals": { + "s3:x-amz-server-side-encryption": "aws:kms" + } + }, } - } - ] - }) + ], + } + ) with assert_raises(S3ResponseError) as err: bucket.get_policy() ex = err.exception ex.box_usage.should.be.none - ex.error_code.should.equal('NoSuchBucketPolicy') - ex.message.should.equal('The bucket policy does not exist') - ex.reason.should.equal('Not Found') + ex.error_code.should.equal("NoSuchBucketPolicy") + ex.message.should.equal("The bucket policy does not exist") + ex.reason.should.equal("Not Found") ex.resource.should.be.none ex.status.should.equal(404) ex.body.should.contain(bucket_name) @@ -1174,7 +1187,7 @@ def test_policy(): bucket = conn.get_bucket(bucket_name) - bucket.get_policy().decode('utf-8').should.equal(policy) + bucket.get_policy().decode("utf-8").should.equal(policy) bucket.delete_policy() @@ -1185,7 +1198,7 @@ def test_policy(): @mock_s3_deprecated def test_website_configuration_xml(): conn = boto.connect_s3() - bucket = conn.create_bucket('test-bucket') + bucket = conn.create_bucket("test-bucket") bucket.set_website_configuration_xml(TEST_XML) bucket.get_website_configuration_xml().should.equal(TEST_XML) @@ -1193,81 +1206,129 @@ def test_website_configuration_xml(): @mock_s3_deprecated def test_key_with_trailing_slash_in_ordinary_calling_format(): conn = boto.connect_s3( - 'access_key', - 'secret_key', - calling_format=boto.s3.connection.OrdinaryCallingFormat() + "access_key", + "secret_key", + calling_format=boto.s3.connection.OrdinaryCallingFormat(), ) - bucket = conn.create_bucket('test_bucket_name') + bucket = conn.create_bucket("test_bucket_name") - key_name = 'key_with_slash/' + key_name = "key_with_slash/" key = Key(bucket, key_name) - key.set_contents_from_string('some value') + key.set_contents_from_string("some value") [k.name for k in bucket.get_all_keys()].should.contain(key_name) -""" -boto3 -""" - - @mock_s3 def test_boto3_key_etag(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='steve', Body=b'is awesome') - resp = s3.get_object(Bucket='mybucket', Key='steve') - resp['ETag'].should.equal('"d32bda93738f7e03adb22e66c90fbc04"') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="steve", Body=b"is awesome") + resp = s3.get_object(Bucket="mybucket", Key="steve") + resp["ETag"].should.equal('"d32bda93738f7e03adb22e66c90fbc04"') @mock_s3 def test_website_redirect_location(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") - s3.put_object(Bucket='mybucket', Key='steve', Body=b'is awesome') - resp = s3.get_object(Bucket='mybucket', Key='steve') - resp.get('WebsiteRedirectLocation').should.be.none + s3.put_object(Bucket="mybucket", Key="steve", Body=b"is awesome") + resp = s3.get_object(Bucket="mybucket", Key="steve") + resp.get("WebsiteRedirectLocation").should.be.none - url = 'https://github.com/spulec/moto' - s3.put_object(Bucket='mybucket', Key='steve', Body=b'is awesome', WebsiteRedirectLocation=url) - resp = s3.get_object(Bucket='mybucket', Key='steve') - resp['WebsiteRedirectLocation'].should.equal(url) + url = "https://github.com/spulec/moto" + s3.put_object( + Bucket="mybucket", Key="steve", Body=b"is awesome", WebsiteRedirectLocation=url + ) + resp = s3.get_object(Bucket="mybucket", Key="steve") + resp["WebsiteRedirectLocation"].should.equal(url) + + +@mock_s3 +def test_boto3_list_objects_truncated_response(): + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"1") + s3.put_object(Bucket="mybucket", Key="two", Body=b"22") + s3.put_object(Bucket="mybucket", Key="three", Body=b"333") + + # First list + resp = s3.list_objects(Bucket="mybucket", MaxKeys=1) + listed_object = resp["Contents"][0] + + assert listed_object["Key"] == "one" + assert resp["MaxKeys"] == 1 + assert resp["IsTruncated"] == True + assert resp["Prefix"] == "None" + assert resp["Delimiter"] == "None" + assert "NextMarker" in resp + + next_marker = resp["NextMarker"] + + # Second list + resp = s3.list_objects(Bucket="mybucket", MaxKeys=1, Marker=next_marker) + listed_object = resp["Contents"][0] + + assert listed_object["Key"] == "three" + assert resp["MaxKeys"] == 1 + assert resp["IsTruncated"] == True + assert resp["Prefix"] == "None" + assert resp["Delimiter"] == "None" + assert "NextMarker" in resp + + next_marker = resp["NextMarker"] + + # Third list + resp = s3.list_objects(Bucket="mybucket", MaxKeys=1, Marker=next_marker) + listed_object = resp["Contents"][0] + + assert listed_object["Key"] == "two" + assert resp["MaxKeys"] == 1 + assert resp["IsTruncated"] == False + assert resp["Prefix"] == "None" + assert resp["Delimiter"] == "None" + assert "NextMarker" not in resp @mock_s3 def test_boto3_list_keys_xml_escaped(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - key_name = 'Q&A.txt' - s3.put_object(Bucket='mybucket', Key=key_name, Body=b'is awesome') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + key_name = "Q&A.txt" + s3.put_object(Bucket="mybucket", Key=key_name, Body=b"is awesome") - resp = s3.list_objects_v2(Bucket='mybucket', Prefix=key_name) + resp = s3.list_objects_v2(Bucket="mybucket", Prefix=key_name) - assert resp['Contents'][0]['Key'] == key_name - assert resp['KeyCount'] == 1 - assert resp['MaxKeys'] == 1000 - assert resp['Prefix'] == key_name - assert resp['IsTruncated'] == False - assert 'Delimiter' not in resp - assert 'StartAfter' not in resp - assert 'NextContinuationToken' not in resp - assert 'Owner' not in resp['Contents'][0] + assert resp["Contents"][0]["Key"] == key_name + assert resp["KeyCount"] == 1 + assert resp["MaxKeys"] == 1000 + assert resp["Prefix"] == key_name + assert resp["IsTruncated"] == False + assert "Delimiter" not in resp + assert "StartAfter" not in resp + assert "NextContinuationToken" not in resp + assert "Owner" not in resp["Contents"][0] @mock_s3 def test_boto3_list_objects_v2_common_prefix_pagination(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") max_keys = 1 - keys = ['test/{i}/{i}'.format(i=i) for i in range(3)] + keys = ["test/{i}/{i}".format(i=i) for i in range(3)] for key in keys: - s3.put_object(Bucket='mybucket', Key=key, Body=b'v') + s3.put_object(Bucket="mybucket", Key=key, Body=b"v") prefixes = [] - args = {"Bucket": 'mybucket', "Delimiter": "/", "Prefix": "test/", "MaxKeys": max_keys} + args = { + "Bucket": "mybucket", + "Delimiter": "/", + "Prefix": "test/", + "MaxKeys": max_keys, + } resp = {"IsTruncated": True} while resp.get("IsTruncated", False): if "NextContinuationToken" in resp: @@ -1277,186 +1338,220 @@ def test_boto3_list_objects_v2_common_prefix_pagination(): assert len(resp["CommonPrefixes"]) == max_keys prefixes.extend(i["Prefix"] for i in resp["CommonPrefixes"]) - assert prefixes == [k[:k.rindex('/') + 1] for k in keys] + assert prefixes == [k[: k.rindex("/") + 1] for k in keys] @mock_s3 def test_boto3_list_objects_v2_truncated_response(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='one', Body=b'1') - s3.put_object(Bucket='mybucket', Key='two', Body=b'22') - s3.put_object(Bucket='mybucket', Key='three', Body=b'333') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"1") + s3.put_object(Bucket="mybucket", Key="two", Body=b"22") + s3.put_object(Bucket="mybucket", Key="three", Body=b"333") # First list - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=1) - listed_object = resp['Contents'][0] + resp = s3.list_objects_v2(Bucket="mybucket", MaxKeys=1) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'one' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == True - assert 'Delimiter' not in resp - assert 'StartAfter' not in resp - assert 'Owner' not in listed_object # owner info was not requested + assert listed_object["Key"] == "one" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == True + assert "Delimiter" not in resp + assert "StartAfter" not in resp + assert "Owner" not in listed_object # owner info was not requested - next_token = resp['NextContinuationToken'] + next_token = resp["NextContinuationToken"] # Second list resp = s3.list_objects_v2( - Bucket='mybucket', MaxKeys=1, ContinuationToken=next_token) - listed_object = resp['Contents'][0] + Bucket="mybucket", MaxKeys=1, ContinuationToken=next_token + ) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'three' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == True - assert 'Delimiter' not in resp - assert 'StartAfter' not in resp - assert 'Owner' not in listed_object + assert listed_object["Key"] == "three" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == True + assert "Delimiter" not in resp + assert "StartAfter" not in resp + assert "Owner" not in listed_object - next_token = resp['NextContinuationToken'] + next_token = resp["NextContinuationToken"] # Third list resp = s3.list_objects_v2( - Bucket='mybucket', MaxKeys=1, ContinuationToken=next_token) - listed_object = resp['Contents'][0] + Bucket="mybucket", MaxKeys=1, ContinuationToken=next_token + ) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'two' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == False - assert 'Delimiter' not in resp - assert 'Owner' not in listed_object - assert 'StartAfter' not in resp - assert 'NextContinuationToken' not in resp + assert listed_object["Key"] == "two" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == False + assert "Delimiter" not in resp + assert "Owner" not in listed_object + assert "StartAfter" not in resp + assert "NextContinuationToken" not in resp @mock_s3 def test_boto3_list_objects_v2_truncated_response_start_after(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='one', Body=b'1') - s3.put_object(Bucket='mybucket', Key='two', Body=b'22') - s3.put_object(Bucket='mybucket', Key='three', Body=b'333') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"1") + s3.put_object(Bucket="mybucket", Key="two", Body=b"22") + s3.put_object(Bucket="mybucket", Key="three", Body=b"333") # First list - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=1, StartAfter='one') - listed_object = resp['Contents'][0] + resp = s3.list_objects_v2(Bucket="mybucket", MaxKeys=1, StartAfter="one") + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'three' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == True - assert resp['StartAfter'] == 'one' - assert 'Delimiter' not in resp - assert 'Owner' not in listed_object + assert listed_object["Key"] == "three" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == True + assert resp["StartAfter"] == "one" + assert "Delimiter" not in resp + assert "Owner" not in listed_object - next_token = resp['NextContinuationToken'] + next_token = resp["NextContinuationToken"] # Second list # The ContinuationToken must take precedence over StartAfter. - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=1, StartAfter='one', - ContinuationToken=next_token) - listed_object = resp['Contents'][0] + resp = s3.list_objects_v2( + Bucket="mybucket", MaxKeys=1, StartAfter="one", ContinuationToken=next_token + ) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'two' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == False + assert listed_object["Key"] == "two" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == False # When ContinuationToken is given, StartAfter is ignored. This also means # AWS does not return it in the response. - assert 'StartAfter' not in resp - assert 'Delimiter' not in resp - assert 'Owner' not in listed_object + assert "StartAfter" not in resp + assert "Delimiter" not in resp + assert "Owner" not in listed_object @mock_s3 def test_boto3_list_objects_v2_fetch_owner(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='one', Body=b'11') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"11") - resp = s3.list_objects_v2(Bucket='mybucket', FetchOwner=True) - owner = resp['Contents'][0]['Owner'] + resp = s3.list_objects_v2(Bucket="mybucket", FetchOwner=True) + owner = resp["Contents"][0]["Owner"] - assert 'ID' in owner - assert 'DisplayName' in owner + assert "ID" in owner + assert "DisplayName" in owner assert len(owner.keys()) == 2 +@mock_s3 +def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="1/2", Body="") + s3.put_object(Bucket="mybucket", Key="2", Body="") + s3.put_object(Bucket="mybucket", Key="3/4", Body="") + s3.put_object(Bucket="mybucket", Key="4", Body="") + + resp = s3.list_objects_v2(Bucket="mybucket", Prefix="", MaxKeys=2, Delimiter="/") + assert "Delimiter" in resp + assert resp["IsTruncated"] is True + assert resp["KeyCount"] == 2 + assert len(resp["Contents"]) == 1 + assert resp["Contents"][0]["Key"] == "2" + assert len(resp["CommonPrefixes"]) == 1 + assert resp["CommonPrefixes"][0]["Prefix"] == "1/" + + last_tail = resp["NextContinuationToken"] + resp = s3.list_objects_v2( + Bucket="mybucket", MaxKeys=2, Prefix="", Delimiter="/", StartAfter=last_tail + ) + assert resp["KeyCount"] == 2 + assert resp["IsTruncated"] is False + assert len(resp["Contents"]) == 1 + assert resp["Contents"][0]["Key"] == "4" + assert len(resp["CommonPrefixes"]) == 1 + assert resp["CommonPrefixes"][0]["Prefix"] == "3/" + + @mock_s3 def test_boto3_bucket_create(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').get()['Body'].read().decode( - "utf-8").should.equal("some text") + s3.Object("blah", "hello.txt").get()["Body"].read().decode("utf-8").should.equal( + "some text" + ) @mock_s3 def test_bucket_create_duplicate(): - s3 = boto3.resource('s3', region_name='us-west-2') - s3.create_bucket(Bucket="blah", CreateBucketConfiguration={ - 'LocationConstraint': 'us-west-2', - }) + s3 = boto3.resource("s3", region_name="us-west-2") + s3.create_bucket( + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "us-west-2"} + ) with assert_raises(ClientError) as exc: s3.create_bucket( - Bucket="blah", - CreateBucketConfiguration={ - 'LocationConstraint': 'us-west-2', - } + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "us-west-2"} ) - exc.exception.response['Error']['Code'].should.equal('BucketAlreadyExists') + exc.exception.response["Error"]["Code"].should.equal("BucketAlreadyExists") @mock_s3 def test_bucket_create_force_us_east_1(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") with assert_raises(ClientError) as exc: - s3.create_bucket(Bucket="blah", CreateBucketConfiguration={ - 'LocationConstraint': 'us-east-1', - }) - exc.exception.response['Error']['Code'].should.equal('InvalidLocationConstraint') + s3.create_bucket( + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "us-east-1"} + ) + exc.exception.response["Error"]["Code"].should.equal("InvalidLocationConstraint") @mock_s3 def test_boto3_bucket_create_eu_central(): - s3 = boto3.resource('s3', region_name='eu-central-1') + s3 = boto3.resource("s3", region_name="eu-central-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').get()['Body'].read().decode( - "utf-8").should.equal("some text") + s3.Object("blah", "hello.txt").get()["Body"].read().decode("utf-8").should.equal( + "some text" + ) @mock_s3 def test_boto3_head_object(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt') + s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt" + ) with assert_raises(ClientError) as e: - s3.Object('blah', 'hello2.txt').meta.client.head_object( - Bucket='blah', Key='hello_bad.txt') - e.exception.response['Error']['Code'].should.equal('404') + s3.Object("blah", "hello2.txt").meta.client.head_object( + Bucket="blah", Key="hello_bad.txt" + ) + e.exception.response["Error"]["Code"].should.equal("404") @mock_s3 def test_boto3_bucket_deletion(): - cli = boto3.client('s3', region_name='us-east-1') + cli = boto3.client("s3", region_name="us-east-1") cli.create_bucket(Bucket="foobar") cli.put_object(Bucket="foobar", Key="the-key", Body="some value") @@ -1464,8 +1559,11 @@ def test_boto3_bucket_deletion(): # Try to delete a bucket that still has keys cli.delete_bucket.when.called_with(Bucket="foobar").should.throw( cli.exceptions.ClientError, - ('An error occurred (BucketNotEmpty) when calling the DeleteBucket operation: ' - 'The bucket you tried to delete is not empty')) + ( + "An error occurred (BucketNotEmpty) when calling the DeleteBucket operation: " + "The bucket you tried to delete is not empty" + ), + ) cli.delete_object(Bucket="foobar", Key="the-key") cli.delete_bucket(Bucket="foobar") @@ -1473,110 +1571,158 @@ def test_boto3_bucket_deletion(): # Get non-existing bucket cli.head_bucket.when.called_with(Bucket="foobar").should.throw( cli.exceptions.ClientError, - "An error occurred (404) when calling the HeadBucket operation: Not Found") + "An error occurred (404) when calling the HeadBucket operation: Not Found", + ) # Delete non-existing bucket - cli.delete_bucket.when.called_with(Bucket="foobar").should.throw(cli.exceptions.NoSuchBucket) + cli.delete_bucket.when.called_with(Bucket="foobar").should.throw( + cli.exceptions.NoSuchBucket + ) @mock_s3 def test_boto3_get_object(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt') + s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt" + ) with assert_raises(ClientError) as e: - s3.Object('blah', 'hello2.txt').get() + s3.Object("blah", "hello2.txt").get() - e.exception.response['Error']['Code'].should.equal('NoSuchKey') + e.exception.response["Error"]["Code"].should.equal("NoSuchKey") + + +@mock_s3 +def test_boto3_get_missing_object_with_part_number(): + s3 = boto3.resource("s3", region_name="us-east-1") + s3.create_bucket(Bucket="blah") + + with assert_raises(ClientError) as e: + s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt", PartNumber=123 + ) + + e.exception.response["Error"]["Code"].should.equal("404") @mock_s3 def test_boto3_head_object_with_versioning(): - s3 = boto3.resource('s3', region_name='us-east-1') - bucket = s3.create_bucket(Bucket='blah') + s3 = boto3.resource("s3", region_name="us-east-1") + bucket = s3.create_bucket(Bucket="blah") bucket.Versioning().enable() - old_content = 'some text' - new_content = 'some new text' - s3.Object('blah', 'hello.txt').put(Body=old_content) - s3.Object('blah', 'hello.txt').put(Body=new_content) + old_content = "some text" + new_content = "some new text" + s3.Object("blah", "hello.txt").put(Body=old_content) + s3.Object("blah", "hello.txt").put(Body=new_content) - versions = list(s3.Bucket('blah').object_versions.all()) + versions = list(s3.Bucket("blah").object_versions.all()) latest = list(filter(lambda item: item.is_latest, versions))[0] oldest = list(filter(lambda item: not item.is_latest, versions))[0] - head_object = s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt') - head_object['VersionId'].should.equal(latest.id) - head_object['ContentLength'].should.equal(len(new_content)) + head_object = s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt" + ) + head_object["VersionId"].should.equal(latest.id) + head_object["ContentLength"].should.equal(len(new_content)) - old_head_object = s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt', VersionId=oldest.id) - old_head_object['VersionId'].should.equal(oldest.id) - old_head_object['ContentLength'].should.equal(len(old_content)) + old_head_object = s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt", VersionId=oldest.id + ) + old_head_object["VersionId"].should.equal(oldest.id) + old_head_object["ContentLength"].should.equal(len(old_content)) - old_head_object['VersionId'].should_not.equal(head_object['VersionId']) + old_head_object["VersionId"].should_not.equal(head_object["VersionId"]) @mock_s3 def test_boto3_copy_object_with_versioning(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='blah', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - client.put_bucket_versioning(Bucket='blah', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket( + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "eu-west-1"} + ) + client.put_bucket_versioning( + Bucket="blah", VersioningConfiguration={"Status": "Enabled"} + ) - client.put_object(Bucket='blah', Key='test1', Body=b'test1') - client.put_object(Bucket='blah', Key='test2', Body=b'test2') + client.put_object(Bucket="blah", Key="test1", Body=b"test1") + client.put_object(Bucket="blah", Key="test2", Body=b"test2") - obj1_version = client.get_object(Bucket='blah', Key='test1')['VersionId'] - obj2_version = client.get_object(Bucket='blah', Key='test2')['VersionId'] + obj1_version = client.get_object(Bucket="blah", Key="test1")["VersionId"] + obj2_version = client.get_object(Bucket="blah", Key="test2")["VersionId"] - client.copy_object(CopySource={'Bucket': 'blah', 'Key': 'test1'}, Bucket='blah', Key='test2') - obj2_version_new = client.get_object(Bucket='blah', Key='test2')['VersionId'] + client.copy_object( + CopySource={"Bucket": "blah", "Key": "test1"}, Bucket="blah", Key="test2" + ) + obj2_version_new = client.get_object(Bucket="blah", Key="test2")["VersionId"] # Version should be different to previous version obj2_version_new.should_not.equal(obj2_version) - client.copy_object(CopySource={'Bucket': 'blah', 'Key': 'test2', 'VersionId': obj2_version}, Bucket='blah', Key='test3') - obj3_version_new = client.get_object(Bucket='blah', Key='test3')['VersionId'] + client.copy_object( + CopySource={"Bucket": "blah", "Key": "test2", "VersionId": obj2_version}, + Bucket="blah", + Key="test3", + ) + obj3_version_new = client.get_object(Bucket="blah", Key="test3")["VersionId"] obj3_version_new.should_not.equal(obj2_version_new) # Copy file that doesn't exist with assert_raises(ClientError) as e: - client.copy_object(CopySource={'Bucket': 'blah', 'Key': 'test4', 'VersionId': obj2_version}, Bucket='blah', Key='test5') - e.exception.response['Error']['Code'].should.equal('404') + client.copy_object( + CopySource={"Bucket": "blah", "Key": "test4", "VersionId": obj2_version}, + Bucket="blah", + Key="test5", + ) + e.exception.response["Error"]["Code"].should.equal("404") - response = client.create_multipart_upload(Bucket='blah', Key='test4') - upload_id = response['UploadId'] - response = client.upload_part_copy(Bucket='blah', Key='test4', CopySource={'Bucket': 'blah', 'Key': 'test3', 'VersionId': obj3_version_new}, - UploadId=upload_id, PartNumber=1) + response = client.create_multipart_upload(Bucket="blah", Key="test4") + upload_id = response["UploadId"] + response = client.upload_part_copy( + Bucket="blah", + Key="test4", + CopySource={"Bucket": "blah", "Key": "test3", "VersionId": obj3_version_new}, + UploadId=upload_id, + PartNumber=1, + ) etag = response["CopyPartResult"]["ETag"] client.complete_multipart_upload( - Bucket='blah', Key='test4', UploadId=upload_id, - MultipartUpload={'Parts': [{'ETag': etag, 'PartNumber': 1}]}) + Bucket="blah", + Key="test4", + UploadId=upload_id, + MultipartUpload={"Parts": [{"ETag": etag, "PartNumber": 1}]}, + ) - response = client.get_object(Bucket='blah', Key='test4') + response = client.get_object(Bucket="blah", Key="test4") data = response["Body"].read() - data.should.equal(b'test2') + data.should.equal(b"test2") @mock_s3 def test_boto3_copy_object_from_unversioned_to_versioned_bucket(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='src', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - client.create_bucket(Bucket='dest', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - client.put_bucket_versioning(Bucket='dest', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket( + Bucket="src", CreateBucketConfiguration={"LocationConstraint": "eu-west-1"} + ) + client.create_bucket( + Bucket="dest", CreateBucketConfiguration={"LocationConstraint": "eu-west-1"} + ) + client.put_bucket_versioning( + Bucket="dest", VersioningConfiguration={"Status": "Enabled"} + ) - client.put_object(Bucket='src', Key='test', Body=b'content') + client.put_object(Bucket="src", Key="test", Body=b"content") - obj2_version_new = client.copy_object(CopySource={'Bucket': 'src', 'Key': 'test'}, Bucket='dest', Key='test') \ - .get('VersionId') + obj2_version_new = client.copy_object( + CopySource={"Bucket": "src", "Key": "test"}, Bucket="dest", Key="test" + ).get("VersionId") # VersionId should be present in the response obj2_version_new.should_not.equal(None) @@ -1584,123 +1730,138 @@ def test_boto3_copy_object_from_unversioned_to_versioned_bucket(): @mock_s3 def test_boto3_deleted_versionings_list(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='blah') - client.put_bucket_versioning(Bucket='blah', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket(Bucket="blah") + client.put_bucket_versioning( + Bucket="blah", VersioningConfiguration={"Status": "Enabled"} + ) - client.put_object(Bucket='blah', Key='test1', Body=b'test1') - client.put_object(Bucket='blah', Key='test2', Body=b'test2') - client.delete_objects(Bucket='blah', Delete={'Objects': [{'Key': 'test1'}]}) + client.put_object(Bucket="blah", Key="test1", Body=b"test1") + client.put_object(Bucket="blah", Key="test2", Body=b"test2") + client.delete_objects(Bucket="blah", Delete={"Objects": [{"Key": "test1"}]}) - listed = client.list_objects_v2(Bucket='blah') - assert len(listed['Contents']) == 1 + listed = client.list_objects_v2(Bucket="blah") + assert len(listed["Contents"]) == 1 @mock_s3 def test_boto3_delete_versioned_bucket(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='blah') - client.put_bucket_versioning(Bucket='blah', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket(Bucket="blah") + client.put_bucket_versioning( + Bucket="blah", VersioningConfiguration={"Status": "Enabled"} + ) - resp = client.put_object(Bucket='blah', Key='test1', Body=b'test1') - client.delete_object(Bucket='blah', Key='test1', VersionId=resp["VersionId"]) + resp = client.put_object(Bucket="blah", Key="test1", Body=b"test1") + client.delete_object(Bucket="blah", Key="test1", VersionId=resp["VersionId"]) + + client.delete_bucket(Bucket="blah") - client.delete_bucket(Bucket='blah') @mock_s3 def test_boto3_get_object_if_modified_since(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "blah" s3.create_bucket(Bucket=bucket_name) - key = 'hello.txt' + key = "hello.txt" - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") with assert_raises(botocore.exceptions.ClientError) as err: s3.get_object( Bucket=bucket_name, Key=key, - IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1) + IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1), ) e = err.exception - e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'}) + e.response["Error"].should.equal({"Code": "304", "Message": "Not Modified"}) + @mock_s3 def test_boto3_head_object_if_modified_since(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "blah" s3.create_bucket(Bucket=bucket_name) - key = 'hello.txt' + key = "hello.txt" - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") with assert_raises(botocore.exceptions.ClientError) as err: s3.head_object( Bucket=bucket_name, Key=key, - IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1) + IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1), ) e = err.exception - e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'}) + e.response["Error"].should.equal({"Code": "304", "Message": "Not Modified"}) @mock_s3 @reduced_min_part_size def test_boto3_multipart_etag(): # Create Bucket so that test can run - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") - upload_id = s3.create_multipart_upload( - Bucket='mybucket', Key='the-key')['UploadId'] - part1 = b'0' * REDUCED_PART_SIZE + upload_id = s3.create_multipart_upload(Bucket="mybucket", Key="the-key")["UploadId"] + part1 = b"0" * REDUCED_PART_SIZE etags = [] etags.append( - s3.upload_part(Bucket='mybucket', Key='the-key', PartNumber=1, - UploadId=upload_id, Body=part1)['ETag']) + s3.upload_part( + Bucket="mybucket", + Key="the-key", + PartNumber=1, + UploadId=upload_id, + Body=part1, + )["ETag"] + ) # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" etags.append( - s3.upload_part(Bucket='mybucket', Key='the-key', PartNumber=2, - UploadId=upload_id, Body=part2)['ETag']) + s3.upload_part( + Bucket="mybucket", + Key="the-key", + PartNumber=2, + UploadId=upload_id, + Body=part2, + )["ETag"] + ) s3.complete_multipart_upload( - Bucket='mybucket', Key='the-key', UploadId=upload_id, - MultipartUpload={'Parts': [{'ETag': etag, 'PartNumber': i} - for i, etag in enumerate(etags, 1)]}) + Bucket="mybucket", + Key="the-key", + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"ETag": etag, "PartNumber": i} for i, etag in enumerate(etags, 1) + ] + }, + ) # we should get both parts as the key contents - resp = s3.get_object(Bucket='mybucket', Key='the-key') - resp['ETag'].should.equal(EXPECTED_ETAG) + resp = s3.get_object(Bucket="mybucket", Key="the-key") + resp["ETag"].should.equal(EXPECTED_ETAG) @mock_s3 @reduced_min_part_size def test_boto3_multipart_part_size(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") - mpu = s3.create_multipart_upload(Bucket='mybucket', Key='the-key') + mpu = s3.create_multipart_upload(Bucket="mybucket", Key="the-key") mpu_id = mpu["UploadId"] parts = [] n_parts = 10 for i in range(1, n_parts + 1): part_size = REDUCED_PART_SIZE + i - body = b'1' * part_size + body = b"1" * part_size part = s3.upload_part( - Bucket='mybucket', - Key='the-key', + Bucket="mybucket", + Key="the-key", PartNumber=i, UploadId=mpu_id, Body=body, @@ -1709,34 +1870,29 @@ def test_boto3_multipart_part_size(): parts.append({"PartNumber": i, "ETag": part["ETag"]}) s3.complete_multipart_upload( - Bucket='mybucket', - Key='the-key', + Bucket="mybucket", + Key="the-key", UploadId=mpu_id, MultipartUpload={"Parts": parts}, ) for i in range(1, n_parts + 1): - obj = s3.head_object(Bucket='mybucket', Key='the-key', PartNumber=i) + obj = s3.head_object(Bucket="mybucket", Key="the-key", PartNumber=i) assert obj["ContentLength"] == REDUCED_PART_SIZE + i @mock_s3 def test_boto3_put_object_with_tagging(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test', - Tagging='foo=bar', - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test", Tagging="foo=bar") resp = s3.get_object_tagging(Bucket=bucket_name, Key=key) - resp['TagSet'].should.contain({'Key': 'foo', 'Value': 'bar'}) + resp["TagSet"].should.contain({"Key": "foo", "Value": "bar"}) @mock_s3 @@ -1746,87 +1902,68 @@ def test_boto3_put_bucket_tagging(): s3.create_bucket(Bucket=bucket_name) # With 1 tag: - resp = s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - } - ] - }) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp = s3.put_bucket_tagging( + Bucket=bucket_name, Tagging={"TagSet": [{"Key": "TagOne", "Value": "ValueOne"}]} + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # With multiple tags: - resp = s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagTwo", - "Value": "ValueTwo" - } - ] - }) + resp = s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagTwo", "Value": "ValueTwo"}, + ] + }, + ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # No tags is also OK: - resp = s3.put_bucket_tagging(Bucket=bucket_name, Tagging={ - "TagSet": [] - }) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp = s3.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": []}) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # With duplicate tag keys: with assert_raises(ClientError) as err: - resp = s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagOne", - "Value": "ValueOneAgain" - } - ] - }) + resp = s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagOne", "Value": "ValueOneAgain"}, + ] + }, + ) e = err.exception e.response["Error"]["Code"].should.equal("InvalidTag") - e.response["Error"]["Message"].should.equal("Cannot provide multiple Tags with the same key") + e.response["Error"]["Message"].should.equal( + "Cannot provide multiple Tags with the same key" + ) + @mock_s3 def test_boto3_get_bucket_tagging(): s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagTwo", - "Value": "ValueTwo" - } - ] - }) + s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagTwo", "Value": "ValueTwo"}, + ] + }, + ) # Get the tags for the bucket: resp = s3.get_bucket_tagging(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) len(resp["TagSet"]).should.equal(2) # With no tags: - s3.put_bucket_tagging(Bucket=bucket_name, Tagging={ - "TagSet": [] - }) + s3.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": []}) with assert_raises(ClientError) as err: s3.get_bucket_tagging(Bucket=bucket_name) @@ -1842,22 +1979,18 @@ def test_boto3_delete_bucket_tagging(): bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagTwo", - "Value": "ValueTwo" - } - ] - }) + s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagTwo", "Value": "ValueTwo"}, + ] + }, + ) resp = s3.delete_bucket_tagging(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(204) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(204) with assert_raises(ClientError) as err: s3.get_bucket_tagging(Bucket=bucket_name) @@ -1873,76 +2006,56 @@ def test_boto3_put_bucket_cors(): bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - resp = s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [ - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "GET", - "POST" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - }, - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "PUT" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - } - ] - }) - - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - - with assert_raises(ClientError) as err: - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ + resp = s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ "CORSRules": [ { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "NOTREAL", - "POST" - ] - } + "AllowedOrigins": ["*"], + "AllowedMethods": ["GET", "POST"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, + { + "AllowedOrigins": ["*"], + "AllowedMethods": ["PUT"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, ] - }) - e = err.exception - e.response["Error"]["Code"].should.equal("InvalidRequest") - e.response["Error"]["Message"].should.equal("Found unsupported HTTP method in CORS config. " - "Unsupported method is NOTREAL") + }, + ) + + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) with assert_raises(ClientError) as err: - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [] - }) + s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ + "CORSRules": [ + {"AllowedOrigins": ["*"], "AllowedMethods": ["NOTREAL", "POST"]} + ] + }, + ) + e = err.exception + e.response["Error"]["Code"].should.equal("InvalidRequest") + e.response["Error"]["Message"].should.equal( + "Found unsupported HTTP method in CORS config. " "Unsupported method is NOTREAL" + ) + + with assert_raises(ClientError) as err: + s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={"CORSRules": []}) e = err.exception e.response["Error"]["Code"].should.equal("MalformedXML") # And 101: many_rules = [{"AllowedOrigins": ["*"], "AllowedMethods": ["GET"]}] * 101 with assert_raises(ClientError) as err: - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": many_rules - }) + s3.put_bucket_cors( + Bucket=bucket_name, CORSConfiguration={"CORSRules": many_rules} + ) e = err.exception e.response["Error"]["Code"].should.equal("MalformedXML") @@ -1961,44 +2074,30 @@ def test_boto3_get_bucket_cors(): e.response["Error"]["Code"].should.equal("NoSuchCORSConfiguration") e.response["Error"]["Message"].should.equal("The CORS configuration does not exist") - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [ - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "GET", - "POST" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - }, - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "PUT" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - } - ] - }) + s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ + "CORSRules": [ + { + "AllowedOrigins": ["*"], + "AllowedMethods": ["GET", "POST"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, + { + "AllowedOrigins": ["*"], + "AllowedMethods": ["PUT"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, + ] + }, + ) resp = s3.get_bucket_cors(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) len(resp["CORSRules"]).should.equal(2) @@ -2007,21 +2106,15 @@ def test_boto3_delete_bucket_cors(): s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [ - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "GET" - ] - } - ] - }) + s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ + "CORSRules": [{"AllowedOrigins": ["*"], "AllowedMethods": ["GET"]}] + }, + ) resp = s3.delete_bucket_cors(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(204) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(204) # Verify deletion: with assert_raises(ClientError) as err: @@ -2037,25 +2130,28 @@ def test_put_bucket_acl_body(): s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="bucket") bucket_owner = s3.get_bucket_acl(Bucket="bucket")["Owner"] - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", }, - "Permission": "WRITE" - }, - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "READ_ACP", }, - "Permission": "READ_ACP" - } - ], - "Owner": bucket_owner - }) + ], + "Owner": bucket_owner, + }, + ) result = s3.get_bucket_acl(Bucket="bucket") assert len(result["Grants"]) == 2 @@ -2065,54 +2161,65 @@ def test_put_bucket_acl_body(): assert g["Permission"] in ["WRITE", "READ_ACP"] # With one: - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" - }, - "Permission": "WRITE" - } - ], - "Owner": bucket_owner - }) + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", + } + ], + "Owner": bucket_owner, + }, + ) result = s3.get_bucket_acl(Bucket="bucket") assert len(result["Grants"]) == 1 # With no owner: with assert_raises(ClientError) as err: - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" - }, - "Permission": "WRITE" - } - ] - }) + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "MalformedACLError" # With incorrect permission: with assert_raises(ClientError) as err: - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" - }, - "Permission": "lskjflkasdjflkdsjfalisdjflkdsjf" - } - ], - "Owner": bucket_owner - }) + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "lskjflkasdjflkdsjfalisdjflkdsjf", + } + ], + "Owner": bucket_owner, + }, + ) assert err.exception.response["Error"]["Code"] == "MalformedACLError" # Clear the ACLs: - result = s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={"Grants": [], "Owner": bucket_owner}) + result = s3.put_bucket_acl( + Bucket="bucket", AccessControlPolicy={"Grants": [], "Owner": bucket_owner} + ) assert not result.get("Grants") @@ -2128,46 +2235,43 @@ def test_put_bucket_notification(): assert not result.get("LambdaFunctionConfigurations") # Place proper topic configuration: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "TopicConfigurations": [ - { - "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", - "Events": [ - "s3:ObjectCreated:*", - "s3:ObjectRemoved:*" - ] - }, - { - "TopicArn": "arn:aws:sns:us-east-1:012345678910:myothertopic", - "Events": [ - "s3:ObjectCreated:*" - ], - "Filter": { - "Key": { - "FilterRules": [ - { - "Name": "prefix", - "Value": "images/" - }, - { - "Name": "suffix", - "Value": "png" - } - ] - } - } - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "TopicConfigurations": [ + { + "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", + "Events": ["s3:ObjectCreated:*", "s3:ObjectRemoved:*"], + }, + { + "TopicArn": "arn:aws:sns:us-east-1:012345678910:myothertopic", + "Events": ["s3:ObjectCreated:*"], + "Filter": { + "Key": { + "FilterRules": [ + {"Name": "prefix", "Value": "images/"}, + {"Name": "suffix", "Value": "png"}, + ] + } + }, + }, + ] + }, + ) # Verify to completion: result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["TopicConfigurations"]) == 2 assert not result.get("QueueConfigurations") assert not result.get("LambdaFunctionConfigurations") - assert result["TopicConfigurations"][0]["TopicArn"] == "arn:aws:sns:us-east-1:012345678910:mytopic" - assert result["TopicConfigurations"][1]["TopicArn"] == "arn:aws:sns:us-east-1:012345678910:myothertopic" + assert ( + result["TopicConfigurations"][0]["TopicArn"] + == "arn:aws:sns:us-east-1:012345678910:mytopic" + ) + assert ( + result["TopicConfigurations"][1]["TopicArn"] + == "arn:aws:sns:us-east-1:012345678910:myothertopic" + ) assert len(result["TopicConfigurations"][0]["Events"]) == 2 assert len(result["TopicConfigurations"][1]["Events"]) == 1 assert result["TopicConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" @@ -2177,111 +2281,138 @@ def test_put_bucket_notification(): assert result["TopicConfigurations"][1]["Id"] assert not result["TopicConfigurations"][0].get("Filter") assert len(result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"]) == 2 - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Name"] == "prefix" - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Value"] == "images/" - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Name"] == "suffix" - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Value"] == "png" + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Name"] + == "prefix" + ) + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Value"] + == "images/" + ) + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Name"] + == "suffix" + ) + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Value"] + == "png" + ) # Place proper queue configuration: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "QueueConfigurations": [ - { - "Id": "SomeID", - "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", - "Events": ["s3:ObjectCreated:*"], - "Filter": { - "Key": { - "FilterRules": [ - { - "Name": "prefix", - "Value": "images/" - } - ] - } - } - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "QueueConfigurations": [ + { + "Id": "SomeID", + "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", + "Events": ["s3:ObjectCreated:*"], + "Filter": { + "Key": {"FilterRules": [{"Name": "prefix", "Value": "images/"}]} + }, + } + ] + }, + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["QueueConfigurations"]) == 1 assert not result.get("TopicConfigurations") assert not result.get("LambdaFunctionConfigurations") assert result["QueueConfigurations"][0]["Id"] == "SomeID" - assert result["QueueConfigurations"][0]["QueueArn"] == "arn:aws:sqs:us-east-1:012345678910:myQueue" + assert ( + result["QueueConfigurations"][0]["QueueArn"] + == "arn:aws:sqs:us-east-1:012345678910:myQueue" + ) assert result["QueueConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" assert len(result["QueueConfigurations"][0]["Events"]) == 1 assert len(result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 - assert result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Name"] == "prefix" - assert result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Value"] == "images/" + assert ( + result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Name"] + == "prefix" + ) + assert ( + result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Value"] + == "images/" + ) # Place proper Lambda configuration: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "LambdaFunctionConfigurations": [ - { - "LambdaFunctionArn": - "arn:aws:lambda:us-east-1:012345678910:function:lambda", - "Events": ["s3:ObjectCreated:*"], - "Filter": { - "Key": { - "FilterRules": [ - { - "Name": "prefix", - "Value": "images/" - } - ] - } - } - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "LambdaFunctionConfigurations": [ + { + "LambdaFunctionArn": "arn:aws:lambda:us-east-1:012345678910:function:lambda", + "Events": ["s3:ObjectCreated:*"], + "Filter": { + "Key": {"FilterRules": [{"Name": "prefix", "Value": "images/"}]} + }, + } + ] + }, + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["LambdaFunctionConfigurations"]) == 1 assert not result.get("TopicConfigurations") assert not result.get("QueueConfigurations") assert result["LambdaFunctionConfigurations"][0]["Id"] - assert result["LambdaFunctionConfigurations"][0]["LambdaFunctionArn"] == \ - "arn:aws:lambda:us-east-1:012345678910:function:lambda" - assert result["LambdaFunctionConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" + assert ( + result["LambdaFunctionConfigurations"][0]["LambdaFunctionArn"] + == "arn:aws:lambda:us-east-1:012345678910:function:lambda" + ) + assert ( + result["LambdaFunctionConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" + ) assert len(result["LambdaFunctionConfigurations"][0]["Events"]) == 1 - assert len(result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 - assert result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Name"] == "prefix" - assert result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Value"] == "images/" + assert ( + len(result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"]) + == 1 + ) + assert ( + result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0][ + "Name" + ] + == "prefix" + ) + assert ( + result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0][ + "Value" + ] + == "images/" + ) # And with all 3 set: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "TopicConfigurations": [ - { - "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", - "Events": [ - "s3:ObjectCreated:*", - "s3:ObjectRemoved:*" - ] - } - ], - "LambdaFunctionConfigurations": [ - { - "LambdaFunctionArn": - "arn:aws:lambda:us-east-1:012345678910:function:lambda", - "Events": ["s3:ObjectCreated:*"] - } - ], - "QueueConfigurations": [ - { - "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", - "Events": ["s3:ObjectCreated:*"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "TopicConfigurations": [ + { + "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", + "Events": ["s3:ObjectCreated:*", "s3:ObjectRemoved:*"], + } + ], + "LambdaFunctionConfigurations": [ + { + "LambdaFunctionArn": "arn:aws:lambda:us-east-1:012345678910:function:lambda", + "Events": ["s3:ObjectCreated:*"], + } + ], + "QueueConfigurations": [ + { + "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", + "Events": ["s3:ObjectCreated:*"], + } + ], + }, + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["LambdaFunctionConfigurations"]) == 1 assert len(result["TopicConfigurations"]) == 1 assert len(result["QueueConfigurations"]) == 1 # And clear it out: - s3.put_bucket_notification_configuration(Bucket="bucket", NotificationConfiguration={}) + s3.put_bucket_notification_configuration( + Bucket="bucket", NotificationConfiguration={} + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert not result.get("TopicConfigurations") assert not result.get("QueueConfigurations") @@ -2296,51 +2427,63 @@ def test_put_bucket_notification_errors(): # With incorrect ARNs: for tech, arn in [("Queue", "sqs"), ("Topic", "sns"), ("LambdaFunction", "lambda")]: with assert_raises(ClientError) as err: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "{}Configurations".format(tech): [ - { - "{}Arn".format(tech): - "arn:aws:{}:us-east-1:012345678910:lksajdfkldskfj", - "Events": ["s3:ObjectCreated:*"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "{}Configurations".format(tech): [ + { + "{}Arn".format( + tech + ): "arn:aws:{}:us-east-1:012345678910:lksajdfkldskfj", + "Events": ["s3:ObjectCreated:*"], + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidArgument" - assert err.exception.response["Error"]["Message"] == "The ARN is not well formed" + assert ( + err.exception.response["Error"]["Message"] == "The ARN is not well formed" + ) # Region not the same as the bucket: with assert_raises(ClientError) as err: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "QueueConfigurations": [ - { - "QueueArn": - "arn:aws:sqs:us-west-2:012345678910:lksajdfkldskfj", - "Events": ["s3:ObjectCreated:*"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "QueueConfigurations": [ + { + "QueueArn": "arn:aws:sqs:us-west-2:012345678910:lksajdfkldskfj", + "Events": ["s3:ObjectCreated:*"], + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidArgument" - assert err.exception.response["Error"]["Message"] == \ - "The notification destination service region is not valid for the bucket location constraint" + assert ( + err.exception.response["Error"]["Message"] + == "The notification destination service region is not valid for the bucket location constraint" + ) # Invalid event name: with assert_raises(ClientError) as err: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "QueueConfigurations": [ - { - "QueueArn": - "arn:aws:sqs:us-east-1:012345678910:lksajdfkldskfj", - "Events": ["notarealeventname"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "QueueConfigurations": [ + { + "QueueArn": "arn:aws:sqs:us-east-1:012345678910:lksajdfkldskfj", + "Events": ["notarealeventname"], + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidArgument" - assert err.exception.response["Error"]["Message"] == "The event is not supported for notifications" + assert ( + err.exception.response["Error"]["Message"] + == "The event is not supported for notifications" + ) @mock_s3 @@ -2351,7 +2494,10 @@ def test_boto3_put_bucket_logging(): wrong_region_bucket = "wrongregionlogbucket" s3.create_bucket(Bucket=bucket_name) s3.create_bucket(Bucket=log_bucket) # Adding the ACL for log-delivery later... - s3.create_bucket(Bucket=wrong_region_bucket, CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + s3.create_bucket( + Bucket=wrong_region_bucket, + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) # No logging config: result = s3.get_bucket_logging(Bucket=bucket_name) @@ -2359,72 +2505,78 @@ def test_boto3_put_bucket_logging(): # A log-bucket that doesn't exist: with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": "IAMNOTREAL", - "TargetPrefix": "" - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": {"TargetBucket": "IAMNOTREAL", "TargetPrefix": ""} + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidTargetBucketForLogging" # A log-bucket that's missing the proper ACLs for LogDelivery: with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "" - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": {"TargetBucket": log_bucket, "TargetPrefix": ""} + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidTargetBucketForLogging" assert "log-delivery" in err.exception.response["Error"]["Message"] # Add the proper "log-delivery" ACL to the log buckets: bucket_owner = s3.get_bucket_acl(Bucket=log_bucket)["Owner"] for bucket in [log_bucket, wrong_region_bucket]: - s3.put_bucket_acl(Bucket=bucket, AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + s3.put_bucket_acl( + Bucket=bucket, + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", }, - "Permission": "WRITE" - }, - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "READ_ACP", }, - "Permission": "READ_ACP" - }, - { - "Grantee": { - "Type": "CanonicalUser", - "ID": bucket_owner["ID"] + { + "Grantee": {"Type": "CanonicalUser", "ID": bucket_owner["ID"]}, + "Permission": "FULL_CONTROL", }, - "Permission": "FULL_CONTROL" - } - ], - "Owner": bucket_owner - }) + ], + "Owner": bucket_owner, + }, + ) # A log-bucket that's in the wrong region: with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": wrong_region_bucket, - "TargetPrefix": "" - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": wrong_region_bucket, + "TargetPrefix": "", + } + }, + ) assert err.exception.response["Error"]["Code"] == "CrossLocationLoggingProhibitted" # Correct logging: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "{}/".format(bucket_name) - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": log_bucket, + "TargetPrefix": "{}/".format(bucket_name), + } + }, + ) result = s3.get_bucket_logging(Bucket=bucket_name) assert result["LoggingEnabled"]["TargetBucket"] == log_bucket assert result["LoggingEnabled"]["TargetPrefix"] == "{}/".format(bucket_name) @@ -2435,56 +2587,9 @@ def test_boto3_put_bucket_logging(): assert not s3.get_bucket_logging(Bucket=bucket_name).get("LoggingEnabled") # And enabling with multiple target grants: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "{}/".format(bucket_name), - "TargetGrants": [ - { - "Grantee": { - "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" - }, - "Permission": "READ" - }, - { - "Grantee": { - "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" - }, - "Permission": "WRITE" - } - ] - } - }) - - result = s3.get_bucket_logging(Bucket=bucket_name) - assert len(result["LoggingEnabled"]["TargetGrants"]) == 2 - assert result["LoggingEnabled"]["TargetGrants"][0]["Grantee"]["ID"] == \ - "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274" - - # Test with just 1 grant: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "{}/".format(bucket_name), - "TargetGrants": [ - { - "Grantee": { - "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" - }, - "Permission": "READ" - } - ] - } - }) - result = s3.get_bucket_logging(Bucket=bucket_name) - assert len(result["LoggingEnabled"]["TargetGrants"]) == 1 - - # With an invalid grant: - with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ "LoggingEnabled": { "TargetBucket": log_bucket, "TargetPrefix": "{}/".format(bucket_name), @@ -2492,393 +2597,513 @@ def test_boto3_put_bucket_logging(): { "Grantee": { "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" + "Type": "CanonicalUser", }, - "Permission": "NOTAREALPERM" - } - ] + "Permission": "READ", + }, + { + "Grantee": { + "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", + "Type": "CanonicalUser", + }, + "Permission": "WRITE", + }, + ], } - }) + }, + ) + + result = s3.get_bucket_logging(Bucket=bucket_name) + assert len(result["LoggingEnabled"]["TargetGrants"]) == 2 + assert ( + result["LoggingEnabled"]["TargetGrants"][0]["Grantee"]["ID"] + == "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274" + ) + + # Test with just 1 grant: + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": log_bucket, + "TargetPrefix": "{}/".format(bucket_name), + "TargetGrants": [ + { + "Grantee": { + "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", + "Type": "CanonicalUser", + }, + "Permission": "READ", + } + ], + } + }, + ) + result = s3.get_bucket_logging(Bucket=bucket_name) + assert len(result["LoggingEnabled"]["TargetGrants"]) == 1 + + # With an invalid grant: + with assert_raises(ClientError) as err: + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": log_bucket, + "TargetPrefix": "{}/".format(bucket_name), + "TargetGrants": [ + { + "Grantee": { + "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", + "Type": "CanonicalUser", + }, + "Permission": "NOTAREALPERM", + } + ], + } + }, + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" @mock_s3 def test_boto3_put_object_tagging(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as err: s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) e = err.exception - e.response['Error'].should.equal({ - 'Code': 'NoSuchKey', - 'Message': 'The specified key does not exist.', - 'RequestID': '7a62c49f-347e-4fc4-9331-6e8eEXAMPLE', - }) - - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' + e.response["Error"].should.equal( + { + "Code": "NoSuchKey", + "Message": "The specified key does not exist.", + "RequestID": "7a62c49f-347e-4fc4-9331-6e8eEXAMPLE", + } ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") + resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + + +@mock_s3 +def test_boto3_put_object_tagging_on_earliest_version(): + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" + s3.create_bucket(Bucket=bucket_name) + s3_resource = boto3.resource("s3") + bucket_versioning = s3_resource.BucketVersioning(bucket_name) + bucket_versioning.enable() + bucket_versioning.status.should.equal("Enabled") + + with assert_raises(ClientError) as err: + s3.put_object_tagging( + Bucket=bucket_name, + Key=key, + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, + ) + + e = err.exception + e.response["Error"].should.equal( + { + "Code": "NoSuchKey", + "Message": "The specified key does not exist.", + "RequestID": "7a62c49f-347e-4fc4-9331-6e8eEXAMPLE", + } + ) + + s3.put_object(Bucket=bucket_name, Key=key, Body="test") + s3.put_object(Bucket=bucket_name, Key=key, Body="test_updated") + + object_versions = list(s3_resource.Bucket(bucket_name).object_versions.all()) + first_object = object_versions[0] + second_object = object_versions[1] + + resp = s3.put_object_tagging( + Bucket=bucket_name, + Key=key, + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, + VersionId=first_object.id, + ) + + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + + # Older version has tags while the most recent does not + resp = s3.get_object_tagging(Bucket=bucket_name, Key=key, VersionId=first_object.id) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal( + [{"Key": "item1", "Value": "foo"}, {"Key": "item2", "Value": "bar"}] + ) + + resp = s3.get_object_tagging( + Bucket=bucket_name, Key=key, VersionId=second_object.id + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal([]) + + +@mock_s3 +def test_boto3_put_object_tagging_on_both_version(): + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" + s3.create_bucket(Bucket=bucket_name) + s3_resource = boto3.resource("s3") + bucket_versioning = s3_resource.BucketVersioning(bucket_name) + bucket_versioning.enable() + bucket_versioning.status.should.equal("Enabled") + + with assert_raises(ClientError) as err: + s3.put_object_tagging( + Bucket=bucket_name, + Key=key, + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, + ) + + e = err.exception + e.response["Error"].should.equal( + { + "Code": "NoSuchKey", + "Message": "The specified key does not exist.", + "RequestID": "7a62c49f-347e-4fc4-9331-6e8eEXAMPLE", + } + ) + + s3.put_object(Bucket=bucket_name, Key=key, Body="test") + s3.put_object(Bucket=bucket_name, Key=key, Body="test_updated") + + object_versions = list(s3_resource.Bucket(bucket_name).object_versions.all()) + first_object = object_versions[0] + second_object = object_versions[1] + + resp = s3.put_object_tagging( + Bucket=bucket_name, + Key=key, + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, + VersionId=first_object.id, + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + + resp = s3.put_object_tagging( + Bucket=bucket_name, + Key=key, + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "baz"}, + {"Key": "item2", "Value": "bin"}, + ] + }, + VersionId=second_object.id, + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + + resp = s3.get_object_tagging(Bucket=bucket_name, Key=key, VersionId=first_object.id) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal( + [{"Key": "item1", "Value": "foo"}, {"Key": "item2", "Value": "bar"}] + ) + + resp = s3.get_object_tagging( + Bucket=bucket_name, Key=key, VersionId=second_object.id + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal( + [{"Key": "item1", "Value": "baz"}, {"Key": "item2", "Value": "bin"}] + ) @mock_s3 def test_boto3_put_object_tagging_with_single_tag(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'} - ]} + Tagging={"TagSet": [{"Key": "item1", "Value": "foo"}]}, ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @mock_s3 def test_boto3_get_object_tagging(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") resp = s3.get_object_tagging(Bucket=bucket_name, Key=key) - resp['TagSet'].should.have.length_of(0) + resp["TagSet"].should.have.length_of(0) resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) resp = s3.get_object_tagging(Bucket=bucket_name, Key=key) - resp['TagSet'].should.have.length_of(2) - resp['TagSet'].should.contain({'Key': 'item1', 'Value': 'foo'}) - resp['TagSet'].should.contain({'Key': 'item2', 'Value': 'bar'}) + resp["TagSet"].should.have.length_of(2) + resp["TagSet"].should.contain({"Key": "item1", "Value": "foo"}) + resp["TagSet"].should.contain({"Key": "item2", "Value": "bar"}) @mock_s3 def test_boto3_list_object_versions(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) - response = s3.list_object_versions( - Bucket=bucket_name - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + response = s3.list_object_versions(Bucket=bucket_name) # Two object versions should be returned - len(response['Versions']).should.equal(2) - keys = set([item['Key'] for item in response['Versions']]) + len(response["Versions"]).should.equal(2) + keys = set([item["Key"] for item in response["Versions"]]) keys.should.equal({key}) # Test latest object version is returned response = s3.get_object(Bucket=bucket_name, Key=key) - response['Body'].read().should.equal(items[-1]) + response["Body"].read().should.equal(items[-1]) @mock_s3 def test_boto3_list_object_versions_with_versioning_disabled(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" s3.create_bucket(Bucket=bucket_name) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) - response = s3.list_object_versions( - Bucket=bucket_name - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + response = s3.list_object_versions(Bucket=bucket_name) # One object version should be returned - len(response['Versions']).should.equal(1) - response['Versions'][0]['Key'].should.equal(key) + len(response["Versions"]).should.equal(1) + response["Versions"][0]["Key"].should.equal(key) # The version id should be the string null - response['Versions'][0]['VersionId'].should.equal('null') + response["Versions"][0]["VersionId"].should.equal("null") # Test latest object version is returned response = s3.get_object(Bucket=bucket_name, Key=key) - response['Body'].read().should.equal(items[-1]) + response["Body"].read().should.equal(items[-1]) @mock_s3 def test_boto3_list_object_versions_with_versioning_enabled_late(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" s3.create_bucket(Bucket=bucket_name) - items = (six.b('v1'), six.b('v2')) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=six.b('v1') - ) + items = (six.b("v1"), six.b("v2")) + s3.put_object(Bucket=bucket_name, Key=key, Body=six.b("v1")) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } - ) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=six.b('v2') - ) - response = s3.list_object_versions( - Bucket=bucket_name + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) + s3.put_object(Bucket=bucket_name, Key=key, Body=six.b("v2")) + response = s3.list_object_versions(Bucket=bucket_name) # Two object versions should be returned - len(response['Versions']).should.equal(2) - keys = set([item['Key'] for item in response['Versions']]) + len(response["Versions"]).should.equal(2) + keys = set([item["Key"] for item in response["Versions"]]) keys.should.equal({key}) # There should still be a null version id. - versionsId = set([item['VersionId'] for item in response['Versions']]) - versionsId.should.contain('null') + versionsId = set([item["VersionId"] for item in response["Versions"]]) + versionsId.should.contain("null") # Test latest object version is returned response = s3.get_object(Bucket=bucket_name, Key=key) - response['Body'].read().should.equal(items[-1]) + response["Body"].read().should.equal(items[-1]) + @mock_s3 def test_boto3_bad_prefix_list_object_versions(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' - bad_prefix = 'key-that-does-not-exist' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" + bad_prefix = "key-that-does-not-exist" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) - response = s3.list_object_versions( - Bucket=bucket_name, - Prefix=bad_prefix, - ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response.should_not.contain('Versions') - response.should_not.contain('DeleteMarkers') + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + response = s3.list_object_versions(Bucket=bucket_name, Prefix=bad_prefix) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response.should_not.contain("Versions") + response.should_not.contain("DeleteMarkers") @mock_s3 def test_boto3_delete_markers(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = u'key-with-versions-and-unicode-ó' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions-and-unicode-ó" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) - s3.delete_objects(Bucket=bucket_name, Delete={'Objects': [{'Key': key}]}) + s3.delete_objects(Bucket=bucket_name, Delete={"Objects": [{"Key": key}]}) with assert_raises(ClientError) as e: - s3.get_object( - Bucket=bucket_name, - Key=key - ) - e.exception.response['Error']['Code'].should.equal('NoSuchKey') + s3.get_object(Bucket=bucket_name, Key=key) + e.exception.response["Error"]["Code"].should.equal("NoSuchKey") - response = s3.list_object_versions( - Bucket=bucket_name - ) - response['Versions'].should.have.length_of(2) - response['DeleteMarkers'].should.have.length_of(1) + response = s3.list_object_versions(Bucket=bucket_name) + response["Versions"].should.have.length_of(2) + response["DeleteMarkers"].should.have.length_of(1) s3.delete_object( - Bucket=bucket_name, - Key=key, - VersionId=response['DeleteMarkers'][0]['VersionId'] + Bucket=bucket_name, Key=key, VersionId=response["DeleteMarkers"][0]["VersionId"] ) - response = s3.get_object( - Bucket=bucket_name, - Key=key - ) - response['Body'].read().should.equal(items[-1]) + response = s3.get_object(Bucket=bucket_name, Key=key) + response["Body"].read().should.equal(items[-1]) - response = s3.list_object_versions( - Bucket=bucket_name - ) - response['Versions'].should.have.length_of(2) + response = s3.list_object_versions(Bucket=bucket_name) + response["Versions"].should.have.length_of(2) # We've asserted there is only 2 records so one is newest, one is oldest - latest = list(filter(lambda item: item['IsLatest'], response['Versions']))[0] - oldest = list(filter(lambda item: not item['IsLatest'], response['Versions']))[0] + latest = list(filter(lambda item: item["IsLatest"], response["Versions"]))[0] + oldest = list(filter(lambda item: not item["IsLatest"], response["Versions"]))[0] # Double check ordering of version ID's - latest['VersionId'].should_not.equal(oldest['VersionId']) + latest["VersionId"].should_not.equal(oldest["VersionId"]) # Double check the name is still unicode - latest['Key'].should.equal('key-with-versions-and-unicode-ó') - oldest['Key'].should.equal('key-with-versions-and-unicode-ó') + latest["Key"].should.equal("key-with-versions-and-unicode-ó") + oldest["Key"].should.equal("key-with-versions-and-unicode-ó") @mock_s3 def test_boto3_multiple_delete_markers(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = u'key-with-versions-and-unicode-ó' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions-and-unicode-ó" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) # Delete the object twice to add multiple delete markers s3.delete_object(Bucket=bucket_name, Key=key) s3.delete_object(Bucket=bucket_name, Key=key) response = s3.list_object_versions(Bucket=bucket_name) - response['DeleteMarkers'].should.have.length_of(2) + response["DeleteMarkers"].should.have.length_of(2) with assert_raises(ClientError) as e: - s3.get_object( - Bucket=bucket_name, - Key=key - ) - e.response['Error']['Code'].should.equal('404') + s3.get_object(Bucket=bucket_name, Key=key) + e.response["Error"]["Code"].should.equal("404") # Remove both delete markers to restore the object s3.delete_object( - Bucket=bucket_name, - Key=key, - VersionId=response['DeleteMarkers'][0]['VersionId'] + Bucket=bucket_name, Key=key, VersionId=response["DeleteMarkers"][0]["VersionId"] ) s3.delete_object( - Bucket=bucket_name, - Key=key, - VersionId=response['DeleteMarkers'][1]['VersionId'] + Bucket=bucket_name, Key=key, VersionId=response["DeleteMarkers"][1]["VersionId"] ) - response = s3.get_object( - Bucket=bucket_name, - Key=key - ) - response['Body'].read().should.equal(items[-1]) + response = s3.get_object(Bucket=bucket_name, Key=key) + response["Body"].read().should.equal(items[-1]) response = s3.list_object_versions(Bucket=bucket_name) - response['Versions'].should.have.length_of(2) + response["Versions"].should.have.length_of(2) # We've asserted there is only 2 records so one is newest, one is oldest - latest = list(filter(lambda item: item['IsLatest'], response['Versions']))[0] - oldest = list(filter(lambda item: not item['IsLatest'], response['Versions']))[0] + latest = list(filter(lambda item: item["IsLatest"], response["Versions"]))[0] + oldest = list(filter(lambda item: not item["IsLatest"], response["Versions"]))[0] # Double check ordering of version ID's - latest['VersionId'].should_not.equal(oldest['VersionId']) + latest["VersionId"].should_not.equal(oldest["VersionId"]) # Double check the name is still unicode - latest['Key'].should.equal('key-with-versions-and-unicode-ó') - oldest['Key'].should.equal('key-with-versions-and-unicode-ó') + latest["Key"].should.equal("key-with-versions-and-unicode-ó") + oldest["Key"].should.equal("key-with-versions-and-unicode-ó") @mock_s3 def test_get_stream_gzipped(): payload = b"this is some stuff here" - s3_client = boto3.client("s3", region_name='us-east-1') - s3_client.create_bucket(Bucket='moto-tests') + s3_client = boto3.client("s3", region_name="us-east-1") + s3_client.create_bucket(Bucket="moto-tests") buffer_ = BytesIO() - with GzipFile(fileobj=buffer_, mode='w') as f: + with GzipFile(fileobj=buffer_, mode="w") as f: f.write(payload) payload_gz = buffer_.getvalue() s3_client.put_object( - Bucket='moto-tests', - Key='keyname', - Body=payload_gz, - ContentEncoding='gzip', + Bucket="moto-tests", Key="keyname", Body=payload_gz, ContentEncoding="gzip" ) - obj = s3_client.get_object( - Bucket='moto-tests', - Key='keyname', - ) - res = zlib.decompress(obj['Body'].read(), 16 + zlib.MAX_WBITS) + obj = s3_client.get_object(Bucket="moto-tests", Key="keyname") + res = zlib.decompress(obj["Body"].read(), 16 + zlib.MAX_WBITS) assert res == payload @@ -2901,93 +3126,812 @@ TEST_XML = """\ """ + @mock_s3 def test_boto3_bucket_name_too_long(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") with assert_raises(ClientError) as exc: - s3.create_bucket(Bucket='x'*64) - exc.exception.response['Error']['Code'].should.equal('InvalidBucketName') + s3.create_bucket(Bucket="x" * 64) + exc.exception.response["Error"]["Code"].should.equal("InvalidBucketName") + @mock_s3 def test_boto3_bucket_name_too_short(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") with assert_raises(ClientError) as exc: - s3.create_bucket(Bucket='x'*2) - exc.exception.response['Error']['Code'].should.equal('InvalidBucketName') + s3.create_bucket(Bucket="x" * 2) + exc.exception.response["Error"]["Code"].should.equal("InvalidBucketName") + @mock_s3 def test_accelerated_none_when_unspecified(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.shouldnt.have.key('Status') + resp.shouldnt.have.key("Status") + @mock_s3 def test_can_enable_bucket_acceleration(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Enabled'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Enabled"} ) - resp.keys().should.have.length_of(1) # Response contains nothing (only HTTP headers) + resp.keys().should.have.length_of( + 1 + ) # Response contains nothing (only HTTP headers) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.should.have.key('Status') - resp['Status'].should.equal('Enabled') + resp.should.have.key("Status") + resp["Status"].should.equal("Enabled") + @mock_s3 def test_can_suspend_bucket_acceleration(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Enabled'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Enabled"} ) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Suspended'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Suspended"} ) - resp.keys().should.have.length_of(1) # Response contains nothing (only HTTP headers) + resp.keys().should.have.length_of( + 1 + ) # Response contains nothing (only HTTP headers) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.should.have.key('Status') - resp['Status'].should.equal('Suspended') + resp.should.have.key("Status") + resp["Status"].should.equal("Suspended") + @mock_s3 def test_suspending_acceleration_on_not_configured_bucket_does_nothing(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Suspended'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Suspended"} ) - resp.keys().should.have.length_of(1) # Response contains nothing (only HTTP headers) + resp.keys().should.have.length_of( + 1 + ) # Response contains nothing (only HTTP headers) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.shouldnt.have.key('Status') + resp.shouldnt.have.key("Status") + @mock_s3 def test_accelerate_configuration_status_validation(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as exc: s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'bad_status'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "bad_status"} ) - exc.exception.response['Error']['Code'].should.equal('MalformedXML') + exc.exception.response["Error"]["Code"].should.equal("MalformedXML") + @mock_s3 def test_accelerate_configuration_is_not_supported_when_bucket_name_has_dots(): - bucket_name = 'some.bucket.with.dots' - s3 = boto3.client('s3') + bucket_name = "some.bucket.with.dots" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as exc: s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Enabled'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Enabled"} ) - exc.exception.response['Error']['Code'].should.equal('InvalidRequest') + exc.exception.response["Error"]["Code"].should.equal("InvalidRequest") + + +def store_and_read_back_a_key(key): + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + body = b"Some body" + + s3.create_bucket(Bucket=bucket_name) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + + response = s3.get_object(Bucket=bucket_name, Key=key) + response["Body"].read().should.equal(body) + + +@mock_s3 +def test_paths_with_leading_slashes_work(): + store_and_read_back_a_key("/a-key") + + +@mock_s3 +def test_root_dir_with_empty_name_works(): + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Does not work in server mode due to error in Workzeug") + store_and_read_back_a_key("/") + + +@parameterized( + [("foo/bar/baz",), ("foo",), ("foo/run_dt%3D2019-01-01%252012%253A30%253A00",)] +) +@mock_s3 +def test_delete_objects_with_url_encoded_key(key): + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + body = b"Some body" + + s3.create_bucket(Bucket=bucket_name) + + def put_object(): + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + + def assert_deleted(): + with assert_raises(ClientError) as e: + s3.get_object(Bucket=bucket_name, Key=key) + + e.exception.response["Error"]["Code"].should.equal("NoSuchKey") + + put_object() + s3.delete_object(Bucket=bucket_name, Key=key) + assert_deleted() + + put_object() + s3.delete_objects(Bucket=bucket_name, Delete={"Objects": [{"Key": key}]}) + assert_deleted() + + +@mock_s3 +@mock_config +def test_public_access_block(): + client = boto3.client("s3") + client.create_bucket(Bucket="mybucket") + + # Try to get the public access block (should not exist by default) + with assert_raises(ClientError) as ce: + client.get_public_access_block(Bucket="mybucket") + + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchPublicAccessBlockConfiguration" + ) + assert ( + ce.exception.response["Error"]["Message"] + == "The public access block configuration was not found" + ) + assert ce.exception.response["ResponseMetadata"]["HTTPStatusCode"] == 404 + + # Put a public block in place: + test_map = { + "BlockPublicAcls": False, + "IgnorePublicAcls": False, + "BlockPublicPolicy": False, + "RestrictPublicBuckets": False, + } + + for field in test_map.keys(): + # Toggle: + test_map[field] = True + + client.put_public_access_block( + Bucket="mybucket", PublicAccessBlockConfiguration=test_map + ) + + # Test: + assert ( + test_map + == client.get_public_access_block(Bucket="mybucket")[ + "PublicAccessBlockConfiguration" + ] + ) + + # Assume missing values are default False: + client.put_public_access_block( + Bucket="mybucket", PublicAccessBlockConfiguration={"BlockPublicAcls": True} + ) + assert client.get_public_access_block(Bucket="mybucket")[ + "PublicAccessBlockConfiguration" + ] == { + "BlockPublicAcls": True, + "IgnorePublicAcls": False, + "BlockPublicPolicy": False, + "RestrictPublicBuckets": False, + } + + # Test with a blank PublicAccessBlockConfiguration: + with assert_raises(ClientError) as ce: + client.put_public_access_block( + Bucket="mybucket", PublicAccessBlockConfiguration={} + ) + + assert ce.exception.response["Error"]["Code"] == "InvalidRequest" + assert ( + ce.exception.response["Error"]["Message"] + == "Must specify at least one configuration." + ) + assert ce.exception.response["ResponseMetadata"]["HTTPStatusCode"] == 400 + + # Test that things work with AWS Config: + config_client = boto3.client("config", region_name="us-east-1") + result = config_client.get_resource_config_history( + resourceType="AWS::S3::Bucket", resourceId="mybucket" + ) + pub_block_config = json.loads( + result["configurationItems"][0]["supplementaryConfiguration"][ + "PublicAccessBlockConfiguration" + ] + ) + + assert pub_block_config == { + "blockPublicAcls": True, + "ignorePublicAcls": False, + "blockPublicPolicy": False, + "restrictPublicBuckets": False, + } + + # Delete: + client.delete_public_access_block(Bucket="mybucket") + + with assert_raises(ClientError) as ce: + client.get_public_access_block(Bucket="mybucket") + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchPublicAccessBlockConfiguration" + ) + + +@mock_s3 +def test_s3_public_access_block_to_config_dict(): + from moto.s3.config import s3_config_query + + # With 1 bucket in us-west-2: + s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + + public_access_block = { + "BlockPublicAcls": "True", + "IgnorePublicAcls": "False", + "BlockPublicPolicy": "True", + "RestrictPublicBuckets": "False", + } + + # Python 2 unicode issues: + if sys.version_info[0] < 3: + public_access_block = py2_strip_unicode_keys(public_access_block) + + # Add a public access block: + s3_config_query.backends["global"].put_bucket_public_access_block( + "bucket1", public_access_block + ) + + result = ( + s3_config_query.backends["global"] + .buckets["bucket1"] + .public_access_block.to_config_dict() + ) + + convert_bool = lambda x: x == "True" + for key, value in public_access_block.items(): + assert result[ + "{lowercase}{rest}".format(lowercase=key[0].lower(), rest=key[1:]) + ] == convert_bool(value) + + # Verify that this resides in the full bucket's to_config_dict: + full_result = s3_config_query.backends["global"].buckets["bucket1"].to_config_dict() + assert ( + json.loads( + full_result["supplementaryConfiguration"]["PublicAccessBlockConfiguration"] + ) + == result + ) + + +@mock_s3 +def test_list_config_discovered_resources(): + from moto.s3.config import s3_config_query + + # Without any buckets: + assert s3_config_query.list_config_service_resources( + "global", "global", None, None, 100, None + ) == ([], None) + + # With 10 buckets in us-west-2: + for x in range(0, 10): + s3_config_query.backends["global"].create_bucket( + "bucket{}".format(x), "us-west-2" + ) + + # With 2 buckets in eu-west-1: + for x in range(10, 12): + s3_config_query.backends["global"].create_bucket( + "eu-bucket{}".format(x), "eu-west-1" + ) + + result, next_token = s3_config_query.list_config_service_resources( + None, None, 100, None + ) + assert not next_token + assert len(result) == 12 + for x in range(0, 10): + assert result[x] == { + "type": "AWS::S3::Bucket", + "id": "bucket{}".format(x), + "name": "bucket{}".format(x), + "region": "us-west-2", + } + for x in range(10, 12): + assert result[x] == { + "type": "AWS::S3::Bucket", + "id": "eu-bucket{}".format(x), + "name": "eu-bucket{}".format(x), + "region": "eu-west-1", + } + + # With a name: + result, next_token = s3_config_query.list_config_service_resources( + None, "bucket0", 100, None + ) + assert len(result) == 1 and result[0]["name"] == "bucket0" and not next_token + + # With a region: + result, next_token = s3_config_query.list_config_service_resources( + None, None, 100, None, resource_region="eu-west-1" + ) + assert len(result) == 2 and not next_token and result[1]["name"] == "eu-bucket11" + + # With resource ids: + result, next_token = s3_config_query.list_config_service_resources( + ["bucket0", "bucket1"], None, 100, None + ) + assert ( + len(result) == 2 + and result[0]["name"] == "bucket0" + and result[1]["name"] == "bucket1" + and not next_token + ) + + # With duplicated resource ids: + result, next_token = s3_config_query.list_config_service_resources( + ["bucket0", "bucket0"], None, 100, None + ) + assert len(result) == 1 and result[0]["name"] == "bucket0" and not next_token + + # Pagination: + result, next_token = s3_config_query.list_config_service_resources( + None, None, 1, None + ) + assert ( + len(result) == 1 and result[0]["name"] == "bucket0" and next_token == "bucket1" + ) + + # Last Page: + result, next_token = s3_config_query.list_config_service_resources( + None, None, 1, "eu-bucket11", resource_region="eu-west-1" + ) + assert len(result) == 1 and result[0]["name"] == "eu-bucket11" and not next_token + + # With a list of buckets: + result, next_token = s3_config_query.list_config_service_resources( + ["bucket0", "bucket1"], None, 1, None + ) + assert ( + len(result) == 1 and result[0]["name"] == "bucket0" and next_token == "bucket1" + ) + + # With an invalid page: + with assert_raises(InvalidNextTokenException) as inte: + s3_config_query.list_config_service_resources(None, None, 1, "notabucket") + + assert "The nextToken provided is invalid" in inte.exception.message + + +@mock_s3 +def test_s3_lifecycle_config_dict(): + from moto.s3.config import s3_config_query + + # With 1 bucket in us-west-2: + s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + + # And a lifecycle policy + lifecycle = [ + { + "ID": "rule1", + "Status": "Enabled", + "Filter": {"Prefix": ""}, + "Expiration": {"Days": 1}, + }, + { + "ID": "rule2", + "Status": "Enabled", + "Filter": { + "And": { + "Prefix": "some/path", + "Tag": [{"Key": "TheKey", "Value": "TheValue"}], + } + }, + "Expiration": {"Days": 1}, + }, + {"ID": "rule3", "Status": "Enabled", "Filter": {}, "Expiration": {"Days": 1}}, + { + "ID": "rule4", + "Status": "Enabled", + "Filter": {"Prefix": ""}, + "AbortIncompleteMultipartUpload": {"DaysAfterInitiation": 1}, + }, + ] + s3_config_query.backends["global"].set_bucket_lifecycle("bucket1", lifecycle) + + # Get the rules for this: + lifecycles = [ + rule.to_config_dict() + for rule in s3_config_query.backends["global"].buckets["bucket1"].rules + ] + + # Verify the first: + assert lifecycles[0] == { + "id": "rule1", + "prefix": None, + "status": "Enabled", + "expirationInDays": 1, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": None, + "filter": {"predicate": {"type": "LifecyclePrefixPredicate", "prefix": ""}}, + } + + # Verify the second: + assert lifecycles[1] == { + "id": "rule2", + "prefix": None, + "status": "Enabled", + "expirationInDays": 1, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": None, + "filter": { + "predicate": { + "type": "LifecycleAndOperator", + "operands": [ + {"type": "LifecyclePrefixPredicate", "prefix": "some/path"}, + { + "type": "LifecycleTagPredicate", + "tag": {"key": "TheKey", "value": "TheValue"}, + }, + ], + } + }, + } + + # And the third: + assert lifecycles[2] == { + "id": "rule3", + "prefix": None, + "status": "Enabled", + "expirationInDays": 1, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": None, + "filter": {"predicate": None}, + } + + # And the last: + assert lifecycles[3] == { + "id": "rule4", + "prefix": None, + "status": "Enabled", + "expirationInDays": None, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": {"daysAfterInitiation": 1}, + "filter": {"predicate": {"type": "LifecyclePrefixPredicate", "prefix": ""}}, + } + + +@mock_s3 +def test_s3_notification_config_dict(): + from moto.s3.config import s3_config_query + + # With 1 bucket in us-west-2: + s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + + # And some notifications: + notifications = { + "TopicConfiguration": [ + { + "Id": "Topic", + "Topic": "arn:aws:sns:us-west-2:012345678910:mytopic", + "Event": [ + "s3:ReducedRedundancyLostObject", + "s3:ObjectRestore:Completed", + ], + } + ], + "QueueConfiguration": [ + { + "Id": "Queue", + "Queue": "arn:aws:sqs:us-west-2:012345678910:myqueue", + "Event": ["s3:ObjectRemoved:Delete"], + "Filter": { + "S3Key": { + "FilterRule": [{"Name": "prefix", "Value": "stuff/here/"}] + } + }, + } + ], + "CloudFunctionConfiguration": [ + { + "Id": "Lambda", + "CloudFunction": "arn:aws:lambda:us-west-2:012345678910:function:mylambda", + "Event": [ + "s3:ObjectCreated:Post", + "s3:ObjectCreated:Copy", + "s3:ObjectCreated:Put", + ], + "Filter": { + "S3Key": {"FilterRule": [{"Name": "suffix", "Value": ".png"}]} + }, + } + ], + } + + s3_config_query.backends["global"].put_bucket_notification_configuration( + "bucket1", notifications + ) + + # Get the notifications for this: + notifications = ( + s3_config_query.backends["global"] + .buckets["bucket1"] + .notification_configuration.to_config_dict() + ) + + # Verify it all: + assert notifications == { + "configurations": { + "Topic": { + "events": [ + "s3:ReducedRedundancyLostObject", + "s3:ObjectRestore:Completed", + ], + "filter": None, + "objectPrefixes": [], + "topicARN": "arn:aws:sns:us-west-2:012345678910:mytopic", + "type": "TopicConfiguration", + }, + "Queue": { + "events": ["s3:ObjectRemoved:Delete"], + "filter": { + "s3KeyFilter": { + "filterRules": [{"name": "prefix", "value": "stuff/here/"}] + } + }, + "objectPrefixes": [], + "queueARN": "arn:aws:sqs:us-west-2:012345678910:myqueue", + "type": "QueueConfiguration", + }, + "Lambda": { + "events": [ + "s3:ObjectCreated:Post", + "s3:ObjectCreated:Copy", + "s3:ObjectCreated:Put", + ], + "filter": { + "s3KeyFilter": { + "filterRules": [{"name": "suffix", "value": ".png"}] + } + }, + "objectPrefixes": [], + "queueARN": "arn:aws:lambda:us-west-2:012345678910:function:mylambda", + "type": "LambdaConfiguration", + }, + } + } + + +@mock_s3 +def test_s3_acl_to_config_dict(): + from moto.s3.config import s3_config_query + from moto.s3.models import FakeAcl, FakeGrant, FakeGrantee, OWNER + + # With 1 bucket in us-west-2: + s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2") + + # Get the config dict with nothing other than the owner details: + acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() + assert acls == {"grantSet": None, "owner": {"displayName": None, "id": OWNER}} + + # Add some Log Bucket ACLs: + log_acls = FakeAcl( + [ + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "WRITE", + ), + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "READ_ACP", + ), + FakeGrant([FakeGrantee(id=OWNER)], "FULL_CONTROL"), + ] + ) + s3_config_query.backends["global"].set_bucket_acl("logbucket", log_acls) + + acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() + assert acls == { + "grantSet": None, + "grantList": [ + {"grantee": "LogDelivery", "permission": "Write"}, + {"grantee": "LogDelivery", "permission": "ReadAcp"}, + ], + "owner": {"displayName": None, "id": OWNER}, + } + + # Give the owner less than full_control permissions: + log_acls = FakeAcl( + [ + FakeGrant([FakeGrantee(id=OWNER)], "READ_ACP"), + FakeGrant([FakeGrantee(id=OWNER)], "WRITE_ACP"), + ] + ) + s3_config_query.backends["global"].set_bucket_acl("logbucket", log_acls) + acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() + assert acls == { + "grantSet": None, + "grantList": [ + {"grantee": {"id": OWNER, "displayName": None}, "permission": "ReadAcp"}, + {"grantee": {"id": OWNER, "displayName": None}, "permission": "WriteAcp"}, + ], + "owner": {"displayName": None, "id": OWNER}, + } + + +@mock_s3 +def test_s3_config_dict(): + from moto.s3.config import s3_config_query + from moto.s3.models import ( + FakeAcl, + FakeGrant, + FakeGrantee, + FakeTag, + FakeTagging, + FakeTagSet, + OWNER, + ) + + # Without any buckets: + assert not s3_config_query.get_config_resource("some_bucket") + + tags = FakeTagging( + FakeTagSet( + [FakeTag("someTag", "someValue"), FakeTag("someOtherTag", "someOtherValue")] + ) + ) + + # With 1 bucket in us-west-2: + s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + s3_config_query.backends["global"].put_bucket_tagging("bucket1", tags) + + # With a log bucket: + s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2") + log_acls = FakeAcl( + [ + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "WRITE", + ), + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "READ_ACP", + ), + FakeGrant([FakeGrantee(id=OWNER)], "FULL_CONTROL"), + ] + ) + + s3_config_query.backends["global"].set_bucket_acl("logbucket", log_acls) + s3_config_query.backends["global"].put_bucket_logging( + "bucket1", {"TargetBucket": "logbucket", "TargetPrefix": ""} + ) + + policy = json.dumps( + { + "Statement": [ + { + "Effect": "Deny", + "Action": "s3:DeleteObject", + "Principal": "*", + "Resource": "arn:aws:s3:::bucket1/*", + } + ] + } + ) + + # The policy is a byte array -- need to encode in Python 3 -- for Python 2 just pass the raw string in: + if sys.version_info[0] > 2: + pass_policy = bytes(policy, "utf-8") + else: + pass_policy = policy + s3_config_query.backends["global"].set_bucket_policy("bucket1", pass_policy) + + # Get the us-west-2 bucket and verify that it works properly: + bucket1_result = s3_config_query.get_config_resource("bucket1") + + # Just verify a few things: + assert bucket1_result["arn"] == "arn:aws:s3:::bucket1" + assert bucket1_result["awsRegion"] == "us-west-2" + assert bucket1_result["resourceName"] == bucket1_result["resourceId"] == "bucket1" + assert bucket1_result["tags"] == { + "someTag": "someValue", + "someOtherTag": "someOtherValue", + } + assert json.loads( + bucket1_result["supplementaryConfiguration"]["BucketTaggingConfiguration"] + ) == {"tagSets": [{"tags": bucket1_result["tags"]}]} + assert isinstance(bucket1_result["configuration"], str) + exist_list = [ + "AccessControlList", + "BucketAccelerateConfiguration", + "BucketLoggingConfiguration", + "BucketPolicy", + "IsRequesterPaysEnabled", + "BucketNotificationConfiguration", + ] + for exist in exist_list: + assert isinstance(bucket1_result["supplementaryConfiguration"][exist], str) + + # Verify the logging config: + assert json.loads( + bucket1_result["supplementaryConfiguration"]["BucketLoggingConfiguration"] + ) == {"destinationBucketName": "logbucket", "logFilePrefix": ""} + + # Verify that the AccessControlList is a double-wrapped JSON string: + assert json.loads( + json.loads(bucket1_result["supplementaryConfiguration"]["AccessControlList"]) + ) == { + "grantSet": None, + "owner": { + "displayName": None, + "id": "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a", + }, + } + + # Verify the policy: + assert json.loads(bucket1_result["supplementaryConfiguration"]["BucketPolicy"]) == { + "policyText": policy + } + + # Filter by correct region: + assert bucket1_result == s3_config_query.get_config_resource( + "bucket1", resource_region="us-west-2" + ) + + # By incorrect region: + assert not s3_config_query.get_config_resource( + "bucket1", resource_region="eu-west-1" + ) + + # With correct resource ID and name: + assert bucket1_result == s3_config_query.get_config_resource( + "bucket1", resource_name="bucket1" + ) + + # With an incorrect resource name: + assert not s3_config_query.get_config_resource( + "bucket1", resource_name="eu-bucket-1" + ) + + # Verify that no bucket policy returns the proper value: + logging_bucket = s3_config_query.get_config_resource("logbucket") + assert json.loads(logging_bucket["supplementaryConfiguration"]["BucketPolicy"]) == { + "policyText": None + } + assert not logging_bucket["tags"] + assert not logging_bucket["supplementaryConfiguration"].get( + "BucketTaggingConfiguration" + ) diff --git a/tests/test_s3/test_s3_lifecycle.py b/tests/test_s3/test_s3_lifecycle.py index 6cb43e96f..260b248f1 100644 --- a/tests/test_s3/test_s3_lifecycle.py +++ b/tests/test_s3/test_s3_lifecycle.py @@ -1,387 +1,505 @@ -from __future__ import unicode_literals - -import boto -import boto3 -from boto.exception import S3ResponseError -from boto.s3.lifecycle import Lifecycle, Transition, Expiration, Rule - -import sure # noqa -from botocore.exceptions import ClientError -from datetime import datetime -from nose.tools import assert_raises - -from moto import mock_s3_deprecated, mock_s3 - - -@mock_s3_deprecated -def test_lifecycle_create(): - conn = boto.s3.connect_to_region("us-west-1") - bucket = conn.create_bucket("foobar") - - lifecycle = Lifecycle() - lifecycle.add_rule('myid', '', 'Enabled', 30) - bucket.configure_lifecycle(lifecycle) - response = bucket.get_lifecycle_config() - len(response).should.equal(1) - lifecycle = response[0] - lifecycle.id.should.equal('myid') - lifecycle.prefix.should.equal('') - lifecycle.status.should.equal('Enabled') - list(lifecycle.transition).should.equal([]) - - -@mock_s3 -def test_lifecycle_with_filters(): - client = boto3.client("s3") - client.create_bucket(Bucket="bucket") - - # Create a lifecycle rule with a Filter (no tags): - lfc = { - "Rules": [ - { - "Expiration": { - "Days": 7 - }, - "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" - } - ] - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["Filter"]["Prefix"] == '' - assert not result["Rules"][0]["Filter"].get("And") - assert not result["Rules"][0]["Filter"].get("Tag") - with assert_raises(KeyError): - assert result["Rules"][0]["Prefix"] - - # With a tag: - lfc["Rules"][0]["Filter"]["Tag"] = { - "Key": "mytag", - "Value": "mytagvalue" - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["Filter"]["Prefix"] == '' - assert not result["Rules"][0]["Filter"].get("And") - assert result["Rules"][0]["Filter"]["Tag"]["Key"] == "mytag" - assert result["Rules"][0]["Filter"]["Tag"]["Value"] == "mytagvalue" - with assert_raises(KeyError): - assert result["Rules"][0]["Prefix"] - - # With And (single tag): - lfc["Rules"][0]["Filter"]["And"] = { - "Prefix": "some/prefix", - "Tags": [ - { - "Key": "mytag", - "Value": "mytagvalue" - } - ] - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["Filter"]["Prefix"] == "" - assert result["Rules"][0]["Filter"]["And"]["Prefix"] == "some/prefix" - assert len(result["Rules"][0]["Filter"]["And"]["Tags"]) == 1 - assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Key"] == "mytag" - assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Value"] == "mytagvalue" - assert result["Rules"][0]["Filter"]["Tag"]["Key"] == "mytag" - assert result["Rules"][0]["Filter"]["Tag"]["Value"] == "mytagvalue" - with assert_raises(KeyError): - assert result["Rules"][0]["Prefix"] - - # With multiple And tags: - lfc["Rules"][0]["Filter"]["And"] = { - "Prefix": "some/prefix", - "Tags": [ - { - "Key": "mytag", - "Value": "mytagvalue" - }, - { - "Key": "mytag2", - "Value": "mytagvalue2" - } - ] - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["Filter"]["Prefix"] == "" - assert result["Rules"][0]["Filter"]["And"]["Prefix"] == "some/prefix" - assert len(result["Rules"][0]["Filter"]["And"]["Tags"]) == 2 - assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Key"] == "mytag" - assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Value"] == "mytagvalue" - assert result["Rules"][0]["Filter"]["Tag"]["Key"] == "mytag" - assert result["Rules"][0]["Filter"]["Tag"]["Value"] == "mytagvalue" - assert result["Rules"][0]["Filter"]["And"]["Tags"][1]["Key"] == "mytag2" - assert result["Rules"][0]["Filter"]["And"]["Tags"][1]["Value"] == "mytagvalue2" - assert result["Rules"][0]["Filter"]["Tag"]["Key"] == "mytag" - assert result["Rules"][0]["Filter"]["Tag"]["Value"] == "mytagvalue" - with assert_raises(KeyError): - assert result["Rules"][0]["Prefix"] - - # Can't have both filter and prefix: - lfc["Rules"][0]["Prefix"] = '' - with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - assert err.exception.response["Error"]["Code"] == "MalformedXML" - - lfc["Rules"][0]["Prefix"] = 'some/path' - with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - assert err.exception.response["Error"]["Code"] == "MalformedXML" - - # No filters -- just a prefix: - del lfc["Rules"][0]["Filter"] - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert not result["Rules"][0].get("Filter") - assert result["Rules"][0]["Prefix"] == "some/path" - - -@mock_s3 -def test_lifecycle_with_eodm(): - client = boto3.client("s3") - client.create_bucket(Bucket="bucket") - - lfc = { - "Rules": [ - { - "Expiration": { - "ExpiredObjectDeleteMarker": True - }, - "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" - } - ] - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] - - # Set to False: - lfc["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] = False - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert not result["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] - - # With failure: - lfc["Rules"][0]["Expiration"]["Days"] = 7 - with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - assert err.exception.response["Error"]["Code"] == "MalformedXML" - del lfc["Rules"][0]["Expiration"]["Days"] - - lfc["Rules"][0]["Expiration"]["Date"] = datetime(2015, 1, 1) - with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - assert err.exception.response["Error"]["Code"] == "MalformedXML" - - -@mock_s3 -def test_lifecycle_with_nve(): - client = boto3.client("s3") - client.create_bucket(Bucket="bucket") - - lfc = { - "Rules": [ - { - "NoncurrentVersionExpiration": { - "NoncurrentDays": 30 - }, - "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" - } - ] - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 30 - - # Change NoncurrentDays: - lfc["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] = 10 - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 10 - - # TODO: Add test for failures due to missing children - - -@mock_s3 -def test_lifecycle_with_nvt(): - client = boto3.client("s3") - client.create_bucket(Bucket="bucket") - - lfc = { - "Rules": [ - { - "NoncurrentVersionTransitions": [{ - "NoncurrentDays": 30, - "StorageClass": "ONEZONE_IA" - }], - "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" - } - ] - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 30 - assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] == "ONEZONE_IA" - - # Change NoncurrentDays: - lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 10 - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 10 - - # Change StorageClass: - lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] = "GLACIER" - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] == "GLACIER" - - # With failures for missing children: - del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] - with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - assert err.exception.response["Error"]["Code"] == "MalformedXML" - lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 30 - - del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] - with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - assert err.exception.response["Error"]["Code"] == "MalformedXML" - - -@mock_s3 -def test_lifecycle_with_aimu(): - client = boto3.client("s3") - client.create_bucket(Bucket="bucket") - - lfc = { - "Rules": [ - { - "AbortIncompleteMultipartUpload": { - "DaysAfterInitiation": 7 - }, - "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" - } - ] - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 7 - - # Change DaysAfterInitiation: - lfc["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] = 30 - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket") - assert len(result["Rules"]) == 1 - assert result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 30 - - # TODO: Add test for failures due to missing children - - -@mock_s3_deprecated -def test_lifecycle_with_glacier_transition(): - conn = boto.s3.connect_to_region("us-west-1") - bucket = conn.create_bucket("foobar") - - lifecycle = Lifecycle() - transition = Transition(days=30, storage_class='GLACIER') - rule = Rule('myid', prefix='', status='Enabled', expiration=None, - transition=transition) - lifecycle.append(rule) - bucket.configure_lifecycle(lifecycle) - response = bucket.get_lifecycle_config() - transition = response[0].transition - transition.days.should.equal(30) - transition.storage_class.should.equal('GLACIER') - transition.date.should.equal(None) - - -@mock_s3_deprecated -def test_lifecycle_multi(): - conn = boto.s3.connect_to_region("us-west-1") - bucket = conn.create_bucket("foobar") - - date = '2022-10-12T00:00:00.000Z' - sc = 'GLACIER' - lifecycle = Lifecycle() - lifecycle.add_rule("1", "1/", "Enabled", 1) - lifecycle.add_rule("2", "2/", "Enabled", Expiration(days=2)) - lifecycle.add_rule("3", "3/", "Enabled", Expiration(date=date)) - lifecycle.add_rule("4", "4/", "Enabled", None, - Transition(days=4, storage_class=sc)) - lifecycle.add_rule("5", "5/", "Enabled", None, - Transition(date=date, storage_class=sc)) - - bucket.configure_lifecycle(lifecycle) - # read the lifecycle back - rules = bucket.get_lifecycle_config() - - for rule in rules: - if rule.id == "1": - rule.prefix.should.equal("1/") - rule.expiration.days.should.equal(1) - elif rule.id == "2": - rule.prefix.should.equal("2/") - rule.expiration.days.should.equal(2) - elif rule.id == "3": - rule.prefix.should.equal("3/") - rule.expiration.date.should.equal(date) - elif rule.id == "4": - rule.prefix.should.equal("4/") - rule.transition.days.should.equal(4) - rule.transition.storage_class.should.equal(sc) - elif rule.id == "5": - rule.prefix.should.equal("5/") - rule.transition.date.should.equal(date) - rule.transition.storage_class.should.equal(sc) - else: - assert False, "Invalid rule id" - - -@mock_s3_deprecated -def test_lifecycle_delete(): - conn = boto.s3.connect_to_region("us-west-1") - bucket = conn.create_bucket("foobar") - - lifecycle = Lifecycle() - lifecycle.add_rule(expiration=30) - bucket.configure_lifecycle(lifecycle) - response = bucket.get_lifecycle_config() - response.should.have.length_of(1) - - bucket.delete_lifecycle_configuration() - bucket.get_lifecycle_config.when.called_with().should.throw(S3ResponseError) +from __future__ import unicode_literals + +import boto +import boto3 +from boto.exception import S3ResponseError +from boto.s3.lifecycle import Lifecycle, Transition, Expiration, Rule + +import sure # noqa +from botocore.exceptions import ClientError +from datetime import datetime +from nose.tools import assert_raises + +from moto import mock_s3_deprecated, mock_s3 + + +@mock_s3_deprecated +def test_lifecycle_create(): + conn = boto.s3.connect_to_region("us-west-1") + bucket = conn.create_bucket("foobar") + + lifecycle = Lifecycle() + lifecycle.add_rule("myid", "", "Enabled", 30) + bucket.configure_lifecycle(lifecycle) + response = bucket.get_lifecycle_config() + len(response).should.equal(1) + lifecycle = response[0] + lifecycle.id.should.equal("myid") + lifecycle.prefix.should.equal("") + lifecycle.status.should.equal("Enabled") + list(lifecycle.transition).should.equal([]) + + +@mock_s3 +def test_lifecycle_with_filters(): + client = boto3.client("s3") + client.create_bucket(Bucket="bucket") + + # Create a lifecycle rule with a Filter (no tags): + lfc = { + "Rules": [ + { + "Expiration": {"Days": 7}, + "ID": "wholebucket", + "Filter": {"Prefix": ""}, + "Status": "Enabled", + } + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert result["Rules"][0]["Filter"]["Prefix"] == "" + assert not result["Rules"][0]["Filter"].get("And") + assert not result["Rules"][0]["Filter"].get("Tag") + with assert_raises(KeyError): + assert result["Rules"][0]["Prefix"] + + # Without any prefixes and an empty filter (this is by default a prefix for the whole bucket): + lfc = { + "Rules": [ + { + "Expiration": {"Days": 7}, + "ID": "wholebucket", + "Filter": {}, + "Status": "Enabled", + } + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + with assert_raises(KeyError): + assert result["Rules"][0]["Prefix"] + + # If we remove the filter -- and don't specify a Prefix, then this is bad: + lfc["Rules"][0].pop("Filter") + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + + # With a tag: + lfc["Rules"][0]["Filter"] = {"Tag": {"Key": "mytag", "Value": "mytagvalue"}} + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + with assert_raises(KeyError): + assert result["Rules"][0]["Filter"]["Prefix"] + assert not result["Rules"][0]["Filter"].get("And") + assert result["Rules"][0]["Filter"]["Tag"]["Key"] == "mytag" + assert result["Rules"][0]["Filter"]["Tag"]["Value"] == "mytagvalue" + with assert_raises(KeyError): + assert result["Rules"][0]["Prefix"] + + # With And (single tag): + lfc["Rules"][0]["Filter"] = { + "And": { + "Prefix": "some/prefix", + "Tags": [{"Key": "mytag", "Value": "mytagvalue"}], + } + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert not result["Rules"][0]["Filter"].get("Prefix") + assert result["Rules"][0]["Filter"]["And"]["Prefix"] == "some/prefix" + assert len(result["Rules"][0]["Filter"]["And"]["Tags"]) == 1 + assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Key"] == "mytag" + assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Value"] == "mytagvalue" + with assert_raises(KeyError): + assert result["Rules"][0]["Prefix"] + + # With multiple And tags: + lfc["Rules"][0]["Filter"]["And"] = { + "Prefix": "some/prefix", + "Tags": [ + {"Key": "mytag", "Value": "mytagvalue"}, + {"Key": "mytag2", "Value": "mytagvalue2"}, + ], + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert not result["Rules"][0]["Filter"].get("Prefix") + assert result["Rules"][0]["Filter"]["And"]["Prefix"] == "some/prefix" + assert len(result["Rules"][0]["Filter"]["And"]["Tags"]) == 2 + assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Key"] == "mytag" + assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Value"] == "mytagvalue" + assert result["Rules"][0]["Filter"]["And"]["Tags"][1]["Key"] == "mytag2" + assert result["Rules"][0]["Filter"]["And"]["Tags"][1]["Value"] == "mytagvalue2" + with assert_raises(KeyError): + assert result["Rules"][0]["Prefix"] + + # And filter without Prefix but multiple Tags: + lfc["Rules"][0]["Filter"]["And"] = { + "Tags": [ + {"Key": "mytag", "Value": "mytagvalue"}, + {"Key": "mytag2", "Value": "mytagvalue2"}, + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + with assert_raises(KeyError): + assert result["Rules"][0]["Filter"]["And"]["Prefix"] + assert len(result["Rules"][0]["Filter"]["And"]["Tags"]) == 2 + assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Key"] == "mytag" + assert result["Rules"][0]["Filter"]["And"]["Tags"][0]["Value"] == "mytagvalue" + assert result["Rules"][0]["Filter"]["And"]["Tags"][1]["Key"] == "mytag2" + assert result["Rules"][0]["Filter"]["And"]["Tags"][1]["Value"] == "mytagvalue2" + with assert_raises(KeyError): + assert result["Rules"][0]["Prefix"] + + # Can't have both filter and prefix: + lfc["Rules"][0]["Prefix"] = "" + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + + lfc["Rules"][0]["Prefix"] = "some/path" + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + + # No filters -- just a prefix: + del lfc["Rules"][0]["Filter"] + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert not result["Rules"][0].get("Filter") + assert result["Rules"][0]["Prefix"] == "some/path" + + # Can't have Tag, Prefix, and And in a filter: + del lfc["Rules"][0]["Prefix"] + lfc["Rules"][0]["Filter"] = { + "Prefix": "some/prefix", + "Tag": {"Key": "mytag", "Value": "mytagvalue"}, + } + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + + lfc["Rules"][0]["Filter"] = { + "Tag": {"Key": "mytag", "Value": "mytagvalue"}, + "And": { + "Prefix": "some/prefix", + "Tags": [ + {"Key": "mytag", "Value": "mytagvalue"}, + {"Key": "mytag2", "Value": "mytagvalue2"}, + ], + }, + } + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + + # Make sure multiple rules work: + lfc = { + "Rules": [ + { + "Expiration": {"Days": 7}, + "ID": "wholebucket", + "Filter": {"Prefix": ""}, + "Status": "Enabled", + }, + { + "Expiration": {"Days": 10}, + "ID": "Tags", + "Filter": {"Tag": {"Key": "somekey", "Value": "somevalue"}}, + "Status": "Enabled", + }, + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket")["Rules"] + assert len(result) == 2 + assert result[0]["ID"] == "wholebucket" + assert result[1]["ID"] == "Tags" + + +@mock_s3 +def test_lifecycle_with_eodm(): + client = boto3.client("s3") + client.create_bucket(Bucket="bucket") + + lfc = { + "Rules": [ + { + "Expiration": {"ExpiredObjectDeleteMarker": True}, + "ID": "wholebucket", + "Filter": {"Prefix": ""}, + "Status": "Enabled", + } + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert result["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] + + # Set to False: + lfc["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] = False + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert not result["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] + + # With failure: + lfc["Rules"][0]["Expiration"]["Days"] = 7 + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + del lfc["Rules"][0]["Expiration"]["Days"] + + lfc["Rules"][0]["Expiration"]["Date"] = datetime(2015, 1, 1) + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + + +@mock_s3 +def test_lifecycle_with_nve(): + client = boto3.client("s3") + client.create_bucket(Bucket="bucket") + + lfc = { + "Rules": [ + { + "NoncurrentVersionExpiration": {"NoncurrentDays": 30}, + "ID": "wholebucket", + "Filter": {"Prefix": ""}, + "Status": "Enabled", + } + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 30 + + # Change NoncurrentDays: + lfc["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] = 10 + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 10 + + # TODO: Add test for failures due to missing children + + +@mock_s3 +def test_lifecycle_with_nvt(): + client = boto3.client("s3") + client.create_bucket(Bucket="bucket") + + lfc = { + "Rules": [ + { + "NoncurrentVersionTransitions": [ + {"NoncurrentDays": 30, "StorageClass": "ONEZONE_IA"} + ], + "ID": "wholebucket", + "Filter": {"Prefix": ""}, + "Status": "Enabled", + } + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 30 + assert ( + result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] + == "ONEZONE_IA" + ) + + # Change NoncurrentDays: + lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 10 + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 10 + + # Change StorageClass: + lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] = "GLACIER" + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert ( + result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] + == "GLACIER" + ) + + # With failures for missing children: + del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 30 + + del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] + with assert_raises(ClientError) as err: + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + assert err.exception.response["Error"]["Code"] == "MalformedXML" + + +@mock_s3 +def test_lifecycle_with_aimu(): + client = boto3.client("s3") + client.create_bucket(Bucket="bucket") + + lfc = { + "Rules": [ + { + "AbortIncompleteMultipartUpload": {"DaysAfterInitiation": 7}, + "ID": "wholebucket", + "Filter": {"Prefix": ""}, + "Status": "Enabled", + } + ] + } + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert ( + result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 7 + ) + + # Change DaysAfterInitiation: + lfc["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] = 30 + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket") + assert len(result["Rules"]) == 1 + assert ( + result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] + == 30 + ) + + # TODO: Add test for failures due to missing children + + +@mock_s3_deprecated +def test_lifecycle_with_glacier_transition(): + conn = boto.s3.connect_to_region("us-west-1") + bucket = conn.create_bucket("foobar") + + lifecycle = Lifecycle() + transition = Transition(days=30, storage_class="GLACIER") + rule = Rule( + "myid", prefix="", status="Enabled", expiration=None, transition=transition + ) + lifecycle.append(rule) + bucket.configure_lifecycle(lifecycle) + response = bucket.get_lifecycle_config() + transition = response[0].transition + transition.days.should.equal(30) + transition.storage_class.should.equal("GLACIER") + transition.date.should.equal(None) + + +@mock_s3_deprecated +def test_lifecycle_multi(): + conn = boto.s3.connect_to_region("us-west-1") + bucket = conn.create_bucket("foobar") + + date = "2022-10-12T00:00:00.000Z" + sc = "GLACIER" + lifecycle = Lifecycle() + lifecycle.add_rule("1", "1/", "Enabled", 1) + lifecycle.add_rule("2", "2/", "Enabled", Expiration(days=2)) + lifecycle.add_rule("3", "3/", "Enabled", Expiration(date=date)) + lifecycle.add_rule("4", "4/", "Enabled", None, Transition(days=4, storage_class=sc)) + lifecycle.add_rule( + "5", "5/", "Enabled", None, Transition(date=date, storage_class=sc) + ) + + bucket.configure_lifecycle(lifecycle) + # read the lifecycle back + rules = bucket.get_lifecycle_config() + + for rule in rules: + if rule.id == "1": + rule.prefix.should.equal("1/") + rule.expiration.days.should.equal(1) + elif rule.id == "2": + rule.prefix.should.equal("2/") + rule.expiration.days.should.equal(2) + elif rule.id == "3": + rule.prefix.should.equal("3/") + rule.expiration.date.should.equal(date) + elif rule.id == "4": + rule.prefix.should.equal("4/") + rule.transition.days.should.equal(4) + rule.transition.storage_class.should.equal(sc) + elif rule.id == "5": + rule.prefix.should.equal("5/") + rule.transition.date.should.equal(date) + rule.transition.storage_class.should.equal(sc) + else: + assert False, "Invalid rule id" + + +@mock_s3_deprecated +def test_lifecycle_delete(): + conn = boto.s3.connect_to_region("us-west-1") + bucket = conn.create_bucket("foobar") + + lifecycle = Lifecycle() + lifecycle.add_rule(expiration=30) + bucket.configure_lifecycle(lifecycle) + response = bucket.get_lifecycle_config() + response.should.have.length_of(1) + + bucket.delete_lifecycle_configuration() + bucket.get_lifecycle_config.when.called_with().should.throw(S3ResponseError) diff --git a/tests/test_s3/test_s3_storageclass.py b/tests/test_s3/test_s3_storageclass.py index c72b773a9..dbdc85c42 100644 --- a/tests/test_s3/test_s3_storageclass.py +++ b/tests/test_s3/test_s3_storageclass.py @@ -11,30 +11,35 @@ from moto import mock_s3 @mock_s3 def test_s3_storage_class_standard(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - # add an object to the bucket with standard storage + # add an object to the bucket with standard storage - s3.put_object(Bucket="Bucket", Key="my_key", Body="my_value") + s3.put_object(Bucket="Bucket", Key="my_key", Body="my_value") - list_of_objects = s3.list_objects(Bucket="Bucket") + list_of_objects = s3.list_objects(Bucket="Bucket") - list_of_objects['Contents'][0]["StorageClass"].should.equal("STANDARD") + list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD") @mock_s3 def test_s3_storage_class_infrequent_access(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - # add an object to the bucket with standard storage + # add an object to the bucket with standard storage - s3.put_object(Bucket="Bucket", Key="my_key_infrequent", Body="my_value_infrequent", StorageClass="STANDARD_IA") + s3.put_object( + Bucket="Bucket", + Key="my_key_infrequent", + Body="my_value_infrequent", + StorageClass="STANDARD_IA", + ) - D = s3.list_objects(Bucket="Bucket") + D = s3.list_objects(Bucket="Bucket") - D['Contents'][0]["StorageClass"].should.equal("STANDARD_IA") + D["Contents"][0]["StorageClass"].should.equal("STANDARD_IA") @mock_s3 @@ -42,74 +47,104 @@ def test_s3_storage_class_intelligent_tiering(): s3 = boto3.client("s3") s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="my_key_infrequent", Body="my_value_infrequent", StorageClass="INTELLIGENT_TIERING") + s3.put_object( + Bucket="Bucket", + Key="my_key_infrequent", + Body="my_value_infrequent", + StorageClass="INTELLIGENT_TIERING", + ) objects = s3.list_objects(Bucket="Bucket") - objects['Contents'][0]["StorageClass"].should.equal("INTELLIGENT_TIERING") + objects["Contents"][0]["StorageClass"].should.equal("INTELLIGENT_TIERING") @mock_s3 def test_s3_storage_class_copy(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD" + ) - s3.create_bucket(Bucket="Bucket2") - # second object is originally of storage class REDUCED_REDUNDANCY - s3.put_object(Bucket="Bucket2", Key="Second_Object", Body="Body2") + s3.create_bucket(Bucket="Bucket2") + # second object is originally of storage class REDUCED_REDUNDANCY + s3.put_object(Bucket="Bucket2", Key="Second_Object", Body="Body2") - s3.copy_object(CopySource = {"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket2", Key="Second_Object", StorageClass="ONEZONE_IA") + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket2", + Key="Second_Object", + StorageClass="ONEZONE_IA", + ) - list_of_copied_objects = s3.list_objects(Bucket="Bucket2") + list_of_copied_objects = s3.list_objects(Bucket="Bucket2") - # checks that a copied object can be properly copied - list_of_copied_objects["Contents"][0]["StorageClass"].should.equal("ONEZONE_IA") + # checks that a copied object can be properly copied + list_of_copied_objects["Contents"][0]["StorageClass"].should.equal("ONEZONE_IA") @mock_s3 def test_s3_invalid_copied_storage_class(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD" + ) - s3.create_bucket(Bucket="Bucket2") - s3.put_object(Bucket="Bucket2", Key="Second_Object", Body="Body2", StorageClass="REDUCED_REDUNDANCY") + s3.create_bucket(Bucket="Bucket2") + s3.put_object( + Bucket="Bucket2", + Key="Second_Object", + Body="Body2", + StorageClass="REDUCED_REDUNDANCY", + ) - # Try to copy an object with an invalid storage class - with assert_raises(ClientError) as err: - s3.copy_object(CopySource = {"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket2", Key="Second_Object", StorageClass="STANDARD2") + # Try to copy an object with an invalid storage class + with assert_raises(ClientError) as err: + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket2", + Key="Second_Object", + StorageClass="STANDARD2", + ) - e = err.exception - e.response["Error"]["Code"].should.equal("InvalidStorageClass") - e.response["Error"]["Message"].should.equal("The storage class you specified is not valid") + e = err.exception + e.response["Error"]["Code"].should.equal("InvalidStorageClass") + e.response["Error"]["Message"].should.equal( + "The storage class you specified is not valid" + ) @mock_s3 def test_s3_invalid_storage_class(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - # Try to add an object with an invalid storage class - with assert_raises(ClientError) as err: - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARDD") + # Try to add an object with an invalid storage class + with assert_raises(ClientError) as err: + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARDD" + ) - e = err.exception - e.response["Error"]["Code"].should.equal("InvalidStorageClass") - e.response["Error"]["Message"].should.equal("The storage class you specified is not valid") + e = err.exception + e.response["Error"]["Code"].should.equal("InvalidStorageClass") + e.response["Error"]["Message"].should.equal( + "The storage class you specified is not valid" + ) @mock_s3 def test_s3_default_storage_class(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body") + s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body") - list_of_objects = s3.list_objects(Bucket="Bucket") + list_of_objects = s3.list_objects(Bucket="Bucket") - # tests that the default storage class is still STANDARD - list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD") + # tests that the default storage class is still STANDARD + list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD") @mock_s3 @@ -117,10 +152,16 @@ def test_s3_copy_object_error_for_glacier_storage_class(): s3 = boto3.client("s3") s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="GLACIER") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="GLACIER" + ) with assert_raises(ClientError) as exc: - s3.copy_object(CopySource={"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket", Key="Second_Object") + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket", + Key="Second_Object", + ) exc.exception.response["Error"]["Code"].should.equal("ObjectNotInActiveTierError") @@ -130,9 +171,15 @@ def test_s3_copy_object_error_for_deep_archive_storage_class(): s3 = boto3.client("s3") s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="DEEP_ARCHIVE") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="DEEP_ARCHIVE" + ) with assert_raises(ClientError) as exc: - s3.copy_object(CopySource={"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket", Key="Second_Object") + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket", + Key="Second_Object", + ) exc.exception.response["Error"]["Code"].should.equal("ObjectNotInActiveTierError") diff --git a/tests/test_s3/test_s3_utils.py b/tests/test_s3/test_s3_utils.py index 501137910..b90225597 100644 --- a/tests/test_s3/test_s3_utils.py +++ b/tests/test_s3/test_s3_utils.py @@ -1,80 +1,121 @@ -from __future__ import unicode_literals -import os -from sure import expect -from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url - - -def test_base_url(): - expect(bucket_name_from_url('https://s3.amazonaws.com/')).should.equal(None) - - -def test_localhost_bucket(): - expect(bucket_name_from_url('https://wfoobar.localhost:5000/abc') - ).should.equal("wfoobar") - - -def test_localhost_without_bucket(): - expect(bucket_name_from_url( - 'https://www.localhost:5000/def')).should.equal(None) - -def test_force_ignore_subdomain_for_bucketnames(): - os.environ['S3_IGNORE_SUBDOMAIN_BUCKETNAME'] = '1' - expect(bucket_name_from_url('https://subdomain.localhost:5000/abc/resource')).should.equal(None) - del(os.environ['S3_IGNORE_SUBDOMAIN_BUCKETNAME']) - - - -def test_versioned_key_store(): - d = _VersionedKeyStore() - - d.should.have.length_of(0) - - d['key'] = [1] - - d.should.have.length_of(1) - - d['key'] = 2 - d.should.have.length_of(1) - - d.should.have.key('key').being.equal(2) - - d.get.when.called_with('key').should.return_value(2) - d.get.when.called_with('badkey').should.return_value(None) - d.get.when.called_with('badkey', 'HELLO').should.return_value('HELLO') - - # Tests key[ - d.shouldnt.have.key('badkey') - d.__getitem__.when.called_with('badkey').should.throw(KeyError) - - d.getlist('key').should.have.length_of(2) - d.getlist('key').should.be.equal([[1], 2]) - d.getlist('badkey').should.be.none - - d.setlist('key', 1) - d.getlist('key').should.be.equal([1]) - - d.setlist('key', (1, 2)) - d.getlist('key').shouldnt.be.equal((1, 2)) - d.getlist('key').should.be.equal([1, 2]) - - d.setlist('key', [[1], [2]]) - d['key'].should.have.length_of(1) - d.getlist('key').should.be.equal([[1], [2]]) - - -def test_parse_region_from_url(): - expected = 'us-west-2' - for url in ['http://s3-us-west-2.amazonaws.com/bucket', - 'http://s3.us-west-2.amazonaws.com/bucket', - 'http://bucket.s3-us-west-2.amazonaws.com', - 'https://s3-us-west-2.amazonaws.com/bucket', - 'https://s3.us-west-2.amazonaws.com/bucket', - 'https://bucket.s3-us-west-2.amazonaws.com']: - parse_region_from_url(url).should.equal(expected) - - expected = 'us-east-1' - for url in ['http://s3.amazonaws.com/bucket', - 'http://bucket.s3.amazonaws.com', - 'https://s3.amazonaws.com/bucket', - 'https://bucket.s3.amazonaws.com']: - parse_region_from_url(url).should.equal(expected) +from __future__ import unicode_literals +import os +from sure import expect +from moto.s3.utils import ( + bucket_name_from_url, + _VersionedKeyStore, + parse_region_from_url, + clean_key_name, + undo_clean_key_name, +) +from parameterized import parameterized + + +def test_base_url(): + expect(bucket_name_from_url("https://s3.amazonaws.com/")).should.equal(None) + + +def test_localhost_bucket(): + expect(bucket_name_from_url("https://wfoobar.localhost:5000/abc")).should.equal( + "wfoobar" + ) + + +def test_localhost_without_bucket(): + expect(bucket_name_from_url("https://www.localhost:5000/def")).should.equal(None) + + +def test_force_ignore_subdomain_for_bucketnames(): + os.environ["S3_IGNORE_SUBDOMAIN_BUCKETNAME"] = "1" + expect( + bucket_name_from_url("https://subdomain.localhost:5000/abc/resource") + ).should.equal(None) + del os.environ["S3_IGNORE_SUBDOMAIN_BUCKETNAME"] + + +def test_versioned_key_store(): + d = _VersionedKeyStore() + + d.should.have.length_of(0) + + d["key"] = [1] + + d.should.have.length_of(1) + + d["key"] = 2 + d.should.have.length_of(1) + + d.should.have.key("key").being.equal(2) + + d.get.when.called_with("key").should.return_value(2) + d.get.when.called_with("badkey").should.return_value(None) + d.get.when.called_with("badkey", "HELLO").should.return_value("HELLO") + + # Tests key[ + d.shouldnt.have.key("badkey") + d.__getitem__.when.called_with("badkey").should.throw(KeyError) + + d.getlist("key").should.have.length_of(2) + d.getlist("key").should.be.equal([[1], 2]) + d.getlist("badkey").should.be.none + + d.setlist("key", 1) + d.getlist("key").should.be.equal([1]) + + d.setlist("key", (1, 2)) + d.getlist("key").shouldnt.be.equal((1, 2)) + d.getlist("key").should.be.equal([1, 2]) + + d.setlist("key", [[1], [2]]) + d["key"].should.have.length_of(1) + d.getlist("key").should.be.equal([[1], [2]]) + + +def test_parse_region_from_url(): + expected = "us-west-2" + for url in [ + "http://s3-us-west-2.amazonaws.com/bucket", + "http://s3.us-west-2.amazonaws.com/bucket", + "http://bucket.s3-us-west-2.amazonaws.com", + "https://s3-us-west-2.amazonaws.com/bucket", + "https://s3.us-west-2.amazonaws.com/bucket", + "https://bucket.s3-us-west-2.amazonaws.com", + ]: + parse_region_from_url(url).should.equal(expected) + + expected = "us-east-1" + for url in [ + "http://s3.amazonaws.com/bucket", + "http://bucket.s3.amazonaws.com", + "https://s3.amazonaws.com/bucket", + "https://bucket.s3.amazonaws.com", + ]: + parse_region_from_url(url).should.equal(expected) + + +@parameterized( + [ + ("foo/bar/baz", "foo/bar/baz"), + ("foo", "foo"), + ( + "foo/run_dt%3D2019-01-01%252012%253A30%253A00", + "foo/run_dt=2019-01-01%2012%3A30%3A00", + ), + ] +) +def test_clean_key_name(key, expected): + clean_key_name(key).should.equal(expected) + + +@parameterized( + [ + ("foo/bar/baz", "foo/bar/baz"), + ("foo", "foo"), + ( + "foo/run_dt%3D2019-01-01%252012%253A30%253A00", + "foo/run_dt%253D2019-01-01%25252012%25253A30%25253A00", + ), + ] +) +def test_undo_clean_key_name(key, expected): + undo_clean_key_name(key).should.equal(expected) diff --git a/tests/test_s3/test_server.py b/tests/test_s3/test_server.py index b179a2329..56d46de09 100644 --- a/tests/test_s3/test_server.py +++ b/tests/test_s3/test_server.py @@ -6,16 +6,16 @@ import sure # noqa from flask.testing import FlaskClient import moto.server as server -''' +""" Test the different server responses -''' +""" class AuthenticatedClient(FlaskClient): def open(self, *args, **kwargs): - kwargs['headers'] = kwargs.get('headers', {}) - kwargs['headers']['Authorization'] = "Any authorization header" - kwargs['content_length'] = 0 # Fixes content-length complaints. + kwargs["headers"] = kwargs.get("headers", {}) + kwargs["headers"]["Authorization"] = "Any authorization header" + kwargs["content_length"] = 0 # Fixes content-length complaints. return super(AuthenticatedClient, self).open(*args, **kwargs) @@ -27,30 +27,29 @@ def authenticated_client(): def test_s3_server_get(): test_client = authenticated_client() - res = test_client.get('/') + res = test_client.get("/") - res.data.should.contain(b'ListAllMyBucketsResult') + res.data.should.contain(b"ListAllMyBucketsResult") def test_s3_server_bucket_create(): test_client = authenticated_client() - res = test_client.put('/', 'http://foobaz.localhost:5000/') + res = test_client.put("/", "http://foobaz.localhost:5000/") res.status_code.should.equal(200) - res = test_client.get('/') - res.data.should.contain(b'foobaz') + res = test_client.get("/") + res.data.should.contain(b"foobaz") - res = test_client.get('/', 'http://foobaz.localhost:5000/') + res = test_client.get("/", "http://foobaz.localhost:5000/") res.status_code.should.equal(200) res.data.should.contain(b"ListBucketResult") - res = test_client.put( - '/bar', 'http://foobaz.localhost:5000/', data='test value') + res = test_client.put("/bar", "http://foobaz.localhost:5000/", data="test value") res.status_code.should.equal(200) - assert 'ETag' in dict(res.headers) + assert "ETag" in dict(res.headers) - res = test_client.get('/bar', 'http://foobaz.localhost:5000/') + res = test_client.get("/bar", "http://foobaz.localhost:5000/") res.status_code.should.equal(200) res.data.should.equal(b"test value") @@ -59,24 +58,24 @@ def test_s3_server_bucket_versioning(): test_client = authenticated_client() # Just enough XML to enable versioning - body = 'Enabled' - res = test_client.put( - '/?versioning', 'http://foobaz.localhost:5000', data=body) + body = "Enabled" + res = test_client.put("/?versioning", "http://foobaz.localhost:5000", data=body) res.status_code.should.equal(200) def test_s3_server_post_to_bucket(): test_client = authenticated_client() - res = test_client.put('/', 'http://tester.localhost:5000/') + res = test_client.put("/", "http://tester.localhost:5000/") res.status_code.should.equal(200) - test_client.post('/', "https://tester.localhost:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/", + "https://tester.localhost:5000/", + data={"key": "the-key", "file": "nothing"}, + ) - res = test_client.get('/the-key', 'http://tester.localhost:5000/') + res = test_client.get("/the-key", "http://tester.localhost:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") @@ -84,23 +83,28 @@ def test_s3_server_post_to_bucket(): def test_s3_server_post_without_content_length(): test_client = authenticated_client() - res = test_client.put('/', 'http://tester.localhost:5000/', environ_overrides={'CONTENT_LENGTH': ''}) + res = test_client.put( + "/", "http://tester.localhost:5000/", environ_overrides={"CONTENT_LENGTH": ""} + ) res.status_code.should.equal(411) - res = test_client.post('/', "https://tester.localhost:5000/", environ_overrides={'CONTENT_LENGTH': ''}) + res = test_client.post( + "/", "https://tester.localhost:5000/", environ_overrides={"CONTENT_LENGTH": ""} + ) res.status_code.should.equal(411) def test_s3_server_post_unicode_bucket_key(): # Make sure that we can deal with non-ascii characters in request URLs (e.g., S3 object names) dispatcher = server.DomainDispatcherApplication(server.create_backend_app) - backend_app = dispatcher.get_application({ - 'HTTP_HOST': 's3.amazonaws.com', - 'PATH_INFO': '/test-bucket/test-object-てすと' - }) + backend_app = dispatcher.get_application( + {"HTTP_HOST": "s3.amazonaws.com", "PATH_INFO": "/test-bucket/test-object-てすと"} + ) assert backend_app - backend_app = dispatcher.get_application({ - 'HTTP_HOST': 's3.amazonaws.com', - 'PATH_INFO': '/test-bucket/test-object-てすと'.encode('utf-8') - }) + backend_app = dispatcher.get_application( + { + "HTTP_HOST": "s3.amazonaws.com", + "PATH_INFO": "/test-bucket/test-object-てすと".encode("utf-8"), + } + ) assert backend_app diff --git a/tests/test_s3bucket_path/test_bucket_path_server.py b/tests/test_s3bucket_path/test_bucket_path_server.py index f6238dd28..2fe606799 100644 --- a/tests/test_s3bucket_path/test_bucket_path_server.py +++ b/tests/test_s3bucket_path/test_bucket_path_server.py @@ -4,16 +4,16 @@ import sure # noqa from flask.testing import FlaskClient import moto.server as server -''' +""" Test the different server responses -''' +""" class AuthenticatedClient(FlaskClient): def open(self, *args, **kwargs): - kwargs['headers'] = kwargs.get('headers', {}) - kwargs['headers']['Authorization'] = "Any authorization header" - kwargs['content_length'] = 0 # Fixes content-length complaints. + kwargs["headers"] = kwargs.get("headers", {}) + kwargs["headers"]["Authorization"] = "Any authorization header" + kwargs["content_length"] = 0 # Fixes content-length complaints. return super(AuthenticatedClient, self).open(*args, **kwargs) @@ -26,42 +26,41 @@ def authenticated_client(): def test_s3_server_get(): test_client = authenticated_client() - res = test_client.get('/') + res = test_client.get("/") - res.data.should.contain(b'ListAllMyBucketsResult') + res.data.should.contain(b"ListAllMyBucketsResult") def test_s3_server_bucket_create(): test_client = authenticated_client() - res = test_client.put('/foobar', 'http://localhost:5000') + res = test_client.put("/foobar", "http://localhost:5000") res.status_code.should.equal(200) - res = test_client.get('/') - res.data.should.contain(b'foobar') + res = test_client.get("/") + res.data.should.contain(b"foobar") - res = test_client.get('/foobar', 'http://localhost:5000') + res = test_client.get("/foobar", "http://localhost:5000") res.status_code.should.equal(200) res.data.should.contain(b"ListBucketResult") - res = test_client.put('/foobar2/', 'http://localhost:5000') + res = test_client.put("/foobar2/", "http://localhost:5000") res.status_code.should.equal(200) - res = test_client.get('/') - res.data.should.contain(b'foobar2') + res = test_client.get("/") + res.data.should.contain(b"foobar2") - res = test_client.get('/foobar2/', 'http://localhost:5000') + res = test_client.get("/foobar2/", "http://localhost:5000") res.status_code.should.equal(200) res.data.should.contain(b"ListBucketResult") - res = test_client.get('/missing-bucket', 'http://localhost:5000') + res = test_client.get("/missing-bucket", "http://localhost:5000") res.status_code.should.equal(404) - res = test_client.put( - '/foobar/bar', 'http://localhost:5000', data='test value') + res = test_client.put("/foobar/bar", "http://localhost:5000", data="test value") res.status_code.should.equal(200) - res = test_client.get('/foobar/bar', 'http://localhost:5000') + res = test_client.get("/foobar/bar", "http://localhost:5000") res.status_code.should.equal(200) res.data.should.equal(b"test value") @@ -69,15 +68,16 @@ def test_s3_server_bucket_create(): def test_s3_server_post_to_bucket(): test_client = authenticated_client() - res = test_client.put('/foobar2', 'http://localhost:5000/') + res = test_client.put("/foobar2", "http://localhost:5000/") res.status_code.should.equal(200) - test_client.post('/foobar2', "https://localhost:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/foobar2", + "https://localhost:5000/", + data={"key": "the-key", "file": "nothing"}, + ) - res = test_client.get('/foobar2/the-key', 'http://localhost:5000/') + res = test_client.get("/foobar2/the-key", "http://localhost:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") @@ -85,15 +85,14 @@ def test_s3_server_post_to_bucket(): def test_s3_server_put_ipv6(): test_client = authenticated_client() - res = test_client.put('/foobar2', 'http://[::]:5000/') + res = test_client.put("/foobar2", "http://[::]:5000/") res.status_code.should.equal(200) - test_client.post('/foobar2', "https://[::]:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/foobar2", "https://[::]:5000/", data={"key": "the-key", "file": "nothing"} + ) - res = test_client.get('/foobar2/the-key', 'http://[::]:5000/') + res = test_client.get("/foobar2/the-key", "http://[::]:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") @@ -101,14 +100,15 @@ def test_s3_server_put_ipv6(): def test_s3_server_put_ipv4(): test_client = authenticated_client() - res = test_client.put('/foobar2', 'http://127.0.0.1:5000/') + res = test_client.put("/foobar2", "http://127.0.0.1:5000/") res.status_code.should.equal(200) - test_client.post('/foobar2', "https://127.0.0.1:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/foobar2", + "https://127.0.0.1:5000/", + data={"key": "the-key", "file": "nothing"}, + ) - res = test_client.get('/foobar2/the-key', 'http://127.0.0.1:5000/') + res = test_client.get("/foobar2/the-key", "http://127.0.0.1:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") diff --git a/tests/test_s3bucket_path/test_s3bucket_path.py b/tests/test_s3bucket_path/test_s3bucket_path.py index 2ec5e8f30..e204d0527 100644 --- a/tests/test_s3bucket_path/test_s3bucket_path.py +++ b/tests/test_s3bucket_path/test_s3bucket_path.py @@ -1,321 +1,322 @@ -from __future__ import unicode_literals -from six.moves.urllib.request import urlopen -from six.moves.urllib.error import HTTPError - -import boto -from boto.exception import S3ResponseError -from boto.s3.key import Key -from boto.s3.connection import OrdinaryCallingFormat - -from freezegun import freeze_time -import requests - -import sure # noqa - -from moto import mock_s3, mock_s3_deprecated - - -def create_connection(key=None, secret=None): - return boto.connect_s3(key, secret, calling_format=OrdinaryCallingFormat()) - - -class MyModel(object): - - def __init__(self, name, value): - self.name = name - self.value = value - - def save(self): - conn = create_connection('the_key', 'the_secret') - bucket = conn.get_bucket('mybucket') - k = Key(bucket) - k.key = self.name - k.set_contents_from_string(self.value) - - -@mock_s3_deprecated -def test_my_model_save(): - # Create Bucket so that test can run - conn = create_connection('the_key', 'the_secret') - conn.create_bucket('mybucket') - #################################### - - model_instance = MyModel('steve', 'is awesome') - model_instance.save() - - conn.get_bucket('mybucket').get_key( - 'steve').get_contents_as_string().should.equal(b'is awesome') - - -@mock_s3_deprecated -def test_missing_key(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - bucket.get_key("the-key").should.equal(None) - - -@mock_s3_deprecated -def test_missing_key_urllib2(): - conn = create_connection('the_key', 'the_secret') - conn.create_bucket("foobar") - - urlopen.when.called_with( - "http://s3.amazonaws.com/foobar/the-key").should.throw(HTTPError) - - -@mock_s3_deprecated -def test_empty_key(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - key = Key(bucket) - key.key = "the-key" - key.set_contents_from_string("") - - bucket.get_key("the-key").get_contents_as_string().should.equal(b'') - - -@mock_s3_deprecated -def test_empty_key_set_on_existing_key(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - key = Key(bucket) - key.key = "the-key" - key.set_contents_from_string("foobar") - - bucket.get_key("the-key").get_contents_as_string().should.equal(b'foobar') - - key.set_contents_from_string("") - bucket.get_key("the-key").get_contents_as_string().should.equal(b'') - - -@mock_s3_deprecated -def test_large_key_save(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - key = Key(bucket) - key.key = "the-key" - key.set_contents_from_string("foobar" * 100000) - - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b'foobar' * 100000) - - -@mock_s3_deprecated -def test_copy_key(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - key = Key(bucket) - key.key = "the-key" - key.set_contents_from_string("some value") - - bucket.copy_key('new-key', 'foobar', 'the-key') - - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b"some value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") - - -@mock_s3_deprecated -def test_set_metadata(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - key = Key(bucket) - key.key = 'the-key' - key.set_metadata('md', 'Metadatastring') - key.set_contents_from_string("Testval") - - bucket.get_key('the-key').get_metadata('md').should.equal('Metadatastring') - - -@freeze_time("2012-01-01 12:00:00") -@mock_s3_deprecated -def test_last_modified(): - # See https://github.com/boto/boto/issues/466 - conn = create_connection() - bucket = conn.create_bucket("foobar") - key = Key(bucket) - key.key = "the-key" - key.set_contents_from_string("some value") - - rs = bucket.get_all_keys() - rs[0].last_modified.should.equal('2012-01-01T12:00:00.000Z') - - bucket.get_key( - "the-key").last_modified.should.equal('Sun, 01 Jan 2012 12:00:00 GMT') - - -@mock_s3_deprecated -def test_missing_bucket(): - conn = create_connection('the_key', 'the_secret') - conn.get_bucket.when.called_with('mybucket').should.throw(S3ResponseError) - - -@mock_s3_deprecated -def test_bucket_with_dash(): - conn = create_connection('the_key', 'the_secret') - conn.get_bucket.when.called_with( - 'mybucket-test').should.throw(S3ResponseError) - - -@mock_s3_deprecated -def test_bucket_deletion(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - - key = Key(bucket) - key.key = "the-key" - key.set_contents_from_string("some value") - - # Try to delete a bucket that still has keys - conn.delete_bucket.when.called_with("foobar").should.throw(S3ResponseError) - - bucket.delete_key("the-key") - conn.delete_bucket("foobar") - - # Get non-existing bucket - conn.get_bucket.when.called_with("foobar").should.throw(S3ResponseError) - - # Delete non-existant bucket - conn.delete_bucket.when.called_with("foobar").should.throw(S3ResponseError) - - -@mock_s3_deprecated -def test_get_all_buckets(): - conn = create_connection('the_key', 'the_secret') - conn.create_bucket("foobar") - conn.create_bucket("foobar2") - buckets = conn.get_all_buckets() - - buckets.should.have.length_of(2) - - -@mock_s3 -@mock_s3_deprecated -def test_post_to_bucket(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - - requests.post("https://s3.amazonaws.com/foobar", { - 'key': 'the-key', - 'file': 'nothing' - }) - - bucket.get_key('the-key').get_contents_as_string().should.equal(b'nothing') - - -@mock_s3 -@mock_s3_deprecated -def test_post_with_metadata_to_bucket(): - conn = create_connection('the_key', 'the_secret') - bucket = conn.create_bucket("foobar") - - requests.post("https://s3.amazonaws.com/foobar", { - 'key': 'the-key', - 'file': 'nothing', - 'x-amz-meta-test': 'metadata' - }) - - bucket.get_key('the-key').get_metadata('test').should.equal('metadata') - - -@mock_s3_deprecated -def test_bucket_name_with_dot(): - conn = create_connection() - bucket = conn.create_bucket('firstname.lastname') - - k = Key(bucket, 'somekey') - k.set_contents_from_string('somedata') - - -@mock_s3_deprecated -def test_key_with_special_characters(): - conn = create_connection() - bucket = conn.create_bucket('test_bucket_name') - - key = Key(bucket, 'test_list_keys_2/*x+?^@~!y') - key.set_contents_from_string('value1') - - key_list = bucket.list('test_list_keys_2/', '/') - keys = [x for x in key_list] - keys[0].name.should.equal("test_list_keys_2/*x+?^@~!y") - - -@mock_s3_deprecated -def test_bucket_key_listing_order(): - conn = create_connection() - bucket = conn.create_bucket('test_bucket') - prefix = 'toplevel/' - - def store(name): - k = Key(bucket, prefix + name) - k.set_contents_from_string('somedata') - - names = ['x/key', 'y.key1', 'y.key2', 'y.key3', 'x/y/key', 'x/y/z/key'] - - for name in names: - store(name) - - delimiter = None - keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/x/key', 'toplevel/x/y/key', 'toplevel/x/y/z/key', - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3' - ]) - - delimiter = '/' - keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3', 'toplevel/x/' - ]) - - # Test delimiter with no prefix - delimiter = '/' - keys = [x.name for x in bucket.list(prefix=None, delimiter=delimiter)] - keys.should.equal(['toplevel/']) - - delimiter = None - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal( - ['toplevel/x/key', 'toplevel/x/y/key', 'toplevel/x/y/z/key']) - - delimiter = '/' - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal(['toplevel/x/']) - - -@mock_s3_deprecated -def test_delete_keys(): - conn = create_connection() - bucket = conn.create_bucket('foobar') - - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') - - result = bucket.delete_keys(['file2', 'file3']) - result.deleted.should.have.length_of(2) - result.errors.should.have.length_of(0) - keys = bucket.get_all_keys() - keys.should.have.length_of(2) - keys[0].name.should.equal('file1') - - -@mock_s3_deprecated -def test_delete_keys_with_invalid(): - conn = create_connection() - bucket = conn.create_bucket('foobar') - - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') - - result = bucket.delete_keys(['abc', 'file3']) - - result.deleted.should.have.length_of(1) - result.errors.should.have.length_of(1) - keys = bucket.get_all_keys() - keys.should.have.length_of(3) - keys[0].name.should.equal('file1') +from __future__ import unicode_literals +from six.moves.urllib.request import urlopen +from six.moves.urllib.error import HTTPError + +import boto +from boto.exception import S3ResponseError +from boto.s3.key import Key +from boto.s3.connection import OrdinaryCallingFormat + +from freezegun import freeze_time +import requests + +import sure # noqa + +from moto import mock_s3, mock_s3_deprecated + + +def create_connection(key=None, secret=None): + return boto.connect_s3(key, secret, calling_format=OrdinaryCallingFormat()) + + +class MyModel(object): + def __init__(self, name, value): + self.name = name + self.value = value + + def save(self): + conn = create_connection("the_key", "the_secret") + bucket = conn.get_bucket("mybucket") + k = Key(bucket) + k.key = self.name + k.set_contents_from_string(self.value) + + +@mock_s3_deprecated +def test_my_model_save(): + # Create Bucket so that test can run + conn = create_connection("the_key", "the_secret") + conn.create_bucket("mybucket") + #################################### + + model_instance = MyModel("steve", "is awesome") + model_instance.save() + + conn.get_bucket("mybucket").get_key("steve").get_contents_as_string().should.equal( + b"is awesome" + ) + + +@mock_s3_deprecated +def test_missing_key(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + bucket.get_key("the-key").should.equal(None) + + +@mock_s3_deprecated +def test_missing_key_urllib2(): + conn = create_connection("the_key", "the_secret") + conn.create_bucket("foobar") + + urlopen.when.called_with("http://s3.amazonaws.com/foobar/the-key").should.throw( + HTTPError + ) + + +@mock_s3_deprecated +def test_empty_key(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + key = Key(bucket) + key.key = "the-key" + key.set_contents_from_string("") + + bucket.get_key("the-key").get_contents_as_string().should.equal(b"") + + +@mock_s3_deprecated +def test_empty_key_set_on_existing_key(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + key = Key(bucket) + key.key = "the-key" + key.set_contents_from_string("foobar") + + bucket.get_key("the-key").get_contents_as_string().should.equal(b"foobar") + + key.set_contents_from_string("") + bucket.get_key("the-key").get_contents_as_string().should.equal(b"") + + +@mock_s3_deprecated +def test_large_key_save(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + key = Key(bucket) + key.key = "the-key" + key.set_contents_from_string("foobar" * 100000) + + bucket.get_key("the-key").get_contents_as_string().should.equal(b"foobar" * 100000) + + +@mock_s3_deprecated +def test_copy_key(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + key = Key(bucket) + key.key = "the-key" + key.set_contents_from_string("some value") + + bucket.copy_key("new-key", "foobar", "the-key") + + bucket.get_key("the-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") + + +@mock_s3_deprecated +def test_set_metadata(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + key = Key(bucket) + key.key = "the-key" + key.set_metadata("md", "Metadatastring") + key.set_contents_from_string("Testval") + + bucket.get_key("the-key").get_metadata("md").should.equal("Metadatastring") + + +@freeze_time("2012-01-01 12:00:00") +@mock_s3_deprecated +def test_last_modified(): + # See https://github.com/boto/boto/issues/466 + conn = create_connection() + bucket = conn.create_bucket("foobar") + key = Key(bucket) + key.key = "the-key" + key.set_contents_from_string("some value") + + rs = bucket.get_all_keys() + rs[0].last_modified.should.equal("2012-01-01T12:00:00.000Z") + + bucket.get_key("the-key").last_modified.should.equal( + "Sun, 01 Jan 2012 12:00:00 GMT" + ) + + +@mock_s3_deprecated +def test_missing_bucket(): + conn = create_connection("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket").should.throw(S3ResponseError) + + +@mock_s3_deprecated +def test_bucket_with_dash(): + conn = create_connection("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket-test").should.throw(S3ResponseError) + + +@mock_s3_deprecated +def test_bucket_deletion(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + + key = Key(bucket) + key.key = "the-key" + key.set_contents_from_string("some value") + + # Try to delete a bucket that still has keys + conn.delete_bucket.when.called_with("foobar").should.throw(S3ResponseError) + + bucket.delete_key("the-key") + conn.delete_bucket("foobar") + + # Get non-existing bucket + conn.get_bucket.when.called_with("foobar").should.throw(S3ResponseError) + + # Delete non-existant bucket + conn.delete_bucket.when.called_with("foobar").should.throw(S3ResponseError) + + +@mock_s3_deprecated +def test_get_all_buckets(): + conn = create_connection("the_key", "the_secret") + conn.create_bucket("foobar") + conn.create_bucket("foobar2") + buckets = conn.get_all_buckets() + + buckets.should.have.length_of(2) + + +@mock_s3 +@mock_s3_deprecated +def test_post_to_bucket(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + + requests.post( + "https://s3.amazonaws.com/foobar", {"key": "the-key", "file": "nothing"} + ) + + bucket.get_key("the-key").get_contents_as_string().should.equal(b"nothing") + + +@mock_s3 +@mock_s3_deprecated +def test_post_with_metadata_to_bucket(): + conn = create_connection("the_key", "the_secret") + bucket = conn.create_bucket("foobar") + + requests.post( + "https://s3.amazonaws.com/foobar", + {"key": "the-key", "file": "nothing", "x-amz-meta-test": "metadata"}, + ) + + bucket.get_key("the-key").get_metadata("test").should.equal("metadata") + + +@mock_s3_deprecated +def test_bucket_name_with_dot(): + conn = create_connection() + bucket = conn.create_bucket("firstname.lastname") + + k = Key(bucket, "somekey") + k.set_contents_from_string("somedata") + + +@mock_s3_deprecated +def test_key_with_special_characters(): + conn = create_connection() + bucket = conn.create_bucket("test_bucket_name") + + key = Key(bucket, "test_list_keys_2/*x+?^@~!y") + key.set_contents_from_string("value1") + + key_list = bucket.list("test_list_keys_2/", "/") + keys = [x for x in key_list] + keys[0].name.should.equal("test_list_keys_2/*x+?^@~!y") + + +@mock_s3_deprecated +def test_bucket_key_listing_order(): + conn = create_connection() + bucket = conn.create_bucket("test_bucket") + prefix = "toplevel/" + + def store(name): + k = Key(bucket, prefix + name) + k.set_contents_from_string("somedata") + + names = ["x/key", "y.key1", "y.key2", "y.key3", "x/y/key", "x/y/z/key"] + + for name in names: + store(name) + + delimiter = None + keys = [x.name for x in bucket.list(prefix, delimiter)] + keys.should.equal( + [ + "toplevel/x/key", + "toplevel/x/y/key", + "toplevel/x/y/z/key", + "toplevel/y.key1", + "toplevel/y.key2", + "toplevel/y.key3", + ] + ) + + delimiter = "/" + keys = [x.name for x in bucket.list(prefix, delimiter)] + keys.should.equal( + ["toplevel/y.key1", "toplevel/y.key2", "toplevel/y.key3", "toplevel/x/"] + ) + + # Test delimiter with no prefix + delimiter = "/" + keys = [x.name for x in bucket.list(prefix=None, delimiter=delimiter)] + keys.should.equal(["toplevel/"]) + + delimiter = None + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/key", "toplevel/x/y/key", "toplevel/x/y/z/key"]) + + delimiter = "/" + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/"]) + + +@mock_s3_deprecated +def test_delete_keys(): + conn = create_connection() + bucket = conn.create_bucket("foobar") + + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") + + result = bucket.delete_keys(["file2", "file3"]) + result.deleted.should.have.length_of(2) + result.errors.should.have.length_of(0) + keys = bucket.get_all_keys() + keys.should.have.length_of(2) + keys[0].name.should.equal("file1") + + +@mock_s3_deprecated +def test_delete_keys_with_invalid(): + conn = create_connection() + bucket = conn.create_bucket("foobar") + + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") + + result = bucket.delete_keys(["abc", "file3"]) + + result.deleted.should.have.length_of(1) + result.errors.should.have.length_of(1) + keys = bucket.get_all_keys() + keys.should.have.length_of(3) + keys[0].name.should.equal("file1") diff --git a/tests/test_s3bucket_path/test_s3bucket_path_combo.py b/tests/test_s3bucket_path/test_s3bucket_path_combo.py index 60dd58e85..2ca7107d9 100644 --- a/tests/test_s3bucket_path/test_s3bucket_path_combo.py +++ b/tests/test_s3bucket_path/test_s3bucket_path_combo.py @@ -1,25 +1,25 @@ -from __future__ import unicode_literals - -import boto -from boto.s3.connection import OrdinaryCallingFormat - -from moto import mock_s3_deprecated - - -def create_connection(key=None, secret=None): - return boto.connect_s3(key, secret, calling_format=OrdinaryCallingFormat()) - - -def test_bucketpath_combo_serial(): - @mock_s3_deprecated - def make_bucket_path(): - conn = create_connection() - conn.create_bucket('mybucketpath') - - @mock_s3_deprecated - def make_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') - conn.create_bucket('mybucket') - - make_bucket() - make_bucket_path() +from __future__ import unicode_literals + +import boto +from boto.s3.connection import OrdinaryCallingFormat + +from moto import mock_s3_deprecated + + +def create_connection(key=None, secret=None): + return boto.connect_s3(key, secret, calling_format=OrdinaryCallingFormat()) + + +def test_bucketpath_combo_serial(): + @mock_s3_deprecated + def make_bucket_path(): + conn = create_connection() + conn.create_bucket("mybucketpath") + + @mock_s3_deprecated + def make_bucket(): + conn = boto.connect_s3("the_key", "the_secret") + conn.create_bucket("mybucket") + + make_bucket() + make_bucket_path() diff --git a/tests/test_s3bucket_path/test_s3bucket_path_utils.py b/tests/test_s3bucket_path/test_s3bucket_path_utils.py index 0bcc5cbe0..072968929 100644 --- a/tests/test_s3bucket_path/test_s3bucket_path_utils.py +++ b/tests/test_s3bucket_path/test_s3bucket_path_utils.py @@ -1,16 +1,17 @@ -from __future__ import unicode_literals -from sure import expect -from moto.s3bucket_path.utils import bucket_name_from_url - - -def test_base_url(): - expect(bucket_name_from_url('https://s3.amazonaws.com/')).should.equal(None) - - -def test_localhost_bucket(): - expect(bucket_name_from_url('https://localhost:5000/wfoobar/abc') - ).should.equal("wfoobar") - - -def test_localhost_without_bucket(): - expect(bucket_name_from_url('https://www.localhost:5000')).should.equal(None) +from __future__ import unicode_literals +from sure import expect +from moto.s3bucket_path.utils import bucket_name_from_url + + +def test_base_url(): + expect(bucket_name_from_url("https://s3.amazonaws.com/")).should.equal(None) + + +def test_localhost_bucket(): + expect(bucket_name_from_url("https://localhost:5000/wfoobar/abc")).should.equal( + "wfoobar" + ) + + +def test_localhost_without_bucket(): + expect(bucket_name_from_url("https://www.localhost:5000")).should.equal(None) diff --git a/tests/test_secretsmanager/test_secretsmanager.py b/tests/test_secretsmanager/test_secretsmanager.py index 78b95ee6a..3b8c74e81 100644 --- a/tests/test_secretsmanager/test_secretsmanager.py +++ b/tests/test_secretsmanager/test_secretsmanager.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import unicode_literals import boto3 @@ -5,461 +6,531 @@ import boto3 from moto import mock_secretsmanager from botocore.exceptions import ClientError import string -import unittest import pytz from datetime import datetime -from nose.tools import assert_raises +import sure # noqa +from nose.tools import assert_raises, assert_equal from six import b -DEFAULT_SECRET_NAME = 'test-secret' +DEFAULT_SECRET_NAME = "test-secret" @mock_secretsmanager def test_get_secret_value(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + create_secret = conn.create_secret( + Name="java-util-test-password", SecretString="foosecret" + ) + result = conn.get_secret_value(SecretId="java-util-test-password") + assert result["SecretString"] == "foosecret" + + +@mock_secretsmanager +def test_get_secret_value_by_arn(): + conn = boto3.client("secretsmanager", region_name="us-west-2") + + secret_value = "test_get_secret_value_by_arn" + result = conn.create_secret( + Name="java-util-test-password", SecretString=secret_value + ) + result = conn.get_secret_value(SecretId=result["ARN"]) + assert result["SecretString"] == secret_value - create_secret = conn.create_secret(Name='java-util-test-password', - SecretString="foosecret") - result = conn.get_secret_value(SecretId='java-util-test-password') - assert result['SecretString'] == 'foosecret' @mock_secretsmanager def test_get_secret_value_binary(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + create_secret = conn.create_secret( + Name="java-util-test-password", SecretBinary=b("foosecret") + ) + result = conn.get_secret_value(SecretId="java-util-test-password") + assert result["SecretBinary"] == b("foosecret") - create_secret = conn.create_secret(Name='java-util-test-password', - SecretBinary=b("foosecret")) - result = conn.get_secret_value(SecretId='java-util-test-password') - assert result['SecretBinary'] == b('foosecret') @mock_secretsmanager def test_get_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + with assert_raises(ClientError) as cm: + result = conn.get_secret_value(SecretId="i-dont-exist") + + assert_equal( + "Secrets Manager can't find the specified secret.", + cm.exception.response["Error"]["Message"], + ) - with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='i-dont-exist') @mock_secretsmanager def test_get_secret_that_does_not_match(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - create_secret = conn.create_secret(Name='java-util-test-password', - SecretString="foosecret") + conn = boto3.client("secretsmanager", region_name="us-west-2") + create_secret = conn.create_secret( + Name="java-util-test-password", SecretString="foosecret" + ) - with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='i-dont-match') + with assert_raises(ClientError) as cm: + result = conn.get_secret_value(SecretId="i-dont-match") + + assert_equal( + "Secrets Manager can't find the specified secret.", + cm.exception.response["Error"]["Message"], + ) @mock_secretsmanager def test_get_secret_value_that_is_marked_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.delete_secret(SecretId='test-secret') + conn.delete_secret(SecretId="test-secret") with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='test-secret') + result = conn.get_secret_value(SecretId="test-secret") + + +@mock_secretsmanager +def test_get_secret_that_has_no_value(): + conn = boto3.client("secretsmanager", region_name="us-west-2") + + create_secret = conn.create_secret(Name="java-util-test-password") + + with assert_raises(ClientError) as cm: + result = conn.get_secret_value(SecretId="java-util-test-password") + + assert_equal( + "Secrets Manager can't find the specified secret value for staging label: AWSCURRENT", + cm.exception.response["Error"]["Message"], + ) @mock_secretsmanager def test_create_secret(): - conn = boto3.client('secretsmanager', region_name='us-east-1') + conn = boto3.client("secretsmanager", region_name="us-east-1") + + result = conn.create_secret(Name="test-secret", SecretString="foosecret") + assert result["ARN"] + assert result["Name"] == "test-secret" + secret = conn.get_secret_value(SecretId="test-secret") + assert secret["SecretString"] == "foosecret" - result = conn.create_secret(Name='test-secret', SecretString="foosecret") - assert result['ARN'] - assert result['Name'] == 'test-secret' - secret = conn.get_secret_value(SecretId='test-secret') - assert secret['SecretString'] == 'foosecret' @mock_secretsmanager def test_create_secret_with_tags(): - conn = boto3.client('secretsmanager', region_name='us-east-1') - secret_name = 'test-secret-with-tags' + conn = boto3.client("secretsmanager", region_name="us-east-1") + secret_name = "test-secret-with-tags" result = conn.create_secret( Name=secret_name, SecretString="foosecret", - Tags=[{"Key": "Foo", "Value": "Bar"}, {"Key": "Mykey", "Value": "Myvalue"}] + Tags=[{"Key": "Foo", "Value": "Bar"}, {"Key": "Mykey", "Value": "Myvalue"}], ) - assert result['ARN'] - assert result['Name'] == secret_name + assert result["ARN"] + assert result["Name"] == secret_name secret_value = conn.get_secret_value(SecretId=secret_name) - assert secret_value['SecretString'] == 'foosecret' + assert secret_value["SecretString"] == "foosecret" secret_details = conn.describe_secret(SecretId=secret_name) - assert secret_details['Tags'] == [{"Key": "Foo", "Value": "Bar"}, {"Key": "Mykey", "Value": "Myvalue"}] + assert secret_details["Tags"] == [ + {"Key": "Foo", "Value": "Bar"}, + {"Key": "Mykey", "Value": "Myvalue"}, + ] @mock_secretsmanager def test_delete_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - deleted_secret = conn.delete_secret(SecretId='test-secret') + deleted_secret = conn.delete_secret(SecretId="test-secret") - assert deleted_secret['ARN'] - assert deleted_secret['Name'] == 'test-secret' - assert deleted_secret['DeletionDate'] > datetime.fromtimestamp(1, pytz.utc) + assert deleted_secret["ARN"] + assert deleted_secret["Name"] == "test-secret" + assert deleted_secret["DeletionDate"] > datetime.fromtimestamp(1, pytz.utc) - secret_details = conn.describe_secret(SecretId='test-secret') + secret_details = conn.describe_secret(SecretId="test-secret") - assert secret_details['ARN'] - assert secret_details['Name'] == 'test-secret' - assert secret_details['DeletedDate'] > datetime.fromtimestamp(1, pytz.utc) + assert secret_details["ARN"] + assert secret_details["Name"] == "test-secret" + assert secret_details["DeletedDate"] > datetime.fromtimestamp(1, pytz.utc) @mock_secretsmanager def test_delete_secret_force(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - result = conn.delete_secret(SecretId='test-secret', ForceDeleteWithoutRecovery=True) + result = conn.delete_secret(SecretId="test-secret", ForceDeleteWithoutRecovery=True) - assert result['ARN'] - assert result['DeletionDate'] > datetime.fromtimestamp(1, pytz.utc) - assert result['Name'] == 'test-secret' + assert result["ARN"] + assert result["DeletionDate"] > datetime.fromtimestamp(1, pytz.utc) + assert result["Name"] == "test-secret" with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='test-secret') + result = conn.get_secret_value(SecretId="test-secret") @mock_secretsmanager def test_delete_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='i-dont-exist', ForceDeleteWithoutRecovery=True) + result = conn.delete_secret( + SecretId="i-dont-exist", ForceDeleteWithoutRecovery=True + ) @mock_secretsmanager def test_delete_secret_fails_with_both_force_delete_flag_and_recovery_window_flag(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret', RecoveryWindowInDays=1, ForceDeleteWithoutRecovery=True) + result = conn.delete_secret( + SecretId="test-secret", + RecoveryWindowInDays=1, + ForceDeleteWithoutRecovery=True, + ) @mock_secretsmanager def test_delete_secret_recovery_window_too_short(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret', RecoveryWindowInDays=6) + result = conn.delete_secret(SecretId="test-secret", RecoveryWindowInDays=6) @mock_secretsmanager def test_delete_secret_recovery_window_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret', RecoveryWindowInDays=31) + result = conn.delete_secret(SecretId="test-secret", RecoveryWindowInDays=31) @mock_secretsmanager def test_delete_secret_that_is_marked_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - deleted_secret = conn.delete_secret(SecretId='test-secret') + deleted_secret = conn.delete_secret(SecretId="test-secret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret') + result = conn.delete_secret(SecretId="test-secret") @mock_secretsmanager def test_get_random_password_default_length(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password() - assert len(random_password['RandomPassword']) == 32 + assert len(random_password["RandomPassword"]) == 32 + @mock_secretsmanager def test_get_random_password_default_requirements(): # When require_each_included_type, default true - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password() # Should contain lowercase, upppercase, digit, special character - assert any(c.islower() for c in random_password['RandomPassword']) - assert any(c.isupper() for c in random_password['RandomPassword']) - assert any(c.isdigit() for c in random_password['RandomPassword']) - assert any(c in string.punctuation - for c in random_password['RandomPassword']) + assert any(c.islower() for c in random_password["RandomPassword"]) + assert any(c.isupper() for c in random_password["RandomPassword"]) + assert any(c.isdigit() for c in random_password["RandomPassword"]) + assert any(c in string.punctuation for c in random_password["RandomPassword"]) + @mock_secretsmanager def test_get_random_password_custom_length(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password(PasswordLength=50) - assert len(random_password['RandomPassword']) == 50 + assert len(random_password["RandomPassword"]) == 50 + @mock_secretsmanager def test_get_random_exclude_lowercase(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + random_password = conn.get_random_password(PasswordLength=55, ExcludeLowercase=True) + assert any(c.islower() for c in random_password["RandomPassword"]) == False - random_password = conn.get_random_password(PasswordLength=55, - ExcludeLowercase=True) - assert any(c.islower() for c in random_password['RandomPassword']) == False @mock_secretsmanager def test_get_random_exclude_uppercase(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + random_password = conn.get_random_password(PasswordLength=55, ExcludeUppercase=True) + assert any(c.isupper() for c in random_password["RandomPassword"]) == False - random_password = conn.get_random_password(PasswordLength=55, - ExcludeUppercase=True) - assert any(c.isupper() for c in random_password['RandomPassword']) == False @mock_secretsmanager def test_get_random_exclude_characters_and_symbols(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + random_password = conn.get_random_password( + PasswordLength=20, ExcludeCharacters="xyzDje@?!." + ) + assert any(c in "xyzDje@?!." for c in random_password["RandomPassword"]) == False - random_password = conn.get_random_password(PasswordLength=20, - ExcludeCharacters='xyzDje@?!.') - assert any(c in 'xyzDje@?!.' for c in random_password['RandomPassword']) == False @mock_secretsmanager def test_get_random_exclude_numbers(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + random_password = conn.get_random_password(PasswordLength=100, ExcludeNumbers=True) + assert any(c.isdigit() for c in random_password["RandomPassword"]) == False - random_password = conn.get_random_password(PasswordLength=100, - ExcludeNumbers=True) - assert any(c.isdigit() for c in random_password['RandomPassword']) == False @mock_secretsmanager def test_get_random_exclude_punctuation(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + random_password = conn.get_random_password( + PasswordLength=100, ExcludePunctuation=True + ) + assert ( + any(c in string.punctuation for c in random_password["RandomPassword"]) == False + ) - random_password = conn.get_random_password(PasswordLength=100, - ExcludePunctuation=True) - assert any(c in string.punctuation - for c in random_password['RandomPassword']) == False @mock_secretsmanager def test_get_random_include_space_false(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password(PasswordLength=300) - assert any(c.isspace() for c in random_password['RandomPassword']) == False + assert any(c.isspace() for c in random_password["RandomPassword"]) == False + @mock_secretsmanager def test_get_random_include_space_true(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + random_password = conn.get_random_password(PasswordLength=4, IncludeSpace=True) + assert any(c.isspace() for c in random_password["RandomPassword"]) == True - random_password = conn.get_random_password(PasswordLength=4, - IncludeSpace=True) - assert any(c.isspace() for c in random_password['RandomPassword']) == True @mock_secretsmanager def test_get_random_require_each_included_type(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + + random_password = conn.get_random_password( + PasswordLength=4, RequireEachIncludedType=True + ) + assert ( + any(c in string.punctuation for c in random_password["RandomPassword"]) == True + ) + assert ( + any(c in string.ascii_lowercase for c in random_password["RandomPassword"]) + == True + ) + assert ( + any(c in string.ascii_uppercase for c in random_password["RandomPassword"]) + == True + ) + assert any(c in string.digits for c in random_password["RandomPassword"]) == True - random_password = conn.get_random_password(PasswordLength=4, - RequireEachIncludedType=True) - assert any(c in string.punctuation for c in random_password['RandomPassword']) == True - assert any(c in string.ascii_lowercase for c in random_password['RandomPassword']) == True - assert any(c in string.ascii_uppercase for c in random_password['RandomPassword']) == True - assert any(c in string.digits for c in random_password['RandomPassword']) == True @mock_secretsmanager def test_get_random_too_short_password(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): random_password = conn.get_random_password(PasswordLength=3) + @mock_secretsmanager def test_get_random_too_long_password(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(Exception): random_password = conn.get_random_password(PasswordLength=5555) + @mock_secretsmanager def test_describe_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name='test-secret', - SecretString='foosecret') - - conn.create_secret(Name='test-secret-2', - SecretString='barsecret') - - secret_description = conn.describe_secret(SecretId='test-secret') - secret_description_2 = conn.describe_secret(SecretId='test-secret-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name="test-secret", SecretString="foosecret") + + conn.create_secret(Name="test-secret-2", SecretString="barsecret") + + secret_description = conn.describe_secret(SecretId="test-secret") + secret_description_2 = conn.describe_secret(SecretId="test-secret-2") + + assert secret_description # Returned dict is not empty + assert secret_description["Name"] == ("test-secret") + assert secret_description["ARN"] != "" # Test arn not empty + assert secret_description_2["Name"] == ("test-secret-2") + assert secret_description_2["ARN"] != "" # Test arn not empty + + +@mock_secretsmanager +def test_describe_secret_with_arn(): + conn = boto3.client("secretsmanager", region_name="us-west-2") + results = conn.create_secret(Name="test-secret", SecretString="foosecret") + + secret_description = conn.describe_secret(SecretId=results["ARN"]) + + assert secret_description # Returned dict is not empty + assert secret_description["Name"] == ("test-secret") + assert secret_description["ARN"] != results["ARN"] - assert secret_description # Returned dict is not empty - assert secret_description['Name'] == ('test-secret') - assert secret_description['ARN'] != '' # Test arn not empty - assert secret_description_2['Name'] == ('test-secret-2') - assert secret_description_2['ARN'] != '' # Test arn not empty @mock_secretsmanager def test_describe_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='i-dont-exist') + result = conn.get_secret_value(SecretId="i-dont-exist") + @mock_secretsmanager def test_describe_secret_that_does_not_match(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name='test-secret', - SecretString='foosecret') - + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name="test-secret", SecretString="foosecret") + with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='i-dont-match') + result = conn.get_secret_value(SecretId="i-dont-match") @mock_secretsmanager def test_list_secrets_empty(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") secrets = conn.list_secrets() - assert secrets['SecretList'] == [] + assert secrets["SecretList"] == [] @mock_secretsmanager def test_list_secrets(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.create_secret(Name='test-secret-2', - SecretString='barsecret', - Tags=[{ - 'Key': 'a', - 'Value': '1' - }]) + conn.create_secret( + Name="test-secret-2", + SecretString="barsecret", + Tags=[{"Key": "a", "Value": "1"}], + ) secrets = conn.list_secrets() - assert secrets['SecretList'][0]['ARN'] is not None - assert secrets['SecretList'][0]['Name'] == 'test-secret' - assert secrets['SecretList'][1]['ARN'] is not None - assert secrets['SecretList'][1]['Name'] == 'test-secret-2' - assert secrets['SecretList'][1]['Tags'] == [{ - 'Key': 'a', - 'Value': '1' - }] + assert secrets["SecretList"][0]["ARN"] is not None + assert secrets["SecretList"][0]["Name"] == "test-secret" + assert secrets["SecretList"][1]["ARN"] is not None + assert secrets["SecretList"][1]["Name"] == "test-secret-2" + assert secrets["SecretList"][1]["Tags"] == [{"Key": "a", "Value": "1"}] @mock_secretsmanager def test_restore_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.delete_secret(SecretId='test-secret') + conn.delete_secret(SecretId="test-secret") - described_secret_before = conn.describe_secret(SecretId='test-secret') - assert described_secret_before['DeletedDate'] > datetime.fromtimestamp(1, pytz.utc) + described_secret_before = conn.describe_secret(SecretId="test-secret") + assert described_secret_before["DeletedDate"] > datetime.fromtimestamp(1, pytz.utc) - restored_secret = conn.restore_secret(SecretId='test-secret') - assert restored_secret['ARN'] - assert restored_secret['Name'] == 'test-secret' + restored_secret = conn.restore_secret(SecretId="test-secret") + assert restored_secret["ARN"] + assert restored_secret["Name"] == "test-secret" - described_secret_after = conn.describe_secret(SecretId='test-secret') - assert 'DeletedDate' not in described_secret_after + described_secret_after = conn.describe_secret(SecretId="test-secret") + assert "DeletedDate" not in described_secret_after @mock_secretsmanager def test_restore_secret_that_is_not_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - restored_secret = conn.restore_secret(SecretId='test-secret') - assert restored_secret['ARN'] - assert restored_secret['Name'] == 'test-secret' + restored_secret = conn.restore_secret(SecretId="test-secret") + assert restored_secret["ARN"] + assert restored_secret["Name"] == "test-secret" @mock_secretsmanager def test_restore_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): - result = conn.restore_secret(SecretId='i-dont-exist') + result = conn.restore_secret(SecretId="i-dont-exist") @mock_secretsmanager def test_rotate_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") rotated_secret = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME) assert rotated_secret - assert rotated_secret['ARN'] != '' # Test arn not empty - assert rotated_secret['Name'] == DEFAULT_SECRET_NAME - assert rotated_secret['VersionId'] != '' + assert rotated_secret["ARN"] != "" # Test arn not empty + assert rotated_secret["Name"] == DEFAULT_SECRET_NAME + assert rotated_secret["VersionId"] != "" + @mock_secretsmanager def test_rotate_secret_enable_rotation(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") initial_description = conn.describe_secret(SecretId=DEFAULT_SECRET_NAME) assert initial_description - assert initial_description['RotationEnabled'] is False - assert initial_description['RotationRules']['AutomaticallyAfterDays'] == 0 + assert initial_description["RotationEnabled"] is False + assert initial_description["RotationRules"]["AutomaticallyAfterDays"] == 0 - conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - RotationRules={'AutomaticallyAfterDays': 42}) + conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, RotationRules={"AutomaticallyAfterDays": 42} + ) rotated_description = conn.describe_secret(SecretId=DEFAULT_SECRET_NAME) assert rotated_description - assert rotated_description['RotationEnabled'] is True - assert rotated_description['RotationRules']['AutomaticallyAfterDays'] == 42 + assert rotated_description["RotationEnabled"] is True + assert rotated_description["RotationRules"]["AutomaticallyAfterDays"] == 42 @mock_secretsmanager def test_rotate_secret_that_is_marked_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.delete_secret(SecretId='test-secret') + conn.delete_secret(SecretId="test-secret") with assert_raises(ClientError): - result = conn.rotate_secret(SecretId='test-secret') + result = conn.rotate_secret(SecretId="test-secret") @mock_secretsmanager def test_rotate_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', 'us-west-2') + conn = boto3.client("secretsmanager", "us-west-2") with assert_raises(ClientError): - result = conn.rotate_secret(SecretId='i-dont-exist') + result = conn.rotate_secret(SecretId="i-dont-exist") + @mock_secretsmanager def test_rotate_secret_that_does_not_match(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.rotate_secret(SecretId='i-dont-match') + result = conn.rotate_secret(SecretId="i-dont-match") + @mock_secretsmanager def test_rotate_secret_client_request_token_too_short(): @@ -468,30 +539,32 @@ def test_rotate_secret_client_request_token_too_short(): # test_server actually handles this error. assert True + @mock_secretsmanager def test_rotate_secret_client_request_token_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") client_request_token = ( - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-' - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C' + "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-" "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C" ) with assert_raises(ClientError): - result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - ClientRequestToken=client_request_token) + result = conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, ClientRequestToken=client_request_token + ) + @mock_secretsmanager def test_rotate_secret_rotation_lambda_arn_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") - rotation_lambda_arn = '85B7-446A-B7E4' * 147 # == 2058 characters + rotation_lambda_arn = "85B7-446A-B7E4" * 147 # == 2058 characters with assert_raises(ClientError): - result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - RotationLambdaARN=rotation_lambda_arn) + result = conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, RotationLambdaARN=rotation_lambda_arn + ) + @mock_secretsmanager def test_rotate_secret_rotation_period_zero(): @@ -500,80 +573,141 @@ def test_rotate_secret_rotation_period_zero(): # test_server actually handles this error. assert True + @mock_secretsmanager def test_rotate_secret_rotation_period_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") - rotation_rules = {'AutomaticallyAfterDays': 1001} + rotation_rules = {"AutomaticallyAfterDays": 1001} with assert_raises(ClientError): - result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - RotationRules=rotation_rules) + result = conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, RotationRules=rotation_rules + ) + @mock_secretsmanager def test_put_secret_value_puts_new_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='foosecret', - VersionStages=['AWSCURRENT']) - version_id = put_secret_value_dict['VersionId'] + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="foosecret", + VersionStages=["AWSCURRENT"], + ) + version_id = put_secret_value_dict["VersionId"] - get_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME, - VersionId=version_id, - VersionStage='AWSCURRENT') + get_secret_value_dict = conn.get_secret_value( + SecretId=DEFAULT_SECRET_NAME, VersionId=version_id, VersionStage="AWSCURRENT" + ) assert get_secret_value_dict - assert get_secret_value_dict['SecretString'] == 'foosecret' + assert get_secret_value_dict["SecretString"] == "foosecret" + + +@mock_secretsmanager +def test_put_secret_binary_value_puts_new_secret(): + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretBinary=b("foosecret"), + VersionStages=["AWSCURRENT"], + ) + version_id = put_secret_value_dict["VersionId"] + + get_secret_value_dict = conn.get_secret_value( + SecretId=DEFAULT_SECRET_NAME, VersionId=version_id, VersionStage="AWSCURRENT" + ) + + assert get_secret_value_dict + assert get_secret_value_dict["SecretBinary"] == b("foosecret") + + +@mock_secretsmanager +def test_create_and_put_secret_binary_value_puts_new_secret(): + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret")) + conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret_update") + ) + + latest_secret = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME) + + assert latest_secret + assert latest_secret["SecretBinary"] == b("foosecret_update") + + +@mock_secretsmanager +def test_put_secret_binary_requires_either_string_or_binary(): + conn = boto3.client("secretsmanager", region_name="us-west-2") + with assert_raises(ClientError) as ire: + conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME) + + ire.exception.response["Error"]["Code"].should.equal("InvalidRequestException") + ire.exception.response["Error"]["Message"].should.equal( + "You must provide either SecretString or SecretBinary." + ) + @mock_secretsmanager def test_put_secret_value_can_get_first_version_if_put_twice(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='first_secret', - VersionStages=['AWSCURRENT']) - first_version_id = put_secret_value_dict['VersionId'] - conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='second_secret', - VersionStages=['AWSCURRENT']) + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="first_secret", + VersionStages=["AWSCURRENT"], + ) + first_version_id = put_secret_value_dict["VersionId"] + conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="second_secret", + VersionStages=["AWSCURRENT"], + ) - first_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME, - VersionId=first_version_id) - first_secret_value = first_secret_value_dict['SecretString'] + first_secret_value_dict = conn.get_secret_value( + SecretId=DEFAULT_SECRET_NAME, VersionId=first_version_id + ) + first_secret_value = first_secret_value_dict["SecretString"] - assert first_secret_value == 'first_secret' + assert first_secret_value == "first_secret" @mock_secretsmanager def test_put_secret_value_versions_differ_if_same_secret_put_twice(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - first_version_id = put_secret_value_dict['VersionId'] - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - second_version_id = put_secret_value_dict['VersionId'] + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + first_version_id = put_secret_value_dict["VersionId"] + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + second_version_id = put_secret_value_dict["VersionId"] assert first_version_id != second_version_id @mock_secretsmanager def test_can_list_secret_version_ids(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - first_version_id = put_secret_value_dict['VersionId'] - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - second_version_id = put_secret_value_dict['VersionId'] + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + first_version_id = put_secret_value_dict["VersionId"] + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + second_version_id = put_secret_value_dict["VersionId"] versions_list = conn.list_secret_version_ids(SecretId=DEFAULT_SECRET_NAME) - returned_version_ids = [v['VersionId'] for v in versions_list['Versions']] + returned_version_ids = [v["VersionId"] for v in versions_list["Versions"]] assert [first_version_id, second_version_id].sort() == returned_version_ids.sort() - diff --git a/tests/test_secretsmanager/test_server.py b/tests/test_secretsmanager/test_server.py index 23d823239..81cb641bd 100644 --- a/tests/test_secretsmanager/test_server.py +++ b/tests/test_secretsmanager/test_server.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from __future__ import unicode_literals import json @@ -6,11 +7,11 @@ import sure # noqa import moto.server as server from moto import mock_secretsmanager -''' +""" Test the different server responses for secretsmanager -''' +""" -DEFAULT_SECRET_NAME = 'test-secret' +DEFAULT_SECRET_NAME = "test-secret" @mock_secretsmanager @@ -19,22 +20,21 @@ def test_get_secret_value(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foo-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) - get_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "VersionStage": "AWSCURRENT"}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foo-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + get_secret = test_client.post( + "/", + data={"SecretId": DEFAULT_SECRET_NAME, "VersionStage": "AWSCURRENT"}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) json_data = json.loads(get_secret.data.decode("utf-8")) - assert json_data['SecretString'] == 'foo-secret' + assert json_data["SecretString"] == "foo-secret" + @mock_secretsmanager def test_get_secret_that_does_not_exist(): @@ -42,36 +42,59 @@ def test_get_secret_that_does_not_exist(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - get_secret = test_client.post('/', - data={"SecretId": "i-dont-exist", - "VersionStage": "AWSCURRENT"}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + get_secret = test_client.post( + "/", + data={"SecretId": "i-dont-exist", "VersionStage": "AWSCURRENT"}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) json_data = json.loads(get_secret.data.decode("utf-8")) - assert json_data['message'] == "Secrets Manager can't find the specified secret" - assert json_data['__type'] == 'ResourceNotFoundException' + assert json_data["message"] == "Secrets Manager can't find the specified secret." + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_get_secret_that_does_not_match(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foo-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) - get_secret = test_client.post('/', - data={"SecretId": "i-dont-match", - "VersionStage": "AWSCURRENT"}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foo-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + get_secret = test_client.post( + "/", + data={"SecretId": "i-dont-match", "VersionStage": "AWSCURRENT"}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) json_data = json.loads(get_secret.data.decode("utf-8")) - assert json_data['message'] == "Secrets Manager can't find the specified secret" - assert json_data['__type'] == 'ResourceNotFoundException' + assert json_data["message"] == "Secrets Manager can't find the specified secret." + assert json_data["__type"] == "ResourceNotFoundException" + + +@mock_secretsmanager +def test_get_secret_that_has_no_value(): + backend = server.create_backend_app("secretsmanager") + test_client = backend.test_client() + + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + get_secret = test_client.post( + "/", + data={"SecretId": DEFAULT_SECRET_NAME}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) + + json_data = json.loads(get_secret.data.decode("utf-8")) + assert ( + json_data["message"] + == "Secrets Manager can't find the specified secret value for staging label: AWSCURRENT" + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_create_secret(): @@ -79,139 +102,131 @@ def test_create_secret(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - res = test_client.post('/', - data={"Name": "test-secret", - "SecretString": "foo-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) - res_2 = test_client.post('/', - data={"Name": "test-secret-2", - "SecretString": "bar-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) + res = test_client.post( + "/", + data={"Name": "test-secret", "SecretString": "foo-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + res_2 = test_client.post( + "/", + data={"Name": "test-secret-2", "SecretString": "bar-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) json_data = json.loads(res.data.decode("utf-8")) - assert json_data['ARN'] != '' - assert json_data['Name'] == 'test-secret' - + assert json_data["ARN"] != "" + assert json_data["Name"] == "test-secret" + json_data_2 = json.loads(res_2.data.decode("utf-8")) - assert json_data_2['ARN'] != '' - assert json_data_2['Name'] == 'test-secret-2' + assert json_data_2["ARN"] != "" + assert json_data_2["Name"] == "test-secret-2" + @mock_secretsmanager def test_describe_secret(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": "test-secret", - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) - describe_secret = test_client.post('/', - data={"SecretId": "test-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) - - create_secret_2 = test_client.post('/', - data={"Name": "test-secret-2", - "SecretString": "barsecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) - describe_secret_2 = test_client.post('/', - data={"SecretId": "test-secret-2"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": "test-secret", "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + describe_secret = test_client.post( + "/", + data={"SecretId": "test-secret"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) + + create_secret_2 = test_client.post( + "/", + data={"Name": "test-secret-2", "SecretString": "barsecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + describe_secret_2 = test_client.post( + "/", + data={"SecretId": "test-secret-2"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) json_data = json.loads(describe_secret.data.decode("utf-8")) - assert json_data # Returned dict is not empty - assert json_data['ARN'] != '' - assert json_data['Name'] == 'test-secret' - + assert json_data # Returned dict is not empty + assert json_data["ARN"] != "" + assert json_data["Name"] == "test-secret" + json_data_2 = json.loads(describe_secret_2.data.decode("utf-8")) - assert json_data_2 # Returned dict is not empty - assert json_data_2['ARN'] != '' - assert json_data_2['Name'] == 'test-secret-2' + assert json_data_2 # Returned dict is not empty + assert json_data_2["ARN"] != "" + assert json_data_2["Name"] == "test-secret-2" + @mock_secretsmanager def test_describe_secret_that_does_not_exist(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - describe_secret = test_client.post('/', - data={"SecretId": "i-dont-exist"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) + describe_secret = test_client.post( + "/", + data={"SecretId": "i-dont-exist"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) json_data = json.loads(describe_secret.data.decode("utf-8")) - assert json_data['message'] == "Secrets Manager can't find the specified secret" - assert json_data['__type'] == 'ResourceNotFoundException' + assert json_data["message"] == "Secrets Manager can't find the specified secret." + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_describe_secret_that_does_not_match(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) - describe_secret = test_client.post('/', - data={"SecretId": "i-dont-match"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + describe_secret = test_client.post( + "/", + data={"SecretId": "i-dont-match"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) json_data = json.loads(describe_secret.data.decode("utf-8")) - assert json_data['message'] == "Secrets Manager can't find the specified secret" - assert json_data['__type'] == 'ResourceNotFoundException' + assert json_data["message"] == "Secrets Manager can't find the specified secret." + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_rotate_secret(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) client_request_token = "EXAMPLE2-90ab-cdef-fedc-ba987SECRET2" - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "ClientRequestToken": client_request_token}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "ClientRequestToken": client_request_token, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data # Returned dict is not empty - assert json_data['ARN'] != '' - assert json_data['Name'] == DEFAULT_SECRET_NAME - assert json_data['VersionId'] == client_request_token + assert json_data # Returned dict is not empty + assert json_data["ARN"] != "" + assert json_data["Name"] == DEFAULT_SECRET_NAME + assert json_data["VersionId"] == client_request_token + # @mock_secretsmanager # def test_rotate_secret_enable_rotation(): @@ -270,291 +285,335 @@ def test_rotate_secret(): # assert json_data['RotationEnabled'] is True # assert json_data['RotationRules']['AutomaticallyAfterDays'] == 42 + @mock_secretsmanager def test_rotate_secret_that_does_not_exist(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - rotate_secret = test_client.post('/', - data={"SecretId": "i-dont-exist"}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={"SecretId": "i-dont-exist"}, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "Secrets Manager can't find the specified secret" - assert json_data['__type'] == 'ResourceNotFoundException' + assert json_data["message"] == "Secrets Manager can't find the specified secret." + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_rotate_secret_that_does_not_match(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) - rotate_secret = test_client.post('/', - data={"SecretId": "i-dont-match"}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={"SecretId": "i-dont-match"}, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "Secrets Manager can't find the specified secret" - assert json_data['__type'] == 'ResourceNotFoundException' + assert json_data["message"] == "Secrets Manager can't find the specified secret." + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_rotate_secret_client_request_token_too_short(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) client_request_token = "ED9F8B6C-85B7-B7E4-38F2A3BEB13C" - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "ClientRequestToken": client_request_token}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "ClientRequestToken": client_request_token, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "ClientRequestToken must be 32-64 characters long." - assert json_data['__type'] == 'InvalidParameterException' + assert json_data["message"] == "ClientRequestToken must be 32-64 characters long." + assert json_data["__type"] == "InvalidParameterException" + @mock_secretsmanager def test_rotate_secret_client_request_token_too_long(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) client_request_token = ( - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-' - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C' + "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-" "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C" + ) + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "ClientRequestToken": client_request_token, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, ) - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "ClientRequestToken": client_request_token}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "ClientRequestToken must be 32-64 characters long." - assert json_data['__type'] == 'InvalidParameterException' + assert json_data["message"] == "ClientRequestToken must be 32-64 characters long." + assert json_data["__type"] == "InvalidParameterException" + @mock_secretsmanager def test_rotate_secret_rotation_lambda_arn_too_long(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) - rotation_lambda_arn = '85B7-446A-B7E4' * 147 # == 2058 characters - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "RotationLambdaARN": rotation_lambda_arn}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotation_lambda_arn = "85B7-446A-B7E4" * 147 # == 2058 characters + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "RotationLambdaARN": rotation_lambda_arn, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "RotationLambdaARN must <= 2048 characters long." - assert json_data['__type'] == 'InvalidParameterException' - - - + assert json_data["message"] == "RotationLambdaARN must <= 2048 characters long." + assert json_data["__type"] == "InvalidParameterException" @mock_secretsmanager def test_put_secret_value_puts_new_secret(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "foosecret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) + test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "foosecret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) - put_second_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "foosecret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - second_secret_json_data = json.loads(put_second_secret_value_json.data.decode("utf-8")) + put_second_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "foosecret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + second_secret_json_data = json.loads( + put_second_secret_value_json.data.decode("utf-8") + ) - version_id = second_secret_json_data['VersionId'] + version_id = second_secret_json_data["VersionId"] - secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "VersionId": version_id, - "VersionStage": 'AWSCURRENT'}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "VersionId": version_id, + "VersionStage": "AWSCURRENT", + }, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) second_secret_json_data = json.loads(secret_value_json.data.decode("utf-8")) assert second_secret_json_data - assert second_secret_json_data['SecretString'] == 'foosecret' + assert second_secret_json_data["SecretString"] == "foosecret" @mock_secretsmanager def test_put_secret_value_can_get_first_version_if_put_twice(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - first_secret_string = 'first_secret' - second_secret_string = 'second_secret' + first_secret_string = "first_secret" + second_secret_string = "second_secret" - put_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": first_secret_string, - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) + put_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": first_secret_string, + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) - first_secret_json_data = json.loads(put_first_secret_value_json.data.decode("utf-8")) + first_secret_json_data = json.loads( + put_first_secret_value_json.data.decode("utf-8") + ) - first_secret_version_id = first_secret_json_data['VersionId'] + first_secret_version_id = first_secret_json_data["VersionId"] - test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": second_secret_string, - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) + test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": second_secret_string, + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) - get_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "VersionId": first_secret_version_id, - "VersionStage": 'AWSCURRENT'}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + get_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "VersionId": first_secret_version_id, + "VersionStage": "AWSCURRENT", + }, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) - get_first_secret_json_data = json.loads(get_first_secret_value_json.data.decode("utf-8")) + get_first_secret_json_data = json.loads( + get_first_secret_value_json.data.decode("utf-8") + ) assert get_first_secret_json_data - assert get_first_secret_json_data['SecretString'] == first_secret_string + assert get_first_secret_json_data["SecretString"] == first_secret_string @mock_secretsmanager def test_put_secret_value_versions_differ_if_same_secret_put_twice(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - put_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - first_secret_json_data = json.loads(put_first_secret_value_json.data.decode("utf-8")) - first_secret_version_id = first_secret_json_data['VersionId'] + put_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + first_secret_json_data = json.loads( + put_first_secret_value_json.data.decode("utf-8") + ) + first_secret_version_id = first_secret_json_data["VersionId"] - put_second_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - second_secret_json_data = json.loads(put_second_secret_value_json.data.decode("utf-8")) - second_secret_version_id = second_secret_json_data['VersionId'] + put_second_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + second_secret_json_data = json.loads( + put_second_secret_value_json.data.decode("utf-8") + ) + second_secret_version_id = second_secret_json_data["VersionId"] assert first_secret_version_id != second_secret_version_id @mock_secretsmanager def test_can_list_secret_version_ids(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - put_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - first_secret_json_data = json.loads(put_first_secret_value_json.data.decode("utf-8")) - first_secret_version_id = first_secret_json_data['VersionId'] - put_second_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - second_secret_json_data = json.loads(put_second_secret_value_json.data.decode("utf-8")) - second_secret_version_id = second_secret_json_data['VersionId'] + put_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + first_secret_json_data = json.loads( + put_first_secret_value_json.data.decode("utf-8") + ) + first_secret_version_id = first_secret_json_data["VersionId"] + put_second_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + second_secret_json_data = json.loads( + put_second_secret_value_json.data.decode("utf-8") + ) + second_secret_version_id = second_secret_json_data["VersionId"] - list_secret_versions_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, }, - headers={ - "X-Amz-Target": "secretsmanager.ListSecretVersionIds"}, - ) + list_secret_versions_json = test_client.post( + "/", + data={"SecretId": DEFAULT_SECRET_NAME}, + headers={"X-Amz-Target": "secretsmanager.ListSecretVersionIds"}, + ) versions_list = json.loads(list_secret_versions_json.data.decode("utf-8")) - returned_version_ids = [v['VersionId'] for v in versions_list['Versions']] + returned_version_ids = [v["VersionId"] for v in versions_list["Versions"]] + + assert [ + first_secret_version_id, + second_secret_version_id, + ].sort() == returned_version_ids.sort() + + +@mock_secretsmanager +def test_get_resource_policy_secret(): + + backend = server.create_backend_app("secretsmanager") + test_client = backend.test_client() + + create_secret = test_client.post( + "/", + data={"Name": "test-secret", "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + describe_secret = test_client.post( + "/", + data={"SecretId": "test-secret"}, + headers={"X-Amz-Target": "secretsmanager.GetResourcePolicy"}, + ) + + json_data = json.loads(describe_secret.data.decode("utf-8")) + assert json_data # Returned dict is not empty + assert json_data["ARN"] != "" + assert json_data["Name"] == "test-secret" - assert [first_secret_version_id, second_secret_version_id].sort() == returned_version_ids.sort() # # The following tests should work, but fail on the embedded dict in # RotationRules. The error message suggests a problem deeper in the code, which # needs further investigation. -# +# # @mock_secretsmanager # def test_rotate_secret_rotation_period_zero(): diff --git a/tests/test_ses/test_server.py b/tests/test_ses/test_server.py index e679f06fb..b9d2252ce 100644 --- a/tests/test_ses/test_server.py +++ b/tests/test_ses/test_server.py @@ -1,16 +1,16 @@ -from __future__ import unicode_literals -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_ses_list_identities(): - backend = server.create_backend_app("ses") - test_client = backend.test_client() - - res = test_client.get('/?Action=ListIdentities') - res.data.should.contain(b"ListIdentitiesResponse") +from __future__ import unicode_literals +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_ses_list_identities(): + backend = server.create_backend_app("ses") + test_client = backend.test_client() + + res = test_client.get("/?Action=ListIdentities") + res.data.should.contain(b"ListIdentitiesResponse") diff --git a/tests/test_ses/test_ses.py b/tests/test_ses/test_ses.py index 4514267c3..851327b9d 100644 --- a/tests/test_ses/test_ses.py +++ b/tests/test_ses/test_ses.py @@ -1,116 +1,129 @@ -from __future__ import unicode_literals -import email - -import boto -from boto.exception import BotoServerError - -import sure # noqa - -from moto import mock_ses_deprecated - - -@mock_ses_deprecated -def test_verify_email_identity(): - conn = boto.connect_ses('the_key', 'the_secret') - conn.verify_email_identity("test@example.com") - - identities = conn.list_identities() - address = identities['ListIdentitiesResponse'][ - 'ListIdentitiesResult']['Identities'][0] - address.should.equal('test@example.com') - - -@mock_ses_deprecated -def test_domain_verify(): - conn = boto.connect_ses('the_key', 'the_secret') - - conn.verify_domain_dkim("domain1.com") - conn.verify_domain_identity("domain2.com") - - identities = conn.list_identities() - domains = list(identities['ListIdentitiesResponse'][ - 'ListIdentitiesResult']['Identities']) - domains.should.equal(['domain1.com', 'domain2.com']) - - -@mock_ses_deprecated -def test_delete_identity(): - conn = boto.connect_ses('the_key', 'the_secret') - conn.verify_email_identity("test@example.com") - - conn.list_identities()['ListIdentitiesResponse']['ListIdentitiesResult'][ - 'Identities'].should.have.length_of(1) - conn.delete_identity("test@example.com") - conn.list_identities()['ListIdentitiesResponse']['ListIdentitiesResult'][ - 'Identities'].should.have.length_of(0) - - -@mock_ses_deprecated -def test_send_email(): - conn = boto.connect_ses('the_key', 'the_secret') - - conn.send_email.when.called_with( - "test@example.com", "test subject", - "test body", "test_to@example.com").should.throw(BotoServerError) - - conn.verify_email_identity("test@example.com") - conn.send_email("test@example.com", "test subject", - "test body", "test_to@example.com") - - send_quota = conn.get_send_quota() - sent_count = int(send_quota['GetSendQuotaResponse'][ - 'GetSendQuotaResult']['SentLast24Hours']) - sent_count.should.equal(1) - - -@mock_ses_deprecated -def test_send_html_email(): - conn = boto.connect_ses('the_key', 'the_secret') - - conn.send_email.when.called_with( - "test@example.com", "test subject", - "test body", "test_to@example.com", format="html").should.throw(BotoServerError) - - conn.verify_email_identity("test@example.com") - conn.send_email("test@example.com", "test subject", - "test body", "test_to@example.com", format="html") - - send_quota = conn.get_send_quota() - sent_count = int(send_quota['GetSendQuotaResponse'][ - 'GetSendQuotaResult']['SentLast24Hours']) - sent_count.should.equal(1) - - -@mock_ses_deprecated -def test_send_raw_email(): - conn = boto.connect_ses('the_key', 'the_secret') - - message = email.mime.multipart.MIMEMultipart() - message['Subject'] = 'Test' - message['From'] = 'test@example.com' - message['To'] = 'to@example.com' - - # Message body - part = email.mime.text.MIMEText('test file attached') - message.attach(part) - - # Attachment - part = email.mime.text.MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') - message.attach(part) - - conn.send_raw_email.when.called_with( - source=message['From'], - raw_message=message.as_string(), - ).should.throw(BotoServerError) - - conn.verify_email_identity("test@example.com") - conn.send_raw_email( - source=message['From'], - raw_message=message.as_string(), - ) - - send_quota = conn.get_send_quota() - sent_count = int(send_quota['GetSendQuotaResponse'][ - 'GetSendQuotaResult']['SentLast24Hours']) - sent_count.should.equal(1) +from __future__ import unicode_literals +import email + +import boto +from boto.exception import BotoServerError + +import sure # noqa + +from moto import mock_ses_deprecated + + +@mock_ses_deprecated +def test_verify_email_identity(): + conn = boto.connect_ses("the_key", "the_secret") + conn.verify_email_identity("test@example.com") + + identities = conn.list_identities() + address = identities["ListIdentitiesResponse"]["ListIdentitiesResult"][ + "Identities" + ][0] + address.should.equal("test@example.com") + + +@mock_ses_deprecated +def test_domain_verify(): + conn = boto.connect_ses("the_key", "the_secret") + + conn.verify_domain_dkim("domain1.com") + conn.verify_domain_identity("domain2.com") + + identities = conn.list_identities() + domains = list( + identities["ListIdentitiesResponse"]["ListIdentitiesResult"]["Identities"] + ) + domains.should.equal(["domain1.com", "domain2.com"]) + + +@mock_ses_deprecated +def test_delete_identity(): + conn = boto.connect_ses("the_key", "the_secret") + conn.verify_email_identity("test@example.com") + + conn.list_identities()["ListIdentitiesResponse"]["ListIdentitiesResult"][ + "Identities" + ].should.have.length_of(1) + conn.delete_identity("test@example.com") + conn.list_identities()["ListIdentitiesResponse"]["ListIdentitiesResult"][ + "Identities" + ].should.have.length_of(0) + + +@mock_ses_deprecated +def test_send_email(): + conn = boto.connect_ses("the_key", "the_secret") + + conn.send_email.when.called_with( + "test@example.com", "test subject", "test body", "test_to@example.com" + ).should.throw(BotoServerError) + + conn.verify_email_identity("test@example.com") + conn.send_email( + "test@example.com", "test subject", "test body", "test_to@example.com" + ) + + send_quota = conn.get_send_quota() + sent_count = int( + send_quota["GetSendQuotaResponse"]["GetSendQuotaResult"]["SentLast24Hours"] + ) + sent_count.should.equal(1) + + +@mock_ses_deprecated +def test_send_html_email(): + conn = boto.connect_ses("the_key", "the_secret") + + conn.send_email.when.called_with( + "test@example.com", + "test subject", + "test body", + "test_to@example.com", + format="html", + ).should.throw(BotoServerError) + + conn.verify_email_identity("test@example.com") + conn.send_email( + "test@example.com", + "test subject", + "test body", + "test_to@example.com", + format="html", + ) + + send_quota = conn.get_send_quota() + sent_count = int( + send_quota["GetSendQuotaResponse"]["GetSendQuotaResult"]["SentLast24Hours"] + ) + sent_count.should.equal(1) + + +@mock_ses_deprecated +def test_send_raw_email(): + conn = boto.connect_ses("the_key", "the_secret") + + message = email.mime.multipart.MIMEMultipart() + message["Subject"] = "Test" + message["From"] = "test@example.com" + message["To"] = "to@example.com" + + # Message body + part = email.mime.text.MIMEText("test file attached") + message.attach(part) + + # Attachment + part = email.mime.text.MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") + message.attach(part) + + conn.send_raw_email.when.called_with( + source=message["From"], raw_message=message.as_string() + ).should.throw(BotoServerError) + + conn.verify_email_identity("test@example.com") + conn.send_raw_email(source=message["From"], raw_message=message.as_string()) + + send_quota = conn.get_send_quota() + sent_count = int( + send_quota["GetSendQuotaResponse"]["GetSendQuotaResult"]["SentLast24Hours"] + ) + sent_count.should.equal(1) diff --git a/tests/test_ses/test_ses_boto3.py b/tests/test_ses/test_ses_boto3.py index 00d44bffa..ee7c92aa1 100644 --- a/tests/test_ses/test_ses_boto3.py +++ b/tests/test_ses/test_ses_boto3.py @@ -1,194 +1,216 @@ -from __future__ import unicode_literals - -import boto3 -from botocore.exceptions import ClientError -from six.moves.email_mime_multipart import MIMEMultipart -from six.moves.email_mime_text import MIMEText - -import sure # noqa - -from moto import mock_ses - - -@mock_ses -def test_verify_email_identity(): - conn = boto3.client('ses', region_name='us-east-1') - conn.verify_email_identity(EmailAddress="test@example.com") - - identities = conn.list_identities() - address = identities['Identities'][0] - address.should.equal('test@example.com') - -@mock_ses -def test_verify_email_address(): - conn = boto3.client('ses', region_name='us-east-1') - conn.verify_email_address(EmailAddress="test@example.com") - email_addresses = conn.list_verified_email_addresses() - email = email_addresses['VerifiedEmailAddresses'][0] - email.should.equal('test@example.com') - -@mock_ses -def test_domain_verify(): - conn = boto3.client('ses', region_name='us-east-1') - - conn.verify_domain_dkim(Domain="domain1.com") - conn.verify_domain_identity(Domain="domain2.com") - - identities = conn.list_identities() - domains = list(identities['Identities']) - domains.should.equal(['domain1.com', 'domain2.com']) - - -@mock_ses -def test_delete_identity(): - conn = boto3.client('ses', region_name='us-east-1') - conn.verify_email_identity(EmailAddress="test@example.com") - - conn.list_identities()['Identities'].should.have.length_of(1) - conn.delete_identity(Identity="test@example.com") - conn.list_identities()['Identities'].should.have.length_of(0) - - -@mock_ses -def test_send_email(): - conn = boto3.client('ses', region_name='us-east-1') - - kwargs = dict( - Source="test@example.com", - Destination={ - "ToAddresses": ["test_to@example.com"], - "CcAddresses": ["test_cc@example.com"], - "BccAddresses": ["test_bcc@example.com"], - }, - Message={ - "Subject": {"Data": "test subject"}, - "Body": {"Text": {"Data": "test body"}} - } - ) - conn.send_email.when.called_with(**kwargs).should.throw(ClientError) - - conn.verify_domain_identity(Domain='example.com') - conn.send_email(**kwargs) - - too_many_addresses = list('to%s@example.com' % i for i in range(51)) - conn.send_email.when.called_with( - **dict(kwargs, Destination={'ToAddresses': too_many_addresses}) - ).should.throw(ClientError) - - send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) - sent_count.should.equal(3) - - -@mock_ses -def test_send_html_email(): - conn = boto3.client('ses', region_name='us-east-1') - - kwargs = dict( - Source="test@example.com", - Destination={ - "ToAddresses": ["test_to@example.com"] - }, - Message={ - "Subject": {"Data": "test subject"}, - "Body": {"Html": {"Data": "test body"}} - } - ) - - conn.send_email.when.called_with(**kwargs).should.throw(ClientError) - - conn.verify_email_identity(EmailAddress="test@example.com") - conn.send_email(**kwargs) - - send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) - sent_count.should.equal(1) - - -@mock_ses -def test_send_raw_email(): - conn = boto3.client('ses', region_name='us-east-1') - - message = MIMEMultipart() - message['Subject'] = 'Test' - message['From'] = 'test@example.com' - message['To'] = 'to@example.com, foo@example.com' - - # Message body - part = MIMEText('test file attached') - message.attach(part) - - # Attachment - part = MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') - message.attach(part) - - kwargs = dict( - Source=message['From'], - RawMessage={'Data': message.as_string()}, - ) - - conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) - - conn.verify_email_identity(EmailAddress="test@example.com") - conn.send_raw_email(**kwargs) - - send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) - sent_count.should.equal(2) - - -@mock_ses -def test_send_raw_email_without_source(): - conn = boto3.client('ses', region_name='us-east-1') - - message = MIMEMultipart() - message['Subject'] = 'Test' - message['From'] = 'test@example.com' - message['To'] = 'to@example.com, foo@example.com' - - # Message body - part = MIMEText('test file attached') - message.attach(part) - - # Attachment - part = MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') - message.attach(part) - - kwargs = dict( - RawMessage={'Data': message.as_string()}, - ) - - conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) - - conn.verify_email_identity(EmailAddress="test@example.com") - conn.send_raw_email(**kwargs) - - send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) - sent_count.should.equal(2) - - -@mock_ses -def test_send_raw_email_without_source_or_from(): - conn = boto3.client('ses', region_name='us-east-1') - - message = MIMEMultipart() - message['Subject'] = 'Test' - message['To'] = 'to@example.com, foo@example.com' - - # Message body - part = MIMEText('test file attached') - message.attach(part) - # Attachment - part = MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') - message.attach(part) - - kwargs = dict( - RawMessage={'Data': message.as_string()}, - ) - - conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) - +from __future__ import unicode_literals + +import boto3 +from botocore.exceptions import ClientError +from six.moves.email_mime_multipart import MIMEMultipart +from six.moves.email_mime_text import MIMEText + +import sure # noqa + +from moto import mock_ses + + +@mock_ses +def test_verify_email_identity(): + conn = boto3.client("ses", region_name="us-east-1") + conn.verify_email_identity(EmailAddress="test@example.com") + + identities = conn.list_identities() + address = identities["Identities"][0] + address.should.equal("test@example.com") + + +@mock_ses +def test_verify_email_address(): + conn = boto3.client("ses", region_name="us-east-1") + conn.verify_email_address(EmailAddress="test@example.com") + email_addresses = conn.list_verified_email_addresses() + email = email_addresses["VerifiedEmailAddresses"][0] + email.should.equal("test@example.com") + + +@mock_ses +def test_domain_verify(): + conn = boto3.client("ses", region_name="us-east-1") + + conn.verify_domain_dkim(Domain="domain1.com") + conn.verify_domain_identity(Domain="domain2.com") + + identities = conn.list_identities() + domains = list(identities["Identities"]) + domains.should.equal(["domain1.com", "domain2.com"]) + + +@mock_ses +def test_delete_identity(): + conn = boto3.client("ses", region_name="us-east-1") + conn.verify_email_identity(EmailAddress="test@example.com") + + conn.list_identities()["Identities"].should.have.length_of(1) + conn.delete_identity(Identity="test@example.com") + conn.list_identities()["Identities"].should.have.length_of(0) + + +@mock_ses +def test_send_email(): + conn = boto3.client("ses", region_name="us-east-1") + + kwargs = dict( + Source="test@example.com", + Destination={ + "ToAddresses": ["test_to@example.com"], + "CcAddresses": ["test_cc@example.com"], + "BccAddresses": ["test_bcc@example.com"], + }, + Message={ + "Subject": {"Data": "test subject"}, + "Body": {"Text": {"Data": "test body"}}, + }, + ) + conn.send_email.when.called_with(**kwargs).should.throw(ClientError) + + conn.verify_domain_identity(Domain="example.com") + conn.send_email(**kwargs) + + too_many_addresses = list("to%s@example.com" % i for i in range(51)) + conn.send_email.when.called_with( + **dict(kwargs, Destination={"ToAddresses": too_many_addresses}) + ).should.throw(ClientError) + + send_quota = conn.get_send_quota() + sent_count = int(send_quota["SentLast24Hours"]) + sent_count.should.equal(3) + + +@mock_ses +def test_send_templated_email(): + conn = boto3.client("ses", region_name="us-east-1") + + kwargs = dict( + Source="test@example.com", + Destination={ + "ToAddresses": ["test_to@example.com"], + "CcAddresses": ["test_cc@example.com"], + "BccAddresses": ["test_bcc@example.com"], + }, + Template="test_template", + TemplateData='{"name": "test"}', + ) + + conn.send_templated_email.when.called_with(**kwargs).should.throw(ClientError) + + conn.verify_domain_identity(Domain="example.com") + conn.send_templated_email(**kwargs) + + too_many_addresses = list("to%s@example.com" % i for i in range(51)) + conn.send_templated_email.when.called_with( + **dict(kwargs, Destination={"ToAddresses": too_many_addresses}) + ).should.throw(ClientError) + + send_quota = conn.get_send_quota() + sent_count = int(send_quota["SentLast24Hours"]) + sent_count.should.equal(3) + + +@mock_ses +def test_send_html_email(): + conn = boto3.client("ses", region_name="us-east-1") + + kwargs = dict( + Source="test@example.com", + Destination={"ToAddresses": ["test_to@example.com"]}, + Message={ + "Subject": {"Data": "test subject"}, + "Body": {"Html": {"Data": "test body"}}, + }, + ) + + conn.send_email.when.called_with(**kwargs).should.throw(ClientError) + + conn.verify_email_identity(EmailAddress="test@example.com") + conn.send_email(**kwargs) + + send_quota = conn.get_send_quota() + sent_count = int(send_quota["SentLast24Hours"]) + sent_count.should.equal(1) + + +@mock_ses +def test_send_raw_email(): + conn = boto3.client("ses", region_name="us-east-1") + + message = MIMEMultipart() + message["Subject"] = "Test" + message["From"] = "test@example.com" + message["To"] = "to@example.com, foo@example.com" + + # Message body + part = MIMEText("test file attached") + message.attach(part) + + # Attachment + part = MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") + message.attach(part) + + kwargs = dict(Source=message["From"], RawMessage={"Data": message.as_string()}) + + conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) + + conn.verify_email_identity(EmailAddress="test@example.com") + conn.send_raw_email(**kwargs) + + send_quota = conn.get_send_quota() + sent_count = int(send_quota["SentLast24Hours"]) + sent_count.should.equal(2) + + +@mock_ses +def test_send_raw_email_without_source(): + conn = boto3.client("ses", region_name="us-east-1") + + message = MIMEMultipart() + message["Subject"] = "Test" + message["From"] = "test@example.com" + message["To"] = "to@example.com, foo@example.com" + + # Message body + part = MIMEText("test file attached") + message.attach(part) + + # Attachment + part = MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") + message.attach(part) + + kwargs = dict(RawMessage={"Data": message.as_string()}) + + conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) + + conn.verify_email_identity(EmailAddress="test@example.com") + conn.send_raw_email(**kwargs) + + send_quota = conn.get_send_quota() + sent_count = int(send_quota["SentLast24Hours"]) + sent_count.should.equal(2) + + +@mock_ses +def test_send_raw_email_without_source_or_from(): + conn = boto3.client("ses", region_name="us-east-1") + + message = MIMEMultipart() + message["Subject"] = "Test" + message["To"] = "to@example.com, foo@example.com" + + # Message body + part = MIMEText("test file attached") + message.attach(part) + # Attachment + part = MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") + message.attach(part) + + kwargs = dict(RawMessage={"Data": message.as_string()}) + + conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) diff --git a/tests/test_ses/test_ses_sns_boto3.py b/tests/test_ses/test_ses_sns_boto3.py index 37f79a8b0..fc58d88aa 100644 --- a/tests/test_ses/test_ses_sns_boto3.py +++ b/tests/test_ses/test_ses_sns_boto3.py @@ -10,23 +10,21 @@ import sure # noqa from nose import tools from moto import mock_ses, mock_sns, mock_sqs from moto.ses.models import SESFeedback +from moto.core import ACCOUNT_ID @mock_ses def test_enable_disable_ses_sns_communication(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") conn.set_identity_notification_topic( - Identity='test.com', - NotificationType='Bounce', - SnsTopic='the-arn' - ) - conn.set_identity_notification_topic( - Identity='test.com', - NotificationType='Bounce' + Identity="test.com", NotificationType="Bounce", SnsTopic="the-arn" ) + conn.set_identity_notification_topic(Identity="test.com", NotificationType="Bounce") -def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg): +def __setup_feedback_env__( + ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg +): """Setup the AWS environment to test the SES SNS Feedback""" # Environment setup # Create SQS queue @@ -35,30 +33,32 @@ def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, r create_topic_response = sns_conn.create_topic(Name=topic) topic_arn = create_topic_response["TopicArn"] # Subscribe the SNS topic to the SQS queue - sns_conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:%s:123456789012:%s" % (region, queue)) + sns_conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:%s:%s:%s" % (region, ACCOUNT_ID, queue), + ) # Verify SES domain ses_conn.verify_domain_identity(Domain=domain) # Setup SES notification topic if expected_msg is not None: ses_conn.set_identity_notification_topic( - Identity=domain, - NotificationType=expected_msg, - SnsTopic=topic_arn + Identity=domain, NotificationType=expected_msg, SnsTopic=topic_arn ) def __test_sns_feedback__(addr, expected_msg): region_name = "us-east-1" - ses_conn = boto3.client('ses', region_name=region_name) - sns_conn = boto3.client('sns', region_name=region_name) - sqs_conn = boto3.resource('sqs', region_name=region_name) + ses_conn = boto3.client("ses", region_name=region_name) + sns_conn = boto3.client("sns", region_name=region_name) + sqs_conn = boto3.resource("sqs", region_name=region_name) domain = "example.com" topic = "bounce-arn-feedback" queue = "feedback-test-queue" - __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg) + __setup_feedback_env__( + ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg + ) # Send the message kwargs = dict( @@ -70,8 +70,8 @@ def __test_sns_feedback__(addr, expected_msg): }, Message={ "Subject": {"Data": "test subject"}, - "Body": {"Text": {"Data": "test body"}} - } + "Body": {"Text": {"Data": "test body"}}, + }, ) ses_conn.send_email(**kwargs) diff --git a/tests/test_sns/test_application.py b/tests/test_sns/test_application.py index e8b5838c0..e4fe93d53 100644 --- a/tests/test_sns/test_application.py +++ b/tests/test_sns/test_application.py @@ -1,308 +1,308 @@ -from __future__ import unicode_literals - -import boto -from boto.exception import BotoServerError -from moto import mock_sns_deprecated -import sure # noqa - - -@mock_sns_deprecated -def test_create_platform_application(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - attributes={ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "platform_principal", - }, - ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - application_arn.should.equal( - 'arn:aws:sns:us-east-1:123456789012:app/APNS/my-application') - - -@mock_sns_deprecated -def test_get_platform_application_attributes(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - attributes={ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "platform_principal", - }, - ) - arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - attributes = conn.get_platform_application_attributes(arn)['GetPlatformApplicationAttributesResponse'][ - 'GetPlatformApplicationAttributesResult']['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "platform_principal", - }) - - -@mock_sns_deprecated -def test_get_missing_platform_application_attributes(): - conn = boto.connect_sns() - conn.get_platform_application_attributes.when.called_with( - "a-fake-arn").should.throw(BotoServerError) - - -@mock_sns_deprecated -def test_set_platform_application_attributes(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - attributes={ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "platform_principal", - }, - ) - arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - conn.set_platform_application_attributes(arn, - {"PlatformPrincipal": "other"} - ) - attributes = conn.get_platform_application_attributes(arn)['GetPlatformApplicationAttributesResponse'][ - 'GetPlatformApplicationAttributesResult']['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "other", - }) - - -@mock_sns_deprecated -def test_list_platform_applications(): - conn = boto.connect_sns() - conn.create_platform_application( - name="application1", - platform="APNS", - ) - conn.create_platform_application( - name="application2", - platform="APNS", - ) - - applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['ListPlatformApplicationsResponse'][ - 'ListPlatformApplicationsResult']['PlatformApplications'] - applications.should.have.length_of(2) - - -@mock_sns_deprecated -def test_delete_platform_application(): - conn = boto.connect_sns() - conn.create_platform_application( - name="application1", - platform="APNS", - ) - conn.create_platform_application( - name="application2", - platform="APNS", - ) - - applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['ListPlatformApplicationsResponse'][ - 'ListPlatformApplicationsResult']['PlatformApplications'] - applications.should.have.length_of(2) - - application_arn = applications[0]['PlatformApplicationArn'] - conn.delete_platform_application(application_arn) - - applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['ListPlatformApplicationsResponse'][ - 'ListPlatformApplicationsResult']['PlatformApplications'] - applications.should.have.length_of(1) - - -@mock_sns_deprecated -def test_create_platform_endpoint(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - - endpoint = conn.create_platform_endpoint( - platform_application_arn=application_arn, - token="some_unique_id", - custom_user_data="some user data", - attributes={ - "Enabled": False, - }, - ) - - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] - endpoint_arn.should.contain( - "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") - - -@mock_sns_deprecated -def test_get_list_endpoints_by_platform_application(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - - endpoint = conn.create_platform_endpoint( - platform_application_arn=application_arn, - token="some_unique_id", - custom_user_data="some user data", - attributes={ - "CustomUserData": "some data", - }, - ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] - - endpoint_list = conn.list_endpoints_by_platform_application( - platform_application_arn=application_arn - )['ListEndpointsByPlatformApplicationResponse']['ListEndpointsByPlatformApplicationResult']['Endpoints'] - - endpoint_list.should.have.length_of(1) - endpoint_list[0]['Attributes']['CustomUserData'].should.equal('some data') - endpoint_list[0]['EndpointArn'].should.equal(endpoint_arn) - - -@mock_sns_deprecated -def test_get_endpoint_attributes(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - - endpoint = conn.create_platform_endpoint( - platform_application_arn=application_arn, - token="some_unique_id", - custom_user_data="some user data", - attributes={ - "Enabled": False, - "CustomUserData": "some data", - }, - ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] - - attributes = conn.get_endpoint_attributes(endpoint_arn)['GetEndpointAttributesResponse'][ - 'GetEndpointAttributesResult']['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'False', - "CustomUserData": "some data", - }) - - -@mock_sns_deprecated -def test_get_missing_endpoint_attributes(): - conn = boto.connect_sns() - conn.get_endpoint_attributes.when.called_with( - "a-fake-arn").should.throw(BotoServerError) - - -@mock_sns_deprecated -def test_set_endpoint_attributes(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - - endpoint = conn.create_platform_endpoint( - platform_application_arn=application_arn, - token="some_unique_id", - custom_user_data="some user data", - attributes={ - "Enabled": False, - "CustomUserData": "some data", - }, - ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] - - conn.set_endpoint_attributes(endpoint_arn, - {"CustomUserData": "other data"} - ) - attributes = conn.get_endpoint_attributes(endpoint_arn)['GetEndpointAttributesResponse'][ - 'GetEndpointAttributesResult']['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'False', - "CustomUserData": "other data", - }) - - -@mock_sns_deprecated -def test_delete_endpoint(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - - endpoint = conn.create_platform_endpoint( - platform_application_arn=application_arn, - token="some_unique_id", - custom_user_data="some user data", - attributes={ - "Enabled": False, - "CustomUserData": "some data", - }, - ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] - - endpoint_list = conn.list_endpoints_by_platform_application( - platform_application_arn=application_arn - )['ListEndpointsByPlatformApplicationResponse']['ListEndpointsByPlatformApplicationResult']['Endpoints'] - - endpoint_list.should.have.length_of(1) - - conn.delete_endpoint(endpoint_arn) - - endpoint_list = conn.list_endpoints_by_platform_application( - platform_application_arn=application_arn - )['ListEndpointsByPlatformApplicationResponse']['ListEndpointsByPlatformApplicationResult']['Endpoints'] - endpoint_list.should.have.length_of(0) - - -@mock_sns_deprecated -def test_publish_to_platform_endpoint(): - conn = boto.connect_sns() - platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", - ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - - endpoint = conn.create_platform_endpoint( - platform_application_arn=application_arn, - token="some_unique_id", - custom_user_data="some user data", - attributes={ - "Enabled": True, - }, - ) - - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] - - conn.publish(message="some message", message_structure="json", - target_arn=endpoint_arn) +from __future__ import unicode_literals + +import boto +from boto.exception import BotoServerError +from moto import mock_sns_deprecated +from moto.core import ACCOUNT_ID +import sure # noqa + + +@mock_sns_deprecated +def test_create_platform_application(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + attributes={ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + }, + ) + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + application_arn.should.equal( + "arn:aws:sns:us-east-1:{}:app/APNS/my-application".format(ACCOUNT_ID) + ) + + +@mock_sns_deprecated +def test_get_platform_application_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + attributes={ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + }, + ) + arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + attributes = conn.get_platform_application_attributes(arn)[ + "GetPlatformApplicationAttributesResponse" + ]["GetPlatformApplicationAttributesResult"]["Attributes"] + attributes.should.equal( + { + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + } + ) + + +@mock_sns_deprecated +def test_get_missing_platform_application_attributes(): + conn = boto.connect_sns() + conn.get_platform_application_attributes.when.called_with( + "a-fake-arn" + ).should.throw(BotoServerError) + + +@mock_sns_deprecated +def test_set_platform_application_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", + platform="APNS", + attributes={ + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + }, + ) + arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + conn.set_platform_application_attributes(arn, {"PlatformPrincipal": "other"}) + attributes = conn.get_platform_application_attributes(arn)[ + "GetPlatformApplicationAttributesResponse" + ]["GetPlatformApplicationAttributesResult"]["Attributes"] + attributes.should.equal( + {"PlatformCredential": "platform_credential", "PlatformPrincipal": "other"} + ) + + +@mock_sns_deprecated +def test_list_platform_applications(): + conn = boto.connect_sns() + conn.create_platform_application(name="application1", platform="APNS") + conn.create_platform_application(name="application2", platform="APNS") + + applications_repsonse = conn.list_platform_applications() + applications = applications_repsonse["ListPlatformApplicationsResponse"][ + "ListPlatformApplicationsResult" + ]["PlatformApplications"] + applications.should.have.length_of(2) + + +@mock_sns_deprecated +def test_delete_platform_application(): + conn = boto.connect_sns() + conn.create_platform_application(name="application1", platform="APNS") + conn.create_platform_application(name="application2", platform="APNS") + + applications_repsonse = conn.list_platform_applications() + applications = applications_repsonse["ListPlatformApplicationsResponse"][ + "ListPlatformApplicationsResult" + ]["PlatformApplications"] + applications.should.have.length_of(2) + + application_arn = applications[0]["PlatformApplicationArn"] + conn.delete_platform_application(application_arn) + + applications_repsonse = conn.list_platform_applications() + applications = applications_repsonse["ListPlatformApplicationsResponse"][ + "ListPlatformApplicationsResult" + ]["PlatformApplications"] + applications.should.have.length_of(1) + + +@mock_sns_deprecated +def test_create_platform_endpoint(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", platform="APNS" + ) + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={"Enabled": False}, + ) + + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] + endpoint_arn.should.contain( + "arn:aws:sns:us-east-1:{}:endpoint/APNS/my-application/".format(ACCOUNT_ID) + ) + + +@mock_sns_deprecated +def test_get_list_endpoints_by_platform_application(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", platform="APNS" + ) + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={"CustomUserData": "some data"}, + ) + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] + + endpoint_list = conn.list_endpoints_by_platform_application( + platform_application_arn=application_arn + )["ListEndpointsByPlatformApplicationResponse"][ + "ListEndpointsByPlatformApplicationResult" + ][ + "Endpoints" + ] + + endpoint_list.should.have.length_of(1) + endpoint_list[0]["Attributes"]["CustomUserData"].should.equal("some data") + endpoint_list[0]["EndpointArn"].should.equal(endpoint_arn) + + +@mock_sns_deprecated +def test_get_endpoint_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", platform="APNS" + ) + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={"Enabled": False, "CustomUserData": "some data"}, + ) + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] + + attributes = conn.get_endpoint_attributes(endpoint_arn)[ + "GetEndpointAttributesResponse" + ]["GetEndpointAttributesResult"]["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "False", "CustomUserData": "some data"} + ) + + +@mock_sns_deprecated +def test_get_missing_endpoint_attributes(): + conn = boto.connect_sns() + conn.get_endpoint_attributes.when.called_with("a-fake-arn").should.throw( + BotoServerError + ) + + +@mock_sns_deprecated +def test_set_endpoint_attributes(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", platform="APNS" + ) + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={"Enabled": False, "CustomUserData": "some data"}, + ) + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] + + conn.set_endpoint_attributes(endpoint_arn, {"CustomUserData": "other data"}) + attributes = conn.get_endpoint_attributes(endpoint_arn)[ + "GetEndpointAttributesResponse" + ]["GetEndpointAttributesResult"]["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "False", "CustomUserData": "other data"} + ) + + +@mock_sns_deprecated +def test_delete_endpoint(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", platform="APNS" + ) + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={"Enabled": False, "CustomUserData": "some data"}, + ) + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] + + endpoint_list = conn.list_endpoints_by_platform_application( + platform_application_arn=application_arn + )["ListEndpointsByPlatformApplicationResponse"][ + "ListEndpointsByPlatformApplicationResult" + ][ + "Endpoints" + ] + + endpoint_list.should.have.length_of(1) + + conn.delete_endpoint(endpoint_arn) + + endpoint_list = conn.list_endpoints_by_platform_application( + platform_application_arn=application_arn + )["ListEndpointsByPlatformApplicationResponse"][ + "ListEndpointsByPlatformApplicationResult" + ][ + "Endpoints" + ] + endpoint_list.should.have.length_of(0) + + +@mock_sns_deprecated +def test_publish_to_platform_endpoint(): + conn = boto.connect_sns() + platform_application = conn.create_platform_application( + name="my-application", platform="APNS" + ) + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + + endpoint = conn.create_platform_endpoint( + platform_application_arn=application_arn, + token="some_unique_id", + custom_user_data="some user data", + attributes={"Enabled": True}, + ) + + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] + + conn.publish( + message="some message", message_structure="json", target_arn=endpoint_arn + ) diff --git a/tests/test_sns/test_application_boto3.py b/tests/test_sns/test_application_boto3.py index 6ba2ed89d..fbf2f725f 100644 --- a/tests/test_sns/test_application_boto3.py +++ b/tests/test_sns/test_application_boto3.py @@ -4,11 +4,12 @@ import boto3 from botocore.exceptions import ClientError from moto import mock_sns import sure # noqa +from moto.core import ACCOUNT_ID @mock_sns def test_create_platform_application(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") response = conn.create_platform_application( Name="my-application", Platform="APNS", @@ -17,14 +18,15 @@ def test_create_platform_application(): "PlatformPrincipal": "platform_principal", }, ) - application_arn = response['PlatformApplicationArn'] + application_arn = response["PlatformApplicationArn"] application_arn.should.equal( - 'arn:aws:sns:us-east-1:123456789012:app/APNS/my-application') + "arn:aws:sns:us-east-1:{}:app/APNS/my-application".format(ACCOUNT_ID) + ) @mock_sns def test_get_platform_application_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( Name="my-application", Platform="APNS", @@ -33,25 +35,29 @@ def test_get_platform_application_attributes(): "PlatformPrincipal": "platform_principal", }, ) - arn = platform_application['PlatformApplicationArn'] - attributes = conn.get_platform_application_attributes( - PlatformApplicationArn=arn)['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "platform_principal", - }) + arn = platform_application["PlatformApplicationArn"] + attributes = conn.get_platform_application_attributes(PlatformApplicationArn=arn)[ + "Attributes" + ] + attributes.should.equal( + { + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + } + ) @mock_sns def test_get_missing_platform_application_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.get_platform_application_attributes.when.called_with( - PlatformApplicationArn="a-fake-arn").should.throw(ClientError) + PlatformApplicationArn="a-fake-arn" + ).should.throw(ClientError) @mock_sns def test_set_platform_application_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( Name="my-application", Platform="APNS", @@ -60,291 +66,249 @@ def test_set_platform_application_attributes(): "PlatformPrincipal": "platform_principal", }, ) - arn = platform_application['PlatformApplicationArn'] - conn.set_platform_application_attributes(PlatformApplicationArn=arn, - Attributes={ - "PlatformPrincipal": "other"} - ) - attributes = conn.get_platform_application_attributes( - PlatformApplicationArn=arn)['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "other", - }) + arn = platform_application["PlatformApplicationArn"] + conn.set_platform_application_attributes( + PlatformApplicationArn=arn, Attributes={"PlatformPrincipal": "other"} + ) + attributes = conn.get_platform_application_attributes(PlatformApplicationArn=arn)[ + "Attributes" + ] + attributes.should.equal( + {"PlatformCredential": "platform_credential", "PlatformPrincipal": "other"} + ) @mock_sns def test_list_platform_applications(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_platform_application( - Name="application1", - Platform="APNS", - Attributes={}, + Name="application1", Platform="APNS", Attributes={} ) conn.create_platform_application( - Name="application2", - Platform="APNS", - Attributes={}, + Name="application2", Platform="APNS", Attributes={} ) applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['PlatformApplications'] + applications = applications_repsonse["PlatformApplications"] applications.should.have.length_of(2) @mock_sns def test_delete_platform_application(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_platform_application( - Name="application1", - Platform="APNS", - Attributes={}, + Name="application1", Platform="APNS", Attributes={} ) conn.create_platform_application( - Name="application2", - Platform="APNS", - Attributes={}, + Name="application2", Platform="APNS", Attributes={} ) applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['PlatformApplications'] + applications = applications_repsonse["PlatformApplications"] applications.should.have.length_of(2) - application_arn = applications[0]['PlatformApplicationArn'] + application_arn = applications[0]["PlatformApplicationArn"] conn.delete_platform_application(PlatformApplicationArn=application_arn) applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['PlatformApplications'] + applications = applications_repsonse["PlatformApplications"] applications.should.have.length_of(1) @mock_sns def test_create_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] endpoint_arn.should.contain( - "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") + "arn:aws:sns:us-east-1:{}:endpoint/APNS/my-application/".format(ACCOUNT_ID) + ) @mock_sns def test_create_duplicate_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ) endpoint = conn.create_platform_endpoint.when.called_with( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ).should.throw(ClientError) @mock_sns def test_get_list_endpoints_by_platform_application(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "CustomUserData": "some data", - }, + Attributes={"CustomUserData": "some data"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] endpoint_list = conn.list_endpoints_by_platform_application( PlatformApplicationArn=application_arn - )['Endpoints'] + )["Endpoints"] endpoint_list.should.have.length_of(1) - endpoint_list[0]['Attributes']['CustomUserData'].should.equal('some data') - endpoint_list[0]['EndpointArn'].should.equal(endpoint_arn) + endpoint_list[0]["Attributes"]["CustomUserData"].should.equal("some data") + endpoint_list[0]["EndpointArn"].should.equal(endpoint_arn) @mock_sns def test_get_endpoint_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - "CustomUserData": "some data", - }, + Attributes={"Enabled": "false", "CustomUserData": "some data"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] - attributes = conn.get_endpoint_attributes( - EndpointArn=endpoint_arn)['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'false', - "CustomUserData": "some data", - }) + attributes = conn.get_endpoint_attributes(EndpointArn=endpoint_arn)["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "false", "CustomUserData": "some data"} + ) @mock_sns def test_get_missing_endpoint_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.get_endpoint_attributes.when.called_with( - EndpointArn="a-fake-arn").should.throw(ClientError) + EndpointArn="a-fake-arn" + ).should.throw(ClientError) @mock_sns def test_set_endpoint_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - "CustomUserData": "some data", - }, + Attributes={"Enabled": "false", "CustomUserData": "some data"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] - conn.set_endpoint_attributes(EndpointArn=endpoint_arn, - Attributes={"CustomUserData": "other data"} - ) - attributes = conn.get_endpoint_attributes( - EndpointArn=endpoint_arn)['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'false', - "CustomUserData": "other data", - }) + conn.set_endpoint_attributes( + EndpointArn=endpoint_arn, Attributes={"CustomUserData": "other data"} + ) + attributes = conn.get_endpoint_attributes(EndpointArn=endpoint_arn)["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "false", "CustomUserData": "other data"} + ) @mock_sns def test_publish_to_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'true', - }, + Attributes={"Enabled": "true"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] - conn.publish(Message="some message", - MessageStructure="json", TargetArn=endpoint_arn) + conn.publish( + Message="some message", MessageStructure="json", TargetArn=endpoint_arn + ) @mock_sns def test_publish_to_disabled_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] conn.publish.when.called_with( - Message="some message", - MessageStructure="json", - TargetArn=endpoint_arn, + Message="some message", MessageStructure="json", TargetArn=endpoint_arn ).should.throw(ClientError) @mock_sns def test_set_sms_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") - conn.set_sms_attributes(attributes={'DefaultSMSType': 'Transactional', 'test': 'test'}) + conn.set_sms_attributes( + attributes={"DefaultSMSType": "Transactional", "test": "test"} + ) response = conn.get_sms_attributes() - response.should.contain('attributes') - response['attributes'].should.contain('DefaultSMSType') - response['attributes'].should.contain('test') - response['attributes']['DefaultSMSType'].should.equal('Transactional') - response['attributes']['test'].should.equal('test') + response.should.contain("attributes") + response["attributes"].should.contain("DefaultSMSType") + response["attributes"].should.contain("test") + response["attributes"]["DefaultSMSType"].should.equal("Transactional") + response["attributes"]["test"].should.equal("test") @mock_sns def test_get_sms_attributes_filtered(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") - conn.set_sms_attributes(attributes={'DefaultSMSType': 'Transactional', 'test': 'test'}) + conn.set_sms_attributes( + attributes={"DefaultSMSType": "Transactional", "test": "test"} + ) - response = conn.get_sms_attributes(attributes=['DefaultSMSType']) - response.should.contain('attributes') - response['attributes'].should.contain('DefaultSMSType') - response['attributes'].should_not.contain('test') - response['attributes']['DefaultSMSType'].should.equal('Transactional') + response = conn.get_sms_attributes(attributes=["DefaultSMSType"]) + response.should.contain("attributes") + response["attributes"].should.contain("DefaultSMSType") + response["attributes"].should_not.contain("test") + response["attributes"]["DefaultSMSType"].should.equal("Transactional") diff --git a/tests/test_sns/test_publishing.py b/tests/test_sns/test_publishing.py index d04cf5acc..30fa80f15 100644 --- a/tests/test_sns/test_publishing.py +++ b/tests/test_sns/test_publishing.py @@ -1,69 +1,105 @@ -from __future__ import unicode_literals - -import boto -import json -import re -from freezegun import freeze_time -import sure # noqa - -from moto import mock_sns_deprecated, mock_sqs_deprecated - - -MESSAGE_FROM_SQS_TEMPLATE = '{\n "Message": "%s",\n "MessageId": "%s",\n "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=",\n "SignatureVersion": "1",\n "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem",\n "Subject": "%s",\n "Timestamp": "2015-01-01T12:00:00.000Z",\n "TopicArn": "arn:aws:sns:%s:123456789012:some-topic",\n "Type": "Notification",\n "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55"\n}' - - -@mock_sqs_deprecated -@mock_sns_deprecated -def test_publish_to_sqs(): - conn = boto.connect_sns() - conn.create_topic("some-topic") - topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] - - sqs_conn = boto.connect_sqs() - sqs_conn.create_queue("test-queue") - - conn.subscribe(topic_arn, "sqs", - "arn:aws:sqs:us-east-1:123456789012:test-queue") - - message_to_publish = 'my message' - subject_to_publish = "test subject" - with freeze_time("2015-01-01 12:00:00"): - published_message = conn.publish(topic=topic_arn, message=message_to_publish, subject=subject_to_publish) - published_message_id = published_message['PublishResponse']['PublishResult']['MessageId'] - - queue = sqs_conn.get_queue("test-queue") - message = queue.read(1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message_to_publish, published_message_id, subject_to_publish, 'us-east-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", '2015-01-01T12:00:00.000Z', message.get_body()) - acquired_message.should.equal(expected) - - -@mock_sqs_deprecated -@mock_sns_deprecated -def test_publish_to_sqs_in_different_region(): - conn = boto.sns.connect_to_region("us-west-1") - conn.create_topic("some-topic") - topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] - - sqs_conn = boto.sqs.connect_to_region("us-west-2") - sqs_conn.create_queue("test-queue") - - conn.subscribe(topic_arn, "sqs", - "arn:aws:sqs:us-west-2:123456789012:test-queue") - - message_to_publish = 'my message' - subject_to_publish = "test subject" - with freeze_time("2015-01-01 12:00:00"): - published_message = conn.publish(topic=topic_arn, message=message_to_publish, subject=subject_to_publish) - published_message_id = published_message['PublishResponse']['PublishResult']['MessageId'] - - queue = sqs_conn.get_queue("test-queue") - message = queue.read(1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message_to_publish, published_message_id, subject_to_publish, 'us-west-1') - - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", '2015-01-01T12:00:00.000Z', message.get_body()) - acquired_message.should.equal(expected) +from __future__ import unicode_literals + +import boto +import json +import re +from freezegun import freeze_time +import sure # noqa + +from moto import mock_sns_deprecated, mock_sqs_deprecated +from moto.core import ACCOUNT_ID + +MESSAGE_FROM_SQS_TEMPLATE = ( + '{\n "Message": "%s",\n "MessageId": "%s",\n "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=",\n "SignatureVersion": "1",\n "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem",\n "Subject": "%s",\n "Timestamp": "2015-01-01T12:00:00.000Z",\n "TopicArn": "arn:aws:sns:%s:' + + ACCOUNT_ID + + ':some-topic",\n "Type": "Notification",\n "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:' + + ACCOUNT_ID + + ':some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55"\n}' +) + + +@mock_sqs_deprecated +@mock_sns_deprecated +def test_publish_to_sqs(): + conn = boto.connect_sns() + conn.create_topic("some-topic") + topics_json = conn.get_all_topics() + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] + + sqs_conn = boto.connect_sqs() + sqs_conn.create_queue("test-queue") + + conn.subscribe( + topic_arn, "sqs", "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID) + ) + + message_to_publish = "my message" + subject_to_publish = "test subject" + with freeze_time("2015-01-01 12:00:00"): + published_message = conn.publish( + topic=topic_arn, message=message_to_publish, subject=subject_to_publish + ) + published_message_id = published_message["PublishResponse"]["PublishResult"][ + "MessageId" + ] + + queue = sqs_conn.get_queue("test-queue") + message = queue.read(1) + expected = MESSAGE_FROM_SQS_TEMPLATE % ( + message_to_publish, + published_message_id, + subject_to_publish, + "us-east-1", + ) + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + message.get_body(), + ) + acquired_message.should.equal(expected) + + +@mock_sqs_deprecated +@mock_sns_deprecated +def test_publish_to_sqs_in_different_region(): + conn = boto.sns.connect_to_region("us-west-1") + conn.create_topic("some-topic") + topics_json = conn.get_all_topics() + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] + + sqs_conn = boto.sqs.connect_to_region("us-west-2") + sqs_conn.create_queue("test-queue") + + conn.subscribe( + topic_arn, "sqs", "arn:aws:sqs:us-west-2:{}:test-queue".format(ACCOUNT_ID) + ) + + message_to_publish = "my message" + subject_to_publish = "test subject" + with freeze_time("2015-01-01 12:00:00"): + published_message = conn.publish( + topic=topic_arn, message=message_to_publish, subject=subject_to_publish + ) + published_message_id = published_message["PublishResponse"]["PublishResult"][ + "MessageId" + ] + + queue = sqs_conn.get_queue("test-queue") + message = queue.read(1) + expected = MESSAGE_FROM_SQS_TEMPLATE % ( + message_to_publish, + published_message_id, + subject_to_publish, + "us-west-1", + ) + + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + message.get_body(), + ) + acquired_message.should.equal(expected) diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index e146ec3c9..d85c8fefe 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -1,489 +1,982 @@ -from __future__ import unicode_literals - -import base64 -import json - -import boto3 -import re -from freezegun import freeze_time -import sure # noqa - -import responses -from botocore.exceptions import ClientError -from nose.tools import assert_raises -from moto import mock_sns, mock_sqs - - -MESSAGE_FROM_SQS_TEMPLATE = '{\n "Message": "%s",\n "MessageId": "%s",\n "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=",\n "SignatureVersion": "1",\n "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem",\n "Subject": "my subject",\n "Timestamp": "2015-01-01T12:00:00.000Z",\n "TopicArn": "arn:aws:sns:%s:123456789012:some-topic",\n "Type": "Notification",\n "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55"\n}' - - -@mock_sqs -@mock_sns -def test_publish_to_sqs(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - sqs_conn = boto3.resource('sqs', region_name='us-east-1') - sqs_conn.create_queue(QueueName="test-queue") - - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' - with freeze_time("2015-01-01 12:00:00"): - published_message = conn.publish(TopicArn=topic_arn, Message=message) - published_message_id = published_message['MessageId'] - - queue = sqs_conn.get_queue_by_name(QueueName="test-queue") - messages = queue.receive_messages(MaxNumberOfMessages=1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, 'us-east-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", u'2015-01-01T12:00:00.000Z', messages[0].body) - acquired_message.should.equal(expected) - - -@mock_sqs -@mock_sns -def test_publish_to_sqs_raw(): - sns = boto3.resource('sns', region_name='us-east-1') - topic = sns.create_topic(Name='some-topic') - - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') - - subscription = topic.subscribe( - Protocol='sqs', Endpoint=queue.attributes['QueueArn']) - - subscription.set_attributes( - AttributeName='RawMessageDelivery', AttributeValue='true') - - message = 'my message' - with freeze_time("2015-01-01 12:00:00"): - topic.publish(Message=message) - - messages = queue.receive_messages(MaxNumberOfMessages=1) - messages[0].body.should.equal(message) - - -@mock_sqs -@mock_sns -def test_publish_to_sqs_bad(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - sqs_conn = boto3.resource('sqs', region_name='us-east-1') - sqs_conn.create_queue(QueueName="test-queue") - - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' - try: - # Test missing Value - conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': {'DataType': 'String'}}) - except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') - try: - # Test empty DataType (if the DataType field is missing entirely - # botocore throws an exception during validation) - conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': { - 'DataType': '', - 'StringValue': 'example_corp' - }}) - except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') - try: - # Test empty Value - conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': { - 'DataType': 'String', - 'StringValue': '' - }}) - except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') - - -@mock_sqs -@mock_sns -def test_publish_to_sqs_msg_attr_byte_value(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - sqs_conn = boto3.resource('sqs', region_name='us-east-1') - queue = sqs_conn.create_queue(QueueName="test-queue") - - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' - conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': { - 'DataType': 'Binary', - 'BinaryValue': b'\x02\x03\x04' - }}) - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([{ - 'store': { - 'Type': 'Binary', - 'Value': base64.b64encode(b'\x02\x03\x04').decode() - } - }]) - - -@mock_sns -def test_publish_sms(): - client = boto3.client('sns', region_name='us-east-1') - client.create_topic(Name="some-topic") - resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] - - client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='+15551234567' - ) - - result = client.publish(PhoneNumber="+15551234567", Message="my message") - result.should.contain('MessageId') - - -@mock_sns -def test_publish_bad_sms(): - client = boto3.client('sns', region_name='us-east-1') - client.create_topic(Name="some-topic") - resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] - - client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='+15551234567' - ) - - try: - # Test invalid number - client.publish(PhoneNumber="NAA+15551234567", Message="my message") - except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') - - try: - # Test not found number - client.publish(PhoneNumber="+44001234567", Message="my message") - except ClientError as err: - err.response['Error']['Code'].should.equal('ParameterValueInvalid') - - -@mock_sqs -@mock_sns -def test_publish_to_sqs_dump_json(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - sqs_conn = boto3.resource('sqs', region_name='us-east-1') - sqs_conn.create_queue(QueueName="test-queue") - - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - - message = json.dumps({ - "Records": [{ - "eventVersion": "2.0", - "eventSource": "aws:s3", - "s3": { - "s3SchemaVersion": "1.0" - } - }] - }, sort_keys=True) - with freeze_time("2015-01-01 12:00:00"): - published_message = conn.publish(TopicArn=topic_arn, Message=message) - published_message_id = published_message['MessageId'] - - queue = sqs_conn.get_queue_by_name(QueueName="test-queue") - messages = queue.receive_messages(MaxNumberOfMessages=1) - - escaped = message.replace('"', '\\"') - expected = MESSAGE_FROM_SQS_TEMPLATE % (escaped, published_message_id, 'us-east-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", u'2015-01-01T12:00:00.000Z', messages[0].body) - acquired_message.should.equal(expected) - - -@mock_sqs -@mock_sns -def test_publish_to_sqs_in_different_region(): - conn = boto3.client('sns', region_name='us-west-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - sqs_conn = boto3.resource('sqs', region_name='us-west-2') - sqs_conn.create_queue(QueueName="test-queue") - - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-west-2:123456789012:test-queue") - - message = 'my message' - with freeze_time("2015-01-01 12:00:00"): - published_message = conn.publish(TopicArn=topic_arn, Message=message) - published_message_id = published_message['MessageId'] - - queue = sqs_conn.get_queue_by_name(QueueName="test-queue") - messages = queue.receive_messages(MaxNumberOfMessages=1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, 'us-west-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", u'2015-01-01T12:00:00.000Z', messages[0].body) - acquired_message.should.equal(expected) - - -@freeze_time("2013-01-01") -@mock_sns -def test_publish_to_http(): - def callback(request): - request.headers["Content-Type"].should.equal("text/plain; charset=UTF-8") - json.loads.when.called_with( - request.body.decode() - ).should_not.throw(Exception) - return 200, {}, "" - - responses.add_callback( - method="POST", - url="http://example.com/foobar", - callback=callback, - ) - - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/foobar") - - response = conn.publish( - TopicArn=topic_arn, Message="my message", Subject="my subject") - - -@mock_sqs -@mock_sns -def test_publish_subject(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - sqs_conn = boto3.resource('sqs', region_name='us-east-1') - sqs_conn.create_queue(QueueName="test-queue") - - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' - subject1 = 'test subject' - subject2 = 'test subject' * 20 - with freeze_time("2015-01-01 12:00:00"): - conn.publish(TopicArn=topic_arn, Message=message, Subject=subject1) - - # Just that it doesnt error is a pass - try: - with freeze_time("2015-01-01 12:00:00"): - conn.publish(TopicArn=topic_arn, Message=message, Subject=subject2) - except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') - else: - raise RuntimeError('Should have raised an InvalidParameter exception') - - -@mock_sns -def test_publish_message_too_long(): - sns = boto3.resource('sns', region_name='us-east-1') - topic = sns.create_topic(Name='some-topic') - - with assert_raises(ClientError): - topic.publish( - Message="".join(["." for i in range(0, 262145)])) - - # message short enough - does not raise an error - topic.publish( - Message="".join(["." for i in range(0, 262144)])) - - -def _setup_filter_policy_test(filter_policy): - sns = boto3.resource('sns', region_name='us-east-1') - topic = sns.create_topic(Name='some-topic') - - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') - - subscription = topic.subscribe( - Protocol='sqs', Endpoint=queue.attributes['QueueArn']) - - subscription.set_attributes( - AttributeName='FilterPolicy', AttributeValue=json.dumps(filter_policy)) - - return topic, subscription, queue - - -@mock_sqs -@mock_sns -def test_filtering_exact_string(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) - - topic.publish( - Message='match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}}) - - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal( - [{'store': {'Type': 'String', 'Value': 'example_corp'}}]) - - -@mock_sqs -@mock_sns -def test_filtering_exact_string_multiple_message_attributes(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) - - topic.publish( - Message='match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}}) - - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([{ - 'store': {'Type': 'String', 'Value': 'example_corp'}, - 'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) - - -@mock_sqs -@mock_sns -def test_filtering_exact_string_OR_matching(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp', 'different_corp']}) - - topic.publish( - Message='match example_corp', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}}) - topic.publish( - Message='match different_corp', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'different_corp'}}) - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal( - ['match example_corp', 'match different_corp']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([ - {'store': {'Type': 'String', 'Value': 'example_corp'}}, - {'store': {'Type': 'String', 'Value': 'different_corp'}}]) - - -@mock_sqs -@mock_sns -def test_filtering_exact_string_AND_matching_positive(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp'], - 'event': ['order_cancelled']}) - - topic.publish( - Message='match example_corp order_cancelled', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}}) - - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal( - ['match example_corp order_cancelled']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([{ - 'store': {'Type': 'String', 'Value': 'example_corp'}, - 'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) - - -@mock_sqs -@mock_sns -def test_filtering_exact_string_AND_matching_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp'], - 'event': ['order_cancelled']}) - - topic.publish( - Message='match example_corp order_accepted', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_accepted'}}) - - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([]) - - -@mock_sqs -@mock_sns -def test_filtering_exact_string_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) - - topic.publish( - Message='no match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'different_corp'}}) - - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([]) - - -@mock_sqs -@mock_sns -def test_filtering_exact_string_no_attributes_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) - - topic.publish(Message='no match') - - messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([]) +from __future__ import unicode_literals + +import base64 +import json + +import boto3 +import re +from freezegun import freeze_time +import sure # noqa + +import responses +from botocore.exceptions import ClientError +from nose.tools import assert_raises +from moto import mock_sns, mock_sqs +from moto.core import ACCOUNT_ID + +MESSAGE_FROM_SQS_TEMPLATE = ( + '{\n "Message": "%s",\n "MessageId": "%s",\n "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=",\n "SignatureVersion": "1",\n "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem",\n "Subject": "my subject",\n "Timestamp": "2015-01-01T12:00:00.000Z",\n "TopicArn": "arn:aws:sns:%s:' + + ACCOUNT_ID + + ':some-topic",\n "Type": "Notification",\n "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:' + + ACCOUNT_ID + + ':some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55"\n}' +) + + +@mock_sqs +@mock_sns +def test_publish_to_sqs(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + sqs_conn = boto3.resource("sqs", region_name="us-east-1") + sqs_conn.create_queue(QueueName="test-queue") + + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID), + ) + message = "my message" + with freeze_time("2015-01-01 12:00:00"): + published_message = conn.publish(TopicArn=topic_arn, Message=message) + published_message_id = published_message["MessageId"] + + queue = sqs_conn.get_queue_by_name(QueueName="test-queue") + messages = queue.receive_messages(MaxNumberOfMessages=1) + expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, "us-east-1") + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + messages[0].body, + ) + acquired_message.should.equal(expected) + + +@mock_sqs +@mock_sns +def test_publish_to_sqs_raw(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="some-topic") + + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") + + subscription = topic.subscribe( + Protocol="sqs", Endpoint=queue.attributes["QueueArn"] + ) + + subscription.set_attributes( + AttributeName="RawMessageDelivery", AttributeValue="true" + ) + + message = "my message" + with freeze_time("2015-01-01 12:00:00"): + topic.publish(Message=message) + + messages = queue.receive_messages(MaxNumberOfMessages=1) + messages[0].body.should.equal(message) + + +@mock_sqs +@mock_sns +def test_publish_to_sqs_bad(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + sqs_conn = boto3.resource("sqs", region_name="us-east-1") + sqs_conn.create_queue(QueueName="test-queue") + + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID), + ) + message = "my message" + try: + # Test missing Value + conn.publish( + TopicArn=topic_arn, + Message=message, + MessageAttributes={"store": {"DataType": "String"}}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameterValue") + try: + # Test empty DataType (if the DataType field is missing entirely + # botocore throws an exception during validation) + conn.publish( + TopicArn=topic_arn, + Message=message, + MessageAttributes={ + "store": {"DataType": "", "StringValue": "example_corp"} + }, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameterValue") + try: + # Test empty Value + conn.publish( + TopicArn=topic_arn, + Message=message, + MessageAttributes={"store": {"DataType": "String", "StringValue": ""}}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameterValue") + try: + # Test Number DataType, with a non numeric value + conn.publish( + TopicArn=topic_arn, + Message=message, + MessageAttributes={"price": {"DataType": "Number", "StringValue": "error"}}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameterValue") + err.response["Error"]["Message"].should.equal( + "An error occurred (ParameterValueInvalid) when calling the Publish operation: Could not cast message attribute 'price' value to number." + ) + + +@mock_sqs +@mock_sns +def test_publish_to_sqs_msg_attr_byte_value(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + sqs_conn = boto3.resource("sqs", region_name="us-east-1") + queue = sqs_conn.create_queue(QueueName="test-queue") + + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID), + ) + message = "my message" + conn.publish( + TopicArn=topic_arn, + Message=message, + MessageAttributes={ + "store": {"DataType": "Binary", "BinaryValue": b"\x02\x03\x04"} + }, + ) + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": { + "Type": "Binary", + "Value": base64.b64encode(b"\x02\x03\x04").decode(), + } + } + ] + ) + + +@mock_sqs +@mock_sns +def test_publish_to_sqs_msg_attr_number_type(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="test-topic") + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") + topic.subscribe(Protocol="sqs", Endpoint=queue.attributes["QueueArn"]) + + topic.publish( + Message="test message", + MessageAttributes={"retries": {"DataType": "Number", "StringValue": "0"}}, + ) + + message = json.loads(queue.receive_messages()[0].body) + message["Message"].should.equal("test message") + message["MessageAttributes"].should.equal( + {"retries": {"Type": "Number", "Value": 0}} + ) + + +@mock_sns +def test_publish_sms(): + client = boto3.client("sns", region_name="us-east-1") + client.create_topic(Name="some-topic") + resp = client.create_topic(Name="some-topic") + arn = resp["TopicArn"] + + client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="+15551234567") + + result = client.publish(PhoneNumber="+15551234567", Message="my message") + result.should.contain("MessageId") + + +@mock_sns +def test_publish_bad_sms(): + client = boto3.client("sns", region_name="us-east-1") + client.create_topic(Name="some-topic") + resp = client.create_topic(Name="some-topic") + arn = resp["TopicArn"] + + client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="+15551234567") + + try: + # Test invalid number + client.publish(PhoneNumber="NAA+15551234567", Message="my message") + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + + try: + # Test not found number + client.publish(PhoneNumber="+44001234567", Message="my message") + except ClientError as err: + err.response["Error"]["Code"].should.equal("ParameterValueInvalid") + + +@mock_sqs +@mock_sns +def test_publish_to_sqs_dump_json(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + sqs_conn = boto3.resource("sqs", region_name="us-east-1") + sqs_conn.create_queue(QueueName="test-queue") + + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID), + ) + + message = json.dumps( + { + "Records": [ + { + "eventVersion": "2.0", + "eventSource": "aws:s3", + "s3": {"s3SchemaVersion": "1.0"}, + } + ] + }, + sort_keys=True, + ) + with freeze_time("2015-01-01 12:00:00"): + published_message = conn.publish(TopicArn=topic_arn, Message=message) + published_message_id = published_message["MessageId"] + + queue = sqs_conn.get_queue_by_name(QueueName="test-queue") + messages = queue.receive_messages(MaxNumberOfMessages=1) + + escaped = message.replace('"', '\\"') + expected = MESSAGE_FROM_SQS_TEMPLATE % (escaped, published_message_id, "us-east-1") + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + messages[0].body, + ) + acquired_message.should.equal(expected) + + +@mock_sqs +@mock_sns +def test_publish_to_sqs_in_different_region(): + conn = boto3.client("sns", region_name="us-west-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + sqs_conn = boto3.resource("sqs", region_name="us-west-2") + sqs_conn.create_queue(QueueName="test-queue") + + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-west-2:{}:test-queue".format(ACCOUNT_ID), + ) + + message = "my message" + with freeze_time("2015-01-01 12:00:00"): + published_message = conn.publish(TopicArn=topic_arn, Message=message) + published_message_id = published_message["MessageId"] + + queue = sqs_conn.get_queue_by_name(QueueName="test-queue") + messages = queue.receive_messages(MaxNumberOfMessages=1) + expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, "us-west-1") + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + messages[0].body, + ) + acquired_message.should.equal(expected) + + +@freeze_time("2013-01-01") +@mock_sns +def test_publish_to_http(): + def callback(request): + request.headers["Content-Type"].should.equal("text/plain; charset=UTF-8") + json.loads.when.called_with(request.body.decode()).should_not.throw(Exception) + return 200, {}, "" + + responses.add_callback( + method="POST", url="http://example.com/foobar", callback=callback + ) + + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + conn.subscribe( + TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/foobar" + ) + + response = conn.publish( + TopicArn=topic_arn, Message="my message", Subject="my subject" + ) + + +@mock_sqs +@mock_sns +def test_publish_subject(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + sqs_conn = boto3.resource("sqs", region_name="us-east-1") + sqs_conn.create_queue(QueueName="test-queue") + + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID), + ) + message = "my message" + subject1 = "test subject" + subject2 = "test subject" * 20 + with freeze_time("2015-01-01 12:00:00"): + conn.publish(TopicArn=topic_arn, Message=message, Subject=subject1) + + # Just that it doesnt error is a pass + try: + with freeze_time("2015-01-01 12:00:00"): + conn.publish(TopicArn=topic_arn, Message=message, Subject=subject2) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + else: + raise RuntimeError("Should have raised an InvalidParameter exception") + + +@mock_sns +def test_publish_message_too_long(): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="some-topic") + + with assert_raises(ClientError): + topic.publish(Message="".join(["." for i in range(0, 262145)])) + + # message short enough - does not raise an error + topic.publish(Message="".join(["." for i in range(0, 262144)])) + + +def _setup_filter_policy_test(filter_policy): + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="some-topic") + + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") + + subscription = topic.subscribe( + Protocol="sqs", Endpoint=queue.attributes["QueueArn"] + ) + + subscription.set_attributes( + AttributeName="FilterPolicy", AttributeValue=json.dumps(filter_policy) + ) + + return topic, subscription, queue + + +@mock_sqs +@mock_sns +def test_filtering_exact_string(): + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) + + topic.publish( + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [{"store": {"Type": "String", "Value": "example_corp"}}] + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_multiple_message_attributes(): + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) + + topic.publish( + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": {"Type": "String", "Value": "example_corp"}, + "event": {"Type": "String", "Value": "order_cancelled"}, + } + ] + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_OR_matching(): + topic, subscription, queue = _setup_filter_policy_test( + {"store": ["example_corp", "different_corp"]} + ) + + topic.publish( + Message="match example_corp", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) + topic.publish( + Message="match different_corp", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "different_corp"} + }, + ) + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match example_corp", "match different_corp"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + {"store": {"Type": "String", "Value": "example_corp"}}, + {"store": {"Type": "String", "Value": "different_corp"}}, + ] + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_AND_matching_positive(): + topic, subscription, queue = _setup_filter_policy_test( + {"store": ["example_corp"], "event": ["order_cancelled"]} + ) + + topic.publish( + Message="match example_corp order_cancelled", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match example_corp order_cancelled"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": {"Type": "String", "Value": "example_corp"}, + "event": {"Type": "String", "Value": "order_cancelled"}, + } + ] + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_AND_matching_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"store": ["example_corp"], "event": ["order_cancelled"]} + ) + + topic.publish( + Message="match example_corp order_accepted", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_accepted"}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_no_match(): + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) + + topic.publish( + Message="no match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "different_corp"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_string_no_attributes_no_match(): + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) + + topic.publish(Message="no match") + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_int(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) + + topic.publish( + Message="match", + MessageAttributes={"price": {"DataType": "Number", "StringValue": "100"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"price": {"Type": "Number", "Value": 100}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_float(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100.1]}) + + topic.publish( + Message="match", + MessageAttributes={"price": {"DataType": "Number", "StringValue": "100.1"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"price": {"Type": "Number", "Value": 100.1}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_float_accuracy(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100.123456789]}) + + topic.publish( + Message="match", + MessageAttributes={ + "price": {"DataType": "Number", "StringValue": "100.1234561"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [{"price": {"Type": "Number", "Value": 100.1234561}}] + ) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_no_match(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) + + topic.publish( + Message="no match", + MessageAttributes={"price": {"DataType": "Number", "StringValue": "101"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_with_string_no_match(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) + + topic.publish( + Message="no match", + MessageAttributes={"price": {"DataType": "String", "StringValue": "100"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"customer_interests": ["basketball", "baseball"]} + ) + + topic.publish( + Message="match", + MessageAttributes={ + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + } + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "customer_interests": { + "Type": "String.Array", + "Value": json.dumps(["basketball", "rugby"]), + } + } + ] + ) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"customer_interests": ["baseball"]} + ) + + topic.publish( + Message="no_match", + MessageAttributes={ + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + } + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_match(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100, 500]}) + + topic.publish( + Message="match", + MessageAttributes={ + "price": {"DataType": "String.Array", "StringValue": json.dumps([100, 50])} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [{"price": {"Type": "String.Array", "Value": json.dumps([100, 50])}}] + ) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_float_accuracy_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"price": [100.123456789, 500]} + ) + + topic.publish( + Message="match", + MessageAttributes={ + "price": { + "DataType": "String.Array", + "StringValue": json.dumps([100.1234561, 50]), + } + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [{"price": {"Type": "String.Array", "Value": json.dumps([100.1234561, 50])}}] + ) + + +@mock_sqs +@mock_sns +# this is the correct behavior from SNS +def test_filtering_string_array_with_number_no_array_match(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100, 500]}) + + topic.publish( + Message="match", + MessageAttributes={"price": {"DataType": "String.Array", "StringValue": "100"}}, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [{"price": {"Type": "String.Array", "Value": "100"}}] + ) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_no_match(): + topic, subscription, queue = _setup_filter_policy_test({"price": [500]}) + + topic.publish( + Message="no_match", + MessageAttributes={ + "price": {"DataType": "String.Array", "StringValue": json.dumps([100, 50])} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +# this is the correct behavior from SNS +def test_filtering_string_array_with_string_no_array_no_match(): + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) + + topic.publish( + Message="no_match", + MessageAttributes={ + "price": {"DataType": "String.Array", "StringValue": "one hundread"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_exists_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"store": [{"exists": True}]} + ) + + topic.publish( + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [{"store": {"Type": "String", "Value": "example_corp"}}] + ) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_exists_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"store": [{"exists": True}]} + ) + + topic.publish( + Message="no match", + MessageAttributes={ + "event": {"DataType": "String", "StringValue": "order_cancelled"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_not_exists_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"store": [{"exists": False}]} + ) + + topic.publish( + Message="match", + MessageAttributes={ + "event": {"DataType": "String", "StringValue": "order_cancelled"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [{"event": {"Type": "String", "Value": "order_cancelled"}}] + ) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_not_exists_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {"store": [{"exists": False}]} + ) + + topic.publish( + Message="no match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_match(): + topic, subscription, queue = _setup_filter_policy_test( + { + "store": [{"exists": True}], + "event": ["order_cancelled"], + "customer_interests": ["basketball", "baseball"], + "price": [100], + } + ) + + topic.publish( + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + }, + "price": {"DataType": "Number", "StringValue": "100"}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": {"Type": "String", "Value": "example_corp"}, + "event": {"Type": "String", "Value": "order_cancelled"}, + "customer_interests": { + "Type": "String.Array", + "Value": json.dumps(["basketball", "rugby"]), + }, + "price": {"Type": "Number", "Value": 100}, + } + ] + ) + + +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + { + "store": [{"exists": True}], + "event": ["order_cancelled"], + "customer_interests": ["basketball", "baseball"], + "price": [100], + "encrypted": [False], + } + ) + + topic.publish( + Message="no match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + }, + "price": {"DataType": "Number", "StringValue": "100"}, + }, + ) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal([]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([]) diff --git a/tests/test_sns/test_server.py b/tests/test_sns/test_server.py index bdaefa453..78bc147df 100644 --- a/tests/test_sns/test_server.py +++ b/tests/test_sns/test_server.py @@ -1,24 +1,27 @@ -from __future__ import unicode_literals - -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_sns_server_get(): - backend = server.create_backend_app("sns") - test_client = backend.test_client() - - topic_data = test_client.action_data("CreateTopic", Name="testtopic") - topic_data.should.contain("CreateTopicResult") - topic_data.should.contain( - "arn:aws:sns:us-east-1:123456789012:testtopic") - - topics_data = test_client.action_data("ListTopics") - topics_data.should.contain("ListTopicsResult") - topic_data.should.contain( - "arn:aws:sns:us-east-1:123456789012:testtopic") +from __future__ import unicode_literals +from moto.core import ACCOUNT_ID + +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_sns_server_get(): + backend = server.create_backend_app("sns") + test_client = backend.test_client() + + topic_data = test_client.action_data("CreateTopic", Name="testtopic") + topic_data.should.contain("CreateTopicResult") + topic_data.should.contain( + "arn:aws:sns:us-east-1:{}:testtopic".format(ACCOUNT_ID) + ) + + topics_data = test_client.action_data("ListTopics") + topics_data.should.contain("ListTopicsResult") + topic_data.should.contain( + "arn:aws:sns:us-east-1:{}:testtopic".format(ACCOUNT_ID) + ) diff --git a/tests/test_sns/test_subscriptions.py b/tests/test_sns/test_subscriptions.py index 3a40ba9ad..fbd4274f4 100644 --- a/tests/test_sns/test_subscriptions.py +++ b/tests/test_sns/test_subscriptions.py @@ -1,135 +1,149 @@ -from __future__ import unicode_literals -import boto - -import sure # noqa - -from moto import mock_sns_deprecated -from moto.sns.models import DEFAULT_PAGE_SIZE - - -@mock_sns_deprecated -def test_creating_subscription(): - conn = boto.connect_sns() - conn.create_topic("some-topic") - topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] - - conn.subscribe(topic_arn, "http", "http://example.com/") - - subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] - subscriptions.should.have.length_of(1) - subscription = subscriptions[0] - subscription["TopicArn"].should.equal(topic_arn) - subscription["Protocol"].should.equal("http") - subscription["SubscriptionArn"].should.contain(topic_arn) - subscription["Endpoint"].should.equal("http://example.com/") - - # Now unsubscribe the subscription - conn.unsubscribe(subscription["SubscriptionArn"]) - - # And there should be zero subscriptions left - subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] - subscriptions.should.have.length_of(0) - - -@mock_sns_deprecated -def test_deleting_subscriptions_by_deleting_topic(): - conn = boto.connect_sns() - conn.create_topic("some-topic") - topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] - - conn.subscribe(topic_arn, "http", "http://example.com/") - - subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] - subscriptions.should.have.length_of(1) - subscription = subscriptions[0] - subscription["TopicArn"].should.equal(topic_arn) - subscription["Protocol"].should.equal("http") - subscription["SubscriptionArn"].should.contain(topic_arn) - subscription["Endpoint"].should.equal("http://example.com/") - - # Now delete the topic - conn.delete_topic(topic_arn) - - # And there should now be 0 topics - topics_json = conn.get_all_topics() - topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] - topics.should.have.length_of(0) - - # And there should be zero subscriptions left - subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] - subscriptions.should.have.length_of(0) - - -@mock_sns_deprecated -def test_getting_subscriptions_by_topic(): - conn = boto.connect_sns() - conn.create_topic("topic1") - conn.create_topic("topic2") - - topics_json = conn.get_all_topics() - topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] - topic1_arn = topics[0]['TopicArn'] - topic2_arn = topics[1]['TopicArn'] - - conn.subscribe(topic1_arn, "http", "http://example1.com/") - conn.subscribe(topic2_arn, "http", "http://example2.com/") - - topic1_subscriptions = conn.get_all_subscriptions_by_topic(topic1_arn)[ - "ListSubscriptionsByTopicResponse"]["ListSubscriptionsByTopicResult"]["Subscriptions"] - topic1_subscriptions.should.have.length_of(1) - topic1_subscriptions[0]['Endpoint'].should.equal("http://example1.com/") - - -@mock_sns_deprecated -def test_subscription_paging(): - conn = boto.connect_sns() - conn.create_topic("topic1") - conn.create_topic("topic2") - - topics_json = conn.get_all_topics() - topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] - topic1_arn = topics[0]['TopicArn'] - topic2_arn = topics[1]['TopicArn'] - - for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 3)): - conn.subscribe(topic1_arn, 'email', 'email_' + - str(index) + '@test.com') - conn.subscribe(topic2_arn, 'email', 'email_' + - str(index) + '@test.com') - - all_subscriptions = conn.get_all_subscriptions() - all_subscriptions["ListSubscriptionsResponse"]["ListSubscriptionsResult"][ - "Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) - next_token = all_subscriptions["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["NextToken"] - next_token.should.equal(DEFAULT_PAGE_SIZE) - - all_subscriptions = conn.get_all_subscriptions(next_token=next_token * 2) - all_subscriptions["ListSubscriptionsResponse"]["ListSubscriptionsResult"][ - "Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE * 2 / 3)) - next_token = all_subscriptions["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["NextToken"] - next_token.should.equal(None) - - topic1_subscriptions = conn.get_all_subscriptions_by_topic(topic1_arn) - topic1_subscriptions["ListSubscriptionsByTopicResponse"]["ListSubscriptionsByTopicResult"][ - "Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) - next_token = topic1_subscriptions["ListSubscriptionsByTopicResponse"][ - "ListSubscriptionsByTopicResult"]["NextToken"] - next_token.should.equal(DEFAULT_PAGE_SIZE) - - topic1_subscriptions = conn.get_all_subscriptions_by_topic( - topic1_arn, next_token=next_token) - topic1_subscriptions["ListSubscriptionsByTopicResponse"]["ListSubscriptionsByTopicResult"][ - "Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE / 3)) - next_token = topic1_subscriptions["ListSubscriptionsByTopicResponse"][ - "ListSubscriptionsByTopicResult"]["NextToken"] - next_token.should.equal(None) +from __future__ import unicode_literals +import boto + +import sure # noqa + +from moto import mock_sns_deprecated +from moto.sns.models import DEFAULT_PAGE_SIZE + + +@mock_sns_deprecated +def test_creating_subscription(): + conn = boto.connect_sns() + conn.create_topic("some-topic") + topics_json = conn.get_all_topics() + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] + + conn.subscribe(topic_arn, "http", "http://example.com/") + + subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ + "ListSubscriptionsResult" + ]["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Now unsubscribe the subscription + conn.unsubscribe(subscription["SubscriptionArn"]) + + # And there should be zero subscriptions left + subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ + "ListSubscriptionsResult" + ]["Subscriptions"] + subscriptions.should.have.length_of(0) + + +@mock_sns_deprecated +def test_deleting_subscriptions_by_deleting_topic(): + conn = boto.connect_sns() + conn.create_topic("some-topic") + topics_json = conn.get_all_topics() + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] + + conn.subscribe(topic_arn, "http", "http://example.com/") + + subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ + "ListSubscriptionsResult" + ]["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Now delete the topic + conn.delete_topic(topic_arn) + + # And there should now be 0 topics + topics_json = conn.get_all_topics() + topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + topics.should.have.length_of(0) + + # And there should be zero subscriptions left + subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ + "ListSubscriptionsResult" + ]["Subscriptions"] + subscriptions.should.have.length_of(0) + + +@mock_sns_deprecated +def test_getting_subscriptions_by_topic(): + conn = boto.connect_sns() + conn.create_topic("topic1") + conn.create_topic("topic2") + + topics_json = conn.get_all_topics() + topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + topic1_arn = topics[0]["TopicArn"] + topic2_arn = topics[1]["TopicArn"] + + conn.subscribe(topic1_arn, "http", "http://example1.com/") + conn.subscribe(topic2_arn, "http", "http://example2.com/") + + topic1_subscriptions = conn.get_all_subscriptions_by_topic(topic1_arn)[ + "ListSubscriptionsByTopicResponse" + ]["ListSubscriptionsByTopicResult"]["Subscriptions"] + topic1_subscriptions.should.have.length_of(1) + topic1_subscriptions[0]["Endpoint"].should.equal("http://example1.com/") + + +@mock_sns_deprecated +def test_subscription_paging(): + conn = boto.connect_sns() + conn.create_topic("topic1") + conn.create_topic("topic2") + + topics_json = conn.get_all_topics() + topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + topic1_arn = topics[0]["TopicArn"] + topic2_arn = topics[1]["TopicArn"] + + for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 3)): + conn.subscribe(topic1_arn, "email", "email_" + str(index) + "@test.com") + conn.subscribe(topic2_arn, "email", "email_" + str(index) + "@test.com") + + all_subscriptions = conn.get_all_subscriptions() + all_subscriptions["ListSubscriptionsResponse"]["ListSubscriptionsResult"][ + "Subscriptions" + ].should.have.length_of(DEFAULT_PAGE_SIZE) + next_token = all_subscriptions["ListSubscriptionsResponse"][ + "ListSubscriptionsResult" + ]["NextToken"] + next_token.should.equal(DEFAULT_PAGE_SIZE) + + all_subscriptions = conn.get_all_subscriptions(next_token=next_token * 2) + all_subscriptions["ListSubscriptionsResponse"]["ListSubscriptionsResult"][ + "Subscriptions" + ].should.have.length_of(int(DEFAULT_PAGE_SIZE * 2 / 3)) + next_token = all_subscriptions["ListSubscriptionsResponse"][ + "ListSubscriptionsResult" + ]["NextToken"] + next_token.should.equal(None) + + topic1_subscriptions = conn.get_all_subscriptions_by_topic(topic1_arn) + topic1_subscriptions["ListSubscriptionsByTopicResponse"][ + "ListSubscriptionsByTopicResult" + ]["Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) + next_token = topic1_subscriptions["ListSubscriptionsByTopicResponse"][ + "ListSubscriptionsByTopicResult" + ]["NextToken"] + next_token.should.equal(DEFAULT_PAGE_SIZE) + + topic1_subscriptions = conn.get_all_subscriptions_by_topic( + topic1_arn, next_token=next_token + ) + topic1_subscriptions["ListSubscriptionsByTopicResponse"][ + "ListSubscriptionsByTopicResult" + ]["Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE / 3)) + next_token = topic1_subscriptions["ListSubscriptionsByTopicResponse"][ + "ListSubscriptionsByTopicResult" + ]["NextToken"] + next_token.should.equal(None) diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index d7a32e0c6..faf3ae4a5 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -1,396 +1,507 @@ -from __future__ import unicode_literals -import boto3 -import json - -import sure # noqa - -from botocore.exceptions import ClientError -from nose.tools import assert_raises - -from moto import mock_sns -from moto.sns.models import DEFAULT_PAGE_SIZE - - -@mock_sns -def test_subscribe_sms(): - client = boto3.client('sns', region_name='us-east-1') - client.create_topic(Name="some-topic") - resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] - - resp = client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='+15551234567' - ) - resp.should.contain('SubscriptionArn') - -@mock_sns -def test_double_subscription(): - client = boto3.client('sns', region_name='us-east-1') - client.create_topic(Name="some-topic") - resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] - - do_subscribe_sqs = lambda sqs_arn: client.subscribe( - TopicArn=arn, - Protocol='sqs', - Endpoint=sqs_arn - ) - resp1 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') - resp2 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') - - resp1['SubscriptionArn'].should.equal(resp2['SubscriptionArn']) - - -@mock_sns -def test_subscribe_bad_sms(): - client = boto3.client('sns', region_name='us-east-1') - client.create_topic(Name="some-topic") - resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] - - try: - # Test invalid number - client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='NAA+15551234567' - ) - except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') - - -@mock_sns -def test_creating_subscription(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/") - - subscriptions = conn.list_subscriptions()["Subscriptions"] - subscriptions.should.have.length_of(1) - subscription = subscriptions[0] - subscription["TopicArn"].should.equal(topic_arn) - subscription["Protocol"].should.equal("http") - subscription["SubscriptionArn"].should.contain(topic_arn) - subscription["Endpoint"].should.equal("http://example.com/") - - # Now unsubscribe the subscription - conn.unsubscribe(SubscriptionArn=subscription["SubscriptionArn"]) - - # And there should be zero subscriptions left - subscriptions = conn.list_subscriptions()["Subscriptions"] - subscriptions.should.have.length_of(0) - - -@mock_sns -def test_deleting_subscriptions_by_deleting_topic(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/") - - subscriptions = conn.list_subscriptions()["Subscriptions"] - subscriptions.should.have.length_of(1) - subscription = subscriptions[0] - subscription["TopicArn"].should.equal(topic_arn) - subscription["Protocol"].should.equal("http") - subscription["SubscriptionArn"].should.contain(topic_arn) - subscription["Endpoint"].should.equal("http://example.com/") - - # Now delete the topic - conn.delete_topic(TopicArn=topic_arn) - - # And there should now be 0 topics - topics_json = conn.list_topics() - topics = topics_json["Topics"] - topics.should.have.length_of(0) - - # And there should be zero subscriptions left - subscriptions = conn.list_subscriptions()["Subscriptions"] - subscriptions.should.have.length_of(0) - - -@mock_sns -def test_getting_subscriptions_by_topic(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="topic1") - conn.create_topic(Name="topic2") - - response = conn.list_topics() - topics = response["Topics"] - topic1_arn = topics[0]['TopicArn'] - topic2_arn = topics[1]['TopicArn'] - - conn.subscribe(TopicArn=topic1_arn, - Protocol="http", - Endpoint="http://example1.com/") - conn.subscribe(TopicArn=topic2_arn, - Protocol="http", - Endpoint="http://example2.com/") - - topic1_subscriptions = conn.list_subscriptions_by_topic(TopicArn=topic1_arn)[ - "Subscriptions"] - topic1_subscriptions.should.have.length_of(1) - topic1_subscriptions[0]['Endpoint'].should.equal("http://example1.com/") - - -@mock_sns -def test_subscription_paging(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="topic1") - - response = conn.list_topics() - topics = response["Topics"] - topic1_arn = topics[0]['TopicArn'] - - for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 3)): - conn.subscribe(TopicArn=topic1_arn, - Protocol='email', - Endpoint='email_' + str(index) + '@test.com') - - all_subscriptions = conn.list_subscriptions() - all_subscriptions["Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) - next_token = all_subscriptions["NextToken"] - next_token.should.equal(str(DEFAULT_PAGE_SIZE)) - - all_subscriptions = conn.list_subscriptions(NextToken=next_token) - all_subscriptions["Subscriptions"].should.have.length_of( - int(DEFAULT_PAGE_SIZE / 3)) - all_subscriptions.shouldnt.have("NextToken") - - topic1_subscriptions = conn.list_subscriptions_by_topic( - TopicArn=topic1_arn) - topic1_subscriptions["Subscriptions"].should.have.length_of( - DEFAULT_PAGE_SIZE) - next_token = topic1_subscriptions["NextToken"] - next_token.should.equal(str(DEFAULT_PAGE_SIZE)) - - topic1_subscriptions = conn.list_subscriptions_by_topic( - TopicArn=topic1_arn, NextToken=next_token) - topic1_subscriptions["Subscriptions"].should.have.length_of( - int(DEFAULT_PAGE_SIZE / 3)) - topic1_subscriptions.shouldnt.have("NextToken") - - -@mock_sns -def test_creating_subscription_with_attributes(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - delivery_policy = json.dumps({ - 'healthyRetryPolicy': { - "numRetries": 10, - "minDelayTarget": 1, - "maxDelayTarget":2 - } - }) - - filter_policy = json.dumps({ - "store": ["example_corp"], - "event": ["order_cancelled"], - "encrypted": [False], - "customer_interests": ["basketball", "baseball"] - }) - - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/", - Attributes={ - 'RawMessageDelivery': 'true', - 'DeliveryPolicy': delivery_policy, - 'FilterPolicy': filter_policy - }) - - subscriptions = conn.list_subscriptions()["Subscriptions"] - subscriptions.should.have.length_of(1) - subscription = subscriptions[0] - subscription["TopicArn"].should.equal(topic_arn) - subscription["Protocol"].should.equal("http") - subscription["SubscriptionArn"].should.contain(topic_arn) - subscription["Endpoint"].should.equal("http://example.com/") - - # Test the subscription attributes have been set - subscription_arn = subscription["SubscriptionArn"] - attrs = conn.get_subscription_attributes( - SubscriptionArn=subscription_arn - ) - - attrs['Attributes']['RawMessageDelivery'].should.equal('true') - attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy) - attrs['Attributes']['FilterPolicy'].should.equal(filter_policy) - - # Now unsubscribe the subscription - conn.unsubscribe(SubscriptionArn=subscription["SubscriptionArn"]) - - # And there should be zero subscriptions left - subscriptions = conn.list_subscriptions()["Subscriptions"] - subscriptions.should.have.length_of(0) - - # invalid attr name - with assert_raises(ClientError): - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/", - Attributes={ - 'InvalidName': 'true' - }) - - -@mock_sns -def test_set_subscription_attributes(): - conn = boto3.client('sns', region_name='us-east-1') - conn.create_topic(Name="some-topic") - response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] - - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/") - - subscriptions = conn.list_subscriptions()["Subscriptions"] - subscriptions.should.have.length_of(1) - subscription = subscriptions[0] - subscription["TopicArn"].should.equal(topic_arn) - subscription["Protocol"].should.equal("http") - subscription["SubscriptionArn"].should.contain(topic_arn) - subscription["Endpoint"].should.equal("http://example.com/") - - subscription_arn = subscription["SubscriptionArn"] - attrs = conn.get_subscription_attributes( - SubscriptionArn=subscription_arn - ) - attrs.should.have.key('Attributes') - conn.set_subscription_attributes( - SubscriptionArn=subscription_arn, - AttributeName='RawMessageDelivery', - AttributeValue='true' - ) - delivery_policy = json.dumps({ - 'healthyRetryPolicy': { - "numRetries": 10, - "minDelayTarget": 1, - "maxDelayTarget":2 - } - }) - conn.set_subscription_attributes( - SubscriptionArn=subscription_arn, - AttributeName='DeliveryPolicy', - AttributeValue=delivery_policy - ) - - filter_policy = json.dumps({ - "store": ["example_corp"], - "event": ["order_cancelled"], - "encrypted": [False], - "customer_interests": ["basketball", "baseball"] - }) - conn.set_subscription_attributes( - SubscriptionArn=subscription_arn, - AttributeName='FilterPolicy', - AttributeValue=filter_policy - ) - - attrs = conn.get_subscription_attributes( - SubscriptionArn=subscription_arn - ) - - attrs['Attributes']['RawMessageDelivery'].should.equal('true') - attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy) - attrs['Attributes']['FilterPolicy'].should.equal(filter_policy) - - # not existing subscription - with assert_raises(ClientError): - conn.set_subscription_attributes( - SubscriptionArn='invalid', - AttributeName='RawMessageDelivery', - AttributeValue='true' - ) - with assert_raises(ClientError): - attrs = conn.get_subscription_attributes( - SubscriptionArn='invalid' - ) - - - # invalid attr name - with assert_raises(ClientError): - conn.set_subscription_attributes( - SubscriptionArn=subscription_arn, - AttributeName='InvalidName', - AttributeValue='true' - ) - - -@mock_sns -def test_check_not_opted_out(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.check_if_phone_number_is_opted_out(phoneNumber='+447428545375') - - response.should.contain('isOptedOut') - response['isOptedOut'].should.be(False) - - -@mock_sns -def test_check_opted_out(): - # Phone number ends in 99 so is hardcoded in the endpoint to return opted - # out status - conn = boto3.client('sns', region_name='us-east-1') - response = conn.check_if_phone_number_is_opted_out(phoneNumber='+447428545399') - - response.should.contain('isOptedOut') - response['isOptedOut'].should.be(True) - - -@mock_sns -def test_check_opted_out_invalid(): - conn = boto3.client('sns', region_name='us-east-1') - - # Invalid phone number - with assert_raises(ClientError): - conn.check_if_phone_number_is_opted_out(phoneNumber='+44742LALALA') - - -@mock_sns -def test_list_opted_out(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.list_phone_numbers_opted_out() - - response.should.contain('phoneNumbers') - len(response['phoneNumbers']).should.be.greater_than(0) - - -@mock_sns -def test_opt_in(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.list_phone_numbers_opted_out() - current_len = len(response['phoneNumbers']) - assert current_len > 0 - - conn.opt_in_phone_number(phoneNumber=response['phoneNumbers'][0]) - - response = conn.list_phone_numbers_opted_out() - len(response['phoneNumbers']).should.be.greater_than(0) - len(response['phoneNumbers']).should.be.lower_than(current_len) - - -@mock_sns -def test_confirm_subscription(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.create_topic(Name='testconfirm') - - conn.confirm_subscription( - TopicArn=response['TopicArn'], - Token='2336412f37fb687f5d51e6e241d59b68c4e583a5cee0be6f95bbf97ab8d2441cf47b99e848408adaadf4c197e65f03473d53c4ba398f6abbf38ce2e8ebf7b4ceceb2cd817959bcde1357e58a2861b05288c535822eb88cac3db04f592285249971efc6484194fc4a4586147f16916692', - AuthenticateOnUnsubscribe='true' - ) +from __future__ import unicode_literals +import boto3 +import json + +import sure # noqa + +from botocore.exceptions import ClientError +from nose.tools import assert_raises + +from moto import mock_sns +from moto.sns.models import ( + DEFAULT_PAGE_SIZE, + DEFAULT_EFFECTIVE_DELIVERY_POLICY, + DEFAULT_ACCOUNT_ID, +) + + +@mock_sns +def test_subscribe_sms(): + client = boto3.client("sns", region_name="us-east-1") + client.create_topic(Name="some-topic") + resp = client.create_topic(Name="some-topic") + arn = resp["TopicArn"] + + resp = client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="+15551234567") + resp.should.have.key("SubscriptionArn") + + resp = client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="+15/55-123.4567") + resp.should.have.key("SubscriptionArn") + + +@mock_sns +def test_double_subscription(): + client = boto3.client("sns", region_name="us-east-1") + client.create_topic(Name="some-topic") + resp = client.create_topic(Name="some-topic") + arn = resp["TopicArn"] + + do_subscribe_sqs = lambda sqs_arn: client.subscribe( + TopicArn=arn, Protocol="sqs", Endpoint=sqs_arn + ) + resp1 = do_subscribe_sqs("arn:aws:sqs:elasticmq:000000000000:foo") + resp2 = do_subscribe_sqs("arn:aws:sqs:elasticmq:000000000000:foo") + + resp1["SubscriptionArn"].should.equal(resp2["SubscriptionArn"]) + + +@mock_sns +def test_subscribe_bad_sms(): + client = boto3.client("sns", region_name="us-east-1") + client.create_topic(Name="some-topic") + resp = client.create_topic(Name="some-topic") + arn = resp["TopicArn"] + + try: + # Test invalid number + client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="NAA+15551234567") + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + + client.subscribe.when.called_with( + TopicArn=arn, Protocol="sms", Endpoint="+15--551234567" + ).should.throw(ClientError, "Invalid SMS endpoint: +15--551234567") + + client.subscribe.when.called_with( + TopicArn=arn, Protocol="sms", Endpoint="+15551234567." + ).should.throw(ClientError, "Invalid SMS endpoint: +15551234567.") + + client.subscribe.when.called_with( + TopicArn=arn, Protocol="sms", Endpoint="/+15551234567" + ).should.throw(ClientError, "Invalid SMS endpoint: /+15551234567") + + +@mock_sns +def test_creating_subscription(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + conn.subscribe(TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/") + + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Now unsubscribe the subscription + conn.unsubscribe(SubscriptionArn=subscription["SubscriptionArn"]) + + # And there should be zero subscriptions left + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(0) + + +@mock_sns +def test_deleting_subscriptions_by_deleting_topic(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + conn.subscribe(TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/") + + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Now delete the topic + conn.delete_topic(TopicArn=topic_arn) + + # And there should now be 0 topics + topics_json = conn.list_topics() + topics = topics_json["Topics"] + topics.should.have.length_of(0) + + # And there should be zero subscriptions left + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(0) + + +@mock_sns +def test_getting_subscriptions_by_topic(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="topic1") + conn.create_topic(Name="topic2") + + response = conn.list_topics() + topics = response["Topics"] + topic1_arn = topics[0]["TopicArn"] + topic2_arn = topics[1]["TopicArn"] + + conn.subscribe( + TopicArn=topic1_arn, Protocol="http", Endpoint="http://example1.com/" + ) + conn.subscribe( + TopicArn=topic2_arn, Protocol="http", Endpoint="http://example2.com/" + ) + + topic1_subscriptions = conn.list_subscriptions_by_topic(TopicArn=topic1_arn)[ + "Subscriptions" + ] + topic1_subscriptions.should.have.length_of(1) + topic1_subscriptions[0]["Endpoint"].should.equal("http://example1.com/") + + +@mock_sns +def test_subscription_paging(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="topic1") + + response = conn.list_topics() + topics = response["Topics"] + topic1_arn = topics[0]["TopicArn"] + + for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 3)): + conn.subscribe( + TopicArn=topic1_arn, + Protocol="email", + Endpoint="email_" + str(index) + "@test.com", + ) + + all_subscriptions = conn.list_subscriptions() + all_subscriptions["Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) + next_token = all_subscriptions["NextToken"] + next_token.should.equal(str(DEFAULT_PAGE_SIZE)) + + all_subscriptions = conn.list_subscriptions(NextToken=next_token) + all_subscriptions["Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE / 3)) + all_subscriptions.shouldnt.have("NextToken") + + topic1_subscriptions = conn.list_subscriptions_by_topic(TopicArn=topic1_arn) + topic1_subscriptions["Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) + next_token = topic1_subscriptions["NextToken"] + next_token.should.equal(str(DEFAULT_PAGE_SIZE)) + + topic1_subscriptions = conn.list_subscriptions_by_topic( + TopicArn=topic1_arn, NextToken=next_token + ) + topic1_subscriptions["Subscriptions"].should.have.length_of( + int(DEFAULT_PAGE_SIZE / 3) + ) + topic1_subscriptions.shouldnt.have("NextToken") + + +@mock_sns +def test_subscribe_attributes(): + client = boto3.client("sns", region_name="us-east-1") + client.create_topic(Name="some-topic") + resp = client.create_topic(Name="some-topic") + arn = resp["TopicArn"] + + resp = client.subscribe(TopicArn=arn, Protocol="http", Endpoint="http://test.com") + + response = client.get_subscription_attributes( + SubscriptionArn=resp["SubscriptionArn"] + ) + + response.should.contain("Attributes") + attributes = response["Attributes"] + attributes["PendingConfirmation"].should.equal("false") + attributes["ConfirmationWasAuthenticated"].should.equal("true") + attributes["Endpoint"].should.equal("http://test.com") + attributes["TopicArn"].should.equal(arn) + attributes["Protocol"].should.equal("http") + attributes["SubscriptionArn"].should.equal(resp["SubscriptionArn"]) + attributes["Owner"].should.equal(str(DEFAULT_ACCOUNT_ID)) + attributes["RawMessageDelivery"].should.equal("false") + json.loads(attributes["EffectiveDeliveryPolicy"]).should.equal( + DEFAULT_EFFECTIVE_DELIVERY_POLICY + ) + + +@mock_sns +def test_creating_subscription_with_attributes(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + delivery_policy = json.dumps( + { + "healthyRetryPolicy": { + "numRetries": 10, + "minDelayTarget": 1, + "maxDelayTarget": 2, + } + } + ) + + filter_policy = json.dumps( + { + "store": ["example_corp"], + "event": ["order_cancelled"], + "encrypted": [False], + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None], + } + ) + + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + "RawMessageDelivery": "true", + "DeliveryPolicy": delivery_policy, + "FilterPolicy": filter_policy, + }, + ) + + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + # Test the subscription attributes have been set + subscription_arn = subscription["SubscriptionArn"] + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) + + attrs["Attributes"]["RawMessageDelivery"].should.equal("true") + attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy) + attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy) + + # Now unsubscribe the subscription + conn.unsubscribe(SubscriptionArn=subscription["SubscriptionArn"]) + + # And there should be zero subscriptions left + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(0) + + # invalid attr name + with assert_raises(ClientError): + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"InvalidName": "true"}, + ) + + +@mock_sns +def test_set_subscription_attributes(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + conn.subscribe(TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/") + + subscriptions = conn.list_subscriptions()["Subscriptions"] + subscriptions.should.have.length_of(1) + subscription = subscriptions[0] + subscription["TopicArn"].should.equal(topic_arn) + subscription["Protocol"].should.equal("http") + subscription["SubscriptionArn"].should.contain(topic_arn) + subscription["Endpoint"].should.equal("http://example.com/") + + subscription_arn = subscription["SubscriptionArn"] + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) + attrs.should.have.key("Attributes") + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName="RawMessageDelivery", + AttributeValue="true", + ) + delivery_policy = json.dumps( + { + "healthyRetryPolicy": { + "numRetries": 10, + "minDelayTarget": 1, + "maxDelayTarget": 2, + } + } + ) + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName="DeliveryPolicy", + AttributeValue=delivery_policy, + ) + + filter_policy = json.dumps( + { + "store": ["example_corp"], + "event": ["order_cancelled"], + "encrypted": [False], + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None], + } + ) + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName="FilterPolicy", + AttributeValue=filter_policy, + ) + + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) + + attrs["Attributes"]["RawMessageDelivery"].should.equal("true") + attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy) + attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy) + + # not existing subscription + with assert_raises(ClientError): + conn.set_subscription_attributes( + SubscriptionArn="invalid", + AttributeName="RawMessageDelivery", + AttributeValue="true", + ) + with assert_raises(ClientError): + attrs = conn.get_subscription_attributes(SubscriptionArn="invalid") + + # invalid attr name + with assert_raises(ClientError): + conn.set_subscription_attributes( + SubscriptionArn=subscription_arn, + AttributeName="InvalidName", + AttributeValue="true", + ) + + +@mock_sns +def test_subscribe_invalid_filter_policy(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") + response = conn.list_topics() + topic_arn = response["Topics"][0]["TopicArn"] + + try: + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + "FilterPolicy": json.dumps({"store": [str(i) for i in range(151)]}) + }, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Filter policy is too complex" + ) + + try: + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [["example_corp"]]})}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" + ) + + try: + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [{"exists": None}]})}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: exists match pattern must be either true or false." + ) + + try: + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [{"error": True}]})}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Unrecognized match type error" + ) + + try: + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [1000000001]})}, + ) + except ClientError as err: + err.response["Error"]["Code"].should.equal("InternalFailure") + + +@mock_sns +def test_check_not_opted_out(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.check_if_phone_number_is_opted_out(phoneNumber="+447428545375") + + response.should.contain("isOptedOut") + response["isOptedOut"].should.be(False) + + +@mock_sns +def test_check_opted_out(): + # Phone number ends in 99 so is hardcoded in the endpoint to return opted + # out status + conn = boto3.client("sns", region_name="us-east-1") + response = conn.check_if_phone_number_is_opted_out(phoneNumber="+447428545399") + + response.should.contain("isOptedOut") + response["isOptedOut"].should.be(True) + + +@mock_sns +def test_check_opted_out_invalid(): + conn = boto3.client("sns", region_name="us-east-1") + + # Invalid phone number + with assert_raises(ClientError): + conn.check_if_phone_number_is_opted_out(phoneNumber="+44742LALALA") + + +@mock_sns +def test_list_opted_out(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.list_phone_numbers_opted_out() + + response.should.contain("phoneNumbers") + len(response["phoneNumbers"]).should.be.greater_than(0) + + +@mock_sns +def test_opt_in(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.list_phone_numbers_opted_out() + current_len = len(response["phoneNumbers"]) + assert current_len > 0 + + conn.opt_in_phone_number(phoneNumber=response["phoneNumbers"][0]) + + response = conn.list_phone_numbers_opted_out() + len(response["phoneNumbers"]).should.be.greater_than(0) + len(response["phoneNumbers"]).should.be.lower_than(current_len) + + +@mock_sns +def test_confirm_subscription(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.create_topic(Name="testconfirm") + + conn.confirm_subscription( + TopicArn=response["TopicArn"], + Token="2336412f37fb687f5d51e6e241d59b68c4e583a5cee0be6f95bbf97ab8d2441cf47b99e848408adaadf4c197e65f03473d53c4ba398f6abbf38ce2e8ebf7b4ceceb2cd817959bcde1357e58a2861b05288c535822eb88cac3db04f592285249971efc6484194fc4a4586147f16916692", + AuthenticateOnUnsubscribe="true", + ) diff --git a/tests/test_sns/test_topics.py b/tests/test_sns/test_topics.py index 928db8d02..e91ab6e2d 100644 --- a/tests/test_sns/test_topics.py +++ b/tests/test_sns/test_topics.py @@ -1,133 +1,164 @@ -from __future__ import unicode_literals -import boto -import json -import six - -import sure # noqa - -from boto.exception import BotoServerError -from moto import mock_sns_deprecated -from moto.sns.models import DEFAULT_TOPIC_POLICY, DEFAULT_EFFECTIVE_DELIVERY_POLICY, DEFAULT_PAGE_SIZE - - -@mock_sns_deprecated -def test_create_and_delete_topic(): - conn = boto.connect_sns() - conn.create_topic("some-topic") - - topics_json = conn.get_all_topics() - topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] - topics.should.have.length_of(1) - topics[0]['TopicArn'].should.equal( - "arn:aws:sns:{0}:123456789012:some-topic" - .format(conn.region.name) - ) - - # Delete the topic - conn.delete_topic(topics[0]['TopicArn']) - - # And there should now be 0 topics - topics_json = conn.get_all_topics() - topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] - topics.should.have.length_of(0) - - -@mock_sns_deprecated -def test_get_missing_topic(): - conn = boto.connect_sns() - conn.get_topic_attributes.when.called_with( - "a-fake-arn").should.throw(BotoServerError) - - -@mock_sns_deprecated -def test_create_topic_in_multiple_regions(): - for region in ['us-west-1', 'us-west-2']: - conn = boto.sns.connect_to_region(region) - conn.create_topic("some-topic") - list(conn.get_all_topics()["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"]).should.have.length_of(1) - - -@mock_sns_deprecated -def test_topic_corresponds_to_region(): - for region in ['us-east-1', 'us-west-2']: - conn = boto.sns.connect_to_region(region) - conn.create_topic("some-topic") - topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] - topic_arn.should.equal( - "arn:aws:sns:{0}:123456789012:some-topic".format(region)) - - -@mock_sns_deprecated -def test_topic_attributes(): - conn = boto.connect_sns() - conn.create_topic("some-topic") - - topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] - - attributes = conn.get_topic_attributes(topic_arn)['GetTopicAttributesResponse'][ - 'GetTopicAttributesResult']['Attributes'] - attributes["TopicArn"].should.equal( - "arn:aws:sns:{0}:123456789012:some-topic" - .format(conn.region.name) - ) - attributes["Owner"].should.equal(123456789012) - json.loads(attributes["Policy"]).should.equal(DEFAULT_TOPIC_POLICY) - attributes["DisplayName"].should.equal("") - attributes["SubscriptionsPending"].should.equal(0) - attributes["SubscriptionsConfirmed"].should.equal(0) - attributes["SubscriptionsDeleted"].should.equal(0) - attributes["DeliveryPolicy"].should.equal("") - json.loads(attributes["EffectiveDeliveryPolicy"]).should.equal( - DEFAULT_EFFECTIVE_DELIVERY_POLICY) - - # boto can't handle prefix-mandatory strings: - # i.e. unicode on Python 2 -- u"foobar" - # and bytes on Python 3 -- b"foobar" - if six.PY2: - policy = {b"foo": b"bar"} - displayname = b"My display name" - delivery = {b"http": {b"defaultHealthyRetryPolicy": {b"numRetries": 5}}} - else: - policy = {u"foo": u"bar"} - displayname = u"My display name" - delivery = {u"http": {u"defaultHealthyRetryPolicy": {u"numRetries": 5}}} - conn.set_topic_attributes(topic_arn, "Policy", policy) - conn.set_topic_attributes(topic_arn, "DisplayName", displayname) - conn.set_topic_attributes(topic_arn, "DeliveryPolicy", delivery) - - attributes = conn.get_topic_attributes(topic_arn)['GetTopicAttributesResponse'][ - 'GetTopicAttributesResult']['Attributes'] - attributes["Policy"].should.equal("{'foo': 'bar'}") - attributes["DisplayName"].should.equal("My display name") - attributes["DeliveryPolicy"].should.equal( - "{'http': {'defaultHealthyRetryPolicy': {'numRetries': 5}}}") - - -@mock_sns_deprecated -def test_topic_paging(): - conn = boto.connect_sns() - for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 2)): - conn.create_topic("some-topic_" + str(index)) - - topics_json = conn.get_all_topics() - topics_list = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"] - next_token = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["NextToken"] - - len(topics_list).should.equal(DEFAULT_PAGE_SIZE) - next_token.should.equal(DEFAULT_PAGE_SIZE) - - topics_json = conn.get_all_topics(next_token=next_token) - topics_list = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"] - next_token = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["NextToken"] - - topics_list.should.have.length_of(int(DEFAULT_PAGE_SIZE / 2)) - next_token.should.equal(None) +from __future__ import unicode_literals +import boto +import json +import six + +import sure # noqa + +from boto.exception import BotoServerError +from moto import mock_sns_deprecated +from moto.sns.models import DEFAULT_EFFECTIVE_DELIVERY_POLICY, DEFAULT_PAGE_SIZE +from moto.core import ACCOUNT_ID + + +@mock_sns_deprecated +def test_create_and_delete_topic(): + conn = boto.connect_sns() + conn.create_topic("some-topic") + + topics_json = conn.get_all_topics() + topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + topics.should.have.length_of(1) + topics[0]["TopicArn"].should.equal( + "arn:aws:sns:{0}:{1}:some-topic".format(conn.region.name, ACCOUNT_ID) + ) + + # Delete the topic + conn.delete_topic(topics[0]["TopicArn"]) + + # And there should now be 0 topics + topics_json = conn.get_all_topics() + topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + topics.should.have.length_of(0) + + +@mock_sns_deprecated +def test_get_missing_topic(): + conn = boto.connect_sns() + conn.get_topic_attributes.when.called_with("a-fake-arn").should.throw( + BotoServerError + ) + + +@mock_sns_deprecated +def test_create_topic_in_multiple_regions(): + for region in ["us-west-1", "us-west-2"]: + conn = boto.sns.connect_to_region(region) + conn.create_topic("some-topic") + list( + conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + ).should.have.length_of(1) + + +@mock_sns_deprecated +def test_topic_corresponds_to_region(): + for region in ["us-east-1", "us-west-2"]: + conn = boto.sns.connect_to_region(region) + conn.create_topic("some-topic") + topics_json = conn.get_all_topics() + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] + topic_arn.should.equal( + "arn:aws:sns:{0}:{1}:some-topic".format(region, ACCOUNT_ID) + ) + + +@mock_sns_deprecated +def test_topic_attributes(): + conn = boto.connect_sns() + conn.create_topic("some-topic") + + topics_json = conn.get_all_topics() + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] + + attributes = conn.get_topic_attributes(topic_arn)["GetTopicAttributesResponse"][ + "GetTopicAttributesResult" + ]["Attributes"] + attributes["TopicArn"].should.equal( + "arn:aws:sns:{0}:{1}:some-topic".format(conn.region.name, ACCOUNT_ID) + ) + attributes["Owner"].should.equal(ACCOUNT_ID) + json.loads(attributes["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:{}:some-topic".format( + ACCOUNT_ID + ), + "Condition": {"StringEquals": {"AWS:SourceOwner": ACCOUNT_ID}}, + } + ], + } + ) + attributes["DisplayName"].should.equal("") + attributes["SubscriptionsPending"].should.equal(0) + attributes["SubscriptionsConfirmed"].should.equal(0) + attributes["SubscriptionsDeleted"].should.equal(0) + attributes["DeliveryPolicy"].should.equal("") + json.loads(attributes["EffectiveDeliveryPolicy"]).should.equal( + DEFAULT_EFFECTIVE_DELIVERY_POLICY + ) + + # boto can't handle prefix-mandatory strings: + # i.e. unicode on Python 2 -- u"foobar" + # and bytes on Python 3 -- b"foobar" + if six.PY2: + policy = json.dumps({b"foo": b"bar"}) + displayname = b"My display name" + delivery = {b"http": {b"defaultHealthyRetryPolicy": {b"numRetries": 5}}} + else: + policy = json.dumps({"foo": "bar"}) + displayname = "My display name" + delivery = {"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}} + conn.set_topic_attributes(topic_arn, "Policy", policy) + conn.set_topic_attributes(topic_arn, "DisplayName", displayname) + conn.set_topic_attributes(topic_arn, "DeliveryPolicy", delivery) + + attributes = conn.get_topic_attributes(topic_arn)["GetTopicAttributesResponse"][ + "GetTopicAttributesResult" + ]["Attributes"] + attributes["Policy"].should.equal('{"foo": "bar"}') + attributes["DisplayName"].should.equal("My display name") + attributes["DeliveryPolicy"].should.equal( + "{'http': {'defaultHealthyRetryPolicy': {'numRetries': 5}}}" + ) + + +@mock_sns_deprecated +def test_topic_paging(): + conn = boto.connect_sns() + for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 2)): + conn.create_topic("some-topic_" + str(index)) + + topics_json = conn.get_all_topics() + topics_list = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + next_token = topics_json["ListTopicsResponse"]["ListTopicsResult"]["NextToken"] + + len(topics_list).should.equal(DEFAULT_PAGE_SIZE) + next_token.should.equal(DEFAULT_PAGE_SIZE) + + topics_json = conn.get_all_topics(next_token=next_token) + topics_list = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + next_token = topics_json["ListTopicsResponse"]["ListTopicsResult"]["NextToken"] + + topics_list.should.have.length_of(int(DEFAULT_PAGE_SIZE / 2)) + next_token.should.equal(None) diff --git a/tests/test_sns/test_topics_boto3.py b/tests/test_sns/test_topics_boto3.py index 870fa6f6e..87800bd84 100644 --- a/tests/test_sns/test_topics_boto3.py +++ b/tests/test_sns/test_topics_boto3.py @@ -7,25 +7,27 @@ import sure # noqa from botocore.exceptions import ClientError from moto import mock_sns -from moto.sns.models import DEFAULT_TOPIC_POLICY, DEFAULT_EFFECTIVE_DELIVERY_POLICY, DEFAULT_PAGE_SIZE +from moto.sns.models import DEFAULT_EFFECTIVE_DELIVERY_POLICY, DEFAULT_PAGE_SIZE +from moto.core import ACCOUNT_ID @mock_sns def test_create_and_delete_topic(): conn = boto3.client("sns", region_name="us-east-1") - for topic_name in ('some-topic', '-some-topic-', '_some-topic_', 'a' * 256): + for topic_name in ("some-topic", "-some-topic-", "_some-topic_", "a" * 256): conn.create_topic(Name=topic_name) topics_json = conn.list_topics() topics = topics_json["Topics"] topics.should.have.length_of(1) - topics[0]['TopicArn'].should.equal( - "arn:aws:sns:{0}:123456789012:{1}" - .format(conn._client_config.region_name, topic_name) + topics[0]["TopicArn"].should.equal( + "arn:aws:sns:{0}:{1}:{2}".format( + conn._client_config.region_name, ACCOUNT_ID, topic_name + ) ) # Delete the topic - conn.delete_topic(TopicArn=topics[0]['TopicArn']) + conn.delete_topic(TopicArn=topics[0]["TopicArn"]) # And there should now be 0 topics topics_json = conn.list_topics() @@ -36,66 +38,89 @@ def test_create_and_delete_topic(): @mock_sns def test_create_topic_with_attributes(): conn = boto3.client("sns", region_name="us-east-1") - conn.create_topic(Name='some-topic-with-attribute', Attributes={'DisplayName': 'test-topic'}) + conn.create_topic( + Name="some-topic-with-attribute", Attributes={"DisplayName": "test-topic"} + ) topics_json = conn.list_topics() - topic_arn = topics_json["Topics"][0]['TopicArn'] + topic_arn = topics_json["Topics"][0]["TopicArn"] - attributes = conn.get_topic_attributes(TopicArn=topic_arn)['Attributes'] - attributes['DisplayName'].should.equal('test-topic') + attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"] + attributes["DisplayName"].should.equal("test-topic") + + +@mock_sns +def test_create_topic_with_tags(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.create_topic( + Name="some-topic-with-tags", + Tags=[ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, + ], + ) + topic_arn = response["TopicArn"] + + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, + ] + ) @mock_sns def test_create_topic_should_be_indempodent(): conn = boto3.client("sns", region_name="us-east-1") - topic_arn = conn.create_topic(Name="some-topic")['TopicArn'] + topic_arn = conn.create_topic(Name="some-topic")["TopicArn"] conn.set_topic_attributes( - TopicArn=topic_arn, - AttributeName="DisplayName", - AttributeValue="should_be_set" + TopicArn=topic_arn, AttributeName="DisplayName", AttributeValue="should_be_set" ) - topic_display_name = conn.get_topic_attributes( - TopicArn=topic_arn - )['Attributes']['DisplayName'] + topic_display_name = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"][ + "DisplayName" + ] topic_display_name.should.be.equal("should_be_set") - #recreate topic to prove indempodentcy - topic_arn = conn.create_topic(Name="some-topic")['TopicArn'] - topic_display_name = conn.get_topic_attributes( - TopicArn=topic_arn - )['Attributes']['DisplayName'] + # recreate topic to prove indempodentcy + topic_arn = conn.create_topic(Name="some-topic")["TopicArn"] + topic_display_name = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"][ + "DisplayName" + ] topic_display_name.should.be.equal("should_be_set") + @mock_sns def test_get_missing_topic(): conn = boto3.client("sns", region_name="us-east-1") - conn.get_topic_attributes.when.called_with( - TopicArn="a-fake-arn").should.throw(ClientError) + conn.get_topic_attributes.when.called_with(TopicArn="a-fake-arn").should.throw( + ClientError + ) + @mock_sns def test_create_topic_must_meet_constraints(): conn = boto3.client("sns", region_name="us-east-1") - common_random_chars = [':', ";", "!", "@", "|", "^", "%"] + common_random_chars = [":", ";", "!", "@", "|", "^", "%"] for char in common_random_chars: - conn.create_topic.when.called_with( - Name="no%s_invalidchar" % char).should.throw(ClientError) - conn.create_topic.when.called_with( - Name="no spaces allowed").should.throw(ClientError) + conn.create_topic.when.called_with(Name="no%s_invalidchar" % char).should.throw( + ClientError + ) + conn.create_topic.when.called_with(Name="no spaces allowed").should.throw( + ClientError + ) @mock_sns def test_create_topic_should_be_of_certain_length(): conn = boto3.client("sns", region_name="us-east-1") too_short = "" - conn.create_topic.when.called_with( - Name=too_short).should.throw(ClientError) + conn.create_topic.when.called_with(Name=too_short).should.throw(ClientError) too_long = "x" * 257 - conn.create_topic.when.called_with( - Name=too_long).should.throw(ClientError) + conn.create_topic.when.called_with(Name=too_long).should.throw(ClientError) @mock_sns def test_create_topic_in_multiple_regions(): - for region in ['us-west-1', 'us-west-2']: + for region in ["us-west-1", "us-west-2"]: conn = boto3.client("sns", region_name=region) conn.create_topic(Name="some-topic") list(conn.list_topics()["Topics"]).should.have.length_of(1) @@ -103,13 +128,14 @@ def test_create_topic_in_multiple_regions(): @mock_sns def test_topic_corresponds_to_region(): - for region in ['us-east-1', 'us-west-2']: + for region in ["us-east-1", "us-west-2"]: conn = boto3.client("sns", region_name=region) conn.create_topic(Name="some-topic") topics_json = conn.list_topics() - topic_arn = topics_json["Topics"][0]['TopicArn'] + topic_arn = topics_json["Topics"][0]["TopicArn"] topic_arn.should.equal( - "arn:aws:sns:{0}:123456789012:some-topic".format(region)) + "arn:aws:sns:{0}:{1}:some-topic".format(region, ACCOUNT_ID) + ) @mock_sns @@ -118,22 +144,51 @@ def test_topic_attributes(): conn.create_topic(Name="some-topic") topics_json = conn.list_topics() - topic_arn = topics_json["Topics"][0]['TopicArn'] + topic_arn = topics_json["Topics"][0]["TopicArn"] - attributes = conn.get_topic_attributes(TopicArn=topic_arn)['Attributes'] + attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"] attributes["TopicArn"].should.equal( - "arn:aws:sns:{0}:123456789012:some-topic" - .format(conn._client_config.region_name) + "arn:aws:sns:{0}:{1}:some-topic".format( + conn._client_config.region_name, ACCOUNT_ID + ) + ) + attributes["Owner"].should.equal(ACCOUNT_ID) + json.loads(attributes["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:{}:some-topic".format( + ACCOUNT_ID + ), + "Condition": {"StringEquals": {"AWS:SourceOwner": ACCOUNT_ID}}, + } + ], + } ) - attributes["Owner"].should.equal('123456789012') - json.loads(attributes["Policy"]).should.equal(DEFAULT_TOPIC_POLICY) attributes["DisplayName"].should.equal("") - attributes["SubscriptionsPending"].should.equal('0') - attributes["SubscriptionsConfirmed"].should.equal('0') - attributes["SubscriptionsDeleted"].should.equal('0') + attributes["SubscriptionsPending"].should.equal("0") + attributes["SubscriptionsConfirmed"].should.equal("0") + attributes["SubscriptionsDeleted"].should.equal("0") attributes["DeliveryPolicy"].should.equal("") json.loads(attributes["EffectiveDeliveryPolicy"]).should.equal( - DEFAULT_EFFECTIVE_DELIVERY_POLICY) + DEFAULT_EFFECTIVE_DELIVERY_POLICY + ) # boto can't handle prefix-mandatory strings: # i.e. unicode on Python 2 -- u"foobar" @@ -142,27 +197,30 @@ def test_topic_attributes(): policy = json.dumps({b"foo": b"bar"}) displayname = b"My display name" delivery = json.dumps( - {b"http": {b"defaultHealthyRetryPolicy": {b"numRetries": 5}}}) + {b"http": {b"defaultHealthyRetryPolicy": {b"numRetries": 5}}} + ) else: - policy = json.dumps({u"foo": u"bar"}) - displayname = u"My display name" + policy = json.dumps({"foo": "bar"}) + displayname = "My display name" delivery = json.dumps( - {u"http": {u"defaultHealthyRetryPolicy": {u"numRetries": 5}}}) - conn.set_topic_attributes(TopicArn=topic_arn, - AttributeName="Policy", - AttributeValue=policy) - conn.set_topic_attributes(TopicArn=topic_arn, - AttributeName="DisplayName", - AttributeValue=displayname) - conn.set_topic_attributes(TopicArn=topic_arn, - AttributeName="DeliveryPolicy", - AttributeValue=delivery) + {"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}} + ) + conn.set_topic_attributes( + TopicArn=topic_arn, AttributeName="Policy", AttributeValue=policy + ) + conn.set_topic_attributes( + TopicArn=topic_arn, AttributeName="DisplayName", AttributeValue=displayname + ) + conn.set_topic_attributes( + TopicArn=topic_arn, AttributeName="DeliveryPolicy", AttributeValue=delivery + ) - attributes = conn.get_topic_attributes(TopicArn=topic_arn)['Attributes'] + attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"] attributes["Policy"].should.equal('{"foo": "bar"}') attributes["DisplayName"].should.equal("My display name") attributes["DeliveryPolicy"].should.equal( - '{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}') + '{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}' + ) @mock_sns @@ -187,16 +245,269 @@ def test_topic_paging(): @mock_sns def test_add_remove_permissions(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.create_topic(Name='testpermissions') + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="test-permissions")["TopicArn"] - conn.add_permission( - TopicArn=response['TopicArn'], - Label='Test1234', - AWSAccountId=['999999999999'], - ActionName=['AddPermission'] + client.add_permission( + TopicArn=topic_arn, + Label="test", + AWSAccountId=["999999999999"], + ActionName=["Publish"], ) - conn.remove_permission( - TopicArn=response['TopicArn'], - Label='Test1234' + + response = client.get_topic_attributes(TopicArn=topic_arn) + json.loads(response["Attributes"]["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:{}:test-permissions".format( + ACCOUNT_ID + ), + "Condition": {"StringEquals": {"AWS:SourceOwner": ACCOUNT_ID}}, + }, + { + "Sid": "test", + "Effect": "Allow", + "Principal": {"AWS": "arn:aws:iam::999999999999:root"}, + "Action": "SNS:Publish", + "Resource": "arn:aws:sns:us-east-1:{}:test-permissions".format( + ACCOUNT_ID + ), + }, + ], + } ) + + client.remove_permission(TopicArn=topic_arn, Label="test") + + response = client.get_topic_attributes(TopicArn=topic_arn) + json.loads(response["Attributes"]["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:{}:test-permissions".format( + ACCOUNT_ID + ), + "Condition": {"StringEquals": {"AWS:SourceOwner": ACCOUNT_ID}}, + } + ], + } + ) + + client.add_permission( + TopicArn=topic_arn, + Label="test", + AWSAccountId=["888888888888", "999999999999"], + ActionName=["Publish", "Subscribe"], + ) + + response = client.get_topic_attributes(TopicArn=topic_arn) + json.loads(response["Attributes"]["Policy"])["Statement"][1].should.equal( + { + "Sid": "test", + "Effect": "Allow", + "Principal": { + "AWS": [ + "arn:aws:iam::888888888888:root", + "arn:aws:iam::999999999999:root", + ] + }, + "Action": ["SNS:Publish", "SNS:Subscribe"], + "Resource": "arn:aws:sns:us-east-1:{}:test-permissions".format(ACCOUNT_ID), + } + ) + + # deleting non existing permission should be successful + client.remove_permission(TopicArn=topic_arn, Label="non-existing") + + +@mock_sns +def test_add_permission_errors(): + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="test-permissions")["TopicArn"] + client.add_permission( + TopicArn=topic_arn, + Label="test", + AWSAccountId=["999999999999"], + ActionName=["Publish"], + ) + + client.add_permission.when.called_with( + TopicArn=topic_arn, + Label="test", + AWSAccountId=["999999999999"], + ActionName=["AddPermission"], + ).should.throw(ClientError, "Statement already exists") + + client.add_permission.when.called_with( + TopicArn=topic_arn + "-not-existing", + Label="test-2", + AWSAccountId=["999999999999"], + ActionName=["AddPermission"], + ).should.throw(ClientError, "Topic does not exist") + + client.add_permission.when.called_with( + TopicArn=topic_arn, + Label="test-2", + AWSAccountId=["999999999999"], + ActionName=["NotExistingAction"], + ).should.throw(ClientError, "Policy statement action out of service scope!") + + +@mock_sns +def test_remove_permission_errors(): + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="test-permissions")["TopicArn"] + client.add_permission( + TopicArn=topic_arn, + Label="test", + AWSAccountId=["999999999999"], + ActionName=["Publish"], + ) + + client.remove_permission.when.called_with( + TopicArn=topic_arn + "-not-existing", Label="test" + ).should.throw(ClientError, "Topic does not exist") + + +@mock_sns +def test_tag_topic(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.create_topic(Name="some-topic-with-tags") + topic_arn = response["TopicArn"] + + conn.tag_resource( + ResourceArn=topic_arn, Tags=[{"Key": "tag_key_1", "Value": "tag_value_1"}] + ) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_1", "Value": "tag_value_1"}] + ) + + conn.tag_resource( + ResourceArn=topic_arn, Tags=[{"Key": "tag_key_2", "Value": "tag_value_2"}] + ) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, + ] + ) + + conn.tag_resource( + ResourceArn=topic_arn, Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] + ) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [ + {"Key": "tag_key_1", "Value": "tag_value_X"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, + ] + ) + + +@mock_sns +def test_untag_topic(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.create_topic( + Name="some-topic-with-tags", + Tags=[ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, + ], + ) + topic_arn = response["TopicArn"] + + conn.untag_resource(ResourceArn=topic_arn, TagKeys=["tag_key_1"]) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_2", "Value": "tag_value_2"}] + ) + + # removing a non existing tag should not raise any error + conn.untag_resource(ResourceArn=topic_arn, TagKeys=["not-existing-tag"]) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_2", "Value": "tag_value_2"}] + ) + + +@mock_sns +def test_list_tags_for_resource_error(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic( + Name="some-topic-with-tags", Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] + ) + + conn.list_tags_for_resource.when.called_with( + ResourceArn="not-existing-topic" + ).should.throw(ClientError, "Resource does not exist") + + +@mock_sns +def test_tag_resource_errors(): + conn = boto3.client("sns", region_name="us-east-1") + response = conn.create_topic( + Name="some-topic-with-tags", Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] + ) + topic_arn = response["TopicArn"] + + conn.tag_resource.when.called_with( + ResourceArn="not-existing-topic", + Tags=[{"Key": "tag_key_1", "Value": "tag_value_1"}], + ).should.throw(ClientError, "Resource does not exist") + + too_many_tags = [ + {"Key": "tag_key_{}".format(i), "Value": "tag_value_{}".format(i)} + for i in range(51) + ] + conn.tag_resource.when.called_with( + ResourceArn=topic_arn, Tags=too_many_tags + ).should.throw( + ClientError, "Could not complete request: tag quota of per resource exceeded" + ) + + # when the request fails, the tags should not be updated + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_1", "Value": "tag_value_X"}] + ) + + +@mock_sns +def test_untag_resource_error(): + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic( + Name="some-topic-with-tags", Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] + ) + + conn.untag_resource.when.called_with( + ResourceArn="not-existing-topic", TagKeys=["tag_key_1"] + ).should.throw(ClientError, "Resource does not exist") diff --git a/tests/test_sqs/test_server.py b/tests/test_sqs/test_server.py index b2b233bde..0116a93ef 100644 --- a/tests/test_sqs/test_server.py +++ b/tests/test_sqs/test_server.py @@ -1,85 +1,84 @@ -from __future__ import unicode_literals - -import re -import sure # noqa -import threading -import time - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_sqs_list_identities(): - backend = server.create_backend_app("sqs") - test_client = backend.test_client() - - res = test_client.get('/?Action=ListQueues') - res.data.should.contain(b"ListQueuesResponse") - - # Make sure that we can receive messages from queues whose name contains dots (".") - # The AWS API mandates that the names of FIFO queues use the suffix ".fifo" - # See: https://github.com/spulec/moto/issues/866 - - for queue_name in ('testqueue', 'otherqueue.fifo'): - - res = test_client.put('/?Action=CreateQueue&QueueName=%s' % queue_name) - - - res = test_client.put( - '/123/%s?MessageBody=test-message&Action=SendMessage' % queue_name) - - res = test_client.get( - '/123/%s?Action=ReceiveMessage&MaxNumberOfMessages=1' % queue_name) - - message = re.search("(.*?)", - res.data.decode('utf-8')).groups()[0] - message.should.equal('test-message') - - res = test_client.get('/?Action=ListQueues&QueueNamePrefix=other') - res.data.should.contain(b'otherqueue.fifo') - res.data.should_not.contain(b'testqueue') - - -def test_messages_polling(): - backend = server.create_backend_app("sqs") - test_client = backend.test_client() - messages = [] - - test_client.put('/?Action=CreateQueue&QueueName=testqueue') - - def insert_messages(): - messages_count = 5 - while messages_count > 0: - test_client.put( - '/123/testqueue?MessageBody=test-message&Action=SendMessage' - '&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10' - ) - messages_count -= 1 - time.sleep(.5) - - def get_messages(): - count = 0 - while count < 5: - msg_res = test_client.get( - '/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5' - ) - new_msgs = re.findall("(.*?)", - msg_res.data.decode('utf-8')) - count += len(new_msgs) - messages.append(new_msgs) - - get_messages_thread = threading.Thread(target=get_messages) - insert_messages_thread = threading.Thread(target=insert_messages) - - get_messages_thread.start() - insert_messages_thread.start() - - get_messages_thread.join() - insert_messages_thread.join() - - # got each message in a separate call to ReceiveMessage, despite the long - # WaitTimeSeconds - assert len(messages) == 5 +from __future__ import unicode_literals + +import re +import sure # noqa +import threading +import time + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_sqs_list_identities(): + backend = server.create_backend_app("sqs") + test_client = backend.test_client() + + res = test_client.get("/?Action=ListQueues") + res.data.should.contain(b"ListQueuesResponse") + + # Make sure that we can receive messages from queues whose name contains dots (".") + # The AWS API mandates that the names of FIFO queues use the suffix ".fifo" + # See: https://github.com/spulec/moto/issues/866 + + for queue_name in ("testqueue", "otherqueue.fifo"): + + res = test_client.put("/?Action=CreateQueue&QueueName=%s" % queue_name) + + res = test_client.put( + "/123/%s?MessageBody=test-message&Action=SendMessage" % queue_name + ) + + res = test_client.get( + "/123/%s?Action=ReceiveMessage&MaxNumberOfMessages=1" % queue_name + ) + + message = re.search("(.*?)", res.data.decode("utf-8")).groups()[0] + message.should.equal("test-message") + + res = test_client.get("/?Action=ListQueues&QueueNamePrefix=other") + res.data.should.contain(b"otherqueue.fifo") + res.data.should_not.contain(b"testqueue") + + +def test_messages_polling(): + backend = server.create_backend_app("sqs") + test_client = backend.test_client() + messages = [] + + test_client.put("/?Action=CreateQueue&QueueName=testqueue") + + def insert_messages(): + messages_count = 5 + while messages_count > 0: + test_client.put( + "/123/testqueue?MessageBody=test-message&Action=SendMessage" + "&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10" + ) + messages_count -= 1 + time.sleep(0.5) + + def get_messages(): + count = 0 + while count < 5: + msg_res = test_client.get( + "/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5" + ) + new_msgs = re.findall("(.*?)", msg_res.data.decode("utf-8")) + count += len(new_msgs) + messages.append(new_msgs) + + get_messages_thread = threading.Thread(target=get_messages) + insert_messages_thread = threading.Thread(target=insert_messages) + + get_messages_thread.start() + insert_messages_thread.start() + + get_messages_thread.join() + insert_messages_thread.join() + + # got each message in a separate call to ReceiveMessage, despite the long + # WaitTimeSeconds + assert len(messages) == 5 diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index d53ae50f7..639d6e51c 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -1,174 +1,180 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals + +import base64 +import json import os +import time +import uuid import boto import boto3 import botocore.exceptions -from botocore.exceptions import ClientError -from boto.exception import SQSError -from boto.sqs.message import RawMessage, Message - -from freezegun import freeze_time -import base64 -import json +import six import sure # noqa -import time -import uuid - -from moto import settings, mock_sqs, mock_sqs_deprecated -from tests.helpers import requires_boto_gte import tests.backport_assert_raises # noqa -from nose.tools import assert_raises +from boto.exception import SQSError +from boto.sqs.message import Message, RawMessage +from botocore.exceptions import ClientError +from freezegun import freeze_time +from moto import mock_sqs, mock_sqs_deprecated, settings from nose import SkipTest +from nose.tools import assert_raises +from tests.helpers import requires_boto_gte +from moto.core import ACCOUNT_ID @mock_sqs def test_create_fifo_queue_fail(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") try: - sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'FifoQueue': 'true', - } - ) + sqs.create_queue(QueueName="test-queue", Attributes={"FifoQueue": "true"}) except botocore.exceptions.ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised InvalidParameterValue Exception') + raise RuntimeError("Should of raised InvalidParameterValue Exception") @mock_sqs def test_create_queue_with_same_attributes(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") - dlq_url = sqs.create_queue(QueueName='test-queue-dlq')['QueueUrl'] - dlq_arn = sqs.get_queue_attributes(QueueUrl=dlq_url)['Attributes']['QueueArn'] + dlq_url = sqs.create_queue(QueueName="test-queue-dlq")["QueueUrl"] + dlq_arn = sqs.get_queue_attributes(QueueUrl=dlq_url)["Attributes"]["QueueArn"] attributes = { - 'DelaySeconds': '900', - 'MaximumMessageSize': '262144', - 'MessageRetentionPeriod': '1209600', - 'ReceiveMessageWaitTimeSeconds': '20', - 'RedrivePolicy': '{"deadLetterTargetArn": "%s", "maxReceiveCount": 100}' % (dlq_arn), - 'VisibilityTimeout': '43200' + "DelaySeconds": "900", + "MaximumMessageSize": "262144", + "MessageRetentionPeriod": "1209600", + "ReceiveMessageWaitTimeSeconds": "20", + "RedrivePolicy": '{"deadLetterTargetArn": "%s", "maxReceiveCount": 100}' + % (dlq_arn), + "VisibilityTimeout": "43200", } - sqs.create_queue( - QueueName='test-queue', - Attributes=attributes - ) + sqs.create_queue(QueueName="test-queue", Attributes=attributes) - sqs.create_queue( - QueueName='test-queue', - Attributes=attributes - ) + sqs.create_queue(QueueName="test-queue", Attributes=attributes) @mock_sqs def test_create_queue_with_different_attributes_fail(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") - sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'VisibilityTimeout': '10', - } - ) + sqs.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "10"}) try: - sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'VisibilityTimeout': '60', - } - ) + sqs.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "60"}) except botocore.exceptions.ClientError as err: - err.response['Error']['Code'].should.equal('QueueAlreadyExists') + err.response["Error"]["Code"].should.equal("QueueAlreadyExists") else: - raise RuntimeError('Should of raised QueueAlreadyExists Exception') + raise RuntimeError("Should of raised QueueAlreadyExists Exception") @mock_sqs def test_create_fifo_queue(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") resp = sqs.create_queue( - QueueName='test-queue.fifo', - Attributes={ - 'FifoQueue': 'true', - } + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] response = sqs.get_queue_attributes(QueueUrl=queue_url) - response['Attributes'].should.contain('FifoQueue') - response['Attributes']['FifoQueue'].should.equal('true') + response["Attributes"].should.contain("FifoQueue") + response["Attributes"]["FifoQueue"].should.equal("true") @mock_sqs def test_create_queue(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") - new_queue = sqs.create_queue(QueueName='test-queue') + new_queue = sqs.create_queue(QueueName="test-queue") new_queue.should_not.be.none - new_queue.should.have.property('url').should.contain('test-queue') + new_queue.should.have.property("url").should.contain("test-queue") - queue = sqs.get_queue_by_name(QueueName='test-queue') - queue.attributes.get('QueueArn').should_not.be.none - queue.attributes.get('QueueArn').split(':')[-1].should.equal('test-queue') - queue.attributes.get('QueueArn').split(':')[3].should.equal('us-east-1') - queue.attributes.get('VisibilityTimeout').should_not.be.none - queue.attributes.get('VisibilityTimeout').should.equal('30') + queue = sqs.get_queue_by_name(QueueName="test-queue") + queue.attributes.get("QueueArn").should_not.be.none + queue.attributes.get("QueueArn").split(":")[-1].should.equal("test-queue") + queue.attributes.get("QueueArn").split(":")[3].should.equal("us-east-1") + queue.attributes.get("VisibilityTimeout").should_not.be.none + queue.attributes.get("VisibilityTimeout").should.equal("30") @mock_sqs def test_create_queue_kms(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") new_queue = sqs.create_queue( - QueueName='test-queue', + QueueName="test-queue", Attributes={ - 'KmsMasterKeyId': 'master-key-id', - 'KmsDataKeyReusePeriodSeconds': '600' - }) + "KmsMasterKeyId": "master-key-id", + "KmsDataKeyReusePeriodSeconds": "600", + }, + ) new_queue.should_not.be.none - queue = sqs.get_queue_by_name(QueueName='test-queue') + queue = sqs.get_queue_by_name(QueueName="test-queue") - queue.attributes.get('KmsMasterKeyId').should.equal('master-key-id') - queue.attributes.get('KmsDataKeyReusePeriodSeconds').should.equal('600') + queue.attributes.get("KmsMasterKeyId").should.equal("master-key-id") + queue.attributes.get("KmsDataKeyReusePeriodSeconds").should.equal("600") + + +@mock_sqs +def test_create_queue_with_tags(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue( + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_1"} + ) + queue_url = response["QueueUrl"] + + client.list_queue_tags(QueueUrl=queue_url)["Tags"].should.equal( + {"tag_key_1": "tag_value_1"} + ) + + +@mock_sqs +def test_get_queue_url(): + client = boto3.client("sqs", region_name="us-east-1") + client.create_queue(QueueName="test-queue") + + response = client.get_queue_url(QueueName="test-queue") + + response.should.have.key("QueueUrl").which.should.contain("test-queue") + + +@mock_sqs +def test_get_queue_url_errors(): + client = boto3.client("sqs", region_name="us-east-1") + + client.get_queue_url.when.called_with(QueueName="non-existing-queue").should.throw( + ClientError, "The specified queue does not exist for this wsdl version." + ) @mock_sqs def test_get_nonexistent_queue(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") with assert_raises(ClientError) as err: - sqs.get_queue_by_name(QueueName='nonexisting-queue') + sqs.get_queue_by_name(QueueName="nonexisting-queue") ex = err.exception - ex.operation_name.should.equal('GetQueueUrl') - ex.response['Error']['Code'].should.equal( - 'AWS.SimpleQueueService.NonExistentQueue') + ex.operation_name.should.equal("GetQueueUrl") + ex.response["Error"]["Code"].should.equal("AWS.SimpleQueueService.NonExistentQueue") with assert_raises(ClientError) as err: - sqs.Queue('http://whatever-incorrect-queue-address').load() + sqs.Queue("http://whatever-incorrect-queue-address").load() ex = err.exception - ex.operation_name.should.equal('GetQueueAttributes') - ex.response['Error']['Code'].should.equal( - 'AWS.SimpleQueueService.NonExistentQueue') + ex.operation_name.should.equal("GetQueueAttributes") + ex.response["Error"]["Code"].should.equal("AWS.SimpleQueueService.NonExistentQueue") @mock_sqs def test_message_send_without_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") - msg = queue.send_message( - MessageBody="derp" - ) - msg.get('MD5OfMessageBody').should.equal( - '58fd9edd83341c29f1aebba81c31e257') - msg.shouldnt.have.key('MD5OfMessageAttributes') - msg.get('MessageId').should_not.contain(' \n') + msg = queue.send_message(MessageBody="derp") + msg.get("MD5OfMessageBody").should.equal("58fd9edd83341c29f1aebba81c31e257") + msg.shouldnt.have.key("MD5OfMessageAttributes") + msg.get("MessageId").should_not.contain(" \n") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -176,22 +182,17 @@ def test_message_send_without_attributes(): @mock_sqs def test_message_send_with_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") msg = queue.send_message( MessageBody="derp", MessageAttributes={ - 'timestamp': { - 'StringValue': '1493147359900', - 'DataType': 'Number', - } - } + "timestamp": {"StringValue": "1493147359900", "DataType": "Number"} + }, ) - msg.get('MD5OfMessageBody').should.equal( - '58fd9edd83341c29f1aebba81c31e257') - msg.get('MD5OfMessageAttributes').should.equal( - '235c5c510d26fb653d073faed50ae77c') - msg.get('MessageId').should_not.contain(' \n') + msg.get("MD5OfMessageBody").should.equal("58fd9edd83341c29f1aebba81c31e257") + msg.get("MD5OfMessageAttributes").should.equal("235c5c510d26fb653d073faed50ae77c") + msg.get("MessageId").should_not.contain(" \n") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -199,22 +200,20 @@ def test_message_send_with_attributes(): @mock_sqs def test_message_with_complex_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") msg = queue.send_message( MessageBody="derp", MessageAttributes={ - 'ccc': {'StringValue': 'testjunk', 'DataType': 'String'}, - 'aaa': {'BinaryValue': b'\x02\x03\x04', 'DataType': 'Binary'}, - 'zzz': {'DataType': 'Number', 'StringValue': '0230.01'}, - 'öther_encodings': {'DataType': 'String', 'StringValue': 'T\xFCst'} - } + "ccc": {"StringValue": "testjunk", "DataType": "String"}, + "aaa": {"BinaryValue": b"\x02\x03\x04", "DataType": "Binary"}, + "zzz": {"DataType": "Number", "StringValue": "0230.01"}, + "öther_encodings": {"DataType": "String", "StringValue": "T\xFCst"}, + }, ) - msg.get('MD5OfMessageBody').should.equal( - '58fd9edd83341c29f1aebba81c31e257') - msg.get('MD5OfMessageAttributes').should.equal( - '8ae21a7957029ef04146b42aeaa18a22') - msg.get('MessageId').should_not.contain(' \n') + msg.get("MD5OfMessageBody").should.equal("58fd9edd83341c29f1aebba81c31e257") + msg.get("MD5OfMessageAttributes").should.equal("8ae21a7957029ef04146b42aeaa18a22") + msg.get("MessageId").should_not.contain(" \n") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -222,9 +221,10 @@ def test_message_with_complex_attributes(): @mock_sqs def test_send_message_with_message_group_id(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-group-id.fifo", - Attributes={'FifoQueue': 'true'}) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-group-id.fifo", Attributes={"FifoQueue": "true"} + ) sent = queue.send_message( MessageBody="mydata", @@ -236,17 +236,17 @@ def test_send_message_with_message_group_id(): messages.should.have.length_of(1) message_attributes = messages[0].attributes - message_attributes.should.contain('MessageGroupId') - message_attributes['MessageGroupId'].should.equal('group_id_1') - message_attributes.should.contain('MessageDeduplicationId') - message_attributes['MessageDeduplicationId'].should.equal('dedupe_id_1') + message_attributes.should.contain("MessageGroupId") + message_attributes["MessageGroupId"].should.equal("group_id_1") + message_attributes.should.contain("MessageDeduplicationId") + message_attributes["MessageDeduplicationId"].should.equal("dedupe_id_1") @mock_sqs def test_send_message_with_unicode_characters(): - body_one = 'Héllo!😀' + body_one = "Héllo!😀" - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") msg = queue.send_message(MessageBody=body_one) @@ -258,173 +258,253 @@ def test_send_message_with_unicode_characters(): @mock_sqs def test_set_queue_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") - queue.attributes['VisibilityTimeout'].should.equal("30") + queue.attributes["VisibilityTimeout"].should.equal("30") queue.set_attributes(Attributes={"VisibilityTimeout": "45"}) - queue.attributes['VisibilityTimeout'].should.equal("45") + queue.attributes["VisibilityTimeout"].should.equal("45") @mock_sqs def test_create_queues_in_multiple_region(): - west1_conn = boto3.client('sqs', region_name='us-west-1') + west1_conn = boto3.client("sqs", region_name="us-west-1") west1_conn.create_queue(QueueName="blah") - west2_conn = boto3.client('sqs', region_name='us-west-2') + west2_conn = boto3.client("sqs", region_name="us-west-2") west2_conn.create_queue(QueueName="test-queue") - list(west1_conn.list_queues()['QueueUrls']).should.have.length_of(1) - list(west2_conn.list_queues()['QueueUrls']).should.have.length_of(1) + list(west1_conn.list_queues()["QueueUrls"]).should.have.length_of(1) + list(west2_conn.list_queues()["QueueUrls"]).should.have.length_of(1) if settings.TEST_SERVER_MODE: - base_url = 'http://localhost:5000' + base_url = "http://localhost:5000" else: - base_url = 'https://us-west-1.queue.amazonaws.com' + base_url = "https://us-west-1.queue.amazonaws.com" - west1_conn.list_queues()['QueueUrls'][0].should.equal( - '{base_url}/123456789012/blah'.format(base_url=base_url)) + west1_conn.list_queues()["QueueUrls"][0].should.equal( + "{base_url}/{AccountId}/blah".format(base_url=base_url, AccountId=ACCOUNT_ID) + ) @mock_sqs def test_get_queue_with_prefix(): - conn = boto3.client("sqs", region_name='us-west-1') + conn = boto3.client("sqs", region_name="us-west-1") conn.create_queue(QueueName="prefixa-queue") conn.create_queue(QueueName="prefixb-queue") conn.create_queue(QueueName="test-queue") - conn.list_queues()['QueueUrls'].should.have.length_of(3) + conn.list_queues()["QueueUrls"].should.have.length_of(3) - queue = conn.list_queues(QueueNamePrefix="test-")['QueueUrls'] + queue = conn.list_queues(QueueNamePrefix="test-")["QueueUrls"] queue.should.have.length_of(1) if settings.TEST_SERVER_MODE: - base_url = 'http://localhost:5000' + base_url = "http://localhost:5000" else: - base_url = 'https://us-west-1.queue.amazonaws.com' + base_url = "https://us-west-1.queue.amazonaws.com" queue[0].should.equal( - "{base_url}/123456789012/test-queue".format(base_url=base_url)) + "{base_url}/{AccountId}/test-queue".format( + base_url=base_url, AccountId=ACCOUNT_ID + ) + ) @mock_sqs def test_delete_queue(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') - conn.create_queue(QueueName="test-queue", - Attributes={"VisibilityTimeout": "3"}) - queue = sqs.Queue('test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") + conn.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "3"}) + queue = sqs.Queue("test-queue") - conn.list_queues()['QueueUrls'].should.have.length_of(1) + conn.list_queues()["QueueUrls"].should.have.length_of(1) queue.delete() - conn.list_queues().get('QueueUrls').should.equal(None) + conn.list_queues().get("QueueUrls").should.equal(None) with assert_raises(botocore.exceptions.ClientError): queue.delete() +@mock_sqs +def test_get_queue_attributes(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + + response = client.get_queue_attributes(QueueUrl=queue_url) + + response["Attributes"]["ApproximateNumberOfMessages"].should.equal("0") + response["Attributes"]["ApproximateNumberOfMessagesDelayed"].should.equal("0") + response["Attributes"]["ApproximateNumberOfMessagesNotVisible"].should.equal("0") + response["Attributes"]["CreatedTimestamp"].should.be.a(six.string_types) + response["Attributes"]["DelaySeconds"].should.equal("0") + response["Attributes"]["LastModifiedTimestamp"].should.be.a(six.string_types) + response["Attributes"]["MaximumMessageSize"].should.equal("65536") + response["Attributes"]["MessageRetentionPeriod"].should.equal("345600") + response["Attributes"]["QueueArn"].should.equal( + "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID) + ) + response["Attributes"]["ReceiveMessageWaitTimeSeconds"].should.equal("0") + response["Attributes"]["VisibilityTimeout"].should.equal("30") + + response = client.get_queue_attributes( + QueueUrl=queue_url, + AttributeNames=[ + "ApproximateNumberOfMessages", + "MaximumMessageSize", + "QueueArn", + "VisibilityTimeout", + ], + ) + + response["Attributes"].should.equal( + { + "ApproximateNumberOfMessages": "0", + "MaximumMessageSize": "65536", + "QueueArn": "arn:aws:sqs:us-east-1:{}:test-queue".format(ACCOUNT_ID), + "VisibilityTimeout": "30", + } + ) + + # should not return any attributes, if it was not set before + response = client.get_queue_attributes( + QueueUrl=queue_url, AttributeNames=["KmsMasterKeyId"] + ) + + response.should_not.have.key("Attributes") + + +@mock_sqs +def test_get_queue_attributes_errors(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + + client.get_queue_attributes.when.called_with( + QueueUrl=queue_url + "-non-existing" + ).should.throw( + ClientError, "The specified queue does not exist for this wsdl version." + ) + + client.get_queue_attributes.when.called_with( + QueueUrl=queue_url, + AttributeNames=["QueueArn", "not-existing", "VisibilityTimeout"], + ).should.throw(ClientError, "Unknown Attribute not-existing.") + + client.get_queue_attributes.when.called_with( + QueueUrl=queue_url, AttributeNames=[""] + ).should.throw(ClientError, "Unknown Attribute .") + + client.get_queue_attributes.when.called_with( + QueueUrl=queue_url, AttributeNames=[] + ).should.throw(ClientError, "Unknown Attribute .") + + @mock_sqs def test_set_queue_attribute(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') - conn.create_queue(QueueName="test-queue", - Attributes={"VisibilityTimeout": '3'}) + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") + conn.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "3"}) queue = sqs.Queue("test-queue") - queue.attributes['VisibilityTimeout'].should.equal('3') + queue.attributes["VisibilityTimeout"].should.equal("3") - queue.set_attributes(Attributes={"VisibilityTimeout": '45'}) + queue.set_attributes(Attributes={"VisibilityTimeout": "45"}) queue = sqs.Queue("test-queue") - queue.attributes['VisibilityTimeout'].should.equal('45') + queue.attributes["VisibilityTimeout"].should.equal("45") @mock_sqs def test_send_receive_message_without_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") conn.create_queue(QueueName="test-queue") queue = sqs.Queue("test-queue") - body_one = 'this is a test message' - body_two = 'this is another test message' + body_one = "this is a test message" + body_two = "this is another test message" queue.send_message(MessageBody=body_one) queue.send_message(MessageBody=body_two) - messages = conn.receive_message( - QueueUrl=queue.url, MaxNumberOfMessages=2)['Messages'] + messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=2)[ + "Messages" + ] message1 = messages[0] message2 = messages[1] - message1['Body'].should.equal(body_one) - message2['Body'].should.equal(body_two) + message1["Body"].should.equal(body_one) + message2["Body"].should.equal(body_two) - message1.shouldnt.have.key('MD5OfMessageAttributes') - message2.shouldnt.have.key('MD5OfMessageAttributes') + message1.shouldnt.have.key("MD5OfMessageAttributes") + message2.shouldnt.have.key("MD5OfMessageAttributes") @mock_sqs def test_send_receive_message_with_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") conn.create_queue(QueueName="test-queue") queue = sqs.Queue("test-queue") - body_one = 'this is a test message' - body_two = 'this is another test message' + body_one = "this is a test message" + body_two = "this is another test message" queue.send_message( MessageBody=body_one, MessageAttributes={ - 'timestamp': { - 'StringValue': '1493147359900', - 'DataType': 'Number', - } - } + "timestamp": {"StringValue": "1493147359900", "DataType": "Number"} + }, ) queue.send_message( MessageBody=body_two, MessageAttributes={ - 'timestamp': { - 'StringValue': '1493147359901', - 'DataType': 'Number', - } - } + "timestamp": {"StringValue": "1493147359901", "DataType": "Number"} + }, ) - messages = conn.receive_message( - QueueUrl=queue.url, MaxNumberOfMessages=2)['Messages'] + messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=2)[ + "Messages" + ] message1 = messages[0] message2 = messages[1] - message1.get('Body').should.equal(body_one) - message2.get('Body').should.equal(body_two) + message1.get("Body").should.equal(body_one) + message2.get("Body").should.equal(body_two) - message1.get('MD5OfMessageAttributes').should.equal('235c5c510d26fb653d073faed50ae77c') - message2.get('MD5OfMessageAttributes').should.equal('994258b45346a2cc3f9cbb611aa7af30') + message1.get("MD5OfMessageAttributes").should.equal( + "235c5c510d26fb653d073faed50ae77c" + ) + message2.get("MD5OfMessageAttributes").should.equal( + "994258b45346a2cc3f9cbb611aa7af30" + ) @mock_sqs def test_send_receive_message_timestamps(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") conn.create_queue(QueueName="test-queue") queue = sqs.Queue("test-queue") response = queue.send_message(MessageBody="derp") - assert response['ResponseMetadata']['RequestId'] + assert response["ResponseMetadata"]["RequestId"] - messages = conn.receive_message( - QueueUrl=queue.url, MaxNumberOfMessages=1)['Messages'] + messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=1)[ + "Messages" + ] message = messages[0] - sent_timestamp = message.get('Attributes').get('SentTimestamp') - approximate_first_receive_timestamp = message.get('Attributes').get('ApproximateFirstReceiveTimestamp') + sent_timestamp = message.get("Attributes").get("SentTimestamp") + approximate_first_receive_timestamp = message.get("Attributes").get( + "ApproximateFirstReceiveTimestamp" + ) int.when.called_with(sent_timestamp).shouldnt.throw(ValueError) int.when.called_with(approximate_first_receive_timestamp).shouldnt.throw(ValueError) @@ -432,8 +512,8 @@ def test_send_receive_message_timestamps(): @mock_sqs def test_max_number_of_messages_invalid_param(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") with assert_raises(ClientError): queue.receive_messages(MaxNumberOfMessages=11) @@ -447,8 +527,8 @@ def test_max_number_of_messages_invalid_param(): @mock_sqs def test_wait_time_seconds_invalid_param(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") with assert_raises(ClientError): queue.receive_messages(WaitTimeSeconds=-1) @@ -468,7 +548,7 @@ def test_receive_messages_with_wait_seconds_timeout_of_zero(): :return: """ - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") messages = queue.receive_messages(WaitTimeSeconds=0) @@ -477,11 +557,11 @@ def test_receive_messages_with_wait_seconds_timeout_of_zero(): @mock_sqs_deprecated def test_send_message_with_xml_characters(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = '< & >' + body_one = "< & >" queue.write(queue.new_message(body_one)) @@ -493,17 +573,23 @@ def test_send_message_with_xml_characters(): @requires_boto_gte("2.28") @mock_sqs_deprecated def test_send_message_with_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body = 'this is a test message' + body = "this is a test message" message = queue.new_message(body) - BASE64_BINARY = base64.b64encode(b'binary value').decode('utf-8') + BASE64_BINARY = base64.b64encode(b"binary value").decode("utf-8") message_attributes = { - 'test.attribute_name': {'data_type': 'String', 'string_value': 'attribute value'}, - 'test.binary_attribute': {'data_type': 'Binary', 'binary_value': BASE64_BINARY}, - 'test.number_attribute': {'data_type': 'Number', 'string_value': 'string value'} + "test.attribute_name": { + "data_type": "String", + "string_value": "attribute value", + }, + "test.binary_attribute": {"data_type": "Binary", "binary_value": BASE64_BINARY}, + "test.number_attribute": { + "data_type": "Number", + "string_value": "string value", + }, } message.message_attributes = message_attributes @@ -519,12 +605,12 @@ def test_send_message_with_attributes(): @mock_sqs_deprecated def test_send_message_with_delay(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = 'this is a test message' - body_two = 'this is another test message' + body_one = "this is a test message" + body_two = "this is another test message" queue.write(queue.new_message(body_one), delay_seconds=3) queue.write(queue.new_message(body_two)) @@ -540,11 +626,11 @@ def test_send_message_with_delay(): @mock_sqs_deprecated def test_send_large_message_fails(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = 'test message' * 200000 + body_one = "test message" * 200000 huge_message = queue.new_message(body_one) queue.write.when.called_with(huge_message).should.throw(SQSError) @@ -552,11 +638,11 @@ def test_send_large_message_fails(): @mock_sqs_deprecated def test_message_becomes_inflight_when_received(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=2) queue.set_message_class(RawMessage) - body_one = 'this is a test message' + body_one = "this is a test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) @@ -573,16 +659,15 @@ def test_message_becomes_inflight_when_received(): @mock_sqs_deprecated def test_receive_message_with_explicit_visibility_timeout(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = 'this is another test message' + body_one = "this is another test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) - messages = conn.receive_message( - queue, number_messages=1, visibility_timeout=0) + messages = conn.receive_message(queue, number_messages=1, visibility_timeout=0) assert len(messages) == 1 @@ -592,11 +677,11 @@ def test_receive_message_with_explicit_visibility_timeout(): @mock_sqs_deprecated def test_change_message_visibility(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=2) queue.set_message_class(RawMessage) - body_one = 'this is another test message' + body_one = "this is another test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) @@ -626,11 +711,11 @@ def test_change_message_visibility(): @mock_sqs_deprecated def test_message_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=2) queue.set_message_class(RawMessage) - body_one = 'this is another test message' + body_one = "this is another test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) @@ -642,19 +727,19 @@ def test_message_attributes(): message_attributes = messages[0].attributes - assert message_attributes.get('ApproximateFirstReceiveTimestamp') - assert int(message_attributes.get('ApproximateReceiveCount')) == 1 - assert message_attributes.get('SentTimestamp') - assert message_attributes.get('SenderId') + assert message_attributes.get("ApproximateFirstReceiveTimestamp") + assert int(message_attributes.get("ApproximateReceiveCount")) == 1 + assert message_attributes.get("SentTimestamp") + assert message_attributes.get("SenderId") @mock_sqs_deprecated def test_read_message_from_queue(): conn = boto.connect_sqs() - queue = conn.create_queue('testqueue') + queue = conn.create_queue("testqueue") queue.set_message_class(RawMessage) - body = 'foo bar baz' + body = "foo bar baz" queue.write(queue.new_message(body)) message = queue.read(1) message.get_body().should.equal(body) @@ -662,23 +747,23 @@ def test_read_message_from_queue(): @mock_sqs_deprecated def test_queue_length(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is a test message')) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is a test message")) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(2) @mock_sqs_deprecated def test_delete_message(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is a test message')) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is a test message")) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(2) messages = conn.receive_message(queue, number_messages=1) @@ -694,17 +779,19 @@ def test_delete_message(): @mock_sqs_deprecated def test_send_batch_operation(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) # See https://github.com/boto/boto/issues/831 queue.set_message_class(RawMessage) - queue.write_batch([ - ("my_first_message", 'test message 1', 0), - ("my_second_message", 'test message 2', 0), - ("my_third_message", 'test message 3', 0), - ]) + queue.write_batch( + [ + ("my_first_message", "test message 1", 0), + ("my_second_message", "test message 2", 0), + ("my_third_message", "test message 3", 0), + ] + ) messages = queue.get_messages(3) messages[0].get_body().should.equal("test message 1") @@ -716,12 +803,16 @@ def test_send_batch_operation(): @requires_boto_gte("2.28") @mock_sqs_deprecated def test_send_batch_operation_with_message_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - message_tuple = ("my_first_message", 'test message 1', 0, { - 'name1': {'data_type': 'String', 'string_value': 'foo'}}) + message_tuple = ( + "my_first_message", + "test message 1", + 0, + {"name1": {"data_type": "String", "string_value": "foo"}}, + ) queue.write_batch([message_tuple]) messages = queue.get_messages() @@ -733,14 +824,17 @@ def test_send_batch_operation_with_message_attributes(): @mock_sqs_deprecated def test_delete_batch_operation(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) - conn.send_message_batch(queue, [ - ("my_first_message", 'test message 1', 0), - ("my_second_message", 'test message 2', 0), - ("my_third_message", 'test message 3', 0), - ]) + conn.send_message_batch( + queue, + [ + ("my_first_message", "test message 1", 0), + ("my_second_message", "test message 2", 0), + ("my_third_message", "test message 3", 0), + ], + ) messages = queue.get_messages(2) queue.delete_message_batch(messages) @@ -750,42 +844,44 @@ def test_delete_batch_operation(): @mock_sqs_deprecated def test_queue_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") - queue_name = 'test-queue' + queue_name = "test-queue" visibility_timeout = 3 - queue = conn.create_queue( - queue_name, visibility_timeout=visibility_timeout) + queue = conn.create_queue(queue_name, visibility_timeout=visibility_timeout) attributes = queue.get_attributes() - attributes['QueueArn'].should.look_like( - 'arn:aws:sqs:us-east-1:123456789012:%s' % queue_name) + attributes["QueueArn"].should.look_like( + "arn:aws:sqs:us-east-1:{AccountId}:{name}".format( + AccountId=ACCOUNT_ID, name=queue_name + ) + ) - attributes['VisibilityTimeout'].should.look_like(str(visibility_timeout)) + attributes["VisibilityTimeout"].should.look_like(str(visibility_timeout)) attribute_names = queue.get_attributes().keys() - attribute_names.should.contain('ApproximateNumberOfMessagesNotVisible') - attribute_names.should.contain('MessageRetentionPeriod') - attribute_names.should.contain('ApproximateNumberOfMessagesDelayed') - attribute_names.should.contain('MaximumMessageSize') - attribute_names.should.contain('CreatedTimestamp') - attribute_names.should.contain('ApproximateNumberOfMessages') - attribute_names.should.contain('ReceiveMessageWaitTimeSeconds') - attribute_names.should.contain('DelaySeconds') - attribute_names.should.contain('VisibilityTimeout') - attribute_names.should.contain('LastModifiedTimestamp') - attribute_names.should.contain('QueueArn') + attribute_names.should.contain("ApproximateNumberOfMessagesNotVisible") + attribute_names.should.contain("MessageRetentionPeriod") + attribute_names.should.contain("ApproximateNumberOfMessagesDelayed") + attribute_names.should.contain("MaximumMessageSize") + attribute_names.should.contain("CreatedTimestamp") + attribute_names.should.contain("ApproximateNumberOfMessages") + attribute_names.should.contain("ReceiveMessageWaitTimeSeconds") + attribute_names.should.contain("DelaySeconds") + attribute_names.should.contain("VisibilityTimeout") + attribute_names.should.contain("LastModifiedTimestamp") + attribute_names.should.contain("QueueArn") @mock_sqs_deprecated def test_change_message_visibility_on_invalid_receipt(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=1) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(1) messages = conn.receive_message(queue, number_messages=1) @@ -803,17 +899,16 @@ def test_change_message_visibility_on_invalid_receipt(): assert len(messages) == 1 - original_message.change_visibility.when.called_with( - 100).should.throw(SQSError) + original_message.change_visibility.when.called_with(100).should.throw(SQSError) @mock_sqs_deprecated def test_change_message_visibility_on_visible_message(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=1) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(1) messages = conn.receive_message(queue, number_messages=1) @@ -827,16 +922,15 @@ def test_change_message_visibility_on_visible_message(): queue.count().should.equal(1) - original_message.change_visibility.when.called_with( - 100).should.throw(SQSError) + original_message.change_visibility.when.called_with(100).should.throw(SQSError) @mock_sqs_deprecated def test_purge_action(): conn = boto.sqs.connect_to_region("us-east-1") - queue = conn.create_queue('new-queue') - queue.write(queue.new_message('this is another test message')) + queue = conn.create_queue("new-queue") + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(1) queue.purge() @@ -848,11 +942,10 @@ def test_purge_action(): def test_delete_message_after_visibility_timeout(): VISIBILITY_TIMEOUT = 1 conn = boto.sqs.connect_to_region("us-east-1") - new_queue = conn.create_queue( - 'new-queue', visibility_timeout=VISIBILITY_TIMEOUT) + new_queue = conn.create_queue("new-queue", visibility_timeout=VISIBILITY_TIMEOUT) m1 = Message() - m1.set_body('Message 1!') + m1.set_body("Message 1!") new_queue.write(m1) assert new_queue.count() == 1 @@ -866,273 +959,521 @@ def test_delete_message_after_visibility_timeout(): assert new_queue.count() == 0 +@mock_sqs +def test_delete_message_errors(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + client.send_message(QueueUrl=queue_url, MessageBody="body") + response = client.receive_message(QueueUrl=queue_url) + receipt_handle = response["Messages"][0]["ReceiptHandle"] + + client.delete_message.when.called_with( + QueueUrl=queue_url + "-not-existing", ReceiptHandle=receipt_handle + ).should.throw( + ClientError, "The specified queue does not exist for this wsdl version." + ) + + client.delete_message.when.called_with( + QueueUrl=queue_url, ReceiptHandle="not-existing" + ).should.throw(ClientError, "The input receipt handle is invalid.") + + +@mock_sqs +def test_send_message_batch(): + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + + response = client.send_message_batch( + QueueUrl=queue_url, + Entries=[ + { + "Id": "id_1", + "MessageBody": "body_1", + "DelaySeconds": 0, + "MessageAttributes": { + "attribute_name_1": { + "StringValue": "attribute_value_1", + "DataType": "String", + } + }, + }, + { + "Id": "id_2", + "MessageBody": "body_2", + "DelaySeconds": 0, + "MessageAttributes": { + "attribute_name_2": {"StringValue": "123", "DataType": "Number"} + }, + }, + ], + ) + + sorted([entry["Id"] for entry in response["Successful"]]).should.equal( + ["id_1", "id_2"] + ) + + response = client.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=10) + + response["Messages"][0]["Body"].should.equal("body_1") + response["Messages"][0]["MessageAttributes"].should.equal( + {"attribute_name_1": {"StringValue": "attribute_value_1", "DataType": "String"}} + ) + response["Messages"][1]["Body"].should.equal("body_2") + response["Messages"][1]["MessageAttributes"].should.equal( + {"attribute_name_2": {"StringValue": "123", "DataType": "Number"}} + ) + + +@mock_sqs +def test_send_message_batch_errors(): + client = boto3.client("sqs", region_name="us-east-1") + + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + + client.send_message_batch.when.called_with( + QueueUrl=queue_url + "-not-existing", + Entries=[{"Id": "id_1", "MessageBody": "body_1"}], + ).should.throw( + ClientError, "The specified queue does not exist for this wsdl version." + ) + + client.send_message_batch.when.called_with( + QueueUrl=queue_url, Entries=[] + ).should.throw( + ClientError, + "There should be at least one SendMessageBatchRequestEntry in the request.", + ) + + client.send_message_batch.when.called_with( + QueueUrl=queue_url, Entries=[{"Id": "", "MessageBody": "body_1"}] + ).should.throw( + ClientError, + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", + ) + + client.send_message_batch.when.called_with( + QueueUrl=queue_url, Entries=[{"Id": ".!@#$%^&*()+=", "MessageBody": "body_1"}] + ).should.throw( + ClientError, + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", + ) + + client.send_message_batch.when.called_with( + QueueUrl=queue_url, Entries=[{"Id": "i" * 81, "MessageBody": "body_1"}] + ).should.throw( + ClientError, + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", + ) + + client.send_message_batch.when.called_with( + QueueUrl=queue_url, Entries=[{"Id": "id_1", "MessageBody": "b" * 262145}] + ).should.throw( + ClientError, + "Batch requests cannot be longer than 262144 bytes. " + "You have sent 262145 bytes.", + ) + + # only the first duplicated Id is reported + client.send_message_batch.when.called_with( + QueueUrl=queue_url, + Entries=[ + {"Id": "id_1", "MessageBody": "body_1"}, + {"Id": "id_2", "MessageBody": "body_2"}, + {"Id": "id_2", "MessageBody": "body_2"}, + {"Id": "id_1", "MessageBody": "body_1"}, + ], + ).should.throw(ClientError, "Id id_2 repeated.") + + entries = [ + {"Id": "id_{}".format(i), "MessageBody": "body_{}".format(i)} for i in range(11) + ] + client.send_message_batch.when.called_with( + QueueUrl=queue_url, Entries=entries + ).should.throw( + ClientError, + "Maximum number of entries per request are 10. " "You have sent 11.", + ) + + @mock_sqs def test_batch_change_message_visibility(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") with freeze_time("2015-01-01 12:00:00"): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") resp = sqs.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] - sqs.send_message(QueueUrl=queue_url, MessageBody='msg1') - sqs.send_message(QueueUrl=queue_url, MessageBody='msg2') - sqs.send_message(QueueUrl=queue_url, MessageBody='msg3') + sqs.send_message(QueueUrl=queue_url, MessageBody="msg1") + sqs.send_message(QueueUrl=queue_url, MessageBody="msg2") + sqs.send_message(QueueUrl=queue_url, MessageBody="msg3") with freeze_time("2015-01-01 12:01:00"): receive_resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=2) - len(receive_resp['Messages']).should.equal(2) + len(receive_resp["Messages"]).should.equal(2) - handles = [item['ReceiptHandle'] for item in receive_resp['Messages']] - entries = [{'Id': str(uuid.uuid4()), 'ReceiptHandle': handle, 'VisibilityTimeout': 43200} for handle in handles] + handles = [item["ReceiptHandle"] for item in receive_resp["Messages"]] + entries = [ + { + "Id": str(uuid.uuid4()), + "ReceiptHandle": handle, + "VisibilityTimeout": 43200, + } + for handle in handles + ] resp = sqs.change_message_visibility_batch(QueueUrl=queue_url, Entries=entries) - len(resp['Successful']).should.equal(2) + len(resp["Successful"]).should.equal(2) with freeze_time("2015-01-01 14:00:00"): resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=3) - len(resp['Messages']).should.equal(1) + len(resp["Messages"]).should.equal(1) with freeze_time("2015-01-01 16:00:00"): resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=3) - len(resp['Messages']).should.equal(1) + len(resp["Messages"]).should.equal(1) with freeze_time("2015-01-02 12:00:00"): resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=3) - len(resp['Messages']).should.equal(3) + len(resp["Messages"]).should.equal(3) @mock_sqs def test_permissions(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") resp = client.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] - client.add_permission(QueueUrl=queue_url, Label='account1', AWSAccountIds=['111111111111'], Actions=['*']) - client.add_permission(QueueUrl=queue_url, Label='account2', AWSAccountIds=['222211111111'], Actions=['SendMessage']) + client.add_permission( + QueueUrl=queue_url, + Label="account1", + AWSAccountIds=["111111111111"], + Actions=["*"], + ) + client.add_permission( + QueueUrl=queue_url, + Label="account2", + AWSAccountIds=["222211111111"], + Actions=["SendMessage"], + ) with assert_raises(ClientError): - client.add_permission(QueueUrl=queue_url, Label='account2', AWSAccountIds=['222211111111'], Actions=['SomeRubbish']) + client.add_permission( + QueueUrl=queue_url, + Label="account2", + AWSAccountIds=["222211111111"], + Actions=["SomeRubbish"], + ) - client.remove_permission(QueueUrl=queue_url, Label='account2') + client.remove_permission(QueueUrl=queue_url, Label="account2") with assert_raises(ClientError): - client.remove_permission(QueueUrl=queue_url, Label='non_existant') + client.remove_permission(QueueUrl=queue_url, Label="non_existant") @mock_sqs def test_tags(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") resp = client.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] - client.tag_queue( - QueueUrl=queue_url, - Tags={ - 'test1': 'value1', - 'test2': 'value2', - } - ) + client.tag_queue(QueueUrl=queue_url, Tags={"test1": "value1", "test2": "value2"}) resp = client.list_queue_tags(QueueUrl=queue_url) - resp['Tags'].should.contain('test1') - resp['Tags'].should.contain('test2') + resp["Tags"].should.contain("test1") + resp["Tags"].should.contain("test2") - client.untag_queue( - QueueUrl=queue_url, - TagKeys=['test2'] - ) + client.untag_queue(QueueUrl=queue_url, TagKeys=["test2"]) resp = client.list_queue_tags(QueueUrl=queue_url) - resp['Tags'].should.contain('test1') - resp['Tags'].should_not.contain('test2') + resp["Tags"].should.contain("test1") + resp["Tags"].should_not.contain("test2") + + # removing a non existing tag should not raise any error + client.untag_queue(QueueUrl=queue_url, TagKeys=["not-existing-tag"]) + client.list_queue_tags(QueueUrl=queue_url)["Tags"].should.equal({"test1": "value1"}) + + +@mock_sqs +def test_list_queue_tags_errors(): + client = boto3.client("sqs", region_name="us-east-1") + + response = client.create_queue( + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_X"} + ) + queue_url = response["QueueUrl"] + + client.list_queue_tags.when.called_with( + QueueUrl=queue_url + "-not-existing" + ).should.throw( + ClientError, "The specified queue does not exist for this wsdl version." + ) + + +@mock_sqs +def test_tag_queue_errors(): + client = boto3.client("sqs", region_name="us-east-1") + + response = client.create_queue( + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_X"} + ) + queue_url = response["QueueUrl"] + + client.tag_queue.when.called_with( + QueueUrl=queue_url + "-not-existing", Tags={"tag_key_1": "tag_value_1"} + ).should.throw( + ClientError, "The specified queue does not exist for this wsdl version." + ) + + client.tag_queue.when.called_with(QueueUrl=queue_url, Tags={}).should.throw( + ClientError, "The request must contain the parameter Tags." + ) + + too_many_tags = { + "tag_key_{}".format(i): "tag_value_{}".format(i) for i in range(51) + } + client.tag_queue.when.called_with( + QueueUrl=queue_url, Tags=too_many_tags + ).should.throw(ClientError, "Too many tags added for queue test-queue-with-tags.") + + # when the request fails, the tags should not be updated + client.list_queue_tags(QueueUrl=queue_url)["Tags"].should.equal( + {"tag_key_1": "tag_value_X"} + ) + + +@mock_sqs +def test_untag_queue_errors(): + client = boto3.client("sqs", region_name="us-east-1") + + response = client.create_queue( + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_1"} + ) + queue_url = response["QueueUrl"] + + client.untag_queue.when.called_with( + QueueUrl=queue_url + "-not-existing", TagKeys=["tag_key_1"] + ).should.throw( + ClientError, "The specified queue does not exist for this wsdl version." + ) + + client.untag_queue.when.called_with(QueueUrl=queue_url, TagKeys=[]).should.throw( + ClientError, "Tag keys must be between 1 and 128 characters in length." + ) @mock_sqs def test_create_fifo_queue_with_dlq(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") resp = sqs.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url1 = resp['QueueUrl'] - queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)['Attributes']['QueueArn'] + queue_url1 = resp["QueueUrl"] + queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)["Attributes"]["QueueArn"] resp = sqs.create_queue( - QueueName='test-dlr-queue', - Attributes={'FifoQueue': 'false'} + QueueName="test-dlr-queue", Attributes={"FifoQueue": "false"} ) - queue_url2 = resp['QueueUrl'] - queue_arn2 = sqs.get_queue_attributes(QueueUrl=queue_url2)['Attributes']['QueueArn'] + queue_url2 = resp["QueueUrl"] + queue_arn2 = sqs.get_queue_attributes(QueueUrl=queue_url2)["Attributes"]["QueueArn"] sqs.create_queue( - QueueName='test-queue.fifo', + QueueName="test-queue.fifo", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn1, 'maxReceiveCount': 2}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps( + {"deadLetterTargetArn": queue_arn1, "maxReceiveCount": 2} + ), + }, ) # Cant have fifo queue with non fifo DLQ with assert_raises(ClientError): sqs.create_queue( - QueueName='test-queue2.fifo', + QueueName="test-queue2.fifo", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn2, 'maxReceiveCount': 2}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps( + {"deadLetterTargetArn": queue_arn2, "maxReceiveCount": 2} + ), + }, ) @mock_sqs def test_queue_with_dlq(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") with freeze_time("2015-01-01 12:00:00"): resp = sqs.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url1 = resp['QueueUrl'] - queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)['Attributes']['QueueArn'] + queue_url1 = resp["QueueUrl"] + queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)["Attributes"][ + "QueueArn" + ] resp = sqs.create_queue( - QueueName='test-queue.fifo', + QueueName="test-queue.fifo", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn1, 'maxReceiveCount': 2}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps( + {"deadLetterTargetArn": queue_arn1, "maxReceiveCount": 2} + ), + }, ) - queue_url2 = resp['QueueUrl'] + queue_url2 = resp["QueueUrl"] - sqs.send_message(QueueUrl=queue_url2, MessageBody='msg1') - sqs.send_message(QueueUrl=queue_url2, MessageBody='msg2') + sqs.send_message(QueueUrl=queue_url2, MessageBody="msg1") + sqs.send_message(QueueUrl=queue_url2, MessageBody="msg2") with freeze_time("2015-01-01 13:00:00"): - resp = sqs.receive_message(QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0) - resp['Messages'][0]['Body'].should.equal('msg1') + resp = sqs.receive_message( + QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + resp["Messages"][0]["Body"].should.equal("msg1") with freeze_time("2015-01-01 13:01:00"): - resp = sqs.receive_message(QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0) - resp['Messages'][0]['Body'].should.equal('msg1') + resp = sqs.receive_message( + QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + resp["Messages"][0]["Body"].should.equal("msg1") with freeze_time("2015-01-01 13:02:00"): - resp = sqs.receive_message(QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0) - len(resp['Messages']).should.equal(1) + resp = sqs.receive_message( + QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + len(resp["Messages"]).should.equal(1) - resp = sqs.receive_message(QueueUrl=queue_url1, VisibilityTimeout=30, WaitTimeSeconds=0) - resp['Messages'][0]['Body'].should.equal('msg1') + resp = sqs.receive_message( + QueueUrl=queue_url1, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + resp["Messages"][0]["Body"].should.equal("msg1") # Might as well test list source queues resp = sqs.list_dead_letter_source_queues(QueueUrl=queue_url1) - resp['queueUrls'][0].should.equal(queue_url2) + resp["queueUrls"][0].should.equal(queue_url2) @mock_sqs def test_redrive_policy_available(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") - resp = sqs.create_queue(QueueName='test-deadletter') - queue_url1 = resp['QueueUrl'] - queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)['Attributes']['QueueArn'] - redrive_policy = { - 'deadLetterTargetArn': queue_arn1, - 'maxReceiveCount': 1, - } + resp = sqs.create_queue(QueueName="test-deadletter") + queue_url1 = resp["QueueUrl"] + queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)["Attributes"]["QueueArn"] + redrive_policy = {"deadLetterTargetArn": queue_arn1, "maxReceiveCount": 1} resp = sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'RedrivePolicy': json.dumps(redrive_policy) - } + QueueName="test-queue", Attributes={"RedrivePolicy": json.dumps(redrive_policy)} ) - queue_url2 = resp['QueueUrl'] - attributes = sqs.get_queue_attributes(QueueUrl=queue_url2)['Attributes'] - assert 'RedrivePolicy' in attributes - assert json.loads(attributes['RedrivePolicy']) == redrive_policy + queue_url2 = resp["QueueUrl"] + attributes = sqs.get_queue_attributes(QueueUrl=queue_url2)["Attributes"] + assert "RedrivePolicy" in attributes + assert json.loads(attributes["RedrivePolicy"]) == redrive_policy # Cant have redrive policy without maxReceiveCount with assert_raises(ClientError): sqs.create_queue( - QueueName='test-queue2', + QueueName="test-queue2", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn1}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps({"deadLetterTargetArn": queue_arn1}), + }, ) @mock_sqs def test_redrive_policy_non_existent_queue(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") redrive_policy = { - 'deadLetterTargetArn': 'arn:aws:sqs:us-east-1:123456789012:no-queue', - 'maxReceiveCount': 1, + "deadLetterTargetArn": "arn:aws:sqs:us-east-1:{}:no-queue".format(ACCOUNT_ID), + "maxReceiveCount": 1, } with assert_raises(ClientError): sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'RedrivePolicy': json.dumps(redrive_policy) - } + QueueName="test-queue", + Attributes={"RedrivePolicy": json.dumps(redrive_policy)}, ) @mock_sqs def test_redrive_policy_set_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") - queue = sqs.create_queue(QueueName='test-queue') - deadletter_queue = sqs.create_queue(QueueName='test-deadletter') + queue = sqs.create_queue(QueueName="test-queue") + deadletter_queue = sqs.create_queue(QueueName="test-deadletter") redrive_policy = { - 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], - 'maxReceiveCount': 1, + "deadLetterTargetArn": deadletter_queue.attributes["QueueArn"], + "maxReceiveCount": 1, } - queue.set_attributes(Attributes={ - 'RedrivePolicy': json.dumps(redrive_policy)}) + queue.set_attributes(Attributes={"RedrivePolicy": json.dumps(redrive_policy)}) - copy = sqs.get_queue_by_name(QueueName='test-queue') - assert 'RedrivePolicy' in copy.attributes - copy_policy = json.loads(copy.attributes['RedrivePolicy']) + copy = sqs.get_queue_by_name(QueueName="test-queue") + assert "RedrivePolicy" in copy.attributes + copy_policy = json.loads(copy.attributes["RedrivePolicy"]) assert copy_policy == redrive_policy +@mock_sqs +def test_redrive_policy_set_attributes_with_string_value(): + sqs = boto3.resource("sqs", region_name="us-east-1") + + queue = sqs.create_queue(QueueName="test-queue") + deadletter_queue = sqs.create_queue(QueueName="test-deadletter") + + queue.set_attributes( + Attributes={ + "RedrivePolicy": json.dumps( + { + "deadLetterTargetArn": deadletter_queue.attributes["QueueArn"], + "maxReceiveCount": "1", + } + ) + } + ) + + copy = sqs.get_queue_by_name(QueueName="test-queue") + assert "RedrivePolicy" in copy.attributes + copy_policy = json.loads(copy.attributes["RedrivePolicy"]) + assert copy_policy == { + "deadLetterTargetArn": deadletter_queue.attributes["QueueArn"], + "maxReceiveCount": 1, + } + + @mock_sqs def test_receive_messages_with_message_group_id(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-queue.fifo", - Attributes={ - 'FifoQueue': 'true', - }) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} + ) queue.set_attributes(Attributes={"VisibilityTimeout": "3600"}) - queue.send_message( - MessageBody="message-1", - MessageGroupId="group" - ) - queue.send_message( - MessageBody="message-2", - MessageGroupId="group" - ) + queue.send_message(MessageBody="message-1", MessageGroupId="group") + queue.send_message(MessageBody="message-2", MessageGroupId="group") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -1151,20 +1492,13 @@ def test_receive_messages_with_message_group_id(): @mock_sqs def test_receive_messages_with_message_group_id_on_requeue(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-queue.fifo", - Attributes={ - 'FifoQueue': 'true', - }) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} + ) queue.set_attributes(Attributes={"VisibilityTimeout": "3600"}) - queue.send_message( - MessageBody="message-1", - MessageGroupId="group" - ) - queue.send_message( - MessageBody="message-2", - MessageGroupId="group" - ) + queue.send_message(MessageBody="message-1", MessageGroupId="group") + queue.send_message(MessageBody="message-2", MessageGroupId="group") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -1184,24 +1518,17 @@ def test_receive_messages_with_message_group_id_on_requeue(): @mock_sqs def test_receive_messages_with_message_group_id_on_visibility_timeout(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") with freeze_time("2015-01-01 12:00:00"): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-queue.fifo", - Attributes={ - 'FifoQueue': 'true', - }) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} + ) queue.set_attributes(Attributes={"VisibilityTimeout": "3600"}) - queue.send_message( - MessageBody="message-1", - MessageGroupId="group" - ) - queue.send_message( - MessageBody="message-2", - MessageGroupId="group" - ) + queue.send_message(MessageBody="message-1", MessageGroupId="group") + queue.send_message(MessageBody="message-2", MessageGroupId="group") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -1225,15 +1552,13 @@ def test_receive_messages_with_message_group_id_on_visibility_timeout(): messages.should.have.length_of(1) messages[0].message_id.should.equal(message.message_id) + @mock_sqs def test_receive_message_for_queue_with_receive_message_wait_time_seconds_set(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'ReceiveMessageWaitTimeSeconds': '2', - } + QueueName="test-queue", Attributes={"ReceiveMessageWaitTimeSeconds": "2"} ) queue.receive_messages() diff --git a/tests/test_ssm/test_ssm_boto3.py b/tests/test_ssm/test_ssm_boto3.py index 77d439d83..5b978520d 100644 --- a/tests/test_ssm/test_ssm_boto3.py +++ b/tests/test_ssm/test_ssm_boto3.py @@ -2,12 +2,12 @@ from __future__ import unicode_literals import boto3 import botocore.exceptions -import sure # noqa +import sure # noqa import datetime import uuid import json -from botocore.exceptions import ClientError +from botocore.exceptions import ClientError, ParamValidationError from nose.tools import assert_raises from moto import mock_ssm, mock_cloudformation @@ -15,714 +15,1078 @@ from moto import mock_ssm, mock_cloudformation @mock_ssm def test_delete_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(1) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(1) - client.delete_parameter(Name='test') + client.delete_parameter(Name="test") - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(0) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(0) @mock_ssm def test_delete_parameters(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(1) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(1) - result = client.delete_parameters(Names=['test', 'invalid']) - len(result['DeletedParameters']).should.equal(1) - len(result['InvalidParameters']).should.equal(1) + result = client.delete_parameters(Names=["test", "invalid"]) + len(result["DeletedParameters"]).should.equal(1) + len(result["InvalidParameters"]).should.equal(1) - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(0) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(0) @mock_ssm def test_get_parameters_by_path(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='/foo/name1', - Description='A test parameter', - Value='value1', - Type='String') - - client.put_parameter( - Name='/foo/name2', - Description='A test parameter', - Value='value2', - Type='String') - - client.put_parameter( - Name='/bar/name3', - Description='A test parameter', - Value='value3', - Type='String') - - client.put_parameter( - Name='/bar/name3/name4', - Description='A test parameter', - Value='value4', - Type='String') - - client.put_parameter( - Name='/baz/name1', - Description='A test parameter (list)', - Value='value1,value2,value3', - Type='StringList') - - client.put_parameter( - Name='/baz/name2', - Description='A test parameter', - Value='value1', - Type='String') - - client.put_parameter( - Name='/baz/pwd', - Description='A secure test parameter', - Value='my_secret', - Type='SecureString', - KeyId='alias/aws/ssm') - - client.put_parameter( - Name='foo', - Description='A test parameter', - Value='bar', - Type='String') - - client.put_parameter( - Name='baz', - Description='A test parameter', - Value='qux', - Type='String') - - response = client.get_parameters_by_path(Path='/', Recursive=False) - len(response['Parameters']).should.equal(2) - {p['Value'] for p in response['Parameters']}.should.equal( - set(['bar', 'qux']) + Name="/foo/name1", Description="A test parameter", Value="value1", Type="String" ) - response = client.get_parameters_by_path(Path='/', Recursive=True) - len(response['Parameters']).should.equal(9) - - response = client.get_parameters_by_path(Path='/foo') - len(response['Parameters']).should.equal(2) - {p['Value'] for p in response['Parameters']}.should.equal( - set(['value1', 'value2']) + client.put_parameter( + Name="/foo/name2", Description="A test parameter", Value="value2", Type="String" ) - response = client.get_parameters_by_path(Path='/bar', Recursive=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Value'].should.equal('value3') - - response = client.get_parameters_by_path(Path='/bar', Recursive=True) - len(response['Parameters']).should.equal(2) - {p['Value'] for p in response['Parameters']}.should.equal( - set(['value3', 'value4']) + client.put_parameter( + Name="/bar/name3", Description="A test parameter", Value="value3", Type="String" ) - response = client.get_parameters_by_path(Path='/baz') - len(response['Parameters']).should.equal(3) - - filters = [{ - 'Key': 'Type', - 'Option': 'Equals', - 'Values': ['StringList'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name1']) + client.put_parameter( + Name="/bar/name3/name4", + Description="A test parameter", + Value="value4", + Type="String", ) + client.put_parameter( + Name="/baz/name1", + Description="A test parameter (list)", + Value="value1,value2,value3", + Type="StringList", + ) + + client.put_parameter( + Name="/baz/name2", Description="A test parameter", Value="value1", Type="String" + ) + + client.put_parameter( + Name="/baz/pwd", + Description="A secure test parameter", + Value="my_secret", + Type="SecureString", + KeyId="alias/aws/ssm", + ) + + client.put_parameter( + Name="foo", Description="A test parameter", Value="bar", Type="String" + ) + + client.put_parameter( + Name="baz", Description="A test parameter", Value="qux", Type="String" + ) + + response = client.get_parameters_by_path(Path="/", Recursive=False) + len(response["Parameters"]).should.equal(2) + {p["Value"] for p in response["Parameters"]}.should.equal(set(["bar", "qux"])) + {p["ARN"] for p in response["Parameters"]}.should.equal( + set( + [ + "arn:aws:ssm:us-east-1:1234567890:parameter/foo", + "arn:aws:ssm:us-east-1:1234567890:parameter/baz", + ] + ) + ) + { + p["LastModifiedDate"].should.be.a(datetime.datetime) + for p in response["Parameters"] + } + + response = client.get_parameters_by_path(Path="/", Recursive=True) + len(response["Parameters"]).should.equal(9) + + response = client.get_parameters_by_path(Path="/foo") + len(response["Parameters"]).should.equal(2) + {p["Value"] for p in response["Parameters"]}.should.equal(set(["value1", "value2"])) + + response = client.get_parameters_by_path(Path="/bar", Recursive=False) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Value"].should.equal("value3") + + response = client.get_parameters_by_path(Path="/bar", Recursive=True) + len(response["Parameters"]).should.equal(2) + {p["Value"] for p in response["Parameters"]}.should.equal(set(["value3", "value4"])) + + response = client.get_parameters_by_path(Path="/baz") + len(response["Parameters"]).should.equal(3) + + filters = [{"Key": "Type", "Option": "Equals", "Values": ["StringList"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/name1"])) + # note: 'Option' is optional (default: 'Equals') - filters = [{ - 'Key': 'Type', - 'Values': ['StringList'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name1']) + filters = [{"Key": "Type", "Values": ["StringList"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/name1"])) + + filters = [{"Key": "Type", "Option": "Equals", "Values": ["String"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/name2"])) + + filters = [ + {"Key": "Type", "Option": "Equals", "Values": ["String", "SecureString"]} + ] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(2) + {p["Name"] for p in response["Parameters"]}.should.equal( + set(["/baz/name2", "/baz/pwd"]) ) - filters = [{ - 'Key': 'Type', - 'Option': 'Equals', - 'Values': ['String'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name2']) + filters = [{"Key": "Type", "Option": "BeginsWith", "Values": ["String"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(2) + {p["Name"] for p in response["Parameters"]}.should.equal( + set(["/baz/name1", "/baz/name2"]) ) - filters = [{ - 'Key': 'Type', - 'Option': 'Equals', - 'Values': ['String', 'SecureString'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(2) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name2', '/baz/pwd']) - ) + filters = [{"Key": "KeyId", "Option": "Equals", "Values": ["alias/aws/ssm"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/pwd"])) - filters = [{ - 'Key': 'Type', - 'Option': 'BeginsWith', - 'Values': ['String'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(2) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name1', '/baz/name2']) + response = client.get_parameters_by_path(Path="/", Recursive=True, MaxResults=4) + len(response["Parameters"]).should.equal(4) + response["NextToken"].should.equal("4") + response = client.get_parameters_by_path( + Path="/", Recursive=True, MaxResults=4, NextToken=response["NextToken"] ) - - filters = [{ - 'Key': 'KeyId', - 'Option': 'Equals', - 'Values': ['alias/aws/ssm'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/pwd']) + len(response["Parameters"]).should.equal(4) + response["NextToken"].should.equal("8") + response = client.get_parameters_by_path( + Path="/", Recursive=True, MaxResults=4, NextToken=response["NextToken"] ) + len(response["Parameters"]).should.equal(1) + response.should_not.have.key("NextToken") @mock_ssm def test_put_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") response = client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response['Version'].should.equal(1) + response["Version"].should.equal(1) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('String') - response['Parameters'][0]['Version'].should.equal(1) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("String") + response["Parameters"][0]["Version"].should.equal(1) + response["Parameters"][0]["LastModifiedDate"].should.be.a(datetime.datetime) + response["Parameters"][0]["ARN"].should.equal( + "arn:aws:ssm:us-east-1:1234567890:parameter/test" + ) + initial_modification_date = response["Parameters"][0]["LastModifiedDate"] try: client.put_parameter( - Name='test', - Description='desc 2', - Value='value 2', - Type='String') - raise RuntimeError('Should fail') + Name="test", Description="desc 2", Value="value 2", Type="String" + ) + raise RuntimeError("Should fail") except botocore.exceptions.ClientError as err: - err.operation_name.should.equal('PutParameter') - err.response['Error']['Message'].should.equal('Parameter test already exists.') + err.operation_name.should.equal("PutParameter") + err.response["Error"]["Message"].should.equal("Parameter test already exists.") - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) # without overwrite nothing change - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('String') - response['Parameters'][0]['Version'].should.equal(1) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("String") + response["Parameters"][0]["Version"].should.equal(1) + response["Parameters"][0]["LastModifiedDate"].should.equal( + initial_modification_date + ) + response["Parameters"][0]["ARN"].should.equal( + "arn:aws:ssm:us-east-1:1234567890:parameter/test" + ) response = client.put_parameter( - Name='test', - Description='desc 3', - Value='value 3', - Type='String', - Overwrite=True) + Name="test", + Description="desc 3", + Value="value 3", + Type="String", + Overwrite=True, + ) - response['Version'].should.equal(2) + response["Version"].should.equal(2) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) # without overwrite nothing change - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value 3') - response['Parameters'][0]['Type'].should.equal('String') - response['Parameters'][0]['Version'].should.equal(2) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value 3") + response["Parameters"][0]["Type"].should.equal("String") + response["Parameters"][0]["Version"].should.equal(2) + response["Parameters"][0]["LastModifiedDate"].should_not.equal( + initial_modification_date + ) + response["Parameters"][0]["ARN"].should.equal( + "arn:aws:ssm:us-east-1:1234567890:parameter/test" + ) + @mock_ssm def test_put_parameter_china(): - client = boto3.client('ssm', region_name='cn-north-1') + client = boto3.client("ssm", region_name="cn-north-1") response = client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response['Version'].should.equal(1) + response["Version"].should.equal(1) @mock_ssm def test_get_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response = client.get_parameter( - Name='test', - WithDecryption=False) + response = client.get_parameter(Name="test", WithDecryption=False) - response['Parameter']['Name'].should.equal('test') - response['Parameter']['Value'].should.equal('value') - response['Parameter']['Type'].should.equal('String') + response["Parameter"]["Name"].should.equal("test") + response["Parameter"]["Value"].should.equal("value") + response["Parameter"]["Type"].should.equal("String") + response["Parameter"]["LastModifiedDate"].should.be.a(datetime.datetime) + response["Parameter"]["ARN"].should.equal( + "arn:aws:ssm:us-east-1:1234567890:parameter/test" + ) @mock_ssm def test_get_nonexistant_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") try: - client.get_parameter( - Name='test_noexist', - WithDecryption=False) - raise RuntimeError('Should of failed') + client.get_parameter(Name="test_noexist", WithDecryption=False) + raise RuntimeError("Should of failed") except botocore.exceptions.ClientError as err: - err.operation_name.should.equal('GetParameter') - err.response['Error']['Message'].should.equal('Parameter test_noexist not found.') + err.operation_name.should.equal("GetParameter") + err.response["Error"]["Message"].should.equal( + "Parameter test_noexist not found." + ) @mock_ssm def test_describe_parameters(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String', - AllowedPattern=r'.*') + Name="test", + Description="A test parameter", + Value="value", + Type="String", + AllowedPattern=r".*", + ) response = client.describe_parameters() - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Type'].should.equal('String') - response['Parameters'][0]['AllowedPattern'].should.equal(r'.*') + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("test") + parameters[0]["Type"].should.equal("String") + parameters[0]["AllowedPattern"].should.equal(r".*") @mock_ssm def test_describe_parameters_paging(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - client.put_parameter( - Name="param-%d" % i, - Value="value-%d" % i, - Type="String" - ) + client.put_parameter(Name="param-%d" % i, Value="value-%d" % i, Type="String") response = client.describe_parameters() - len(response['Parameters']).should.equal(10) - response['NextToken'].should.equal('10') + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("10") - response = client.describe_parameters(NextToken=response['NextToken']) - len(response['Parameters']).should.equal(10) - response['NextToken'].should.equal('20') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("20") - response = client.describe_parameters(NextToken=response['NextToken']) - len(response['Parameters']).should.equal(10) - response['NextToken'].should.equal('30') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("30") - response = client.describe_parameters(NextToken=response['NextToken']) - len(response['Parameters']).should.equal(10) - response['NextToken'].should.equal('40') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("40") - response = client.describe_parameters(NextToken=response['NextToken']) - len(response['Parameters']).should.equal(10) - response['NextToken'].should.equal('50') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("50") - response = client.describe_parameters(NextToken=response['NextToken']) - len(response['Parameters']).should.equal(0) - ''.should.equal(response.get('NextToken', '')) + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(0) + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_filter_names(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - p = { - 'Name': "param-%d" % i, - 'Value': "value-%d" % i, - 'Type': "String" - } + p = {"Name": "param-%d" % i, "Value": "value-%d" % i, "Type": "String"} if i % 5 == 0: - p['Type'] = 'SecureString' - p['KeyId'] = 'a key' + p["Type"] = "SecureString" + p["KeyId"] = "a key" client.put_parameter(**p) - response = client.describe_parameters(Filters=[ - { - 'Key': 'Name', - 'Values': ['param-22'] - }, - ]) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('param-22') - response['Parameters'][0]['Type'].should.equal('String') - ''.should.equal(response.get('NextToken', '')) + response = client.describe_parameters( + Filters=[{"Key": "Name", "Values": ["param-22"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("param-22") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_filter_type(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - p = { - 'Name': "param-%d" % i, - 'Value': "value-%d" % i, - 'Type': "String" - } + p = {"Name": "param-%d" % i, "Value": "value-%d" % i, "Type": "String"} if i % 5 == 0: - p['Type'] = 'SecureString' - p['KeyId'] = 'a key' + p["Type"] = "SecureString" + p["KeyId"] = "a key" client.put_parameter(**p) - response = client.describe_parameters(Filters=[ - { - 'Key': 'Type', - 'Values': ['SecureString'] - }, - ]) - len(response['Parameters']).should.equal(10) - response['Parameters'][0]['Type'].should.equal('SecureString') - '10'.should.equal(response.get('NextToken', '')) + response = client.describe_parameters( + Filters=[{"Key": "Type", "Values": ["SecureString"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(10) + parameters[0]["Type"].should.equal("SecureString") + response.should.have.key("NextToken").which.should.equal("10") @mock_ssm def test_describe_parameters_filter_keyid(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - p = { - 'Name': "param-%d" % i, - 'Value': "value-%d" % i, - 'Type': "String" - } + p = {"Name": "param-%d" % i, "Value": "value-%d" % i, "Type": "String"} if i % 5 == 0: - p['Type'] = 'SecureString' - p['KeyId'] = "key:%d" % i + p["Type"] = "SecureString" + p["KeyId"] = "key:%d" % i client.put_parameter(**p) - response = client.describe_parameters(Filters=[ - { - 'Key': 'KeyId', - 'Values': ['key:10'] - }, - ]) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('param-10') - response['Parameters'][0]['Type'].should.equal('SecureString') - ''.should.equal(response.get('NextToken', '')) + response = client.describe_parameters( + Filters=[{"Key": "KeyId", "Values": ["key:10"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("param-10") + parameters[0]["Type"].should.equal("SecureString") + response.should_not.have.key("NextToken") + + +@mock_ssm +def test_describe_parameters_with_parameter_filters_keyid(): + client = boto3.client("ssm", region_name="us-east-1") + client.put_parameter(Name="secure-param", Value="secure-value", Type="SecureString") + client.put_parameter( + Name="custom-secure-param", + Value="custom-secure-value", + Type="SecureString", + KeyId="alias/custom", + ) + client.put_parameter(Name="param", Value="value", Type="String") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "KeyId", "Values": ["alias/aws/ssm"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("secure-param") + parameters[0]["Type"].should.equal("SecureString") + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "KeyId", "Values": ["alias/custom"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("custom-secure-param") + parameters[0]["Type"].should.equal("SecureString") + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "KeyId", "Option": "BeginsWith", "Values": ["alias"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(2) + response.should_not.have.key("NextToken") + + +@mock_ssm +def test_describe_parameters_with_parameter_filters_name(): + client = boto3.client("ssm", region_name="us-east-1") + client.put_parameter(Name="param", Value="value", Type="String") + client.put_parameter(Name="/param-2", Value="value-2", Type="String") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Name", "Values": ["param"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("param") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Name", "Values": ["/param"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("param") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Name", "Values": ["param-2"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("/param-2") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Name", "Option": "BeginsWith", "Values": ["param"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(2) + response.should_not.have.key("NextToken") + + +@mock_ssm +def test_describe_parameters_with_parameter_filters_path(): + client = boto3.client("ssm", region_name="us-east-1") + client.put_parameter(Name="/foo/name1", Value="value1", Type="String") + + client.put_parameter(Name="/foo/name2", Value="value2", Type="String") + + client.put_parameter(Name="/bar/name3", Value="value3", Type="String") + + client.put_parameter(Name="/bar/name3/name4", Value="value4", Type="String") + + client.put_parameter(Name="foo", Value="bar", Type="String") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Path", "Values": ["/fo"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(0) + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Path", "Values": ["/"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("foo") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Path", "Values": ["/", "/foo"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(3) + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2", "foo"} + ) + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Path", "Values": ["/foo/"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(2) + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2"} + ) + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[ + {"Key": "Path", "Option": "OneLevel", "Values": ["/bar/name3"]} + ] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("/bar/name3/name4") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Path", "Option": "Recursive", "Values": ["/fo"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(0) + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Path", "Option": "Recursive", "Values": ["/"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(5) + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[ + {"Key": "Path", "Option": "Recursive", "Values": ["/foo", "/bar"]} + ] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(4) + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2", "/bar/name3", "/bar/name3/name4"} + ) + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[{"Key": "Path", "Option": "Recursive", "Values": ["/foo/"]}] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(2) + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2"} + ) + response.should_not.have.key("NextToken") + + response = client.describe_parameters( + ParameterFilters=[ + {"Key": "Path", "Option": "Recursive", "Values": ["/bar/name3"]} + ] + ) + + parameters = response["Parameters"] + parameters.should.have.length_of(1) + parameters[0]["Name"].should.equal("/bar/name3/name4") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") + + +@mock_ssm +def test_describe_parameters_invalid_parameter_filters(): + client = boto3.client("ssm", region_name="us-east-1") + + client.describe_parameters.when.called_with( + Filters=[{"Key": "Name", "Values": ["test"]}], + ParameterFilters=[{"Key": "Name", "Values": ["test"]}], + ).should.throw( + ClientError, + "You can use either Filters or ParameterFilters in a single request.", + ) + + client.describe_parameters.when.called_with(ParameterFilters=[{}]).should.throw( + ParamValidationError, + 'Parameter validation failed:\nMissing required parameter in ParameterFilters[0]: "Key"', + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "key"}] + ).should.throw( + ClientError, + '1 validation error detected: Value "key" at "parameterFilters.1.member.key" failed to satisfy constraint: ' + "Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier", + ) + + long_key = "tag:" + "t" * 129 + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": long_key}] + ).should.throw( + ClientError, + '1 validation error detected: Value "{value}" at "parameterFilters.1.member.key" failed to satisfy constraint: ' + "Member must have length less than or equal to 132".format(value=long_key), + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Name", "Option": "over 10 chars"}] + ).should.throw( + ClientError, + '1 validation error detected: Value "over 10 chars" at "parameterFilters.1.member.option" failed to satisfy constraint: ' + "Member must have length less than or equal to 10", + ) + + many_values = ["test"] * 51 + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Name", "Values": many_values}] + ).should.throw( + ClientError, + '1 validation error detected: Value "{value}" at "parameterFilters.1.member.values" failed to satisfy constraint: ' + "Member must have length less than or equal to 50".format(value=many_values), + ) + + long_value = ["t" * 1025] + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Name", "Values": long_value}] + ).should.throw( + ClientError, + '1 validation error detected: Value "{value}" at "parameterFilters.1.member.values" failed to satisfy constraint: ' + "[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]".format( + value=long_value + ), + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Name", "Option": "over 10 chars"}, {"Key": "key"}] + ).should.throw( + ClientError, + "2 validation errors detected: " + 'Value "over 10 chars" at "parameterFilters.1.member.option" failed to satisfy constraint: ' + "Member must have length less than or equal to 10; " + 'Value "key" at "parameterFilters.2.member.key" failed to satisfy constraint: ' + "Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Label"}] + ).should.throw( + ClientError, + "The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier].", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Name"}] + ).should.throw( + ClientError, + "The following filter values are missing : null for filter key Name.", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Name", "Values": []}] + ).should.throw( + ParamValidationError, + "Invalid length for parameter ParameterFilters[0].Values, value: 0, valid range: 1-inf", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[ + {"Key": "Name", "Values": ["test"]}, + {"Key": "Name", "Values": ["test test"]}, + ] + ).should.throw( + ClientError, + "The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter.", + ) + + for value in ["/###", "//", "test"]: + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Path", "Values": [value]}] + ).should.throw( + ClientError, + 'The parameter doesn\'t meet the parameter name requirements. The parameter name must begin with a forward slash "/". ' + 'It can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' + "It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). " + 'Special characters are not allowed. All sub-paths, if specified, must use the forward slash symbol "/". ' + "Valid example: /get/parameters2-/by1./path0_.", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Path", "Values": ["/aws", "/ssm"]}] + ).should.throw( + ClientError, + 'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' + "When using global parameters, please specify within a global namespace.", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Path", "Option": "Equals", "Values": ["test"]}] + ).should.throw( + ClientError, + "The following filter option is not valid: Equals. Valid options include: [Recursive, OneLevel].", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Tier", "Values": ["test"]}] + ).should.throw( + ClientError, + "The following filter value is not valid: test. Valid values include: [Standard, Advanced, Intelligent-Tiering]", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Type", "Values": ["test"]}] + ).should.throw( + ClientError, + "The following filter value is not valid: test. Valid values include: [String, StringList, SecureString]", + ) + + client.describe_parameters.when.called_with( + ParameterFilters=[{"Key": "Name", "Option": "option", "Values": ["test"]}] + ).should.throw( + ClientError, + "The following filter option is not valid: option. Valid options include: [BeginsWith, Equals].", + ) @mock_ssm def test_describe_parameters_attributes(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='aa', - Value='11', - Type='String', - Description='my description' + Name="aa", Value="11", Type="String", Description="my description" ) - client.put_parameter( - Name='bb', - Value='22', - Type='String' - ) + client.put_parameter(Name="bb", Value="22", Type="String") response = client.describe_parameters() - len(response['Parameters']).should.equal(2) - response['Parameters'][0]['Description'].should.equal('my description') - response['Parameters'][0]['Version'].should.equal(1) - response['Parameters'][0]['LastModifiedDate'].should.be.a(datetime.date) - response['Parameters'][0]['LastModifiedUser'].should.equal('N/A') + parameters = response["Parameters"] + parameters.should.have.length_of(2) - response['Parameters'][1].get('Description').should.be.none - response['Parameters'][1]['Version'].should.equal(1) + parameters[0]["Description"].should.equal("my description") + parameters[0]["Version"].should.equal(1) + parameters[0]["LastModifiedDate"].should.be.a(datetime.date) + parameters[0]["LastModifiedUser"].should.equal("N/A") + + parameters[1].should_not.have.key("Description") + parameters[1]["Version"].should.equal(1) @mock_ssm def test_get_parameter_invalid(): - client = client = boto3.client('ssm', region_name='us-east-1') - response = client.get_parameters( - Names=[ - 'invalid' - ], - WithDecryption=False) + client = client = boto3.client("ssm", region_name="us-east-1") + response = client.get_parameters(Names=["invalid"], WithDecryption=False) - len(response['Parameters']).should.equal(0) - len(response['InvalidParameters']).should.equal(1) - response['InvalidParameters'][0].should.equal('invalid') + len(response["Parameters"]).should.equal(0) + len(response["InvalidParameters"]).should.equal(1) + response["InvalidParameters"][0].should.equal("invalid") @mock_ssm def test_put_parameter_secure_default_kms(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='SecureString') + Name="test", Description="A test parameter", Value="value", Type="SecureString" + ) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('kms:default:value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("kms:alias/aws/ssm:value") + response["Parameters"][0]["Type"].should.equal("SecureString") - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=True) + response = client.get_parameters(Names=["test"], WithDecryption=True) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("SecureString") @mock_ssm def test_put_parameter_secure_custom_kms(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='SecureString', - KeyId='foo') + Name="test", + Description="A test parameter", + Value="value", + Type="SecureString", + KeyId="foo", + ) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('kms:foo:value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("kms:foo:value") + response["Parameters"][0]["Type"].should.equal("SecureString") - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=True) + response = client.get_parameters(Names=["test"], WithDecryption=True) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("SecureString") + + +@mock_ssm +def test_get_parameter_history(): + client = boto3.client("ssm", region_name="us-east-1") + + test_parameter_name = "test" + + for i in range(3): + client.put_parameter( + Name=test_parameter_name, + Description="A test parameter version %d" % i, + Value="value-%d" % i, + Type="String", + Overwrite=True, + ) + + response = client.get_parameter_history(Name=test_parameter_name) + parameters_response = response["Parameters"] + + for index, param in enumerate(parameters_response): + param["Name"].should.equal(test_parameter_name) + param["Type"].should.equal("String") + param["Value"].should.equal("value-%d" % index) + param["Version"].should.equal(index + 1) + param["Description"].should.equal("A test parameter version %d" % index) + + len(parameters_response).should.equal(3) + + +@mock_ssm +def test_get_parameter_history_with_secure_string(): + client = boto3.client("ssm", region_name="us-east-1") + + test_parameter_name = "test" + + for i in range(3): + client.put_parameter( + Name=test_parameter_name, + Description="A test parameter version %d" % i, + Value="value-%d" % i, + Type="SecureString", + Overwrite=True, + ) + + for with_decryption in [True, False]: + response = client.get_parameter_history( + Name=test_parameter_name, WithDecryption=with_decryption + ) + parameters_response = response["Parameters"] + + for index, param in enumerate(parameters_response): + param["Name"].should.equal(test_parameter_name) + param["Type"].should.equal("SecureString") + expected_plaintext_value = "value-%d" % index + if with_decryption: + param["Value"].should.equal(expected_plaintext_value) + else: + param["Value"].should.equal( + "kms:alias/aws/ssm:%s" % expected_plaintext_value + ) + param["Version"].should.equal(index + 1) + param["Description"].should.equal("A test parameter version %d" % index) + + len(parameters_response).should.equal(3) + + +@mock_ssm +def test_get_parameter_history_missing_parameter(): + client = boto3.client("ssm", region_name="us-east-1") + + try: + client.get_parameter_history(Name="test_noexist") + raise RuntimeError("Should have failed") + except botocore.exceptions.ClientError as err: + err.operation_name.should.equal("GetParameterHistory") + err.response["Error"]["Message"].should.equal( + "Parameter test_noexist not found." + ) @mock_ssm def test_add_remove_list_tags_for_resource(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.add_tags_to_resource( - ResourceId='test', - ResourceType='Parameter', - Tags=[{'Key': 'test-key', 'Value': 'test-value'}] + ResourceId="test", + ResourceType="Parameter", + Tags=[{"Key": "test-key", "Value": "test-value"}], ) response = client.list_tags_for_resource( - ResourceId='test', - ResourceType='Parameter' + ResourceId="test", ResourceType="Parameter" ) - len(response['TagList']).should.equal(1) - response['TagList'][0]['Key'].should.equal('test-key') - response['TagList'][0]['Value'].should.equal('test-value') + len(response["TagList"]).should.equal(1) + response["TagList"][0]["Key"].should.equal("test-key") + response["TagList"][0]["Value"].should.equal("test-value") client.remove_tags_from_resource( - ResourceId='test', - ResourceType='Parameter', - TagKeys=['test-key'] + ResourceId="test", ResourceType="Parameter", TagKeys=["test-key"] ) response = client.list_tags_for_resource( - ResourceId='test', - ResourceType='Parameter' + ResourceId="test", ResourceType="Parameter" ) - len(response['TagList']).should.equal(0) + len(response["TagList"]).should.equal(0) @mock_ssm def test_send_command(): - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") # note the timeout is determined server side, so this is a simpler check. before = datetime.datetime.now() response = client.send_command( - InstanceIds=['i-123456'], + InstanceIds=["i-123456"], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref' + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", ) - cmd = response['Command'] + cmd = response["Command"] - cmd['CommandId'].should_not.be(None) - cmd['DocumentName'].should.equal(ssm_document) - cmd['Parameters'].should.equal(params) + cmd["CommandId"].should_not.be(None) + cmd["DocumentName"].should.equal(ssm_document) + cmd["Parameters"].should.equal(params) - cmd['OutputS3Region'].should.equal('us-east-2') - cmd['OutputS3BucketName'].should.equal('the-bucket') - cmd['OutputS3KeyPrefix'].should.equal('pref') + cmd["OutputS3Region"].should.equal("us-east-2") + cmd["OutputS3BucketName"].should.equal("the-bucket") + cmd["OutputS3KeyPrefix"].should.equal("pref") - cmd['ExpiresAfter'].should.be.greater_than(before) + cmd["ExpiresAfter"].should.be.greater_than(before) # test sending a command without any optional parameters - response = client.send_command( - DocumentName=ssm_document) + response = client.send_command(DocumentName=ssm_document) - cmd = response['Command'] + cmd = response["Command"] - cmd['CommandId'].should_not.be(None) - cmd['DocumentName'].should.equal(ssm_document) + cmd["CommandId"].should_not.be(None) + cmd["DocumentName"].should.equal(ssm_document) @mock_ssm def test_list_commands(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} response = client.send_command( - InstanceIds=['i-123456'], + InstanceIds=["i-123456"], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref') + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", + ) - cmd = response['Command'] - cmd_id = cmd['CommandId'] + cmd = response["Command"] + cmd_id = cmd["CommandId"] # get the command by id - response = client.list_commands( - CommandId=cmd_id) + response = client.list_commands(CommandId=cmd_id) - cmds = response['Commands'] + cmds = response["Commands"] len(cmds).should.equal(1) - cmds[0]['CommandId'].should.equal(cmd_id) + cmds[0]["CommandId"].should.equal(cmd_id) # add another command with the same instance id to test listing by # instance id - client.send_command( - InstanceIds=['i-123456'], - DocumentName=ssm_document) + client.send_command(InstanceIds=["i-123456"], DocumentName=ssm_document) - response = client.list_commands( - InstanceId='i-123456') + response = client.list_commands(InstanceId="i-123456") - cmds = response['Commands'] + cmds = response["Commands"] len(cmds).should.equal(2) for cmd in cmds: - cmd['InstanceIds'].should.contain('i-123456') + cmd["InstanceIds"].should.contain("i-123456") # test the error case for an invalid command id with assert_raises(ClientError): - response = client.list_commands( - CommandId=str(uuid.uuid4())) + response = client.list_commands(CommandId=str(uuid.uuid4())) + @mock_ssm def test_get_command_invocation(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} response = client.send_command( - InstanceIds=['i-123456', 'i-234567', 'i-345678'], + InstanceIds=["i-123456", "i-234567", "i-345678"], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref') + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", + ) - cmd = response['Command'] - cmd_id = cmd['CommandId'] + cmd = response["Command"] + cmd_id = cmd["CommandId"] - instance_id = 'i-345678' + instance_id = "i-345678" invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId=instance_id, - PluginName='aws:runShellScript') + CommandId=cmd_id, InstanceId=instance_id, PluginName="aws:runShellScript" + ) - invocation_response['CommandId'].should.equal(cmd_id) - invocation_response['InstanceId'].should.equal(instance_id) + invocation_response["CommandId"].should.equal(cmd_id) + invocation_response["InstanceId"].should.equal(instance_id) # test the error case for an invalid instance id with assert_raises(ClientError): invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId='i-FAKE') + CommandId=cmd_id, InstanceId="i-FAKE" + ) # test the error case for an invalid plugin name with assert_raises(ClientError): invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId=instance_id, - PluginName='FAKE') + CommandId=cmd_id, InstanceId=instance_id, PluginName="FAKE" + ) + @mock_ssm @mock_cloudformation @@ -738,63 +1102,52 @@ def test_get_command_invocations_from_stack(): "KeyName": "test", "InstanceType": "t2.micro", "Tags": [ - { - "Key": "Test Description", - "Value": "Test tag" - }, - { - "Key": "Test Name", - "Value": "Name tag for tests" - } - ] - } + {"Key": "Test Description", "Value": "Test tag"}, + {"Key": "Test Name", "Value": "Name tag for tests"}, + ], + }, } }, "Outputs": { "test": { "Description": "Test Output", "Value": "Test output value", - "Export": { - "Name": "Test value to export" - } + "Export": {"Name": "Test value to export"}, }, - "PublicIP": { - "Value": "Test public ip" - } - } + "PublicIP": {"Value": "Test public ip"}, + }, } - cloudformation_client = boto3.client( - 'cloudformation', - region_name='us-east-1') + cloudformation_client = boto3.client("cloudformation", region_name="us-east-1") stack_template_str = json.dumps(stack_template) response = cloudformation_client.create_stack( - StackName='test_stack', + StackName="test_stack", TemplateBody=stack_template_str, - Capabilities=('CAPABILITY_IAM', )) + Capabilities=("CAPABILITY_IAM",), + ) - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} response = client.send_command( - Targets=[{ - 'Key': 'tag:aws:cloudformation:stack-name', - 'Values': ('test_stack', )}], + Targets=[ + {"Key": "tag:aws:cloudformation:stack-name", "Values": ("test_stack",)} + ], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref') + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", + ) - cmd = response['Command'] - cmd_id = cmd['CommandId'] - instance_ids = cmd['InstanceIds'] + cmd = response["Command"] + cmd_id = cmd["CommandId"] + instance_ids = cmd["InstanceIds"] invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId=instance_ids[0], - PluginName='aws:runShellScript') + CommandId=cmd_id, InstanceId=instance_ids[0], PluginName="aws:runShellScript" + ) diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py new file mode 100644 index 000000000..3e0a8115d --- /dev/null +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -0,0 +1,534 @@ +from __future__ import unicode_literals + +import boto3 +import sure # noqa +import datetime + +from datetime import datetime +from botocore.exceptions import ClientError +from nose.tools import assert_raises + +from moto import mock_sts, mock_stepfunctions +from moto.core import ACCOUNT_ID + +region = "us-east-1" +simple_definition = ( + '{"Comment": "An example of the Amazon States Language using a choice state.",' + '"StartAt": "DefaultState",' + '"States": ' + '{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}' +) +account_id = None + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_succeeds(): + client = boto3.client("stepfunctions", region_name=region) + name = "example_step_function" + # + response = client.create_state_machine( + name=name, definition=str(simple_definition), roleArn=_get_default_role() + ) + # + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["creationDate"].should.be.a(datetime) + response["stateMachineArn"].should.equal( + "arn:aws:states:" + region + ":" + ACCOUNT_ID + ":stateMachine:" + name + ) + + +@mock_stepfunctions +def test_state_machine_creation_fails_with_invalid_names(): + client = boto3.client("stepfunctions", region_name=region) + invalid_names = [ + "with space", + "withbracket", + "with{bracket", + "with}bracket", + "with[bracket", + "with]bracket", + "with?wildcard", + "with*wildcard", + 'special"char', + "special#char", + "special%char", + "special\\char", + "special^char", + "special|char", + "special~char", + "special`char", + "special$char", + "special&char", + "special,char", + "special;char", + "special:char", + "special/char", + "uni\u0000code", + "uni\u0001code", + "uni\u0002code", + "uni\u0003code", + "uni\u0004code", + "uni\u0005code", + "uni\u0006code", + "uni\u0007code", + "uni\u0008code", + "uni\u0009code", + "uni\u000Acode", + "uni\u000Bcode", + "uni\u000Ccode", + "uni\u000Dcode", + "uni\u000Ecode", + "uni\u000Fcode", + "uni\u0010code", + "uni\u0011code", + "uni\u0012code", + "uni\u0013code", + "uni\u0014code", + "uni\u0015code", + "uni\u0016code", + "uni\u0017code", + "uni\u0018code", + "uni\u0019code", + "uni\u001Acode", + "uni\u001Bcode", + "uni\u001Ccode", + "uni\u001Dcode", + "uni\u001Ecode", + "uni\u001Fcode", + "uni\u007Fcode", + "uni\u0080code", + "uni\u0081code", + "uni\u0082code", + "uni\u0083code", + "uni\u0084code", + "uni\u0085code", + "uni\u0086code", + "uni\u0087code", + "uni\u0088code", + "uni\u0089code", + "uni\u008Acode", + "uni\u008Bcode", + "uni\u008Ccode", + "uni\u008Dcode", + "uni\u008Ecode", + "uni\u008Fcode", + "uni\u0090code", + "uni\u0091code", + "uni\u0092code", + "uni\u0093code", + "uni\u0094code", + "uni\u0095code", + "uni\u0096code", + "uni\u0097code", + "uni\u0098code", + "uni\u0099code", + "uni\u009Acode", + "uni\u009Bcode", + "uni\u009Ccode", + "uni\u009Dcode", + "uni\u009Ecode", + "uni\u009Fcode", + ] + # + + for invalid_name in invalid_names: + with assert_raises(ClientError) as exc: + client.create_state_machine( + name=invalid_name, + definition=str(simple_definition), + roleArn=_get_default_role(), + ) + + +@mock_stepfunctions +def test_state_machine_creation_requires_valid_role_arn(): + client = boto3.client("stepfunctions", region_name=region) + name = "example_step_function" + # + with assert_raises(ClientError) as exc: + client.create_state_machine( + name=name, + definition=str(simple_definition), + roleArn="arn:aws:iam::1234:role/unknown_role", + ) + + +@mock_stepfunctions +def test_state_machine_list_returns_empty_list_by_default(): + client = boto3.client("stepfunctions", region_name=region) + # + list = client.list_state_machines() + list["stateMachines"].should.be.empty + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_returns_created_state_machines(): + client = boto3.client("stepfunctions", region_name=region) + # + machine2 = client.create_state_machine( + name="name2", definition=str(simple_definition), roleArn=_get_default_role() + ) + machine1 = client.create_state_machine( + name="name1", + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=[{"key": "tag_key", "value": "tag_value"}], + ) + list = client.list_state_machines() + # + list["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + list["stateMachines"].should.have.length_of(2) + list["stateMachines"][0]["creationDate"].should.be.a(datetime) + list["stateMachines"][0]["creationDate"].should.equal(machine1["creationDate"]) + list["stateMachines"][0]["name"].should.equal("name1") + list["stateMachines"][0]["stateMachineArn"].should.equal( + machine1["stateMachineArn"] + ) + list["stateMachines"][1]["creationDate"].should.be.a(datetime) + list["stateMachines"][1]["creationDate"].should.equal(machine2["creationDate"]) + list["stateMachines"][1]["name"].should.equal("name2") + list["stateMachines"][1]["stateMachineArn"].should.equal( + machine2["stateMachineArn"] + ) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_is_idempotent_by_name(): + client = boto3.client("stepfunctions", region_name=region) + # + client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + sm_list = client.list_state_machines() + sm_list["stateMachines"].should.have.length_of(1) + # + client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + sm_list = client.list_state_machines() + sm_list["stateMachines"].should.have.length_of(1) + # + client.create_state_machine( + name="diff_name", definition=str(simple_definition), roleArn=_get_default_role() + ) + sm_list = client.list_state_machines() + sm_list["stateMachines"].should.have.length_of(2) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_can_be_described(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + desc = client.describe_state_machine(stateMachineArn=sm["stateMachineArn"]) + desc["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + desc["creationDate"].should.equal(sm["creationDate"]) + desc["definition"].should.equal(str(simple_definition)) + desc["name"].should.equal("name") + desc["roleArn"].should.equal(_get_default_role()) + desc["stateMachineArn"].should.equal(sm["stateMachineArn"]) + desc["status"].should.equal("ACTIVE") + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_machine(): + client = boto3.client("stepfunctions", region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_state_machine = ( + "arn:aws:states:" + + region + + ":" + + _get_account_id() + + ":stateMachine:unknown" + ) + client.describe_state_machine(stateMachineArn=unknown_state_machine) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_machine_in_different_account(): + client = boto3.client("stepfunctions", region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_state_machine = ( + "arn:aws:states:" + region + ":000000000000:stateMachine:unknown" + ) + client.describe_state_machine(stateMachineArn=unknown_state_machine) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_be_deleted(): + client = boto3.client("stepfunctions", region_name=region) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + # + response = client.delete_state_machine(stateMachineArn=sm["stateMachineArn"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + # + sm_list = client.list_state_machines() + sm_list["stateMachines"].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_deleted_nonexisting_machine(): + client = boto3.client("stepfunctions", region_name=region) + # + unknown_state_machine = ( + "arn:aws:states:" + region + ":" + ACCOUNT_ID + ":stateMachine:unknown" + ) + response = client.delete_state_machine(stateMachineArn=unknown_state_machine) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + # + sm_list = client.list_state_machines() + sm_list["stateMachines"].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_created_machine(): + client = boto3.client("stepfunctions", region_name=region) + # + machine = client.create_state_machine( + name="name1", + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=[{"key": "tag_key", "value": "tag_value"}], + ) + response = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + tags = response["tags"] + tags.should.have.length_of(1) + tags[0].should.equal({"key": "tag_key", "value": "tag_value"}) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_machine_without_tags(): + client = boto3.client("stepfunctions", region_name=region) + # + machine = client.create_state_machine( + name="name1", definition=str(simple_definition), roleArn=_get_default_role() + ) + response = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + tags = response["tags"] + tags.should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_nonexisting_machine(): + client = boto3.client("stepfunctions", region_name=region) + # + non_existing_state_machine = ( + "arn:aws:states:" + region + ":" + _get_account_id() + ":stateMachine:unknown" + ) + response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) + tags = response["tags"] + tags.should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_start_execution(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + # + execution["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + uuid_regex = "[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" + expected_exec_name = ( + "arn:aws:states:" + + region + + ":" + + _get_account_id() + + ":execution:name:" + + uuid_regex + ) + execution["executionArn"].should.match(expected_exec_name) + execution["startDate"].should.be.a(datetime) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_start_execution_with_custom_name(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution( + stateMachineArn=sm["stateMachineArn"], name="execution_name" + ) + # + execution["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + expected_exec_name = ( + "arn:aws:states:" + + region + + ":" + + _get_account_id() + + ":execution:name:execution_name" + ) + execution["executionArn"].should.equal(expected_exec_name) + execution["startDate"].should.be.a(datetime) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_executions(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + execution_arn = execution["executionArn"] + execution_name = execution_arn[execution_arn.rindex(":") + 1 :] + executions = client.list_executions(stateMachineArn=sm["stateMachineArn"]) + # + executions["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + executions["executions"].should.have.length_of(1) + executions["executions"][0]["executionArn"].should.equal(execution_arn) + executions["executions"][0]["name"].should.equal(execution_name) + executions["executions"][0]["startDate"].should.equal(execution["startDate"]) + executions["executions"][0]["stateMachineArn"].should.equal(sm["stateMachineArn"]) + executions["executions"][0]["status"].should.equal("RUNNING") + executions["executions"][0].shouldnt.have("stopDate") + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_executions_when_none_exist(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + executions = client.list_executions(stateMachineArn=sm["stateMachineArn"]) + # + executions["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + executions["executions"].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_describe_execution(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + description = client.describe_execution(executionArn=execution["executionArn"]) + # + description["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + description["executionArn"].should.equal(execution["executionArn"]) + description["input"].should.equal("{}") + description["name"].shouldnt.be.empty + description["startDate"].should.equal(execution["startDate"]) + description["stateMachineArn"].should.equal(sm["stateMachineArn"]) + description["status"].should.equal("RUNNING") + description.shouldnt.have("stopDate") + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_machine(): + client = boto3.client("stepfunctions", region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_execution = ( + "arn:aws:states:" + region + ":" + _get_account_id() + ":execution:unknown" + ) + client.describe_execution(executionArn=unknown_execution) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_be_described_by_execution(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + desc = client.describe_state_machine_for_execution( + executionArn=execution["executionArn"] + ) + desc["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + desc["definition"].should.equal(str(simple_definition)) + desc["name"].should.equal("name") + desc["roleArn"].should.equal(_get_default_role()) + desc["stateMachineArn"].should.equal(sm["stateMachineArn"]) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_execution(): + client = boto3.client("stepfunctions", region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_execution = ( + "arn:aws:states:" + region + ":" + _get_account_id() + ":execution:unknown" + ) + client.describe_state_machine_for_execution(executionArn=unknown_execution) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_stop_execution(): + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + start = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + stop = client.stop_execution(executionArn=start["executionArn"]) + # + stop["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + stop["stopDate"].should.be.a(datetime) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_describe_execution_after_stoppage(): + account_id + client = boto3.client("stepfunctions", region_name=region) + # + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + client.stop_execution(executionArn=execution["executionArn"]) + description = client.describe_execution(executionArn=execution["executionArn"]) + # + description["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + description["status"].should.equal("SUCCEEDED") + description["stopDate"].should.be.a(datetime) + + +def _get_account_id(): + global account_id + if account_id: + return account_id + sts = boto3.client("sts", region_name=region) + identity = sts.get_caller_identity() + account_id = identity["Account"] + return account_id + + +def _get_default_role(): + return "arn:aws:iam::" + _get_account_id() + ":role/unknown_sf_role" diff --git a/tests/test_sts/test_server.py b/tests/test_sts/test_server.py index 1cff6b0af..8903477d7 100644 --- a/tests/test_sts/test_server.py +++ b/tests/test_sts/test_server.py @@ -1,39 +1,39 @@ -from __future__ import unicode_literals -import sure # noqa - -import moto.server as server - -''' -Test the different server responses -''' - - -def test_sts_get_session_token(): - backend = server.create_backend_app("sts") - test_client = backend.test_client() - - res = test_client.get('/?Action=GetSessionToken') - res.status_code.should.equal(200) - res.data.should.contain(b"SessionToken") - res.data.should.contain(b"AccessKeyId") - - -def test_sts_get_federation_token(): - backend = server.create_backend_app("sts") - test_client = backend.test_client() - - res = test_client.get('/?Action=GetFederationToken&Name=Bob') - res.status_code.should.equal(200) - res.data.should.contain(b"SessionToken") - res.data.should.contain(b"AccessKeyId") - - -def test_sts_get_caller_identity(): - backend = server.create_backend_app("sts") - test_client = backend.test_client() - - res = test_client.get('/?Action=GetCallerIdentity') - res.status_code.should.equal(200) - res.data.should.contain(b"Arn") - res.data.should.contain(b"UserId") - res.data.should.contain(b"Account") +from __future__ import unicode_literals +import sure # noqa + +import moto.server as server + +""" +Test the different server responses +""" + + +def test_sts_get_session_token(): + backend = server.create_backend_app("sts") + test_client = backend.test_client() + + res = test_client.get("/?Action=GetSessionToken") + res.status_code.should.equal(200) + res.data.should.contain(b"SessionToken") + res.data.should.contain(b"AccessKeyId") + + +def test_sts_get_federation_token(): + backend = server.create_backend_app("sts") + test_client = backend.test_client() + + res = test_client.get("/?Action=GetFederationToken&Name=Bob") + res.status_code.should.equal(200) + res.data.should.contain(b"SessionToken") + res.data.should.contain(b"AccessKeyId") + + +def test_sts_get_caller_identity(): + backend = server.create_backend_app("sts") + test_client = backend.test_client() + + res = test_client.get("/?Action=GetCallerIdentity") + res.status_code.should.equal(200) + res.data.should.contain(b"Arn") + res.data.should.contain(b"UserId") + res.data.should.contain(b"Account") diff --git a/tests/test_sts/test_sts.py b/tests/test_sts/test_sts.py index b047a8d13..4dee9184f 100644 --- a/tests/test_sts/test_sts.py +++ b/tests/test_sts/test_sts.py @@ -10,7 +10,7 @@ import sure # noqa from moto import mock_sts, mock_sts_deprecated, mock_iam, settings -from moto.iam.models import ACCOUNT_ID +from moto.core import ACCOUNT_ID from moto.sts.responses import MAX_FEDERATION_TOKEN_POLICY_LENGTH @@ -20,9 +20,10 @@ def test_get_session_token(): conn = boto.connect_sts() token = conn.get_session_token(duration=123) - token.expiration.should.equal('2012-01-01T12:02:03.000Z') + token.expiration.should.equal("2012-01-01T12:02:03.000Z") token.session_token.should.equal( - "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE") + "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE" + ) token.access_key.should.equal("AKIAIOSFODNN7EXAMPLE") token.secret_key.should.equal("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") @@ -34,57 +35,72 @@ def test_get_federation_token(): token_name = "Bob" token = conn.get_federation_token(duration=123, name=token_name) - token.credentials.expiration.should.equal('2012-01-01T12:02:03.000Z') + token.credentials.expiration.should.equal("2012-01-01T12:02:03.000Z") token.credentials.session_token.should.equal( - "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQWLWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGdQrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz+scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==") + "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQWLWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGdQrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz+scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==" + ) token.credentials.access_key.should.equal("AKIAIOSFODNN7EXAMPLE") token.credentials.secret_key.should.equal( - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + ) token.federated_user_arn.should.equal( - "arn:aws:sts::{account_id}:federated-user/{token_name}".format(account_id=ACCOUNT_ID, token_name=token_name)) + "arn:aws:sts::{account_id}:federated-user/{token_name}".format( + account_id=ACCOUNT_ID, token_name=token_name + ) + ) token.federated_user_id.should.equal(str(ACCOUNT_ID) + ":" + token_name) @freeze_time("2012-01-01 12:00:00") @mock_sts def test_assume_role(): - client = boto3.client( - "sts", region_name='us-east-1') + client = boto3.client("sts", region_name="us-east-1") session_name = "session-name" - policy = json.dumps({ - "Statement": [ - { - "Sid": "Stmt13690092345534", - "Action": [ - "S3:ListBucket" - ], - "Effect": "Allow", - "Resource": [ - "arn:aws:s3:::foobar-tester" - ] - }, - ] - }) + policy = json.dumps( + { + "Statement": [ + { + "Sid": "Stmt13690092345534", + "Action": ["S3:ListBucket"], + "Effect": "Allow", + "Resource": ["arn:aws:s3:::foobar-tester"], + } + ] + } + ) role_name = "test-role" - s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format(account_id=ACCOUNT_ID, role_name=role_name) - assume_role_response = client.assume_role(RoleArn=s3_role, RoleSessionName=session_name, - Policy=policy, DurationSeconds=900) + s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format( + account_id=ACCOUNT_ID, role_name=role_name + ) + assume_role_response = client.assume_role( + RoleArn=s3_role, + RoleSessionName=session_name, + Policy=policy, + DurationSeconds=900, + ) - credentials = assume_role_response['Credentials'] + credentials = assume_role_response["Credentials"] if not settings.TEST_SERVER_MODE: - credentials['Expiration'].isoformat().should.equal('2012-01-01T12:15:00+00:00') - credentials['SessionToken'].should.have.length_of(356) - assert credentials['SessionToken'].startswith("FQoGZXIvYXdzE") - credentials['AccessKeyId'].should.have.length_of(20) - assert credentials['AccessKeyId'].startswith("ASIA") - credentials['SecretAccessKey'].should.have.length_of(40) + credentials["Expiration"].isoformat().should.equal("2012-01-01T12:15:00+00:00") + credentials["SessionToken"].should.have.length_of(356) + assert credentials["SessionToken"].startswith("FQoGZXIvYXdzE") + credentials["AccessKeyId"].should.have.length_of(20) + assert credentials["AccessKeyId"].startswith("ASIA") + credentials["SecretAccessKey"].should.have.length_of(40) - assume_role_response['AssumedRoleUser']['Arn'].should.equal("arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( - account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name)) - assert assume_role_response['AssumedRoleUser']['AssumedRoleId'].startswith("AROA") - assert assume_role_response['AssumedRoleUser']['AssumedRoleId'].endswith(":" + session_name) - assume_role_response['AssumedRoleUser']['AssumedRoleId'].should.have.length_of(21 + 1 + len(session_name)) + assume_role_response["AssumedRoleUser"]["Arn"].should.equal( + "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( + account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name + ) + ) + assert assume_role_response["AssumedRoleUser"]["AssumedRoleId"].startswith("AROA") + assert assume_role_response["AssumedRoleUser"]["AssumedRoleId"].endswith( + ":" + session_name + ) + assume_role_response["AssumedRoleUser"]["AssumedRoleId"].should.have.length_of( + 21 + 1 + len(session_name) + ) @freeze_time("2012-01-01 12:00:00") @@ -92,122 +108,135 @@ def test_assume_role(): def test_assume_role_with_web_identity(): conn = boto.connect_sts() - policy = json.dumps({ - "Statement": [ - { - "Sid": "Stmt13690092345534", - "Action": [ - "S3:ListBucket" - ], - "Effect": "Allow", - "Resource": [ - "arn:aws:s3:::foobar-tester" - ] - }, - ] - }) + policy = json.dumps( + { + "Statement": [ + { + "Sid": "Stmt13690092345534", + "Action": ["S3:ListBucket"], + "Effect": "Allow", + "Resource": ["arn:aws:s3:::foobar-tester"], + } + ] + } + ) role_name = "test-role" - s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format(account_id=ACCOUNT_ID, role_name=role_name) + s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format( + account_id=ACCOUNT_ID, role_name=role_name + ) session_name = "session-name" role = conn.assume_role_with_web_identity( - s3_role, session_name, policy, duration_seconds=123) + s3_role, session_name, policy, duration_seconds=123 + ) credentials = role.credentials - credentials.expiration.should.equal('2012-01-01T12:02:03.000Z') + credentials.expiration.should.equal("2012-01-01T12:02:03.000Z") credentials.session_token.should.have.length_of(356) assert credentials.session_token.startswith("FQoGZXIvYXdzE") credentials.access_key.should.have.length_of(20) assert credentials.access_key.startswith("ASIA") credentials.secret_key.should.have.length_of(40) - role.user.arn.should.equal("arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( - account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name)) + role.user.arn.should.equal( + "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( + account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name + ) + ) role.user.assume_role_id.should.contain("session-name") @mock_sts def test_get_caller_identity_with_default_credentials(): - identity = boto3.client( - "sts", region_name='us-east-1').get_caller_identity() + identity = boto3.client("sts", region_name="us-east-1").get_caller_identity() - identity['Arn'].should.equal('arn:aws:sts::{account_id}:user/moto'.format(account_id=ACCOUNT_ID)) - identity['UserId'].should.equal('AKIAIOSFODNN7EXAMPLE') - identity['Account'].should.equal(str(ACCOUNT_ID)) + identity["Arn"].should.equal( + "arn:aws:sts::{account_id}:user/moto".format(account_id=ACCOUNT_ID) + ) + identity["UserId"].should.equal("AKIAIOSFODNN7EXAMPLE") + identity["Account"].should.equal(str(ACCOUNT_ID)) @mock_sts @mock_iam def test_get_caller_identity_with_iam_user_credentials(): - iam_client = boto3.client("iam", region_name='us-east-1') + iam_client = boto3.client("iam", region_name="us-east-1") iam_user_name = "new-user" - iam_user = iam_client.create_user(UserName=iam_user_name)['User'] - access_key = iam_client.create_access_key(UserName=iam_user_name)['AccessKey'] + iam_user = iam_client.create_user(UserName=iam_user_name)["User"] + access_key = iam_client.create_access_key(UserName=iam_user_name)["AccessKey"] identity = boto3.client( - "sts", region_name='us-east-1', aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']).get_caller_identity() + "sts", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ).get_caller_identity() - identity['Arn'].should.equal(iam_user['Arn']) - identity['UserId'].should.equal(iam_user['UserId']) - identity['Account'].should.equal(str(ACCOUNT_ID)) + identity["Arn"].should.equal(iam_user["Arn"]) + identity["UserId"].should.equal(iam_user["UserId"]) + identity["Account"].should.equal(str(ACCOUNT_ID)) @mock_sts @mock_iam def test_get_caller_identity_with_assumed_role_credentials(): - iam_client = boto3.client("iam", region_name='us-east-1') - sts_client = boto3.client("sts", region_name='us-east-1') + iam_client = boto3.client("iam", region_name="us-east-1") + sts_client = boto3.client("sts", region_name="us-east-1") iam_role_name = "new-user" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } iam_role_arn = iam_client.role_arn = iam_client.create_role( RoleName=iam_role_name, - AssumeRolePolicyDocument=json.dumps(trust_policy_document) - )['Role']['Arn'] + AssumeRolePolicyDocument=json.dumps(trust_policy_document), + )["Role"]["Arn"] session_name = "new-session" - assumed_role = sts_client.assume_role(RoleArn=iam_role_arn, - RoleSessionName=session_name) - access_key = assumed_role['Credentials'] + assumed_role = sts_client.assume_role( + RoleArn=iam_role_arn, RoleSessionName=session_name + ) + access_key = assumed_role["Credentials"] identity = boto3.client( - "sts", region_name='us-east-1', aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']).get_caller_identity() + "sts", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ).get_caller_identity() - identity['Arn'].should.equal(assumed_role['AssumedRoleUser']['Arn']) - identity['UserId'].should.equal(assumed_role['AssumedRoleUser']['AssumedRoleId']) - identity['Account'].should.equal(str(ACCOUNT_ID)) + identity["Arn"].should.equal(assumed_role["AssumedRoleUser"]["Arn"]) + identity["UserId"].should.equal(assumed_role["AssumedRoleUser"]["AssumedRoleId"]) + identity["Account"].should.equal(str(ACCOUNT_ID)) @mock_sts def test_federation_token_with_too_long_policy(): "Trying to get a federation token with a policy longer than 2048 character should fail" - cli = boto3.client("sts", region_name='us-east-1') - resource_tmpl = 'arn:aws:s3:::yyyy-xxxxx-cloud-default/my_default_folder/folder-name-%s/*' + cli = boto3.client("sts", region_name="us-east-1") + resource_tmpl = ( + "arn:aws:s3:::yyyy-xxxxx-cloud-default/my_default_folder/folder-name-%s/*" + ) statements = [] for num in range(30): statements.append( { - 'Effect': 'Allow', - 'Action': ['s3:*'], - 'Resource': resource_tmpl % str(num) + "Effect": "Allow", + "Action": ["s3:*"], + "Resource": resource_tmpl % str(num), } ) - policy = { - 'Version': '2012-10-17', - 'Statement': statements - } + policy = {"Version": "2012-10-17", "Statement": statements} json_policy = json.dumps(policy) assert len(json_policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH with assert_raises(ClientError) as exc: - cli.get_federation_token(Name='foo', DurationSeconds=3600, Policy=json_policy) - exc.exception.response['Error']['Code'].should.equal('ValidationError') - exc.exception.response['Error']['Message'].should.contain( + cli.get_federation_token(Name="foo", DurationSeconds=3600, Policy=json_policy) + exc.exception.response["Error"]["Code"].should.equal("ValidationError") + exc.exception.response["Error"]["Message"].should.contain( str(MAX_FEDERATION_TOKEN_POLICY_LENGTH) ) diff --git a/tests/test_swf/models/test_activity_task.py b/tests/test_swf/models/test_activity_task.py index dfcaf9801..96f7c345f 100644 --- a/tests/test_swf/models/test_activity_task.py +++ b/tests/test_swf/models/test_activity_task.py @@ -1,154 +1,151 @@ -from freezegun import freeze_time -import sure # noqa - -from moto.swf.exceptions import SWFWorkflowExecutionClosedError -from moto.swf.models import ( - ActivityTask, - ActivityType, - Timeout, -) - -from ..utils import ( - ACTIVITY_TASK_TIMEOUTS, - make_workflow_execution, - process_first_timeout, -) - - -def test_activity_task_creation(): - wfe = make_workflow_execution() - task = ActivityTask( - activity_id="my-activity-123", - activity_type="foo", - input="optional", - scheduled_event_id=117, - workflow_execution=wfe, - timeouts=ACTIVITY_TASK_TIMEOUTS, - ) - task.workflow_execution.should.equal(wfe) - task.state.should.equal("SCHEDULED") - task.task_token.should_not.be.empty - task.started_event_id.should.be.none - - task.start(123) - task.state.should.equal("STARTED") - task.started_event_id.should.equal(123) - - task.complete() - task.state.should.equal("COMPLETED") - - # NB: this doesn't make any sense for SWF, a task shouldn't go from a - # "COMPLETED" state to a "FAILED" one, but this is an internal state on our - # side and we don't care about invalid state transitions for now. - task.fail() - task.state.should.equal("FAILED") - - -def test_activity_task_full_dict_representation(): - wfe = make_workflow_execution() - at = ActivityTask( - activity_id="my-activity-123", - activity_type=ActivityType("foo", "v1.0"), - input="optional", - scheduled_event_id=117, - timeouts=ACTIVITY_TASK_TIMEOUTS, - workflow_execution=wfe, - ) - at.start(1234) - - fd = at.to_full_dict() - fd["activityId"].should.equal("my-activity-123") - fd["activityType"]["version"].should.equal("v1.0") - fd["input"].should.equal("optional") - fd["startedEventId"].should.equal(1234) - fd.should.contain("taskToken") - fd["workflowExecution"].should.equal(wfe.to_short_dict()) - - at.start(1234) - fd = at.to_full_dict() - fd["startedEventId"].should.equal(1234) - - -def test_activity_task_reset_heartbeat_clock(): - wfe = make_workflow_execution() - - with freeze_time("2015-01-01 12:00:00"): - task = ActivityTask( - activity_id="my-activity-123", - activity_type="foo", - input="optional", - scheduled_event_id=117, - timeouts=ACTIVITY_TASK_TIMEOUTS, - workflow_execution=wfe, - ) - - task.last_heartbeat_timestamp.should.equal(1420113600.0) - - with freeze_time("2015-01-01 13:00:00"): - task.reset_heartbeat_clock() - - task.last_heartbeat_timestamp.should.equal(1420117200.0) - - -def test_activity_task_first_timeout(): - wfe = make_workflow_execution() - - with freeze_time("2015-01-01 12:00:00"): - task = ActivityTask( - activity_id="my-activity-123", - activity_type="foo", - input="optional", - scheduled_event_id=117, - timeouts=ACTIVITY_TASK_TIMEOUTS, - workflow_execution=wfe, - ) - task.first_timeout().should.be.none - - # activity task timeout is 300s == 5mins - with freeze_time("2015-01-01 12:06:00"): - task.first_timeout().should.be.a(Timeout) - process_first_timeout(task) - task.state.should.equal("TIMED_OUT") - task.timeout_type.should.equal("HEARTBEAT") - - -def test_activity_task_cannot_timeout_on_closed_workflow_execution(): - with freeze_time("2015-01-01 12:00:00"): - wfe = make_workflow_execution() - wfe.start() - - with freeze_time("2015-01-01 13:58:00"): - task = ActivityTask( - activity_id="my-activity-123", - activity_type="foo", - input="optional", - scheduled_event_id=117, - timeouts=ACTIVITY_TASK_TIMEOUTS, - workflow_execution=wfe, - ) - - with freeze_time("2015-01-01 14:10:00"): - task.first_timeout().should.be.a(Timeout) - wfe.first_timeout().should.be.a(Timeout) - process_first_timeout(wfe) - task.first_timeout().should.be.none - - -def test_activity_task_cannot_change_state_on_closed_workflow_execution(): - wfe = make_workflow_execution() - wfe.start() - - task = ActivityTask( - activity_id="my-activity-123", - activity_type="foo", - input="optional", - scheduled_event_id=117, - timeouts=ACTIVITY_TASK_TIMEOUTS, - workflow_execution=wfe, - ) - wfe.complete(123) - - task.timeout.when.called_with(Timeout(task, 0, "foo")).should.throw( - SWFWorkflowExecutionClosedError) - task.complete.when.called_with().should.throw(SWFWorkflowExecutionClosedError) - task.fail.when.called_with().should.throw(SWFWorkflowExecutionClosedError) +from freezegun import freeze_time +import sure # noqa + +from moto.swf.exceptions import SWFWorkflowExecutionClosedError +from moto.swf.models import ActivityTask, ActivityType, Timeout + +from ..utils import ( + ACTIVITY_TASK_TIMEOUTS, + make_workflow_execution, + process_first_timeout, +) + + +def test_activity_task_creation(): + wfe = make_workflow_execution() + task = ActivityTask( + activity_id="my-activity-123", + activity_type="foo", + input="optional", + scheduled_event_id=117, + workflow_execution=wfe, + timeouts=ACTIVITY_TASK_TIMEOUTS, + ) + task.workflow_execution.should.equal(wfe) + task.state.should.equal("SCHEDULED") + task.task_token.should_not.be.empty + task.started_event_id.should.be.none + + task.start(123) + task.state.should.equal("STARTED") + task.started_event_id.should.equal(123) + + task.complete() + task.state.should.equal("COMPLETED") + + # NB: this doesn't make any sense for SWF, a task shouldn't go from a + # "COMPLETED" state to a "FAILED" one, but this is an internal state on our + # side and we don't care about invalid state transitions for now. + task.fail() + task.state.should.equal("FAILED") + + +def test_activity_task_full_dict_representation(): + wfe = make_workflow_execution() + at = ActivityTask( + activity_id="my-activity-123", + activity_type=ActivityType("foo", "v1.0"), + input="optional", + scheduled_event_id=117, + timeouts=ACTIVITY_TASK_TIMEOUTS, + workflow_execution=wfe, + ) + at.start(1234) + + fd = at.to_full_dict() + fd["activityId"].should.equal("my-activity-123") + fd["activityType"]["version"].should.equal("v1.0") + fd["input"].should.equal("optional") + fd["startedEventId"].should.equal(1234) + fd.should.contain("taskToken") + fd["workflowExecution"].should.equal(wfe.to_short_dict()) + + at.start(1234) + fd = at.to_full_dict() + fd["startedEventId"].should.equal(1234) + + +def test_activity_task_reset_heartbeat_clock(): + wfe = make_workflow_execution() + + with freeze_time("2015-01-01 12:00:00"): + task = ActivityTask( + activity_id="my-activity-123", + activity_type="foo", + input="optional", + scheduled_event_id=117, + timeouts=ACTIVITY_TASK_TIMEOUTS, + workflow_execution=wfe, + ) + + task.last_heartbeat_timestamp.should.equal(1420113600.0) + + with freeze_time("2015-01-01 13:00:00"): + task.reset_heartbeat_clock() + + task.last_heartbeat_timestamp.should.equal(1420117200.0) + + +def test_activity_task_first_timeout(): + wfe = make_workflow_execution() + + with freeze_time("2015-01-01 12:00:00"): + task = ActivityTask( + activity_id="my-activity-123", + activity_type="foo", + input="optional", + scheduled_event_id=117, + timeouts=ACTIVITY_TASK_TIMEOUTS, + workflow_execution=wfe, + ) + task.first_timeout().should.be.none + + # activity task timeout is 300s == 5mins + with freeze_time("2015-01-01 12:06:00"): + task.first_timeout().should.be.a(Timeout) + process_first_timeout(task) + task.state.should.equal("TIMED_OUT") + task.timeout_type.should.equal("HEARTBEAT") + + +def test_activity_task_cannot_timeout_on_closed_workflow_execution(): + with freeze_time("2015-01-01 12:00:00"): + wfe = make_workflow_execution() + wfe.start() + + with freeze_time("2015-01-01 13:58:00"): + task = ActivityTask( + activity_id="my-activity-123", + activity_type="foo", + input="optional", + scheduled_event_id=117, + timeouts=ACTIVITY_TASK_TIMEOUTS, + workflow_execution=wfe, + ) + + with freeze_time("2015-01-01 14:10:00"): + task.first_timeout().should.be.a(Timeout) + wfe.first_timeout().should.be.a(Timeout) + process_first_timeout(wfe) + task.first_timeout().should.be.none + + +def test_activity_task_cannot_change_state_on_closed_workflow_execution(): + wfe = make_workflow_execution() + wfe.start() + + task = ActivityTask( + activity_id="my-activity-123", + activity_type="foo", + input="optional", + scheduled_event_id=117, + timeouts=ACTIVITY_TASK_TIMEOUTS, + workflow_execution=wfe, + ) + wfe.complete(123) + + task.timeout.when.called_with(Timeout(task, 0, "foo")).should.throw( + SWFWorkflowExecutionClosedError + ) + task.complete.when.called_with().should.throw(SWFWorkflowExecutionClosedError) + task.fail.when.called_with().should.throw(SWFWorkflowExecutionClosedError) diff --git a/tests/test_swf/models/test_decision_task.py b/tests/test_swf/models/test_decision_task.py index b593db5ff..0661adffb 100644 --- a/tests/test_swf/models/test_decision_task.py +++ b/tests/test_swf/models/test_decision_task.py @@ -1,80 +1,81 @@ -from boto.swf.exceptions import SWFResponseError -from freezegun import freeze_time -from sure import expect - -from moto.swf.models import DecisionTask, Timeout -from moto.swf.exceptions import SWFWorkflowExecutionClosedError - -from ..utils import make_workflow_execution, process_first_timeout - - -def test_decision_task_creation(): - wfe = make_workflow_execution() - dt = DecisionTask(wfe, 123) - dt.workflow_execution.should.equal(wfe) - dt.state.should.equal("SCHEDULED") - dt.task_token.should_not.be.empty - dt.started_event_id.should.be.none - - -def test_decision_task_full_dict_representation(): - wfe = make_workflow_execution() - wft = wfe.workflow_type - dt = DecisionTask(wfe, 123) - - fd = dt.to_full_dict() - fd["events"].should.be.a("list") - fd["previousStartedEventId"].should.equal(0) - fd.should_not.contain("startedEventId") - fd.should.contain("taskToken") - fd["workflowExecution"].should.equal(wfe.to_short_dict()) - fd["workflowType"].should.equal(wft.to_short_dict()) - - dt.start(1234) - fd = dt.to_full_dict() - fd["startedEventId"].should.equal(1234) - - -def test_decision_task_first_timeout(): - wfe = make_workflow_execution() - dt = DecisionTask(wfe, 123) - dt.first_timeout().should.be.none - - with freeze_time("2015-01-01 12:00:00"): - dt.start(1234) - dt.first_timeout().should.be.none - - # activity task timeout is 300s == 5mins - with freeze_time("2015-01-01 12:06:00"): - dt.first_timeout().should.be.a(Timeout) - - dt.complete() - dt.first_timeout().should.be.none - - -def test_decision_task_cannot_timeout_on_closed_workflow_execution(): - with freeze_time("2015-01-01 12:00:00"): - wfe = make_workflow_execution() - wfe.start() - - with freeze_time("2015-01-01 13:55:00"): - dt = DecisionTask(wfe, 123) - dt.start(1234) - - with freeze_time("2015-01-01 14:10:00"): - dt.first_timeout().should.be.a(Timeout) - wfe.first_timeout().should.be.a(Timeout) - process_first_timeout(wfe) - dt.first_timeout().should.be.none - - -def test_decision_task_cannot_change_state_on_closed_workflow_execution(): - wfe = make_workflow_execution() - wfe.start() - task = DecisionTask(wfe, 123) - - wfe.complete(123) - - task.timeout.when.called_with(Timeout(task, 0, "foo")).should.throw( - SWFWorkflowExecutionClosedError) - task.complete.when.called_with().should.throw(SWFWorkflowExecutionClosedError) +from boto.swf.exceptions import SWFResponseError +from freezegun import freeze_time +from sure import expect + +from moto.swf.models import DecisionTask, Timeout +from moto.swf.exceptions import SWFWorkflowExecutionClosedError + +from ..utils import make_workflow_execution, process_first_timeout + + +def test_decision_task_creation(): + wfe = make_workflow_execution() + dt = DecisionTask(wfe, 123) + dt.workflow_execution.should.equal(wfe) + dt.state.should.equal("SCHEDULED") + dt.task_token.should_not.be.empty + dt.started_event_id.should.be.none + + +def test_decision_task_full_dict_representation(): + wfe = make_workflow_execution() + wft = wfe.workflow_type + dt = DecisionTask(wfe, 123) + + fd = dt.to_full_dict() + fd["events"].should.be.a("list") + fd["previousStartedEventId"].should.equal(0) + fd.should_not.contain("startedEventId") + fd.should.contain("taskToken") + fd["workflowExecution"].should.equal(wfe.to_short_dict()) + fd["workflowType"].should.equal(wft.to_short_dict()) + + dt.start(1234) + fd = dt.to_full_dict() + fd["startedEventId"].should.equal(1234) + + +def test_decision_task_first_timeout(): + wfe = make_workflow_execution() + dt = DecisionTask(wfe, 123) + dt.first_timeout().should.be.none + + with freeze_time("2015-01-01 12:00:00"): + dt.start(1234) + dt.first_timeout().should.be.none + + # activity task timeout is 300s == 5mins + with freeze_time("2015-01-01 12:06:00"): + dt.first_timeout().should.be.a(Timeout) + + dt.complete() + dt.first_timeout().should.be.none + + +def test_decision_task_cannot_timeout_on_closed_workflow_execution(): + with freeze_time("2015-01-01 12:00:00"): + wfe = make_workflow_execution() + wfe.start() + + with freeze_time("2015-01-01 13:55:00"): + dt = DecisionTask(wfe, 123) + dt.start(1234) + + with freeze_time("2015-01-01 14:10:00"): + dt.first_timeout().should.be.a(Timeout) + wfe.first_timeout().should.be.a(Timeout) + process_first_timeout(wfe) + dt.first_timeout().should.be.none + + +def test_decision_task_cannot_change_state_on_closed_workflow_execution(): + wfe = make_workflow_execution() + wfe.start() + task = DecisionTask(wfe, 123) + + wfe.complete(123) + + task.timeout.when.called_with(Timeout(task, 0, "foo")).should.throw( + SWFWorkflowExecutionClosedError + ) + task.complete.when.called_with().should.throw(SWFWorkflowExecutionClosedError) diff --git a/tests/test_swf/models/test_domain.py b/tests/test_swf/models/test_domain.py index 1dc5cec65..32940753f 100644 --- a/tests/test_swf/models/test_domain.py +++ b/tests/test_swf/models/test_domain.py @@ -9,15 +9,13 @@ import tests.backport_assert_raises # noqa # Fake WorkflowExecution for tests purposes WorkflowExecution = namedtuple( - "WorkflowExecution", - ["workflow_id", "run_id", "execution_status", "open"] + "WorkflowExecution", ["workflow_id", "run_id", "execution_status", "open"] ) def test_domain_short_dict_representation(): domain = Domain("foo", "52") - domain.to_short_dict().should.equal( - {"name": "foo", "status": "REGISTERED"}) + domain.to_short_dict().should.equal({"name": "foo", "status": "REGISTERED"}) domain.description = "foo bar" domain.to_short_dict()["description"].should.equal("foo bar") @@ -39,9 +37,7 @@ def test_domain_string_representation(): def test_domain_add_to_activity_task_list(): domain = Domain("my-domain", "60") domain.add_to_activity_task_list("foo", "bar") - domain.activity_task_lists.should.equal({ - "foo": ["bar"] - }) + domain.activity_task_lists.should.equal({"foo": ["bar"]}) def test_domain_activity_tasks(): @@ -54,9 +50,7 @@ def test_domain_activity_tasks(): def test_domain_add_to_decision_task_list(): domain = Domain("my-domain", "60") domain.add_to_decision_task_list("foo", "bar") - domain.decision_task_lists.should.equal({ - "foo": ["bar"] - }) + domain.decision_task_lists.should.equal({"foo": ["bar"]}) def test_domain_decision_tasks(): @@ -70,50 +64,44 @@ def test_domain_get_workflow_execution(): domain = Domain("my-domain", "60") wfe1 = WorkflowExecution( - workflow_id="wf-id-1", run_id="run-id-1", execution_status="OPEN", open=True) + workflow_id="wf-id-1", run_id="run-id-1", execution_status="OPEN", open=True + ) wfe2 = WorkflowExecution( - workflow_id="wf-id-1", run_id="run-id-2", execution_status="CLOSED", open=False) + workflow_id="wf-id-1", run_id="run-id-2", execution_status="CLOSED", open=False + ) wfe3 = WorkflowExecution( - workflow_id="wf-id-2", run_id="run-id-3", execution_status="OPEN", open=True) + workflow_id="wf-id-2", run_id="run-id-3", execution_status="OPEN", open=True + ) wfe4 = WorkflowExecution( - workflow_id="wf-id-3", run_id="run-id-4", execution_status="CLOSED", open=False) + workflow_id="wf-id-3", run_id="run-id-4", execution_status="CLOSED", open=False + ) domain.workflow_executions = [wfe1, wfe2, wfe3, wfe4] # get workflow execution through workflow_id and run_id - domain.get_workflow_execution( - "wf-id-1", run_id="run-id-1").should.equal(wfe1) - domain.get_workflow_execution( - "wf-id-1", run_id="run-id-2").should.equal(wfe2) - domain.get_workflow_execution( - "wf-id-3", run_id="run-id-4").should.equal(wfe4) + domain.get_workflow_execution("wf-id-1", run_id="run-id-1").should.equal(wfe1) + domain.get_workflow_execution("wf-id-1", run_id="run-id-2").should.equal(wfe2) + domain.get_workflow_execution("wf-id-3", run_id="run-id-4").should.equal(wfe4) domain.get_workflow_execution.when.called_with( "wf-id-1", run_id="non-existent" - ).should.throw( - SWFUnknownResourceFault, - ) + ).should.throw(SWFUnknownResourceFault) # get OPEN workflow execution by default if no run_id domain.get_workflow_execution("wf-id-1").should.equal(wfe1) - domain.get_workflow_execution.when.called_with( - "wf-id-3" - ).should.throw( + domain.get_workflow_execution.when.called_with("wf-id-3").should.throw( SWFUnknownResourceFault ) - domain.get_workflow_execution.when.called_with( - "wf-id-non-existent" - ).should.throw( + domain.get_workflow_execution.when.called_with("wf-id-non-existent").should.throw( SWFUnknownResourceFault ) # raise_if_closed attribute domain.get_workflow_execution( - "wf-id-1", run_id="run-id-1", raise_if_closed=True).should.equal(wfe1) + "wf-id-1", run_id="run-id-1", raise_if_closed=True + ).should.equal(wfe1) domain.get_workflow_execution.when.called_with( "wf-id-3", run_id="run-id-4", raise_if_closed=True - ).should.throw( - SWFUnknownResourceFault - ) + ).should.throw(SWFUnknownResourceFault) # raise_if_none attribute domain.get_workflow_execution("foo", raise_if_none=False).should.be.none diff --git a/tests/test_swf/models/test_generic_type.py b/tests/test_swf/models/test_generic_type.py index bea07ce1c..ef7378d06 100644 --- a/tests/test_swf/models/test_generic_type.py +++ b/tests/test_swf/models/test_generic_type.py @@ -1,58 +1,58 @@ -from moto.swf.models import GenericType -import sure # noqa - - -# Tests for GenericType (ActivityType, WorkflowType) -class FooType(GenericType): - - @property - def kind(self): - return "foo" - - @property - def _configuration_keys(self): - return ["justAnExampleTimeout"] - - -def test_type_short_dict_representation(): - _type = FooType("test-foo", "v1.0") - _type.to_short_dict().should.equal({"name": "test-foo", "version": "v1.0"}) - - -def test_type_medium_dict_representation(): - _type = FooType("test-foo", "v1.0") - _type.to_medium_dict()["fooType"].should.equal(_type.to_short_dict()) - _type.to_medium_dict()["status"].should.equal("REGISTERED") - _type.to_medium_dict().should.contain("creationDate") - _type.to_medium_dict().should_not.contain("deprecationDate") - _type.to_medium_dict().should_not.contain("description") - - _type.description = "foo bar" - _type.to_medium_dict()["description"].should.equal("foo bar") - - _type.status = "DEPRECATED" - _type.to_medium_dict().should.contain("deprecationDate") - - -def test_type_full_dict_representation(): - _type = FooType("test-foo", "v1.0") - _type.to_full_dict()["typeInfo"].should.equal(_type.to_medium_dict()) - _type.to_full_dict()["configuration"].should.equal({}) - - _type.task_list = "foo" - _type.to_full_dict()["configuration"][ - "defaultTaskList"].should.equal({"name": "foo"}) - - _type.just_an_example_timeout = "60" - _type.to_full_dict()["configuration"][ - "justAnExampleTimeout"].should.equal("60") - - _type.non_whitelisted_property = "34" - keys = _type.to_full_dict()["configuration"].keys() - sorted(keys).should.equal(["defaultTaskList", "justAnExampleTimeout"]) - - -def test_type_string_representation(): - _type = FooType("test-foo", "v1.0") - str(_type).should.equal( - "FooType(name: test-foo, version: v1.0, status: REGISTERED)") +from moto.swf.models import GenericType +import sure # noqa + + +# Tests for GenericType (ActivityType, WorkflowType) +class FooType(GenericType): + @property + def kind(self): + return "foo" + + @property + def _configuration_keys(self): + return ["justAnExampleTimeout"] + + +def test_type_short_dict_representation(): + _type = FooType("test-foo", "v1.0") + _type.to_short_dict().should.equal({"name": "test-foo", "version": "v1.0"}) + + +def test_type_medium_dict_representation(): + _type = FooType("test-foo", "v1.0") + _type.to_medium_dict()["fooType"].should.equal(_type.to_short_dict()) + _type.to_medium_dict()["status"].should.equal("REGISTERED") + _type.to_medium_dict().should.contain("creationDate") + _type.to_medium_dict().should_not.contain("deprecationDate") + _type.to_medium_dict().should_not.contain("description") + + _type.description = "foo bar" + _type.to_medium_dict()["description"].should.equal("foo bar") + + _type.status = "DEPRECATED" + _type.to_medium_dict().should.contain("deprecationDate") + + +def test_type_full_dict_representation(): + _type = FooType("test-foo", "v1.0") + _type.to_full_dict()["typeInfo"].should.equal(_type.to_medium_dict()) + _type.to_full_dict()["configuration"].should.equal({}) + + _type.task_list = "foo" + _type.to_full_dict()["configuration"]["defaultTaskList"].should.equal( + {"name": "foo"} + ) + + _type.just_an_example_timeout = "60" + _type.to_full_dict()["configuration"]["justAnExampleTimeout"].should.equal("60") + + _type.non_whitelisted_property = "34" + keys = _type.to_full_dict()["configuration"].keys() + sorted(keys).should.equal(["defaultTaskList", "justAnExampleTimeout"]) + + +def test_type_string_representation(): + _type = FooType("test-foo", "v1.0") + str(_type).should.equal( + "FooType(name: test-foo, version: v1.0, status: REGISTERED)" + ) diff --git a/tests/test_swf/models/test_history_event.py b/tests/test_swf/models/test_history_event.py index fcf4a4a55..8b8234187 100644 --- a/tests/test_swf/models/test_history_event.py +++ b/tests/test_swf/models/test_history_event.py @@ -1,31 +1,31 @@ -from freezegun import freeze_time -import sure # noqa - -from moto.swf.models import HistoryEvent - - -@freeze_time("2015-01-01 12:00:00") -def test_history_event_creation(): - he = HistoryEvent(123, "DecisionTaskStarted", scheduled_event_id=2) - he.event_id.should.equal(123) - he.event_type.should.equal("DecisionTaskStarted") - he.event_timestamp.should.equal(1420113600.0) - - -@freeze_time("2015-01-01 12:00:00") -def test_history_event_to_dict_representation(): - he = HistoryEvent(123, "DecisionTaskStarted", scheduled_event_id=2) - he.to_dict().should.equal({ - "eventId": 123, - "eventType": "DecisionTaskStarted", - "eventTimestamp": 1420113600.0, - "decisionTaskStartedEventAttributes": { - "scheduledEventId": 2 - } - }) - - -def test_history_event_breaks_on_initialization_if_not_implemented(): - HistoryEvent.when.called_with( - 123, "UnknownHistoryEvent" - ).should.throw(NotImplementedError) +from freezegun import freeze_time +import sure # noqa + +from moto.swf.models import HistoryEvent + + +@freeze_time("2015-01-01 12:00:00") +def test_history_event_creation(): + he = HistoryEvent(123, "DecisionTaskStarted", scheduled_event_id=2) + he.event_id.should.equal(123) + he.event_type.should.equal("DecisionTaskStarted") + he.event_timestamp.should.equal(1420113600.0) + + +@freeze_time("2015-01-01 12:00:00") +def test_history_event_to_dict_representation(): + he = HistoryEvent(123, "DecisionTaskStarted", scheduled_event_id=2) + he.to_dict().should.equal( + { + "eventId": 123, + "eventType": "DecisionTaskStarted", + "eventTimestamp": 1420113600.0, + "decisionTaskStartedEventAttributes": {"scheduledEventId": 2}, + } + ) + + +def test_history_event_breaks_on_initialization_if_not_implemented(): + HistoryEvent.when.called_with(123, "UnknownHistoryEvent").should.throw( + NotImplementedError + ) diff --git a/tests/test_swf/models/test_workflow_execution.py b/tests/test_swf/models/test_workflow_execution.py index 7271cca7f..6c73a9686 100644 --- a/tests/test_swf/models/test_workflow_execution.py +++ b/tests/test_swf/models/test_workflow_execution.py @@ -1,501 +1,510 @@ -from freezegun import freeze_time -import sure # noqa - -from moto.swf.models import ( - ActivityType, - Timeout, - WorkflowType, - WorkflowExecution, -) -from moto.swf.exceptions import SWFDefaultUndefinedFault -from ..utils import ( - auto_start_decision_tasks, - get_basic_domain, - get_basic_workflow_type, - make_workflow_execution, -) - - -VALID_ACTIVITY_TASK_ATTRIBUTES = { - "activityId": "my-activity-001", - "activityType": {"name": "test-activity", "version": "v1.1"}, - "taskList": {"name": "task-list-name"}, - "scheduleToStartTimeout": "600", - "scheduleToCloseTimeout": "600", - "startToCloseTimeout": "600", - "heartbeatTimeout": "300", -} - - -def test_workflow_execution_creation(): - domain = get_basic_domain() - wft = get_basic_workflow_type() - wfe = WorkflowExecution(domain, wft, "ab1234", child_policy="TERMINATE") - - wfe.domain.should.equal(domain) - wfe.workflow_type.should.equal(wft) - wfe.child_policy.should.equal("TERMINATE") - - -def test_workflow_execution_creation_child_policy_logic(): - domain = get_basic_domain() - - WorkflowExecution( - domain, - WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", - default_execution_start_to_close_timeout="300", - default_task_start_to_close_timeout="300", - ), - "ab1234" - ).child_policy.should.equal("ABANDON") - - WorkflowExecution( - domain, - WorkflowType( - "test-workflow", "v1.0", task_list="queue", - default_execution_start_to_close_timeout="300", - default_task_start_to_close_timeout="300", - ), - "ab1234", - child_policy="REQUEST_CANCEL" - ).child_policy.should.equal("REQUEST_CANCEL") - - WorkflowExecution.when.called_with( - domain, - WorkflowType("test-workflow", "v1.0"), "ab1234" - ).should.throw(SWFDefaultUndefinedFault) - - -def test_workflow_execution_string_representation(): - wfe = make_workflow_execution(child_policy="TERMINATE") - str(wfe).should.match(r"^WorkflowExecution\(run_id: .*\)") - - -def test_workflow_execution_generates_a_random_run_id(): - domain = get_basic_domain() - wft = get_basic_workflow_type() - wfe1 = WorkflowExecution(domain, wft, "ab1234", child_policy="TERMINATE") - wfe2 = WorkflowExecution(domain, wft, "ab1235", child_policy="TERMINATE") - wfe1.run_id.should_not.equal(wfe2.run_id) - - -def test_workflow_execution_short_dict_representation(): - domain = get_basic_domain() - wf_type = WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", - default_execution_start_to_close_timeout="300", - default_task_start_to_close_timeout="300", - ) - wfe = WorkflowExecution(domain, wf_type, "ab1234") - - sd = wfe.to_short_dict() - sd["workflowId"].should.equal("ab1234") - sd.should.contain("runId") - - -def test_workflow_execution_medium_dict_representation(): - domain = get_basic_domain() - wf_type = WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", - default_execution_start_to_close_timeout="300", - default_task_start_to_close_timeout="300", - ) - wfe = WorkflowExecution(domain, wf_type, "ab1234") - - md = wfe.to_medium_dict() - md["execution"].should.equal(wfe.to_short_dict()) - md["workflowType"].should.equal(wf_type.to_short_dict()) - md["startTimestamp"].should.be.a('float') - md["executionStatus"].should.equal("OPEN") - md["cancelRequested"].should.be.falsy - md.should_not.contain("tagList") - - wfe.tag_list = ["foo", "bar", "baz"] - md = wfe.to_medium_dict() - md["tagList"].should.equal(["foo", "bar", "baz"]) - - -def test_workflow_execution_full_dict_representation(): - domain = get_basic_domain() - wf_type = WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", - default_execution_start_to_close_timeout="300", - default_task_start_to_close_timeout="300", - ) - wfe = WorkflowExecution(domain, wf_type, "ab1234") - - fd = wfe.to_full_dict() - fd["executionInfo"].should.equal(wfe.to_medium_dict()) - fd["openCounts"]["openTimers"].should.equal(0) - fd["openCounts"]["openDecisionTasks"].should.equal(0) - fd["openCounts"]["openActivityTasks"].should.equal(0) - fd["executionConfiguration"].should.equal({ - "childPolicy": "ABANDON", - "executionStartToCloseTimeout": "300", - "taskList": {"name": "queue"}, - "taskStartToCloseTimeout": "300", - }) - - -def test_workflow_execution_list_dict_representation(): - domain = get_basic_domain() - wf_type = WorkflowType( - 'test-workflow', 'v1.0', - task_list='queue', default_child_policy='ABANDON', - default_execution_start_to_close_timeout='300', - default_task_start_to_close_timeout='300', - ) - wfe = WorkflowExecution(domain, wf_type, 'ab1234') - - ld = wfe.to_list_dict() - ld['workflowType']['version'].should.equal('v1.0') - ld['workflowType']['name'].should.equal('test-workflow') - ld['executionStatus'].should.equal('OPEN') - ld['execution']['workflowId'].should.equal('ab1234') - ld['execution'].should.contain('runId') - ld['cancelRequested'].should.be.false - ld.should.contain('startTimestamp') - - -def test_workflow_execution_schedule_decision_task(): - wfe = make_workflow_execution() - wfe.open_counts["openDecisionTasks"].should.equal(0) - wfe.schedule_decision_task() - wfe.open_counts["openDecisionTasks"].should.equal(1) - - -def test_workflow_execution_start_decision_task(): - wfe = make_workflow_execution() - wfe.schedule_decision_task() - dt = wfe.decision_tasks[0] - wfe.start_decision_task(dt.task_token, identity="srv01") - dt = wfe.decision_tasks[0] - dt.state.should.equal("STARTED") - wfe.events()[-1].event_type.should.equal("DecisionTaskStarted") - wfe.events()[-1].event_attributes["identity"].should.equal("srv01") - - -def test_workflow_execution_history_events_ids(): - wfe = make_workflow_execution() - wfe._add_event("WorkflowExecutionStarted") - wfe._add_event("DecisionTaskScheduled") - wfe._add_event("DecisionTaskStarted") - ids = [evt.event_id for evt in wfe.events()] - ids.should.equal([1, 2, 3]) - - -@freeze_time("2015-01-01 12:00:00") -def test_workflow_execution_start(): - wfe = make_workflow_execution() - wfe.events().should.equal([]) - - wfe.start() - wfe.start_timestamp.should.equal(1420113600.0) - wfe.events().should.have.length_of(2) - wfe.events()[0].event_type.should.equal("WorkflowExecutionStarted") - wfe.events()[1].event_type.should.equal("DecisionTaskScheduled") - - -@freeze_time("2015-01-02 12:00:00") -def test_workflow_execution_complete(): - wfe = make_workflow_execution() - wfe.complete(123, result="foo") - - wfe.execution_status.should.equal("CLOSED") - wfe.close_status.should.equal("COMPLETED") - wfe.close_timestamp.should.equal(1420200000.0) - wfe.events()[-1].event_type.should.equal("WorkflowExecutionCompleted") - wfe.events()[-1].event_attributes["decisionTaskCompletedEventId"].should.equal(123) - wfe.events()[-1].event_attributes["result"].should.equal("foo") - - -@freeze_time("2015-01-02 12:00:00") -def test_workflow_execution_fail(): - wfe = make_workflow_execution() - wfe.fail(123, details="some details", reason="my rules") - - wfe.execution_status.should.equal("CLOSED") - wfe.close_status.should.equal("FAILED") - wfe.close_timestamp.should.equal(1420200000.0) - wfe.events()[-1].event_type.should.equal("WorkflowExecutionFailed") - wfe.events()[-1].event_attributes["decisionTaskCompletedEventId"].should.equal(123) - wfe.events()[-1].event_attributes["details"].should.equal("some details") - wfe.events()[-1].event_attributes["reason"].should.equal("my rules") - - -@freeze_time("2015-01-01 12:00:00") -def test_workflow_execution_schedule_activity_task(): - wfe = make_workflow_execution() - wfe.latest_activity_task_timestamp.should.be.none - - wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) - - wfe.latest_activity_task_timestamp.should.equal(1420113600.0) - - wfe.open_counts["openActivityTasks"].should.equal(1) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ActivityTaskScheduled") - last_event.event_attributes[ - "decisionTaskCompletedEventId"].should.equal(123) - last_event.event_attributes["taskList"][ - "name"].should.equal("task-list-name") - - wfe.activity_tasks.should.have.length_of(1) - task = wfe.activity_tasks[0] - task.activity_id.should.equal("my-activity-001") - task.activity_type.name.should.equal("test-activity") - wfe.domain.activity_task_lists["task-list-name"].should.contain(task) - - -def test_workflow_execution_schedule_activity_task_without_task_list_should_take_default(): - wfe = make_workflow_execution() - wfe.domain.add_type( - ActivityType("test-activity", "v1.2", task_list="foobar") - ) - wfe.schedule_activity_task(123, { - "activityId": "my-activity-001", - "activityType": {"name": "test-activity", "version": "v1.2"}, - "scheduleToStartTimeout": "600", - "scheduleToCloseTimeout": "600", - "startToCloseTimeout": "600", - "heartbeatTimeout": "300", - }) - - wfe.open_counts["openActivityTasks"].should.equal(1) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ActivityTaskScheduled") - last_event.event_attributes["taskList"]["name"].should.equal("foobar") - - task = wfe.activity_tasks[0] - wfe.domain.activity_task_lists["foobar"].should.contain(task) - - -def test_workflow_execution_schedule_activity_task_should_fail_if_wrong_attributes(): - wfe = make_workflow_execution() - at = ActivityType("test-activity", "v1.1") - at.status = "DEPRECATED" - wfe.domain.add_type(at) - wfe.domain.add_type(ActivityType("test-activity", "v1.2")) - - hsh = { - "activityId": "my-activity-001", - "activityType": {"name": "test-activity-does-not-exists", "version": "v1.1"}, - } - - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "ACTIVITY_TYPE_DOES_NOT_EXIST") - - hsh["activityType"]["name"] = "test-activity" - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "ACTIVITY_TYPE_DEPRECATED") - - hsh["activityType"]["version"] = "v1.2" - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "DEFAULT_TASK_LIST_UNDEFINED") - - hsh["taskList"] = {"name": "foobar"} - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "DEFAULT_SCHEDULE_TO_START_TIMEOUT_UNDEFINED") - - hsh["scheduleToStartTimeout"] = "600" - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "DEFAULT_SCHEDULE_TO_CLOSE_TIMEOUT_UNDEFINED") - - hsh["scheduleToCloseTimeout"] = "600" - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "DEFAULT_START_TO_CLOSE_TIMEOUT_UNDEFINED") - - hsh["startToCloseTimeout"] = "600" - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "DEFAULT_HEARTBEAT_TIMEOUT_UNDEFINED") - - wfe.open_counts["openActivityTasks"].should.equal(0) - wfe.activity_tasks.should.have.length_of(0) - wfe.domain.activity_task_lists.should.have.length_of(0) - - hsh["heartbeatTimeout"] = "300" - wfe.schedule_activity_task(123, hsh) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ActivityTaskScheduled") - - task = wfe.activity_tasks[0] - wfe.domain.activity_task_lists["foobar"].should.contain(task) - wfe.open_counts["openDecisionTasks"].should.equal(0) - wfe.open_counts["openActivityTasks"].should.equal(1) - - -def test_workflow_execution_schedule_activity_task_failure_triggers_new_decision(): - wfe = make_workflow_execution() - wfe.start() - task_token = wfe.decision_tasks[-1].task_token - wfe.start_decision_task(task_token) - wfe.complete_decision_task( - task_token, - execution_context="free-form execution context", - decisions=[ - { - "decisionType": "ScheduleActivityTask", - "scheduleActivityTaskDecisionAttributes": { - "activityId": "my-activity-001", - "activityType": { - "name": "test-activity-does-not-exist", - "version": "v1.2" - }, - } - }, - { - "decisionType": "ScheduleActivityTask", - "scheduleActivityTaskDecisionAttributes": { - "activityId": "my-activity-001", - "activityType": { - "name": "test-activity-does-not-exist", - "version": "v1.2" - }, - } - }, - ]) - - wfe.latest_execution_context.should.equal("free-form execution context") - wfe.open_counts["openActivityTasks"].should.equal(0) - wfe.open_counts["openDecisionTasks"].should.equal(1) - last_events = wfe.events()[-3:] - last_events[0].event_type.should.equal("ScheduleActivityTaskFailed") - last_events[1].event_type.should.equal("ScheduleActivityTaskFailed") - last_events[2].event_type.should.equal("DecisionTaskScheduled") - - -def test_workflow_execution_schedule_activity_task_with_same_activity_id(): - wfe = make_workflow_execution() - - wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) - wfe.open_counts["openActivityTasks"].should.equal(1) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ActivityTaskScheduled") - - wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) - wfe.open_counts["openActivityTasks"].should.equal(1) - last_event = wfe.events()[-1] - last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "ACTIVITY_ID_ALREADY_IN_USE") - - -def test_workflow_execution_start_activity_task(): - wfe = make_workflow_execution() - wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) - task_token = wfe.activity_tasks[-1].task_token - wfe.start_activity_task(task_token, identity="worker01") - task = wfe.activity_tasks[-1] - task.state.should.equal("STARTED") - wfe.events()[-1].event_type.should.equal("ActivityTaskStarted") - wfe.events()[-1].event_attributes["identity"].should.equal("worker01") - - -def test_complete_activity_task(): - wfe = make_workflow_execution() - wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) - task_token = wfe.activity_tasks[-1].task_token - - wfe.open_counts["openActivityTasks"].should.equal(1) - wfe.open_counts["openDecisionTasks"].should.equal(0) - - wfe.start_activity_task(task_token, identity="worker01") - wfe.complete_activity_task(task_token, result="a superb result") - - task = wfe.activity_tasks[-1] - task.state.should.equal("COMPLETED") - wfe.events()[-2].event_type.should.equal("ActivityTaskCompleted") - wfe.events()[-1].event_type.should.equal("DecisionTaskScheduled") - - wfe.open_counts["openActivityTasks"].should.equal(0) - wfe.open_counts["openDecisionTasks"].should.equal(1) - - -def test_terminate(): - wfe = make_workflow_execution() - wfe.schedule_decision_task() - wfe.terminate() - - wfe.execution_status.should.equal("CLOSED") - wfe.close_status.should.equal("TERMINATED") - wfe.close_cause.should.equal("OPERATOR_INITIATED") - wfe.open_counts["openDecisionTasks"].should.equal(1) - - last_event = wfe.events()[-1] - last_event.event_type.should.equal("WorkflowExecutionTerminated") - # take default child_policy if not provided (as here) - last_event.event_attributes["childPolicy"].should.equal("ABANDON") - - -def test_first_timeout(): - wfe = make_workflow_execution() - wfe.first_timeout().should.be.none - - with freeze_time("2015-01-01 12:00:00"): - wfe.start() - wfe.first_timeout().should.be.none - - with freeze_time("2015-01-01 14:01"): - # 2 hours timeout reached - wfe.first_timeout().should.be.a(Timeout) - - -# See moto/swf/models/workflow_execution.py "_process_timeouts()" for more -# details -def test_timeouts_are_processed_in_order_and_reevaluated(): - # Let's make a Workflow Execution with the following properties: - # - execution start to close timeout of 8 mins - # - (decision) task start to close timeout of 5 mins - # - # Now start the workflow execution, and look at the history 15 mins later: - # - a first decision task is fired just after workflow execution start - # - the first decision task should have timed out after 5 mins - # - that fires a new decision task (which we hack to start automatically) - # - then the workflow timeouts after 8 mins (shows gradual reevaluation) - # - but the last scheduled decision task should *not* timeout (workflow closed) - with freeze_time("2015-01-01 12:00:00"): - wfe = make_workflow_execution( - execution_start_to_close_timeout=8 * 60, - task_start_to_close_timeout=5 * 60, - ) - # decision will automatically start - wfe = auto_start_decision_tasks(wfe) - wfe.start() - event_idx = len(wfe.events()) - - with freeze_time("2015-01-01 12:08:00"): - wfe._process_timeouts() - - event_types = [e.event_type for e in wfe.events()[event_idx:]] - event_types.should.equal([ - "DecisionTaskTimedOut", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "WorkflowExecutionTimedOut", - ]) +from freezegun import freeze_time +import sure # noqa + +from moto.swf.models import ActivityType, Timeout, WorkflowType, WorkflowExecution +from moto.swf.exceptions import SWFDefaultUndefinedFault +from ..utils import ( + auto_start_decision_tasks, + get_basic_domain, + get_basic_workflow_type, + make_workflow_execution, +) + + +VALID_ACTIVITY_TASK_ATTRIBUTES = { + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.1"}, + "taskList": {"name": "task-list-name"}, + "scheduleToStartTimeout": "600", + "scheduleToCloseTimeout": "600", + "startToCloseTimeout": "600", + "heartbeatTimeout": "300", +} + + +def test_workflow_execution_creation(): + domain = get_basic_domain() + wft = get_basic_workflow_type() + wfe = WorkflowExecution(domain, wft, "ab1234", child_policy="TERMINATE") + + wfe.domain.should.equal(domain) + wfe.workflow_type.should.equal(wft) + wfe.child_policy.should.equal("TERMINATE") + + +def test_workflow_execution_creation_child_policy_logic(): + domain = get_basic_domain() + + WorkflowExecution( + domain, + WorkflowType( + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", + ), + "ab1234", + ).child_policy.should.equal("ABANDON") + + WorkflowExecution( + domain, + WorkflowType( + "test-workflow", + "v1.0", + task_list="queue", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", + ), + "ab1234", + child_policy="REQUEST_CANCEL", + ).child_policy.should.equal("REQUEST_CANCEL") + + WorkflowExecution.when.called_with( + domain, WorkflowType("test-workflow", "v1.0"), "ab1234" + ).should.throw(SWFDefaultUndefinedFault) + + +def test_workflow_execution_string_representation(): + wfe = make_workflow_execution(child_policy="TERMINATE") + str(wfe).should.match(r"^WorkflowExecution\(run_id: .*\)") + + +def test_workflow_execution_generates_a_random_run_id(): + domain = get_basic_domain() + wft = get_basic_workflow_type() + wfe1 = WorkflowExecution(domain, wft, "ab1234", child_policy="TERMINATE") + wfe2 = WorkflowExecution(domain, wft, "ab1235", child_policy="TERMINATE") + wfe1.run_id.should_not.equal(wfe2.run_id) + + +def test_workflow_execution_short_dict_representation(): + domain = get_basic_domain() + wf_type = WorkflowType( + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", + ) + wfe = WorkflowExecution(domain, wf_type, "ab1234") + + sd = wfe.to_short_dict() + sd["workflowId"].should.equal("ab1234") + sd.should.contain("runId") + + +def test_workflow_execution_medium_dict_representation(): + domain = get_basic_domain() + wf_type = WorkflowType( + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", + ) + wfe = WorkflowExecution(domain, wf_type, "ab1234") + + md = wfe.to_medium_dict() + md["execution"].should.equal(wfe.to_short_dict()) + md["workflowType"].should.equal(wf_type.to_short_dict()) + md["startTimestamp"].should.be.a("float") + md["executionStatus"].should.equal("OPEN") + md["cancelRequested"].should.be.falsy + md.should_not.contain("tagList") + + wfe.tag_list = ["foo", "bar", "baz"] + md = wfe.to_medium_dict() + md["tagList"].should.equal(["foo", "bar", "baz"]) + + +def test_workflow_execution_full_dict_representation(): + domain = get_basic_domain() + wf_type = WorkflowType( + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", + ) + wfe = WorkflowExecution(domain, wf_type, "ab1234") + + fd = wfe.to_full_dict() + fd["executionInfo"].should.equal(wfe.to_medium_dict()) + fd["openCounts"]["openTimers"].should.equal(0) + fd["openCounts"]["openDecisionTasks"].should.equal(0) + fd["openCounts"]["openActivityTasks"].should.equal(0) + fd["executionConfiguration"].should.equal( + { + "childPolicy": "ABANDON", + "executionStartToCloseTimeout": "300", + "taskList": {"name": "queue"}, + "taskStartToCloseTimeout": "300", + } + ) + + +def test_workflow_execution_list_dict_representation(): + domain = get_basic_domain() + wf_type = WorkflowType( + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", + ) + wfe = WorkflowExecution(domain, wf_type, "ab1234") + + ld = wfe.to_list_dict() + ld["workflowType"]["version"].should.equal("v1.0") + ld["workflowType"]["name"].should.equal("test-workflow") + ld["executionStatus"].should.equal("OPEN") + ld["execution"]["workflowId"].should.equal("ab1234") + ld["execution"].should.contain("runId") + ld["cancelRequested"].should.be.false + ld.should.contain("startTimestamp") + + +def test_workflow_execution_schedule_decision_task(): + wfe = make_workflow_execution() + wfe.open_counts["openDecisionTasks"].should.equal(0) + wfe.schedule_decision_task() + wfe.open_counts["openDecisionTasks"].should.equal(1) + + +def test_workflow_execution_start_decision_task(): + wfe = make_workflow_execution() + wfe.schedule_decision_task() + dt = wfe.decision_tasks[0] + wfe.start_decision_task(dt.task_token, identity="srv01") + dt = wfe.decision_tasks[0] + dt.state.should.equal("STARTED") + wfe.events()[-1].event_type.should.equal("DecisionTaskStarted") + wfe.events()[-1].event_attributes["identity"].should.equal("srv01") + + +def test_workflow_execution_history_events_ids(): + wfe = make_workflow_execution() + wfe._add_event("WorkflowExecutionStarted") + wfe._add_event("DecisionTaskScheduled") + wfe._add_event("DecisionTaskStarted") + ids = [evt.event_id for evt in wfe.events()] + ids.should.equal([1, 2, 3]) + + +@freeze_time("2015-01-01 12:00:00") +def test_workflow_execution_start(): + wfe = make_workflow_execution() + wfe.events().should.equal([]) + + wfe.start() + wfe.start_timestamp.should.equal(1420113600.0) + wfe.events().should.have.length_of(2) + wfe.events()[0].event_type.should.equal("WorkflowExecutionStarted") + wfe.events()[1].event_type.should.equal("DecisionTaskScheduled") + + +@freeze_time("2015-01-02 12:00:00") +def test_workflow_execution_complete(): + wfe = make_workflow_execution() + wfe.complete(123, result="foo") + + wfe.execution_status.should.equal("CLOSED") + wfe.close_status.should.equal("COMPLETED") + wfe.close_timestamp.should.equal(1420200000.0) + wfe.events()[-1].event_type.should.equal("WorkflowExecutionCompleted") + wfe.events()[-1].event_attributes["decisionTaskCompletedEventId"].should.equal(123) + wfe.events()[-1].event_attributes["result"].should.equal("foo") + + +@freeze_time("2015-01-02 12:00:00") +def test_workflow_execution_fail(): + wfe = make_workflow_execution() + wfe.fail(123, details="some details", reason="my rules") + + wfe.execution_status.should.equal("CLOSED") + wfe.close_status.should.equal("FAILED") + wfe.close_timestamp.should.equal(1420200000.0) + wfe.events()[-1].event_type.should.equal("WorkflowExecutionFailed") + wfe.events()[-1].event_attributes["decisionTaskCompletedEventId"].should.equal(123) + wfe.events()[-1].event_attributes["details"].should.equal("some details") + wfe.events()[-1].event_attributes["reason"].should.equal("my rules") + + +@freeze_time("2015-01-01 12:00:00") +def test_workflow_execution_schedule_activity_task(): + wfe = make_workflow_execution() + wfe.latest_activity_task_timestamp.should.be.none + + wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) + + wfe.latest_activity_task_timestamp.should.equal(1420113600.0) + + wfe.open_counts["openActivityTasks"].should.equal(1) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ActivityTaskScheduled") + last_event.event_attributes["decisionTaskCompletedEventId"].should.equal(123) + last_event.event_attributes["taskList"]["name"].should.equal("task-list-name") + + wfe.activity_tasks.should.have.length_of(1) + task = wfe.activity_tasks[0] + task.activity_id.should.equal("my-activity-001") + task.activity_type.name.should.equal("test-activity") + wfe.domain.activity_task_lists["task-list-name"].should.contain(task) + + +def test_workflow_execution_schedule_activity_task_without_task_list_should_take_default(): + wfe = make_workflow_execution() + wfe.domain.add_type(ActivityType("test-activity", "v1.2", task_list="foobar")) + wfe.schedule_activity_task( + 123, + { + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.2"}, + "scheduleToStartTimeout": "600", + "scheduleToCloseTimeout": "600", + "startToCloseTimeout": "600", + "heartbeatTimeout": "300", + }, + ) + + wfe.open_counts["openActivityTasks"].should.equal(1) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ActivityTaskScheduled") + last_event.event_attributes["taskList"]["name"].should.equal("foobar") + + task = wfe.activity_tasks[0] + wfe.domain.activity_task_lists["foobar"].should.contain(task) + + +def test_workflow_execution_schedule_activity_task_should_fail_if_wrong_attributes(): + wfe = make_workflow_execution() + at = ActivityType("test-activity", "v1.1") + at.status = "DEPRECATED" + wfe.domain.add_type(at) + wfe.domain.add_type(ActivityType("test-activity", "v1.2")) + + hsh = { + "activityId": "my-activity-001", + "activityType": {"name": "test-activity-does-not-exists", "version": "v1.1"}, + } + + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal("ACTIVITY_TYPE_DOES_NOT_EXIST") + + hsh["activityType"]["name"] = "test-activity" + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal("ACTIVITY_TYPE_DEPRECATED") + + hsh["activityType"]["version"] = "v1.2" + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal("DEFAULT_TASK_LIST_UNDEFINED") + + hsh["taskList"] = {"name": "foobar"} + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal( + "DEFAULT_SCHEDULE_TO_START_TIMEOUT_UNDEFINED" + ) + + hsh["scheduleToStartTimeout"] = "600" + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal( + "DEFAULT_SCHEDULE_TO_CLOSE_TIMEOUT_UNDEFINED" + ) + + hsh["scheduleToCloseTimeout"] = "600" + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal( + "DEFAULT_START_TO_CLOSE_TIMEOUT_UNDEFINED" + ) + + hsh["startToCloseTimeout"] = "600" + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal( + "DEFAULT_HEARTBEAT_TIMEOUT_UNDEFINED" + ) + + wfe.open_counts["openActivityTasks"].should.equal(0) + wfe.activity_tasks.should.have.length_of(0) + wfe.domain.activity_task_lists.should.have.length_of(0) + + hsh["heartbeatTimeout"] = "300" + wfe.schedule_activity_task(123, hsh) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ActivityTaskScheduled") + + task = wfe.activity_tasks[0] + wfe.domain.activity_task_lists["foobar"].should.contain(task) + wfe.open_counts["openDecisionTasks"].should.equal(0) + wfe.open_counts["openActivityTasks"].should.equal(1) + + +def test_workflow_execution_schedule_activity_task_failure_triggers_new_decision(): + wfe = make_workflow_execution() + wfe.start() + task_token = wfe.decision_tasks[-1].task_token + wfe.start_decision_task(task_token) + wfe.complete_decision_task( + task_token, + execution_context="free-form execution context", + decisions=[ + { + "decisionType": "ScheduleActivityTask", + "scheduleActivityTaskDecisionAttributes": { + "activityId": "my-activity-001", + "activityType": { + "name": "test-activity-does-not-exist", + "version": "v1.2", + }, + }, + }, + { + "decisionType": "ScheduleActivityTask", + "scheduleActivityTaskDecisionAttributes": { + "activityId": "my-activity-001", + "activityType": { + "name": "test-activity-does-not-exist", + "version": "v1.2", + }, + }, + }, + ], + ) + + wfe.latest_execution_context.should.equal("free-form execution context") + wfe.open_counts["openActivityTasks"].should.equal(0) + wfe.open_counts["openDecisionTasks"].should.equal(1) + last_events = wfe.events()[-3:] + last_events[0].event_type.should.equal("ScheduleActivityTaskFailed") + last_events[1].event_type.should.equal("ScheduleActivityTaskFailed") + last_events[2].event_type.should.equal("DecisionTaskScheduled") + + +def test_workflow_execution_schedule_activity_task_with_same_activity_id(): + wfe = make_workflow_execution() + + wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) + wfe.open_counts["openActivityTasks"].should.equal(1) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ActivityTaskScheduled") + + wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) + wfe.open_counts["openActivityTasks"].should.equal(1) + last_event = wfe.events()[-1] + last_event.event_type.should.equal("ScheduleActivityTaskFailed") + last_event.event_attributes["cause"].should.equal("ACTIVITY_ID_ALREADY_IN_USE") + + +def test_workflow_execution_start_activity_task(): + wfe = make_workflow_execution() + wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) + task_token = wfe.activity_tasks[-1].task_token + wfe.start_activity_task(task_token, identity="worker01") + task = wfe.activity_tasks[-1] + task.state.should.equal("STARTED") + wfe.events()[-1].event_type.should.equal("ActivityTaskStarted") + wfe.events()[-1].event_attributes["identity"].should.equal("worker01") + + +def test_complete_activity_task(): + wfe = make_workflow_execution() + wfe.schedule_activity_task(123, VALID_ACTIVITY_TASK_ATTRIBUTES) + task_token = wfe.activity_tasks[-1].task_token + + wfe.open_counts["openActivityTasks"].should.equal(1) + wfe.open_counts["openDecisionTasks"].should.equal(0) + + wfe.start_activity_task(task_token, identity="worker01") + wfe.complete_activity_task(task_token, result="a superb result") + + task = wfe.activity_tasks[-1] + task.state.should.equal("COMPLETED") + wfe.events()[-2].event_type.should.equal("ActivityTaskCompleted") + wfe.events()[-1].event_type.should.equal("DecisionTaskScheduled") + + wfe.open_counts["openActivityTasks"].should.equal(0) + wfe.open_counts["openDecisionTasks"].should.equal(1) + + +def test_terminate(): + wfe = make_workflow_execution() + wfe.schedule_decision_task() + wfe.terminate() + + wfe.execution_status.should.equal("CLOSED") + wfe.close_status.should.equal("TERMINATED") + wfe.close_cause.should.equal("OPERATOR_INITIATED") + wfe.open_counts["openDecisionTasks"].should.equal(1) + + last_event = wfe.events()[-1] + last_event.event_type.should.equal("WorkflowExecutionTerminated") + # take default child_policy if not provided (as here) + last_event.event_attributes["childPolicy"].should.equal("ABANDON") + + +def test_first_timeout(): + wfe = make_workflow_execution() + wfe.first_timeout().should.be.none + + with freeze_time("2015-01-01 12:00:00"): + wfe.start() + wfe.first_timeout().should.be.none + + with freeze_time("2015-01-01 14:01"): + # 2 hours timeout reached + wfe.first_timeout().should.be.a(Timeout) + + +# See moto/swf/models/workflow_execution.py "_process_timeouts()" for more +# details +def test_timeouts_are_processed_in_order_and_reevaluated(): + # Let's make a Workflow Execution with the following properties: + # - execution start to close timeout of 8 mins + # - (decision) task start to close timeout of 5 mins + # + # Now start the workflow execution, and look at the history 15 mins later: + # - a first decision task is fired just after workflow execution start + # - the first decision task should have timed out after 5 mins + # - that fires a new decision task (which we hack to start automatically) + # - then the workflow timeouts after 8 mins (shows gradual reevaluation) + # - but the last scheduled decision task should *not* timeout (workflow closed) + with freeze_time("2015-01-01 12:00:00"): + wfe = make_workflow_execution( + execution_start_to_close_timeout=8 * 60, task_start_to_close_timeout=5 * 60 + ) + # decision will automatically start + wfe = auto_start_decision_tasks(wfe) + wfe.start() + event_idx = len(wfe.events()) + + with freeze_time("2015-01-01 12:08:00"): + wfe._process_timeouts() + + event_types = [e.event_type for e in wfe.events()[event_idx:]] + event_types.should.equal( + [ + "DecisionTaskTimedOut", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "WorkflowExecutionTimedOut", + ] + ) diff --git a/tests/test_swf/responses/test_activity_tasks.py b/tests/test_swf/responses/test_activity_tasks.py index e67013f6b..0b72b7ca7 100644 --- a/tests/test_swf/responses/test_activity_tasks.py +++ b/tests/test_swf/responses/test_activity_tasks.py @@ -1,228 +1,233 @@ -from boto.swf.exceptions import SWFResponseError -from freezegun import freeze_time -import sure # noqa - -from moto import mock_swf_deprecated -from moto.swf import swf_backend - -from ..utils import setup_workflow, SCHEDULE_ACTIVITY_TASK_DECISION - - -# PollForActivityTask endpoint -@mock_swf_deprecated -def test_poll_for_activity_task_when_one(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - resp = conn.poll_for_activity_task( - "test-domain", "activity-task-list", identity="surprise") - resp["activityId"].should.equal("my-activity-001") - resp["taskToken"].should_not.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - resp["events"][-1]["eventType"].should.equal("ActivityTaskStarted") - resp["events"][-1]["activityTaskStartedEventAttributes"].should.equal( - {"identity": "surprise", "scheduledEventId": 5} - ) - - -@mock_swf_deprecated -def test_poll_for_activity_task_when_none(): - conn = setup_workflow() - resp = conn.poll_for_activity_task("test-domain", "activity-task-list") - resp.should.equal({"startedEventId": 0}) - - -@mock_swf_deprecated -def test_poll_for_activity_task_on_non_existent_queue(): - conn = setup_workflow() - resp = conn.poll_for_activity_task("test-domain", "non-existent-queue") - resp.should.equal({"startedEventId": 0}) - - -# CountPendingActivityTasks endpoint -@mock_swf_deprecated -def test_count_pending_activity_tasks(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - - resp = conn.count_pending_activity_tasks( - "test-domain", "activity-task-list") - resp.should.equal({"count": 1, "truncated": False}) - - -@mock_swf_deprecated -def test_count_pending_decision_tasks_on_non_existent_task_list(): - conn = setup_workflow() - resp = conn.count_pending_activity_tasks("test-domain", "non-existent") - resp.should.equal({"count": 0, "truncated": False}) - - -# RespondActivityTaskCompleted endpoint -@mock_swf_deprecated -def test_respond_activity_task_completed(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] - - resp = conn.respond_activity_task_completed( - activity_token, result="result of the task") - resp.should.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - resp["events"][-2]["eventType"].should.equal("ActivityTaskCompleted") - resp["events"][-2]["activityTaskCompletedEventAttributes"].should.equal( - {"result": "result of the task", "scheduledEventId": 5, "startedEventId": 6} - ) - - -@mock_swf_deprecated -def test_respond_activity_task_completed_on_closed_workflow_execution(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] - - # bad: we're closing workflow execution manually, but endpoints are not - # coded for now.. - wfe = swf_backend.domains[0].workflow_executions[-1] - wfe.execution_status = "CLOSED" - # /bad - - conn.respond_activity_task_completed.when.called_with( - activity_token - ).should.throw(SWFResponseError, "WorkflowExecution=") - - -@mock_swf_deprecated -def test_respond_activity_task_completed_with_task_already_completed(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] - - conn.respond_activity_task_completed(activity_token) - - conn.respond_activity_task_completed.when.called_with( - activity_token - ).should.throw(SWFResponseError, "Unknown activity, scheduledEventId = 5") - - -# RespondActivityTaskFailed endpoint -@mock_swf_deprecated -def test_respond_activity_task_failed(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] - - resp = conn.respond_activity_task_failed(activity_token, - reason="short reason", - details="long details") - resp.should.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - resp["events"][-2]["eventType"].should.equal("ActivityTaskFailed") - resp["events"][-2]["activityTaskFailedEventAttributes"].should.equal( - {"reason": "short reason", "details": "long details", - "scheduledEventId": 5, "startedEventId": 6} - ) - - -@mock_swf_deprecated -def test_respond_activity_task_completed_with_wrong_token(): - # NB: we just test ONE failure case for RespondActivityTaskFailed - # because the safeguards are shared with RespondActivityTaskCompleted, so - # no need to retest everything end-to-end. - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - conn.poll_for_activity_task("test-domain", "activity-task-list") - conn.respond_activity_task_failed.when.called_with( - "not-a-correct-token" - ).should.throw(SWFResponseError, "Invalid token") - - -# RecordActivityTaskHeartbeat endpoint -@mock_swf_deprecated -def test_record_activity_task_heartbeat(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] - - resp = conn.record_activity_task_heartbeat(activity_token) - resp.should.equal({"cancelRequested": False}) - - -@mock_swf_deprecated -def test_record_activity_task_heartbeat_with_wrong_token(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] - - conn.record_activity_task_heartbeat.when.called_with( - "bad-token", details="some progress details" - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_record_activity_task_heartbeat_sets_details_in_case_of_timeout(): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - with freeze_time("2015-01-01 12:00:00"): - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] - conn.record_activity_task_heartbeat( - activity_token, details="some progress details") - - with freeze_time("2015-01-01 12:05:30"): - # => Activity Task Heartbeat timeout reached!! - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - resp["events"][-2]["eventType"].should.equal("ActivityTaskTimedOut") - attrs = resp["events"][-2]["activityTaskTimedOutEventAttributes"] - attrs["details"].should.equal("some progress details") +from boto.swf.exceptions import SWFResponseError +from freezegun import freeze_time +import sure # noqa + +from moto import mock_swf_deprecated +from moto.swf import swf_backend + +from ..utils import setup_workflow, SCHEDULE_ACTIVITY_TASK_DECISION + + +# PollForActivityTask endpoint +@mock_swf_deprecated +def test_poll_for_activity_task_when_one(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + resp = conn.poll_for_activity_task( + "test-domain", "activity-task-list", identity="surprise" + ) + resp["activityId"].should.equal("my-activity-001") + resp["taskToken"].should_not.be.none + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + resp["events"][-1]["eventType"].should.equal("ActivityTaskStarted") + resp["events"][-1]["activityTaskStartedEventAttributes"].should.equal( + {"identity": "surprise", "scheduledEventId": 5} + ) + + +@mock_swf_deprecated +def test_poll_for_activity_task_when_none(): + conn = setup_workflow() + resp = conn.poll_for_activity_task("test-domain", "activity-task-list") + resp.should.equal({"startedEventId": 0}) + + +@mock_swf_deprecated +def test_poll_for_activity_task_on_non_existent_queue(): + conn = setup_workflow() + resp = conn.poll_for_activity_task("test-domain", "non-existent-queue") + resp.should.equal({"startedEventId": 0}) + + +# CountPendingActivityTasks endpoint +@mock_swf_deprecated +def test_count_pending_activity_tasks(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + + resp = conn.count_pending_activity_tasks("test-domain", "activity-task-list") + resp.should.equal({"count": 1, "truncated": False}) + + +@mock_swf_deprecated +def test_count_pending_decision_tasks_on_non_existent_task_list(): + conn = setup_workflow() + resp = conn.count_pending_activity_tasks("test-domain", "non-existent") + resp.should.equal({"count": 0, "truncated": False}) + + +# RespondActivityTaskCompleted endpoint +@mock_swf_deprecated +def test_respond_activity_task_completed(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] + + resp = conn.respond_activity_task_completed( + activity_token, result="result of the task" + ) + resp.should.be.none + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + resp["events"][-2]["eventType"].should.equal("ActivityTaskCompleted") + resp["events"][-2]["activityTaskCompletedEventAttributes"].should.equal( + {"result": "result of the task", "scheduledEventId": 5, "startedEventId": 6} + ) + + +@mock_swf_deprecated +def test_respond_activity_task_completed_on_closed_workflow_execution(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] + + # bad: we're closing workflow execution manually, but endpoints are not + # coded for now.. + wfe = swf_backend.domains[0].workflow_executions[-1] + wfe.execution_status = "CLOSED" + # /bad + + conn.respond_activity_task_completed.when.called_with(activity_token).should.throw( + SWFResponseError, "WorkflowExecution=" + ) + + +@mock_swf_deprecated +def test_respond_activity_task_completed_with_task_already_completed(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] + + conn.respond_activity_task_completed(activity_token) + + conn.respond_activity_task_completed.when.called_with(activity_token).should.throw( + SWFResponseError, "Unknown activity, scheduledEventId = 5" + ) + + +# RespondActivityTaskFailed endpoint +@mock_swf_deprecated +def test_respond_activity_task_failed(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] + + resp = conn.respond_activity_task_failed( + activity_token, reason="short reason", details="long details" + ) + resp.should.be.none + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + resp["events"][-2]["eventType"].should.equal("ActivityTaskFailed") + resp["events"][-2]["activityTaskFailedEventAttributes"].should.equal( + { + "reason": "short reason", + "details": "long details", + "scheduledEventId": 5, + "startedEventId": 6, + } + ) + + +@mock_swf_deprecated +def test_respond_activity_task_completed_with_wrong_token(): + # NB: we just test ONE failure case for RespondActivityTaskFailed + # because the safeguards are shared with RespondActivityTaskCompleted, so + # no need to retest everything end-to-end. + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + conn.poll_for_activity_task("test-domain", "activity-task-list") + conn.respond_activity_task_failed.when.called_with( + "not-a-correct-token" + ).should.throw(SWFResponseError, "Invalid token") + + +# RecordActivityTaskHeartbeat endpoint +@mock_swf_deprecated +def test_record_activity_task_heartbeat(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] + + resp = conn.record_activity_task_heartbeat(activity_token) + resp.should.equal({"cancelRequested": False}) + + +@mock_swf_deprecated +def test_record_activity_task_heartbeat_with_wrong_token(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + conn.poll_for_activity_task("test-domain", "activity-task-list")["taskToken"] + + conn.record_activity_task_heartbeat.when.called_with( + "bad-token", details="some progress details" + ).should.throw(SWFResponseError) + + +@mock_swf_deprecated +def test_record_activity_task_heartbeat_sets_details_in_case_of_timeout(): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + with freeze_time("2015-01-01 12:00:00"): + activity_token = conn.poll_for_activity_task( + "test-domain", "activity-task-list" + )["taskToken"] + conn.record_activity_task_heartbeat( + activity_token, details="some progress details" + ) + + with freeze_time("2015-01-01 12:05:30"): + # => Activity Task Heartbeat timeout reached!! + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + resp["events"][-2]["eventType"].should.equal("ActivityTaskTimedOut") + attrs = resp["events"][-2]["activityTaskTimedOutEventAttributes"] + attrs["details"].should.equal("some progress details") diff --git a/tests/test_swf/responses/test_activity_types.py b/tests/test_swf/responses/test_activity_types.py index 7bb66ac32..3fa9ad6b1 100644 --- a/tests/test_swf/responses/test_activity_types.py +++ b/tests/test_swf/responses/test_activity_types.py @@ -1,134 +1,141 @@ -import boto -from boto.swf.exceptions import SWFResponseError -import sure # noqa - -from moto import mock_swf_deprecated - - -# RegisterActivityType endpoint -@mock_swf_deprecated -def test_register_activity_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "test-activity", "v1.0") - - types = conn.list_activity_types("test-domain", "REGISTERED") - actype = types["typeInfos"][0] - actype["activityType"]["name"].should.equal("test-activity") - actype["activityType"]["version"].should.equal("v1.0") - - -@mock_swf_deprecated -def test_register_already_existing_activity_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "test-activity", "v1.0") - - conn.register_activity_type.when.called_with( - "test-domain", "test-activity", "v1.0" - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_register_with_wrong_parameter_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - - conn.register_activity_type.when.called_with( - "test-domain", "test-activity", 12 - ).should.throw(SWFResponseError) - - -# ListActivityTypes endpoint -@mock_swf_deprecated -def test_list_activity_types(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "b-test-activity", "v1.0") - conn.register_activity_type("test-domain", "a-test-activity", "v1.0") - conn.register_activity_type("test-domain", "c-test-activity", "v1.0") - - all_activity_types = conn.list_activity_types("test-domain", "REGISTERED") - names = [activity_type["activityType"]["name"] - for activity_type in all_activity_types["typeInfos"]] - names.should.equal( - ["a-test-activity", "b-test-activity", "c-test-activity"]) - - -@mock_swf_deprecated -def test_list_activity_types_reverse_order(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "b-test-activity", "v1.0") - conn.register_activity_type("test-domain", "a-test-activity", "v1.0") - conn.register_activity_type("test-domain", "c-test-activity", "v1.0") - - all_activity_types = conn.list_activity_types("test-domain", "REGISTERED", - reverse_order=True) - names = [activity_type["activityType"]["name"] - for activity_type in all_activity_types["typeInfos"]] - names.should.equal( - ["c-test-activity", "b-test-activity", "a-test-activity"]) - - -# DeprecateActivityType endpoint -@mock_swf_deprecated -def test_deprecate_activity_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "test-activity", "v1.0") - conn.deprecate_activity_type("test-domain", "test-activity", "v1.0") - - actypes = conn.list_activity_types("test-domain", "DEPRECATED") - actype = actypes["typeInfos"][0] - actype["activityType"]["name"].should.equal("test-activity") - actype["activityType"]["version"].should.equal("v1.0") - - -@mock_swf_deprecated -def test_deprecate_already_deprecated_activity_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "test-activity", "v1.0") - conn.deprecate_activity_type("test-domain", "test-activity", "v1.0") - - conn.deprecate_activity_type.when.called_with( - "test-domain", "test-activity", "v1.0" - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_deprecate_non_existent_activity_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - - conn.deprecate_activity_type.when.called_with( - "test-domain", "non-existent", "v1.0" - ).should.throw(SWFResponseError) - - -# DescribeActivityType endpoint -@mock_swf_deprecated -def test_describe_activity_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "test-activity", "v1.0", - task_list="foo", default_task_heartbeat_timeout="32") - - actype = conn.describe_activity_type( - "test-domain", "test-activity", "v1.0") - actype["configuration"]["defaultTaskList"]["name"].should.equal("foo") - infos = actype["typeInfo"] - infos["activityType"]["name"].should.equal("test-activity") - infos["activityType"]["version"].should.equal("v1.0") - infos["status"].should.equal("REGISTERED") - - -@mock_swf_deprecated -def test_describe_non_existent_activity_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - - conn.describe_activity_type.when.called_with( - "test-domain", "non-existent", "v1.0" - ).should.throw(SWFResponseError) +import boto +from boto.swf.exceptions import SWFResponseError +import sure # noqa + +from moto import mock_swf_deprecated + + +# RegisterActivityType endpoint +@mock_swf_deprecated +def test_register_activity_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_activity_type("test-domain", "test-activity", "v1.0") + + types = conn.list_activity_types("test-domain", "REGISTERED") + actype = types["typeInfos"][0] + actype["activityType"]["name"].should.equal("test-activity") + actype["activityType"]["version"].should.equal("v1.0") + + +@mock_swf_deprecated +def test_register_already_existing_activity_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_activity_type("test-domain", "test-activity", "v1.0") + + conn.register_activity_type.when.called_with( + "test-domain", "test-activity", "v1.0" + ).should.throw(SWFResponseError) + + +@mock_swf_deprecated +def test_register_with_wrong_parameter_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + + conn.register_activity_type.when.called_with( + "test-domain", "test-activity", 12 + ).should.throw(SWFResponseError) + + +# ListActivityTypes endpoint +@mock_swf_deprecated +def test_list_activity_types(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_activity_type("test-domain", "b-test-activity", "v1.0") + conn.register_activity_type("test-domain", "a-test-activity", "v1.0") + conn.register_activity_type("test-domain", "c-test-activity", "v1.0") + + all_activity_types = conn.list_activity_types("test-domain", "REGISTERED") + names = [ + activity_type["activityType"]["name"] + for activity_type in all_activity_types["typeInfos"] + ] + names.should.equal(["a-test-activity", "b-test-activity", "c-test-activity"]) + + +@mock_swf_deprecated +def test_list_activity_types_reverse_order(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_activity_type("test-domain", "b-test-activity", "v1.0") + conn.register_activity_type("test-domain", "a-test-activity", "v1.0") + conn.register_activity_type("test-domain", "c-test-activity", "v1.0") + + all_activity_types = conn.list_activity_types( + "test-domain", "REGISTERED", reverse_order=True + ) + names = [ + activity_type["activityType"]["name"] + for activity_type in all_activity_types["typeInfos"] + ] + names.should.equal(["c-test-activity", "b-test-activity", "a-test-activity"]) + + +# DeprecateActivityType endpoint +@mock_swf_deprecated +def test_deprecate_activity_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_activity_type("test-domain", "test-activity", "v1.0") + conn.deprecate_activity_type("test-domain", "test-activity", "v1.0") + + actypes = conn.list_activity_types("test-domain", "DEPRECATED") + actype = actypes["typeInfos"][0] + actype["activityType"]["name"].should.equal("test-activity") + actype["activityType"]["version"].should.equal("v1.0") + + +@mock_swf_deprecated +def test_deprecate_already_deprecated_activity_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_activity_type("test-domain", "test-activity", "v1.0") + conn.deprecate_activity_type("test-domain", "test-activity", "v1.0") + + conn.deprecate_activity_type.when.called_with( + "test-domain", "test-activity", "v1.0" + ).should.throw(SWFResponseError) + + +@mock_swf_deprecated +def test_deprecate_non_existent_activity_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + + conn.deprecate_activity_type.when.called_with( + "test-domain", "non-existent", "v1.0" + ).should.throw(SWFResponseError) + + +# DescribeActivityType endpoint +@mock_swf_deprecated +def test_describe_activity_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_activity_type( + "test-domain", + "test-activity", + "v1.0", + task_list="foo", + default_task_heartbeat_timeout="32", + ) + + actype = conn.describe_activity_type("test-domain", "test-activity", "v1.0") + actype["configuration"]["defaultTaskList"]["name"].should.equal("foo") + infos = actype["typeInfo"] + infos["activityType"]["name"].should.equal("test-activity") + infos["activityType"]["version"].should.equal("v1.0") + infos["status"].should.equal("REGISTERED") + + +@mock_swf_deprecated +def test_describe_non_existent_activity_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + + conn.describe_activity_type.when.called_with( + "test-domain", "non-existent", "v1.0" + ).should.throw(SWFResponseError) diff --git a/tests/test_swf/responses/test_decision_tasks.py b/tests/test_swf/responses/test_decision_tasks.py index ecb3c3117..6389536e6 100644 --- a/tests/test_swf/responses/test_decision_tasks.py +++ b/tests/test_swf/responses/test_decision_tasks.py @@ -1,342 +1,353 @@ -from boto.swf.exceptions import SWFResponseError -from freezegun import freeze_time -import sure # noqa - -from moto import mock_swf_deprecated -from moto.swf import swf_backend - -from ..utils import setup_workflow - - -# PollForDecisionTask endpoint -@mock_swf_deprecated -def test_poll_for_decision_task_when_one(): - conn = setup_workflow() - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) - - resp = conn.poll_for_decision_task( - "test-domain", "queue", identity="srv01") - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal(["WorkflowExecutionStarted", - "DecisionTaskScheduled", "DecisionTaskStarted"]) - - resp[ - "events"][-1]["decisionTaskStartedEventAttributes"]["identity"].should.equal("srv01") - - -@mock_swf_deprecated -def test_poll_for_decision_task_when_none(): - conn = setup_workflow() - conn.poll_for_decision_task("test-domain", "queue") - - resp = conn.poll_for_decision_task("test-domain", "queue") - # this is the DecisionTask representation you get from the real SWF - # after waiting 60s when there's no decision to be taken - resp.should.equal({"previousStartedEventId": 0, "startedEventId": 0}) - - -@mock_swf_deprecated -def test_poll_for_decision_task_on_non_existent_queue(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "non-existent-queue") - resp.should.equal({"previousStartedEventId": 0, "startedEventId": 0}) - - -@mock_swf_deprecated -def test_poll_for_decision_task_with_reverse_order(): - conn = setup_workflow() - resp = conn.poll_for_decision_task( - "test-domain", "queue", reverse_order=True) - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal( - ["DecisionTaskStarted", "DecisionTaskScheduled", "WorkflowExecutionStarted"]) - - -# CountPendingDecisionTasks endpoint -@mock_swf_deprecated -def test_count_pending_decision_tasks(): - conn = setup_workflow() - conn.poll_for_decision_task("test-domain", "queue") - resp = conn.count_pending_decision_tasks("test-domain", "queue") - resp.should.equal({"count": 1, "truncated": False}) - - -@mock_swf_deprecated -def test_count_pending_decision_tasks_on_non_existent_task_list(): - conn = setup_workflow() - resp = conn.count_pending_decision_tasks("test-domain", "non-existent") - resp.should.equal({"count": 0, "truncated": False}) - - -@mock_swf_deprecated -def test_count_pending_decision_tasks_after_decision_completes(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - conn.respond_decision_task_completed(resp["taskToken"]) - - resp = conn.count_pending_decision_tasks("test-domain", "queue") - resp.should.equal({"count": 0, "truncated": False}) - - -# RespondDecisionTaskCompleted endpoint -@mock_swf_deprecated -def test_respond_decision_task_completed_with_no_decision(): - conn = setup_workflow() - - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - resp = conn.respond_decision_task_completed( - task_token, - execution_context="free-form context", - ) - resp.should.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - ]) - evt = resp["events"][-1] - evt["decisionTaskCompletedEventAttributes"].should.equal({ - "executionContext": "free-form context", - "scheduledEventId": 2, - "startedEventId": 3, - }) - - resp = conn.describe_workflow_execution( - "test-domain", conn.run_id, "uid-abcd1234") - resp["latestExecutionContext"].should.equal("free-form context") - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_wrong_token(): - conn = setup_workflow() - conn.poll_for_decision_task("test-domain", "queue") - conn.respond_decision_task_completed.when.called_with( - "not-a-correct-token" - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_respond_decision_task_completed_on_close_workflow_execution(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - # bad: we're closing workflow execution manually, but endpoints are not - # coded for now.. - wfe = swf_backend.domains[0].workflow_executions[-1] - wfe.execution_status = "CLOSED" - # /bad - - conn.respond_decision_task_completed.when.called_with( - task_token - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_task_already_completed(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - conn.respond_decision_task_completed(task_token) - - conn.respond_decision_task_completed.when.called_with( - task_token - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_complete_workflow_execution(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - decisions = [{ - "decisionType": "CompleteWorkflowExecution", - "completeWorkflowExecutionDecisionAttributes": {"result": "foo bar"} - }] - resp = conn.respond_decision_task_completed( - task_token, decisions=decisions) - resp.should.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - "WorkflowExecutionCompleted", - ]) - resp["events"][-1]["workflowExecutionCompletedEventAttributes"][ - "result"].should.equal("foo bar") - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_close_decision_not_last(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - decisions = [ - {"decisionType": "CompleteWorkflowExecution"}, - {"decisionType": "WeDontCare"}, - ] - - conn.respond_decision_task_completed.when.called_with( - task_token, decisions=decisions - ).should.throw(SWFResponseError, r"Close must be last decision in list") - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_invalid_decision_type(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - decisions = [ - {"decisionType": "BadDecisionType"}, - {"decisionType": "CompleteWorkflowExecution"}, - ] - - conn.respond_decision_task_completed.when.called_with( - task_token, decisions=decisions).should.throw( - SWFResponseError, - r"Value 'BadDecisionType' at 'decisions.1.member.decisionType'" - ) - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_missing_attributes(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - decisions = [ - { - "decisionType": "should trigger even with incorrect decision type", - "startTimerDecisionAttributes": {} - }, - ] - - conn.respond_decision_task_completed.when.called_with( - task_token, decisions=decisions - ).should.throw( - SWFResponseError, - r"Value null at 'decisions.1.member.startTimerDecisionAttributes.timerId' " - r"failed to satisfy constraint: Member must not be null" - ) - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_missing_attributes_totally(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - decisions = [ - {"decisionType": "StartTimer"}, - ] - - conn.respond_decision_task_completed.when.called_with( - task_token, decisions=decisions - ).should.throw( - SWFResponseError, - r"Value null at 'decisions.1.member.startTimerDecisionAttributes.timerId' " - r"failed to satisfy constraint: Member must not be null" - ) - - -@mock_swf_deprecated -def test_respond_decision_task_completed_with_fail_workflow_execution(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - decisions = [{ - "decisionType": "FailWorkflowExecution", - "failWorkflowExecutionDecisionAttributes": {"reason": "my rules", "details": "foo"} - }] - resp = conn.respond_decision_task_completed( - task_token, decisions=decisions) - resp.should.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - "WorkflowExecutionFailed", - ]) - attrs = resp["events"][-1]["workflowExecutionFailedEventAttributes"] - attrs["reason"].should.equal("my rules") - attrs["details"].should.equal("foo") - - -@mock_swf_deprecated -@freeze_time("2015-01-01 12:00:00") -def test_respond_decision_task_completed_with_schedule_activity_task(): - conn = setup_workflow() - resp = conn.poll_for_decision_task("test-domain", "queue") - task_token = resp["taskToken"] - - decisions = [{ - "decisionType": "ScheduleActivityTask", - "scheduleActivityTaskDecisionAttributes": { - "activityId": "my-activity-001", - "activityType": { - "name": "test-activity", - "version": "v1.1" - }, - "heartbeatTimeout": "60", - "input": "123", - "taskList": { - "name": "my-task-list" - }, - } - }] - resp = conn.respond_decision_task_completed( - task_token, decisions=decisions) - resp.should.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - "ActivityTaskScheduled", - ]) - resp["events"][-1]["activityTaskScheduledEventAttributes"].should.equal({ - "decisionTaskCompletedEventId": 4, - "activityId": "my-activity-001", - "activityType": { - "name": "test-activity", - "version": "v1.1", - }, - "heartbeatTimeout": "60", - "input": "123", - "taskList": { - "name": "my-task-list" - }, - }) - - resp = conn.describe_workflow_execution( - "test-domain", conn.run_id, "uid-abcd1234") - resp["latestActivityTaskTimestamp"].should.equal(1420113600.0) +from boto.swf.exceptions import SWFResponseError +from freezegun import freeze_time +import sure # noqa + +from moto import mock_swf_deprecated +from moto.swf import swf_backend + +from ..utils import setup_workflow + + +# PollForDecisionTask endpoint +@mock_swf_deprecated +def test_poll_for_decision_task_when_one(): + conn = setup_workflow() + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) + + resp = conn.poll_for_decision_task("test-domain", "queue", identity="srv01") + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal( + ["WorkflowExecutionStarted", "DecisionTaskScheduled", "DecisionTaskStarted"] + ) + + resp["events"][-1]["decisionTaskStartedEventAttributes"]["identity"].should.equal( + "srv01" + ) + + +@mock_swf_deprecated +def test_poll_for_decision_task_when_none(): + conn = setup_workflow() + conn.poll_for_decision_task("test-domain", "queue") + + resp = conn.poll_for_decision_task("test-domain", "queue") + # this is the DecisionTask representation you get from the real SWF + # after waiting 60s when there's no decision to be taken + resp.should.equal({"previousStartedEventId": 0, "startedEventId": 0}) + + +@mock_swf_deprecated +def test_poll_for_decision_task_on_non_existent_queue(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "non-existent-queue") + resp.should.equal({"previousStartedEventId": 0, "startedEventId": 0}) + + +@mock_swf_deprecated +def test_poll_for_decision_task_with_reverse_order(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue", reverse_order=True) + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal( + ["DecisionTaskStarted", "DecisionTaskScheduled", "WorkflowExecutionStarted"] + ) + + +# CountPendingDecisionTasks endpoint +@mock_swf_deprecated +def test_count_pending_decision_tasks(): + conn = setup_workflow() + conn.poll_for_decision_task("test-domain", "queue") + resp = conn.count_pending_decision_tasks("test-domain", "queue") + resp.should.equal({"count": 1, "truncated": False}) + + +@mock_swf_deprecated +def test_count_pending_decision_tasks_on_non_existent_task_list(): + conn = setup_workflow() + resp = conn.count_pending_decision_tasks("test-domain", "non-existent") + resp.should.equal({"count": 0, "truncated": False}) + + +@mock_swf_deprecated +def test_count_pending_decision_tasks_after_decision_completes(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + conn.respond_decision_task_completed(resp["taskToken"]) + + resp = conn.count_pending_decision_tasks("test-domain", "queue") + resp.should.equal({"count": 0, "truncated": False}) + + +# RespondDecisionTaskCompleted endpoint +@mock_swf_deprecated +def test_respond_decision_task_completed_with_no_decision(): + conn = setup_workflow() + + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + resp = conn.respond_decision_task_completed( + task_token, execution_context="free-form context" + ) + resp.should.be.none + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + ] + ) + evt = resp["events"][-1] + evt["decisionTaskCompletedEventAttributes"].should.equal( + { + "executionContext": "free-form context", + "scheduledEventId": 2, + "startedEventId": 3, + } + ) + + resp = conn.describe_workflow_execution("test-domain", conn.run_id, "uid-abcd1234") + resp["latestExecutionContext"].should.equal("free-form context") + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_wrong_token(): + conn = setup_workflow() + conn.poll_for_decision_task("test-domain", "queue") + conn.respond_decision_task_completed.when.called_with( + "not-a-correct-token" + ).should.throw(SWFResponseError) + + +@mock_swf_deprecated +def test_respond_decision_task_completed_on_close_workflow_execution(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + # bad: we're closing workflow execution manually, but endpoints are not + # coded for now.. + wfe = swf_backend.domains[0].workflow_executions[-1] + wfe.execution_status = "CLOSED" + # /bad + + conn.respond_decision_task_completed.when.called_with(task_token).should.throw( + SWFResponseError + ) + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_task_already_completed(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + conn.respond_decision_task_completed(task_token) + + conn.respond_decision_task_completed.when.called_with(task_token).should.throw( + SWFResponseError + ) + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_complete_workflow_execution(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + decisions = [ + { + "decisionType": "CompleteWorkflowExecution", + "completeWorkflowExecutionDecisionAttributes": {"result": "foo bar"}, + } + ] + resp = conn.respond_decision_task_completed(task_token, decisions=decisions) + resp.should.be.none + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + "WorkflowExecutionCompleted", + ] + ) + resp["events"][-1]["workflowExecutionCompletedEventAttributes"][ + "result" + ].should.equal("foo bar") + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_close_decision_not_last(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + decisions = [ + {"decisionType": "CompleteWorkflowExecution"}, + {"decisionType": "WeDontCare"}, + ] + + conn.respond_decision_task_completed.when.called_with( + task_token, decisions=decisions + ).should.throw(SWFResponseError, r"Close must be last decision in list") + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_invalid_decision_type(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + decisions = [ + {"decisionType": "BadDecisionType"}, + {"decisionType": "CompleteWorkflowExecution"}, + ] + + conn.respond_decision_task_completed.when.called_with( + task_token, decisions=decisions + ).should.throw( + SWFResponseError, + r"Value 'BadDecisionType' at 'decisions.1.member.decisionType'", + ) + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_missing_attributes(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + decisions = [ + { + "decisionType": "should trigger even with incorrect decision type", + "startTimerDecisionAttributes": {}, + } + ] + + conn.respond_decision_task_completed.when.called_with( + task_token, decisions=decisions + ).should.throw( + SWFResponseError, + r"Value null at 'decisions.1.member.startTimerDecisionAttributes.timerId' " + r"failed to satisfy constraint: Member must not be null", + ) + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_missing_attributes_totally(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + decisions = [{"decisionType": "StartTimer"}] + + conn.respond_decision_task_completed.when.called_with( + task_token, decisions=decisions + ).should.throw( + SWFResponseError, + r"Value null at 'decisions.1.member.startTimerDecisionAttributes.timerId' " + r"failed to satisfy constraint: Member must not be null", + ) + + +@mock_swf_deprecated +def test_respond_decision_task_completed_with_fail_workflow_execution(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + decisions = [ + { + "decisionType": "FailWorkflowExecution", + "failWorkflowExecutionDecisionAttributes": { + "reason": "my rules", + "details": "foo", + }, + } + ] + resp = conn.respond_decision_task_completed(task_token, decisions=decisions) + resp.should.be.none + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + "WorkflowExecutionFailed", + ] + ) + attrs = resp["events"][-1]["workflowExecutionFailedEventAttributes"] + attrs["reason"].should.equal("my rules") + attrs["details"].should.equal("foo") + + +@mock_swf_deprecated +@freeze_time("2015-01-01 12:00:00") +def test_respond_decision_task_completed_with_schedule_activity_task(): + conn = setup_workflow() + resp = conn.poll_for_decision_task("test-domain", "queue") + task_token = resp["taskToken"] + + decisions = [ + { + "decisionType": "ScheduleActivityTask", + "scheduleActivityTaskDecisionAttributes": { + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.1"}, + "heartbeatTimeout": "60", + "input": "123", + "taskList": {"name": "my-task-list"}, + }, + } + ] + resp = conn.respond_decision_task_completed(task_token, decisions=decisions) + resp.should.be.none + + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + "ActivityTaskScheduled", + ] + ) + resp["events"][-1]["activityTaskScheduledEventAttributes"].should.equal( + { + "decisionTaskCompletedEventId": 4, + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.1"}, + "heartbeatTimeout": "60", + "input": "123", + "taskList": {"name": "my-task-list"}, + } + ) + + resp = conn.describe_workflow_execution("test-domain", conn.run_id, "uid-abcd1234") + resp["latestActivityTaskTimestamp"].should.equal(1420113600.0) diff --git a/tests/test_swf/responses/test_domains.py b/tests/test_swf/responses/test_domains.py index 4004496ed..199219d27 100644 --- a/tests/test_swf/responses/test_domains.py +++ b/tests/test_swf/responses/test_domains.py @@ -82,18 +82,16 @@ def test_deprecate_already_deprecated_domain(): conn.register_domain("test-domain", "60", description="A test domain") conn.deprecate_domain("test-domain") - conn.deprecate_domain.when.called_with( - "test-domain" - ).should.throw(SWFResponseError) + conn.deprecate_domain.when.called_with("test-domain").should.throw(SWFResponseError) @mock_swf_deprecated def test_deprecate_non_existent_domain(): conn = boto.connect_swf("the_key", "the_secret") - conn.deprecate_domain.when.called_with( - "non-existent" - ).should.throw(SWFResponseError) + conn.deprecate_domain.when.called_with("non-existent").should.throw( + SWFResponseError + ) # DescribeDomain endpoint @@ -103,8 +101,7 @@ def test_describe_domain(): conn.register_domain("test-domain", "60", description="A test domain") domain = conn.describe_domain("test-domain") - domain["configuration"][ - "workflowExecutionRetentionPeriodInDays"].should.equal("60") + domain["configuration"]["workflowExecutionRetentionPeriodInDays"].should.equal("60") domain["domainInfo"]["description"].should.equal("A test domain") domain["domainInfo"]["name"].should.equal("test-domain") domain["domainInfo"]["status"].should.equal("REGISTERED") @@ -114,6 +111,4 @@ def test_describe_domain(): def test_describe_non_existent_domain(): conn = boto.connect_swf("the_key", "the_secret") - conn.describe_domain.when.called_with( - "non-existent" - ).should.throw(SWFResponseError) + conn.describe_domain.when.called_with("non-existent").should.throw(SWFResponseError) diff --git a/tests/test_swf/responses/test_timeouts.py b/tests/test_swf/responses/test_timeouts.py index 95d956f99..25ca8ae7d 100644 --- a/tests/test_swf/responses/test_timeouts.py +++ b/tests/test_swf/responses/test_timeouts.py @@ -1,110 +1,126 @@ -from freezegun import freeze_time -import sure # noqa - -from moto import mock_swf_deprecated - -from ..utils import setup_workflow, SCHEDULE_ACTIVITY_TASK_DECISION - - -# Activity Task Heartbeat timeout -# Default value in workflow helpers: 5 mins -@mock_swf_deprecated -def test_activity_task_heartbeat_timeout(): - with freeze_time("2015-01-01 12:00:00"): - conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - conn.poll_for_activity_task( - "test-domain", "activity-task-list", identity="surprise") - - with freeze_time("2015-01-01 12:04:30"): - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - resp["events"][-1]["eventType"].should.equal("ActivityTaskStarted") - - with freeze_time("2015-01-01 12:05:30"): - # => Activity Task Heartbeat timeout reached!! - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - - resp["events"][-2]["eventType"].should.equal("ActivityTaskTimedOut") - attrs = resp["events"][-2]["activityTaskTimedOutEventAttributes"] - attrs["timeoutType"].should.equal("HEARTBEAT") - # checks that event has been emitted at 12:05:00, not 12:05:30 - resp["events"][-2]["eventTimestamp"].should.equal(1420113900.0) - - resp["events"][-1]["eventType"].should.equal("DecisionTaskScheduled") - - -# Decision Task Start to Close timeout -# Default value in workflow helpers: 5 mins -@mock_swf_deprecated -def test_decision_task_start_to_close_timeout(): - pass - with freeze_time("2015-01-01 12:00:00"): - conn = setup_workflow() - conn.poll_for_decision_task("test-domain", "queue")["taskToken"] - - with freeze_time("2015-01-01 12:04:30"): - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - - event_types = [evt["eventType"] for evt in resp["events"]] - event_types.should.equal( - ["WorkflowExecutionStarted", "DecisionTaskScheduled", "DecisionTaskStarted"] - ) - - with freeze_time("2015-01-01 12:05:30"): - # => Decision Task Start to Close timeout reached!! - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - - event_types = [evt["eventType"] for evt in resp["events"]] - event_types.should.equal( - ["WorkflowExecutionStarted", "DecisionTaskScheduled", "DecisionTaskStarted", - "DecisionTaskTimedOut", "DecisionTaskScheduled"] - ) - attrs = resp["events"][-2]["decisionTaskTimedOutEventAttributes"] - attrs.should.equal({ - "scheduledEventId": 2, "startedEventId": 3, "timeoutType": "START_TO_CLOSE" - }) - # checks that event has been emitted at 12:05:00, not 12:05:30 - resp["events"][-2]["eventTimestamp"].should.equal(1420113900.0) - - -# Workflow Execution Start to Close timeout -# Default value in workflow helpers: 2 hours -@mock_swf_deprecated -def test_workflow_execution_start_to_close_timeout(): - pass - with freeze_time("2015-01-01 12:00:00"): - conn = setup_workflow() - - with freeze_time("2015-01-01 13:59:30"): - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - - event_types = [evt["eventType"] for evt in resp["events"]] - event_types.should.equal( - ["WorkflowExecutionStarted", "DecisionTaskScheduled"] - ) - - with freeze_time("2015-01-01 14:00:30"): - # => Workflow Execution Start to Close timeout reached!! - resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") - - event_types = [evt["eventType"] for evt in resp["events"]] - event_types.should.equal( - ["WorkflowExecutionStarted", "DecisionTaskScheduled", - "WorkflowExecutionTimedOut"] - ) - attrs = resp["events"][-1]["workflowExecutionTimedOutEventAttributes"] - attrs.should.equal({ - "childPolicy": "ABANDON", "timeoutType": "START_TO_CLOSE" - }) - # checks that event has been emitted at 14:00:00, not 14:00:30 - resp["events"][-1]["eventTimestamp"].should.equal(1420120800.0) +from freezegun import freeze_time +import sure # noqa + +from moto import mock_swf_deprecated + +from ..utils import setup_workflow, SCHEDULE_ACTIVITY_TASK_DECISION + + +# Activity Task Heartbeat timeout +# Default value in workflow helpers: 5 mins +@mock_swf_deprecated +def test_activity_task_heartbeat_timeout(): + with freeze_time("2015-01-01 12:00:00"): + conn = setup_workflow() + decision_token = conn.poll_for_decision_task("test-domain", "queue")[ + "taskToken" + ] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + conn.poll_for_activity_task( + "test-domain", "activity-task-list", identity="surprise" + ) + + with freeze_time("2015-01-01 12:04:30"): + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + resp["events"][-1]["eventType"].should.equal("ActivityTaskStarted") + + with freeze_time("2015-01-01 12:05:30"): + # => Activity Task Heartbeat timeout reached!! + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + + resp["events"][-2]["eventType"].should.equal("ActivityTaskTimedOut") + attrs = resp["events"][-2]["activityTaskTimedOutEventAttributes"] + attrs["timeoutType"].should.equal("HEARTBEAT") + # checks that event has been emitted at 12:05:00, not 12:05:30 + resp["events"][-2]["eventTimestamp"].should.equal(1420113900.0) + + resp["events"][-1]["eventType"].should.equal("DecisionTaskScheduled") + + +# Decision Task Start to Close timeout +# Default value in workflow helpers: 5 mins +@mock_swf_deprecated +def test_decision_task_start_to_close_timeout(): + pass + with freeze_time("2015-01-01 12:00:00"): + conn = setup_workflow() + conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + + with freeze_time("2015-01-01 12:04:30"): + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + + event_types = [evt["eventType"] for evt in resp["events"]] + event_types.should.equal( + ["WorkflowExecutionStarted", "DecisionTaskScheduled", "DecisionTaskStarted"] + ) + + with freeze_time("2015-01-01 12:05:30"): + # => Decision Task Start to Close timeout reached!! + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + + event_types = [evt["eventType"] for evt in resp["events"]] + event_types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskTimedOut", + "DecisionTaskScheduled", + ] + ) + attrs = resp["events"][-2]["decisionTaskTimedOutEventAttributes"] + attrs.should.equal( + { + "scheduledEventId": 2, + "startedEventId": 3, + "timeoutType": "START_TO_CLOSE", + } + ) + # checks that event has been emitted at 12:05:00, not 12:05:30 + resp["events"][-2]["eventTimestamp"].should.equal(1420113900.0) + + +# Workflow Execution Start to Close timeout +# Default value in workflow helpers: 2 hours +@mock_swf_deprecated +def test_workflow_execution_start_to_close_timeout(): + pass + with freeze_time("2015-01-01 12:00:00"): + conn = setup_workflow() + + with freeze_time("2015-01-01 13:59:30"): + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + + event_types = [evt["eventType"] for evt in resp["events"]] + event_types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) + + with freeze_time("2015-01-01 14:00:30"): + # => Workflow Execution Start to Close timeout reached!! + resp = conn.get_workflow_execution_history( + "test-domain", conn.run_id, "uid-abcd1234" + ) + + event_types = [evt["eventType"] for evt in resp["events"]] + event_types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "WorkflowExecutionTimedOut", + ] + ) + attrs = resp["events"][-1]["workflowExecutionTimedOutEventAttributes"] + attrs.should.equal({"childPolicy": "ABANDON", "timeoutType": "START_TO_CLOSE"}) + # checks that event has been emitted at 14:00:00, not 14:00:30 + resp["events"][-1]["eventTimestamp"].should.equal(1420120800.0) diff --git a/tests/test_swf/responses/test_workflow_executions.py b/tests/test_swf/responses/test_workflow_executions.py index 2cb092260..bec352ce8 100644 --- a/tests/test_swf/responses/test_workflow_executions.py +++ b/tests/test_swf/responses/test_workflow_executions.py @@ -1,262 +1,277 @@ -import boto -from boto.swf.exceptions import SWFResponseError -from datetime import datetime, timedelta - -import sure # noqa -# Ensure 'assert_raises' context manager support for Python 2.6 -import tests.backport_assert_raises # noqa - -from moto import mock_swf_deprecated -from moto.core.utils import unix_time - - -# Utils -@mock_swf_deprecated -def setup_swf_environment(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60", description="A test domain") - conn.register_workflow_type( - "test-domain", "test-workflow", "v1.0", - task_list="queue", default_child_policy="TERMINATE", - default_execution_start_to_close_timeout="300", - default_task_start_to_close_timeout="300", - ) - conn.register_activity_type("test-domain", "test-activity", "v1.1") - return conn - - -# StartWorkflowExecution endpoint -@mock_swf_deprecated -def test_start_workflow_execution(): - conn = setup_swf_environment() - - wf = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") - wf.should.contain("runId") - -@mock_swf_deprecated -def test_signal_workflow_execution(): - conn = setup_swf_environment() - hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") - run_id = hsh["runId"] - - wfe = conn.signal_workflow_execution( - "test-domain", "my_signal", "uid-abcd1234", "my_input", run_id) - - wfe = conn.describe_workflow_execution( - "test-domain", run_id, "uid-abcd1234") - - wfe["openCounts"]["openDecisionTasks"].should.equal(2) - -@mock_swf_deprecated -def test_start_already_started_workflow_execution(): - conn = setup_swf_environment() - conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") - - conn.start_workflow_execution.when.called_with( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0" - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_start_workflow_execution_on_deprecated_type(): - conn = setup_swf_environment() - conn.deprecate_workflow_type("test-domain", "test-workflow", "v1.0") - - conn.start_workflow_execution.when.called_with( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0" - ).should.throw(SWFResponseError) - - -# DescribeWorkflowExecution endpoint -@mock_swf_deprecated -def test_describe_workflow_execution(): - conn = setup_swf_environment() - hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") - run_id = hsh["runId"] - - wfe = conn.describe_workflow_execution( - "test-domain", run_id, "uid-abcd1234") - wfe["executionInfo"]["execution"][ - "workflowId"].should.equal("uid-abcd1234") - wfe["executionInfo"]["executionStatus"].should.equal("OPEN") - - -@mock_swf_deprecated -def test_describe_non_existent_workflow_execution(): - conn = setup_swf_environment() - - conn.describe_workflow_execution.when.called_with( - "test-domain", "wrong-run-id", "wrong-workflow-id" - ).should.throw(SWFResponseError) - - -# GetWorkflowExecutionHistory endpoint -@mock_swf_deprecated -def test_get_workflow_execution_history(): - conn = setup_swf_environment() - hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") - run_id = hsh["runId"] - - resp = conn.get_workflow_execution_history( - "test-domain", run_id, "uid-abcd1234") - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) - - -@mock_swf_deprecated -def test_get_workflow_execution_history_with_reverse_order(): - conn = setup_swf_environment() - hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") - run_id = hsh["runId"] - - resp = conn.get_workflow_execution_history("test-domain", run_id, "uid-abcd1234", - reverse_order=True) - types = [evt["eventType"] for evt in resp["events"]] - types.should.equal(["DecisionTaskScheduled", "WorkflowExecutionStarted"]) - - -@mock_swf_deprecated -def test_get_workflow_execution_history_on_non_existent_workflow_execution(): - conn = setup_swf_environment() - - conn.get_workflow_execution_history.when.called_with( - "test-domain", "wrong-run-id", "wrong-workflow-id" - ).should.throw(SWFResponseError) - - -# ListOpenWorkflowExecutions endpoint -@mock_swf_deprecated -def test_list_open_workflow_executions(): - conn = setup_swf_environment() - # One open workflow execution - conn.start_workflow_execution( - 'test-domain', 'uid-abcd1234', 'test-workflow', 'v1.0' - ) - # One closed workflow execution to make sure it isn't displayed - run_id = conn.start_workflow_execution( - 'test-domain', 'uid-abcd12345', 'test-workflow', 'v1.0' - )['runId'] - conn.terminate_workflow_execution('test-domain', 'uid-abcd12345', - details='some details', - reason='a more complete reason', - run_id=run_id) - - yesterday = datetime.utcnow() - timedelta(days=1) - oldest_date = unix_time(yesterday) - response = conn.list_open_workflow_executions('test-domain', - oldest_date, - workflow_id='test-workflow') - execution_infos = response['executionInfos'] - len(execution_infos).should.equal(1) - open_workflow = execution_infos[0] - open_workflow['workflowType'].should.equal({'version': 'v1.0', - 'name': 'test-workflow'}) - open_workflow.should.contain('startTimestamp') - open_workflow['execution']['workflowId'].should.equal('uid-abcd1234') - open_workflow['execution'].should.contain('runId') - open_workflow['cancelRequested'].should.be(False) - open_workflow['executionStatus'].should.equal('OPEN') - - -# ListClosedWorkflowExecutions endpoint -@mock_swf_deprecated -def test_list_closed_workflow_executions(): - conn = setup_swf_environment() - # Leave one workflow execution open to make sure it isn't displayed - conn.start_workflow_execution( - 'test-domain', 'uid-abcd1234', 'test-workflow', 'v1.0' - ) - # One closed workflow execution - run_id = conn.start_workflow_execution( - 'test-domain', 'uid-abcd12345', 'test-workflow', 'v1.0' - )['runId'] - conn.terminate_workflow_execution('test-domain', 'uid-abcd12345', - details='some details', - reason='a more complete reason', - run_id=run_id) - - yesterday = datetime.utcnow() - timedelta(days=1) - oldest_date = unix_time(yesterday) - response = conn.list_closed_workflow_executions( - 'test-domain', - start_oldest_date=oldest_date, - workflow_id='test-workflow') - execution_infos = response['executionInfos'] - len(execution_infos).should.equal(1) - open_workflow = execution_infos[0] - open_workflow['workflowType'].should.equal({'version': 'v1.0', - 'name': 'test-workflow'}) - open_workflow.should.contain('startTimestamp') - open_workflow['execution']['workflowId'].should.equal('uid-abcd12345') - open_workflow['execution'].should.contain('runId') - open_workflow['cancelRequested'].should.be(False) - open_workflow['executionStatus'].should.equal('CLOSED') - - -# TerminateWorkflowExecution endpoint -@mock_swf_deprecated -def test_terminate_workflow_execution(): - conn = setup_swf_environment() - run_id = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0" - )["runId"] - - resp = conn.terminate_workflow_execution("test-domain", "uid-abcd1234", - details="some details", - reason="a more complete reason", - run_id=run_id) - resp.should.be.none - - resp = conn.get_workflow_execution_history( - "test-domain", run_id, "uid-abcd1234") - evt = resp["events"][-1] - evt["eventType"].should.equal("WorkflowExecutionTerminated") - attrs = evt["workflowExecutionTerminatedEventAttributes"] - attrs["details"].should.equal("some details") - attrs["reason"].should.equal("a more complete reason") - attrs["cause"].should.equal("OPERATOR_INITIATED") - - -@mock_swf_deprecated -def test_terminate_workflow_execution_with_wrong_workflow_or_run_id(): - conn = setup_swf_environment() - run_id = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0" - )["runId"] - - # terminate workflow execution - conn.terminate_workflow_execution("test-domain", "uid-abcd1234") - - # already closed, with run_id - conn.terminate_workflow_execution.when.called_with( - "test-domain", "uid-abcd1234", run_id=run_id - ).should.throw( - SWFResponseError, "WorkflowExecution=[workflowId=uid-abcd1234, runId=" - ) - - # already closed, without run_id - conn.terminate_workflow_execution.when.called_with( - "test-domain", "uid-abcd1234" - ).should.throw( - SWFResponseError, "Unknown execution, workflowId = uid-abcd1234" - ) - - # wrong workflow id - conn.terminate_workflow_execution.when.called_with( - "test-domain", "uid-non-existent" - ).should.throw( - SWFResponseError, "Unknown execution, workflowId = uid-non-existent" - ) - - # wrong run_id - conn.terminate_workflow_execution.when.called_with( - "test-domain", "uid-abcd1234", run_id="foo" - ).should.throw( - SWFResponseError, "WorkflowExecution=[workflowId=uid-abcd1234, runId=" - ) +import boto +from boto.swf.exceptions import SWFResponseError +from datetime import datetime, timedelta + +import sure # noqa + +# Ensure 'assert_raises' context manager support for Python 2.6 +import tests.backport_assert_raises # noqa + +from moto import mock_swf_deprecated +from moto.core.utils import unix_time + + +# Utils +@mock_swf_deprecated +def setup_swf_environment(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60", description="A test domain") + conn.register_workflow_type( + "test-domain", + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="TERMINATE", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", + ) + conn.register_activity_type("test-domain", "test-activity", "v1.1") + return conn + + +# StartWorkflowExecution endpoint +@mock_swf_deprecated +def test_start_workflow_execution(): + conn = setup_swf_environment() + + wf = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + wf.should.contain("runId") + + +@mock_swf_deprecated +def test_signal_workflow_execution(): + conn = setup_swf_environment() + hsh = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + run_id = hsh["runId"] + + wfe = conn.signal_workflow_execution( + "test-domain", "my_signal", "uid-abcd1234", "my_input", run_id + ) + + wfe = conn.describe_workflow_execution("test-domain", run_id, "uid-abcd1234") + + wfe["openCounts"]["openDecisionTasks"].should.equal(2) + + +@mock_swf_deprecated +def test_start_already_started_workflow_execution(): + conn = setup_swf_environment() + conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + + conn.start_workflow_execution.when.called_with( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ).should.throw(SWFResponseError) + + +@mock_swf_deprecated +def test_start_workflow_execution_on_deprecated_type(): + conn = setup_swf_environment() + conn.deprecate_workflow_type("test-domain", "test-workflow", "v1.0") + + conn.start_workflow_execution.when.called_with( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ).should.throw(SWFResponseError) + + +# DescribeWorkflowExecution endpoint +@mock_swf_deprecated +def test_describe_workflow_execution(): + conn = setup_swf_environment() + hsh = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + run_id = hsh["runId"] + + wfe = conn.describe_workflow_execution("test-domain", run_id, "uid-abcd1234") + wfe["executionInfo"]["execution"]["workflowId"].should.equal("uid-abcd1234") + wfe["executionInfo"]["executionStatus"].should.equal("OPEN") + + +@mock_swf_deprecated +def test_describe_non_existent_workflow_execution(): + conn = setup_swf_environment() + + conn.describe_workflow_execution.when.called_with( + "test-domain", "wrong-run-id", "wrong-workflow-id" + ).should.throw(SWFResponseError) + + +# GetWorkflowExecutionHistory endpoint +@mock_swf_deprecated +def test_get_workflow_execution_history(): + conn = setup_swf_environment() + hsh = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + run_id = hsh["runId"] + + resp = conn.get_workflow_execution_history("test-domain", run_id, "uid-abcd1234") + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) + + +@mock_swf_deprecated +def test_get_workflow_execution_history_with_reverse_order(): + conn = setup_swf_environment() + hsh = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + run_id = hsh["runId"] + + resp = conn.get_workflow_execution_history( + "test-domain", run_id, "uid-abcd1234", reverse_order=True + ) + types = [evt["eventType"] for evt in resp["events"]] + types.should.equal(["DecisionTaskScheduled", "WorkflowExecutionStarted"]) + + +@mock_swf_deprecated +def test_get_workflow_execution_history_on_non_existent_workflow_execution(): + conn = setup_swf_environment() + + conn.get_workflow_execution_history.when.called_with( + "test-domain", "wrong-run-id", "wrong-workflow-id" + ).should.throw(SWFResponseError) + + +# ListOpenWorkflowExecutions endpoint +@mock_swf_deprecated +def test_list_open_workflow_executions(): + conn = setup_swf_environment() + # One open workflow execution + conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + # One closed workflow execution to make sure it isn't displayed + run_id = conn.start_workflow_execution( + "test-domain", "uid-abcd12345", "test-workflow", "v1.0" + )["runId"] + conn.terminate_workflow_execution( + "test-domain", + "uid-abcd12345", + details="some details", + reason="a more complete reason", + run_id=run_id, + ) + + yesterday = datetime.utcnow() - timedelta(days=1) + oldest_date = unix_time(yesterday) + response = conn.list_open_workflow_executions( + "test-domain", oldest_date, workflow_id="test-workflow" + ) + execution_infos = response["executionInfos"] + len(execution_infos).should.equal(1) + open_workflow = execution_infos[0] + open_workflow["workflowType"].should.equal( + {"version": "v1.0", "name": "test-workflow"} + ) + open_workflow.should.contain("startTimestamp") + open_workflow["execution"]["workflowId"].should.equal("uid-abcd1234") + open_workflow["execution"].should.contain("runId") + open_workflow["cancelRequested"].should.be(False) + open_workflow["executionStatus"].should.equal("OPEN") + + +# ListClosedWorkflowExecutions endpoint +@mock_swf_deprecated +def test_list_closed_workflow_executions(): + conn = setup_swf_environment() + # Leave one workflow execution open to make sure it isn't displayed + conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + # One closed workflow execution + run_id = conn.start_workflow_execution( + "test-domain", "uid-abcd12345", "test-workflow", "v1.0" + )["runId"] + conn.terminate_workflow_execution( + "test-domain", + "uid-abcd12345", + details="some details", + reason="a more complete reason", + run_id=run_id, + ) + + yesterday = datetime.utcnow() - timedelta(days=1) + oldest_date = unix_time(yesterday) + response = conn.list_closed_workflow_executions( + "test-domain", start_oldest_date=oldest_date, workflow_id="test-workflow" + ) + execution_infos = response["executionInfos"] + len(execution_infos).should.equal(1) + open_workflow = execution_infos[0] + open_workflow["workflowType"].should.equal( + {"version": "v1.0", "name": "test-workflow"} + ) + open_workflow.should.contain("startTimestamp") + open_workflow["execution"]["workflowId"].should.equal("uid-abcd12345") + open_workflow["execution"].should.contain("runId") + open_workflow["cancelRequested"].should.be(False) + open_workflow["executionStatus"].should.equal("CLOSED") + + +# TerminateWorkflowExecution endpoint +@mock_swf_deprecated +def test_terminate_workflow_execution(): + conn = setup_swf_environment() + run_id = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + )["runId"] + + resp = conn.terminate_workflow_execution( + "test-domain", + "uid-abcd1234", + details="some details", + reason="a more complete reason", + run_id=run_id, + ) + resp.should.be.none + + resp = conn.get_workflow_execution_history("test-domain", run_id, "uid-abcd1234") + evt = resp["events"][-1] + evt["eventType"].should.equal("WorkflowExecutionTerminated") + attrs = evt["workflowExecutionTerminatedEventAttributes"] + attrs["details"].should.equal("some details") + attrs["reason"].should.equal("a more complete reason") + attrs["cause"].should.equal("OPERATOR_INITIATED") + + +@mock_swf_deprecated +def test_terminate_workflow_execution_with_wrong_workflow_or_run_id(): + conn = setup_swf_environment() + run_id = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + )["runId"] + + # terminate workflow execution + conn.terminate_workflow_execution("test-domain", "uid-abcd1234") + + # already closed, with run_id + conn.terminate_workflow_execution.when.called_with( + "test-domain", "uid-abcd1234", run_id=run_id + ).should.throw( + SWFResponseError, "WorkflowExecution=[workflowId=uid-abcd1234, runId=" + ) + + # already closed, without run_id + conn.terminate_workflow_execution.when.called_with( + "test-domain", "uid-abcd1234" + ).should.throw(SWFResponseError, "Unknown execution, workflowId = uid-abcd1234") + + # wrong workflow id + conn.terminate_workflow_execution.when.called_with( + "test-domain", "uid-non-existent" + ).should.throw(SWFResponseError, "Unknown execution, workflowId = uid-non-existent") + + # wrong run_id + conn.terminate_workflow_execution.when.called_with( + "test-domain", "uid-abcd1234", run_id="foo" + ).should.throw( + SWFResponseError, "WorkflowExecution=[workflowId=uid-abcd1234, runId=" + ) diff --git a/tests/test_swf/responses/test_workflow_types.py b/tests/test_swf/responses/test_workflow_types.py index f0b39e7ad..4c92d7762 100644 --- a/tests/test_swf/responses/test_workflow_types.py +++ b/tests/test_swf/responses/test_workflow_types.py @@ -1,137 +1,143 @@ -import sure -import boto - -from moto import mock_swf_deprecated -from boto.swf.exceptions import SWFResponseError - - -# RegisterWorkflowType endpoint -@mock_swf_deprecated -def test_register_workflow_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "test-workflow", "v1.0") - - types = conn.list_workflow_types("test-domain", "REGISTERED") - actype = types["typeInfos"][0] - actype["workflowType"]["name"].should.equal("test-workflow") - actype["workflowType"]["version"].should.equal("v1.0") - - -@mock_swf_deprecated -def test_register_already_existing_workflow_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "test-workflow", "v1.0") - - conn.register_workflow_type.when.called_with( - "test-domain", "test-workflow", "v1.0" - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_register_with_wrong_parameter_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - - conn.register_workflow_type.when.called_with( - "test-domain", "test-workflow", 12 - ).should.throw(SWFResponseError) - - -# ListWorkflowTypes endpoint -@mock_swf_deprecated -def test_list_workflow_types(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "b-test-workflow", "v1.0") - conn.register_workflow_type("test-domain", "a-test-workflow", "v1.0") - conn.register_workflow_type("test-domain", "c-test-workflow", "v1.0") - - all_workflow_types = conn.list_workflow_types("test-domain", "REGISTERED") - names = [activity_type["workflowType"]["name"] - for activity_type in all_workflow_types["typeInfos"]] - names.should.equal( - ["a-test-workflow", "b-test-workflow", "c-test-workflow"]) - - -@mock_swf_deprecated -def test_list_workflow_types_reverse_order(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "b-test-workflow", "v1.0") - conn.register_workflow_type("test-domain", "a-test-workflow", "v1.0") - conn.register_workflow_type("test-domain", "c-test-workflow", "v1.0") - - all_workflow_types = conn.list_workflow_types("test-domain", "REGISTERED", - reverse_order=True) - names = [activity_type["workflowType"]["name"] - for activity_type in all_workflow_types["typeInfos"]] - names.should.equal( - ["c-test-workflow", "b-test-workflow", "a-test-workflow"]) - - -# DeprecateWorkflowType endpoint -@mock_swf_deprecated -def test_deprecate_workflow_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "test-workflow", "v1.0") - conn.deprecate_workflow_type("test-domain", "test-workflow", "v1.0") - - actypes = conn.list_workflow_types("test-domain", "DEPRECATED") - actype = actypes["typeInfos"][0] - actype["workflowType"]["name"].should.equal("test-workflow") - actype["workflowType"]["version"].should.equal("v1.0") - - -@mock_swf_deprecated -def test_deprecate_already_deprecated_workflow_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "test-workflow", "v1.0") - conn.deprecate_workflow_type("test-domain", "test-workflow", "v1.0") - - conn.deprecate_workflow_type.when.called_with( - "test-domain", "test-workflow", "v1.0" - ).should.throw(SWFResponseError) - - -@mock_swf_deprecated -def test_deprecate_non_existent_workflow_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - - conn.deprecate_workflow_type.when.called_with( - "test-domain", "non-existent", "v1.0" - ).should.throw(SWFResponseError) - - -# DescribeWorkflowType endpoint -@mock_swf_deprecated -def test_describe_workflow_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "test-workflow", "v1.0", - task_list="foo", default_child_policy="TERMINATE") - - actype = conn.describe_workflow_type( - "test-domain", "test-workflow", "v1.0") - actype["configuration"]["defaultTaskList"]["name"].should.equal("foo") - actype["configuration"]["defaultChildPolicy"].should.equal("TERMINATE") - actype["configuration"].keys().should_not.contain( - "defaultTaskStartToCloseTimeout") - infos = actype["typeInfo"] - infos["workflowType"]["name"].should.equal("test-workflow") - infos["workflowType"]["version"].should.equal("v1.0") - infos["status"].should.equal("REGISTERED") - - -@mock_swf_deprecated -def test_describe_non_existent_workflow_type(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60") - - conn.describe_workflow_type.when.called_with( - "test-domain", "non-existent", "v1.0" - ).should.throw(SWFResponseError) +import sure +import boto + +from moto import mock_swf_deprecated +from boto.swf.exceptions import SWFResponseError + + +# RegisterWorkflowType endpoint +@mock_swf_deprecated +def test_register_workflow_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_workflow_type("test-domain", "test-workflow", "v1.0") + + types = conn.list_workflow_types("test-domain", "REGISTERED") + actype = types["typeInfos"][0] + actype["workflowType"]["name"].should.equal("test-workflow") + actype["workflowType"]["version"].should.equal("v1.0") + + +@mock_swf_deprecated +def test_register_already_existing_workflow_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_workflow_type("test-domain", "test-workflow", "v1.0") + + conn.register_workflow_type.when.called_with( + "test-domain", "test-workflow", "v1.0" + ).should.throw(SWFResponseError) + + +@mock_swf_deprecated +def test_register_with_wrong_parameter_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + + conn.register_workflow_type.when.called_with( + "test-domain", "test-workflow", 12 + ).should.throw(SWFResponseError) + + +# ListWorkflowTypes endpoint +@mock_swf_deprecated +def test_list_workflow_types(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_workflow_type("test-domain", "b-test-workflow", "v1.0") + conn.register_workflow_type("test-domain", "a-test-workflow", "v1.0") + conn.register_workflow_type("test-domain", "c-test-workflow", "v1.0") + + all_workflow_types = conn.list_workflow_types("test-domain", "REGISTERED") + names = [ + activity_type["workflowType"]["name"] + for activity_type in all_workflow_types["typeInfos"] + ] + names.should.equal(["a-test-workflow", "b-test-workflow", "c-test-workflow"]) + + +@mock_swf_deprecated +def test_list_workflow_types_reverse_order(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_workflow_type("test-domain", "b-test-workflow", "v1.0") + conn.register_workflow_type("test-domain", "a-test-workflow", "v1.0") + conn.register_workflow_type("test-domain", "c-test-workflow", "v1.0") + + all_workflow_types = conn.list_workflow_types( + "test-domain", "REGISTERED", reverse_order=True + ) + names = [ + activity_type["workflowType"]["name"] + for activity_type in all_workflow_types["typeInfos"] + ] + names.should.equal(["c-test-workflow", "b-test-workflow", "a-test-workflow"]) + + +# DeprecateWorkflowType endpoint +@mock_swf_deprecated +def test_deprecate_workflow_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_workflow_type("test-domain", "test-workflow", "v1.0") + conn.deprecate_workflow_type("test-domain", "test-workflow", "v1.0") + + actypes = conn.list_workflow_types("test-domain", "DEPRECATED") + actype = actypes["typeInfos"][0] + actype["workflowType"]["name"].should.equal("test-workflow") + actype["workflowType"]["version"].should.equal("v1.0") + + +@mock_swf_deprecated +def test_deprecate_already_deprecated_workflow_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_workflow_type("test-domain", "test-workflow", "v1.0") + conn.deprecate_workflow_type("test-domain", "test-workflow", "v1.0") + + conn.deprecate_workflow_type.when.called_with( + "test-domain", "test-workflow", "v1.0" + ).should.throw(SWFResponseError) + + +@mock_swf_deprecated +def test_deprecate_non_existent_workflow_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + + conn.deprecate_workflow_type.when.called_with( + "test-domain", "non-existent", "v1.0" + ).should.throw(SWFResponseError) + + +# DescribeWorkflowType endpoint +@mock_swf_deprecated +def test_describe_workflow_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + conn.register_workflow_type( + "test-domain", + "test-workflow", + "v1.0", + task_list="foo", + default_child_policy="TERMINATE", + ) + + actype = conn.describe_workflow_type("test-domain", "test-workflow", "v1.0") + actype["configuration"]["defaultTaskList"]["name"].should.equal("foo") + actype["configuration"]["defaultChildPolicy"].should.equal("TERMINATE") + actype["configuration"].keys().should_not.contain("defaultTaskStartToCloseTimeout") + infos = actype["typeInfo"] + infos["workflowType"]["name"].should.equal("test-workflow") + infos["workflowType"]["version"].should.equal("v1.0") + infos["status"].should.equal("REGISTERED") + + +@mock_swf_deprecated +def test_describe_non_existent_workflow_type(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60") + + conn.describe_workflow_type.when.called_with( + "test-domain", "non-existent", "v1.0" + ).should.throw(SWFResponseError) diff --git a/tests/test_swf/test_exceptions.py b/tests/test_swf/test_exceptions.py index b91a697b9..2e42cdb9b 100644 --- a/tests/test_swf/test_exceptions.py +++ b/tests/test_swf/test_exceptions.py @@ -1,158 +1,181 @@ -from __future__ import unicode_literals -import sure # noqa - -import json - -from moto.swf.exceptions import ( - SWFClientError, - SWFUnknownResourceFault, - SWFDomainAlreadyExistsFault, - SWFDomainDeprecatedFault, - SWFSerializationException, - SWFTypeAlreadyExistsFault, - SWFTypeDeprecatedFault, - SWFWorkflowExecutionAlreadyStartedFault, - SWFDefaultUndefinedFault, - SWFValidationException, - SWFDecisionValidationException, -) -from moto.swf.models import ( - WorkflowType, -) - - -def test_swf_client_error(): - ex = SWFClientError("ASpecificType", "error message") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "ASpecificType", - "message": "error message" - }) - - -def test_swf_unknown_resource_fault(): - ex = SWFUnknownResourceFault("type", "detail") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", - "message": "Unknown type: detail" - }) - - -def test_swf_unknown_resource_fault_with_only_one_parameter(): - ex = SWFUnknownResourceFault("foo bar baz") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", - "message": "Unknown foo bar baz" - }) - - -def test_swf_domain_already_exists_fault(): - ex = SWFDomainAlreadyExistsFault("domain-name") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", - "message": "domain-name" - }) - - -def test_swf_domain_deprecated_fault(): - ex = SWFDomainDeprecatedFault("domain-name") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#DomainDeprecatedFault", - "message": "domain-name" - }) - - -def test_swf_serialization_exception(): - ex = SWFSerializationException("value") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#SerializationException", - "message": "class java.lang.Foo can not be converted to an String (not a real SWF exception ; happened on: value)" - }) - - -def test_swf_type_already_exists_fault(): - wft = WorkflowType("wf-name", "wf-version") - ex = SWFTypeAlreadyExistsFault(wft) - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#TypeAlreadyExistsFault", - "message": "WorkflowType=[name=wf-name, version=wf-version]" - }) - - -def test_swf_type_deprecated_fault(): - wft = WorkflowType("wf-name", "wf-version") - ex = SWFTypeDeprecatedFault(wft) - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#TypeDeprecatedFault", - "message": "WorkflowType=[name=wf-name, version=wf-version]" - }) - - -def test_swf_workflow_execution_already_started_fault(): - ex = SWFWorkflowExecutionAlreadyStartedFault() - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", - 'message': 'Already Started', - }) - - -def test_swf_default_undefined_fault(): - ex = SWFDefaultUndefinedFault("execution_start_to_close_timeout") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#DefaultUndefinedFault", - "message": "executionStartToCloseTimeout", - }) - - -def test_swf_validation_exception(): - ex = SWFValidationException("Invalid token") - - ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazon.coral.validate#ValidationException", - "message": "Invalid token", - }) - - -def test_swf_decision_validation_error(): - ex = SWFDecisionValidationException([ - {"type": "null_value", - "where": "decisions.1.member.startTimerDecisionAttributes.startToFireTimeout"}, - {"type": "bad_decision_type", - "value": "FooBar", - "where": "decisions.1.member.decisionType", - "possible_values": "Foo, Bar, Baz"}, - ]) - - ex.code.should.equal(400) - ex.error_type.should.equal("com.amazon.coral.validate#ValidationException") - - msg = ex.get_body() - msg.should.match(r"2 validation errors detected:") - msg.should.match( - r"Value null at 'decisions.1.member.startTimerDecisionAttributes.startToFireTimeout' " - r"failed to satisfy constraint: Member must not be null;" - ) - msg.should.match( - r"Value 'FooBar' at 'decisions.1.member.decisionType' failed to satisfy constraint: " - r"Member must satisfy enum value set: \[Foo, Bar, Baz\]" - ) +from __future__ import unicode_literals +import sure # noqa + +import json + +from moto.swf.exceptions import ( + SWFClientError, + SWFUnknownResourceFault, + SWFDomainAlreadyExistsFault, + SWFDomainDeprecatedFault, + SWFSerializationException, + SWFTypeAlreadyExistsFault, + SWFTypeDeprecatedFault, + SWFWorkflowExecutionAlreadyStartedFault, + SWFDefaultUndefinedFault, + SWFValidationException, + SWFDecisionValidationException, +) +from moto.swf.models import WorkflowType + + +def test_swf_client_error(): + ex = SWFClientError("ASpecificType", "error message") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + {"__type": "ASpecificType", "message": "error message"} + ) + + +def test_swf_unknown_resource_fault(): + ex = SWFUnknownResourceFault("type", "detail") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", + "message": "Unknown type: detail", + } + ) + + +def test_swf_unknown_resource_fault_with_only_one_parameter(): + ex = SWFUnknownResourceFault("foo bar baz") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", + "message": "Unknown foo bar baz", + } + ) + + +def test_swf_domain_already_exists_fault(): + ex = SWFDomainAlreadyExistsFault("domain-name") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", + "message": "domain-name", + } + ) + + +def test_swf_domain_deprecated_fault(): + ex = SWFDomainDeprecatedFault("domain-name") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#DomainDeprecatedFault", + "message": "domain-name", + } + ) + + +def test_swf_serialization_exception(): + ex = SWFSerializationException("value") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#SerializationException", + "message": "class java.lang.Foo can not be converted to an String (not a real SWF exception ; happened on: value)", + } + ) + + +def test_swf_type_already_exists_fault(): + wft = WorkflowType("wf-name", "wf-version") + ex = SWFTypeAlreadyExistsFault(wft) + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#TypeAlreadyExistsFault", + "message": "WorkflowType=[name=wf-name, version=wf-version]", + } + ) + + +def test_swf_type_deprecated_fault(): + wft = WorkflowType("wf-name", "wf-version") + ex = SWFTypeDeprecatedFault(wft) + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#TypeDeprecatedFault", + "message": "WorkflowType=[name=wf-name, version=wf-version]", + } + ) + + +def test_swf_workflow_execution_already_started_fault(): + ex = SWFWorkflowExecutionAlreadyStartedFault() + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", + "message": "Already Started", + } + ) + + +def test_swf_default_undefined_fault(): + ex = SWFDefaultUndefinedFault("execution_start_to_close_timeout") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#DefaultUndefinedFault", + "message": "executionStartToCloseTimeout", + } + ) + + +def test_swf_validation_exception(): + ex = SWFValidationException("Invalid token") + + ex.code.should.equal(400) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazon.coral.validate#ValidationException", + "message": "Invalid token", + } + ) + + +def test_swf_decision_validation_error(): + ex = SWFDecisionValidationException( + [ + { + "type": "null_value", + "where": "decisions.1.member.startTimerDecisionAttributes.startToFireTimeout", + }, + { + "type": "bad_decision_type", + "value": "FooBar", + "where": "decisions.1.member.decisionType", + "possible_values": "Foo, Bar, Baz", + }, + ] + ) + + ex.code.should.equal(400) + ex.error_type.should.equal("com.amazon.coral.validate#ValidationException") + + msg = ex.get_body() + msg.should.match(r"2 validation errors detected:") + msg.should.match( + r"Value null at 'decisions.1.member.startTimerDecisionAttributes.startToFireTimeout' " + r"failed to satisfy constraint: Member must not be null;" + ) + msg.should.match( + r"Value 'FooBar' at 'decisions.1.member.decisionType' failed to satisfy constraint: " + r"Member must satisfy enum value set: \[Foo, Bar, Baz\]" + ) diff --git a/tests/test_swf/test_utils.py b/tests/test_swf/test_utils.py index 2e04b990c..143804ca9 100644 --- a/tests/test_swf/test_utils.py +++ b/tests/test_swf/test_utils.py @@ -4,10 +4,6 @@ from moto.swf.utils import decapitalize def test_decapitalize(): - cases = { - "fooBar": "fooBar", - "FooBar": "fooBar", - "FOO BAR": "fOO BAR", - } + cases = {"fooBar": "fooBar", "FooBar": "fooBar", "FOO BAR": "fOO BAR"} for before, after in cases.items(): decapitalize(before).should.equal(after) diff --git a/tests/test_swf/utils.py b/tests/test_swf/utils.py index 4879a0011..48c2cbd94 100644 --- a/tests/test_swf/utils.py +++ b/tests/test_swf/utils.py @@ -1,100 +1,100 @@ -import boto - -from moto.swf.models import ( - ActivityType, - Domain, - WorkflowType, - WorkflowExecution, -) - - -# Some useful constants -# Here are some activity timeouts we use in moto/swf tests ; they're extracted -# from semi-real world example, the goal is mostly to have predictible and -# intuitive behaviour in moto/swf own tests... -ACTIVITY_TASK_TIMEOUTS = { - "heartbeatTimeout": "300", # 5 mins - "scheduleToStartTimeout": "1800", # 30 mins - "startToCloseTimeout": "1800", # 30 mins - "scheduleToCloseTimeout": "2700", # 45 mins -} - -# Some useful decisions -SCHEDULE_ACTIVITY_TASK_DECISION = { - "decisionType": "ScheduleActivityTask", - "scheduleActivityTaskDecisionAttributes": { - "activityId": "my-activity-001", - "activityType": {"name": "test-activity", "version": "v1.1"}, - "taskList": {"name": "activity-task-list"}, - } -} -for key, value in ACTIVITY_TASK_TIMEOUTS.items(): - SCHEDULE_ACTIVITY_TASK_DECISION[ - "scheduleActivityTaskDecisionAttributes"][key] = value - - -# A test Domain -def get_basic_domain(): - return Domain("test-domain", "90") - - -# A test WorkflowType -def _generic_workflow_type_attributes(): - return [ - "test-workflow", "v1.0" - ], { - "task_list": "queue", - "default_child_policy": "ABANDON", - "default_execution_start_to_close_timeout": "7200", - "default_task_start_to_close_timeout": "300", - } - - -def get_basic_workflow_type(): - args, kwargs = _generic_workflow_type_attributes() - return WorkflowType(*args, **kwargs) - - -def mock_basic_workflow_type(domain_name, conn): - args, kwargs = _generic_workflow_type_attributes() - conn.register_workflow_type(domain_name, *args, **kwargs) - return conn - - -# A test WorkflowExecution -def make_workflow_execution(**kwargs): - domain = get_basic_domain() - domain.add_type(ActivityType("test-activity", "v1.1")) - wft = get_basic_workflow_type() - return WorkflowExecution(domain, wft, "ab1234", **kwargs) - - -# Makes decision tasks start automatically on a given workflow -def auto_start_decision_tasks(wfe): - wfe.schedule_decision_task = wfe.schedule_and_start_decision_task - return wfe - - -# Setup a complete example workflow and return the connection object -def setup_workflow(): - conn = boto.connect_swf("the_key", "the_secret") - conn.register_domain("test-domain", "60", description="A test domain") - conn = mock_basic_workflow_type("test-domain", conn) - conn.register_activity_type( - "test-domain", "test-activity", "v1.1", - default_task_heartbeat_timeout="600", - default_task_schedule_to_close_timeout="600", - default_task_schedule_to_start_timeout="600", - default_task_start_to_close_timeout="600", - ) - wfe = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") - conn.run_id = wfe["runId"] - return conn - - -# A helper for processing the first timeout on a given object -def process_first_timeout(obj): - _timeout = obj.first_timeout() - if _timeout: - obj.timeout(_timeout) +import boto + +from moto.swf.models import ActivityType, Domain, WorkflowType, WorkflowExecution + + +# Some useful constants +# Here are some activity timeouts we use in moto/swf tests ; they're extracted +# from semi-real world example, the goal is mostly to have predictible and +# intuitive behaviour in moto/swf own tests... +ACTIVITY_TASK_TIMEOUTS = { + "heartbeatTimeout": "300", # 5 mins + "scheduleToStartTimeout": "1800", # 30 mins + "startToCloseTimeout": "1800", # 30 mins + "scheduleToCloseTimeout": "2700", # 45 mins +} + +# Some useful decisions +SCHEDULE_ACTIVITY_TASK_DECISION = { + "decisionType": "ScheduleActivityTask", + "scheduleActivityTaskDecisionAttributes": { + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.1"}, + "taskList": {"name": "activity-task-list"}, + }, +} +for key, value in ACTIVITY_TASK_TIMEOUTS.items(): + SCHEDULE_ACTIVITY_TASK_DECISION["scheduleActivityTaskDecisionAttributes"][ + key + ] = value + + +# A test Domain +def get_basic_domain(): + return Domain("test-domain", "90") + + +# A test WorkflowType +def _generic_workflow_type_attributes(): + return ( + ["test-workflow", "v1.0"], + { + "task_list": "queue", + "default_child_policy": "ABANDON", + "default_execution_start_to_close_timeout": "7200", + "default_task_start_to_close_timeout": "300", + }, + ) + + +def get_basic_workflow_type(): + args, kwargs = _generic_workflow_type_attributes() + return WorkflowType(*args, **kwargs) + + +def mock_basic_workflow_type(domain_name, conn): + args, kwargs = _generic_workflow_type_attributes() + conn.register_workflow_type(domain_name, *args, **kwargs) + return conn + + +# A test WorkflowExecution +def make_workflow_execution(**kwargs): + domain = get_basic_domain() + domain.add_type(ActivityType("test-activity", "v1.1")) + wft = get_basic_workflow_type() + return WorkflowExecution(domain, wft, "ab1234", **kwargs) + + +# Makes decision tasks start automatically on a given workflow +def auto_start_decision_tasks(wfe): + wfe.schedule_decision_task = wfe.schedule_and_start_decision_task + return wfe + + +# Setup a complete example workflow and return the connection object +def setup_workflow(): + conn = boto.connect_swf("the_key", "the_secret") + conn.register_domain("test-domain", "60", description="A test domain") + conn = mock_basic_workflow_type("test-domain", conn) + conn.register_activity_type( + "test-domain", + "test-activity", + "v1.1", + default_task_heartbeat_timeout="600", + default_task_schedule_to_close_timeout="600", + default_task_schedule_to_start_timeout="600", + default_task_start_to_close_timeout="600", + ) + wfe = conn.start_workflow_execution( + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) + conn.run_id = wfe["runId"] + return conn + + +# A helper for processing the first timeout on a given object +def process_first_timeout(obj): + _timeout = obj.first_timeout() + if _timeout: + obj.timeout(_timeout) diff --git a/tests/test_xray/test_xray_boto3.py b/tests/test_xray/test_xray_boto3.py index c754e3a69..4089abd2e 100644 --- a/tests/test_xray/test_xray_boto3.py +++ b/tests/test_xray/test_xray_boto3.py @@ -1,139 +1,148 @@ -from __future__ import unicode_literals - -import boto3 -import json -import botocore.exceptions -import sure # noqa - -from moto import mock_xray - -import datetime - - -@mock_xray -def test_put_telemetry(): - client = boto3.client('xray', region_name='us-east-1') - - client.put_telemetry_records( - TelemetryRecords=[ - { - 'Timestamp': datetime.datetime(2015, 1, 1), - 'SegmentsReceivedCount': 123, - 'SegmentsSentCount': 123, - 'SegmentsSpilloverCount': 123, - 'SegmentsRejectedCount': 123, - 'BackendConnectionErrors': { - 'TimeoutCount': 123, - 'ConnectionRefusedCount': 123, - 'HTTPCode4XXCount': 123, - 'HTTPCode5XXCount': 123, - 'UnknownHostCount': 123, - 'OtherCount': 123 - } - }, - ], - EC2InstanceId='string', - Hostname='string', - ResourceARN='string' - ) - - -@mock_xray -def test_put_trace_segments(): - client = boto3.client('xray', region_name='us-east-1') - - client.put_trace_segments( - TraceSegmentDocuments=[ - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0a', - 'start_time': 1.478293361271E9, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'end_time': 1.478293361449E9 - }) - ] - ) - - -@mock_xray -def test_trace_summary(): - client = boto3.client('xray', region_name='us-east-1') - - client.put_trace_segments( - TraceSegmentDocuments=[ - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0a', - 'start_time': 1.478293361271E9, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'in_progress': True - }), - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0b', - 'start_time': 1478293365, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'end_time': 1478293385 - }) - ] - ) - - client.get_trace_summaries( - StartTime=datetime.datetime(2014, 1, 1), - EndTime=datetime.datetime(2017, 1, 1) - ) - - -@mock_xray -def test_batch_get_trace(): - client = boto3.client('xray', region_name='us-east-1') - - client.put_trace_segments( - TraceSegmentDocuments=[ - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0a', - 'start_time': 1.478293361271E9, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'in_progress': True - }), - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0b', - 'start_time': 1478293365, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'end_time': 1478293385 - }) - ] - ) - - resp = client.batch_get_traces( - TraceIds=['1-581cf771-a006649127e371903a2de979', '1-581cf772-b006649127e371903a2de979'] - ) - len(resp['UnprocessedTraceIds']).should.equal(1) - len(resp['Traces']).should.equal(1) - - -# Following are not implemented, just testing it returns what boto expects -@mock_xray -def test_batch_get_service_graph(): - client = boto3.client('xray', region_name='us-east-1') - - client.get_service_graph( - StartTime=datetime.datetime(2014, 1, 1), - EndTime=datetime.datetime(2017, 1, 1) - ) - - -@mock_xray -def test_batch_get_trace_graph(): - client = boto3.client('xray', region_name='us-east-1') - - client.batch_get_traces( - TraceIds=['1-581cf771-a006649127e371903a2de979', '1-581cf772-b006649127e371903a2de979'] - ) - - - - - +from __future__ import unicode_literals + +import boto3 +import json +import botocore.exceptions +import sure # noqa + +from moto import mock_xray + +import datetime + + +@mock_xray +def test_put_telemetry(): + client = boto3.client("xray", region_name="us-east-1") + + client.put_telemetry_records( + TelemetryRecords=[ + { + "Timestamp": datetime.datetime(2015, 1, 1), + "SegmentsReceivedCount": 123, + "SegmentsSentCount": 123, + "SegmentsSpilloverCount": 123, + "SegmentsRejectedCount": 123, + "BackendConnectionErrors": { + "TimeoutCount": 123, + "ConnectionRefusedCount": 123, + "HTTPCode4XXCount": 123, + "HTTPCode5XXCount": 123, + "UnknownHostCount": 123, + "OtherCount": 123, + }, + } + ], + EC2InstanceId="string", + Hostname="string", + ResourceARN="string", + ) + + +@mock_xray +def test_put_trace_segments(): + client = boto3.client("xray", region_name="us-east-1") + + client.put_trace_segments( + TraceSegmentDocuments=[ + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0a", + "start_time": 1.478293361271e9, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "end_time": 1.478293361449e9, + } + ) + ] + ) + + +@mock_xray +def test_trace_summary(): + client = boto3.client("xray", region_name="us-east-1") + + client.put_trace_segments( + TraceSegmentDocuments=[ + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0a", + "start_time": 1.478293361271e9, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "in_progress": True, + } + ), + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0b", + "start_time": 1478293365, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "end_time": 1478293385, + } + ), + ] + ) + + client.get_trace_summaries( + StartTime=datetime.datetime(2014, 1, 1), EndTime=datetime.datetime(2017, 1, 1) + ) + + +@mock_xray +def test_batch_get_trace(): + client = boto3.client("xray", region_name="us-east-1") + + client.put_trace_segments( + TraceSegmentDocuments=[ + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0a", + "start_time": 1.478293361271e9, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "in_progress": True, + } + ), + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0b", + "start_time": 1478293365, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "end_time": 1478293385, + } + ), + ] + ) + + resp = client.batch_get_traces( + TraceIds=[ + "1-581cf771-a006649127e371903a2de979", + "1-581cf772-b006649127e371903a2de979", + ] + ) + len(resp["UnprocessedTraceIds"]).should.equal(1) + len(resp["Traces"]).should.equal(1) + + +# Following are not implemented, just testing it returns what boto expects +@mock_xray +def test_batch_get_service_graph(): + client = boto3.client("xray", region_name="us-east-1") + + client.get_service_graph( + StartTime=datetime.datetime(2014, 1, 1), EndTime=datetime.datetime(2017, 1, 1) + ) + + +@mock_xray +def test_batch_get_trace_graph(): + client = boto3.client("xray", region_name="us-east-1") + + client.batch_get_traces( + TraceIds=[ + "1-581cf771-a006649127e371903a2de979", + "1-581cf772-b006649127e371903a2de979", + ] + ) diff --git a/tests/test_xray/test_xray_client.py b/tests/test_xray/test_xray_client.py index 8e7b84be0..6b74136c9 100644 --- a/tests/test_xray/test_xray_client.py +++ b/tests/test_xray/test_xray_client.py @@ -1,72 +1,74 @@ -from __future__ import unicode_literals -from moto import mock_xray_client, XRaySegment, mock_dynamodb2 -import sure # noqa -import boto3 - -from moto.xray.mock_client import MockEmitter -import aws_xray_sdk.core as xray_core -import aws_xray_sdk.core.patcher as xray_core_patcher - -import botocore.client -import botocore.endpoint -original_make_api_call = botocore.client.BaseClient._make_api_call -original_encode_headers = botocore.endpoint.Endpoint._encode_headers - -import requests -original_session_request = requests.Session.request -original_session_prep_request = requests.Session.prepare_request - - -@mock_xray_client -@mock_dynamodb2 -def test_xray_dynamo_request_id(): - # Could be ran in any order, so we need to tell sdk that its been unpatched - xray_core_patcher._PATCHED_MODULES = set() - xray_core.patch_all() - - client = boto3.client('dynamodb', region_name='us-east-1') - - with XRaySegment(): - resp = client.list_tables() - resp['ResponseMetadata'].should.contain('RequestId') - id1 = resp['ResponseMetadata']['RequestId'] - - with XRaySegment(): - client.list_tables() - resp = client.list_tables() - id2 = resp['ResponseMetadata']['RequestId'] - - id1.should_not.equal(id2) - - setattr(botocore.client.BaseClient, '_make_api_call', original_make_api_call) - setattr(botocore.endpoint.Endpoint, '_encode_headers', original_encode_headers) - setattr(requests.Session, 'request', original_session_request) - setattr(requests.Session, 'prepare_request', original_session_prep_request) - - -@mock_xray_client -def test_xray_udp_emitter_patched(): - # Could be ran in any order, so we need to tell sdk that its been unpatched - xray_core_patcher._PATCHED_MODULES = set() - xray_core.patch_all() - - assert isinstance(xray_core.xray_recorder._emitter, MockEmitter) - - setattr(botocore.client.BaseClient, '_make_api_call', original_make_api_call) - setattr(botocore.endpoint.Endpoint, '_encode_headers', original_encode_headers) - setattr(requests.Session, 'request', original_session_request) - setattr(requests.Session, 'prepare_request', original_session_prep_request) - - -@mock_xray_client -def test_xray_context_patched(): - # Could be ran in any order, so we need to tell sdk that its been unpatched - xray_core_patcher._PATCHED_MODULES = set() - xray_core.patch_all() - - xray_core.xray_recorder._context.context_missing.should.equal('LOG_ERROR') - - setattr(botocore.client.BaseClient, '_make_api_call', original_make_api_call) - setattr(botocore.endpoint.Endpoint, '_encode_headers', original_encode_headers) - setattr(requests.Session, 'request', original_session_request) - setattr(requests.Session, 'prepare_request', original_session_prep_request) +from __future__ import unicode_literals +from moto import mock_xray_client, XRaySegment, mock_dynamodb2 +import sure # noqa +import boto3 + +from moto.xray.mock_client import MockEmitter +import aws_xray_sdk.core as xray_core +import aws_xray_sdk.core.patcher as xray_core_patcher + +import botocore.client +import botocore.endpoint + +original_make_api_call = botocore.client.BaseClient._make_api_call +original_encode_headers = botocore.endpoint.Endpoint._encode_headers + +import requests + +original_session_request = requests.Session.request +original_session_prep_request = requests.Session.prepare_request + + +@mock_xray_client +@mock_dynamodb2 +def test_xray_dynamo_request_id(): + # Could be ran in any order, so we need to tell sdk that its been unpatched + xray_core_patcher._PATCHED_MODULES = set() + xray_core.patch_all() + + client = boto3.client("dynamodb", region_name="us-east-1") + + with XRaySegment(): + resp = client.list_tables() + resp["ResponseMetadata"].should.contain("RequestId") + id1 = resp["ResponseMetadata"]["RequestId"] + + with XRaySegment(): + client.list_tables() + resp = client.list_tables() + id2 = resp["ResponseMetadata"]["RequestId"] + + id1.should_not.equal(id2) + + setattr(botocore.client.BaseClient, "_make_api_call", original_make_api_call) + setattr(botocore.endpoint.Endpoint, "_encode_headers", original_encode_headers) + setattr(requests.Session, "request", original_session_request) + setattr(requests.Session, "prepare_request", original_session_prep_request) + + +@mock_xray_client +def test_xray_udp_emitter_patched(): + # Could be ran in any order, so we need to tell sdk that its been unpatched + xray_core_patcher._PATCHED_MODULES = set() + xray_core.patch_all() + + assert isinstance(xray_core.xray_recorder._emitter, MockEmitter) + + setattr(botocore.client.BaseClient, "_make_api_call", original_make_api_call) + setattr(botocore.endpoint.Endpoint, "_encode_headers", original_encode_headers) + setattr(requests.Session, "request", original_session_request) + setattr(requests.Session, "prepare_request", original_session_prep_request) + + +@mock_xray_client +def test_xray_context_patched(): + # Could be ran in any order, so we need to tell sdk that its been unpatched + xray_core_patcher._PATCHED_MODULES = set() + xray_core.patch_all() + + xray_core.xray_recorder._context.context_missing.should.equal("LOG_ERROR") + + setattr(botocore.client.BaseClient, "_make_api_call", original_make_api_call) + setattr(botocore.endpoint.Endpoint, "_encode_headers", original_encode_headers) + setattr(requests.Session, "request", original_session_request) + setattr(requests.Session, "prepare_request", original_session_prep_request) diff --git a/tox.ini b/tox.ini index 52b66711e..9dacca18c 100644 --- a/tox.ini +++ b/tox.ini @@ -15,5 +15,5 @@ commands = nosetests {posargs} [flake8] -ignore = E128,E501 -exclude = moto/packages,dist \ No newline at end of file +ignore = W503,W605,E128,E501,E203,E266,E501,E231 +exclude = moto/packages,dist