move back to bundled httpretty for now
This commit is contained in:
parent
46b38c705c
commit
fe2b3518ae
@ -1,7 +1,7 @@
|
||||
import functools
|
||||
import re
|
||||
|
||||
from httpretty import HTTPretty
|
||||
from moto.packages.httpretty import HTTPretty
|
||||
from .responses import metadata_response
|
||||
from .utils import convert_regex_to_flask_path
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# #!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# <HTTPretty - HTTP client mock for Python>
|
||||
# Copyright (C) <2011-2012> Gabriel Falcão <gabriel@nacaolivre.org>
|
||||
# Copyright (C) <2011-2013> Gabriel Falcão <gabriel@nacaolivre.org>
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person
|
||||
# obtaining a copy of this software and associated documentation
|
||||
@ -23,7 +23,9 @@
|
||||
# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
# OTHER DEALINGS IN THE SOFTWARE.
|
||||
version = '0.5.8'
|
||||
from __future__ import unicode_literals
|
||||
|
||||
version = '0.5.12'
|
||||
|
||||
import re
|
||||
import inspect
|
||||
@ -40,8 +42,10 @@ PY3 = sys.version_info[0] == 3
|
||||
if PY3:
|
||||
text_type = str
|
||||
byte_type = bytes
|
||||
basestring = (str, bytes)
|
||||
|
||||
import io
|
||||
StringIO = io.StringIO
|
||||
StringIO = io.BytesIO
|
||||
|
||||
class Py3kObject(object):
|
||||
def __repr__(self):
|
||||
@ -64,9 +68,10 @@ class Py3kObject(object):
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
try:
|
||||
from urllib.parse import urlsplit, parse_qs
|
||||
from urllib.parse import urlsplit, urlunsplit, parse_qs, quote, quote_plus
|
||||
except ImportError:
|
||||
from urlparse import urlsplit, parse_qs
|
||||
from urlparse import urlsplit, urlunsplit, parse_qs
|
||||
from urllib import quote, quote_plus
|
||||
|
||||
try:
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
@ -99,6 +104,14 @@ except ImportError:
|
||||
ssl = None
|
||||
|
||||
|
||||
ClassTypes = (type,)
|
||||
if not PY3:
|
||||
ClassTypes = (type, types.ClassType)
|
||||
|
||||
|
||||
POTENTIAL_HTTP_PORTS = [80, 443]
|
||||
|
||||
|
||||
class HTTPrettyError(Exception):
|
||||
pass
|
||||
|
||||
@ -110,6 +123,13 @@ def utf8(s):
|
||||
return byte_type(s)
|
||||
|
||||
|
||||
def decode_utf8(s):
|
||||
if isinstance(s, byte_type):
|
||||
s = s.decode("utf-8")
|
||||
|
||||
return text_type(s)
|
||||
|
||||
|
||||
def parse_requestline(s):
|
||||
"""
|
||||
http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5
|
||||
@ -123,8 +143,8 @@ def parse_requestline(s):
|
||||
...
|
||||
ValueError: Not a Request-Line
|
||||
"""
|
||||
methods = '|'.join(HTTPretty.METHODS)
|
||||
m = re.match(r'('+methods+')\s+(.*)\s+HTTP/(1.[0|1])', s, re.I)
|
||||
methods = b'|'.join(HTTPretty.METHODS)
|
||||
m = re.match(br'(' + methods + b')\s+(.*)\s+HTTP/(1.[0|1])', s, re.I)
|
||||
if m:
|
||||
return m.group(1).upper(), m.group(2), m.group(3)
|
||||
else:
|
||||
@ -135,7 +155,9 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, Py3kObject):
|
||||
def __init__(self, headers, body=''):
|
||||
self.body = utf8(body)
|
||||
self.raw_headers = utf8(headers)
|
||||
self.rfile = StringIO('\r\n\r\n'.join([headers.strip(), body]))
|
||||
self.client_address = ['10.0.0.1']
|
||||
self.rfile = StringIO(b'\r\n\r\n'.join([headers.strip(), body]))
|
||||
self.wfile = StringIO()
|
||||
self.raw_requestline = self.rfile.readline()
|
||||
self.error_code = self.error_message = None
|
||||
self.parse_request()
|
||||
@ -159,15 +181,7 @@ class HTTPrettyRequestEmpty(object):
|
||||
|
||||
|
||||
class FakeSockFile(StringIO):
|
||||
def read(self, amount=None):
|
||||
amount = amount or self.len
|
||||
new_amount = amount
|
||||
|
||||
if amount > self.len:
|
||||
new_amount = self.len - self.tell()
|
||||
|
||||
ret = StringIO.read(self, new_amount)
|
||||
return ret
|
||||
pass
|
||||
|
||||
|
||||
class FakeSSLSocket(object):
|
||||
@ -194,6 +208,7 @@ class fakesock(object):
|
||||
self.fd = FakeSockFile()
|
||||
self.timeout = socket._GLOBAL_DEFAULT_TIMEOUT
|
||||
self._sock = self
|
||||
self.is_http = False
|
||||
|
||||
def getpeercert(self, *a, **kw):
|
||||
now = datetime.now()
|
||||
@ -230,6 +245,9 @@ class fakesock(object):
|
||||
def connect(self, address):
|
||||
self._address = (self._host, self._port) = address
|
||||
self._closed = False
|
||||
self.is_http = self._port in POTENTIAL_HTTP_PORTS
|
||||
if not self.is_http:
|
||||
self.truesock.connect(self._address)
|
||||
|
||||
def close(self):
|
||||
if not self._closed:
|
||||
@ -246,17 +264,22 @@ class fakesock(object):
|
||||
return self.fd
|
||||
|
||||
def _true_sendall(self, data, *args, **kw):
|
||||
self.truesock.connect(self._address)
|
||||
if self.is_http:
|
||||
self.truesock.connect(self._address)
|
||||
|
||||
self.truesock.sendall(data, *args, **kw)
|
||||
_d = self.truesock.recv(16)
|
||||
self.fd.seek(0)
|
||||
self.fd.write(_d)
|
||||
|
||||
_d = True
|
||||
while _d:
|
||||
_d = self.truesock.recv(16)
|
||||
self.fd.write(_d)
|
||||
try:
|
||||
_d = self.truesock.recv(16)
|
||||
self.truesock.settimeout(0.0)
|
||||
self.fd.write(_d)
|
||||
|
||||
except socket.error:
|
||||
break
|
||||
|
||||
self.fd.seek(0)
|
||||
self.truesock.close()
|
||||
|
||||
def sendall(self, data, *args, **kw):
|
||||
|
||||
@ -264,22 +287,13 @@ class fakesock(object):
|
||||
hostnames = [getattr(i.info, 'hostname', None) for i in HTTPretty._entries.keys()]
|
||||
self.fd.seek(0)
|
||||
try:
|
||||
print("data", data)
|
||||
requestline, _ = data.split('\r\n', 1)
|
||||
requestline, _ = data.split(b'\r\n', 1)
|
||||
method, path, version = parse_requestline(requestline)
|
||||
is_parsing_headers = True
|
||||
except ValueError:
|
||||
is_parsing_headers = False
|
||||
|
||||
# This need to be reconsidered. URIMatchers with regexs don't
|
||||
# have hostnames which can cause this to return even though
|
||||
# the regex may have matched
|
||||
# if self._host not in hostnames:
|
||||
# return self._true_sendall(data)
|
||||
|
||||
import pdb;pdb.set_trace()
|
||||
if not is_parsing_headers:
|
||||
|
||||
if len(self._sent_data) > 1:
|
||||
headers, body = map(utf8, self._sent_data[-2:])
|
||||
|
||||
@ -288,8 +302,7 @@ class fakesock(object):
|
||||
|
||||
info = URIInfo(hostname=self._host, port=self._port,
|
||||
path=split_url.path,
|
||||
query=split_url.query,
|
||||
method=method)
|
||||
query=split_url.query)
|
||||
|
||||
# If we are sending more data to a dynamic response entry,
|
||||
# we need to call the method again.
|
||||
@ -298,41 +311,36 @@ class fakesock(object):
|
||||
|
||||
try:
|
||||
return HTTPretty.historify_request(headers, body, False)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(traceback.format_exc(e))
|
||||
return self._true_sendall(data, *args, **kw)
|
||||
|
||||
# path might come with
|
||||
s = urlsplit(path)
|
||||
|
||||
headers, body = map(utf8, data.split('\r\n\r\n', 1))
|
||||
POTENTIAL_HTTP_PORTS.append(int(s.port or 80))
|
||||
headers, body = map(utf8, data.split(b'\r\n\r\n', 1))
|
||||
|
||||
request = HTTPretty.historify_request(headers, body)
|
||||
|
||||
info = URIInfo(hostname=self._host, port=self._port,
|
||||
path=s.path,
|
||||
query=s.query,
|
||||
last_request=request,
|
||||
method=method)
|
||||
last_request=request)
|
||||
|
||||
entries = []
|
||||
|
||||
for matcher, value in HTTPretty._entries.items():
|
||||
if matcher.matches(info) and matcher.method == method:
|
||||
if matcher.matches(info):
|
||||
entries = value
|
||||
#info = matcher.info
|
||||
break
|
||||
|
||||
if not entries:
|
||||
self._true_sendall(data)
|
||||
return
|
||||
|
||||
entry = matcher.get_next_entry()
|
||||
if entry.method == method:
|
||||
self._entry = entry
|
||||
self._request = (info, method, body, headers)
|
||||
else:
|
||||
raise ValueError("No match found for", method, entry.uri)
|
||||
self._entry = matcher.get_next_entry(method)
|
||||
self._request = (info, body, headers)
|
||||
|
||||
def debug(*a, **kw):
|
||||
frame = inspect.stack()[0][0]
|
||||
@ -343,7 +351,7 @@ class fakesock(object):
|
||||
("Please open an issue at "
|
||||
"'https://github.com/gabrielfalcao/HTTPretty/issues'"),
|
||||
"And paste the following traceback:\n",
|
||||
"".join(lines),
|
||||
"".join(decode_utf8(lines)),
|
||||
]
|
||||
raise RuntimeError("\n".join(message))
|
||||
|
||||
@ -527,75 +535,72 @@ class Entry(Py3kObject):
|
||||
def normalize_headers(self, headers):
|
||||
new = {}
|
||||
for k in headers:
|
||||
new_k = '-'.join([s.title() for s in k.split('-')])
|
||||
new_k = '-'.join([s.lower() for s in k.split('-')])
|
||||
new[new_k] = headers[k]
|
||||
|
||||
return new
|
||||
|
||||
def fill_filekind(self, fk, request):
|
||||
req_info, method, req_body, req_headers = request
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
headers = {
|
||||
'Status': self.status,
|
||||
'Date': now.strftime('%a, %d %b %Y %H:%M:%S GMT'),
|
||||
'Server': 'Python/HTTPretty',
|
||||
'Connection': 'close',
|
||||
'status': self.status,
|
||||
'date': now.strftime('%a, %d %b %Y %H:%M:%S GMT'),
|
||||
'server': 'Python/HTTPretty',
|
||||
'connection': 'close',
|
||||
}
|
||||
|
||||
if self.dynamic_response:
|
||||
response = self.body(req_info, method, req_body, req_headers)
|
||||
if isinstance(response, basestring):
|
||||
body = response
|
||||
new_headers = {}
|
||||
else:
|
||||
body, new_headers = response
|
||||
else:
|
||||
body = self.body
|
||||
new_headers = {}
|
||||
|
||||
|
||||
if self.forcing_headers:
|
||||
headers = self.forcing_headers
|
||||
|
||||
headers.update(new_headers)
|
||||
if self.dynamic_response:
|
||||
req_info, req_body, req_headers = request
|
||||
response = self.body(req_info, self.method, req_body, req_headers)
|
||||
if isinstance(response, basestring):
|
||||
body = response
|
||||
else:
|
||||
body, new_headers = response
|
||||
headers.update(new_headers)
|
||||
else:
|
||||
body = self.body
|
||||
|
||||
if self.adding_headers:
|
||||
headers.update(self.adding_headers)
|
||||
headers.update(self.normalize_headers(self.adding_headers))
|
||||
|
||||
headers = self.normalize_headers(headers)
|
||||
|
||||
status = headers.get('Status', self.status)
|
||||
status = headers.get('status', self.status)
|
||||
string_list = [
|
||||
'HTTP/1.1 %d %s' % (status, STATUSES[status]),
|
||||
]
|
||||
|
||||
if 'Date' in headers:
|
||||
string_list.append('Date: %s' % headers.pop('Date'))
|
||||
if 'date' in headers:
|
||||
string_list.append('date: %s' % headers.pop('date'))
|
||||
|
||||
if not self.forcing_headers:
|
||||
content_type = headers.pop('Content-Type',
|
||||
content_type = headers.pop('content-type',
|
||||
'text/plain; charset=utf-8')
|
||||
|
||||
body_length = self.body_length
|
||||
if self.dynamic_response:
|
||||
body_length = len(body)
|
||||
content_length = headers.pop('Content-Length', body_length)
|
||||
content_length = headers.pop('content-length', body_length)
|
||||
|
||||
string_list.append('Content-Type: %s' % content_type)
|
||||
string_list.append('content-type: %s' % content_type)
|
||||
if not self.streaming:
|
||||
string_list.append('Content-Length: %s' % content_length)
|
||||
string_list.append('content-length: %s' % content_length)
|
||||
|
||||
string_list.append('Server: %s' % headers.pop('Server'))
|
||||
string_list.append('server: %s' % headers.pop('server'))
|
||||
|
||||
for k, v in headers.items():
|
||||
string_list.append(
|
||||
'%s: %s' % (k, utf8(v)),
|
||||
'{0}: {1}'.format(k, v),
|
||||
)
|
||||
|
||||
fk.write("\n".join(string_list))
|
||||
fk.write('\n\r\n')
|
||||
for item in string_list:
|
||||
fk.write(utf8(item) + b'\n')
|
||||
|
||||
fk.write(b'\r\n')
|
||||
|
||||
if self.streaming:
|
||||
self.body, body = itertools.tee(body)
|
||||
@ -608,25 +613,10 @@ class Entry(Py3kObject):
|
||||
|
||||
|
||||
def url_fix(s, charset='utf-8'):
|
||||
import urllib
|
||||
import urlparse
|
||||
"""Sometimes you get an URL by a user that just isn't a real
|
||||
URL because it contains unsafe characters like ' ' and so on. This
|
||||
function can fix some of the problems in a similar way browsers
|
||||
handle data entered by the user:
|
||||
|
||||
>>> url_fix(u'http://de.wikipedia.org/wiki/Elf (Begriffsklärung)')
|
||||
'http://de.wikipedia.org/wiki/Elf%20%28Begriffskl%C3%A4rung%29'
|
||||
|
||||
:param charset: The target charset for the URL if the url was
|
||||
given as unicode string.
|
||||
"""
|
||||
if isinstance(s, unicode):
|
||||
s = s.encode(charset, 'ignore')
|
||||
scheme, netloc, path, qs, anchor = urlparse.urlsplit(s)
|
||||
path = urllib.quote(path, '/%')
|
||||
qs = urllib.quote_plus(qs, ':&=')
|
||||
return urlparse.urlunsplit((scheme, netloc, path, qs, anchor))
|
||||
scheme, netloc, path, querystring, fragment = urlsplit(s)
|
||||
path = quote(path, b'/%')
|
||||
querystring = quote_plus(querystring, b':&=')
|
||||
return urlunsplit((scheme, netloc, path, querystring, fragment))
|
||||
|
||||
|
||||
class URIInfo(Py3kObject):
|
||||
@ -639,7 +629,6 @@ class URIInfo(Py3kObject):
|
||||
query='',
|
||||
fragment='',
|
||||
scheme='',
|
||||
method=None,
|
||||
last_request=None):
|
||||
|
||||
self.username = username or ''
|
||||
@ -653,11 +642,10 @@ class URIInfo(Py3kObject):
|
||||
port = 443
|
||||
|
||||
self.port = port or 80
|
||||
self.path = url_fix(path) or ''
|
||||
self.path = path or ''
|
||||
self.query = query or ''
|
||||
self.scheme = scheme or (self.port is 80 and "http" or "https")
|
||||
self.fragment = fragment or ''
|
||||
self.method = method
|
||||
self.last_request = last_request
|
||||
|
||||
def __str__(self):
|
||||
@ -675,17 +663,17 @@ class URIInfo(Py3kObject):
|
||||
return hash(text_type(self))
|
||||
|
||||
def __eq__(self, other):
|
||||
orig_hostname = self.hostname
|
||||
orig_other = other.hostname
|
||||
|
||||
self.hostname = None
|
||||
other.hostname = None
|
||||
result = text_type(self) == text_type(other)
|
||||
|
||||
self.hostname = orig_hostname
|
||||
other.hostname = orig_other
|
||||
|
||||
return result
|
||||
self_tuple = (
|
||||
self.port,
|
||||
decode_utf8(self.hostname),
|
||||
url_fix(decode_utf8(self.path)),
|
||||
)
|
||||
other_tuple = (
|
||||
other.port,
|
||||
decode_utf8(other.hostname),
|
||||
url_fix(decode_utf8(other.path)),
|
||||
)
|
||||
return self_tuple == other_tuple
|
||||
|
||||
def full_url(self):
|
||||
credentials = ""
|
||||
@ -693,21 +681,18 @@ class URIInfo(Py3kObject):
|
||||
credentials = "{0}:{1}@".format(
|
||||
self.username, self.password)
|
||||
|
||||
# query = ""
|
||||
# if self.query:
|
||||
# query = "?{0}".format(self.query)
|
||||
|
||||
return "{scheme}://{credentials}{host}{path}".format(
|
||||
result = "{scheme}://{credentials}{host}{path}".format(
|
||||
scheme=self.scheme,
|
||||
credentials=credentials,
|
||||
host=self.hostname,
|
||||
path=self.path,
|
||||
#query=query
|
||||
host=decode_utf8(self.hostname),
|
||||
path=decode_utf8(self.path)
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_uri(cls, uri, entry):
|
||||
result = urlsplit(uri)
|
||||
POTENTIAL_HTTP_PORTS.append(int(result.port or 80))
|
||||
return cls(result.username,
|
||||
result.password,
|
||||
result.hostname,
|
||||
@ -723,15 +708,16 @@ class URIMatcher(object):
|
||||
regex = None
|
||||
info = None
|
||||
|
||||
def __init__(self, uri, method, entries):
|
||||
def __init__(self, uri, entries):
|
||||
if type(uri).__name__ == 'SRE_Pattern':
|
||||
self.regex = uri
|
||||
else:
|
||||
self.info = URIInfo.from_uri(uri, entries)
|
||||
|
||||
self.method = method
|
||||
self.entries = entries
|
||||
self.current_entry = 0
|
||||
|
||||
#hash of current_entry pointers, per method.
|
||||
self.current_entries = {}
|
||||
|
||||
def matches(self, info):
|
||||
if self.info:
|
||||
@ -740,22 +726,32 @@ class URIMatcher(object):
|
||||
return self.regex.search(info.full_url())
|
||||
|
||||
def __str__(self):
|
||||
wrap = 'URLMatcher({0} {1})'
|
||||
wrap = 'URLMatcher({0})'
|
||||
if self.info:
|
||||
return wrap.format(text_type(self.info), self.method)
|
||||
return wrap.format(text_type(self.info))
|
||||
else:
|
||||
return wrap.format(self.regex.pattern, self.method)
|
||||
return wrap.format(self.regex.pattern)
|
||||
|
||||
def get_next_entry(self):
|
||||
if self.current_entry >= len(self.entries):
|
||||
self.current_entry = -1
|
||||
def get_next_entry(self, method='GET'):
|
||||
"""Cycle through available responses, but only once.
|
||||
Any subsequent requests will receive the last response"""
|
||||
|
||||
if not self.entries:
|
||||
raise ValueError('I have no entries: %s' % self)
|
||||
if method not in self.current_entries:
|
||||
self.current_entries[method] = 0
|
||||
|
||||
entry = self.entries[self.current_entry]
|
||||
if self.current_entry != -1:
|
||||
self.current_entry += 1
|
||||
#restrict selection to entries that match the requested method
|
||||
entries_for_method = [e for e in self.entries if e.method == method]
|
||||
|
||||
if self.current_entries[method] >= len(entries_for_method):
|
||||
self.current_entries[method] = -1
|
||||
|
||||
if not self.entries or not entries_for_method:
|
||||
raise ValueError('I have no entries for method %s: %s'
|
||||
% (method, self))
|
||||
|
||||
entry = entries_for_method[self.current_entries[method]]
|
||||
if self.current_entries[method] != -1:
|
||||
self.current_entries[method] += 1
|
||||
return entry
|
||||
|
||||
def __hash__(self):
|
||||
@ -769,14 +765,15 @@ class HTTPretty(Py3kObject):
|
||||
u"""The URI registration class"""
|
||||
_entries = {}
|
||||
latest_requests = []
|
||||
GET = 'GET'
|
||||
PUT = 'PUT'
|
||||
POST = 'POST'
|
||||
DELETE = 'DELETE'
|
||||
HEAD = 'HEAD'
|
||||
PATCH = 'PATCH'
|
||||
GET = b'GET'
|
||||
PUT = b'PUT'
|
||||
POST = b'POST'
|
||||
DELETE = b'DELETE'
|
||||
HEAD = b'HEAD'
|
||||
PATCH = b'PATCH'
|
||||
METHODS = (GET, PUT, POST, DELETE, HEAD, PATCH)
|
||||
last_request = HTTPrettyRequestEmpty()
|
||||
_is_enabled = False
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
@ -802,6 +799,9 @@ class HTTPretty(Py3kObject):
|
||||
responses=None, **headers):
|
||||
|
||||
if isinstance(responses, list) and len(responses) > 0:
|
||||
for response in responses:
|
||||
response.uri = uri
|
||||
response.method = method
|
||||
entries_for_this_uri = responses
|
||||
else:
|
||||
headers['body'] = body
|
||||
@ -813,11 +813,9 @@ class HTTPretty(Py3kObject):
|
||||
cls.Response(method=method, uri=uri, **headers),
|
||||
]
|
||||
|
||||
map(lambda e: setattr(e, 'uri', uri) or setattr(e, 'method', method),
|
||||
entries_for_this_uri)
|
||||
|
||||
matcher = URIMatcher(uri, method, entries_for_this_uri)
|
||||
matcher = URIMatcher(uri, entries_for_this_uri)
|
||||
if matcher in cls._entries:
|
||||
matcher.entries.extend(cls._entries[matcher])
|
||||
del cls._entries[matcher]
|
||||
|
||||
cls._entries[matcher] = entries_for_this_uri
|
||||
@ -838,6 +836,7 @@ class HTTPretty(Py3kObject):
|
||||
|
||||
@classmethod
|
||||
def disable(cls):
|
||||
cls._is_enabled = False
|
||||
socket.socket = old_socket
|
||||
socket.SocketType = old_socket
|
||||
socket._socketobject = old_socket
|
||||
@ -872,8 +871,13 @@ class HTTPretty(Py3kObject):
|
||||
ssl.sslwrap_simple = old_sslwrap_simple
|
||||
ssl.__dict__['sslwrap_simple'] = old_sslwrap_simple
|
||||
|
||||
@classmethod
|
||||
def is_enabled(cls):
|
||||
return cls._is_enabled
|
||||
|
||||
@classmethod
|
||||
def enable(cls):
|
||||
cls._is_enabled = True
|
||||
socket.socket = fakesock.socket
|
||||
socket._socketobject = fakesock.socket
|
||||
socket.SocketType = fakesock.socket
|
||||
@ -912,12 +916,29 @@ class HTTPretty(Py3kObject):
|
||||
|
||||
def httprettified(test):
|
||||
"A decorator tests that use HTTPretty"
|
||||
@functools.wraps(test)
|
||||
def wrapper(*args, **kw):
|
||||
HTTPretty.reset()
|
||||
HTTPretty.enable()
|
||||
try:
|
||||
return test(*args, **kw)
|
||||
finally:
|
||||
HTTPretty.disable()
|
||||
return wrapper
|
||||
def decorate_class(klass):
|
||||
for attr in dir(klass):
|
||||
if not attr.startswith('test_'):
|
||||
continue
|
||||
|
||||
attr_value = getattr(klass, attr)
|
||||
if not hasattr(attr_value, "__call__"):
|
||||
continue
|
||||
|
||||
setattr(klass, attr, decorate_callable(attr_value))
|
||||
return klass
|
||||
|
||||
def decorate_callable(test):
|
||||
@functools.wraps(test)
|
||||
def wrapper(*args, **kw):
|
||||
HTTPretty.reset()
|
||||
HTTPretty.enable()
|
||||
try:
|
||||
return test(*args, **kw)
|
||||
finally:
|
||||
HTTPretty.disable()
|
||||
return wrapper
|
||||
|
||||
if isinstance(test, ClassTypes):
|
||||
return decorate_class(test)
|
||||
return decorate_callable(test)
|
||||
|
Loading…
Reference in New Issue
Block a user