diff --git a/.coveragerc b/.coveragerc index 25d85b805..2130ec2ad 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,6 +3,7 @@ exclude_lines = if __name__ == .__main__.: raise NotImplemented. + return NotImplemented def __repr__ [run] diff --git a/moto/emr/responses.py b/moto/emr/responses.py index 234fbc8e7..a5d98ced4 100644 --- a/moto/emr/responses.py +++ b/moto/emr/responses.py @@ -13,7 +13,7 @@ from moto.core.responses import xml_to_json_response from moto.core.utils import tags_from_query_string from .exceptions import EmrError from .models import emr_backends -from .utils import steps_from_query_string, Unflattener +from .utils import steps_from_query_string, Unflattener, ReleaseLabel def generate_boto3_response(operation): @@ -323,7 +323,9 @@ class ElasticMapReduceResponse(BaseResponse): custom_ami_id = self._get_param("CustomAmiId") if custom_ami_id: kwargs["custom_ami_id"] = custom_ami_id - if release_label and release_label < "emr-5.7.0": + if release_label and ( + ReleaseLabel(release_label) < ReleaseLabel("emr-5.7.0") + ): message = "Custom AMI is not allowed" raise EmrError( error_type="ValidationException", diff --git a/moto/emr/utils.py b/moto/emr/utils.py index 48f3232fa..506201c1c 100644 --- a/moto/emr/utils.py +++ b/moto/emr/utils.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals import random +import re import string from moto.core.utils import camelcase_to_underscores @@ -144,3 +145,76 @@ class CamelToUnderscoresWalker: @staticmethod def parse_scalar(x): return x + + +class ReleaseLabel(object): + + version_re = re.compile(r"^emr-(\d+)\.(\d+)\.(\d+)$") + + def __init__(self, release_label): + major, minor, patch = self.parse(release_label) + + self.major = major + self.minor = minor + self.patch = patch + + @classmethod + def parse(cls, release_label): + if not release_label: + raise ValueError("Invalid empty ReleaseLabel: %r" % release_label) + + match = cls.version_re.match(release_label) + if not match: + raise ValueError("Invalid ReleaseLabel: %r" % release_label) + + major, minor, patch = match.groups() + + major = int(major) + minor = int(minor) + patch = int(patch) + + return major, minor, patch + + def __str__(self): + version = "emr-%d.%d.%d" % (self.major, self.minor, self.patch) + return version + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, str(self)) + + def __iter__(self): + return iter((self.major, self.minor, self.patch)) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return ( + self.major == other.major + and self.minor == other.minor + and self.patch == other.patch + ) + + def __ne__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return tuple(self) != tuple(other) + + def __lt__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return tuple(self) < tuple(other) + + def __le__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return tuple(self) <= tuple(other) + + def __gt__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return tuple(self) > tuple(other) + + def __ge__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return tuple(self) >= tuple(other) diff --git a/tests/test_emr/test_emr_boto3.py b/tests/test_emr/test_emr_boto3.py index 8b815e0fa..e2aa49444 100644 --- a/tests/test_emr/test_emr_boto3.py +++ b/tests/test_emr/test_emr_boto3.py @@ -636,7 +636,7 @@ def test_run_job_flow_with_custom_ami(): args = deepcopy(run_job_flow_args) args["CustomAmiId"] = "MyEmrCustomAmi" - args["ReleaseLabel"] = "emr-5.7.0" + args["ReleaseLabel"] = "emr-5.31.0" cluster_id = client.run_job_flow(**args)["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) resp["Cluster"]["CustomAmiId"].should.equal("MyEmrCustomAmi") diff --git a/tests/test_emr/test_utils.py b/tests/test_emr/test_utils.py new file mode 100644 index 000000000..b836ebf48 --- /dev/null +++ b/tests/test_emr/test_utils.py @@ -0,0 +1,49 @@ +import pytest + +from moto.emr.utils import ReleaseLabel + + +def test_invalid_release_labels_raise_exception(): + invalid_releases = [ + "", + "0", + "1.0", + "emr-2.0", + ] + for invalid_release in invalid_releases: + with pytest.raises(ValueError): + ReleaseLabel(invalid_release) + + +def test_release_label_comparisons(): + assert str(ReleaseLabel("emr-5.1.2")) == "emr-5.1.2" + + assert ReleaseLabel("emr-5.0.0") != ReleaseLabel("emr-5.0.1") + assert ReleaseLabel("emr-5.0.0") == ReleaseLabel("emr-5.0.0") + + assert ReleaseLabel("emr-5.31.0") > ReleaseLabel("emr-5.7.0") + assert ReleaseLabel("emr-6.0.0") > ReleaseLabel("emr-5.7.0") + + assert ReleaseLabel("emr-5.7.0") < ReleaseLabel("emr-5.10.0") + assert ReleaseLabel("emr-5.10.0") < ReleaseLabel("emr-5.10.1") + + assert ReleaseLabel("emr-5.60.0") >= ReleaseLabel("emr-5.7.0") + assert ReleaseLabel("emr-6.0.0") >= ReleaseLabel("emr-6.0.0") + + assert ReleaseLabel("emr-5.7.0") <= ReleaseLabel("emr-5.17.0") + assert ReleaseLabel("emr-5.7.0") <= ReleaseLabel("emr-5.7.0") + + releases_unsorted = [ + ReleaseLabel("emr-5.60.2"), + ReleaseLabel("emr-4.0.1"), + ReleaseLabel("emr-4.0.0"), + ReleaseLabel("emr-5.7.3"), + ] + releases_sorted = [str(label) for label in sorted(releases_unsorted)] + expected = [ + "emr-4.0.0", + "emr-4.0.1", + "emr-5.7.3", + "emr-5.60.2", + ] + assert releases_sorted == expected