Ausgabe der neuen DB Einträge

This commit is contained in:
hubobel 2022-01-02 21:50:48 +01:00
parent bad48e1627
commit cfbbb9ee3d
2399 changed files with 843193 additions and 43 deletions

View file

@ -0,0 +1,59 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import re
_split_words_capitalization_re = re.compile(
'^[a-z0-9]+|[A-Z][a-z0-9]+|[A-Z]+(?=[A-Z][a-z0-9])|[A-Z]+$'
)
_split_words_dashes_re = re.compile('[-_/]+')
def split_words(name):
"""
Splits name based on capitalization, dashes, and underscores.
Example: 'GetFile' -> ['Get', 'File']
Example: 'get_file' -> ['get', 'file']
"""
all_words = []
for word in re.split(_split_words_dashes_re, name):
vals = _split_words_capitalization_re.findall(word)
if vals:
all_words.extend(vals)
else:
all_words.append(word)
return all_words
def fmt_camel(name):
"""
Converts name to lower camel case. Words are identified by capitalization,
dashes, and underscores.
"""
words = split_words(name)
assert len(words) > 0
first = words.pop(0).lower()
return first + ''.join([word.capitalize() for word in words])
def fmt_dashes(name):
"""
Converts name to words separated by dashes. Words are identified by
capitalization, dashes, and underscores.
"""
return '-'.join([word.lower() for word in split_words(name)])
def fmt_pascal(name):
"""
Converts name to pascal case. Words are identified by capitalization,
dashes, and underscores.
"""
return ''.join([word.capitalize() for word in split_words(name)])
def fmt_underscores(name):
"""
Converts name to words separated by underscores. Words are identified by
capitalization, dashes, and underscores.
"""
return '_'.join([word.lower() for word in split_words(name)])

View file

@ -0,0 +1,150 @@
from __future__ import absolute_import, division, print_function, unicode_literals
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
from stone.ir import ApiNamespace
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
from stone.backend import CodeBackend
from stone.backends.js_helpers import (
check_route_name_conflict,
fmt_error_type,
fmt_func,
fmt_obj,
fmt_type,
fmt_url,
)
from stone.ir import Void
_cmdline_parser = argparse.ArgumentParser(prog='js-client-backend')
_cmdline_parser.add_argument(
'filename',
help=('The name to give the single Javascript file that is created and '
'contains all of the routes.'),
)
_cmdline_parser.add_argument(
'-c',
'--class-name',
type=str,
help=('The name of the class the generated functions will be attached to. '
'The name will be added to each function documentation, which makes '
'it available for tools like JSDoc.'),
)
_cmdline_parser.add_argument(
'--wrap-response-in',
type=str,
default='',
help=('Wraps the response in a response class')
)
_header = """\
// Auto-generated by Stone, do not modify.
var routes = {};
"""
class JavascriptClientBackend(CodeBackend):
"""Generates a single Javascript file with all of the routes defined."""
cmdline_parser = _cmdline_parser
# Instance var of the current namespace being generated
cur_namespace = None # type: typing.Optional[ApiNamespace]
preserve_aliases = True
def generate(self, api):
# first check for route name conflict
with self.output_to_relative_path(self.args.filename):
self.emit_raw(_header)
for namespace in api.namespaces.values():
# Hack: needed for _docf()
self.cur_namespace = namespace
check_route_name_conflict(namespace)
for route in namespace.routes:
self._generate_route(api.route_schema, namespace, route)
self.emit()
self.emit('export { routes };')
def _generate_route(self, route_schema, namespace, route):
function_name = fmt_func(namespace.name + '_' + route.name, route.version)
self.emit()
self.emit('/**')
if route.doc:
self.emit_wrapped_text(self.process_doc(route.doc, self._docf), prefix=' * ')
if self.args.class_name:
self.emit(' * @function {}#{}'.format(self.args.class_name,
function_name))
if route.deprecated:
self.emit(' * @deprecated')
return_type = None
if self.args.wrap_response_in:
return_type = '%s<%s>' % (self.args.wrap_response_in,
fmt_type(route.result_data_type))
else:
return_type = fmt_type(route.result_data_type)
if route.arg_data_type.__class__ != Void:
self.emit(' * @arg {%s} arg - The request parameters.' %
fmt_type(route.arg_data_type))
self.emit(' * @returns {Promise.<%s, %s>}' %
(return_type,
fmt_error_type(route.error_data_type)))
self.emit(' */')
if route.arg_data_type.__class__ != Void:
self.emit('routes.%s = function (arg) {' % (function_name))
else:
self.emit('routes.%s = function () {' % (function_name))
with self.indent(dent=2):
url = fmt_url(namespace.name, route.name, route.version)
if route_schema.fields:
additional_args = []
for field in route_schema.fields:
additional_args.append(fmt_obj(route.attrs[field.name]))
if route.arg_data_type.__class__ != Void:
self.emit(
"return this.request('{}', arg, {});".format(
url, ', '.join(additional_args)))
else:
self.emit(
"return this.request('{}', null, {});".format(
url, ', '.join(additional_args)))
else:
if route.arg_data_type.__class__ != Void:
self.emit(
'return this.request("%s", arg);' % url)
else:
self.emit(
'return this.request("%s", null);' % url)
self.emit('};')
def _docf(self, tag, val):
"""
Callback used as the handler argument to process_docs(). This converts
Stone doc references to JSDoc-friendly annotations.
"""
# TODO(kelkabany): We're currently just dropping all doc ref tags ...
# NOTE(praneshp): ... except for versioned routes
if tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
url = fmt_url(self.cur_namespace.name, val, version)
# NOTE: In js, for comments, we drop the namespace name and the '/' when
# documenting URLs
return url[(len(self.cur_namespace.name) + 1):]
return val

View file

@ -0,0 +1,127 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import six
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_list_type,
is_struct_type,
is_user_defined_type,
)
from stone.backends.helpers import (
fmt_camel,
fmt_pascal,
)
_base_type_table = {
Boolean: 'boolean',
Bytes: 'string',
Float32: 'number',
Float64: 'number',
Int32: 'number',
Int64: 'number',
List: 'Array',
String: 'string',
UInt32: 'number',
UInt64: 'number',
Timestamp: 'Timestamp',
Void: 'void',
}
def fmt_obj(o):
if isinstance(o, six.text_type):
# Prioritize single-quoted strings per JS style guides.
return repr(o).lstrip('u')
else:
return json.dumps(o, indent=2)
def fmt_error_type(data_type):
"""
Converts the error type into a JSDoc type.
"""
return 'Error.<%s>' % fmt_type(data_type)
def fmt_type_name(data_type):
"""
Returns the JSDoc name for the given data type.
(Does not attempt to enumerate subtypes.)
"""
if is_user_defined_type(data_type):
return fmt_pascal('%s%s' % (data_type.namespace.name, data_type.name))
else:
fmted_type = _base_type_table.get(data_type.__class__, 'Object')
if is_list_type(data_type):
fmted_type += '.<' + fmt_type(data_type.data_type) + '>'
return fmted_type
def fmt_type(data_type):
"""
Returns a JSDoc annotation for a data type.
May contain a union of enumerated subtypes.
"""
if is_struct_type(data_type) and data_type.has_enumerated_subtypes():
possible_types = []
possible_subtypes = data_type.get_all_subtypes_with_tags()
for _, subtype in possible_subtypes:
possible_types.append(fmt_type_name(subtype))
if data_type.is_catch_all():
possible_types.append(fmt_type_name(data_type))
return fmt_jsdoc_union(possible_types)
else:
return fmt_type_name(data_type)
def fmt_jsdoc_union(type_strings):
"""
Returns a JSDoc union of the given type strings.
"""
return '(' + '|'.join(type_strings) + ')' if len(type_strings) > 1 else type_strings[0]
def fmt_func(name, version):
if version == 1:
return fmt_camel(name)
return fmt_camel(name) + 'V{}'.format(version)
def fmt_url(namespace_name, route_name, route_version):
if route_version != 1:
return '{}/{}_v{}'.format(namespace_name, route_name, route_version)
else:
return '{}/{}'.format(namespace_name, route_name)
def fmt_var(name):
return fmt_camel(name)
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, version=route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route

View file

