diff --git a/moto/core/models.py b/moto/core/models.py index 720d98c95..72b2cb512 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals + import functools +import inspect import re from httpretty import HTTPretty @@ -17,6 +19,8 @@ class MockAWS(object): HTTPretty.reset() def __call__(self, func): + if inspect.isclass(func): + return self.decorate_class(func) return self.decorate_callable(func) def __enter__(self): @@ -67,6 +71,26 @@ class MockAWS(object): wrapper.__wrapped__ = func return wrapper + def decorate_class(self, klass): + for attr in dir(klass): + if attr.startswith("_"): + continue + + attr_value = getattr(klass, attr) + if not hasattr(attr_value, "__call__"): + continue + + # Check if this is a classmethod. If so, skip patching + if inspect.ismethod(attr_value) and attr_value.__self__ is klass: + continue + + try: + setattr(klass, attr, self(attr_value)) + except TypeError: + # Sometimes we can't set this for built-in types + continue + return klass + class Model(type): def __new__(self, clsname, bases, namespace): diff --git a/tests/test_core/test_decorator_calls.py b/tests/test_core/test_decorator_calls.py index 2688abc41..5360061c8 100644 --- a/tests/test_core/test_decorator_calls.py +++ b/tests/test_core/test_decorator_calls.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import boto from boto.exception import EC2ResponseError import sure # noqa -import tests.backport_assert_raises +import tests.backport_assert_raises # noqa from nose.tools import assert_raises from moto import mock_ec2 @@ -57,3 +57,14 @@ def test_decorater_wrapped_gets_set(): Moto decorator's __wrapped__ should get set to the tests function """ test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal('test_decorater_wrapped_gets_set') + + +@mock_ec2 +class Tester(object): + def test_the_class(self): + conn = boto.connect_ec2() + list(conn.get_all_instances()).should.have.length_of(0) + + def test_still_the_same(self): + conn = boto.connect_ec2() + list(conn.get_all_instances()).should.have.length_of(0)