From feedf1802d6f9e77476066dd1225c75dea655977 Mon Sep 17 00:00:00 2001 From: Karoline Pauls <43616133+karolinepauls@users.noreply.github.com> Date: Tue, 26 Jul 2022 01:29:18 +0100 Subject: [PATCH] Cache Jinja environment by response type (#5308) --- moto/core/responses.py | 36 +++++++++++++++++------------ tests/test_core/test_responses.py | 38 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/moto/core/responses.py b/moto/core/responses.py index b63186288..aeef13fd5 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -25,6 +25,8 @@ from moto import settings log = logging.getLogger(__name__) +JINJA_ENVS = {} + def _decode_dict(d): decoded = OrderedDict() @@ -81,20 +83,30 @@ class _TemplateEnvironmentMixin(object): LEFT_PATTERN = re.compile(r"[\s\n]+<") RIGHT_PATTERN = re.compile(r">[\s\n]+") - def __init__(self): - super().__init__() - self.loader = DynamicDictLoader({}) - self.environment = Environment( - loader=self.loader, autoescape=self.should_autoescape - ) - @property def should_autoescape(self): # Allow for subclass to overwrite return False + @property + def environment(self): + key = type(self) + try: + environment = JINJA_ENVS[key] + except KeyError: + loader = DynamicDictLoader({}) + environment = Environment( + loader=loader, + autoescape=self.should_autoescape, + trim_blocks=True, + lstrip_blocks=True, + ) + JINJA_ENVS[key] = environment + + return environment + def contains_template(self, template_id): - return self.loader.contains(template_id) + return self.environment.loader.contains(template_id) def response_template(self, source): template_id = id(source) @@ -102,13 +114,7 @@ class _TemplateEnvironmentMixin(object): collapsed = re.sub( self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source) ) - self.loader.update({template_id: collapsed}) - self.environment = Environment( - loader=self.loader, - autoescape=self.should_autoescape, - trim_blocks=True, - lstrip_blocks=True, - ) + self.environment.loader.update({template_id: collapsed}) return self.environment.get_template(template_id) diff --git a/tests/test_core/test_responses.py b/tests/test_core/test_responses.py index f454350ea..524d200dd 100644 --- a/tests/test_core/test_responses.py +++ b/tests/test_core/test_responses.py @@ -163,3 +163,41 @@ def test_get_dict_list_params(): result = subject._get_multi_param_dict("VpcSecurityGroupIds") result.should.equal({"VpcSecurityGroupId": ["sg-123", "sg-456", "sg-789"]}) + + +def test_response_environment_preserved_by_type(): + """Ensure Jinja environment is cached by response type.""" + + class ResponseA(BaseResponse): + pass + + class ResponseB(BaseResponse): + pass + + resp_a = ResponseA() + another_resp_a = ResponseA() + resp_b = ResponseB() + + assert resp_a.environment is another_resp_a.environment + assert resp_b.environment is not resp_a.environment + + source_1 = "template" + source_2 = "amother template" + + assert not resp_a.contains_template(id(source_1)) + resp_a.response_template(source_1) + assert resp_a.contains_template(id(source_1)) + + assert not resp_a.contains_template(id(source_2)) + resp_a.response_template(source_2) + assert resp_a.contains_template(id(source_2)) + + assert not resp_b.contains_template(id(source_1)) + assert not resp_b.contains_template(id(source_2)) + + assert another_resp_a.contains_template(id(source_1)) + assert another_resp_a.contains_template(id(source_2)) + + resp_a_new_instance = ResponseA() + assert resp_a_new_instance.contains_template(id(source_1)) + assert resp_a_new_instance.contains_template(id(source_2))