@ -0,0 +1,285 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import six
import sys
from stone.ir import (
is_user_defined_type,
is_union_type,
is_struct_type,
is_void_type,
unwrap,
)
from stone.backend import CodeBackend
from stone.backends.js_helpers import (
fmt_jsdoc_union,
fmt_type,
fmt_type_name,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(prog='js-types-backend')
_cmdline_parser.add_argument(
'filename',
help=('The name to give the single Javascript file that is created and '
'contains all of the JSDoc types.'),
)
_cmdline_parser.add_argument(
'-e',
'--extra-arg',
action='append',
type=str,
default=[],
help=("Additional properties to add to a route's argument type based "
"on if the route has a certain attribute set. Format (JSON): "
'{"match": ["ROUTE_ATTR", ROUTE_VALUE_TO_MATCH], '
'"arg_name": "ARG_NAME", "arg_type": "ARG_TYPE", '
'"arg_docstring": "ARG_DOCSTRING"}'),
)
_header = """\
// Auto-generated by Stone, do not modify.
/**
* An Error object returned from a route.
* @typedef {Object} Error
* @property {string} error_summary - Text summary of the error.
* @property {T} error - The error object.
* @property {UserMessage} user_message - An optional field. If present, it includes a
message that can be shown directly to the end user of your app. You should show this message
if your app is unprepared to programmatically handle the error returned by an endpoint.
* @template T
*/
/**
* User-friendly error message.
* @typedef {Object} UserMessage
* @property {string} text - The message.
* @property {string} locale
*/
/**
* @typedef {string} Timestamp
*/
"""
class JavascriptTypesBackend(CodeBackend):
"""Generates a single Javascript file with all of the data types defined in JSDoc."""
cmdline_parser = _cmdline_parser
preserve_aliases = True
def generate(self, api):
with self.output_to_relative_path(self.args.filename):
self.emit_raw(_header)
extra_args = self._parse_extra_args(api, self.args.extra_arg)
for namespace in api.namespaces.values():
for data_type in namespace.data_types:
self._generate_type(data_type, extra_args.get(data_type, []))
def _parse_extra_args(self, api, extra_args_raw):
"""
Parses extra arguments into a map keyed on particular data types.
"""
extra_args = {}
def die(m, extra_arg_raw):
print('Invalid --extra-arg:%s: %s' % (m, extra_arg_raw),
file=sys.stderr)
sys.exit(1)
for extra_arg_raw in extra_args_raw:
try:
extra_arg = json.loads(extra_arg_raw)
except ValueError as e:
die(str(e), extra_arg_raw)
# Validate extra_arg JSON blob
if 'match' not in extra_arg:
die('No match key', extra_arg_raw)
elif (not isinstance(extra_arg['match'], list) or
len(extra_arg['match']) != 2):
die('match key is not a list of two strings', extra_arg_raw)
elif (not isinstance(extra_arg['match'][0], six.text_type) or
not isinstance(extra_arg['match'][1], six.text_type)):
print(type(extra_arg['match'][0]))
die('match values are not strings', extra_arg_raw)
elif 'arg_name' not in extra_arg:
die('No arg_name key', extra_arg_raw)
elif not isinstance(extra_arg['arg_name'], six.text_type):
die('arg_name is not a string', extra_arg_raw)
elif 'arg_type' not in extra_arg:
die('No arg_type key', extra_arg_raw)
elif not isinstance(extra_arg['arg_type'], six.text_type):
die('arg_type is not a string', extra_arg_raw)
elif ('arg_docstring' in extra_arg and
not isinstance(extra_arg['arg_docstring'], six.text_type)):
die('arg_docstring is not a string', extra_arg_raw)
attr_key, attr_val = extra_arg['match'][0], extra_arg['match'][1]
extra_args.setdefault(attr_key, {})[attr_val] = \
(extra_arg['arg_name'], extra_arg['arg_type'],
extra_arg.get('arg_docstring'))
# Extra arguments, keyed on data type objects.
extra_args_for_types = {}
# Locate data types that contain extra arguments
for namespace in api.namespaces.values():
for route in namespace.routes:
extra_parameters = []
if is_user_defined_type(route.arg_data_type):
for attr_key in route.attrs:
if attr_key not in extra_args:
continue
attr_val = route.attrs[attr_key]
if attr_val in extra_args[attr_key]:
extra_parameters.append(extra_args[attr_key][attr_val])
if len(extra_parameters) > 0:
extra_args_for_types[route.arg_data_type] = extra_parameters
return extra_args_for_types
def _generate_type(self, data_type, extra_parameters):
if is_struct_type(data_type):
self._generate_struct(data_type, extra_parameters)
elif is_union_type(data_type):
self._generate_union(data_type)
def _emit_jsdoc_header(self, doc=None):
self.emit()
self.emit('/**')
if doc:
self.emit_wrapped_text(self.process_doc(doc, self._docf), prefix=' * ')
def _generate_struct(self, struct_type, extra_parameters=None, nameOverride=None):
"""
Emits a JSDoc @typedef for a struct.
"""
extra_parameters = extra_parameters if extra_parameters is not None else []
self._emit_jsdoc_header(struct_type.doc)
self.emit(
' * @typedef {Object} %s' % (
nameOverride if nameOverride else fmt_type_name(struct_type)
)
)
# Some structs can explicitly list their subtypes. These structs
# have a .tag field that indicate which subtype they are.
if struct_type.is_member_of_enumerated_subtypes_tree():
if struct_type.has_enumerated_subtypes():
# This struct is the parent to multiple subtypes.
# Determine all of the possible values of the .tag
# property.
tag_values = []
for tags, _ in struct_type.get_all_subtypes_with_tags():
for tag in tags:
tag_values.append('"%s"' % tag)
jsdoc_tag_union = fmt_jsdoc_union(tag_values)
txt = '@property {%s} .tag - Tag identifying the subtype variant.' % \
jsdoc_tag_union
self.emit_wrapped_text(txt)
else:
# This struct is a particular subtype. Find the applicable
# .tag value from the parent type, which may be an
# arbitrary number of steps up the inheritance hierarchy.
parent = struct_type.parent_type
while not parent.has_enumerated_subtypes():
parent = parent.parent_type
# parent now contains the closest parent type in the
# inheritance hierarchy that has enumerated subtypes.
# Determine which subtype this is.
for subtype in parent.get_enumerated_subtypes():
if subtype.data_type == struct_type:
txt = '@property {\'%s\'} [.tag] - Tag identifying ' \
'this subtype variant. This field is only ' \
'present when needed to discriminate ' \
'between multiple possible subtypes.' % \
subtype.name
self.emit_wrapped_text(txt)
break
for param_name, param_type, param_docstring in extra_parameters:
param_docstring = ' - %s' % param_docstring if param_docstring else ''
self.emit_wrapped_text(
'@property {%s} %s%s' % (
param_type,
param_name,
param_docstring,
),
prefix=' * ',
)
# NOTE: JSDoc @typedef does not support inheritance. Using @class would be inappropriate,
# since these are not nominal types backed by a constructor. Thus, we emit all_fields,
# which includes fields on parent types.
for field in struct_type.all_fields:
field_doc = ' - ' + field.doc if field.doc else ''
field_type, nullable, _ = unwrap(field.data_type)
field_js_type = fmt_type(field_type)
# Translate nullable types into optional properties.
field_name = '[' + field.name + ']' if nullable else field.name
self.emit_wrapped_text(
'@property {%s} %s%s' % (
field_js_type,
field_name,
self.process_doc(field_doc, self._docf),
),
prefix=' * ',
)
self.emit(' */')
def _generate_union(self, union_type):
"""
Emits a JSDoc @typedef for a union type.
"""
union_name = fmt_type_name(union_type)
self._emit_jsdoc_header(union_type.doc)
self.emit(' * @typedef {Object} %s' % union_name)
variant_types = []
for variant in union_type.all_fields:
variant_types.append("'%s'" % variant.name)
variant_data_type, _, _ = unwrap(variant.data_type)
# Don't emit fields for void types.
if not is_void_type(variant_data_type):
variant_doc = ' - Available if .tag is %s.' % variant.name
if variant.doc:
variant_doc += ' ' + variant.doc
self.emit_wrapped_text(
'@property {%s} [%s]%s' % (
fmt_type(variant_data_type),
variant.name,
variant_doc,
),
prefix=' * ',
)
jsdoc_tag_union = fmt_jsdoc_union(variant_types)
self.emit(' * @property {%s} .tag - Tag identifying the union variant.' % jsdoc_tag_union)
self.emit(' */')
def _docf(self, tag, val): # pylint: disable=unused-argument
"""
Callback used as the handler argument to process_docs(). This converts
Stone doc references to JSDoc-friendly annotations.
"""
# TODO(kelkabany): We're currently just dropping all doc ref tags.
return val

View file

@ -0,0 +1,283 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from contextlib import contextmanager
from stone.ir import (
is_list_type,
is_map_type,
is_struct_type,
is_union_type,
is_nullable_type,
is_user_defined_type,
is_void_type,
unwrap_nullable, )
from stone.backend import CodeBackend
from stone.backends.obj_c_helpers import (
fmt_camel_upper,
fmt_class,
fmt_class_prefix,
fmt_import, )
stone_warning = """\
///
/// Copyright (c) 2016 Dropbox, Inc. All rights reserved.
///
/// Auto-generated by Stone, do not modify.
///
"""
# This will be at the top of the generated file.
base_file_comment = """\
{}\
""".format(stone_warning)
undocumented = '(no description).'
comment_prefix = '/// '
class ObjCBaseBackend(CodeBackend):
"""Wrapper class over Stone generator for Obj C logic."""
# pylint: disable=abstract-method
@contextmanager
def block_m(self, class_name):
with self.block(
'@implementation {}'.format(class_name),
delim=('', '@end'),
dent=0):
self.emit()
yield
@contextmanager
def block_h_from_data_type(self, data_type, protocol=None):
assert is_user_defined_type(data_type), \
'Expected user-defined type, got %r' % type(data_type)
if not protocol:
extensions = []
if data_type.parent_type and is_struct_type(data_type):
extensions.append(fmt_class_prefix(data_type.parent_type))
else:
if is_union_type(data_type):
# Use a handwritten base class
extensions.append('NSObject')
else:
extensions.append('NSObject')
extend_suffix = ' : {}'.format(
', '.join(extensions)) if extensions else ''
else:
base = fmt_class_prefix(data_type.parent_type) if (
data_type.parent_type and
not is_union_type(data_type)) else 'NSObject'
extend_suffix = ' : {} <{}>'.format(base, ', '.join(protocol))
with self.block(
'@interface {}{}'.format(
fmt_class_prefix(data_type), extend_suffix),
delim=('', '@end'),
dent=0):
self.emit()
yield
@contextmanager
def block_h(self,
class_name,
protocol=None,
extensions=None,
protected=None):
if not extensions:
extensions = ['NSObject']
if not protocol:
extend_suffix = ' : {}'.format(', '.join(extensions))
else:
extend_suffix = ' : {} <{}>'.format(', '.join(extensions),
fmt_class(protocol))
base_interface_str = '@interface {}{} {{' if protected else '@interface {}{}'
with self.block(
base_interface_str.format(class_name, extend_suffix),
delim=('', '@end'),
dent=0):
if protected:
with self.block('', delim=('', '')):
self.emit('@protected')
for field_name, field_type in protected:
self.emit('{} _{};'.format(field_type, field_name))
self.emit('}')
self.emit()
yield
@contextmanager
def block_init(self):
with self.block('if (self)'):
yield
self.emit('return self;')
@contextmanager
def block_func(self, func, args=None, return_type='void',
class_func=False):
args = args if args is not None else []
modifier = '-' if not class_func else '+'
base_string = '{} ({}){}:{}' if args else '{} ({}){}'
signature = base_string.format(modifier, return_type, func, args)
with self.block(signature):
yield
def _get_imports_m(self, data_types, default_imports):
"""Emits all necessary implementation file imports for the given Stone data type."""
if not isinstance(data_types, list):
data_types = [data_types]
import_classes = default_imports
for data_type in data_types:
import_classes.append(fmt_class_prefix(data_type))
if data_type.parent_type:
import_classes.append(fmt_class_prefix(data_type.parent_type))
if is_struct_type(
data_type) and data_type.has_enumerated_subtypes():
for _, subtype in data_type.get_all_subtypes_with_tags():
import_classes.append(fmt_class_prefix(subtype))
for field in data_type.all_fields:
data_type, _ = unwrap_nullable(field.data_type)
# unpack list or map
while is_list_type(data_type) or is_map_type(data_type):
data_type = (data_type.value_data_type if
is_map_type(data_type) else data_type.data_type)
if is_user_defined_type(data_type):
import_classes.append(fmt_class_prefix(data_type))
if import_classes:
import_classes = list(set(import_classes))
import_classes.sort()
return import_classes
def _get_imports_h(self, data_types):
"""Emits all necessary header file imports for the given Stone data type."""
if not isinstance(data_types, list):
data_types = [data_types]
import_classes = []
for data_type in data_types:
if is_user_defined_type(data_type):
import_classes.append(fmt_class_prefix(data_type))
for field in data_type.all_fields:
data_type, _ = unwrap_nullable(field.data_type)
# unpack list or map
while is_list_type(data_type) or is_map_type(data_type):
data_type = (data_type.value_data_type if
is_map_type(data_type) else data_type.data_type)
if is_user_defined_type(data_type):
import_classes.append(fmt_class_prefix(data_type))
import_classes = list(set(import_classes))
import_classes.sort()
return import_classes
def _generate_imports_h(self, import_classes):
import_classes = list(set(import_classes))
import_classes.sort()
for import_class in import_classes:
self.emit('@class {};'.format(import_class))
if import_classes:
self.emit()
def _generate_imports_m(self, import_classes):
import_classes = list(set(import_classes))
import_classes.sort()
for import_class in import_classes:
self.emit(fmt_import(import_class))
self.emit()
def _generate_init_imports_h(self, data_type):
self.emit('#import <Foundation/Foundation.h>')
self.emit()
self.emit('#import "DBSerializableProtocol.h"')
if data_type.parent_type and not is_union_type(data_type):
self.emit(fmt_import(fmt_class_prefix(data_type.parent_type)))
self.emit()
def _get_namespace_route_imports(self,
namespace,
include_route_args=True,
include_route_deep_args=False):
result = []
def _unpack_and_store_data_type(data_type):
data_type, _ = unwrap_nullable(data_type)
if is_list_type(data_type):
while is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
if not is_void_type(data_type) and is_user_defined_type(data_type):
result.append(data_type)
for route in namespace.routes:
if include_route_args:
data_type, _ = unwrap_nullable(route.arg_data_type)
_unpack_and_store_data_type(data_type)
elif include_route_deep_args:
data_type, _ = unwrap_nullable(route.arg_data_type)
if is_union_type(data_type) or is_list_type(data_type):
_unpack_and_store_data_type(data_type)
elif not is_void_type(data_type):
for field in data_type.all_fields:
data_type, _ = unwrap_nullable(field.data_type)
if (is_struct_type(data_type) or
is_union_type(data_type) or
is_list_type(data_type)):
_unpack_and_store_data_type(data_type)
_unpack_and_store_data_type(route.result_data_type)
_unpack_and_store_data_type(route.error_data_type)
return result
def _cstor_name_from_fields(self, fields):
"""Returns an Obj C appropriate name for a constructor based on
the name of the first argument."""
if fields:
return self._cstor_name_from_field(fields[0])
else:
return 'initDefault'
def _cstor_name_from_field(self, field):
"""Returns an Obj C appropriate name for a constructor based on
the name of the supplied argument."""
return 'initWith{}'.format(fmt_camel_upper(field.name))
def _cstor_name_from_fields_names(self, fields_names):
"""Returns an Obj C appropriate name for a constructor based on
the name of the first argument."""
if fields_names:
return 'initWith{}'.format(fmt_camel_upper(fields_names[0][0]))
else:
return 'initDefault'
def _struct_has_defaults(self, struct):
"""Returns whether the given struct has any default values."""
return [
f for f in struct.all_fields
if f.has_default or is_nullable_type(f.data_type)
]

View file

@ -0,0 +1,618 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
from stone.ir import (
is_nullable_type,
is_struct_type,
is_union_type,
is_void_type,
unwrap_nullable, )
from stone.backends.obj_c_helpers import (
fmt_alloc_call,
fmt_camel_upper,
fmt_class,
fmt_class_prefix,
fmt_func,
fmt_func_args,
fmt_func_args_declaration,
fmt_func_call,
fmt_import,
fmt_property_str,
fmt_route_obj_class,
fmt_route_func,
fmt_route_var,
fmt_routes_class,
fmt_signature,
fmt_type,
fmt_var, )
from stone.backends.obj_c import (
base_file_comment,
comment_prefix,
ObjCBaseBackend,
stone_warning,
undocumented, )
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(
prog='objc-client-backend',
description=(
'Generates a ObjC class with an object for each namespace, and in each '
'namespace object, a method for each route. This class assumes that the '
'obj_c_types backend was used with the same output directory.'), )
_cmdline_parser.add_argument(
'-m',
'--module-name',
required=True,
type=str,
help=(
'The name of the ObjC module to generate. Please exclude the {.h,.m} '
'file extension.'), )
_cmdline_parser.add_argument(
'-c',
'--class-name',
required=True,
type=str,
help=(
'The name of the ObjC class that contains an object for each namespace, '
'and in each namespace object, a method for each route.'))
_cmdline_parser.add_argument(
'-t',
'--transport-client-name',
required=True,
type=str,
help='The name of the ObjC class that manages network API calls.', )
_cmdline_parser.add_argument(
'-w',
'--auth-type',
type=str,
help='The auth type of the client to generate.', )
_cmdline_parser.add_argument(
'-y',
'--client-args',
required=True,
type=str,
help='The client-side route arguments to append to each route by style type.', )
_cmdline_parser.add_argument(
'-z',
'--style-to-request',
required=True,
type=str,
help='The dict that maps a style type to a ObjC request object name.', )
class ObjCBackend(ObjCBaseBackend):
"""Generates ObjC client base that implements route interfaces."""
cmdline_parser = _cmdline_parser
obj_name_to_namespace = {} # type: typing.Dict[str, int]
namespace_to_has_routes = {} # type: typing.Dict[typing.Any, bool]
def generate(self, api):
for namespace in api.namespaces.values():
self.namespace_to_has_routes[namespace] = False
if namespace.routes:
for route in namespace.routes:
if self._should_generate_route(route):
self.namespace_to_has_routes[namespace] = True
break
for namespace in api.namespaces.values():
for data_type in namespace.linearize_data_types():
self.obj_name_to_namespace[data_type.name] = fmt_class_prefix(
data_type)
for namespace in api.namespaces.values():
if namespace.routes and self.namespace_to_has_routes[namespace]:
import_classes = [
fmt_routes_class(namespace.name, self.args.auth_type),
fmt_route_obj_class(namespace.name),
'{}Protocol'.format(self.args.transport_client_name),
'DBStoneBase',
'DBRequestErrors',
]
with self.output_to_relative_path('Routes/{}.m'.format(
fmt_routes_class(namespace.name,
self.args.auth_type))):
self.emit_raw(stone_warning)
imports_classes_m = import_classes + \
self._get_imports_m(
self._get_namespace_route_imports(namespace), [])
self._generate_imports_m(imports_classes_m)
self._generate_routes_m(namespace)
with self.output_to_relative_path('Routes/{}.h'.format(
fmt_routes_class(namespace.name,
self.args.auth_type))):
self.emit_raw(base_file_comment)
self.emit('#import <Foundation/Foundation.h>')
self.emit()
self.emit(fmt_import('DBTasks'))
self.emit()
import_classes_h = [
'DBNilObject',
]
import_classes_h = (import_classes_h + self._get_imports_h(
self._get_namespace_route_imports(
namespace,
include_route_args=False,
include_route_deep_args=True)))
self._generate_imports_h(import_classes_h)
self.emit(
'@protocol {};'.format(
self.args.transport_client_name), )
self.emit()
self._generate_routes_h(namespace)
with self.output_to_relative_path(
'Client/{}.m'.format(self.args.module_name)):
self._generate_client_m(api)
with self.output_to_relative_path(
'Client/{}.h'.format(self.args.module_name)):
self._generate_client_h(api)
def _generate_client_m(self, api):
"""Generates client base implementation file. For each namespace, the client will
have an object field that encapsulates each route in the particular namespace."""
self.emit_raw(base_file_comment)
import_classes = [self.args.module_name]
import_classes += [
fmt_routes_class(ns.name, self.args.auth_type)
for ns in api.namespaces.values()
if ns.routes and self.namespace_to_has_routes[ns]
]
import_classes.append(
'{}Protocol'.format(self.args.transport_client_name))
self._generate_imports_m(import_classes)
with self.block_m(self.args.class_name):
client_args = fmt_func_args_declaration(
[('client',
'id<{}>'.format(self.args.transport_client_name))])
with self.block_func(
func='initWithTransportClient',
args=client_args,
return_type='instancetype'):
self.emit('self = [super init];')
with self.block_init():
self.emit('_transportClient = client;')
for namespace in api.namespaces.values():
if namespace.routes and self.namespace_to_has_routes[namespace]:
base_string = '_{}Routes = [[{} alloc] init:client];'
self.emit(
base_string.format(
fmt_var(namespace.name),
fmt_routes_class(namespace.name,
self.args.auth_type)))
def _generate_client_h(self, api):
"""Generates client base header file. For each namespace, the client will
have an object field that encapsulates each route in the particular namespace."""
self.emit_raw(stone_warning)
self.emit('#import <Foundation/Foundation.h>')
import_classes = [
fmt_routes_class(ns.name, self.args.auth_type)
for ns in api.namespaces.values()
if ns.routes and self.namespace_to_has_routes[ns]
]
import_classes.append('DBRequestErrors')
import_classes.append('DBTasks')
self._generate_imports_m(import_classes)
self.emit()
self.emit('NS_ASSUME_NONNULL_BEGIN')
self.emit()
self.emit('@protocol {};'.format(self.args.transport_client_name))
self.emit()
self.emit(comment_prefix)
description_str = (
'Base client object that contains an instance field for '
'each namespace, each of which contains references to all routes within '
'that namespace. Fully-implemented API clients will inherit this class.'
)
self.emit_wrapped_text(description_str, prefix=comment_prefix)
self.emit(comment_prefix)
with self.block_h(
self.args.class_name,
protected=[
('transportClient',
'id<{}>'.format(self.args.transport_client_name))
]):
self.emit()
for namespace in api.namespaces.values():
if namespace.routes and self.namespace_to_has_routes[namespace]:
class_doc = 'Routes within the `{}` namespace.'.format(
fmt_var(namespace.name))
self.emit_wrapped_text(class_doc, prefix=comment_prefix)
prop = '{}Routes'.format(fmt_var(namespace.name))
typ = '{} *'.format(
fmt_routes_class(namespace.name, self.args.auth_type))
self.emit(fmt_property_str(prop=prop, typ=typ))
self.emit()
client_args = fmt_func_args_declaration(
[('client',
'id<{}>'.format(self.args.transport_client_name))])
description_str = (
'Initializes the `{}` object with a networking client.')
self.emit_wrapped_text(
description_str.format(self.args.class_name),
prefix=comment_prefix)
init_signature = fmt_signature(
func='initWithTransportClient',
args=client_args,
return_type='instancetype')
self.emit('{};'.format(init_signature))
self.emit()
self.emit()
self.emit('NS_ASSUME_NONNULL_END')
def _auth_type_in_route(self, route, desired_auth_type):
for auth_type in route.attrs.get('auth').split(','):
if auth_type.strip() == desired_auth_type:
return True
return False
def _route_is_special_noauth_case(self, route):
return self._auth_type_in_route(route, 'noauth') and self.args.auth_type == 'user'
def _should_generate_route(self, route):
return (self._auth_type_in_route(route, self.args.auth_type) or
self._route_is_special_noauth_case(route))
def _generate_routes_m(self, namespace):
"""Generates implementation file for namespace object that has as methods
all routes within the namespace."""
with self.block_m(
fmt_routes_class(namespace.name, self.args.auth_type)):
init_args = fmt_func_args_declaration([(
'client', 'id<{}>'.format(self.args.transport_client_name))])
with self.block_func(
func='init', args=init_args, return_type='instancetype'):
self.emit('self = [super init];')
with self.block_init():
self.emit('_client = client;')
self.emit()
style_to_request = json.loads(self.args.style_to_request)
for route in namespace.routes:
if not self._should_generate_route(route):
continue
route_type = route.attrs.get('style')
client_args = json.loads(self.args.client_args)
if route_type in client_args.keys():
for args_data in client_args[route_type]:
task_type_key, type_data_dict = tuple(args_data)
task_type_name = style_to_request[task_type_key]
func_suffix = type_data_dict[0]
extra_args = [
tuple(type_data[:-1])
for type_data in type_data_dict[1]
]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, _ = self._get_default_route_args(
namespace, route)
self._generate_route_m(route, namespace,
route_args, extra_args,
task_type_name, func_suffix)
route_args, _ = self._get_route_args(namespace, route)
self._generate_route_m(route, namespace, route_args,
extra_args, task_type_name,
func_suffix)
else:
task_type_name = style_to_request[route_type]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, _ = self._get_default_route_args(
namespace, route)
self._generate_route_m(route, namespace, route_args,
[], task_type_name, '')
route_args, _ = self._get_route_args(namespace, route)
self._generate_route_m(route, namespace, route_args, [],
task_type_name, '')
def _generate_route_m(self, route, namespace, route_args, extra_args,
task_type_name, func_suffix):
"""Generates route method implementation for the given route."""
user_args = list(route_args)
transport_args = [
('route', 'route'),
('arg', 'arg' if not is_void_type(route.arg_data_type) else 'nil'),
]
for name, value, typ in extra_args:
user_args.append((name, typ))
transport_args.append((name, value))
with self.block_func(
func='{}{}'.format(fmt_route_func(route), func_suffix),
args=fmt_func_args_declaration(user_args),
return_type='{} *'.format(task_type_name)):
self.emit('DBRoute *route = {}.{};'.format(
fmt_route_obj_class(namespace.name),
fmt_route_var(namespace.name, route)))
if is_union_type(route.arg_data_type):
self.emit('{} *arg = {};'.format(
fmt_class_prefix(route.arg_data_type),
fmt_var(route.arg_data_type.name)))
elif not is_void_type(route.arg_data_type):
init_call = fmt_func_call(
caller=fmt_alloc_call(
caller=fmt_class_prefix(route.arg_data_type)),
callee=self._cstor_name_from_fields_names(route_args),
args=fmt_func_args([(f[0], f[0]) for f in route_args]))
self.emit('{} *arg = {};'.format(
fmt_class_prefix(route.arg_data_type), init_call))
request_call = fmt_func_call(
caller='self.client',
callee='request{}'.format(
fmt_camel_upper(route.attrs.get('style'))),
args=fmt_func_args(transport_args))
self.emit('return {};'.format(request_call))
self.emit()
def _generate_routes_h(self, namespace):
"""Generates header file for namespace object that has as methods
all routes within the namespace."""
self.emit(comment_prefix)
self.emit_wrapped_text(
'Routes for the `{}` namespace'.format(fmt_class(namespace.name)),
prefix=comment_prefix)
self.emit(comment_prefix)
self.emit()
self.emit('NS_ASSUME_NONNULL_BEGIN')
self.emit()
with self.block_h(
fmt_routes_class(namespace.name, self.args.auth_type)):
description_str = (
'An instance of the networking client that each '
'route will use to submit a request.')
self.emit_wrapped_text(description_str, prefix=comment_prefix)
self.emit(
fmt_property_str(
prop='client',
typ='id<{}>'.format(
self.args.transport_client_name)))
self.emit()
routes_obj_args = fmt_func_args_declaration(
[('client',
'id<{}>'.format(self.args.transport_client_name))])
init_signature = fmt_signature(
func='init',
args=routes_obj_args,
return_type='instancetype')
description_str = (
'Initializes the `{}` namespace container object '
'with a networking client.')
self.emit_wrapped_text(
description_str.format(
fmt_routes_class(namespace.name, self.args.auth_type)),
prefix=comment_prefix)
self.emit('{};'.format(init_signature))
self.emit()
style_to_request = json.loads(self.args.style_to_request)
for route in namespace.routes:
if not self._should_generate_route(route):
continue
route_type = route.attrs.get('style')
client_args = json.loads(self.args.client_args)
if route_type in client_args.keys():
for args_data in client_args[route_type]:
task_type_key, type_data_dict = tuple(args_data)
task_type_name = style_to_request[task_type_key]
func_suffix = type_data_dict[0]
extra_args = [
tuple(type_data[:-1])
for type_data in type_data_dict[1]
]
extra_docs = [(type_data[0], type_data[-1])
for type_data in type_data_dict[1]]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, doc_list = self._get_default_route_args(
namespace, route, tag=True)
self._generate_route_signature(
route, namespace, route_args, extra_args,
doc_list + extra_docs, task_type_name,
func_suffix)
route_args, doc_list = self._get_route_args(
namespace, route, tag=True)
self._generate_route_signature(
route, namespace, route_args, extra_args,
doc_list + extra_docs, task_type_name, func_suffix)
else:
task_type_name = style_to_request[route_type]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, doc_list = self._get_default_route_args(
namespace, route, tag=True)
self._generate_route_signature(
route, namespace, route_args, [], doc_list,
task_type_name, '')
route_args, doc_list = self._get_route_args(
namespace, route, tag=True)
self._generate_route_signature(route, namespace,
route_args, [], doc_list,
task_type_name, '')
self.emit()
self.emit('NS_ASSUME_NONNULL_END')
self.emit()
def _generate_route_signature(
self,
route,
namespace, # pylint: disable=unused-argument
route_args,
extra_args,
doc_list,
task_type_name,
func_suffix):
"""Generates route method signature for the given route."""
for name, _, typ in extra_args:
route_args.append((name, typ))
deprecated = 'DEPRECATED: ' if route.deprecated else ''
func_name = '{}{}'.format(fmt_route_func(route), func_suffix)
self.emit(comment_prefix)
if route.doc:
route_doc = self.process_doc(route.doc, self._docf)
else:
route_doc = 'The {} route'.format(func_name)
self.emit_wrapped_text(
deprecated + route_doc, prefix=comment_prefix, width=120)
self.emit(comment_prefix)
for name, doc in doc_list:
self.emit_wrapped_text(
'@param {} {}'.format(name, doc if doc else undocumented),
prefix=comment_prefix,
width=120)
self.emit(comment_prefix)
output = (
'@return Through the response callback, the caller will ' +
'receive a `{}` object on success or a `{}` object on failure.')
output = output.format(
fmt_type(route.result_data_type, tag=False, no_ptr=True),
fmt_type(route.error_data_type, tag=False, no_ptr=True))
self.emit_wrapped_text(output, prefix=comment_prefix, width=120)
self.emit(comment_prefix)
result_type_str = fmt_type(route.result_data_type) if not is_void_type(
route.result_data_type) else 'DBNilObject *'
error_type_str = fmt_type(route.error_data_type) if not is_void_type(
route.error_data_type) else 'DBNilObject *'
return_type = '{}<{}, {}> *'.format(task_type_name, result_type_str,
error_type_str)
deprecated = self._get_deprecation_warning(route)
route_signature = fmt_signature(
func=func_name,
args=fmt_func_args_declaration(route_args),
return_type='{}'.format(return_type))
self.emit('{}{};'.format(route_signature, deprecated))
self.emit()
def _get_deprecation_warning(self, route):
"""Returns a deprecation tag / message, if route is deprecated."""
result = ''
if route.deprecated:
msg = '{} is deprecated.'.format(fmt_route_func(route))
if route.deprecated.by:
msg += ' Use {}.'.format(fmt_var(route.deprecated.by.name))
result = ' __deprecated_msg("{}")'.format(msg)
return result
def _get_route_args(self, namespace, route, tag=False): # pylint: disable=unused-argument
"""Returns a list of name / value string pairs representing the arguments for
a particular route."""
data_type, _ = unwrap_nullable(route.arg_data_type)
if is_struct_type(data_type):
arg_list = []
for field in data_type.all_fields:
arg_list.append((fmt_var(field.name), fmt_type(
field.data_type, tag=tag, has_default=field.has_default)))
doc_list = [(fmt_var(f.name), self.process_doc(f.doc, self._docf))
for f in data_type.fields if f.doc]
elif is_union_type(data_type):
arg_list = [(fmt_var(data_type.name), fmt_type(
route.arg_data_type, tag=tag))]
doc_list = [(fmt_var(data_type.name),
self.process_doc(data_type.doc,
self._docf) if data_type.doc
else 'The {} union'.format(
fmt_class(data_type
.name)))]
else:
arg_list = []
doc_list = []
return arg_list, doc_list
def _get_default_route_args(
self,
namespace, # pylint: disable=unused-argument
route,
tag=False):
"""Returns a list of name / value string pairs representing the default arguments for
a particular route."""
data_type, _ = unwrap_nullable(route.arg_data_type)
if is_struct_type(data_type):
arg_list = []
for field in data_type.all_fields:
if not field.has_default and not is_nullable_type(
field.data_type):
arg_list.append((fmt_var(field.name), fmt_type(
field.data_type, tag=tag)))
doc_list = ([(fmt_var(f.name), self.process_doc(f.doc, self._docf))
for f in data_type.fields
if f.doc and not f.has_default and
not is_nullable_type(f.data_type)])
else:
arg_list = []
doc_list = []
return arg_list, doc_list
def _docf(self, tag, val):
if tag == 'route':
return '`{}`'.format(fmt_func(val))
elif tag == 'field':
if '.' in val:
cls_name, field = val.split('.')
return ('`{}` in `{}`'.format(
fmt_var(field), self.obj_name_to_namespace[cls_name]))
else:
return fmt_var(val)
elif tag in ('type', 'val', 'link'):
return val
else:
return val

View file

@ -0,0 +1,483 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import pprint
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
Map,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_boolean_type,
is_list_type,
is_map_type,
is_numeric_type,
is_string_type,
is_tag_ref,
is_user_defined_type,
is_void_type,
unwrap_nullable, )
from .helpers import split_words
# This file defines *stylistic* choices for Swift
# (ie, that class names are UpperCamelCase and that variables are lowerCamelCase)
_primitive_table = {
Boolean: 'NSNumber *',
Bytes: 'NSData',
Float32: 'NSNumber *',
Float64: 'NSNumber *',
Int32: 'NSNumber *',
Int64: 'NSNumber *',
List: 'NSArray',
Map: 'NSDictionary',
String: 'NSString *',
Timestamp: 'NSDate *',
UInt32: 'NSNumber *',
UInt64: 'NSNumber *',
Void: 'void',
}
_primitive_table_user_interface = {
Boolean: 'BOOL',
Bytes: 'NSData',
Float32: 'double',
Float64: 'double',
Int32: 'int',
Int64: 'long',
List: 'NSArray',
Map: 'NSDictionary',
String: 'NSString *',
Timestamp: 'NSDate *',
UInt32: 'unsigned int',
UInt64: 'unsigned long',
Void: 'void',
}
_serial_table = {
Boolean: 'DBBoolSerializer',
Bytes: 'DBNSDataSerializer',
Float32: 'DBNSNumberSerializer',
Float64: 'DBNSNumberSerializer',
Int32: 'DBNSNumberSerializer',
Int64: 'DBNSNumberSerializer',
List: 'DBArraySerializer',
Map: 'DBMapSerializer',
String: 'DBStringSerializer',
Timestamp: 'DBNSDateSerializer',
UInt32: 'DBNSNumberSerializer',
UInt64: 'DBNSNumberSerializer',
}
_validator_table = {
Float32: 'numericValidator',
Float64: 'numericValidator',
Int32: 'numericValidator',
Int64: 'numericValidator',
List: 'arrayValidator',
Map: 'mapValidator',
String: 'stringValidator',
UInt32: 'numericValidator',
UInt64: 'numericValidator',
}
_wrapper_primitives = {
Boolean,
Float32,
Float64,
UInt32,
UInt64,
Int32,
Int64,
String,
}
_reserved_words = {
'auto',
'else',
'long',
'switch',
'break',
'enum',
'register',
'typedef',
'case',
'extern',
'return',
'union',
'char',
'float',
'short',
'unsigned',
'const',
'for',
'signed',
'void',
'continue',
'goto',
'sizeof',
'volatile',
'default',
'if',
'static',
'while',
'do',
'int',
'struct',
'_Packed',
'double',
'protocol',
'interface',
'implementation',
'NSObject',
'NSInteger',
'NSNumber',
'CGFloat',
'property',
'nonatomic',
'retain',
'strong',
'weak',
'unsafe_unretained',
'readwrite',
'description',
'id',
'delete',
}
_reserved_prefixes = {
'copy',
'new',
}
def fmt_obj(o):
assert not isinstance(o, dict), "Only use for base type literals"
if o is True:
return 'true'
if o is False:
return 'false'
if o is None:
return 'nil'
return pprint.pformat(o, width=1)
def fmt_camel(name, upper_first=False, reserved=True):
name = str(name)
words = [word.capitalize() for word in split_words(name)]
if not upper_first:
words[0] = words[0].lower()
ret = ''.join(words)
if reserved:
if ret.lower() in _reserved_words:
ret += '_'
# properties can't begin with certain keywords
for reserved_prefix in _reserved_prefixes:
if ret.lower().startswith(reserved_prefix):
new_prefix = 'd' if not upper_first else 'D'
ret = new_prefix + ret[0].upper() + ret[1:]
continue
return ret
def fmt_enum_name(field_name, union):
return 'DB{}{}{}'.format(
fmt_class_caps(union.namespace.name),
fmt_camel_upper(union.name), fmt_camel_upper(field_name))
def fmt_camel_upper(name, reserved=True):
return fmt_camel(name, upper_first=True, reserved=reserved)
def fmt_public_name(name):
return fmt_camel_upper(name)
def fmt_class(name):
return fmt_camel_upper(name)
def fmt_class_caps(name):
return fmt_camel_upper(name).upper()
def fmt_class_type(data_type, suppress_ptr=False):
data_type, _ = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}'.format(fmt_class_prefix(data_type))
else:
result = _primitive_table.get(data_type.__class__,
fmt_class(data_type.name))
if suppress_ptr:
result = result.replace(' *', '')
result = result.replace('*', '')
if is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
result = result + '<{}>'.format(fmt_type(data_type))
elif is_map_type(data_type):
data_type, _ = unwrap_nullable(data_type.value_data_type)
result = result + '<NSString *, {}>'.format(fmt_type(data_type))
return result
def fmt_func(name):
return fmt_camel(name)
def fmt_type(data_type, tag=False, has_default=False, no_ptr=False, is_prop=False):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
base = '{}' if no_ptr else '{} *'
result = base.format(fmt_class_prefix(data_type))
else:
result = _primitive_table.get(data_type.__class__,
fmt_class(data_type.name))
if is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
base = '<{}>' if no_ptr else '<{}> *'
result = result + base.format(fmt_type(data_type))
elif is_map_type(data_type):
data_type, _ = unwrap_nullable(data_type.value_data_type)
base = '<NSString *, {}>' if no_ptr else '<NSString *, {}> *'
result = result + base.format(fmt_type(data_type))
if tag:
if (nullable or has_default) and not is_prop:
result = 'nullable ' + result
return result
def fmt_route_type(data_type, tag=False, has_default=False):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{} *'.format(fmt_class_prefix(data_type))
else:
result = _primitive_table_user_interface.get(data_type.__class__,
fmt_class(data_type.name))
if is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
result = result + '<{}> *'.format(fmt_type(data_type))
elif is_map_type(data_type):
data_type, _ = unwrap_nullable(data_type.value_data_type)
result = result + '<NSString *, {}>'.format(fmt_type(data_type))
if is_user_defined_type(data_type) and tag:
if nullable or has_default:
result = 'nullable ' + result
elif not is_void_type(data_type):
result += ''
return result
def fmt_class_prefix(data_type):
return 'DB{}{}'.format(
fmt_class_caps(data_type.namespace.name), fmt_class(data_type.name))
def fmt_validator(data_type):
return _validator_table.get(data_type.__class__, fmt_class(data_type.name))
def fmt_serial_obj(data_type):
data_type, _ = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = fmt_serial_class(fmt_class_prefix(data_type))
else:
result = _serial_table.get(data_type.__class__,
fmt_class(data_type.name))
return result
def fmt_serial_class(class_name):
return '{}Serializer'.format(class_name)
def fmt_route_obj_class(namespace_name):
return 'DB{}RouteObjects'.format(fmt_class_caps(namespace_name))
def fmt_routes_class(namespace_name, auth_type):
auth_type_to_use = auth_type
if auth_type == 'noauth':
auth_type_to_use = 'user'
return 'DB{}{}AuthRoutes'.format(
fmt_class_caps(namespace_name), fmt_camel_upper(auth_type_to_use))
def fmt_route_var(namespace_name, route):
ret = 'DB{}{}'.format(
fmt_class_caps(namespace_name), fmt_camel_upper(route.name))
if route.version != 1:
ret = '{}V{}'.format(ret, route.version)
return ret
def fmt_route_func(route):
ret = fmt_var(route.name)
if route.version != 1:
ret = '{}V{}'.format(ret, route.version)
return ret
def fmt_func_args(arg_str_pairs):
result = []
first_arg = True
for arg_name, arg_value in arg_str_pairs:
if first_arg:
result.append('{}'.format(arg_value))
first_arg = False
else:
result.append('{}:{}'.format(arg_name, arg_value))
return ' '.join(result)
def fmt_func_args_declaration(arg_str_pairs):
result = []
first_arg = True
for arg_name, arg_type in arg_str_pairs:
if first_arg:
result.append('({}){}'.format(arg_type, arg_name))
first_arg = False
else:
result.append('{0}:({1}){0}'.format(arg_name, arg_type))
return ' '.join(result)
def fmt_func_args_from_fields(args):
result = []
first_arg = True
for arg in args:
if first_arg:
result.append(
'({}){}'.format(fmt_type(arg.data_type), fmt_var(arg.name)))
first_arg = False
else:
result.append('{}:({}){}'.format(
fmt_var(arg.name), fmt_type(arg.data_type), fmt_var(arg.name)))
return ' '.join(result)
def fmt_func_call(caller, callee, args=None):
if args:
result = '[{} {}:{}]'.format(caller, callee, args)
else:
result = '[{} {}]'.format(caller, callee)
return result
def fmt_alloc_call(caller):
return '[{} alloc]'.format(caller)
def fmt_default_value(field):
if is_tag_ref(field.default):
return '[[{} alloc] initWith{}]'.format(
fmt_class_prefix(field.default.union_data_type),
fmt_class(field.default.tag_name))
elif is_numeric_type(field.data_type):
return '@({})'.format(field.default)
elif is_boolean_type(field.data_type):
if field.default:
bool_str = 'YES'
else:
bool_str = 'NO'
return '@{}'.format(bool_str)
elif is_string_type(field.data_type):
return '@"{}"'.format(field.default)
else:
raise TypeError(
'Can\'t handle default value type %r' % type(field.data_type))
def fmt_ns_number_call(data_type):
result = ''
if is_numeric_type(data_type):
if isinstance(data_type, UInt32):
result = 'numberWithUnsignedInt'
elif isinstance(data_type, UInt64):
result = 'numberWithUnsignedLong'
elif isinstance(data_type, Int32):
result = 'numberWithInt'
elif isinstance(data_type, Int64):
result = 'numberWithLong'
elif isinstance(data_type, Float32):
result = 'numberWithDouble'
elif isinstance(data_type, Float64):
result = 'numberWithDouble'
elif is_boolean_type(data_type):
result = 'numberWithBool'
return result
def fmt_signature(func, args, return_type='void', class_func=False):
modifier = '-' if not class_func else '+'
if args:
result = '{} ({}){}:{}'.format(modifier, return_type, func, args)
else:
result = '{} ({}){}'.format(modifier, return_type, func)
return result
def is_primitive_type(data_type):
data_type, _ = unwrap_nullable(data_type)
return data_type.__class__ in _wrapper_primitives
def fmt_var(name):
return fmt_camel(name)
def fmt_property(field):
attrs = ['nonatomic', 'readonly']
data_type, nullable = unwrap_nullable(field.data_type)
if is_string_type(data_type):
attrs.append('copy')
if nullable:
attrs.append('nullable')
base_string = '@property ({}) {}{};'
return base_string.format(', '.join(attrs),
fmt_type(field.data_type, tag=True, is_prop=True),
fmt_var(field.name))
def fmt_import(header_file):
return '#import "{}.h"'.format(header_file)
def fmt_property_str(prop, typ, attrs=None):
if not attrs:
attrs = ['nonatomic', 'readonly']
base_string = '@property ({}) {} {};'
return base_string.format(', '.join(attrs), typ, prop)
def append_to_jazzy_category_dict(jazzy_dict, label, item):
for category_dict in jazzy_dict['custom_categories']:
if category_dict['name'] == label:
category_dict['children'].append(item)
return
return None

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,560 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import re
from stone.backend import CodeBackend
from stone.backends.helpers import fmt_underscores
from stone.backends.python_helpers import (
check_route_name_conflict,
fmt_class,
fmt_func,
fmt_namespace,
fmt_obj,
fmt_type,
fmt_var,
)
from stone.backends.python_types import (
class_name_for_data_type,
)
from stone.ir import (
is_nullable_type,
is_list_type,
is_map_type,
is_struct_type,
is_tag_ref,
is_union_type,
is_user_defined_type,
is_void_type,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
# This will be at the top of the generated file.
base = """\
# -*- coding: utf-8 -*-
# Auto-generated by Stone, do not modify.
# flake8: noqa
# pylint: skip-file
from abc import ABCMeta, abstractmethod
"""
# Matches format of Babel doc tags
doc_sub_tag_re = re.compile(':(?P<tag>[A-z]*):`(?P<val>.*?)`')
DOCSTRING_CLOSE_RESPONSE = """\
If you do not consume the entire response body, then you must call close on the
response object, otherwise you will max out your available connections. We
recommend using the `contextlib.closing
<https://docs.python.org/2/library/contextlib.html#contextlib.closing>`_
context manager to ensure this."""
_cmdline_parser = argparse.ArgumentParser(
prog='python-client-backend',
description=(
'Generates a Python class with a method for each route. Extend the '
'generated class and implement the abstract request() method. This '
'class assumes that the python_types backend was used with the same '
'output directory.'),
)
_cmdline_parser.add_argument(
'-m',
'--module-name',
required=True,
type=str,
help=('The name of the Python module to generate. Please exclude the .py '
'file extension.'),
)
_cmdline_parser.add_argument(
'-c',
'--class-name',
required=True,
type=str,
help='The name of the Python class that contains each route as a method.',
)
_cmdline_parser.add_argument(
'-t',
'--types-package',
required=True,
type=str,
help='The output Python package of the python_types backend.',
)
_cmdline_parser.add_argument(
'-e',
'--error-class-path',
default='.exceptions.ApiError',
type=str,
help=(
"The path to the class that's raised when a route returns an error. "
"The class name is inserted into the doc for route methods."),
)
_cmdline_parser.add_argument(
'-w',
'--auth-type',
type=str,
help='The auth type of the client to generate.',
)
class PythonClientBackend(CodeBackend):
cmdline_parser = _cmdline_parser
supported_auth_types = None
def generate(self, api):
"""Generates a module called "base".
The module will contain a base class that will have a method for
each route across all namespaces.
"""
with self.output_to_relative_path('%s.py' % self.args.module_name):
self.emit_raw(base)
# Import "warnings" if any of the routes are deprecated.
found_deprecated = False
for namespace in api.namespaces.values():
for route in namespace.routes:
if route.deprecated:
self.emit('import warnings')
found_deprecated = True
break
if found_deprecated:
break
self.emit()
self._generate_imports(api.namespaces.values())
self.emit()
self.emit() # PEP-8 expects two-blank lines before class def
self.emit('class %s(object):' % self.args.class_name)
with self.indent():
self.emit('__metaclass__ = ABCMeta')
self.emit()
self.emit('@abstractmethod')
self.emit(
'def request(self, route, namespace, arg, arg_binary=None):')
with self.indent():
self.emit('pass')
self.emit()
self._generate_route_methods(api.namespaces.values())
def _generate_imports(self, namespaces):
# Only import namespaces that have user-defined types defined.
for namespace in namespaces:
if namespace.data_types:
self.emit('from {} import {}'.format(self.args.types_package, fmt_namespace(namespace.name)))
def _generate_route_methods(self, namespaces):
"""Creates methods for the routes in each namespace. All data types
and routes are represented as Python classes."""
self.cur_namespace = None
for namespace in namespaces:
if namespace.routes:
self.emit('# ------------------------------------------')
self.emit('# Routes in {} namespace'.format(namespace.name))
self.emit()
self._generate_routes(namespace)
def _generate_routes(self, namespace):
"""
Generates Python methods that correspond to routes in the namespace.
"""
# Hack: needed for _docf()
self.cur_namespace = namespace
# list of auth_types supported in this base class.
# this is passed with the new -w flag
if self.args.auth_type is not None:
self.supported_auth_types = [auth_type.strip().lower() for auth_type in self.args.auth_type.split(',')]
check_route_name_conflict(namespace)
for route in namespace.routes:
# compatibility mode : included routes are passed by whitelist
# actual auth attr inluded in the route is ignored in this mode.
if self.supported_auth_types is None:
self._generate_route_helper(namespace, route)
if route.attrs.get('style') == 'download':
self._generate_route_helper(namespace, route, True)
else:
route_auth_attr = None
if route.attrs is not None:
route_auth_attr = route.attrs.get('auth')
if route_auth_attr is None:
continue
route_auth_modes = [mode.strip().lower() for mode in route_auth_attr.split(',')]
for base_auth_type in self.supported_auth_types:
if base_auth_type in route_auth_modes:
self._generate_route_helper(namespace, route)
if route.attrs.get('style') == 'download':
self._generate_route_helper(namespace, route, True)
break # to avoid duplicate method declaration in the same base class
def _generate_route_helper(self, namespace, route, download_to_file=False):
"""Generate a Python method that corresponds to a route.
:param namespace: Namespace that the route belongs to.
:param stone.ir.ApiRoute route: IR node for the route.
:param bool download_to_file: Whether a special version of the route
that downloads the response body to a file should be generated.
This can only be used for download-style routes.
"""
arg_data_type = route.arg_data_type
result_data_type = route.result_data_type
request_binary_body = route.attrs.get('style') == 'upload'
response_binary_body = route.attrs.get('style') == 'download'
if download_to_file:
assert response_binary_body, 'download_to_file can only be set ' \
'for download-style routes.'
self._generate_route_method_decl(namespace,
route,
arg_data_type,
request_binary_body,
method_name_suffix='_to_file',
extra_args=['download_path'])
else:
self._generate_route_method_decl(namespace,
route,
arg_data_type,
request_binary_body)
with self.indent():
extra_request_args = None
extra_return_arg = None
footer = None
if request_binary_body:
extra_request_args = [('f',
'bytes',
'Contents to upload.')]
elif download_to_file:
extra_request_args = [('download_path',
'str',
'Path on local machine to save file.')]
if response_binary_body and not download_to_file:
extra_return_arg = ':class:`requests.models.Response`'
footer = DOCSTRING_CLOSE_RESPONSE
if route.doc:
func_docstring = self.process_doc(route.doc, self._docf)
else:
func_docstring = None
self._generate_docstring_for_func(
namespace,
arg_data_type,
result_data_type,
route.error_data_type,
overview=func_docstring,
extra_request_args=extra_request_args,
extra_return_arg=extra_return_arg,
footer=footer,
)
self._maybe_generate_deprecation_warning(route)
# Code to instantiate a class for the request data type
if is_void_type(arg_data_type):
self.emit('arg = None')
elif is_struct_type(arg_data_type):
self.generate_multiline_list(
[f.name for f in arg_data_type.all_fields],
before='arg = {}.{}'.format(
fmt_namespace(arg_data_type.namespace.name),
fmt_class(arg_data_type.name)),
)
elif not is_union_type(arg_data_type):
raise AssertionError('Unhandled request type %r' %
arg_data_type)
# Code to make the request
args = [
'{}.{}'.format(fmt_namespace(namespace.name),
fmt_func(route.name, version=route.version)),
"'{}'".format(namespace.name),
'arg']
if request_binary_body:
args.append('f')
else:
args.append('None')
self.generate_multiline_list(args, 'r = self.request', compact=False)
if download_to_file:
self.emit('self._save_body_to_file(download_path, r[1])')
if is_void_type(result_data_type):
self.emit('return None')
else:
self.emit('return r[0]')
else:
if is_void_type(result_data_type):
self.emit('return None')
else:
self.emit('return r')
self.emit()
def _generate_route_method_decl(
self, namespace, route, arg_data_type, request_binary_body,
method_name_suffix='', extra_args=None):
"""Generates the method prototype for a route."""
args = ['self']
if extra_args:
args += extra_args
if request_binary_body:
args.append('f')
if is_struct_type(arg_data_type):
for field in arg_data_type.all_fields:
if is_nullable_type(field.data_type):
args.append('{}=None'.format(field.name))
elif field.has_default:
# TODO(kelkabany): Decide whether we really want to set the
# default in the argument list. This will send the default
# over the wire even if it isn't overridden. The benefit is
# it locks in a default even if it is changed server-side.
if is_user_defined_type(field.data_type):
ns = field.data_type.namespace
else:
ns = None
arg = '{}={}'.format(
field.name,
self._generate_python_value(ns, field.default))
args.append(arg)
else:
args.append(field.name)
elif is_union_type(arg_data_type):
args.append('arg')
elif not is_void_type(arg_data_type):
raise AssertionError('Unhandled request type: %r' %
arg_data_type)
method_name = fmt_func(route.name + method_name_suffix, version=route.version)
namespace_name = fmt_underscores(namespace.name)
self.generate_multiline_list(args, 'def {}_{}'.format(namespace_name, method_name), ':')
def _maybe_generate_deprecation_warning(self, route):
if route.deprecated:
msg = '{} is deprecated.'.format(route.name)
if route.deprecated.by:
msg += ' Use {}.'.format(route.deprecated.by.name)
args = ["'{}'".format(msg), 'DeprecationWarning']
self.generate_multiline_list(
args,
before='warnings.warn',
delim=('(', ')'),
compact=False,
)
def _generate_docstring_for_func(self, namespace, arg_data_type,
result_data_type=None, error_data_type=None,
overview=None, extra_request_args=None,
extra_return_arg=None, footer=None):
"""
Generates a docstring for a function or method.
This function is versatile. It will create a docstring using all the
data that is provided.
:param arg_data_type: The data type describing the argument to the
route. The data type should be a struct, and each field will be
treated as an input parameter of the method.
:param result_data_type: The data type of the route result.
:param error_data_type: The data type of the route result in the case
of an error.
:param str overview: A description of the route that will be located
at the top of the docstring.
:param extra_request_args: [(field name, field type, field doc), ...]
Describes any additional parameters for the method that aren't a
field in arg_data_type.
:param str extra_return_arg: Name of an additional return type that. If
this is specified, it is assumed that the return of the function
will be a tuple of return_data_type and extra_return-arg.
:param str footer: Additional notes at the end of the docstring.
"""
fields = [] if is_void_type(arg_data_type) else arg_data_type.fields
if not fields and not overview:
# If we don't have an overview or any input parameters, we skip the
# docstring altogether.
return
self.emit('"""')
if overview:
self.emit_wrapped_text(overview)
# Description of all input parameters
if extra_request_args or fields:
if overview:
# Add a blank line if we had an overview
self.emit()
if extra_request_args:
for name, data_type_name, doc in extra_request_args:
if data_type_name:
field_doc = ':param {} {}: {}'.format(data_type_name,
name, doc)
self.emit_wrapped_text(field_doc,
subsequent_prefix=' ')
else:
self.emit_wrapped_text(
':param {}: {}'.format(name, doc),
subsequent_prefix=' ')
if is_struct_type(arg_data_type):
for field in fields:
if field.doc:
if is_user_defined_type(field.data_type):
field_doc = ':param {}: {}'.format(
field.name, self.process_doc(field.doc, self._docf))
else:
field_doc = ':param {} {}: {}'.format(
self._format_type_in_doc(namespace, field.data_type),
field.name,
self.process_doc(field.doc, self._docf),
)
self.emit_wrapped_text(
field_doc, subsequent_prefix=' ')
if is_user_defined_type(field.data_type):
# It's clearer to declare the type of a composite on
# a separate line since it references a class in
# another module
self.emit(':type {}: {}'.format(
field.name,
self._format_type_in_doc(namespace, field.data_type),
))
else:
# If the field has no docstring, then just document its
# type.
field_doc = ':type {}: {}'.format(
field.name,
self._format_type_in_doc(namespace, field.data_type),
)
self.emit_wrapped_text(field_doc)
elif is_union_type(arg_data_type):
if arg_data_type.doc:
self.emit_wrapped_text(':param arg: {}'.format(
self.process_doc(arg_data_type.doc, self._docf)),
subsequent_prefix=' ')
self.emit(':type arg: {}'.format(
self._format_type_in_doc(namespace, arg_data_type)))
if overview and not (extra_request_args or fields):
# Only output an empty line if we had an overview and haven't
# started a section on declaring types.
self.emit()
if extra_return_arg:
# Special case where the function returns a tuple. The first
# element is the JSON response. The second element is the
# the extra_return_arg param.
args = []
if is_void_type(result_data_type):
args.append('None')
else:
rtype = self._format_type_in_doc(namespace,
result_data_type)
args.append(rtype)
args.append(extra_return_arg)
self.generate_multiline_list(args, ':rtype: ')
else:
if is_void_type(result_data_type):
self.emit(':rtype: None')
else:
rtype = self._format_type_in_doc(namespace, result_data_type)
self.emit(':rtype: {}'.format(rtype))
if not is_void_type(error_data_type) and error_data_type.fields:
self.emit(':raises: :class:`{}`'.format(self.args.error_class_path))
self.emit()
# To provide more clarity to a dev who reads the docstring, suggest
# the route's error class. This is confusing, however, because we
# don't know where the error object that's raised will store
# the more detailed route error defined in stone.
error_class_name = self.args.error_class_path.rsplit('.', 1)[-1]
self.emit('If this raises, {} will contain:'.format(error_class_name))
with self.indent():
self.emit(self._format_type_in_doc(namespace, error_data_type))
if footer:
self.emit()
self.emit_wrapped_text(footer)
self.emit('"""')
def _docf(self, tag, val):
"""
Callback used as the handler argument to process_docs(). This converts
Babel doc references to Sphinx-friendly annotations.
"""
if tag == 'type':
fq_val = val
if '.' not in val:
fq_val = self.cur_namespace.name + '.' + fq_val
return ':class:`{}.{}`'.format(self.args.types_package, fq_val)
elif tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
if '.' in val:
return ':meth:`{}`'.format(fmt_func(val, version=version))
else:
return ':meth:`{}_{}`'.format(
self.cur_namespace.name, fmt_func(val, version=version))
elif tag == 'link':
anchor, link = val.rsplit(' ', 1)
return '`{} <{}>`_'.format(anchor, link)
elif tag == 'val':
if val == 'null':
return 'None'
elif val == 'true' or val == 'false':
return '``{}``'.format(val.capitalize())
else:
return val
elif tag == 'field':
return '``{}``'.format(val)
else:
raise RuntimeError('Unknown doc ref tag %r' % tag)
def _format_type_in_doc(self, namespace, data_type):
"""
Returns a string that can be recognized by Sphinx as a type reference
in a docstring.
"""
if is_void_type(data_type):
return 'None'
elif is_user_defined_type(data_type):
return ':class:`{}.{}.{}`'.format(
self.args.types_package, namespace.name, fmt_type(data_type))
elif is_nullable_type(data_type):
return 'Nullable[{}]'.format(
self._format_type_in_doc(namespace, data_type.data_type),
)
elif is_list_type(data_type):
return 'List[{}]'.format(
self._format_type_in_doc(namespace, data_type.data_type),
)
elif is_map_type(data_type):
return 'Map[{}, {}]'.format(
self._format_type_in_doc(namespace, data_type.key_data_type),
self._format_type_in_doc(namespace, data_type.value_data_type),
)
else:
return fmt_type(data_type)
def _generate_python_value(self, namespace, value):
if is_tag_ref(value):
return '{}.{}.{}'.format(
fmt_namespace(namespace.name),
class_name_for_data_type(value.union_data_type),
fmt_var(value.tag_name))
else:
return fmt_obj(value)

View file

@ -0,0 +1,201 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from contextlib import contextmanager
import pprint
from stone.backend import Backend, CodeBackend
from stone.backends.helpers import (
fmt_pascal,
fmt_underscores,
)
from stone.ir import ApiNamespace
from stone.ir import (
AnnotationType,
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
is_user_defined_type,
is_alias,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
_type_table = {
Boolean: 'bool',
Bytes: 'bytes',
Float32: 'float',
Float64: 'float',
Int32: 'int',
Int64: 'int',
List: 'list',
String: 'str',
Timestamp: 'datetime',
UInt32: 'int',
UInt64: 'int',
}
_reserved_keywords = {
'break',
'class',
'continue',
'for',
'pass',
'while',
'async',
}
@contextmanager
def emit_pass_if_nothing_emitted(codegen):
# type: (CodeBackend) -> typing.Iterator[None]
starting_lineno = codegen.lineno
yield
ending_lineno = codegen.lineno
if starting_lineno == ending_lineno:
codegen.emit("pass")
codegen.emit()
def _rename_if_reserved(s):
if s in _reserved_keywords:
return s + '_'
else:
return s
def fmt_class(name, check_reserved=False):
s = fmt_pascal(name)
return _rename_if_reserved(s) if check_reserved else s
def fmt_func(name, check_reserved=False, version=1):
name = fmt_underscores(name)
if check_reserved:
name = _rename_if_reserved(name)
if version > 1:
name = '{}_v{}'.format(name, version)
return name
def fmt_obj(o):
return pprint.pformat(o, width=1)
def fmt_type(data_type):
return _type_table.get(data_type.__class__, fmt_class(data_type.name))
def fmt_var(name, check_reserved=False):
s = fmt_underscores(name)
return _rename_if_reserved(s) if check_reserved else s
def fmt_namespaced_var(ns_name, data_type_name, field_name):
return ".".join([ns_name, data_type_name, fmt_var(field_name)])
def fmt_namespace(name):
return _rename_if_reserved(name)
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, version=route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route
TYPE_IGNORE_COMMENT = " # type: ignore"
def generate_imports_for_referenced_namespaces(
backend, namespace, package, insert_type_ignore=False):
# type: (Backend, ApiNamespace, typing.Text, bool) -> None
"""
Both the true Python backend and the Python PEP 484 Type Stub backend have
to perform the same imports.
:param insert_type_ignore: add a MyPy type-ignore comment to the imports in
the except: clause.
"""
imported_namespaces = namespace.get_imported_namespaces(consider_annotation_types=True)
if not imported_namespaces:
return
type_ignore_comment = TYPE_IGNORE_COMMENT if insert_type_ignore else ""
for ns in imported_namespaces:
backend.emit('from {package} import {namespace_name}{type_ignore_comment}'.format(
package=package,
namespace_name=fmt_namespace(ns.name),
type_ignore_comment=type_ignore_comment
))
backend.emit()
def generate_module_header(backend):
backend.emit('# -*- coding: utf-8 -*-')
backend.emit('# Auto-generated by Stone, do not modify.')
# Silly way to not type ATgenerated in our code to avoid having this
# file marked as auto-generated by our code review tool.
backend.emit('# @{}'.format('generated'))
backend.emit('# flake8: noqa')
backend.emit('# pylint: skip-file')
# This will be at the top of every generated file.
_validators_import_template = """\
from stone.backends.python_rsrc import stone_base as bb{type_ignore_comment}
from stone.backends.python_rsrc import stone_validators as bv{type_ignore_comment}
"""
validators_import = _validators_import_template.format(type_ignore_comment="")
validators_import_with_type_ignore = _validators_import_template.format(
type_ignore_comment=TYPE_IGNORE_COMMENT
)
def prefix_with_ns_if_necessary(name, name_ns, source_ns):
# type: (typing.Text, ApiNamespace, ApiNamespace) -> typing.Text
"""
Returns a name that can be used to reference `name` in namespace `name_ns`
from `source_ns`.
If `source_ns` and `name_ns` are the same, that's just `name`. Otherwise
it's `name_ns`.`name`.
"""
if source_ns == name_ns:
return name
return '{}.{}'.format(fmt_namespace(name_ns.name), name)
def class_name_for_data_type(data_type, ns=None):
"""
Returns the name of the Python class that maps to a user-defined type.
The name is identical to the name in the spec.
If ``ns`` is set to a Namespace and the namespace of `data_type` does
not match, then a namespace prefix is added to the returned name.
For example, ``foreign_ns.TypeName``.
"""
assert is_user_defined_type(data_type) or is_alias(data_type), \
'Expected composite type, got %r' % type(data_type)
name = fmt_class(data_type.name)
if ns:
return prefix_with_ns_if_necessary(name, data_type.namespace, ns)
return name
def class_name_for_annotation_type(annotation_type, ns=None):
"""
Same as class_name_for_data_type, but works with annotation types.
"""
assert isinstance(annotation_type, AnnotationType)
name = fmt_class(annotation_type.name)
if ns:
return prefix_with_ns_if_necessary(name, annotation_type.namespace, ns)
return name

View file

@ -0,0 +1 @@
# Make this a package so that the Python backend tests can import these.

View file

@ -0,0 +1,250 @@
"""
Helpers for representing Stone data types in Python.
"""
from __future__ import absolute_import, unicode_literals
import functools
from stone.backends.python_rsrc import stone_validators as bv
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
class AnnotationType(object):
# This is a base class for all annotation types.
pass
if _MYPY:
T = typing.TypeVar('T', bound=AnnotationType)
U = typing.TypeVar('U')
class NotSet(object):
__slots__ = ()
def __copy__(self):
# type: () -> NotSet
# disable copying so we can do identity comparison even after copying stone objects
return self
def __deepcopy__(self, memo):
# type: (typing.Dict[typing.Text, typing.Any]) -> NotSet
# disable copying so we can do identity comparison even after copying stone objects
return self
def __repr__(self):
return "NOT_SET"
NOT_SET = NotSet() # dummy object to denote that a field has not been set
NO_DEFAULT = object()
class Attribute(object):
__slots__ = ("name", "default", "nullable", "user_defined", "validator")
def __init__(self, name, nullable=False, user_defined=False):
# type: (typing.Text, bool, bool) -> None
# Internal name to store actual value for attribute.
self.name = "_{}_value".format(name)
self.nullable = nullable
self.user_defined = user_defined
# These should be set later, because of possible cross-references.
self.validator = None # type: typing.Any
self.default = NO_DEFAULT
def __get__(self, instance, owner):
# type: (typing.Any, typing.Any) -> typing.Any
if instance is None:
return self
value = getattr(instance, self.name)
if value is not NOT_SET:
return value
if self.nullable:
return None
if self.default is not NO_DEFAULT:
return self.default
# No luck, give a nice error.
raise AttributeError("missing required field '{}'".format(public_name(self.name)))
def __set__(self, instance, value):
# type: (typing.Any, typing.Any) -> None
if self.nullable and value is None:
setattr(instance, self.name, NOT_SET)
return
if self.user_defined:
self.validator.validate_type_only(value)
else:
value = self.validator.validate(value)
setattr(instance, self.name, value)
def __delete__(self, instance):
# type: (typing.Any) -> None
setattr(instance, self.name, NOT_SET)
class Struct(object):
# This is a base class for all classes representing Stone structs.
# every parent class in the inheritance tree must define __slots__ in order to get full memory
# savings
__slots__ = ()
_all_field_names_ = set() # type: typing.Set[str]
def __eq__(self, other):
# type: (object) -> bool
if not isinstance(other, Struct):
return False
if self._all_field_names_ != other._all_field_names_:
return False
if not isinstance(other, self.__class__) and not isinstance(self, other.__class__):
return False
for field_name in self._all_field_names_:
if getattr(self, field_name) != getattr(other, field_name):
return False
return True
def __ne__(self, other):
# type: (object) -> bool
return not self == other
def __repr__(self):
args = ["{}={!r}".format(name, getattr(self, "_{}_value".format(name)))
for name in sorted(self._all_field_names_)]
return "{}({})".format(type(self).__name__, ", ".join(args))
def _process_custom_annotations(self, annotation_type, field_path, processor):
# type: (typing.Type[T], typing.Text, typing.Callable[[T, U], U]) -> None
pass
class Union(object):
# TODO(kelkabany): Possible optimization is to remove _value if a
# union is composed of only symbols.
__slots__ = ['_tag', '_value']
_tagmap = {} # type: typing.Dict[str, bv.Validator]
_permissioned_tagmaps = set() # type: typing.Set[typing.Text]
def __init__(self, tag, value=None):
validator = None
tagmap_names = ['_{}_tagmap'.format(map_name) for map_name in self._permissioned_tagmaps]
for tagmap_name in ['_tagmap'] + tagmap_names:
if tag in getattr(self, tagmap_name):
validator = getattr(self, tagmap_name)[tag]
assert validator is not None, 'Invalid tag %r.' % tag
if isinstance(validator, bv.Void):
assert value is None, 'Void type union member must have None value.'
elif isinstance(validator, (bv.Struct, bv.Union)):
validator.validate_type_only(value)
else:
validator.validate(value)
self._tag = tag
self._value = value
def __eq__(self, other):
# Also need to check if one class is a subclass of another. If one union extends another,
# the common fields should be able to be compared to each other.
return (
isinstance(other, Union) and
(isinstance(self, other.__class__) or isinstance(other, self.__class__)) and
self._tag == other._tag and self._value == other._value
)
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash((self._tag, self._value))
def __repr__(self):
return "{}({!r}, {!r})".format(type(self).__name__, self._tag, self._value)
def _process_custom_annotations(self, annotation_type, field_path, processor):
# type: (typing.Type[T], typing.Text, typing.Callable[[T, U], U]) -> None
pass
@classmethod
def _is_tag_present(cls, tag, caller_permissions):
assert tag is not None, 'tag value should not be None'
if tag in cls._tagmap:
return True
for extra_permission in caller_permissions.permissions:
tagmap_name = '_{}_tagmap'.format(extra_permission)
if hasattr(cls, tagmap_name) and tag in getattr(cls, tagmap_name):
return True
return False
@classmethod
def _get_val_data_type(cls, tag, caller_permissions):
assert tag is not None, 'tag value should not be None'
for extra_permission in caller_permissions.permissions:
tagmap_name = '_{}_tagmap'.format(extra_permission)
if hasattr(cls, tagmap_name) and tag in getattr(cls, tagmap_name):
return getattr(cls, tagmap_name)[tag]
return cls._tagmap[tag]
class Route(object):
__slots__ = ("name", "version", "deprecated", "arg_type", "result_type", "error_type", "attrs")
def __init__(self, name, version, deprecated, arg_type, result_type, error_type, attrs):
self.name = name
self.version = version
self.deprecated = deprecated
self.arg_type = arg_type
self.result_type = result_type
self.error_type = error_type
assert isinstance(attrs, dict), 'Expected dict, got %r' % attrs
self.attrs = attrs
def __repr__(self):
return 'Route({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r})'.format(
self.name,
self.version,
self.deprecated,
self.arg_type,
self.result_type,
self.error_type,
self.attrs)
# helper functions used when constructing custom annotation processors
# put this here so that every other file doesn't need to import functools
partially_apply = functools.partial
def make_struct_annotation_processor(annotation_type, processor):
def g(field_path, struct):
if struct is None:
return struct
struct._process_custom_annotations(annotation_type, field_path, processor)
return struct
return g
def make_list_annotation_processor(processor):
def g(field_path, list_):
if list_ is None:
return list_
return [processor('{}[{}]'.format(field_path, idx), x) for idx, x in enumerate(list_)]
return g
def make_map_value_annotation_processor(processor):
def g(field_path, map_):
if map_ is None:
return map_
return {k: processor('{}[{}]'.format(field_path, repr(k)), v) for k, v in map_.items()}
return g
def public_name(name):
# _some_attr_value -> some_attr
return "_".join(name.split("_")[1:-1])

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,722 @@
"""
Defines classes to represent each Stone type in Python. These classes should
be used to validate Python objects and normalize them for a given type.
The data types defined here should not be specific to an RPC or serialization
format.
"""
from __future__ import absolute_import, unicode_literals
import datetime
import hashlib
import math
import numbers
import re
from abc import ABCMeta, abstractmethod
import six
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# See <http://python3porting.com/differences.html#buffer>
if six.PY3:
_binary_types = (bytes, memoryview) # noqa: E501,F821 # pylint: disable=undefined-variable,useless-suppression
else:
_binary_types = (bytes, buffer) # noqa: E501,F821 # pylint: disable=undefined-variable,useless-suppression
class ValidationError(Exception):
"""Raised when a value doesn't pass validation by its validator."""
def __init__(self, message, parent=None):
"""
Args:
message (str): Error message detailing validation failure.
parent (str): Adds the parent as the closest reference point for
the error. Use :meth:`add_parent` to add more.
"""
super(ValidationError, self).__init__(message)
self.message = message
self._parents = []
if parent:
self._parents.append(parent)
def add_parent(self, parent):
"""
Args:
parent (str): Adds the parent to the top of the tree of references
that lead to the validator that failed.
"""
self._parents.append(parent)
def __str__(self):
"""
Returns:
str: A descriptive message of the validation error that may also
include the path to the validator that failed.
"""
if self._parents:
return '{}: {}'.format('.'.join(self._parents[::-1]), self.message)
else:
return self.message
def __repr__(self):
# Not a perfect repr, but includes the error location information.
return 'ValidationError(%r)' % six.text_type(self)
def type_name_with_module(t):
# type: (typing.Type[typing.Any]) -> six.text_type
return '%s.%s' % (t.__module__, t.__name__)
def generic_type_name(v):
# type: (typing.Any) -> six.text_type
"""Return a descriptive type name that isn't Python specific. For example,
an int value will return 'integer' rather than 'int'."""
if isinstance(v, bool):
# Must come before any numbers checks since booleans are integers too
return 'boolean'
elif isinstance(v, numbers.Integral):
# Must come before real numbers check since integrals are reals too
return 'integer'
elif isinstance(v, numbers.Real):
return 'float'
elif isinstance(v, (tuple, list)):
return 'list'
elif isinstance(v, six.string_types):
return 'string'
elif v is None:
return 'null'
else:
return type_name_with_module(type(v))
class Validator(six.with_metaclass(ABCMeta, object)):
"""All primitive and composite data types should be a subclass of this."""
__slots__ = ("_redact",)
@abstractmethod
def validate(self, val):
"""Validates that val is of this data type.
Returns: A normalized value if validation succeeds.
Raises: ValidationError
"""
def has_default(self):
return False
def get_default(self):
raise AssertionError('No default available.')
class Primitive(Validator): # pylint: disable=abstract-method
"""A basic type that is defined by Stone."""
__slots__ = ()
class Boolean(Primitive):
__slots__ = ()
def validate(self, val):
if not isinstance(val, bool):
raise ValidationError('%r is not a valid boolean' % val)
return val
class Integer(Primitive):
"""
Do not use this class directly. Extend it and specify a 'default_minimum' and
'default_maximum' value as class variables for a more restrictive integer range.
"""
__slots__ = ("minimum", "maximum")
default_minimum = None # type: typing.Optional[int]
default_maximum = None # type: typing.Optional[int]
def __init__(self, min_value=None, max_value=None):
"""
A more restrictive minimum or maximum value can be specified than the
range inherent to the defined type.
"""
if min_value is not None:
assert isinstance(min_value, numbers.Integral), \
'min_value must be an integral number'
assert min_value >= self.default_minimum, \
'min_value cannot be less than the minimum value for this ' \
'type (%d < %d)' % (min_value, self.default_minimum)
self.minimum = min_value
else:
self.minimum = self.default_minimum
if max_value is not None:
assert isinstance(max_value, numbers.Integral), \
'max_value must be an integral number'
assert max_value <= self.default_maximum, \
'max_value cannot be greater than the maximum value for ' \
'this type (%d < %d)' % (max_value, self.default_maximum)
self.maximum = max_value
else:
self.maximum = self.default_maximum
def validate(self, val):
if not isinstance(val, numbers.Integral):
raise ValidationError('expected integer, got %s'
% generic_type_name(val))
elif not (self.minimum <= val <= self.maximum):
raise ValidationError('%d is not within range [%d, %d]'
% (val, self.minimum, self.maximum))
return val
def __repr__(self):
return '%s()' % self.__class__.__name__
class Int32(Integer):
__slots__ = ()
default_minimum = -2**31
default_maximum = 2**31 - 1
class UInt32(Integer):
__slots__ = ()
default_minimum = 0
default_maximum = 2**32 - 1
class Int64(Integer):
__slots__ = ()
default_minimum = -2**63
default_maximum = 2**63 - 1
class UInt64(Integer):
__slots__ = ()
default_minimum = 0
default_maximum = 2**64 - 1
class Real(Primitive):
"""
Do not use this class directly. Extend it and optionally set a 'default_minimum'
and 'default_maximum' value to enforce a range that's a subset of the Python float
implementation. Python floats are doubles.
"""
__slots__ = ("minimum", "maximum")
default_minimum = None # type: typing.Optional[float]
default_maximum = None # type: typing.Optional[float]
def __init__(self, min_value=None, max_value=None):
"""
A more restrictive minimum or maximum value can be specified than the
range inherent to the defined type.
"""
if min_value is not None:
assert isinstance(min_value, numbers.Real), \
'min_value must be a real number'
if not isinstance(min_value, float):
try:
min_value = float(min_value)
except OverflowError:
raise AssertionError('min_value is too small for a float')
if self.default_minimum is not None and min_value < self.default_minimum:
raise AssertionError('min_value cannot be less than the '
'minimum value for this type (%f < %f)' %
(min_value, self.default_minimum))
self.minimum = min_value
else:
self.minimum = self.default_minimum
if max_value is not None:
assert isinstance(max_value, numbers.Real), \
'max_value must be a real number'
if not isinstance(max_value, float):
try:
max_value = float(max_value)
except OverflowError:
raise AssertionError('max_value is too large for a float')
if self.default_maximum is not None and max_value > self.default_maximum:
raise AssertionError('max_value cannot be greater than the '
'maximum value for this type (%f < %f)' %
(max_value, self.default_maximum))
self.maximum = max_value
else:
self.maximum = self.default_maximum
def validate(self, val):
if not isinstance(val, numbers.Real):
raise ValidationError('expected real number, got %s' %
generic_type_name(val))
if not isinstance(val, float):
# This checks for the case where a number is passed in with a
# magnitude larger than supported by float64.
try:
val = float(val)
except OverflowError:
raise ValidationError('too large for float')
if math.isnan(val) or math.isinf(val):
raise ValidationError('%f values are not supported' % val)
if self.minimum is not None and val < self.minimum:
raise ValidationError('%f is not greater than %f' %
(val, self.minimum))
if self.maximum is not None and val > self.maximum:
raise ValidationError('%f is not less than %f' %
(val, self.maximum))
return val
def __repr__(self):
return '%s()' % self.__class__.__name__
class Float32(Real):
__slots__ = ()
# Maximum and minimums from the IEEE 754-1985 standard
default_minimum = -3.40282 * 10**38
default_maximum = 3.40282 * 10**38
class Float64(Real):
__slots__ = ()
class String(Primitive):
"""Represents a unicode string."""
__slots__ = ("min_length", "max_length", "pattern", "pattern_re")
def __init__(self, min_length=None, max_length=None, pattern=None):
if min_length is not None:
assert isinstance(min_length, numbers.Integral), \
'min_length must be an integral number'
assert min_length >= 0, 'min_length must be >= 0'
if max_length is not None:
assert isinstance(max_length, numbers.Integral), \
'max_length must be an integral number'
assert max_length > 0, 'max_length must be > 0'
if min_length and max_length:
assert max_length >= min_length, 'max_length must be >= min_length'
if pattern is not None:
assert isinstance(pattern, six.string_types), \
'pattern must be a string'
self.min_length = min_length
self.max_length = max_length
self.pattern = pattern
self.pattern_re = None
if pattern:
try:
self.pattern_re = re.compile(r"\A(?:" + pattern + r")\Z")
except re.error as e:
raise AssertionError('Regex {!r} failed: {}'.format(
pattern, e.args[0]))
def validate(self, val):
"""
A unicode string of the correct length and pattern will pass validation.
In PY2, we enforce that a str type must be valid utf-8, and a unicode
string will be returned.
"""
if not isinstance(val, six.string_types):
raise ValidationError("'%s' expected to be a string, got %s"
% (val, generic_type_name(val)))
if not six.PY3 and isinstance(val, str):
try:
val = val.decode('utf-8')
except UnicodeDecodeError:
raise ValidationError("'%s' was not valid utf-8")
if self.max_length is not None and len(val) > self.max_length:
raise ValidationError("'%s' must be at most %d characters, got %d"
% (val, self.max_length, len(val)))
if self.min_length is not None and len(val) < self.min_length:
raise ValidationError("'%s' must be at least %d characters, got %d"
% (val, self.min_length, len(val)))
if self.pattern and not self.pattern_re.match(val):
raise ValidationError("'%s' did not match pattern '%s'"
% (val, self.pattern))
return val
class Bytes(Primitive):
__slots__ = ("min_length", "max_length")
def __init__(self, min_length=None, max_length=None):
if min_length is not None:
assert isinstance(min_length, numbers.Integral), \
'min_length must be an integral number'
assert min_length >= 0, 'min_length must be >= 0'
if max_length is not None:
assert isinstance(max_length, numbers.Integral), \
'max_length must be an integral number'
assert max_length > 0, 'max_length must be > 0'
if min_length is not None and max_length is not None:
assert max_length >= min_length, 'max_length must be >= min_length'
self.min_length = min_length
self.max_length = max_length
def validate(self, val):
if not isinstance(val, _binary_types):
raise ValidationError("expected bytes type, got %s"
% generic_type_name(val))
elif self.max_length is not None and len(val) > self.max_length:
raise ValidationError("'%s' must have at most %d bytes, got %d"
% (val, self.max_length, len(val)))
elif self.min_length is not None and len(val) < self.min_length:
raise ValidationError("'%s' has fewer than %d bytes, got %d"
% (val, self.min_length, len(val)))
return val
class Timestamp(Primitive):
"""Note that while a format is specified, it isn't used in validation
since a native Python datetime object is preferred. The format, however,
can and should be used by serializers."""
__slots__ = ("format",)
def __init__(self, fmt):
"""fmt must be composed of format codes that the C standard (1989)
supports, most notably in its strftime() function."""
assert isinstance(fmt, six.text_type), 'format must be a string'
self.format = fmt
def validate(self, val):
if not isinstance(val, datetime.datetime):
raise ValidationError('expected timestamp, got %s'
% generic_type_name(val))
elif val.tzinfo is not None and \
val.tzinfo.utcoffset(val).total_seconds() != 0:
raise ValidationError('timestamp should have either a UTC '
'timezone or none set at all')
return val
class Composite(Validator): # pylint: disable=abstract-method
"""Validator for a type that builds on other primitive and composite
types."""
__slots__ = ()
class List(Composite):
"""Assumes list contents are homogeneous with respect to types."""
__slots__ = ("item_validator", "min_items", "max_items")
def __init__(self, item_validator, min_items=None, max_items=None):
"""Every list item will be validated with item_validator."""
self.item_validator = item_validator
if min_items is not None:
assert isinstance(min_items, numbers.Integral), \
'min_items must be an integral number'
assert min_items >= 0, 'min_items must be >= 0'
if max_items is not None:
assert isinstance(max_items, numbers.Integral), \
'max_items must be an integral number'
assert max_items > 0, 'max_items must be > 0'
if min_items is not None and max_items is not None:
assert max_items >= min_items, 'max_items must be >= min_items'
self.min_items = min_items
self.max_items = max_items
def validate(self, val):
if not isinstance(val, (tuple, list)):
raise ValidationError('%r is not a valid list' % val)
elif self.max_items is not None and len(val) > self.max_items:
raise ValidationError('%r has more than %s items'
% (val, self.max_items))
elif self.min_items is not None and len(val) < self.min_items:
raise ValidationError('%r has fewer than %s items'
% (val, self.min_items))
return [self.item_validator.validate(item) for item in val]
class Map(Composite):
"""Assumes map keys and values are homogeneous with respect to types."""
__slots__ = ("key_validator", "value_validator")
def __init__(self, key_validator, value_validator):
"""
Every Map key/value pair will be validated with item_validator.
key validators must be a subclass of a String validator
"""
self.key_validator = key_validator
self.value_validator = value_validator
def validate(self, val):
if not isinstance(val, dict):
raise ValidationError('%r is not a valid dict' % val)
return {
self.key_validator.validate(key):
self.value_validator.validate(value) for key, value in val.items()
}
class Struct(Composite):
__slots__ = ("definition",)
def __init__(self, definition):
"""
Args:
definition (class): A generated class representing a Stone struct
from a spec. Must have a _fields_ attribute with the following
structure:
_fields_ = [(field_name, validator), ...]
where
field_name: Name of the field (str).
validator: Validator object.
"""
super(Struct, self).__init__()
self.definition = definition
def validate(self, val):
"""
For a val to pass validation, val must be of the correct type and have
all required fields present.
"""
self.validate_type_only(val)
self.validate_fields_only(val)
return val
def validate_with_permissions(self, val, caller_permissions):
"""
For a val to pass validation, val must be of the correct type and have
all required permissioned fields present. Should only be called
for callers with extra permissions.
"""
self.validate(val)
self.validate_fields_only_with_permissions(val, caller_permissions)
return val
def validate_fields_only(self, val):
"""
To pass field validation, no required field should be missing.
This method assumes that the contents of each field have already been
validated on assignment, so it's merely a presence check.
FIXME(kelkabany): Since the definition object does not maintain a list
of which fields are required, all fields are scanned.
"""
for field_name in self.definition._all_field_names_:
if not hasattr(val, field_name):
raise ValidationError("missing required field '%s'" %
field_name)
def validate_fields_only_with_permissions(self, val, caller_permissions):
"""
To pass field validation, no required field should be missing.
This method assumes that the contents of each field have already been
validated on assignment, so it's merely a presence check.
Should only be called for callers with extra permissions.
"""
self.validate_fields_only(val)
# check if type has been patched
for extra_permission in caller_permissions.permissions:
all_field_names = '_all_{}_field_names_'.format(extra_permission)
for field_name in getattr(self.definition, all_field_names, set()):
if not hasattr(val, field_name):
raise ValidationError("missing required field '%s'" % field_name)
def validate_type_only(self, val):
"""
Use this when you only want to validate that the type of an object
is correct, but not yet validate each field.
"""
# Since the definition maintains the list of fields for serialization,
# we're okay with a subclass that might have extra information. This
# makes it easier to return one subclass for two routes, one of which
# relies on the parent class.
if not isinstance(val, self.definition):
raise ValidationError('expected type %s, got %s' %
(
type_name_with_module(self.definition),
generic_type_name(val),
),
)
def has_default(self):
return not self.definition._has_required_fields
def get_default(self):
assert not self.definition._has_required_fields, 'No default available.'
return self.definition()
class StructTree(Struct):
"""Validator for structs with enumerated subtypes.
NOTE: validate_fields_only() validates the fields known to this base
struct, but does not do any validation specific to the subtype.
"""
__slots__ = ()
# See PyCQA/pylint#1043 for why this is disabled; this should show up
# as a usless-suppression (and can be removed) once a fix is released
def __init__(self, definition): # pylint: disable=useless-super-delegation
super(StructTree, self).__init__(definition)
class Union(Composite):
__slots__ = ("definition",)
def __init__(self, definition):
"""
Args:
definition (class): A generated class representing a Stone union
from a spec. Must have a _tagmap attribute with the following
structure:
_tagmap = {field_name: validator, ...}
where
field_name (str): Tag name.
validator (Validator): Tag value validator.
"""
self.definition = definition
def validate(self, val):
"""
For a val to pass validation, it must have a _tag set. This assumes
that the object validated that _tag is a valid tag, and that any
associated value has also been validated.
"""
self.validate_type_only(val)
if not hasattr(val, '_tag') or val._tag is None:
raise ValidationError('no tag set')
return val
def validate_type_only(self, val):
"""
Use this when you only want to validate that the type of an object
is correct, but not yet validate each field.
We check whether val is a Python parent class of the definition. This
is because Union subtyping works in the opposite direction of Python
inheritance. For example, if a union U2 extends U1 in Python, this
validator will accept U1 in places where U2 is expected.
"""
if not issubclass(self.definition, type(val)):
raise ValidationError('expected type %s or subtype, got %s' %
(
type_name_with_module(self.definition),
generic_type_name(val),
),
)
class Void(Primitive):
__slots__ = ()
def validate(self, val):
if val is not None:
raise ValidationError('expected NoneType, got %s' %
generic_type_name(val))
def has_default(self):
return True
def get_default(self):
return None
class Nullable(Validator):
__slots__ = ("validator",)
def __init__(self, validator):
super(Nullable, self).__init__()
assert isinstance(validator, (Primitive, Composite)), \
'validator must be for a primitive or composite type'
assert not isinstance(validator, Nullable), \
'nullables cannot be stacked'
assert not isinstance(validator, Void), \
'void cannot be made nullable'
self.validator = validator
def validate(self, val):
if val is None:
return
else:
return self.validator.validate(val)
def validate_type_only(self, val):
"""Use this only if Nullable is wrapping a Composite."""
if val is None:
return
else:
return self.validator.validate_type_only(val)
def has_default(self):
return True
def get_default(self):
return None
class Redactor(object):
__slots__ = ("regex",)
def __init__(self, regex):
"""
Args:
regex: What parts of the field to redact.
"""
self.regex = regex
@abstractmethod
def apply(self, val):
"""Redacts information from annotated field.
Returns: A redacted version of the string provided.
"""
def _get_matches(self, val):
if not self.regex:
return None
try:
return re.search(self.regex, val)
except TypeError:
return None
class HashRedactor(Redactor):
__slots__ = ()
def apply(self, val):
matches = self._get_matches(val)
val_to_hash = str(val) if isinstance(val, int) or isinstance(val, float) else val
try:
# add string literal to ensure unicode
hashed = hashlib.md5(val_to_hash.encode('utf-8')).hexdigest() + ''
except [AttributeError, ValueError]:
hashed = None
if matches:
blotted = '***'.join(matches.groups())
if hashed:
return '{} ({})'.format(hashed, blotted)
return blotted
return hashed
class BlotRedactor(Redactor):
__slots__ = ()
def apply(self, val):
matches = self._get_matches(val)
if matches:
return '***'.join(matches.groups())
return '********'

View file

@ -0,0 +1,114 @@
from stone.ir import (
Alias,
ApiNamespace,
DataType,
List,
Map,
Nullable,
Timestamp,
UserDefined,
is_alias,
is_boolean_type,
is_bytes_type,
is_float_type,
is_integer_type,
is_list_type,
is_map_type,
is_nullable_type,
is_string_type,
is_timestamp_type,
is_user_defined_type,
is_void_type,
)
from stone.backends.python_helpers import class_name_for_data_type, fmt_namespace
from stone.ir.data_types import String
from stone.typing_hacks import cast
MYPY = False
if MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
DataTypeCls = typing.Type[DataType]
# Unfortunately these are codependent, so I'll weakly type the Dict in Callback
Callback = typing.Callable[
[ApiNamespace, DataType, typing.Dict[typing.Any, typing.Any]],
typing.Text
]
OverrideDefaultTypesDict = typing.Dict[DataTypeCls, Callback]
else:
OverrideDefaultTypesDict = "OverrideDefaultTypesDict"
def map_stone_type_to_python_type(ns, data_type, override_dict=None):
# type: (ApiNamespace, DataType, typing.Optional[OverrideDefaultTypesDict]) -> typing.Text
"""
Args:
override_dict: lets you override the default behavior for a given type by hooking into
a callback. (Currently only hooked up for stone's List and Nullable)
"""
override_dict = override_dict or {}
if is_string_type(data_type):
string_override = override_dict.get(String, None)
if string_override:
return string_override(ns, data_type, override_dict)
return 'str'
elif is_bytes_type(data_type):
return 'bytes'
elif is_boolean_type(data_type):
return 'bool'
elif is_float_type(data_type):
return 'float'
elif is_integer_type(data_type):
return 'int'
elif is_void_type(data_type):
return 'None'
elif is_timestamp_type(data_type):
timestamp_override = override_dict.get(Timestamp, None)
if timestamp_override:
return timestamp_override(ns, data_type, override_dict)
return 'datetime.datetime'
elif is_alias(data_type):
alias_type = cast(Alias, data_type)
return map_stone_type_to_python_type(ns, alias_type.data_type, override_dict)
elif is_user_defined_type(data_type):
user_defined_type = cast(UserDefined, data_type)
class_name = class_name_for_data_type(user_defined_type)
if user_defined_type.namespace.name != ns.name:
return '{}.{}'.format(
fmt_namespace(user_defined_type.namespace.name), class_name)
else:
return class_name
elif is_list_type(data_type):
list_type = cast(List, data_type)
if List in override_dict:
return override_dict[List](ns, list_type.data_type, override_dict)
# PyCharm understands this description format for a list
return 'list of [{}]'.format(
map_stone_type_to_python_type(ns, list_type.data_type, override_dict)
)
elif is_map_type(data_type):
map_type = cast(Map, data_type)
if Map in override_dict:
return override_dict[Map](
ns,
data_type,
override_dict
)
return 'dict of [{}:{}]'.format(
map_stone_type_to_python_type(ns, map_type.key_data_type, override_dict),
map_stone_type_to_python_type(ns, map_type.value_data_type, override_dict)
)
elif is_nullable_type(data_type):
nullable_type = cast(Nullable, data_type)
if Nullable in override_dict:
return override_dict[Nullable](ns, nullable_type.data_type, override_dict)
return 'Optional[{}]'.format(
map_stone_type_to_python_type(ns, nullable_type.data_type, override_dict)
)
else:
raise TypeError('Unknown data type %r' % data_type)

View file

@ -0,0 +1,486 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
from six import StringIO
from stone.backend import CodeBackend
from stone.backends.python_helpers import (
check_route_name_conflict,
class_name_for_annotation_type,
class_name_for_data_type,
emit_pass_if_nothing_emitted,
fmt_func,
fmt_namespace,
fmt_var,
generate_imports_for_referenced_namespaces,
generate_module_header,
validators_import_with_type_ignore,
)
from stone.backends.python_type_mapping import (
map_stone_type_to_python_type,
OverrideDefaultTypesDict,
)
from stone.ir import (
Alias,
AnnotationType,
Api,
ApiNamespace,
DataType,
is_nullable_type,
is_struct_type,
is_union_type,
is_user_defined_type,
is_void_type,
List,
Map,
Nullable,
Struct,
Timestamp,
Union,
unwrap_aliases,
)
from stone.ir.data_types import String
from stone.typing_hacks import cast
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
class ImportTracker(object):
def __init__(self):
# type: () -> None
self.cur_namespace_typing_imports = set() # type: typing.Set[typing.Text]
self.cur_namespace_adhoc_imports = set() # type: typing.Set[typing.Text]
def clear(self):
# type: () -> None
self.cur_namespace_typing_imports.clear()
self.cur_namespace_adhoc_imports.clear()
def _register_typing_import(self, s):
# type: (typing.Text) -> None
"""
Denotes that we need to import something specifically from the `typing` module.
For example, _register_typing_import("Optional")
"""
self.cur_namespace_typing_imports.add(s)
def _register_adhoc_import(self, s):
# type: (typing.Text) -> None
"""
Denotes an ad-hoc import.
For example,
_register_adhoc_import("import datetime")
or
_register_adhoc_import("from xyz import abc")
"""
self.cur_namespace_adhoc_imports.add(s)
_cmdline_parser = argparse.ArgumentParser(prog='python-types-backend')
_cmdline_parser.add_argument(
'-p',
'--package',
type=str,
required=True,
help='Package prefix for absolute imports in generated files.',
)
class PythonTypeStubsBackend(CodeBackend):
"""Generates Python modules to represent the input Stone spec."""
cmdline_parser = _cmdline_parser
# Instance var of the current namespace being generated
cur_namespace = None
preserve_aliases = True
import_tracker = ImportTracker()
def __init__(self, *args, **kwargs):
# type: (...) -> None
super(PythonTypeStubsBackend, self).__init__(*args, **kwargs)
self._pep_484_type_mapping_callbacks = self._get_pep_484_type_mapping_callbacks()
def generate(self, api):
# type: (Api) -> None
"""
Generates a module for each namespace.
Each namespace will have Python classes to represent data types and
routes in the Stone spec.
"""
for namespace in api.namespaces.values():
with self.output_to_relative_path('{}.pyi'.format(fmt_namespace(namespace.name))):
self._generate_base_namespace_module(namespace)
def _generate_base_namespace_module(self, namespace):
# type: (ApiNamespace) -> None
"""Creates a module for the namespace. All data types and routes are
represented as Python classes."""
self.cur_namespace = namespace
self.import_tracker.clear()
generate_module_header(self)
self.emit_placeholder('imports_needed_for_typing')
self.emit_raw(validators_import_with_type_ignore)
# Generate import statements for all referenced namespaces.
self._generate_imports_for_referenced_namespaces(namespace)
self._generate_typevars()
for annotation_type in namespace.annotation_types:
self._generate_annotation_type_class(namespace, annotation_type)
for data_type in namespace.linearize_data_types():
if isinstance(data_type, Struct):
self._generate_struct_class(namespace, data_type)
elif isinstance(data_type, Union):
self._generate_union_class(namespace, data_type)
else:
raise TypeError('Cannot handle type %r' % type(data_type))
for alias in namespace.linearize_aliases():
self._generate_alias_definition(namespace, alias)
self._generate_routes(namespace)
self._generate_imports_needed_for_typing()
def _generate_imports_for_referenced_namespaces(self, namespace):
# type: (ApiNamespace) -> None
assert self.args is not None
generate_imports_for_referenced_namespaces(
backend=self,
namespace=namespace,
package=self.args.package,
insert_type_ignore=True,
)
def _generate_typevars(self):
# type: () -> None
"""
Creates type variables that are used by the type signatures for
_process_custom_annotations.
"""
self.emit("T = TypeVar('T', bound=bb.AnnotationType)")
self.emit("U = TypeVar('U')")
self.import_tracker._register_typing_import('TypeVar')
self.emit()
def _generate_annotation_type_class(self, ns, annotation_type):
# type: (ApiNamespace, AnnotationType) -> None
"""Defines a Python class that represents an annotation type in Stone."""
self.emit('class {}(object):'.format(class_name_for_annotation_type(annotation_type, ns)))
with self.indent():
self._generate_annotation_type_class_init(ns, annotation_type)
self._generate_annotation_type_class_properties(ns, annotation_type)
self.emit()
def _generate_annotation_type_class_init(self, ns, annotation_type):
# type: (ApiNamespace, AnnotationType) -> None
args = ['self']
for param in annotation_type.params:
param_name = fmt_var(param.name, True)
param_type = self.map_stone_type_to_pep484_type(ns, param.data_type)
if not is_nullable_type(param.data_type):
self.import_tracker._register_typing_import('Optional')
param_type = 'Optional[{}]'.format(param_type)
args.append(
"{param_name}: {param_type} = ...".format(
param_name=param_name,
param_type=param_type))
self.generate_multiline_list(args, before='def __init__', after=' -> None: ...')
self.emit()
def _generate_annotation_type_class_properties(self, ns, annotation_type):
# type: (ApiNamespace, AnnotationType) -> None
for param in annotation_type.params:
prop_name = fmt_var(param.name, True)
param_type = self.map_stone_type_to_pep484_type(ns, param.data_type)
self.emit('@property')
self.emit('def {prop_name}(self) -> {param_type}: ...'.format(
prop_name=prop_name,
param_type=param_type,
))
self.emit()
def _generate_struct_class(self, ns, data_type):
# type: (ApiNamespace, Struct) -> None
"""Defines a Python class that represents a struct in Stone."""
self.emit(self._class_declaration_for_type(ns, data_type))
with self.indent():
self._generate_struct_class_init(ns, data_type)
self._generate_struct_class_properties(ns, data_type)
self._generate_struct_or_union_class_custom_annotations()
self._generate_validator_for(data_type)
self.emit()
def _generate_validator_for(self, data_type):
# type: (DataType) -> None
cls_name = class_name_for_data_type(data_type)
self.emit("{}_validator: bv.Validator = ...".format(
cls_name
))
def _generate_union_class(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
self.emit(self._class_declaration_for_type(ns, data_type))
with self.indent(), emit_pass_if_nothing_emitted(self):
self._generate_union_class_vars(ns, data_type)
self._generate_union_class_is_set(data_type)
self._generate_union_class_variant_creators(ns, data_type)
self._generate_union_class_get_helpers(ns, data_type)
self._generate_struct_or_union_class_custom_annotations()
self._generate_validator_for(data_type)
self.emit()
def _generate_union_class_vars(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
lineno = self.lineno
# Generate stubs for class variables so that IDEs like PyCharms have an
# easier time detecting their existence.
for field in data_type.fields:
if is_void_type(field.data_type):
field_name = fmt_var(field.name)
field_type = class_name_for_data_type(data_type, ns)
self.emit('{field_name}: {field_type} = ...'.format(
field_name=field_name,
field_type=field_type,
))
if lineno != self.lineno:
self.emit()
def _generate_union_class_is_set(self, union):
# type: (Union) -> None
for field in union.fields:
field_name = fmt_func(field.name)
self.emit('def is_{}(self) -> bool: ...'.format(field_name))
self.emit()
def _generate_union_class_variant_creators(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
"""
Generate the following section in the 'union Shape' example:
@classmethod
def circle(cls, val: float) -> Shape: ...
"""
union_type = class_name_for_data_type(data_type)
for field in data_type.fields:
if not is_void_type(field.data_type):
field_name_reserved_check = fmt_func(field.name, check_reserved=True)
val_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
self.emit('@classmethod')
self.emit('def {field_name}(cls, val: {val_type}) -> {union_type}: ...'.format(
field_name=field_name_reserved_check,
val_type=val_type,
union_type=union_type,
))
self.emit()
def _generate_union_class_get_helpers(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
"""
Generates the following section in the 'union Shape' example:
def get_circle(self) -> float: ...
"""
for field in data_type.fields:
field_name = fmt_func(field.name)
if not is_void_type(field.data_type):
# generate getter for field
val_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
self.emit('def get_{field_name}(self) -> {val_type}: ...'.format(
field_name=field_name,
val_type=val_type,
))
self.emit()
def _generate_alias_definition(self, namespace, alias):
# type: (ApiNamespace, Alias) -> None
self._generate_validator_for(alias)
unwrapped_dt, _ = unwrap_aliases(alias)
if is_user_defined_type(unwrapped_dt):
# If the alias is to a composite type, we want to alias the
# generated class as well.
self.emit('{} = {}'.format(
alias.name,
class_name_for_data_type(alias.data_type, namespace)))
def _class_declaration_for_type(self, ns, data_type):
# type: (ApiNamespace, typing.Union[Struct, Union]) -> typing.Text
assert is_user_defined_type(data_type), \
'Expected struct, got %r' % type(data_type)
if data_type.parent_type:
extends = class_name_for_data_type(data_type.parent_type, ns)
else:
if is_struct_type(data_type):
# Use a handwritten base class
extends = 'bb.Struct'
elif is_union_type(data_type):
extends = 'bb.Union'
else:
extends = 'object'
return 'class {}({}):'.format(
class_name_for_data_type(data_type), extends)
def _generate_struct_class_init(self, ns, struct):
# type: (ApiNamespace, Struct) -> None
args = ["self"]
for field in struct.all_fields:
field_name_reserved_check = fmt_var(field.name, True)
field_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
if field.has_default:
self.import_tracker._register_typing_import('Optional')
field_type = 'Optional[{}]'.format(field_type)
args.append("{field_name}: {field_type} = ...".format(
field_name=field_name_reserved_check,
field_type=field_type))
self.generate_multiline_list(args, before='def __init__', after=' -> None: ...')
def _generate_struct_class_properties(self, ns, struct):
# type: (ApiNamespace, Struct) -> None
to_emit = [] # type: typing.List[typing.Text]
for field in struct.all_fields:
field_name_reserved_check = fmt_func(field.name, check_reserved=True)
field_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
to_emit.append(
"{}: bb.Attribute[{}] = ...".format(field_name_reserved_check, field_type)
)
for s in to_emit:
self.emit(s)
def _generate_struct_or_union_class_custom_annotations(self):
"""
The _process_custom_annotations function allows client code to access
custom annotations defined in the spec.
"""
self.emit('def _process_custom_annotations(')
with self.indent():
self.emit('self,')
self.emit('annotation_type: Type[T],')
self.emit('field_path: Text,')
self.emit('processor: Callable[[T, U], U],')
self.import_tracker._register_typing_import('Type')
self.import_tracker._register_typing_import('Text')
self.import_tracker._register_typing_import('Callable')
self.emit(') -> None: ...')
self.emit()
def _get_pep_484_type_mapping_callbacks(self):
# type: () -> OverrideDefaultTypesDict
"""
Once-per-instance, generate a mapping from
"List" -> return pep4848-compatible List[SomeType]
"Nullable" -> return pep484-compatible Optional[SomeType]
This is per-instance because we have to also call `self._register_typing_import`, because
we need to potentially import some things.
"""
def upon_encountering_list(ns, data_type, override_dict):
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
self.import_tracker._register_typing_import("List")
return "List[{}]".format(
map_stone_type_to_python_type(ns, data_type, override_dict)
)
def upon_encountering_map(ns, map_data_type, override_dict):
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
map_type = cast(Map, map_data_type)
self.import_tracker._register_typing_import("Dict")
return "Dict[{}, {}]".format(
map_stone_type_to_python_type(ns, map_type.key_data_type, override_dict),
map_stone_type_to_python_type(ns, map_type.value_data_type, override_dict)
)
def upon_encountering_nullable(ns, data_type, override_dict):
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
self.import_tracker._register_typing_import("Optional")
return "Optional[{}]".format(
map_stone_type_to_python_type(ns, data_type, override_dict)
)
def upon_encountering_timestamp(
ns, data_type, override_dict
): # pylint: disable=unused-argument
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
self.import_tracker._register_adhoc_import("import datetime")
return map_stone_type_to_python_type(ns, data_type)
def upon_encountering_string(
ns, data_type, override_dict
): # pylint: disable=unused-argument
# type: (...) -> typing.Text
self.import_tracker._register_typing_import("Text")
return "Text"
callback_dict = {
List: upon_encountering_list,
Map: upon_encountering_map,
Nullable: upon_encountering_nullable,
Timestamp: upon_encountering_timestamp,
String: upon_encountering_string,
} # type: OverrideDefaultTypesDict
return callback_dict
def map_stone_type_to_pep484_type(self, ns, data_type):
# type: (ApiNamespace, DataType) -> typing.Text
assert self._pep_484_type_mapping_callbacks
return map_stone_type_to_python_type(ns, data_type,
override_dict=self._pep_484_type_mapping_callbacks)
def _generate_routes(
self,
namespace, # type: ApiNamespace
):
# type: (...) -> None
check_route_name_conflict(namespace)
for route in namespace.routes:
self.emit(
"{method_name}: bb.Route = ...".format(
method_name=fmt_func(route.name, version=route.version)))
if namespace.routes:
self.emit()
def _generate_imports_needed_for_typing(self):
# type: () -> None
output_buffer = StringIO()
with self.capture_emitted_output(output_buffer):
if self.import_tracker.cur_namespace_typing_imports:
self.emit("")
self.emit('from typing import (')
with self.indent():
for to_import in sorted(self.import_tracker.cur_namespace_typing_imports):
self.emit("{},".format(to_import))
self.emit(')')
if self.import_tracker.cur_namespace_adhoc_imports:
self.emit("")
for to_import in self.import_tracker.cur_namespace_adhoc_imports:
self.emit(to_import)
self.add_named_placeholder('imports_needed_for_typing', output_buffer.getvalue())

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,196 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from contextlib import contextmanager
from stone.ir import (
Boolean,
Bytes,
DataType,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_list_type,
is_timestamp_type,
is_union_type,
is_user_defined_type,
unwrap_nullable,
)
from stone.backend import CodeBackend
from stone.backends.swift_helpers import (
fmt_class,
fmt_func,
fmt_obj,
fmt_type,
fmt_var,
)
_serial_type_table = {
Boolean: 'BoolSerializer',
Bytes: 'NSDataSerializer',
Float32: 'FloatSerializer',
Float64: 'DoubleSerializer',
Int32: 'Int32Serializer',
Int64: 'Int64Serializer',
List: 'ArraySerializer',
String: 'StringSerializer',
Timestamp: 'NSDateSerializer',
UInt32: 'UInt32Serializer',
UInt64: 'UInt64Serializer',
Void: 'VoidSerializer',
}
stone_warning = """\
///
/// Copyright (c) 2016 Dropbox, Inc. All rights reserved.
///
/// Auto-generated by Stone, do not modify.
///
"""
# This will be at the top of the generated file.
base = """\
{}\
import Foundation
""".format(stone_warning)
undocumented = '(no description)'
class SwiftBaseBackend(CodeBackend):
"""Wrapper class over Stone generator for Swift logic."""
# pylint: disable=abstract-method
@contextmanager
def function_block(self, func, args, return_type=None):
signature = '{}({})'.format(func, args)
if return_type:
signature += ' -> {}'.format(return_type)
with self.block(signature):
yield
def _func_args(self, args_list, newlines=False, force_first=False, not_init=False):
out = []
first = True
for k, v in args_list:
# this is a temporary hack -- injected client-side args
# do not have a separate field for default value. Right now,
# default values are stored along with the type, e.g.
# `Bool = True` is a type, hence this check.
if first and force_first and '=' not in v:
k = "{0} {0}".format(k)
if first and v is not None and not_init:
out.append('{}'.format(v))
elif v is not None:
out.append('{}: {}'.format(k, v))
first = False
sep = ', '
if newlines:
sep += '\n' + self.make_indent()
return sep.join(out)
@contextmanager
def class_block(self, thing, protocols=None):
protocols = protocols or []
extensions = []
if isinstance(thing, DataType):
name = fmt_class(thing.name)
if thing.parent_type:
extensions.append(fmt_type(thing.parent_type))
else:
name = thing
extensions.extend(protocols)
extend_suffix = ': {}'.format(', '.join(extensions)) if extensions else ''
with self.block('open class {}{}'.format(name, extend_suffix)):
yield
def _struct_init_args(self, data_type, namespace=None): # pylint: disable=unused-argument
args = []
for field in data_type.all_fields:
name = fmt_var(field.name)
value = fmt_type(field.data_type)
data_type, nullable = unwrap_nullable(field.data_type)
if field.has_default:
if is_union_type(data_type):
default = '.{}'.format(fmt_var(field.default.tag_name))
else:
default = fmt_obj(field.default)
value += ' = {}'.format(default)
elif nullable:
value += ' = nil'
arg = (name, value)
args.append(arg)
return args
def _docf(self, tag, val):
if tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
return fmt_func(val, version)
elif tag == 'field':
if '.' in val:
cls, field = val.split('.')
return ('{} in {}'.format(fmt_var(field),
fmt_class(cls)))
else:
return fmt_var(val)
elif tag in ('type', 'val', 'link'):
return val
else:
import pdb
pdb.set_trace()
return val
def fmt_serial_type(data_type):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}.{}Serializer'
result = result.format(fmt_class(data_type.namespace.name),
fmt_class(data_type.name))
else:
result = _serial_type_table.get(data_type.__class__, fmt_class(data_type.name))
if is_list_type(data_type):
result = result + '<{}>'.format(fmt_serial_type(data_type.data_type))
return result if not nullable else 'NullableSerializer'
def fmt_serial_obj(data_type):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}.{}Serializer()'
result = result.format(fmt_class(data_type.namespace.name),
fmt_class(data_type.name))
else:
result = _serial_type_table.get(data_type.__class__, fmt_class(data_type.name))
if is_list_type(data_type):
result = result + '({})'.format(fmt_serial_obj(data_type.data_type))
elif is_timestamp_type(data_type):
result = result + '("{}")'.format(data_type.format)
else:
result = 'Serialization._{}'.format(result)
return result if not nullable else 'NullableSerializer({})'.format(result)

View file

@ -0,0 +1,280 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
from stone.ir import (
is_struct_type,
is_union_type,
is_void_type,
)
from stone.backends.swift import (
base,
fmt_serial_type,
SwiftBaseBackend,
undocumented,
)
from stone.backends.swift_helpers import (
check_route_name_conflict,
fmt_class,
fmt_func,
fmt_var,
fmt_type,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(
prog='swift-client-backend',
description=(
'Generates a Swift class with an object for each namespace, and in each '
'namespace object, a method for each route. This class assumes that the '
'swift_types backend was used with the same output directory.'),
)
_cmdline_parser.add_argument(
'-m',
'--module-name',
required=True,
type=str,
help=('The name of the Swift module to generate. Please exclude the .swift '
'file extension.'),
)
_cmdline_parser.add_argument(
'-c',
'--class-name',
required=True,
type=str,
help=('The name of the Swift class that contains an object for each namespace, '
'and in each namespace object, a method for each route.')
)
_cmdline_parser.add_argument(
'-t',
'--transport-client-name',
required=True,
type=str,
help='The name of the Swift class that manages network API calls.',
)
_cmdline_parser.add_argument(
'-y',
'--client-args',
required=True,
type=str,
help='The client-side route arguments to append to each route by style type.',
)
_cmdline_parser.add_argument(
'-z',
'--style-to-request',
required=True,
type=str,
help='The dict that maps a style type to a Swift request object name.',
)
class SwiftBackend(SwiftBaseBackend):
"""
Generates Swift client base that implements route interfaces.
Examples:
```
open class ExampleClientBase {
/// Routes within the namespace1 namespace. See Namespace1 for details.
open var namespace1: Namespace1!
/// Routes within the namespace2 namespace. See Namespace2 for details.
open var namespace2: Namespace2!
public init(client: ExampleTransportClient) {
self.namespace1 = Namespace1(client: client)
self.namespace2 = Namespace2(client: client)
}
}
```
Here, `ExampleTransportClient` would contain the implementation of a handwritten,
project-specific networking client. Additionally, the `Namespace1` object would
have as its methods all routes in the `Namespace1` namespace. A hypothetical 'copy'
enpoding might be implemented like:
```
open func copy(fromPath fromPath: String, toPath: String) ->
ExampleRequestType<Namespace1.CopySerializer, Namespace1.CopyErrorSerializer> {
let route = Namespace1.copy
let serverArgs = Namespace1.CopyArg(fromPath: fromPath, toPath: toPath)
return client.request(route, serverArgs: serverArgs)
}
```
Here, ExampleRequestType is a project-specific request type, parameterized by response and
error serializers.
"""
cmdline_parser = _cmdline_parser
def generate(self, api):
for namespace in api.namespaces.values():
ns_class = fmt_class(namespace.name)
if namespace.routes:
with self.output_to_relative_path('{}Routes.swift'.format(ns_class)):
self._generate_routes(namespace)
with self.output_to_relative_path('{}.swift'.format(self.args.module_name)):
self._generate_client(api)
def _generate_client(self, api):
self.emit_raw(base)
self.emit('import Alamofire')
self.emit()
with self.block('open class {}'.format(self.args.class_name)):
namespace_fields = []
for namespace in api.namespaces.values():
if namespace.routes:
namespace_fields.append((namespace.name,
fmt_class(namespace.name)))
for var, typ in namespace_fields:
self.emit('/// Routes within the {} namespace. '
'See {}Routes for details.'.format(var, typ))
self.emit('open var {}: {}Routes!'.format(var, typ))
self.emit()
with self.function_block('public init', args=self._func_args(
[('client', '{}'.format(self.args.transport_client_name))])):
for var, typ in namespace_fields:
self.emit('self.{} = {}Routes(client: client)'.format(var, typ))
def _generate_routes(self, namespace):
check_route_name_conflict(namespace)
ns_class = fmt_class(namespace.name)
self.emit_raw(base)
self.emit('/// Routes for the {} namespace'.format(namespace.name))
with self.block('open class {}Routes'.format(ns_class)):
self.emit('public let client: {}'.format(self.args.transport_client_name))
args = [('client', '{}'.format(self.args.transport_client_name))]
with self.function_block('init', self._func_args(args)):
self.emit('self.client = client')
self.emit()
for route in namespace.routes:
self._generate_route(namespace, route)
def _get_route_args(self, namespace, route):
data_type = route.arg_data_type
arg_type = fmt_type(data_type)
if is_struct_type(data_type):
arg_list = self._struct_init_args(data_type, namespace=namespace)
doc_list = [(fmt_var(f.name), self.process_doc(f.doc, self._docf)
if f.doc else undocumented) for f in data_type.fields if f.doc]
elif is_union_type(data_type):
arg_list = [(fmt_var(data_type.name), '{}.{}'.format(
fmt_class(namespace.name), fmt_class(data_type.name)))]
doc_list = [(fmt_var(data_type.name),
self.process_doc(data_type.doc, self._docf)
if data_type.doc else 'The {} union'.format(fmt_class(data_type.name)))]
else:
arg_list = [] if is_void_type(data_type) else [('request', arg_type)]
doc_list = []
return arg_list, doc_list
def _emit_route(self, namespace, route, req_obj_name, extra_args=None, extra_docs=None):
arg_list, doc_list = self._get_route_args(namespace, route)
extra_args = extra_args or []
extra_docs = extra_docs or []
arg_type = fmt_type(route.arg_data_type)
func_name = fmt_func(route.name, route.version)
if route.doc:
route_doc = self.process_doc(route.doc, self._docf)
else:
route_doc = 'The {} route'.format(func_name)
self.emit_wrapped_text(route_doc, prefix='/// ', width=120)
self.emit('///')
for name, doc in doc_list + extra_docs:
param_doc = '- parameter {}: {}'.format(name, doc if doc is not None else undocumented)
self.emit_wrapped_text(param_doc, prefix='/// ', width=120)
self.emit('///')
output = (' - returns: Through the response callback, the caller will ' +
'receive a `{}` object on success or a `{}` object on failure.')
output = output.format(fmt_type(route.result_data_type),
fmt_type(route.error_data_type))
self.emit_wrapped_text(output, prefix='/// ', width=120)
func_args = [
('route', '{}.{}'.format(fmt_class(namespace.name), func_name)),
]
client_args = []
return_args = [('route', 'route')]
for name, value, typ in extra_args:
arg_list.append((name, typ))
func_args.append((name, value))
client_args.append((name, value))
rtype = fmt_serial_type(route.result_data_type)
etype = fmt_serial_type(route.error_data_type)
self._maybe_generate_deprecation_warning(route)
with self.function_block('@discardableResult open func {}'.format(func_name),
args=self._func_args(arg_list, force_first=False),
return_type='{}<{}, {}>'.format(req_obj_name, rtype, etype)):
self.emit('let route = {}.{}'.format(fmt_class(namespace.name), func_name))
if is_struct_type(route.arg_data_type):
args = [(name, name) for name, _ in self._struct_init_args(route.arg_data_type)]
func_args += [('serverArgs', '{}({})'.format(arg_type, self._func_args(args)))]
self.emit('let serverArgs = {}({})'.format(arg_type, self._func_args(args)))
elif is_union_type(route.arg_data_type):
self.emit('let serverArgs = {}'.format(fmt_var(route.arg_data_type.name)))
if not is_void_type(route.arg_data_type):
return_args += [('serverArgs', 'serverArgs')]
return_args += client_args
txt = 'return client.request({})'.format(
self._func_args(return_args, not_init=True)
)
self.emit(txt)
self.emit()
def _maybe_generate_deprecation_warning(self, route):
if route.deprecated:
msg = '{} is deprecated.'.format(fmt_func(route.name, route.version))
if route.deprecated.by:
msg += ' Use {}.'.format(
fmt_func(route.deprecated.by.name, route.deprecated.by.version))
self.emit('@available(*, unavailable, message:"{}")'.format(msg))
def _generate_route(self, namespace, route):
route_type = route.attrs.get('style')
client_args = json.loads(self.args.client_args)
style_to_request = json.loads(self.args.style_to_request)
if route_type not in client_args.keys():
self._emit_route(namespace, route, style_to_request[route_type])
else:
for args_data in client_args[route_type]:
req_obj_key, type_data_list = tuple(args_data)
req_obj_name = style_to_request[req_obj_key]
extra_args = [tuple(type_data[:-1]) for type_data in type_data_list]
extra_docs = [(type_data[0], type_data[-1]) for type_data in type_data_list]
self._emit_route(namespace, route, req_obj_name, extra_args, extra_docs)

View file

@ -0,0 +1,173 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import pprint
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_boolean_type,
is_list_type,
is_numeric_type,
is_string_type,
is_tag_ref,
is_user_defined_type,
unwrap_nullable,
)
from .helpers import split_words
# This file defines *stylistic* choices for Swift
# (ie, that class names are UpperCamelCase and that variables are lowerCamelCase)
_type_table = {
Boolean: 'Bool',
Bytes: 'Data',
Float32: 'Float',
Float64: 'Double',
Int32: 'Int32',
Int64: 'Int64',
List: 'Array',
String: 'String',
Timestamp: 'Date',
UInt32: 'UInt32',
UInt64: 'UInt64',
Void: 'Void',
}
_reserved_words = {
'description',
'bool',
'nsdata'
'float',
'double',
'int32',
'int64',
'list',
'string',
'timestamp',
'uint32',
'uint64',
'void',
'associatedtype',
'class',
'deinit',
'enum',
'extension',
'func',
'import',
'init',
'inout',
'internal',
'let',
'operator',
'private',
'protocol',
'public',
'static',
'struct',
'subscript',
'typealias',
'var',
'default',
}
def fmt_obj(o):
assert not isinstance(o, dict), "Only use for base type literals"
if o is True:
return 'true'
if o is False:
return 'false'
if o is None:
return 'nil'
if o == u'':
return '""'
return pprint.pformat(o, width=1)
def _format_camelcase(name, lower_first=True):
words = [word.capitalize() for word in split_words(name)]
if lower_first:
words[0] = words[0].lower()
ret = ''.join(words)
if ret.lower() in _reserved_words:
ret += '_'
return ret
def fmt_class(name):
return _format_camelcase(name, lower_first=False)
def fmt_func(name, version):
if version > 1:
name = '{}_v{}'.format(name, version)
name = _format_camelcase(name)
return name
def fmt_type(data_type):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}.{}'.format(fmt_class(data_type.namespace.name),
fmt_class(data_type.name))
else:
result = _type_table.get(data_type.__class__, fmt_class(data_type.name))
if is_list_type(data_type):
result = result + '<{}>'.format(fmt_type(data_type.data_type))
return result if not nullable else result + '?'
def fmt_var(name):
return _format_camelcase(name)
def fmt_default_value(namespace, field):
if is_tag_ref(field.default):
return '{}.{}Serializer().serialize(.{})'.format(
fmt_class(namespace.name),
fmt_class(field.default.union_data_type.name),
fmt_var(field.default.tag_name))
elif is_list_type(field.data_type):
return '.array({})'.format(field.default)
elif is_numeric_type(field.data_type):
return '.number({})'.format(field.default)
elif is_string_type(field.data_type):
return '.str({})'.format(fmt_obj(field.default))
elif is_boolean_type(field.data_type):
if field.default:
bool_str = '1'
else:
bool_str = '0'
return '.number({})'.format(bool_str)
else:
raise TypeError('Can\'t handle default value type %r' %
type(field.data_type))
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route

View file

@ -0,0 +1,486 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import os
import shutil
import six
from contextlib import contextmanager
from stone.ir import (
is_list_type,
is_numeric_type,
is_string_type,
is_struct_type,
is_union_type,
is_void_type,
unwrap_nullable,
)
from stone.backends.swift_helpers import (
check_route_name_conflict,
fmt_class,
fmt_default_value,
fmt_func,
fmt_var,
fmt_type,
)
from stone.backends.swift import (
base,
fmt_serial_obj,
SwiftBaseBackend,
undocumented,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(prog='swift-types-backend')
_cmdline_parser.add_argument(
'-r',
'--route-method',
help=('A string used to construct the location of a Swift method for a '
'given route; use {ns} as a placeholder for namespace name and '
'{route} for the route name.'),
)
class SwiftTypesBackend(SwiftBaseBackend):
"""
Generates Swift modules to represent the input Stone spec.
Examples for a hypothetical 'copy' enpoint:
Endpoint argument (struct):
```
open class CopyArg: CustomStringConvertible {
open let fromPath: String
open let toPath: String
public init(fromPath: String, toPath: String) {
stringValidator(pattern: "/(.|[\\r\\n])*")(value: fromPath)
self.fromPath = fromPath
stringValidator(pattern: "/(.|[\\r\\n])*")(value: toPath)
self.toPath = toPath
}
open var description: String {
return "\\(SerializeUtil.prepareJSONForSerialization(
CopyArgSerializer().serialize(self)))"
}
}
```
Endpoint error (union):
```
open enum CopyError: CustomStringConvertible {
case TooManyFiles
case Other
open var description: String {
return "\\(SerializeUtil.prepareJSONForSerialization(
CopyErrorSerializer().serialize(self)))"
}
}
```
Argument serializer (error serializer not listed):
```
open class CopyArgSerializer: JSONSerializer {
public init() { }
open func serialize(value: CopyArg) -> JSON {
let output = [
"from_path": Serialization.serialize(value.fromPath),
"to_path": Serialization.serialize(value.toPath),
]
return .Dictionary(output)
}
open func deserialize(json: JSON) -> CopyArg {
switch json {
case .Dictionary(let dict):
let fromPath = Serialization.deserialize(dict["from_path"] ?? .Null)
let toPath = Serialization.deserialize(dict["to_path"] ?? .Null)
return CopyArg(fromPath: fromPath, toPath: toPath)
default:
fatalError("Type error deserializing")
}
}
}
```
"""
cmdline_parser = _cmdline_parser
def generate(self, api):
rsrc_folder = os.path.join(os.path.dirname(__file__), 'swift_rsrc')
self.logger.info('Copying StoneValidators.swift to output folder')
shutil.copy(os.path.join(rsrc_folder, 'StoneValidators.swift'),
self.target_folder_path)
self.logger.info('Copying StoneSerializers.swift to output folder')
shutil.copy(os.path.join(rsrc_folder, 'StoneSerializers.swift'),
self.target_folder_path)
self.logger.info('Copying StoneBase.swift to output folder')
shutil.copy(os.path.join(rsrc_folder, 'StoneBase.swift'),
self.target_folder_path)
jazzy_cfg_path = os.path.join('../Format', 'jazzy.json')
with open(jazzy_cfg_path) as jazzy_file:
jazzy_cfg = json.load(jazzy_file)
for namespace in api.namespaces.values():
ns_class = fmt_class(namespace.name)
with self.output_to_relative_path('{}.swift'.format(ns_class)):
self._generate_base_namespace_module(api, namespace)
jazzy_cfg['custom_categories'][1]['children'].append(ns_class)
if namespace.routes:
jazzy_cfg['custom_categories'][0]['children'].append(ns_class + 'Routes')
with self.output_to_relative_path('../../../../.jazzy.json'):
self.emit_raw(json.dumps(jazzy_cfg, indent=2) + '\n')
def _generate_base_namespace_module(self, api, namespace):
self.emit_raw(base)
routes_base = 'Datatypes and serializers for the {} namespace'.format(namespace.name)
self.emit_wrapped_text(routes_base, prefix='/// ', width=120)
with self.block('open class {}'.format(fmt_class(namespace.name))):
for data_type in namespace.linearize_data_types():
if is_struct_type(data_type):
self._generate_struct_class(namespace, data_type)
self.emit()
elif is_union_type(data_type):
self._generate_union_type(namespace, data_type)
self.emit()
if namespace.routes:
self._generate_route_objects(api.route_schema, namespace)
def _generate_struct_class(self, namespace, data_type):
if data_type.doc:
doc = self.process_doc(data_type.doc, self._docf)
else:
doc = 'The {} struct'.format(fmt_class(data_type.name))
self.emit_wrapped_text(doc, prefix='/// ', width=120)
protocols = []
if not data_type.parent_type:
protocols.append('CustomStringConvertible')
with self.class_block(data_type, protocols=protocols):
for field in data_type.fields:
fdoc = self.process_doc(field.doc,
self._docf) if field.doc else undocumented
self.emit_wrapped_text(fdoc, prefix='/// ', width=120)
self.emit('public let {}: {}'.format(
fmt_var(field.name),
fmt_type(field.data_type),
))
self._generate_struct_init(namespace, data_type)
decl = 'open var' if not data_type.parent_type else 'open override var'
with self.block('{} description: String'.format(decl)):
cls = fmt_class(data_type.name) + 'Serializer'
self.emit('return "\\(SerializeUtil.prepareJSONForSerialization' +
'({}().serialize(self)))"'.format(cls))
self._generate_struct_class_serializer(namespace, data_type)
def _generate_struct_init(self, namespace, data_type): # pylint: disable=unused-argument
# init method
args = self._struct_init_args(data_type)
if data_type.parent_type and not data_type.fields:
return
with self.function_block('public init', self._func_args(args)):
for field in data_type.fields:
v = fmt_var(field.name)
validator = self._determine_validator_type(field.data_type, v)
if validator:
self.emit('{}({})'.format(validator, v))
self.emit('self.{0} = {0}'.format(v))
if data_type.parent_type:
func_args = [(fmt_var(f.name),
fmt_var(f.name))
for f in data_type.parent_type.all_fields]
self.emit('super.init({})'.format(self._func_args(func_args)))
def _determine_validator_type(self, data_type, value):
data_type, nullable = unwrap_nullable(data_type)
if is_list_type(data_type):
item_validator = self._determine_validator_type(data_type.data_type, value)
if item_validator:
v = "arrayValidator({})".format(
self._func_args([
("minItems", data_type.min_items),
("maxItems", data_type.max_items),
("itemValidator", item_validator),
])
)
else:
return None
elif is_numeric_type(data_type):
v = "comparableValidator({})".format(
self._func_args([
("minValue", data_type.min_value),
("maxValue", data_type.max_value),
])
)
elif is_string_type(data_type):
pat = data_type.pattern if data_type.pattern else None
pat = pat.encode('unicode_escape').replace(six.ensure_binary("\""),
six.ensure_binary("\\\"")) if pat else pat
v = "stringValidator({})".format(
self._func_args([
("minLength", data_type.min_length),
("maxLength", data_type.max_length),
("pattern", '"{}"'.format(six.ensure_str(pat)) if pat else None),
])
)
else:
return None
if nullable:
v = "nullableValidator({})".format(v)
return v
def _generate_enumerated_subtype_serializer(self, namespace, # pylint: disable=unused-argument
data_type):
with self.block('switch value'):
for tags, subtype in data_type.get_all_subtypes_with_tags():
assert len(tags) == 1, tags
tag = tags[0]
tagvar = fmt_var(tag)
self.emit('case let {} as {}:'.format(
tagvar,
fmt_type(subtype)
))
with self.indent():
block_txt = 'for (k, v) in Serialization.getFields({}.serialize({}))'.format(
fmt_serial_obj(subtype),
tagvar,
)
with self.block(block_txt):
self.emit('output[k] = v')
self.emit('output[".tag"] = .str("{}")'.format(tag))
self.emit('default: fatalError("Tried to serialize unexpected subtype")')
def _generate_struct_base_class_deserializer(self, namespace, data_type):
args = []
for field in data_type.all_fields:
var = fmt_var(field.name)
value = 'dict["{}"]'.format(field.name)
self.emit('let {} = {}.deserialize({} ?? {})'.format(
var,
fmt_serial_obj(field.data_type),
value,
fmt_default_value(namespace, field) if field.has_default else '.null'
))
args.append((var, var))
self.emit('return {}({})'.format(
fmt_class(data_type.name),
self._func_args(args)
))
def _generate_enumerated_subtype_deserializer(self, namespace, data_type):
self.emit('let tag = Serialization.getTag(dict)')
with self.block('switch tag'):
for tags, subtype in data_type.get_all_subtypes_with_tags():
assert len(tags) == 1, tags
tag = tags[0]
self.emit('case "{}":'.format(tag))
with self.indent():
self.emit('return {}.deserialize(json)'.format(fmt_serial_obj(subtype)))
self.emit('default:')
with self.indent():
if data_type.is_catch_all():
self._generate_struct_base_class_deserializer(namespace, data_type)
else:
self.emit('fatalError("Unknown tag \\(tag)")')
def _generate_struct_class_serializer(self, namespace, data_type):
with self.serializer_block(data_type):
with self.serializer_func(data_type):
if not data_type.all_fields:
self.emit('let output = [String: JSON]()')
else:
intro = 'var' if data_type.has_enumerated_subtypes() else 'let'
self.emit("{} output = [ ".format(intro))
for field in data_type.all_fields:
self.emit('"{}": {}.serialize(value.{}),'.format(
field.name,
fmt_serial_obj(field.data_type),
fmt_var(field.name)
))
self.emit(']')
if data_type.has_enumerated_subtypes():
self._generate_enumerated_subtype_serializer(namespace, data_type)
self.emit('return .dictionary(output)')
with self.deserializer_func(data_type):
with self.block("switch json"):
dict_name = "let dict" if data_type.all_fields else "_"
self.emit("case .dictionary({}):".format(dict_name))
with self.indent():
if data_type.has_enumerated_subtypes():
self._generate_enumerated_subtype_deserializer(namespace, data_type)
else:
self._generate_struct_base_class_deserializer(namespace, data_type)
self.emit("default:")
with self.indent():
self.emit('fatalError("Type error deserializing")')
def _format_tag_type(self, namespace, data_type): # pylint: disable=unused-argument
if is_void_type(data_type):
return ''
else:
return '({})'.format(fmt_type(data_type))
def _generate_union_type(self, namespace, data_type):
if data_type.doc:
doc = self.process_doc(data_type.doc, self._docf)
else:
doc = 'The {} union'.format(fmt_class(data_type.name))
self.emit_wrapped_text(doc, prefix='/// ', width=120)
class_type = fmt_class(data_type.name)
with self.block('public enum {}: CustomStringConvertible'.format(class_type)):
for field in data_type.all_fields:
typ = self._format_tag_type(namespace, field.data_type)
fdoc = self.process_doc(field.doc,
self._docf) if field.doc else 'An unspecified error.'
self.emit_wrapped_text(fdoc, prefix='/// ', width=120)
self.emit('case {}{}'.format(fmt_var(field.name), typ))
self.emit()
with self.block('public var description: String'):
cls = class_type + 'Serializer'
self.emit('return "\\(SerializeUtil.prepareJSONForSerialization' +
'({}().serialize(self)))"'.format(cls))
self._generate_union_serializer(data_type)
def _tag_type(self, data_type, field):
return "{}.{}".format(
fmt_class(data_type.name),
fmt_var(field.name)
)
def _generate_union_serializer(self, data_type):
with self.serializer_block(data_type):
with self.serializer_func(data_type), self.block('switch value'):
for field in data_type.all_fields:
field_type = field.data_type
case = '.{}{}'.format(fmt_var(field.name),
'' if is_void_type(field_type) else '(let arg)')
self.emit('case {}:'.format(case))
with self.indent():
if is_void_type(field_type):
self.emit('var d = [String: JSON]()')
elif (is_struct_type(field_type) and
not field_type.has_enumerated_subtypes()):
self.emit('var d = Serialization.getFields({}.serialize(arg))'.format(
fmt_serial_obj(field_type)))
else:
self.emit('var d = ["{}": {}.serialize(arg)]'.format(
field.name,
fmt_serial_obj(field_type)))
self.emit('d[".tag"] = .str("{}")'.format(field.name))
self.emit('return .dictionary(d)')
with self.deserializer_func(data_type):
with self.block("switch json"):
self.emit("case .dictionary(let d):")
with self.indent():
self.emit('let tag = Serialization.getTag(d)')
with self.block('switch tag'):
for field in data_type.all_fields:
field_type = field.data_type
self.emit('case "{}":'.format(field.name))
tag_type = self._tag_type(data_type, field)
with self.indent():
if is_void_type(field_type):
self.emit('return {}'.format(tag_type))
else:
if (is_struct_type(field_type) and
not field_type.has_enumerated_subtypes()):
subdict = 'json'
else:
subdict = 'd["{}"] ?? .null'.format(field.name)
self.emit('let v = {}.deserialize({})'.format(
fmt_serial_obj(field_type), subdict
))
self.emit('return {}(v)'.format(tag_type))
self.emit('default:')
with self.indent():
if data_type.catch_all_field:
self.emit('return {}'.format(
self._tag_type(data_type, data_type.catch_all_field)
))
else:
self.emit('fatalError("Unknown tag \\(tag)")')
self.emit("default:")
with self.indent():
self.emit('fatalError("Failed to deserialize")')
@contextmanager
def serializer_block(self, data_type):
with self.class_block(fmt_class(data_type.name) + 'Serializer',
protocols=['JSONSerializer']):
self.emit("public init() { }")
yield
@contextmanager
def serializer_func(self, data_type):
with self.function_block('open func serialize',
args=self._func_args([('_ value', fmt_class(data_type.name))]),
return_type='JSON'):
yield
@contextmanager
def deserializer_func(self, data_type):
with self.function_block('open func deserialize',
args=self._func_args([('_ json', 'JSON')]),
return_type=fmt_class(data_type.name)):
yield
def _generate_route_objects(self, route_schema, namespace):
check_route_name_conflict(namespace)
self.emit()
self.emit('/// Stone Route Objects')
self.emit()
for route in namespace.routes:
var_name = fmt_func(route.name, route.version)
with self.block('static let {} = Route('.format(var_name),
delim=(None, None), after=')'):
self.emit('name: \"{}\",'.format(route.name))
self.emit('version: {},'.format(route.version))
self.emit('namespace: \"{}\",'.format(namespace.name))
self.emit('deprecated: {},'.format('true' if route.deprecated
is not None else 'false'))
self.emit('argSerializer: {},'.format(fmt_serial_obj(route.arg_data_type)))
self.emit('responseSerializer: {},'.format(fmt_serial_obj(route.result_data_type)))
self.emit('errorSerializer: {},'.format(fmt_serial_obj(route.error_data_type)))
attrs = []
for field in route_schema.fields:
attr_key = field.name
attr_val = ("\"{}\"".format(route.attrs.get(attr_key))
if route.attrs.get(attr_key) else 'nil')
attrs.append('\"{}\": {}'.format(attr_key, attr_val))
self.generate_multiline_list(
attrs, delim=('attrs: [', ']'), compact=True)

View file

@ -0,0 +1,152 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import re
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
from stone.backend import CodeBackend
from stone.backends.tsd_helpers import (
check_route_name_conflict,
fmt_error_type,
fmt_func,
fmt_tag,
fmt_type,
)
from stone.ir import Void
_cmdline_parser = argparse.ArgumentParser(prog='tsd-client-backend')
_cmdline_parser.add_argument(
'template',
help=('A template to use when generating the TypeScript definition file.')
)
_cmdline_parser.add_argument(
'filename',
help=('The name to give the single TypeScript definition file to contain '
'all of the emitted types.'),
)
_cmdline_parser.add_argument(
'-t',
'--template-string',
type=str,
default='ROUTES',
help=('The name of the template string to replace with route definitions. '
'Defaults to ROUTES, which replaces the string /*ROUTES*/ with route '
'definitions.')
)
_cmdline_parser.add_argument(
'-i',
'--indent-level',
type=int,
default=1,
help=('Indentation level to emit types at. Routes are automatically '
'indented one level further than this.')
)
_cmdline_parser.add_argument(
'-s',
'--spaces-per-indent',
type=int,
default=2,
help=('Number of spaces to use per indentation level.')
)
_cmdline_parser.add_argument(
'--wrap-response-in',
type=str,
default='',
help=('Wraps the response in a response class')
)
_header = """\
// Auto-generated by Stone, do not modify.
"""
class TSDClientBackend(CodeBackend):
"""Generates a TypeScript definition file with routes defined."""
cmdline_parser = _cmdline_parser
preserve_aliases = True
def generate(self, api):
spaces_per_indent = self.args.spaces_per_indent
indent_level = self.args.indent_level
template_path = os.path.join(self.target_folder_path, self.args.template)
template_string = self.args.template_string
with self.output_to_relative_path(self.args.filename):
if os.path.isfile(template_path):
with open(template_path, 'r') as template_file:
template = template_file.read()
else:
raise AssertionError('TypeScript template file does not exist.')
# /*ROUTES*/
r_match = re.search("/\\*%s\\*/" % (template_string), template)
if not r_match:
raise AssertionError(
'Missing /*%s*/ in TypeScript template file.' % template_string)
r_start = r_match.start()
r_end = r_match.end()
r_ends_with_newline = template[r_end - 1] == '\n'
t_end = len(template)
t_ends_with_newline = template[t_end - 1] == '\n'
self.emit_raw(template[0:r_start] + ('\n' if not r_ends_with_newline else ''))
self._generate_routes(api, spaces_per_indent, indent_level)
self.emit_raw(template[r_end + 1:t_end] + ('\n' if not t_ends_with_newline else ''))
def _generate_routes(self, api, spaces_per_indent, indent_level):
with self.indent(dent=spaces_per_indent * (indent_level + 1)):
for namespace in api.namespaces.values():
# first check for route name conflict
check_route_name_conflict(namespace)
for route in namespace.routes:
self._generate_route(
namespace, route)
def _generate_route(self, namespace, route):
function_name = fmt_func(namespace.name + '_' + route.name, route.version)
self.emit()
self.emit('/**')
if route.doc:
self.emit_wrapped_text(self.process_doc(route.doc, self._docf), prefix=' * ')
self.emit(' *')
self.emit_wrapped_text('When an error occurs, the route rejects the promise with type %s.'
% fmt_error_type(route.error_data_type), prefix=' * ')
if route.deprecated:
self.emit(' * @deprecated')
if route.arg_data_type.__class__ != Void:
self.emit(' * @param arg The request parameters.')
self.emit(' */')
return_type = None
if self.args.wrap_response_in:
return_type = 'Promise<%s<%s>>;' % (self.args.wrap_response_in,
fmt_type(route.result_data_type))
else:
return_type = 'Promise<%s>;' % (fmt_type(route.result_data_type))
arg = ''
if route.arg_data_type.__class__ != Void:
arg = 'arg: %s' % fmt_type(route.arg_data_type)
self.emit('public %s(%s): %s' % (function_name, arg, return_type))
def _docf(self, tag, val):
"""
Callback to process documentation references.
"""
return fmt_tag(None, tag, val)

View file

@ -0,0 +1,178 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from stone.backend import Backend
from stone.ir.api import ApiNamespace
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_alias,
is_list_type,
is_struct_type,
is_map_type,
is_user_defined_type,
)
from stone.backends.helpers import (
fmt_camel,
)
_base_type_table = {
Boolean: 'boolean',
Bytes: 'string',
Float32: 'number',
Float64: 'number',
Int32: 'number',
Int64: 'number',
List: 'Array',
String: 'string',
UInt32: 'number',
UInt64: 'number',
Timestamp: 'Timestamp',
Void: 'void',
}
def fmt_error_type(data_type, inside_namespace=None):
"""
Converts the error type into a TypeScript type.
inside_namespace should be set to the namespace that the reference
occurs in, or None if this parameter is not relevant.
"""
return 'Error<%s>' % fmt_type(data_type, inside_namespace)
def fmt_type_name(data_type, inside_namespace=None):
"""
Produces a TypeScript type name for the given data type.
inside_namespace should be set to the namespace that the reference
occurs in, or None if this parameter is not relevant.
"""
if is_user_defined_type(data_type) or is_alias(data_type):
if data_type.namespace == inside_namespace:
return data_type.name
else:
return '%s.%s' % (data_type.namespace.name, data_type.name)
else:
fmted_type = _base_type_table.get(data_type.__class__, 'Object')
if is_list_type(data_type):
fmted_type += '<' + fmt_type(data_type.data_type, inside_namespace) + '>'
elif is_map_type(data_type):
key_data_type = _base_type_table.get(data_type.key_data_type, 'string')
value_data_type = fmt_type_name(data_type.value_data_type, inside_namespace)
fmted_type = '{[key: %s]: %s}' % (key_data_type, value_data_type)
return fmted_type
def fmt_polymorphic_type_reference(data_type, inside_namespace=None):
"""
Produces a TypeScript type name for the meta-type that refers to the given
struct, which belongs to an enumerated subtypes tree. This meta-type contains the
.tag field that lets developers discriminate between subtypes.
"""
# NOTE: These types are not properly namespaced, so there could be a conflict
# with other user-defined types. If this ever surfaces as a problem, we
# can defer emitting these types until the end, and emit them in a
# nested namespace (e.g., files.references.MetadataReference).
return fmt_type_name(data_type, inside_namespace) + "Reference"
def fmt_type(data_type, inside_namespace=None):
"""
Returns a TypeScript type annotation for a data type.
May contain a union of enumerated subtypes.
inside_namespace should be set to the namespace that the type reference
occurs in, or None if this parameter is not relevant.
"""
if is_struct_type(data_type) and data_type.has_enumerated_subtypes():
possible_types = []
possible_subtypes = data_type.get_all_subtypes_with_tags()
for _, subtype in possible_subtypes:
possible_types.append(fmt_polymorphic_type_reference(subtype, inside_namespace))
if data_type.is_catch_all():
possible_types.append(fmt_polymorphic_type_reference(data_type, inside_namespace))
return fmt_union(possible_types)
else:
return fmt_type_name(data_type, inside_namespace)
def fmt_union(type_strings):
"""
Returns a union type of the given types.
"""
return '|'.join(type_strings) if len(type_strings) > 1 else type_strings[0]
def fmt_func(name, version):
if version == 1:
return fmt_camel(name)
return fmt_camel(name) + 'V{}'.format(version)
def fmt_var(name):
return fmt_camel(name)
def fmt_tag(cur_namespace, tag, val):
"""
Processes a documentation reference.
"""
if tag == 'type':
fq_val = val
if '.' not in val and cur_namespace is not None:
fq_val = cur_namespace.name + '.' + fq_val
return fq_val
elif tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
return fmt_func(val, version) + "()"
elif tag == 'link':
anchor, link = val.rsplit(' ', 1)
# There's no way to have links in TSDoc, so simply use JSDoc's formatting.
# It's entirely possible some editors support this.
return '[%s]{@link %s}' % (anchor, link)
elif tag == 'val':
# Value types seem to match JavaScript (true, false, null)
return val
elif tag == 'field':
return val
else:
raise RuntimeError('Unknown doc ref tag %r' % tag)
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route
def generate_imports_for_referenced_namespaces(backend, namespace, module_name_prefix):
# type: (Backend, ApiNamespace, str) -> None
imported_namespaces = namespace.get_imported_namespaces()
if not imported_namespaces:
return
for ns in imported_namespaces:
backend.emit(
"import * as {namespace_name} from '{module_name_prefix}{namespace_name}';".format(
module_name_prefix=module_name_prefix,
namespace_name=ns.name
)
)
backend.emit()

View file

@ -0,0 +1,517 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import os
import re
import six
import sys
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
from stone.ir import ApiNamespace
from stone.ir import (
is_alias,
is_struct_type,
is_union_type,
is_user_defined_type,
is_void_type,
unwrap_nullable,
)
from stone.backend import CodeBackend
from stone.backends.helpers import (
fmt_pascal,
)
from stone.backends.tsd_helpers import (
fmt_polymorphic_type_reference,
fmt_tag,
fmt_type,
fmt_type_name,
fmt_union,
generate_imports_for_referenced_namespaces,
)
_cmdline_parser = argparse.ArgumentParser(prog='tsd-types-backend')
_cmdline_parser.add_argument(
'template',
help=('A template to use when generating the TypeScript definition file. '
'Replaces the string /*TYPES*/ with stone type definitions.')
)
_cmdline_parser.add_argument(
'filename',
nargs='?',
help=('The name of the generated typeScript definition file that contains '
'all of the emitted types.'),
)
_cmdline_parser.add_argument(
'--exclude_error_types',
default=False,
action='store_true',
help='If true, the output will exclude the interface for Error type.',
)
_cmdline_parser.add_argument(
'-e',
'--extra-arg',
action='append',
type=str,
default=[],
help=("Additional argument to add to a route's argument based "
"on if the route has a certain attribute set. Format (JSON): "
'{"match": ["ROUTE_ATTR", ROUTE_VALUE_TO_MATCH], '
'"arg_name": "ARG_NAME", "arg_type": "ARG_TYPE", '
'"arg_docstring": "ARG_DOCSTRING"}'),
)
_cmdline_parser.add_argument(
'-i',
'--indent-level',
type=int,
default=1,
help=('Indentation level to emit types at. Routes are automatically '
'indented one level further than this.')
)
_cmdline_parser.add_argument(
'-s',
'--spaces-per-indent',
type=int,
default=2,
help=('Number of spaces to use per indentation level.')
)
_cmdline_parser.add_argument(
'-p',
'--module-name-prefix',
type=str,
default='',
help=('Prefix for data type module names. '
'This is useful for repo which requires absolute path as '
'module name')
)
_cmdline_parser.add_argument(
'--export-namespaces',
default=False,
action='store_true',
help=('Adds the export tag to each namespace.'
'This is useful is you are not placing each namespace '
'inside of a module and want to export each namespace individually')
)
_header = """\
// Auto-generated by Stone, do not modify.
"""
_types_header = """\
/**
* An Error object returned from a route.
*/
interface Error<T> {
\t// Text summary of the error.
\terror_summary: string;
\t// The error object.
\terror: T;
\t// User-friendly error message.
\tuser_message: UserMessage;
}
/**
* User-friendly error message.
*/
interface UserMessage {
\t// The message.
\ttext: string;
\t// The locale of the message.
\tlocale: string;
}
"""
_timestamp_definition = "type Timestamp = string;"
class TSDTypesBackend(CodeBackend):
"""
Generates a single TypeScript definition file with all of the types defined, organized
as namespaces, if a filename is provided in input arguments. Otherwise generates one
declaration file for each namespace with the corresponding typescript definitions.
If a single output file is generated, a top level type definition will be added for the
Timestamp data type. Otherwise, each namespace will have the type definition for Timestamp.
Also, note that namespace definitions are emitted as declaration files. Hence any template
provided as argument must not have a top level declare statement. If namespaces are emitted
into a single file, the template file can be used to wrap them around a declare statement.
"""
cmdline_parser = _cmdline_parser
preserve_aliases = True
# Instance var of the current namespace being generated
cur_namespace = None # type: typing.Optional[ApiNamespace]
# Instance var to denote if one file is output for each namespace.
split_by_namespace = False
def generate(self, api):
extra_args = self._parse_extra_args(api, self.args.extra_arg)
template = self._read_template()
if self.args.filename:
self._generate_base_namespace_module(api.namespaces.values(), self.args.filename,
template, extra_args,
exclude_error_types=self.args.exclude_error_types)
else:
self.split_by_namespace = True
for namespace in api.namespaces.values():
filename = '{}.d.ts'.format(namespace.name)
self._generate_base_namespace_module(
[namespace], filename, template,
extra_args,
exclude_error_types=self.args.exclude_error_types)
def _read_template(self):
template_path = os.path.join(self.target_folder_path, self.args.template)
if os.path.isfile(template_path):
with open(template_path, 'r') as template_file:
return template_file.read()
else:
raise AssertionError('TypeScript template file does not exist.')
def _get_data_types(self, namespace):
return namespace.data_types + namespace.aliases
def _generate_base_namespace_module(self, namespace_list, filename,
template, extra_args,
exclude_error_types=False):
# Skip namespaces that do not contain types.
if all([len(self._get_data_types(ns)) == 0 for ns in namespace_list]):
return
spaces_per_indent = self.args.spaces_per_indent
indent_level = self.args.indent_level
with self.output_to_relative_path(filename):
# /*TYPES*/
t_match = re.search("/\\*TYPES\\*/", template)
if not t_match:
raise AssertionError('Missing /*TYPES*/ in TypeScript template file.')
t_start = t_match.start()
t_end = t_match.end()
t_ends_with_newline = template[t_end - 1] == '\n'
temp_end = len(template)
temp_ends_with_newline = template[temp_end - 1] == '\n'
self.emit_raw(template[0:t_start] + ("\n" if not t_ends_with_newline else ''))
indent = spaces_per_indent * indent_level
indent_spaces = (' ' * indent)
with self.indent(dent=indent):
if not exclude_error_types:
indented_types_header = indent_spaces + (
('\n' + indent_spaces)
.join(_types_header.split('\n'))
.replace('\t', ' ' * spaces_per_indent)
)
self.emit_raw(indented_types_header + '\n')
if not self.split_by_namespace:
self.emit(_timestamp_definition)
self.emit()
for namespace in namespace_list:
self._generate_types(namespace, spaces_per_indent, extra_args)
self.emit_raw(template[t_end + 1:temp_end] +
("\n" if not temp_ends_with_newline else ''))
def _generate_types(self, namespace, spaces_per_indent, extra_args):
self.cur_namespace = namespace
# Count aliases as data types too!
data_types = self._get_data_types(namespace)
# Skip namespaces that do not contain types.
if len(data_types) == 0:
return
if self.split_by_namespace:
generate_imports_for_referenced_namespaces(
backend=self, namespace=namespace, module_name_prefix=self.args.module_name_prefix
)
if namespace.doc:
self._emit_tsdoc_header(namespace.doc)
self.emit_wrapped_text(self._get_top_level_declaration(namespace.name))
with self.indent(dent=spaces_per_indent):
for data_type in data_types:
self._generate_type(data_type, spaces_per_indent,
extra_args.get(data_type, []))
if self.split_by_namespace:
with self.indent(dent=spaces_per_indent):
# TODO(Pranay): May avoid adding an unused definition if needed.
self.emit(_timestamp_definition)
self.emit('}')
self.emit()
def _get_top_level_declaration(self, name):
if self.split_by_namespace:
# Use module for when emitting declaration files.
return "declare module '%s%s' {" % (self.args.module_name_prefix, name)
else:
if self.args.export_namespaces:
return "export namespace %s {" % name
else:
# Use namespace for organizing code with-in the file.
return "namespace %s {" % name
def _parse_extra_args(self, api, extra_args_raw):
"""
Parses extra arguments into a map keyed on particular data types.
"""
extra_args = {}
def invalid(msg, extra_arg_raw):
print('Invalid --extra-arg:%s: %s' % (msg, extra_arg_raw),
file=sys.stderr)
sys.exit(1)
for extra_arg_raw in extra_args_raw:
try:
extra_arg = json.loads(extra_arg_raw)
except ValueError as e:
invalid(str(e), extra_arg_raw)
# Validate extra_arg JSON blob
if 'match' not in extra_arg:
invalid('No match key', extra_arg_raw)
elif (not isinstance(extra_arg['match'], list) or
len(extra_arg['match']) != 2):
invalid('match key is not a list of two strings', extra_arg_raw)
elif (not isinstance(extra_arg['match'][0], six.text_type) or
not isinstance(extra_arg['match'][1], six.text_type)):
print(type(extra_arg['match'][0]))
invalid('match values are not strings', extra_arg_raw)
elif 'arg_name' not in extra_arg:
invalid('No arg_name key', extra_arg_raw)
elif not isinstance(extra_arg['arg_name'], six.text_type):
invalid('arg_name is not a string', extra_arg_raw)
elif 'arg_type' not in extra_arg:
invalid('No arg_type key', extra_arg_raw)
elif not isinstance(extra_arg['arg_type'], six.text_type):
invalid('arg_type is not a string', extra_arg_raw)
elif ('arg_docstring' in extra_arg and
not isinstance(extra_arg['arg_docstring'], six.text_type)):
invalid('arg_docstring is not a string', extra_arg_raw)
attr_key, attr_val = extra_arg['match'][0], extra_arg['match'][1]
extra_args.setdefault(attr_key, {})[attr_val] = \
(extra_arg['arg_name'], extra_arg['arg_type'],
extra_arg.get('arg_docstring'))
# Extra arguments, keyed on data type objects.
extra_args_for_types = {}
# Locate data types that contain extra arguments
for namespace in api.namespaces.values():
for route in namespace.routes:
extra_parameters = []
if is_user_defined_type(route.arg_data_type):
for attr_key in route.attrs:
if attr_key not in extra_args:
continue
attr_val = route.attrs[attr_key]
if attr_val in extra_args[attr_key]:
extra_parameters.append(extra_args[attr_key][attr_val])
if len(extra_parameters) > 0:
extra_args_for_types[route.arg_data_type] = extra_parameters
return extra_args_for_types
def _emit_tsdoc_header(self, docstring):
self.emit('/**')
self.emit_wrapped_text(self.process_doc(docstring, self._docf), prefix=' * ')
self.emit(' */')
def _generate_type(self, data_type, indent_spaces, extra_args):
"""
Generates a TypeScript type for the given type.
"""
if is_alias(data_type):
self._generate_alias_type(data_type)
elif is_struct_type(data_type):
self._generate_struct_type(data_type, indent_spaces, extra_args)
elif is_union_type(data_type):
self._generate_union_type(data_type, indent_spaces)
def _generate_alias_type(self, alias_type):
"""
Generates a TypeScript type for a stone alias.
"""
namespace = alias_type.namespace
self.emit('export type %s = %s;' % (fmt_type_name(alias_type, namespace),
fmt_type_name(alias_type.data_type, namespace)))
self.emit()
def _generate_struct_type(self, struct_type, indent_spaces, extra_parameters):
"""
Generates a TypeScript interface for a stone struct.
"""
namespace = struct_type.namespace
if struct_type.doc:
self._emit_tsdoc_header(struct_type.doc)
parent_type = struct_type.parent_type
extends_line = ' extends %s' % fmt_type_name(parent_type, namespace) if parent_type else ''
self.emit('export interface %s%s {' % (fmt_type_name(struct_type, namespace), extends_line))
with self.indent(dent=indent_spaces):
for param_name, param_type, param_docstring in extra_parameters:
if param_docstring:
self._emit_tsdoc_header(param_docstring)
self.emit('%s: %s;' % (param_name, param_type))
for field in struct_type.fields:
doc = field.doc
field_type, nullable = unwrap_nullable(field.data_type)
field_ts_type = fmt_type(field_type, namespace)
optional = nullable or field.has_default
if field.has_default:
# doc may be None. If it is not empty, add newlines
# before appending to it.
doc = doc + '\n\n' if doc else ''
doc = "Defaults to %s." % field.default
if doc:
self._emit_tsdoc_header(doc)
# Translate nullable types into optional properties.
field_name = '%s?' % field.name if optional else field.name
self.emit('%s: %s;' % (field_name, field_ts_type))
self.emit('}')
self.emit()
# Some structs can explicitly list their subtypes. These structs have a .tag field that
# indicate which subtype they are, which is only present when a type reference is
# ambiguous.
# Emit a special interface that contains this extra field, and refer to it whenever we
# encounter a reference to a type with enumerated subtypes.
if struct_type.is_member_of_enumerated_subtypes_tree():
if struct_type.has_enumerated_subtypes():
# This struct is the parent to multiple subtypes. Determine all of the possible
# values of the .tag property.
tag_values = []
for tags, _ in struct_type.get_all_subtypes_with_tags():
for tag in tags:
tag_values.append('"%s"' % tag)
tag_union = fmt_union(tag_values)
self._emit_tsdoc_header('Reference to the %s polymorphic type. Contains a .tag '
'property to let you discriminate between possible '
'subtypes.' % fmt_type_name(struct_type, namespace))
self.emit('export interface %s extends %s {' %
(fmt_polymorphic_type_reference(struct_type, namespace),
fmt_type_name(struct_type, namespace)))
with self.indent(dent=indent_spaces):
self._emit_tsdoc_header('Tag identifying the subtype variant.')
self.emit('\'.tag\': %s;' % tag_union)
self.emit('}')
self.emit()
else:
# This struct is a particular subtype. Find the applicable .tag value from the
# parent type, which may be an arbitrary number of steps up the inheritance
# hierarchy.
parent = struct_type.parent_type
while not parent.has_enumerated_subtypes():
parent = parent.parent_type
# parent now contains the closest parent type in the inheritance hierarchy that has
# enumerated subtypes. Determine which subtype this is.
for subtype in parent.get_enumerated_subtypes():
if subtype.data_type == struct_type:
self._emit_tsdoc_header('Reference to the %s type, identified by the '
'value of the .tag property.' %
fmt_type_name(struct_type, namespace))
self.emit('export interface %s extends %s {' %
(fmt_polymorphic_type_reference(struct_type, namespace),
fmt_type_name(struct_type, namespace)))
with self.indent(dent=indent_spaces):
self._emit_tsdoc_header('Tag identifying this subtype variant. This '
'field is only present when needed to '
'discriminate between multiple possible '
'subtypes.')
self.emit_wrapped_text('\'.tag\': \'%s\';' % subtype.name)
self.emit('}')
self.emit()
break
def _generate_union_type(self, union_type, indent_spaces):
"""
Generates a TypeScript interface for a stone union.
"""
# Emit an interface for each variant. TypeScript 2.0 supports these tagged unions.
# https://github.com/Microsoft/TypeScript/wiki/What%27s-new-in-TypeScript#tagged-union-types
parent_type = union_type.parent_type
namespace = union_type.namespace
union_type_name = fmt_type_name(union_type, namespace)
variant_type_names = []
if parent_type:
variant_type_names.append(fmt_type_name(parent_type, namespace))
def _is_struct_without_enumerated_subtypes(data_type):
"""
:param data_type: any data type.
:return: True if the given data type is a struct which has no enumerated subtypes.
"""
return is_struct_type(data_type) and (
not data_type.has_enumerated_subtypes())
for variant in union_type.fields:
if variant.doc:
self._emit_tsdoc_header(variant.doc)
variant_name = '%s%s' % (union_type_name, fmt_pascal(variant.name))
variant_type_names.append(variant_name)
is_struct_without_enumerated_subtypes = _is_struct_without_enumerated_subtypes(
variant.data_type)
if is_struct_without_enumerated_subtypes:
self.emit('export interface %s extends %s {' % (
variant_name, fmt_type(variant.data_type, namespace)))
else:
self.emit('export interface %s {' % variant_name)
with self.indent(dent=indent_spaces):
# Since field contains non-alphanumeric character, we need to enclose
# it in quotation marks.
self.emit("'.tag': '%s';" % variant.name)
if is_void_type(variant.data_type) is False and (
not is_struct_without_enumerated_subtypes
):
self.emit("%s: %s;" % (variant.name, fmt_type(variant.data_type, namespace)))
self.emit('}')
self.emit()
if union_type.doc:
self._emit_tsdoc_header(union_type.doc)
self.emit('export type %s = %s;' % (union_type_name, ' | '.join(variant_type_names)))
self.emit()
def _docf(self, tag, val):
"""
Callback to process documentation references.
"""
return fmt_tag(self.cur_namespace, tag, val)