diff --git a/requirements-dev.txt b/requirements-dev.txt index 7dda4026b..6d84d7a86 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,3 +12,4 @@ six>=1.9 prompt-toolkit==1.0.14 click==6.7 inflection==0.3.1 +lxml==4.0.0 diff --git a/scaffold.py b/scaffold.py index 2c168da69..d2f06b127 100755 --- a/scaffold.py +++ b/scaffold.py @@ -1,6 +1,8 @@ #!/usr/bin/env python import os import re +import inspect +import importlib from lxml import etree import click @@ -15,6 +17,8 @@ from botocore import xform_name from botocore.session import Session import boto3 +from moto.core.responses import BaseResponse +from moto.core import BaseBackend from implementation_coverage import ( get_moto_implementation ) @@ -142,7 +146,7 @@ def to_snake_case(s): return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() -def get_function_in_query_responses(service, operation): +def get_function_in_responses(service, operation, protocol): """refers to definition of API in botocore, and autogenerates function You can see example of elbv2 from link below. https://github.com/boto/botocore/blob/develop/botocore/data/elbv2/2015-12-01/service-2.json @@ -174,10 +178,14 @@ def get_function_in_query_responses(service, operation): body += ' {}={},\n'.format(input_name, input_name) body += ' )\n' - body += ' template = self.response_template({}_TEMPLATE)\n'.format(operation.upper()) - body += ' return template.render({})\n'.format( - ','.join(['{}={}'.format(_, _) for _ in output_names]) - ) + if protocol == 'query': + body += ' template = self.response_template({}_TEMPLATE)\n'.format(operation.upper()) + body += ' return template.render({})\n'.format( + ','.join(['{}={}'.format(_, _) for _ in output_names]) + ) + elif protocol == 'json': + body += ' # TODO: adjust reponse\n' + body += ' return json.dumps({})\n'.format(','.join(['{}={}'.format(_, _) for _ in output_names])) return body @@ -255,8 +263,8 @@ def get_response_query_template(service, operation): for output_name, output_shape in outputs.items(): t_result.append(_get_subtree(output_name, output_shape, replace_list)) t_root.append(t_result) - body = etree.tostring(t_root, pretty_print=True).decode('utf-8') - body_lines = body.splitlines() + xml_body = etree.tostring(t_root, pretty_print=True).decode('utf-8') + xml_body_lines = xml_body.splitlines() for replace in replace_list: name = replace[0] prefix = replace[1] @@ -268,35 +276,91 @@ def get_response_query_template(service, operation): end_tag = '' % name loop_end = '{{ endfor }}' - start_tag_indexes = [i for i, l in enumerate(body_lines) if start_tag in l] + start_tag_indexes = [i for i, l in enumerate(xml_body_lines) if start_tag in l] if len(start_tag_indexes) != 1: raise Exception('tag %s not found in response body' % start_tag) start_tag_index = start_tag_indexes[0] - body_lines.insert(start_tag_index + 1, loop_start) + xml_body_lines.insert(start_tag_index + 1, loop_start) - end_tag_indexes = [i for i, l in enumerate(body_lines) if end_tag in l] + end_tag_indexes = [i for i, l in enumerate(xml_body_lines) if end_tag in l] if len(end_tag_indexes) != 1: raise Exception('tag %s not found in response body' % end_tag) end_tag_index = end_tag_indexes[0] - body_lines.insert(end_tag_index, loop_end) - body = '\n'.join(body_lines) + xml_body_lines.insert(end_tag_index, loop_end) + xml_body = '\n'.join(xml_body_lines) + body = '\n{}_TEMPLATE = """{}"""'.format(operation.upper(), xml_body) return body + +def insert_code_to_class(path, base_class, new_code): + with open(path) as f: + lines = [_.replace('\n', '') for _ in f.readlines()] + mod_path = os.path.splitext(path)[0].replace('/', '.') + mod = importlib.import_module(mod_path) + clsmembers = inspect.getmembers(mod, inspect.isclass) + _response_cls = [_[1] for _ in clsmembers if issubclass(_[1], base_class) and _[1] != base_class] + if len(_response_cls) != 1: + raise Exception('unknown error, number of clsmembers is not 1') + response_cls = _response_cls[0] + code_lines, line_no = inspect.getsourcelines(response_cls) + end_line_no = line_no + len(code_lines) + + func_lines = [' ' * 4 + _ for _ in new_code.splitlines()] + + lines = lines[:end_line_no] + func_lines + lines[end_line_no:] + + with open(path, 'w') as f: + f.write('\n'.join(lines)) + + +def insert_query_codes(service, operation): + func_in_responses = get_function_in_responses(service, operation, 'query') + func_in_models = get_function_in_models(service, operation) + template = get_response_query_template(service, operation) + + # edit responses.py + responses_path = 'moto/{}/responses.py'.format(service) + print_progress('inserting code', responses_path, 'green') + insert_code_to_class(responses_path, BaseResponse, func_in_responses) + + # insert template + with open(responses_path) as f: + lines = [_[:-1] for _ in f.readlines()] + lines += template.splitlines() + with open(responses_path, 'w') as f: + f.write('\n'.join(lines)) + + # edit models.py + models_path = 'moto/{}/models.py'.format(service) + print_progress('inserting code', models_path, 'green') + insert_code_to_class(models_path, BaseBackend, func_in_models) + +def insert_json_codes(service, operation): + func_in_responses = get_function_in_responses(service, operation, 'json') + func_in_models = get_function_in_models(service, operation) + + # edit responses.py + responses_path = 'moto/{}/responses.py'.format(service) + print_progress('inserting code', responses_path, 'green') + insert_code_to_class(responses_path, BaseResponse, func_in_responses) + + # edit models.py + models_path = 'moto/{}/models.py'.format(service) + print_progress('inserting code', models_path, 'green') + insert_code_to_class(models_path, BaseBackend, func_in_models) + @click.command() def main(): service, operation = select_service_and_operation() - api_protocol = boto3.client(service_name)._service_model.metadata['protocol'] + api_protocol = boto3.client(service)._service_model.metadata['protocol'] initialize_service(service, operation, api_protocol) if api_protocol == 'query': - func_in_responses = get_function_in_responses(service, operation) - func_in_models = get_function_in_models(service, operation) - teamplte = get_response_xml_template(service, operation) - - + insert_query_codes(service, operation) + elif api_protocol == 'json': + insert_json_codes(service, operation) + pass + else: + print_progress('skip inserting code', 'api protocol "{}" is not supported'.format(api_protocol), 'yellow') if __name__ == '__main__': -# print(get_function_in_responses('elbv2', 'describe_listeners')) -# print(get_function_in_models('elbv2', 'describe_listeners')) - b = get_response_query_template('elbv2', 'describe_listeners') - print(b) -# main() + main()