From ed86df6baefd8ff995ecba560ba85e95e4d55f62 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sat, 29 Jan 2022 11:04:14 -0100 Subject: [PATCH] ELBv2 improvements (#4808) --- moto/elbv2/models.py | 64 +++++++++------- moto/elbv2/responses.py | 18 ++--- tests/test_elbv2/test_elbv2.py | 81 +++++++++++++++++++- tests/test_elbv2/test_elbv2_target_groups.py | 4 +- 4 files changed, 129 insertions(+), 38 deletions(-) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 403664ed6..d8883e4b5 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -62,6 +62,7 @@ class FakeTargetGroup(CloudFormationModel): vpc_id, protocol, port, + protocol_version=None, healthcheck_protocol=None, healthcheck_port=None, healthcheck_path=None, @@ -79,6 +80,7 @@ class FakeTargetGroup(CloudFormationModel): self.arn = arn self.vpc_id = vpc_id self.protocol = protocol + self.protocol_version = protocol_version or "HTTP1" self.port = port self.healthcheck_protocol = healthcheck_protocol or self.protocol self.healthcheck_port = healthcheck_port @@ -912,7 +914,7 @@ Member must satisfy regular expression pattern: {}".format( if target_group.name == name: raise DuplicateTargetGroupName() - valid_protocols = ["HTTPS", "HTTP", "TCP"] + valid_protocols = ["HTTPS", "HTTP", "TCP", "TLS", "UDP", "TCP_UDP", "GENEVE"] if ( kwargs.get("healthcheck_protocol") and kwargs["healthcheck_protocol"] not in valid_protocols @@ -948,6 +950,13 @@ Member must satisfy regular expression pattern: {}".format( self.target_groups[target_group.arn] = target_group return target_group + def modify_target_group_attributes(self, target_group_arn, attributes): + target_group = self.target_groups.get(target_group_arn) + if not target_group: + raise TargetGroupNotFoundError() + + target_group.attributes.update(attributes) + def convert_and_validate_certificates(self, certificates): # transform default certificate to conform with the rest of the code and XML templates @@ -1355,7 +1364,7 @@ Member must satisfy regular expression pattern: {}".format( "HttpCode must be like 200 | 200-399 | 200,201 ...", ) - if http_codes is not None: + if http_codes is not None and target_group.protocol in ["HTTP", "HTTPS"]: target_group.matcher["HttpCode"] = http_codes if health_check_interval is not None: target_group.healthcheck_interval_seconds = health_check_interval @@ -1397,34 +1406,35 @@ Member must satisfy regular expression pattern: {}".format( if port is not None: listener.port = port - if protocol is not None: - if protocol not in ("HTTP", "HTTPS", "TCP"): + if protocol not in (None, "HTTP", "HTTPS", "TCP"): + raise RESTError( + "UnsupportedProtocol", "Protocol {0} is not supported".format(protocol), + ) + + # HTTPS checks + protocol_becomes_https = protocol == "HTTPS" + protocol_stays_https = protocol is None and listener.protocol == "HTTPS" + if protocol_becomes_https or protocol_stays_https: + # Check certificates exist + if certificates: + default_cert = certificates[0] + default_cert_arn = default_cert["certificate_arn"] + try: + self.acm_backend.get_certificate(default_cert_arn) + except Exception: + raise RESTError( + "CertificateNotFound", + "Certificate {0} not found".format(default_cert_arn), + ) + listener.certificate = default_cert_arn + listener.certificates = certificates + else: raise RESTError( - "UnsupportedProtocol", - "Protocol {0} is not supported".format(protocol), + "CertificateWereNotPassed", + "You must provide a list containing exactly one certificate if the listener protocol is HTTPS.", ) - # HTTPS checks - if protocol == "HTTPS": - # Check certificates exist - if certificates: - default_cert = certificates[0] - default_cert_arn = default_cert["certificate_arn"] - try: - self.acm_backend.get_certificate(default_cert_arn) - except Exception: - raise RESTError( - "CertificateNotFound", - "Certificate {0} not found".format(default_cert_arn), - ) - listener.certificate = default_cert_arn - listener.certificates = certificates - else: - raise RESTError( - "CertificateWereNotPassed", - "You must provide a list containing exactly one certificate if the listener protocol is HTTPS.", - ) - + if protocol is not None: listener.protocol = protocol if ssl_policy is not None: diff --git a/moto/elbv2/responses.py b/moto/elbv2/responses.py index c576a6715..56d1ca3b0 100644 --- a/moto/elbv2/responses.py +++ b/moto/elbv2/responses.py @@ -178,6 +178,7 @@ class ELBV2Response(BaseResponse): name = self._get_param("Name") vpc_id = self._get_param("VpcId") protocol = self._get_param("Protocol") + protocol_version = self._get_param("ProtocolVersion", "HTTP1") port = self._get_param("Port") healthcheck_protocol = self._get_param("HealthCheckProtocol") healthcheck_port = self._get_param("HealthCheckPort") @@ -194,6 +195,7 @@ class ELBV2Response(BaseResponse): name, vpc_id=vpc_id, protocol=protocol, + protocol_version=protocol_version, port=port, healthcheck_protocol=healthcheck_protocol, healthcheck_port=healthcheck_port, @@ -366,14 +368,10 @@ class ELBV2Response(BaseResponse): @amzn_request_id def modify_target_group_attributes(self): target_group_arn = self._get_param("TargetGroupArn") - target_group = self.elbv2_backend.target_groups.get(target_group_arn) - attributes = { - attr["key"]: attr["value"] - for attr in self._get_list_prefix("Attributes.member") - } - target_group.attributes.update(attributes) - if not target_group: - raise TargetGroupNotFoundError() + attributes = self._get_list_prefix("Attributes.member") + attributes = {attr["key"]: attr["value"] for attr in attributes} + self.elbv2_backend.modify_target_group_attributes(target_group_arn, attributes) + template = self.response_template(MODIFY_TARGET_GROUP_ATTRIBUTES_TEMPLATE) return template.render(attributes=attributes) @@ -1085,6 +1083,7 @@ DESCRIBE_TARGET_GROUPS_TEMPLATE = """ """ + MODIFY_TARGET_GROUP_TEMPLATE = """ @@ -1650,7 +1650,7 @@ MODIFY_LISTENER_TEMPLATE = """ {% endfor %} diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index 6af1c4df8..037b52805 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -1354,7 +1354,7 @@ def test_modify_listener_http_to_https(): ) # Bad cert - with pytest.raises(ClientError): + with pytest.raises(ClientError) as exc: client.modify_listener( ListenerArn=listener_arn, Port=443, @@ -1363,6 +1363,85 @@ def test_modify_listener_http_to_https(): Certificates=[{"CertificateArn": "lalala", "IsDefault": True}], DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) + err = exc.value.response["Error"] + err["Message"].should.equal("Certificate lalala not found") + + # Unknown protocol + with pytest.raises(ClientError) as exc: + client.modify_listener( + ListenerArn=listener_arn, + Port=443, + Protocol="HTP", + SslPolicy="ELBSecurityPolicy-TLS-1-2-2017-01", + Certificates=[{"CertificateArn": yahoo_arn, "IsDefault": True}], + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], + ) + err = exc.value.response["Error"] + err["Message"].should.equal("Protocol HTP is not supported") + + +@mock_acm +@mock_ec2 +@mock_elbv2 +def test_modify_listener_of_https_target_group(): + # Verify we can add a listener for a TargetGroup that is already HTTPS + client = boto3.client("elbv2", region_name="eu-central-1") + acm = boto3.client("acm", region_name="eu-central-1") + ec2 = boto3.resource("ec2", region_name="eu-central-1") + + security_group = ec2.create_security_group( + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet1 = ec2.create_subnet( + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="eu-central-1a" + ) + + response = client.create_load_balancer( + Name="my-lb", + Subnets=[subnet1.id], + SecurityGroups=[security_group.id], + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") + + response = client.create_target_group( + Name="a-target", Protocol="HTTPS", Port=8443, VpcId=vpc.id, + ) + target_group = response.get("TargetGroups")[0] + target_group_arn = target_group["TargetGroupArn"] + + # HTTPS listener + response = acm.request_certificate( + DomainName="google.com", SubjectAlternativeNames=["google.com"], + ) + google_arn = response["CertificateArn"] + response = client.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTPS", + Port=443, + Certificates=[{"CertificateArn": google_arn}], + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], + ) + listener_arn = response["Listeners"][0]["ListenerArn"] + + # Now modify the HTTPS listener with a different certificate + response = acm.request_certificate( + DomainName="yahoo.com", SubjectAlternativeNames=["yahoo.com"], + ) + yahoo_arn = response["CertificateArn"] + + listener = client.modify_listener( + ListenerArn=listener_arn, + Certificates=[{"CertificateArn": yahoo_arn,},], + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], + )["Listeners"][0] + listener["Certificates"].should.equal([{"CertificateArn": yahoo_arn}]) + + listener = client.describe_listeners(ListenerArns=[listener_arn])["Listeners"][0] + listener["Certificates"].should.equal([{"CertificateArn": yahoo_arn}]) @mock_elbv2 diff --git a/tests/test_elbv2/test_elbv2_target_groups.py b/tests/test_elbv2/test_elbv2_target_groups.py index 9c5a9ac11..2a665ae5b 100644 --- a/tests/test_elbv2/test_elbv2_target_groups.py +++ b/tests/test_elbv2/test_elbv2_target_groups.py @@ -32,7 +32,7 @@ def test_create_target_group_with_invalid_healthcheck_protocol(): err = exc.value.response["Error"] err["Code"].should.equal("ValidationError") err["Message"].should.equal( - "Value /HTTP at 'healthCheckProtocol' failed to satisfy constraint: Member must satisfy enum value set: ['HTTPS', 'HTTP', 'TCP']" + "Value /HTTP at 'healthCheckProtocol' failed to satisfy constraint: Member must satisfy enum value set: ['HTTPS', 'HTTP', 'TCP', 'TLS', 'UDP', 'TCP_UDP', 'GENEVE']" ) @@ -499,6 +499,8 @@ def test_modify_target_group(): response["TargetGroups"][0]["HealthCheckProtocol"].should.equal("HTTPS") response["TargetGroups"][0]["HealthCheckTimeoutSeconds"].should.equal(10) response["TargetGroups"][0]["HealthyThresholdCount"].should.equal(10) + response["TargetGroups"][0].should.have.key("Protocol").equals("HTTP") + response["TargetGroups"][0].should.have.key("ProtocolVersion").equals("HTTP1") response["TargetGroups"][0]["UnhealthyThresholdCount"].should.equal(4)