Ausgabe der neuen DB Einträge

This commit is contained in:
hubobel 2022-01-02 21:50:48 +01:00
parent bad48e1627
commit cfbbb9ee3d
2399 changed files with 843193 additions and 43 deletions

View file

@ -0,0 +1,499 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from abc import ABCMeta, abstractmethod
from contextlib import contextmanager
import argparse
import logging
import os
import six
import textwrap
from stone.frontend.ir_generator import doc_ref_re
from stone.ir import (
is_alias,
resolve_aliases,
strip_alias
)
_MYPY = False
if _MYPY:
from stone.ir import Api
import typing # pylint: disable=import-error,useless-suppression
# Generic Dict key-val types
DelimTuple = typing.Tuple[typing.Text, typing.Text]
K = typing.TypeVar('K')
V = typing.TypeVar('V')
def remove_aliases_from_api(api):
# Resolve nested aliases from each namespace first. This way, when we replace an alias with
# its source later on, it too is alias free.
for namespace in api.namespaces.values():
for alias in namespace.aliases:
# This loops through each alias type chain, resolving each (nested) alias
# to its underlying type at the end of the chain (see resolve_aliases fn).
#
# It will continue until it no longer encounters a type
# with a data_type attribute - this ensures it resolves aliases
# that are subtypes of composites e.g. Lists
curr_type = alias
while hasattr(curr_type, 'data_type'):
curr_type.data_type = resolve_aliases(curr_type.data_type)
curr_type = curr_type.data_type
# Remove alias layers from each data type
for namespace in api.namespaces.values():
for data_type in namespace.data_types:
for field in data_type.fields:
strip_alias(field)
for route in namespace.routes:
# Strip inner aliases
strip_alias(route.arg_data_type)
strip_alias(route.result_data_type)
strip_alias(route.error_data_type)
# Strip top-level aliases
if is_alias(route.arg_data_type):
route.arg_data_type = route.arg_data_type.data_type
if is_alias(route.result_data_type):
route.result_data_type = route.result_data_type.data_type
if is_alias(route.error_data_type):
route.error_data_type = route.error_data_type.data_type
# Clear aliases
namespace.aliases = []
namespace.alias_by_name = {}
return api
@six.add_metaclass(ABCMeta)
class Backend(object):
"""
The parent class for all backends. All backends should extend this
class to be recognized as such.
You will want to implement the generate() function to do the generation
that you need.
Here's roughly what you need to do in generate().
1. Use the context manager output_to_relative_path() to specify an output file.
with output_to_relative_path('generated_code.py'):
...
2. Use the family of emit*() functions to write to the output file.
The target_folder_path attribute is the path to the folder where all
generated files should be created.
"""
# Can be overridden by a subclass
tabs_for_indents = False
# Can be overridden with an argparse.ArgumentParser object.
cmdline_parser = None # type: argparse.ArgumentParser
# Can be overridden by a subclass. If true, stone.data_type.Alias
# objects will be present in the API object. If false, aliases are masked
# by replacing them with duplicate type definitions as the source type.
# For backwards compatibility with existing backends defaults to false.
preserve_aliases = False
def __init__(self, target_folder_path, args):
# type: (str, typing.Optional[typing.Sequence[str]]) -> None
"""
Args:
target_folder_path (str): Path to the folder where all generated
files should be created.
"""
self.logger = logging.getLogger('Backend<%s>' %
self.__class__.__name__)
self.target_folder_path = target_folder_path
# Output is a list of strings that should be concatenated together for
# the final output.
self.output = [] # type: typing.List[typing.Text]
self.lineno = 1
self.cur_indent = 0
self.positional_placeholders = [] # type: typing.List[typing.Text]
self.named_placeholders = {} # type: typing.Dict[typing.Text, typing.Text]
self.args = None # type: typing.Optional[argparse.Namespace]
if self.cmdline_parser:
assert isinstance(self.cmdline_parser, argparse.ArgumentParser), (
'expected cmdline_parser to be ArgumentParser, got %r' %
self.cmdline_parser)
try:
self.args = self.cmdline_parser.parse_args(args)
except SystemExit:
print('Note: This is for backend-specific arguments which '
'follow arguments to Stone after a "--" delimiter.')
raise
@abstractmethod
def generate(self, api):
# type: (Api) -> None
"""
Subclasses should override this method. It's the entry point that is
invoked by the rest of the toolchain.
Args:
api (stone.api.Api): The API specification.
"""
raise NotImplementedError
@contextmanager
def output_to_relative_path(self, relative_path, mode='wb'):
# type: (typing.Text, typing.Text) -> typing.Iterator[None]
"""
Sets up backend so that all emits are directed towards the new file
created at :param:`relative_path`.
Clears the output buffer on enter and exit.
"""
full_path = os.path.join(self.target_folder_path, relative_path)
directory = os.path.dirname(full_path)
if not os.path.exists(directory):
self.logger.info('Creating %s', directory)
os.makedirs(directory)
self.logger.info('Generating %s', full_path)
self.clear_output_buffer()
yield
with open(full_path, mode) as f:
f.write(self.output_buffer_to_string().encode('utf-8'))
self.clear_output_buffer()
def output_buffer_to_string(self):
# type: () -> typing.Text
"""Returns the contents of the output buffer as a string."""
return ''.join(self.output).format(
*self.positional_placeholders,
**self.named_placeholders)
def clear_output_buffer(self):
self.output = []
self.positional_placeholders = []
self.named_placeholders = {}
def indent_step(self):
# type: () -> int
"""
Returns the size of a single indentation step.
"""
return 1 if self.tabs_for_indents else 4
@contextmanager
def indent(self, dent=None):
# type: (typing.Optional[int]) -> typing.Iterator[None]
"""
For the duration of the context manager, indentation will be increased
by dent. Dent is in units of spaces or tabs depending on the value of
the class variable tabs_for_indents. If dent is None, indentation will
increase by either four spaces or one tab.
"""
assert dent is None or dent >= 0, 'dent must be >= 0.'
if dent is None:
dent = self.indent_step()
self.cur_indent += dent
yield
self.cur_indent -= dent
def make_indent(self):
# type: () -> typing.Text
"""
Returns a string representing the current indentation. Indents can be
either spaces or tabs, depending on the value of the class variable
tabs_for_indents.
"""
if self.tabs_for_indents:
return '\t' * self.cur_indent
else:
return ' ' * self.cur_indent
@contextmanager
def capture_emitted_output(self, output_buffer):
# type: (six.StringIO) -> typing.Iterator[None]
original_output = self.output
self.output = []
yield
output_buffer.write(''.join(self.output))
self.output = original_output
def emit_raw(self, s):
# type: (typing.Text) -> None
"""
Adds the input string to the output buffer. The string must end in a
newline. It may contain any number of newline characters. No
indentation is generated.
"""
self.lineno += s.count('\n')
self._append_output(s.replace('{', '{{').replace('}', '}}'))
if len(s) > 0 and s[-1] != '\n':
raise AssertionError(
'Input string to emit_raw must end with a newline.')
def _append_output(self, s):
# type: (typing.Text) -> None
self.output.append(s)
def emit(self, s=''):
# type: (typing.Text) -> None
"""
Adds indentation, then the input string, and lastly a newline to the
output buffer. If s is an empty string (default) then an empty line is
created with no indentation.
"""
assert isinstance(s, six.text_type), 's must be a unicode string'
assert '\n' not in s, \
'String to emit cannot contain newline strings.'
if s:
self.emit_raw('%s%s\n' % (self.make_indent(), s))
else:
self.emit_raw('\n')
def emit_wrapped_text(
self,
s, # type: typing.Text
prefix='', # type: typing.Text
initial_prefix='', # type: typing.Text
subsequent_prefix='', # type: typing.Text
width=80, # type: int
break_long_words=False, # type: bool
break_on_hyphens=False # type: bool
):
# type: (...) -> None
"""
Adds the input string to the output buffer with indentation and
wrapping. The wrapping is performed by the :func:`textwrap.fill` Python
library function.
Args:
s (str): The input string to wrap.
prefix (str): The string to prepend to *every* line.
initial_prefix (str): The string to prepend to the first line of
the wrapped string. Note that the current indentation is
already added to each line.
subsequent_prefix (str): The string to prepend to every line after
the first. Note that the current indentation is already added
to each line.
width (int): The target width of each line including indentation
and text.
break_long_words (bool): Break words longer than width. If false,
those words will not be broken, and some lines might be longer
than width.
break_on_hyphens (bool): Allow breaking hyphenated words. If true,
wrapping will occur preferably on whitespaces and right after
hyphens part of compound words.
"""
indent = self.make_indent()
prefix = indent + prefix
self.emit_raw(textwrap.fill(s,
initial_indent=prefix + initial_prefix,
subsequent_indent=prefix + subsequent_prefix,
width=width,
break_long_words=break_long_words,
break_on_hyphens=break_on_hyphens,
) + '\n')
def emit_placeholder(self, s=''):
# type: (typing.Text) -> None
"""
Emits replacements fields that can be used to format the output string later.
"""
self._append_output('{%s}' % s)
def add_positional_placeholder(self, s):
# type: (typing.Text) -> None
"""
Format replacement fields corresponding to empty calls to emit_placeholder.
"""
self.positional_placeholders.append(s)
def add_named_placeholder(self, name, s):
# type: (typing.Text, typing.Text) -> None
"""
Format replacement fields corresponding to non-empty calls to emit_placeholder.
"""
self.named_placeholders[name] = s
@classmethod
def process_doc(cls, doc, handler):
# type: (str, typing.Callable[[str, str], str]) -> typing.Text
"""
Helper for parsing documentation references in Stone docstrings and
replacing them with more suitable annotations for the generated output.
Args:
doc (str): A Stone docstring.
handler: A function with the following signature:
`(tag: str, value: str) -> str`. It will be called for every
reference found in the docstring with the tag and value parsed
for you. The returned string will be substituted in the
docstring in place of the reference.
"""
assert isinstance(doc, six.text_type), \
'Expected string (unicode in PY2), got %r.' % type(doc)
cur_index = 0
parts = []
for match in doc_ref_re.finditer(doc):
# Append the part of the doc that is not part of any reference.
start, end = match.span()
parts.append(doc[cur_index:start])
cur_index = end
# Call the handler with the next tag and value.
tag = match.group('tag')
val = match.group('val')
sub = handler(tag, val)
parts.append(sub)
parts.append(doc[cur_index:])
return ''.join(parts)
class CodeBackend(Backend):
"""
Extend this instead of :class:`Backend` when generating source code.
Contains helper functions specific to code generation.
"""
# pylint: disable=abstract-method
def filter_out_none_valued_keys(self, d):
# type: (typing.Dict[K, V]) -> typing.Dict[K, V]
"""Given a dict, returns a new dict with all the same key/values except
for keys that had values of None."""
new_d = {}
for k, v in d.items():
if v is not None:
new_d[k] = v
return new_d
def generate_multiline_list(
self,
items, # type: typing.List[typing.Text]
before='', # type: typing.Text
after='', # type: typing.Text
delim=('(', ')'), # type: DelimTuple
compact=True, # type: bool
sep=',', # type: typing.Text
skip_last_sep=False # type: bool
):
# type: (...) -> None
"""
Given a list of items, emits one item per line.
This is convenient for function prototypes and invocations, as well as
for instantiating arrays, sets, and maps in some languages.
TODO(kelkabany): A backend that uses tabs cannot be used with this
if compact is false.
Args:
items (list[str]): Should contain the items to generate a list of.
before (str): The string to come before the list of items.
after (str): The string to follow the list of items.
delim (str, str): The first element is added immediately following
`before`. The second element is added prior to `after`.
compact (bool): In compact mode, the enclosing parentheses are on
the same lines as the first and last list item.
sep (str): The string that follows each list item when compact is
true. If compact is false, the separator is omitted for the
last item.
skip_last_sep (bool): When compact is false, whether the last line
should have a trailing separator. Ignored when compact is true.
"""
assert len(delim) == 2 and isinstance(delim[0], six.text_type) and \
isinstance(delim[1], six.text_type), 'delim must be a tuple of two unicode strings.'
if len(items) == 0:
self.emit(before + delim[0] + delim[1] + after)
return
if len(items) == 1:
self.emit(before + delim[0] + items[0] + delim[1] + after)
return
if compact:
self.emit(before + delim[0] + items[0] + sep)
def emit_list(items):
items = items[1:]
for (i, item) in enumerate(items):
if i == len(items) - 1:
self.emit(item + delim[1] + after)
else:
self.emit(item + sep)
if before or delim[0]:
with self.indent(len(before) + len(delim[0])):
emit_list(items)
else:
emit_list(items)
else:
if before or delim[0]:
self.emit(before + delim[0])
with self.indent():
for (i, item) in enumerate(items):
if i == len(items) - 1 and skip_last_sep:
self.emit(item)
else:
self.emit(item + sep)
if delim[1] or after:
self.emit(delim[1] + after)
elif delim[1]:
self.emit(delim[1])
@contextmanager
def block(
self,
before='', # type: typing.Text
after='', # type: typing.Text
delim=('{', '}'), # type: DelimTuple
dent=None, # type: typing.Optional[int]
allman=False # type: bool
):
# type: (...) -> typing.Iterator[None]
"""
A context manager that emits configurable lines before and after an
indented block of text.
This is convenient for class and function definitions in some
languages.
Args:
before (str): The string to be output in the first line which is
not indented..
after (str): The string to be output in the last line which is
not indented.
delim (str, str): The first element is added immediately following
`before` and a space. The second element is added prior to a
space and then `after`.
dent (int): The amount to indent the block. If none, the default
indentation increment is used (four spaces or one tab).
allman (bool): Indicates whether to use `Allman` style indentation,
or the default `K&R` style. If there is no `before` string this
is ignored. For more details about indent styles see
http://en.wikipedia.org/wiki/Indent_style
"""
assert len(delim) == 2, 'delim must be a tuple of length 2'
assert (isinstance(delim[0], (six.text_type, type(None))) and
isinstance(delim[1], (six.text_type, type(None)))), (
'delim must be a tuple of two optional strings.')
if before and not allman:
if delim[0] is not None:
self.emit('{} {}'.format(before, delim[0]))
else:
self.emit(before)
else:
if before:
self.emit(before)
if delim[0] is not None:
self.emit(delim[0])
with self.indent(dent):
yield
if delim[1] is not None:
self.emit(delim[1] + after)
else:
self.emit(after)

View file

@ -0,0 +1,59 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import re
_split_words_capitalization_re = re.compile(
'^[a-z0-9]+|[A-Z][a-z0-9]+|[A-Z]+(?=[A-Z][a-z0-9])|[A-Z]+$'
)
_split_words_dashes_re = re.compile('[-_/]+')
def split_words(name):
"""
Splits name based on capitalization, dashes, and underscores.
Example: 'GetFile' -> ['Get', 'File']
Example: 'get_file' -> ['get', 'file']
"""
all_words = []
for word in re.split(_split_words_dashes_re, name):
vals = _split_words_capitalization_re.findall(word)
if vals:
all_words.extend(vals)
else:
all_words.append(word)
return all_words
def fmt_camel(name):
"""
Converts name to lower camel case. Words are identified by capitalization,
dashes, and underscores.
"""
words = split_words(name)
assert len(words) > 0
first = words.pop(0).lower()
return first + ''.join([word.capitalize() for word in words])
def fmt_dashes(name):
"""
Converts name to words separated by dashes. Words are identified by
capitalization, dashes, and underscores.
"""
return '-'.join([word.lower() for word in split_words(name)])
def fmt_pascal(name):
"""
Converts name to pascal case. Words are identified by capitalization,
dashes, and underscores.
"""
return ''.join([word.capitalize() for word in split_words(name)])
def fmt_underscores(name):
"""
Converts name to words separated by underscores. Words are identified by
capitalization, dashes, and underscores.
"""
return '_'.join([word.lower() for word in split_words(name)])

View file

@ -0,0 +1,150 @@
from __future__ import absolute_import, division, print_function, unicode_literals
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
from stone.ir import ApiNamespace
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
from stone.backend import CodeBackend
from stone.backends.js_helpers import (
check_route_name_conflict,
fmt_error_type,
fmt_func,
fmt_obj,
fmt_type,
fmt_url,
)
from stone.ir import Void
_cmdline_parser = argparse.ArgumentParser(prog='js-client-backend')
_cmdline_parser.add_argument(
'filename',
help=('The name to give the single Javascript file that is created and '
'contains all of the routes.'),
)
_cmdline_parser.add_argument(
'-c',
'--class-name',
type=str,
help=('The name of the class the generated functions will be attached to. '
'The name will be added to each function documentation, which makes '
'it available for tools like JSDoc.'),
)
_cmdline_parser.add_argument(
'--wrap-response-in',
type=str,
default='',
help=('Wraps the response in a response class')
)
_header = """\
// Auto-generated by Stone, do not modify.
var routes = {};
"""
class JavascriptClientBackend(CodeBackend):
"""Generates a single Javascript file with all of the routes defined."""
cmdline_parser = _cmdline_parser
# Instance var of the current namespace being generated
cur_namespace = None # type: typing.Optional[ApiNamespace]
preserve_aliases = True
def generate(self, api):
# first check for route name conflict
with self.output_to_relative_path(self.args.filename):
self.emit_raw(_header)
for namespace in api.namespaces.values():
# Hack: needed for _docf()
self.cur_namespace = namespace
check_route_name_conflict(namespace)
for route in namespace.routes:
self._generate_route(api.route_schema, namespace, route)
self.emit()
self.emit('export { routes };')
def _generate_route(self, route_schema, namespace, route):
function_name = fmt_func(namespace.name + '_' + route.name, route.version)
self.emit()
self.emit('/**')
if route.doc:
self.emit_wrapped_text(self.process_doc(route.doc, self._docf), prefix=' * ')
if self.args.class_name:
self.emit(' * @function {}#{}'.format(self.args.class_name,
function_name))
if route.deprecated:
self.emit(' * @deprecated')
return_type = None
if self.args.wrap_response_in:
return_type = '%s<%s>' % (self.args.wrap_response_in,
fmt_type(route.result_data_type))
else:
return_type = fmt_type(route.result_data_type)
if route.arg_data_type.__class__ != Void:
self.emit(' * @arg {%s} arg - The request parameters.' %
fmt_type(route.arg_data_type))
self.emit(' * @returns {Promise.<%s, %s>}' %
(return_type,
fmt_error_type(route.error_data_type)))
self.emit(' */')
if route.arg_data_type.__class__ != Void:
self.emit('routes.%s = function (arg) {' % (function_name))
else:
self.emit('routes.%s = function () {' % (function_name))
with self.indent(dent=2):
url = fmt_url(namespace.name, route.name, route.version)
if route_schema.fields:
additional_args = []
for field in route_schema.fields:
additional_args.append(fmt_obj(route.attrs[field.name]))
if route.arg_data_type.__class__ != Void:
self.emit(
"return this.request('{}', arg, {});".format(
url, ', '.join(additional_args)))
else:
self.emit(
"return this.request('{}', null, {});".format(
url, ', '.join(additional_args)))
else:
if route.arg_data_type.__class__ != Void:
self.emit(
'return this.request("%s", arg);' % url)
else:
self.emit(
'return this.request("%s", null);' % url)
self.emit('};')
def _docf(self, tag, val):
"""
Callback used as the handler argument to process_docs(). This converts
Stone doc references to JSDoc-friendly annotations.
"""
# TODO(kelkabany): We're currently just dropping all doc ref tags ...
# NOTE(praneshp): ... except for versioned routes
if tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
url = fmt_url(self.cur_namespace.name, val, version)
# NOTE: In js, for comments, we drop the namespace name and the '/' when
# documenting URLs
return url[(len(self.cur_namespace.name) + 1):]
return val

View file

@ -0,0 +1,127 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import six
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_list_type,
is_struct_type,
is_user_defined_type,
)
from stone.backends.helpers import (
fmt_camel,
fmt_pascal,
)
_base_type_table = {
Boolean: 'boolean',
Bytes: 'string',
Float32: 'number',
Float64: 'number',
Int32: 'number',
Int64: 'number',
List: 'Array',
String: 'string',
UInt32: 'number',
UInt64: 'number',
Timestamp: 'Timestamp',
Void: 'void',
}
def fmt_obj(o):
if isinstance(o, six.text_type):
# Prioritize single-quoted strings per JS style guides.
return repr(o).lstrip('u')
else:
return json.dumps(o, indent=2)
def fmt_error_type(data_type):
"""
Converts the error type into a JSDoc type.
"""
return 'Error.<%s>' % fmt_type(data_type)
def fmt_type_name(data_type):
"""
Returns the JSDoc name for the given data type.
(Does not attempt to enumerate subtypes.)
"""
if is_user_defined_type(data_type):
return fmt_pascal('%s%s' % (data_type.namespace.name, data_type.name))
else:
fmted_type = _base_type_table.get(data_type.__class__, 'Object')
if is_list_type(data_type):
fmted_type += '.<' + fmt_type(data_type.data_type) + '>'
return fmted_type
def fmt_type(data_type):
"""
Returns a JSDoc annotation for a data type.
May contain a union of enumerated subtypes.
"""
if is_struct_type(data_type) and data_type.has_enumerated_subtypes():
possible_types = []
possible_subtypes = data_type.get_all_subtypes_with_tags()
for _, subtype in possible_subtypes:
possible_types.append(fmt_type_name(subtype))
if data_type.is_catch_all():
possible_types.append(fmt_type_name(data_type))
return fmt_jsdoc_union(possible_types)
else:
return fmt_type_name(data_type)
def fmt_jsdoc_union(type_strings):
"""
Returns a JSDoc union of the given type strings.
"""
return '(' + '|'.join(type_strings) + ')' if len(type_strings) > 1 else type_strings[0]
def fmt_func(name, version):
if version == 1:
return fmt_camel(name)
return fmt_camel(name) + 'V{}'.format(version)
def fmt_url(namespace_name, route_name, route_version):
if route_version != 1:
return '{}/{}_v{}'.format(namespace_name, route_name, route_version)
else:
return '{}/{}'.format(namespace_name, route_name)
def fmt_var(name):
return fmt_camel(name)
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, version=route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route

View file

@ -0,0 +1,285 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import six
import sys
from stone.ir import (
is_user_defined_type,
is_union_type,
is_struct_type,
is_void_type,
unwrap,
)
from stone.backend import CodeBackend
from stone.backends.js_helpers import (
fmt_jsdoc_union,
fmt_type,
fmt_type_name,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(prog='js-types-backend')
_cmdline_parser.add_argument(
'filename',
help=('The name to give the single Javascript file that is created and '
'contains all of the JSDoc types.'),
)
_cmdline_parser.add_argument(
'-e',
'--extra-arg',
action='append',
type=str,
default=[],
help=("Additional properties to add to a route's argument type based "
"on if the route has a certain attribute set. Format (JSON): "
'{"match": ["ROUTE_ATTR", ROUTE_VALUE_TO_MATCH], '
'"arg_name": "ARG_NAME", "arg_type": "ARG_TYPE", '
'"arg_docstring": "ARG_DOCSTRING"}'),
)
_header = """\
// Auto-generated by Stone, do not modify.
/**
* An Error object returned from a route.
* @typedef {Object} Error
* @property {string} error_summary - Text summary of the error.
* @property {T} error - The error object.
* @property {UserMessage} user_message - An optional field. If present, it includes a
message that can be shown directly to the end user of your app. You should show this message
if your app is unprepared to programmatically handle the error returned by an endpoint.
* @template T
*/
/**
* User-friendly error message.
* @typedef {Object} UserMessage
* @property {string} text - The message.
* @property {string} locale
*/
/**
* @typedef {string} Timestamp
*/
"""
class JavascriptTypesBackend(CodeBackend):
"""Generates a single Javascript file with all of the data types defined in JSDoc."""
cmdline_parser = _cmdline_parser
preserve_aliases = True
def generate(self, api):
with self.output_to_relative_path(self.args.filename):
self.emit_raw(_header)
extra_args = self._parse_extra_args(api, self.args.extra_arg)
for namespace in api.namespaces.values():
for data_type in namespace.data_types:
self._generate_type(data_type, extra_args.get(data_type, []))
def _parse_extra_args(self, api, extra_args_raw):
"""
Parses extra arguments into a map keyed on particular data types.
"""
extra_args = {}
def die(m, extra_arg_raw):
print('Invalid --extra-arg:%s: %s' % (m, extra_arg_raw),
file=sys.stderr)
sys.exit(1)
for extra_arg_raw in extra_args_raw:
try:
extra_arg = json.loads(extra_arg_raw)
except ValueError as e:
die(str(e), extra_arg_raw)
# Validate extra_arg JSON blob
if 'match' not in extra_arg:
die('No match key', extra_arg_raw)
elif (not isinstance(extra_arg['match'], list) or
len(extra_arg['match']) != 2):
die('match key is not a list of two strings', extra_arg_raw)
elif (not isinstance(extra_arg['match'][0], six.text_type) or
not isinstance(extra_arg['match'][1], six.text_type)):
print(type(extra_arg['match'][0]))
die('match values are not strings', extra_arg_raw)
elif 'arg_name' not in extra_arg:
die('No arg_name key', extra_arg_raw)
elif not isinstance(extra_arg['arg_name'], six.text_type):
die('arg_name is not a string', extra_arg_raw)
elif 'arg_type' not in extra_arg:
die('No arg_type key', extra_arg_raw)
elif not isinstance(extra_arg['arg_type'], six.text_type):
die('arg_type is not a string', extra_arg_raw)
elif ('arg_docstring' in extra_arg and
not isinstance(extra_arg['arg_docstring'], six.text_type)):
die('arg_docstring is not a string', extra_arg_raw)
attr_key, attr_val = extra_arg['match'][0], extra_arg['match'][1]
extra_args.setdefault(attr_key, {})[attr_val] = \
(extra_arg['arg_name'], extra_arg['arg_type'],
extra_arg.get('arg_docstring'))
# Extra arguments, keyed on data type objects.
extra_args_for_types = {}
# Locate data types that contain extra arguments
for namespace in api.namespaces.values():
for route in namespace.routes:
extra_parameters = []
if is_user_defined_type(route.arg_data_type):
for attr_key in route.attrs:
if attr_key not in extra_args:
continue
attr_val = route.attrs[attr_key]
if attr_val in extra_args[attr_key]:
extra_parameters.append(extra_args[attr_key][attr_val])
if len(extra_parameters) > 0:
extra_args_for_types[route.arg_data_type] = extra_parameters
return extra_args_for_types
def _generate_type(self, data_type, extra_parameters):
if is_struct_type(data_type):
self._generate_struct(data_type, extra_parameters)
elif is_union_type(data_type):
self._generate_union(data_type)
def _emit_jsdoc_header(self, doc=None):
self.emit()
self.emit('/**')
if doc:
self.emit_wrapped_text(self.process_doc(doc, self._docf), prefix=' * ')
def _generate_struct(self, struct_type, extra_parameters=None, nameOverride=None):
"""
Emits a JSDoc @typedef for a struct.
"""
extra_parameters = extra_parameters if extra_parameters is not None else []
self._emit_jsdoc_header(struct_type.doc)
self.emit(
' * @typedef {Object} %s' % (
nameOverride if nameOverride else fmt_type_name(struct_type)
)
)
# Some structs can explicitly list their subtypes. These structs
# have a .tag field that indicate which subtype they are.
if struct_type.is_member_of_enumerated_subtypes_tree():
if struct_type.has_enumerated_subtypes():
# This struct is the parent to multiple subtypes.
# Determine all of the possible values of the .tag
# property.
tag_values = []
for tags, _ in struct_type.get_all_subtypes_with_tags():
for tag in tags:
tag_values.append('"%s"' % tag)
jsdoc_tag_union = fmt_jsdoc_union(tag_values)
txt = '@property {%s} .tag - Tag identifying the subtype variant.' % \
jsdoc_tag_union
self.emit_wrapped_text(txt)
else:
# This struct is a particular subtype. Find the applicable
# .tag value from the parent type, which may be an
# arbitrary number of steps up the inheritance hierarchy.
parent = struct_type.parent_type
while not parent.has_enumerated_subtypes():
parent = parent.parent_type
# parent now contains the closest parent type in the
# inheritance hierarchy that has enumerated subtypes.
# Determine which subtype this is.
for subtype in parent.get_enumerated_subtypes():
if subtype.data_type == struct_type:
txt = '@property {\'%s\'} [.tag] - Tag identifying ' \
'this subtype variant. This field is only ' \
'present when needed to discriminate ' \
'between multiple possible subtypes.' % \
subtype.name
self.emit_wrapped_text(txt)
break
for param_name, param_type, param_docstring in extra_parameters:
param_docstring = ' - %s' % param_docstring if param_docstring else ''
self.emit_wrapped_text(
'@property {%s} %s%s' % (
param_type,
param_name,
param_docstring,
),
prefix=' * ',
)
# NOTE: JSDoc @typedef does not support inheritance. Using @class would be inappropriate,
# since these are not nominal types backed by a constructor. Thus, we emit all_fields,
# which includes fields on parent types.
for field in struct_type.all_fields:
field_doc = ' - ' + field.doc if field.doc else ''
field_type, nullable, _ = unwrap(field.data_type)
field_js_type = fmt_type(field_type)
# Translate nullable types into optional properties.
field_name = '[' + field.name + ']' if nullable else field.name
self.emit_wrapped_text(
'@property {%s} %s%s' % (
field_js_type,
field_name,
self.process_doc(field_doc, self._docf),
),
prefix=' * ',
)
self.emit(' */')
def _generate_union(self, union_type):
"""
Emits a JSDoc @typedef for a union type.
"""
union_name = fmt_type_name(union_type)
self._emit_jsdoc_header(union_type.doc)
self.emit(' * @typedef {Object} %s' % union_name)
variant_types = []
for variant in union_type.all_fields:
variant_types.append("'%s'" % variant.name)
variant_data_type, _, _ = unwrap(variant.data_type)
# Don't emit fields for void types.
if not is_void_type(variant_data_type):
variant_doc = ' - Available if .tag is %s.' % variant.name
if variant.doc:
variant_doc += ' ' + variant.doc
self.emit_wrapped_text(
'@property {%s} [%s]%s' % (
fmt_type(variant_data_type),
variant.name,
variant_doc,
),
prefix=' * ',
)
jsdoc_tag_union = fmt_jsdoc_union(variant_types)
self.emit(' * @property {%s} .tag - Tag identifying the union variant.' % jsdoc_tag_union)
self.emit(' */')
def _docf(self, tag, val): # pylint: disable=unused-argument
"""
Callback used as the handler argument to process_docs(). This converts
Stone doc references to JSDoc-friendly annotations.
"""
# TODO(kelkabany): We're currently just dropping all doc ref tags.
return val

View file

@ -0,0 +1,283 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from contextlib import contextmanager
from stone.ir import (
is_list_type,
is_map_type,
is_struct_type,
is_union_type,
is_nullable_type,
is_user_defined_type,
is_void_type,
unwrap_nullable, )
from stone.backend import CodeBackend
from stone.backends.obj_c_helpers import (
fmt_camel_upper,
fmt_class,
fmt_class_prefix,
fmt_import, )
stone_warning = """\
///
/// Copyright (c) 2016 Dropbox, Inc. All rights reserved.
///
/// Auto-generated by Stone, do not modify.
///
"""
# This will be at the top of the generated file.
base_file_comment = """\
{}\
""".format(stone_warning)
undocumented = '(no description).'
comment_prefix = '/// '
class ObjCBaseBackend(CodeBackend):
"""Wrapper class over Stone generator for Obj C logic."""
# pylint: disable=abstract-method
@contextmanager
def block_m(self, class_name):
with self.block(
'@implementation {}'.format(class_name),
delim=('', '@end'),
dent=0):
self.emit()
yield
@contextmanager
def block_h_from_data_type(self, data_type, protocol=None):
assert is_user_defined_type(data_type), \
'Expected user-defined type, got %r' % type(data_type)
if not protocol:
extensions = []
if data_type.parent_type and is_struct_type(data_type):
extensions.append(fmt_class_prefix(data_type.parent_type))
else:
if is_union_type(data_type):
# Use a handwritten base class
extensions.append('NSObject')
else:
extensions.append('NSObject')
extend_suffix = ' : {}'.format(
', '.join(extensions)) if extensions else ''
else:
base = fmt_class_prefix(data_type.parent_type) if (
data_type.parent_type and
not is_union_type(data_type)) else 'NSObject'
extend_suffix = ' : {} <{}>'.format(base, ', '.join(protocol))
with self.block(
'@interface {}{}'.format(
fmt_class_prefix(data_type), extend_suffix),
delim=('', '@end'),
dent=0):
self.emit()
yield
@contextmanager
def block_h(self,
class_name,
protocol=None,
extensions=None,
protected=None):
if not extensions:
extensions = ['NSObject']
if not protocol:
extend_suffix = ' : {}'.format(', '.join(extensions))
else:
extend_suffix = ' : {} <{}>'.format(', '.join(extensions),
fmt_class(protocol))
base_interface_str = '@interface {}{} {{' if protected else '@interface {}{}'
with self.block(
base_interface_str.format(class_name, extend_suffix),
delim=('', '@end'),
dent=0):
if protected:
with self.block('', delim=('', '')):
self.emit('@protected')
for field_name, field_type in protected:
self.emit('{} _{};'.format(field_type, field_name))
self.emit('}')
self.emit()
yield
@contextmanager
def block_init(self):
with self.block('if (self)'):
yield
self.emit('return self;')
@contextmanager
def block_func(self, func, args=None, return_type='void',
class_func=False):
args = args if args is not None else []
modifier = '-' if not class_func else '+'
base_string = '{} ({}){}:{}' if args else '{} ({}){}'
signature = base_string.format(modifier, return_type, func, args)
with self.block(signature):
yield
def _get_imports_m(self, data_types, default_imports):
"""Emits all necessary implementation file imports for the given Stone data type."""
if not isinstance(data_types, list):
data_types = [data_types]
import_classes = default_imports
for data_type in data_types:
import_classes.append(fmt_class_prefix(data_type))
if data_type.parent_type:
import_classes.append(fmt_class_prefix(data_type.parent_type))
if is_struct_type(
data_type) and data_type.has_enumerated_subtypes():
for _, subtype in data_type.get_all_subtypes_with_tags():
import_classes.append(fmt_class_prefix(subtype))
for field in data_type.all_fields:
data_type, _ = unwrap_nullable(field.data_type)
# unpack list or map
while is_list_type(data_type) or is_map_type(data_type):
data_type = (data_type.value_data_type if
is_map_type(data_type) else data_type.data_type)
if is_user_defined_type(data_type):
import_classes.append(fmt_class_prefix(data_type))
if import_classes:
import_classes = list(set(import_classes))
import_classes.sort()
return import_classes
def _get_imports_h(self, data_types):
"""Emits all necessary header file imports for the given Stone data type."""
if not isinstance(data_types, list):
data_types = [data_types]
import_classes = []
for data_type in data_types:
if is_user_defined_type(data_type):
import_classes.append(fmt_class_prefix(data_type))
for field in data_type.all_fields:
data_type, _ = unwrap_nullable(field.data_type)
# unpack list or map
while is_list_type(data_type) or is_map_type(data_type):
data_type = (data_type.value_data_type if
is_map_type(data_type) else data_type.data_type)
if is_user_defined_type(data_type):
import_classes.append(fmt_class_prefix(data_type))
import_classes = list(set(import_classes))
import_classes.sort()
return import_classes
def _generate_imports_h(self, import_classes):
import_classes = list(set(import_classes))
import_classes.sort()
for import_class in import_classes:
self.emit('@class {};'.format(import_class))
if import_classes:
self.emit()
def _generate_imports_m(self, import_classes):
import_classes = list(set(import_classes))
import_classes.sort()
for import_class in import_classes:
self.emit(fmt_import(import_class))
self.emit()
def _generate_init_imports_h(self, data_type):
self.emit('#import <Foundation/Foundation.h>')
self.emit()
self.emit('#import "DBSerializableProtocol.h"')
if data_type.parent_type and not is_union_type(data_type):
self.emit(fmt_import(fmt_class_prefix(data_type.parent_type)))
self.emit()
def _get_namespace_route_imports(self,
namespace,
include_route_args=True,
include_route_deep_args=False):
result = []
def _unpack_and_store_data_type(data_type):
data_type, _ = unwrap_nullable(data_type)
if is_list_type(data_type):
while is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
if not is_void_type(data_type) and is_user_defined_type(data_type):
result.append(data_type)
for route in namespace.routes:
if include_route_args:
data_type, _ = unwrap_nullable(route.arg_data_type)
_unpack_and_store_data_type(data_type)
elif include_route_deep_args:
data_type, _ = unwrap_nullable(route.arg_data_type)
if is_union_type(data_type) or is_list_type(data_type):
_unpack_and_store_data_type(data_type)
elif not is_void_type(data_type):
for field in data_type.all_fields:
data_type, _ = unwrap_nullable(field.data_type)
if (is_struct_type(data_type) or
is_union_type(data_type) or
is_list_type(data_type)):
_unpack_and_store_data_type(data_type)
_unpack_and_store_data_type(route.result_data_type)
_unpack_and_store_data_type(route.error_data_type)
return result
def _cstor_name_from_fields(self, fields):
"""Returns an Obj C appropriate name for a constructor based on
the name of the first argument."""
if fields:
return self._cstor_name_from_field(fields[0])
else:
return 'initDefault'
def _cstor_name_from_field(self, field):
"""Returns an Obj C appropriate name for a constructor based on
the name of the supplied argument."""
return 'initWith{}'.format(fmt_camel_upper(field.name))
def _cstor_name_from_fields_names(self, fields_names):
"""Returns an Obj C appropriate name for a constructor based on
the name of the first argument."""
if fields_names:
return 'initWith{}'.format(fmt_camel_upper(fields_names[0][0]))
else:
return 'initDefault'
def _struct_has_defaults(self, struct):
"""Returns whether the given struct has any default values."""
return [
f for f in struct.all_fields
if f.has_default or is_nullable_type(f.data_type)
]

View file

@ -0,0 +1,618 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
from stone.ir import (
is_nullable_type,
is_struct_type,
is_union_type,
is_void_type,
unwrap_nullable, )
from stone.backends.obj_c_helpers import (
fmt_alloc_call,
fmt_camel_upper,
fmt_class,
fmt_class_prefix,
fmt_func,
fmt_func_args,
fmt_func_args_declaration,
fmt_func_call,
fmt_import,
fmt_property_str,
fmt_route_obj_class,
fmt_route_func,
fmt_route_var,
fmt_routes_class,
fmt_signature,
fmt_type,
fmt_var, )
from stone.backends.obj_c import (
base_file_comment,
comment_prefix,
ObjCBaseBackend,
stone_warning,
undocumented, )
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(
prog='objc-client-backend',
description=(
'Generates a ObjC class with an object for each namespace, and in each '
'namespace object, a method for each route. This class assumes that the '
'obj_c_types backend was used with the same output directory.'), )
_cmdline_parser.add_argument(
'-m',
'--module-name',
required=True,
type=str,
help=(
'The name of the ObjC module to generate. Please exclude the {.h,.m} '
'file extension.'), )
_cmdline_parser.add_argument(
'-c',
'--class-name',
required=True,
type=str,
help=(
'The name of the ObjC class that contains an object for each namespace, '
'and in each namespace object, a method for each route.'))
_cmdline_parser.add_argument(
'-t',
'--transport-client-name',
required=True,
type=str,
help='The name of the ObjC class that manages network API calls.', )
_cmdline_parser.add_argument(
'-w',
'--auth-type',
type=str,
help='The auth type of the client to generate.', )
_cmdline_parser.add_argument(
'-y',
'--client-args',
required=True,
type=str,
help='The client-side route arguments to append to each route by style type.', )
_cmdline_parser.add_argument(
'-z',
'--style-to-request',
required=True,
type=str,
help='The dict that maps a style type to a ObjC request object name.', )
class ObjCBackend(ObjCBaseBackend):
"""Generates ObjC client base that implements route interfaces."""
cmdline_parser = _cmdline_parser
obj_name_to_namespace = {} # type: typing.Dict[str, int]
namespace_to_has_routes = {} # type: typing.Dict[typing.Any, bool]
def generate(self, api):
for namespace in api.namespaces.values():
self.namespace_to_has_routes[namespace] = False
if namespace.routes:
for route in namespace.routes:
if self._should_generate_route(route):
self.namespace_to_has_routes[namespace] = True
break
for namespace in api.namespaces.values():
for data_type in namespace.linearize_data_types():
self.obj_name_to_namespace[data_type.name] = fmt_class_prefix(
data_type)
for namespace in api.namespaces.values():
if namespace.routes and self.namespace_to_has_routes[namespace]:
import_classes = [
fmt_routes_class(namespace.name, self.args.auth_type),
fmt_route_obj_class(namespace.name),
'{}Protocol'.format(self.args.transport_client_name),
'DBStoneBase',
'DBRequestErrors',
]
with self.output_to_relative_path('Routes/{}.m'.format(
fmt_routes_class(namespace.name,
self.args.auth_type))):
self.emit_raw(stone_warning)
imports_classes_m = import_classes + \
self._get_imports_m(
self._get_namespace_route_imports(namespace), [])
self._generate_imports_m(imports_classes_m)
self._generate_routes_m(namespace)
with self.output_to_relative_path('Routes/{}.h'.format(
fmt_routes_class(namespace.name,
self.args.auth_type))):
self.emit_raw(base_file_comment)
self.emit('#import <Foundation/Foundation.h>')
self.emit()
self.emit(fmt_import('DBTasks'))
self.emit()
import_classes_h = [
'DBNilObject',
]
import_classes_h = (import_classes_h + self._get_imports_h(
self._get_namespace_route_imports(
namespace,
include_route_args=False,
include_route_deep_args=True)))
self._generate_imports_h(import_classes_h)
self.emit(
'@protocol {};'.format(
self.args.transport_client_name), )
self.emit()
self._generate_routes_h(namespace)
with self.output_to_relative_path(
'Client/{}.m'.format(self.args.module_name)):
self._generate_client_m(api)
with self.output_to_relative_path(
'Client/{}.h'.format(self.args.module_name)):
self._generate_client_h(api)
def _generate_client_m(self, api):
"""Generates client base implementation file. For each namespace, the client will
have an object field that encapsulates each route in the particular namespace."""
self.emit_raw(base_file_comment)
import_classes = [self.args.module_name]
import_classes += [
fmt_routes_class(ns.name, self.args.auth_type)
for ns in api.namespaces.values()
if ns.routes and self.namespace_to_has_routes[ns]
]
import_classes.append(
'{}Protocol'.format(self.args.transport_client_name))
self._generate_imports_m(import_classes)
with self.block_m(self.args.class_name):
client_args = fmt_func_args_declaration(
[('client',
'id<{}>'.format(self.args.transport_client_name))])
with self.block_func(
func='initWithTransportClient',
args=client_args,
return_type='instancetype'):
self.emit('self = [super init];')
with self.block_init():
self.emit('_transportClient = client;')
for namespace in api.namespaces.values():
if namespace.routes and self.namespace_to_has_routes[namespace]:
base_string = '_{}Routes = [[{} alloc] init:client];'
self.emit(
base_string.format(
fmt_var(namespace.name),
fmt_routes_class(namespace.name,
self.args.auth_type)))
def _generate_client_h(self, api):
"""Generates client base header file. For each namespace, the client will
have an object field that encapsulates each route in the particular namespace."""
self.emit_raw(stone_warning)
self.emit('#import <Foundation/Foundation.h>')
import_classes = [
fmt_routes_class(ns.name, self.args.auth_type)
for ns in api.namespaces.values()
if ns.routes and self.namespace_to_has_routes[ns]
]
import_classes.append('DBRequestErrors')
import_classes.append('DBTasks')
self._generate_imports_m(import_classes)
self.emit()
self.emit('NS_ASSUME_NONNULL_BEGIN')
self.emit()
self.emit('@protocol {};'.format(self.args.transport_client_name))
self.emit()
self.emit(comment_prefix)
description_str = (
'Base client object that contains an instance field for '
'each namespace, each of which contains references to all routes within '
'that namespace. Fully-implemented API clients will inherit this class.'
)
self.emit_wrapped_text(description_str, prefix=comment_prefix)
self.emit(comment_prefix)
with self.block_h(
self.args.class_name,
protected=[
('transportClient',
'id<{}>'.format(self.args.transport_client_name))
]):
self.emit()
for namespace in api.namespaces.values():
if namespace.routes and self.namespace_to_has_routes[namespace]:
class_doc = 'Routes within the `{}` namespace.'.format(
fmt_var(namespace.name))
self.emit_wrapped_text(class_doc, prefix=comment_prefix)
prop = '{}Routes'.format(fmt_var(namespace.name))
typ = '{} *'.format(
fmt_routes_class(namespace.name, self.args.auth_type))
self.emit(fmt_property_str(prop=prop, typ=typ))
self.emit()
client_args = fmt_func_args_declaration(
[('client',
'id<{}>'.format(self.args.transport_client_name))])
description_str = (
'Initializes the `{}` object with a networking client.')
self.emit_wrapped_text(
description_str.format(self.args.class_name),
prefix=comment_prefix)
init_signature = fmt_signature(
func='initWithTransportClient',
args=client_args,
return_type='instancetype')
self.emit('{};'.format(init_signature))
self.emit()
self.emit()
self.emit('NS_ASSUME_NONNULL_END')
def _auth_type_in_route(self, route, desired_auth_type):
for auth_type in route.attrs.get('auth').split(','):
if auth_type.strip() == desired_auth_type:
return True
return False
def _route_is_special_noauth_case(self, route):
return self._auth_type_in_route(route, 'noauth') and self.args.auth_type == 'user'
def _should_generate_route(self, route):
return (self._auth_type_in_route(route, self.args.auth_type) or
self._route_is_special_noauth_case(route))
def _generate_routes_m(self, namespace):
"""Generates implementation file for namespace object that has as methods
all routes within the namespace."""
with self.block_m(
fmt_routes_class(namespace.name, self.args.auth_type)):
init_args = fmt_func_args_declaration([(
'client', 'id<{}>'.format(self.args.transport_client_name))])
with self.block_func(
func='init', args=init_args, return_type='instancetype'):
self.emit('self = [super init];')
with self.block_init():
self.emit('_client = client;')
self.emit()
style_to_request = json.loads(self.args.style_to_request)
for route in namespace.routes:
if not self._should_generate_route(route):
continue
route_type = route.attrs.get('style')
client_args = json.loads(self.args.client_args)
if route_type in client_args.keys():
for args_data in client_args[route_type]:
task_type_key, type_data_dict = tuple(args_data)
task_type_name = style_to_request[task_type_key]
func_suffix = type_data_dict[0]
extra_args = [
tuple(type_data[:-1])
for type_data in type_data_dict[1]
]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, _ = self._get_default_route_args(
namespace, route)
self._generate_route_m(route, namespace,
route_args, extra_args,
task_type_name, func_suffix)
route_args, _ = self._get_route_args(namespace, route)
self._generate_route_m(route, namespace, route_args,
extra_args, task_type_name,
func_suffix)
else:
task_type_name = style_to_request[route_type]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, _ = self._get_default_route_args(
namespace, route)
self._generate_route_m(route, namespace, route_args,
[], task_type_name, '')
route_args, _ = self._get_route_args(namespace, route)
self._generate_route_m(route, namespace, route_args, [],
task_type_name, '')
def _generate_route_m(self, route, namespace, route_args, extra_args,
task_type_name, func_suffix):
"""Generates route method implementation for the given route."""
user_args = list(route_args)
transport_args = [
('route', 'route'),
('arg', 'arg' if not is_void_type(route.arg_data_type) else 'nil'),
]
for name, value, typ in extra_args:
user_args.append((name, typ))
transport_args.append((name, value))
with self.block_func(
func='{}{}'.format(fmt_route_func(route), func_suffix),
args=fmt_func_args_declaration(user_args),
return_type='{} *'.format(task_type_name)):
self.emit('DBRoute *route = {}.{};'.format(
fmt_route_obj_class(namespace.name),
fmt_route_var(namespace.name, route)))
if is_union_type(route.arg_data_type):
self.emit('{} *arg = {};'.format(
fmt_class_prefix(route.arg_data_type),
fmt_var(route.arg_data_type.name)))
elif not is_void_type(route.arg_data_type):
init_call = fmt_func_call(
caller=fmt_alloc_call(
caller=fmt_class_prefix(route.arg_data_type)),
callee=self._cstor_name_from_fields_names(route_args),
args=fmt_func_args([(f[0], f[0]) for f in route_args]))
self.emit('{} *arg = {};'.format(
fmt_class_prefix(route.arg_data_type), init_call))
request_call = fmt_func_call(
caller='self.client',
callee='request{}'.format(
fmt_camel_upper(route.attrs.get('style'))),
args=fmt_func_args(transport_args))
self.emit('return {};'.format(request_call))
self.emit()
def _generate_routes_h(self, namespace):
"""Generates header file for namespace object that has as methods
all routes within the namespace."""
self.emit(comment_prefix)
self.emit_wrapped_text(
'Routes for the `{}` namespace'.format(fmt_class(namespace.name)),
prefix=comment_prefix)
self.emit(comment_prefix)
self.emit()
self.emit('NS_ASSUME_NONNULL_BEGIN')
self.emit()
with self.block_h(
fmt_routes_class(namespace.name, self.args.auth_type)):
description_str = (
'An instance of the networking client that each '
'route will use to submit a request.')
self.emit_wrapped_text(description_str, prefix=comment_prefix)
self.emit(
fmt_property_str(
prop='client',
typ='id<{}>'.format(
self.args.transport_client_name)))
self.emit()
routes_obj_args = fmt_func_args_declaration(
[('client',
'id<{}>'.format(self.args.transport_client_name))])
init_signature = fmt_signature(
func='init',
args=routes_obj_args,
return_type='instancetype')
description_str = (
'Initializes the `{}` namespace container object '
'with a networking client.')
self.emit_wrapped_text(
description_str.format(
fmt_routes_class(namespace.name, self.args.auth_type)),
prefix=comment_prefix)
self.emit('{};'.format(init_signature))
self.emit()
style_to_request = json.loads(self.args.style_to_request)
for route in namespace.routes:
if not self._should_generate_route(route):
continue
route_type = route.attrs.get('style')
client_args = json.loads(self.args.client_args)
if route_type in client_args.keys():
for args_data in client_args[route_type]:
task_type_key, type_data_dict = tuple(args_data)
task_type_name = style_to_request[task_type_key]
func_suffix = type_data_dict[0]
extra_args = [
tuple(type_data[:-1])
for type_data in type_data_dict[1]
]
extra_docs = [(type_data[0], type_data[-1])
for type_data in type_data_dict[1]]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, doc_list = self._get_default_route_args(
namespace, route, tag=True)
self._generate_route_signature(
route, namespace, route_args, extra_args,
doc_list + extra_docs, task_type_name,
func_suffix)
route_args, doc_list = self._get_route_args(
namespace, route, tag=True)
self._generate_route_signature(
route, namespace, route_args, extra_args,
doc_list + extra_docs, task_type_name, func_suffix)
else:
task_type_name = style_to_request[route_type]
if (is_struct_type(route.arg_data_type) and
self._struct_has_defaults(route.arg_data_type)):
route_args, doc_list = self._get_default_route_args(
namespace, route, tag=True)
self._generate_route_signature(
route, namespace, route_args, [], doc_list,
task_type_name, '')
route_args, doc_list = self._get_route_args(
namespace, route, tag=True)
self._generate_route_signature(route, namespace,
route_args, [], doc_list,
task_type_name, '')
self.emit()
self.emit('NS_ASSUME_NONNULL_END')
self.emit()
def _generate_route_signature(
self,
route,
namespace, # pylint: disable=unused-argument
route_args,
extra_args,
doc_list,
task_type_name,
func_suffix):
"""Generates route method signature for the given route."""
for name, _, typ in extra_args:
route_args.append((name, typ))
deprecated = 'DEPRECATED: ' if route.deprecated else ''
func_name = '{}{}'.format(fmt_route_func(route), func_suffix)
self.emit(comment_prefix)
if route.doc:
route_doc = self.process_doc(route.doc, self._docf)
else:
route_doc = 'The {} route'.format(func_name)
self.emit_wrapped_text(
deprecated + route_doc, prefix=comment_prefix, width=120)
self.emit(comment_prefix)
for name, doc in doc_list:
self.emit_wrapped_text(
'@param {} {}'.format(name, doc if doc else undocumented),
prefix=comment_prefix,
width=120)
self.emit(comment_prefix)
output = (
'@return Through the response callback, the caller will ' +
'receive a `{}` object on success or a `{}` object on failure.')
output = output.format(
fmt_type(route.result_data_type, tag=False, no_ptr=True),
fmt_type(route.error_data_type, tag=False, no_ptr=True))
self.emit_wrapped_text(output, prefix=comment_prefix, width=120)
self.emit(comment_prefix)
result_type_str = fmt_type(route.result_data_type) if not is_void_type(
route.result_data_type) else 'DBNilObject *'
error_type_str = fmt_type(route.error_data_type) if not is_void_type(
route.error_data_type) else 'DBNilObject *'
return_type = '{}<{}, {}> *'.format(task_type_name, result_type_str,
error_type_str)
deprecated = self._get_deprecation_warning(route)
route_signature = fmt_signature(
func=func_name,
args=fmt_func_args_declaration(route_args),
return_type='{}'.format(return_type))
self.emit('{}{};'.format(route_signature, deprecated))
self.emit()
def _get_deprecation_warning(self, route):
"""Returns a deprecation tag / message, if route is deprecated."""
result = ''
if route.deprecated:
msg = '{} is deprecated.'.format(fmt_route_func(route))
if route.deprecated.by:
msg += ' Use {}.'.format(fmt_var(route.deprecated.by.name))
result = ' __deprecated_msg("{}")'.format(msg)
return result
def _get_route_args(self, namespace, route, tag=False): # pylint: disable=unused-argument
"""Returns a list of name / value string pairs representing the arguments for
a particular route."""
data_type, _ = unwrap_nullable(route.arg_data_type)
if is_struct_type(data_type):
arg_list = []
for field in data_type.all_fields:
arg_list.append((fmt_var(field.name), fmt_type(
field.data_type, tag=tag, has_default=field.has_default)))
doc_list = [(fmt_var(f.name), self.process_doc(f.doc, self._docf))
for f in data_type.fields if f.doc]
elif is_union_type(data_type):
arg_list = [(fmt_var(data_type.name), fmt_type(
route.arg_data_type, tag=tag))]
doc_list = [(fmt_var(data_type.name),
self.process_doc(data_type.doc,
self._docf) if data_type.doc
else 'The {} union'.format(
fmt_class(data_type
.name)))]
else:
arg_list = []
doc_list = []
return arg_list, doc_list
def _get_default_route_args(
self,
namespace, # pylint: disable=unused-argument
route,
tag=False):
"""Returns a list of name / value string pairs representing the default arguments for
a particular route."""
data_type, _ = unwrap_nullable(route.arg_data_type)
if is_struct_type(data_type):
arg_list = []
for field in data_type.all_fields:
if not field.has_default and not is_nullable_type(
field.data_type):
arg_list.append((fmt_var(field.name), fmt_type(
field.data_type, tag=tag)))
doc_list = ([(fmt_var(f.name), self.process_doc(f.doc, self._docf))
for f in data_type.fields
if f.doc and not f.has_default and
not is_nullable_type(f.data_type)])
else:
arg_list = []
doc_list = []
return arg_list, doc_list
def _docf(self, tag, val):
if tag == 'route':
return '`{}`'.format(fmt_func(val))
elif tag == 'field':
if '.' in val:
cls_name, field = val.split('.')
return ('`{}` in `{}`'.format(
fmt_var(field), self.obj_name_to_namespace[cls_name]))
else:
return fmt_var(val)
elif tag in ('type', 'val', 'link'):
return val
else:
return val

View file

@ -0,0 +1,483 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import pprint
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
Map,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_boolean_type,
is_list_type,
is_map_type,
is_numeric_type,
is_string_type,
is_tag_ref,
is_user_defined_type,
is_void_type,
unwrap_nullable, )
from .helpers import split_words
# This file defines *stylistic* choices for Swift
# (ie, that class names are UpperCamelCase and that variables are lowerCamelCase)
_primitive_table = {
Boolean: 'NSNumber *',
Bytes: 'NSData',
Float32: 'NSNumber *',
Float64: 'NSNumber *',
Int32: 'NSNumber *',
Int64: 'NSNumber *',
List: 'NSArray',
Map: 'NSDictionary',
String: 'NSString *',
Timestamp: 'NSDate *',
UInt32: 'NSNumber *',
UInt64: 'NSNumber *',
Void: 'void',
}
_primitive_table_user_interface = {
Boolean: 'BOOL',
Bytes: 'NSData',
Float32: 'double',
Float64: 'double',
Int32: 'int',
Int64: 'long',
List: 'NSArray',
Map: 'NSDictionary',
String: 'NSString *',
Timestamp: 'NSDate *',
UInt32: 'unsigned int',
UInt64: 'unsigned long',
Void: 'void',
}
_serial_table = {
Boolean: 'DBBoolSerializer',
Bytes: 'DBNSDataSerializer',
Float32: 'DBNSNumberSerializer',
Float64: 'DBNSNumberSerializer',
Int32: 'DBNSNumberSerializer',
Int64: 'DBNSNumberSerializer',
List: 'DBArraySerializer',
Map: 'DBMapSerializer',
String: 'DBStringSerializer',
Timestamp: 'DBNSDateSerializer',
UInt32: 'DBNSNumberSerializer',
UInt64: 'DBNSNumberSerializer',
}
_validator_table = {
Float32: 'numericValidator',
Float64: 'numericValidator',
Int32: 'numericValidator',
Int64: 'numericValidator',
List: 'arrayValidator',
Map: 'mapValidator',
String: 'stringValidator',
UInt32: 'numericValidator',
UInt64: 'numericValidator',
}
_wrapper_primitives = {
Boolean,
Float32,
Float64,
UInt32,
UInt64,
Int32,
Int64,
String,
}
_reserved_words = {
'auto',
'else',
'long',
'switch',
'break',
'enum',
'register',
'typedef',
'case',
'extern',
'return',
'union',
'char',
'float',
'short',
'unsigned',
'const',
'for',
'signed',
'void',
'continue',
'goto',
'sizeof',
'volatile',
'default',
'if',
'static',
'while',
'do',
'int',
'struct',
'_Packed',
'double',
'protocol',
'interface',
'implementation',
'NSObject',
'NSInteger',
'NSNumber',
'CGFloat',
'property',
'nonatomic',
'retain',
'strong',
'weak',
'unsafe_unretained',
'readwrite',
'description',
'id',
'delete',
}
_reserved_prefixes = {
'copy',
'new',
}
def fmt_obj(o):
assert not isinstance(o, dict), "Only use for base type literals"
if o is True:
return 'true'
if o is False:
return 'false'
if o is None:
return 'nil'
return pprint.pformat(o, width=1)
def fmt_camel(name, upper_first=False, reserved=True):
name = str(name)
words = [word.capitalize() for word in split_words(name)]
if not upper_first:
words[0] = words[0].lower()
ret = ''.join(words)
if reserved:
if ret.lower() in _reserved_words:
ret += '_'
# properties can't begin with certain keywords
for reserved_prefix in _reserved_prefixes:
if ret.lower().startswith(reserved_prefix):
new_prefix = 'd' if not upper_first else 'D'
ret = new_prefix + ret[0].upper() + ret[1:]
continue
return ret
def fmt_enum_name(field_name, union):
return 'DB{}{}{}'.format(
fmt_class_caps(union.namespace.name),
fmt_camel_upper(union.name), fmt_camel_upper(field_name))
def fmt_camel_upper(name, reserved=True):
return fmt_camel(name, upper_first=True, reserved=reserved)
def fmt_public_name(name):
return fmt_camel_upper(name)
def fmt_class(name):
return fmt_camel_upper(name)
def fmt_class_caps(name):
return fmt_camel_upper(name).upper()
def fmt_class_type(data_type, suppress_ptr=False):
data_type, _ = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}'.format(fmt_class_prefix(data_type))
else:
result = _primitive_table.get(data_type.__class__,
fmt_class(data_type.name))
if suppress_ptr:
result = result.replace(' *', '')
result = result.replace('*', '')
if is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
result = result + '<{}>'.format(fmt_type(data_type))
elif is_map_type(data_type):
data_type, _ = unwrap_nullable(data_type.value_data_type)
result = result + '<NSString *, {}>'.format(fmt_type(data_type))
return result
def fmt_func(name):
return fmt_camel(name)
def fmt_type(data_type, tag=False, has_default=False, no_ptr=False, is_prop=False):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
base = '{}' if no_ptr else '{} *'
result = base.format(fmt_class_prefix(data_type))
else:
result = _primitive_table.get(data_type.__class__,
fmt_class(data_type.name))
if is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
base = '<{}>' if no_ptr else '<{}> *'
result = result + base.format(fmt_type(data_type))
elif is_map_type(data_type):
data_type, _ = unwrap_nullable(data_type.value_data_type)
base = '<NSString *, {}>' if no_ptr else '<NSString *, {}> *'
result = result + base.format(fmt_type(data_type))
if tag:
if (nullable or has_default) and not is_prop:
result = 'nullable ' + result
return result
def fmt_route_type(data_type, tag=False, has_default=False):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{} *'.format(fmt_class_prefix(data_type))
else:
result = _primitive_table_user_interface.get(data_type.__class__,
fmt_class(data_type.name))
if is_list_type(data_type):
data_type, _ = unwrap_nullable(data_type.data_type)
result = result + '<{}> *'.format(fmt_type(data_type))
elif is_map_type(data_type):
data_type, _ = unwrap_nullable(data_type.value_data_type)
result = result + '<NSString *, {}>'.format(fmt_type(data_type))
if is_user_defined_type(data_type) and tag:
if nullable or has_default:
result = 'nullable ' + result
elif not is_void_type(data_type):
result += ''
return result
def fmt_class_prefix(data_type):
return 'DB{}{}'.format(
fmt_class_caps(data_type.namespace.name), fmt_class(data_type.name))
def fmt_validator(data_type):
return _validator_table.get(data_type.__class__, fmt_class(data_type.name))
def fmt_serial_obj(data_type):
data_type, _ = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = fmt_serial_class(fmt_class_prefix(data_type))
else:
result = _serial_table.get(data_type.__class__,
fmt_class(data_type.name))
return result
def fmt_serial_class(class_name):
return '{}Serializer'.format(class_name)
def fmt_route_obj_class(namespace_name):
return 'DB{}RouteObjects'.format(fmt_class_caps(namespace_name))
def fmt_routes_class(namespace_name, auth_type):
auth_type_to_use = auth_type
if auth_type == 'noauth':
auth_type_to_use = 'user'
return 'DB{}{}AuthRoutes'.format(
fmt_class_caps(namespace_name), fmt_camel_upper(auth_type_to_use))
def fmt_route_var(namespace_name, route):
ret = 'DB{}{}'.format(
fmt_class_caps(namespace_name), fmt_camel_upper(route.name))
if route.version != 1:
ret = '{}V{}'.format(ret, route.version)
return ret
def fmt_route_func(route):
ret = fmt_var(route.name)
if route.version != 1:
ret = '{}V{}'.format(ret, route.version)
return ret
def fmt_func_args(arg_str_pairs):
result = []
first_arg = True
for arg_name, arg_value in arg_str_pairs:
if first_arg:
result.append('{}'.format(arg_value))
first_arg = False
else:
result.append('{}:{}'.format(arg_name, arg_value))
return ' '.join(result)
def fmt_func_args_declaration(arg_str_pairs):
result = []
first_arg = True
for arg_name, arg_type in arg_str_pairs:
if first_arg:
result.append('({}){}'.format(arg_type, arg_name))
first_arg = False
else:
result.append('{0}:({1}){0}'.format(arg_name, arg_type))
return ' '.join(result)
def fmt_func_args_from_fields(args):
result = []
first_arg = True
for arg in args:
if first_arg:
result.append(
'({}){}'.format(fmt_type(arg.data_type), fmt_var(arg.name)))
first_arg = False
else:
result.append('{}:({}){}'.format(
fmt_var(arg.name), fmt_type(arg.data_type), fmt_var(arg.name)))
return ' '.join(result)
def fmt_func_call(caller, callee, args=None):
if args:
result = '[{} {}:{}]'.format(caller, callee, args)
else:
result = '[{} {}]'.format(caller, callee)
return result
def fmt_alloc_call(caller):
return '[{} alloc]'.format(caller)
def fmt_default_value(field):
if is_tag_ref(field.default):
return '[[{} alloc] initWith{}]'.format(
fmt_class_prefix(field.default.union_data_type),
fmt_class(field.default.tag_name))
elif is_numeric_type(field.data_type):
return '@({})'.format(field.default)
elif is_boolean_type(field.data_type):
if field.default:
bool_str = 'YES'
else:
bool_str = 'NO'
return '@{}'.format(bool_str)
elif is_string_type(field.data_type):
return '@"{}"'.format(field.default)
else:
raise TypeError(
'Can\'t handle default value type %r' % type(field.data_type))
def fmt_ns_number_call(data_type):
result = ''
if is_numeric_type(data_type):
if isinstance(data_type, UInt32):
result = 'numberWithUnsignedInt'
elif isinstance(data_type, UInt64):
result = 'numberWithUnsignedLong'
elif isinstance(data_type, Int32):
result = 'numberWithInt'
elif isinstance(data_type, Int64):
result = 'numberWithLong'
elif isinstance(data_type, Float32):
result = 'numberWithDouble'
elif isinstance(data_type, Float64):
result = 'numberWithDouble'
elif is_boolean_type(data_type):
result = 'numberWithBool'
return result
def fmt_signature(func, args, return_type='void', class_func=False):
modifier = '-' if not class_func else '+'
if args:
result = '{} ({}){}:{}'.format(modifier, return_type, func, args)
else:
result = '{} ({}){}'.format(modifier, return_type, func)
return result
def is_primitive_type(data_type):
data_type, _ = unwrap_nullable(data_type)
return data_type.__class__ in _wrapper_primitives
def fmt_var(name):
return fmt_camel(name)
def fmt_property(field):
attrs = ['nonatomic', 'readonly']
data_type, nullable = unwrap_nullable(field.data_type)
if is_string_type(data_type):
attrs.append('copy')
if nullable:
attrs.append('nullable')
base_string = '@property ({}) {}{};'
return base_string.format(', '.join(attrs),
fmt_type(field.data_type, tag=True, is_prop=True),
fmt_var(field.name))
def fmt_import(header_file):
return '#import "{}.h"'.format(header_file)
def fmt_property_str(prop, typ, attrs=None):
if not attrs:
attrs = ['nonatomic', 'readonly']
base_string = '@property ({}) {} {};'
return base_string.format(', '.join(attrs), typ, prop)
def append_to_jazzy_category_dict(jazzy_dict, label, item):
for category_dict in jazzy_dict['custom_categories']:
if category_dict['name'] == label:
category_dict['children'].append(item)
return
return None

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,560 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import re
from stone.backend import CodeBackend
from stone.backends.helpers import fmt_underscores
from stone.backends.python_helpers import (
check_route_name_conflict,
fmt_class,
fmt_func,
fmt_namespace,
fmt_obj,
fmt_type,
fmt_var,
)
from stone.backends.python_types import (
class_name_for_data_type,
)
from stone.ir import (
is_nullable_type,
is_list_type,
is_map_type,
is_struct_type,
is_tag_ref,
is_union_type,
is_user_defined_type,
is_void_type,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
# This will be at the top of the generated file.
base = """\
# -*- coding: utf-8 -*-
# Auto-generated by Stone, do not modify.
# flake8: noqa
# pylint: skip-file
from abc import ABCMeta, abstractmethod
"""
# Matches format of Babel doc tags
doc_sub_tag_re = re.compile(':(?P<tag>[A-z]*):`(?P<val>.*?)`')
DOCSTRING_CLOSE_RESPONSE = """\
If you do not consume the entire response body, then you must call close on the
response object, otherwise you will max out your available connections. We
recommend using the `contextlib.closing
<https://docs.python.org/2/library/contextlib.html#contextlib.closing>`_
context manager to ensure this."""
_cmdline_parser = argparse.ArgumentParser(
prog='python-client-backend',
description=(
'Generates a Python class with a method for each route. Extend the '
'generated class and implement the abstract request() method. This '
'class assumes that the python_types backend was used with the same '
'output directory.'),
)
_cmdline_parser.add_argument(
'-m',
'--module-name',
required=True,
type=str,
help=('The name of the Python module to generate. Please exclude the .py '
'file extension.'),
)
_cmdline_parser.add_argument(
'-c',
'--class-name',
required=True,
type=str,
help='The name of the Python class that contains each route as a method.',
)
_cmdline_parser.add_argument(
'-t',
'--types-package',
required=True,
type=str,
help='The output Python package of the python_types backend.',
)
_cmdline_parser.add_argument(
'-e',
'--error-class-path',
default='.exceptions.ApiError',
type=str,
help=(
"The path to the class that's raised when a route returns an error. "
"The class name is inserted into the doc for route methods."),
)
_cmdline_parser.add_argument(
'-w',
'--auth-type',
type=str,
help='The auth type of the client to generate.',
)
class PythonClientBackend(CodeBackend):
cmdline_parser = _cmdline_parser
supported_auth_types = None
def generate(self, api):
"""Generates a module called "base".
The module will contain a base class that will have a method for
each route across all namespaces.
"""
with self.output_to_relative_path('%s.py' % self.args.module_name):
self.emit_raw(base)
# Import "warnings" if any of the routes are deprecated.
found_deprecated = False
for namespace in api.namespaces.values():
for route in namespace.routes:
if route.deprecated:
self.emit('import warnings')
found_deprecated = True
break
if found_deprecated:
break
self.emit()
self._generate_imports(api.namespaces.values())
self.emit()
self.emit() # PEP-8 expects two-blank lines before class def
self.emit('class %s(object):' % self.args.class_name)
with self.indent():
self.emit('__metaclass__ = ABCMeta')
self.emit()
self.emit('@abstractmethod')
self.emit(
'def request(self, route, namespace, arg, arg_binary=None):')
with self.indent():
self.emit('pass')
self.emit()
self._generate_route_methods(api.namespaces.values())
def _generate_imports(self, namespaces):
# Only import namespaces that have user-defined types defined.
for namespace in namespaces:
if namespace.data_types:
self.emit('from {} import {}'.format(self.args.types_package, fmt_namespace(namespace.name)))
def _generate_route_methods(self, namespaces):
"""Creates methods for the routes in each namespace. All data types
and routes are represented as Python classes."""
self.cur_namespace = None
for namespace in namespaces:
if namespace.routes:
self.emit('# ------------------------------------------')
self.emit('# Routes in {} namespace'.format(namespace.name))
self.emit()
self._generate_routes(namespace)
def _generate_routes(self, namespace):
"""
Generates Python methods that correspond to routes in the namespace.
"""
# Hack: needed for _docf()
self.cur_namespace = namespace
# list of auth_types supported in this base class.
# this is passed with the new -w flag
if self.args.auth_type is not None:
self.supported_auth_types = [auth_type.strip().lower() for auth_type in self.args.auth_type.split(',')]
check_route_name_conflict(namespace)
for route in namespace.routes:
# compatibility mode : included routes are passed by whitelist
# actual auth attr inluded in the route is ignored in this mode.
if self.supported_auth_types is None:
self._generate_route_helper(namespace, route)
if route.attrs.get('style') == 'download':
self._generate_route_helper(namespace, route, True)
else:
route_auth_attr = None
if route.attrs is not None:
route_auth_attr = route.attrs.get('auth')
if route_auth_attr is None:
continue
route_auth_modes = [mode.strip().lower() for mode in route_auth_attr.split(',')]
for base_auth_type in self.supported_auth_types:
if base_auth_type in route_auth_modes:
self._generate_route_helper(namespace, route)
if route.attrs.get('style') == 'download':
self._generate_route_helper(namespace, route, True)
break # to avoid duplicate method declaration in the same base class
def _generate_route_helper(self, namespace, route, download_to_file=False):
"""Generate a Python method that corresponds to a route.
:param namespace: Namespace that the route belongs to.
:param stone.ir.ApiRoute route: IR node for the route.
:param bool download_to_file: Whether a special version of the route
that downloads the response body to a file should be generated.
This can only be used for download-style routes.
"""
arg_data_type = route.arg_data_type
result_data_type = route.result_data_type
request_binary_body = route.attrs.get('style') == 'upload'
response_binary_body = route.attrs.get('style') == 'download'
if download_to_file:
assert response_binary_body, 'download_to_file can only be set ' \
'for download-style routes.'
self._generate_route_method_decl(namespace,
route,
arg_data_type,
request_binary_body,
method_name_suffix='_to_file',
extra_args=['download_path'])
else:
self._generate_route_method_decl(namespace,
route,
arg_data_type,
request_binary_body)
with self.indent():
extra_request_args = None
extra_return_arg = None
footer = None
if request_binary_body:
extra_request_args = [('f',
'bytes',
'Contents to upload.')]
elif download_to_file:
extra_request_args = [('download_path',
'str',
'Path on local machine to save file.')]
if response_binary_body and not download_to_file:
extra_return_arg = ':class:`requests.models.Response`'
footer = DOCSTRING_CLOSE_RESPONSE
if route.doc:
func_docstring = self.process_doc(route.doc, self._docf)
else:
func_docstring = None
self._generate_docstring_for_func(
namespace,
arg_data_type,
result_data_type,
route.error_data_type,
overview=func_docstring,
extra_request_args=extra_request_args,
extra_return_arg=extra_return_arg,
footer=footer,
)
self._maybe_generate_deprecation_warning(route)
# Code to instantiate a class for the request data type
if is_void_type(arg_data_type):
self.emit('arg = None')
elif is_struct_type(arg_data_type):
self.generate_multiline_list(
[f.name for f in arg_data_type.all_fields],
before='arg = {}.{}'.format(
fmt_namespace(arg_data_type.namespace.name),
fmt_class(arg_data_type.name)),
)
elif not is_union_type(arg_data_type):
raise AssertionError('Unhandled request type %r' %
arg_data_type)
# Code to make the request
args = [
'{}.{}'.format(fmt_namespace(namespace.name),
fmt_func(route.name, version=route.version)),
"'{}'".format(namespace.name),
'arg']
if request_binary_body:
args.append('f')
else:
args.append('None')
self.generate_multiline_list(args, 'r = self.request', compact=False)
if download_to_file:
self.emit('self._save_body_to_file(download_path, r[1])')
if is_void_type(result_data_type):
self.emit('return None')
else:
self.emit('return r[0]')
else:
if is_void_type(result_data_type):
self.emit('return None')
else:
self.emit('return r')
self.emit()
def _generate_route_method_decl(
self, namespace, route, arg_data_type, request_binary_body,
method_name_suffix='', extra_args=None):
"""Generates the method prototype for a route."""
args = ['self']
if extra_args:
args += extra_args
if request_binary_body:
args.append('f')
if is_struct_type(arg_data_type):
for field in arg_data_type.all_fields:
if is_nullable_type(field.data_type):
args.append('{}=None'.format(field.name))
elif field.has_default:
# TODO(kelkabany): Decide whether we really want to set the
# default in the argument list. This will send the default
# over the wire even if it isn't overridden. The benefit is
# it locks in a default even if it is changed server-side.
if is_user_defined_type(field.data_type):
ns = field.data_type.namespace
else:
ns = None
arg = '{}={}'.format(
field.name,
self._generate_python_value(ns, field.default))
args.append(arg)
else:
args.append(field.name)
elif is_union_type(arg_data_type):
args.append('arg')
elif not is_void_type(arg_data_type):
raise AssertionError('Unhandled request type: %r' %
arg_data_type)
method_name = fmt_func(route.name + method_name_suffix, version=route.version)
namespace_name = fmt_underscores(namespace.name)
self.generate_multiline_list(args, 'def {}_{}'.format(namespace_name, method_name), ':')
def _maybe_generate_deprecation_warning(self, route):
if route.deprecated:
msg = '{} is deprecated.'.format(route.name)
if route.deprecated.by:
msg += ' Use {}.'.format(route.deprecated.by.name)
args = ["'{}'".format(msg), 'DeprecationWarning']
self.generate_multiline_list(
args,
before='warnings.warn',
delim=('(', ')'),
compact=False,
)
def _generate_docstring_for_func(self, namespace, arg_data_type,
result_data_type=None, error_data_type=None,
overview=None, extra_request_args=None,
extra_return_arg=None, footer=None):
"""
Generates a docstring for a function or method.
This function is versatile. It will create a docstring using all the
data that is provided.
:param arg_data_type: The data type describing the argument to the
route. The data type should be a struct, and each field will be
treated as an input parameter of the method.
:param result_data_type: The data type of the route result.
:param error_data_type: The data type of the route result in the case
of an error.
:param str overview: A description of the route that will be located
at the top of the docstring.
:param extra_request_args: [(field name, field type, field doc), ...]
Describes any additional parameters for the method that aren't a
field in arg_data_type.
:param str extra_return_arg: Name of an additional return type that. If
this is specified, it is assumed that the return of the function
will be a tuple of return_data_type and extra_return-arg.
:param str footer: Additional notes at the end of the docstring.
"""
fields = [] if is_void_type(arg_data_type) else arg_data_type.fields
if not fields and not overview:
# If we don't have an overview or any input parameters, we skip the
# docstring altogether.
return
self.emit('"""')
if overview:
self.emit_wrapped_text(overview)
# Description of all input parameters
if extra_request_args or fields:
if overview:
# Add a blank line if we had an overview
self.emit()
if extra_request_args:
for name, data_type_name, doc in extra_request_args:
if data_type_name:
field_doc = ':param {} {}: {}'.format(data_type_name,
name, doc)
self.emit_wrapped_text(field_doc,
subsequent_prefix=' ')
else:
self.emit_wrapped_text(
':param {}: {}'.format(name, doc),
subsequent_prefix=' ')
if is_struct_type(arg_data_type):
for field in fields:
if field.doc:
if is_user_defined_type(field.data_type):
field_doc = ':param {}: {}'.format(
field.name, self.process_doc(field.doc, self._docf))
else:
field_doc = ':param {} {}: {}'.format(
self._format_type_in_doc(namespace, field.data_type),
field.name,
self.process_doc(field.doc, self._docf),
)
self.emit_wrapped_text(
field_doc, subsequent_prefix=' ')
if is_user_defined_type(field.data_type):
# It's clearer to declare the type of a composite on
# a separate line since it references a class in
# another module
self.emit(':type {}: {}'.format(
field.name,
self._format_type_in_doc(namespace, field.data_type),
))
else:
# If the field has no docstring, then just document its
# type.
field_doc = ':type {}: {}'.format(
field.name,
self._format_type_in_doc(namespace, field.data_type),
)
self.emit_wrapped_text(field_doc)
elif is_union_type(arg_data_type):
if arg_data_type.doc:
self.emit_wrapped_text(':param arg: {}'.format(
self.process_doc(arg_data_type.doc, self._docf)),
subsequent_prefix=' ')
self.emit(':type arg: {}'.format(
self._format_type_in_doc(namespace, arg_data_type)))
if overview and not (extra_request_args or fields):
# Only output an empty line if we had an overview and haven't
# started a section on declaring types.
self.emit()
if extra_return_arg:
# Special case where the function returns a tuple. The first
# element is the JSON response. The second element is the
# the extra_return_arg param.
args = []
if is_void_type(result_data_type):
args.append('None')
else:
rtype = self._format_type_in_doc(namespace,
result_data_type)
args.append(rtype)
args.append(extra_return_arg)
self.generate_multiline_list(args, ':rtype: ')
else:
if is_void_type(result_data_type):
self.emit(':rtype: None')
else:
rtype = self._format_type_in_doc(namespace, result_data_type)
self.emit(':rtype: {}'.format(rtype))
if not is_void_type(error_data_type) and error_data_type.fields:
self.emit(':raises: :class:`{}`'.format(self.args.error_class_path))
self.emit()
# To provide more clarity to a dev who reads the docstring, suggest
# the route's error class. This is confusing, however, because we
# don't know where the error object that's raised will store
# the more detailed route error defined in stone.
error_class_name = self.args.error_class_path.rsplit('.', 1)[-1]
self.emit('If this raises, {} will contain:'.format(error_class_name))
with self.indent():
self.emit(self._format_type_in_doc(namespace, error_data_type))
if footer:
self.emit()
self.emit_wrapped_text(footer)
self.emit('"""')
def _docf(self, tag, val):
"""
Callback used as the handler argument to process_docs(). This converts
Babel doc references to Sphinx-friendly annotations.
"""
if tag == 'type':
fq_val = val
if '.' not in val:
fq_val = self.cur_namespace.name + '.' + fq_val
return ':class:`{}.{}`'.format(self.args.types_package, fq_val)
elif tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
if '.' in val:
return ':meth:`{}`'.format(fmt_func(val, version=version))
else:
return ':meth:`{}_{}`'.format(
self.cur_namespace.name, fmt_func(val, version=version))
elif tag == 'link':
anchor, link = val.rsplit(' ', 1)
return '`{} <{}>`_'.format(anchor, link)
elif tag == 'val':
if val == 'null':
return 'None'
elif val == 'true' or val == 'false':
return '``{}``'.format(val.capitalize())
else:
return val
elif tag == 'field':
return '``{}``'.format(val)
else:
raise RuntimeError('Unknown doc ref tag %r' % tag)
def _format_type_in_doc(self, namespace, data_type):
"""
Returns a string that can be recognized by Sphinx as a type reference
in a docstring.
"""
if is_void_type(data_type):
return 'None'
elif is_user_defined_type(data_type):
return ':class:`{}.{}.{}`'.format(
self.args.types_package, namespace.name, fmt_type(data_type))
elif is_nullable_type(data_type):
return 'Nullable[{}]'.format(
self._format_type_in_doc(namespace, data_type.data_type),
)
elif is_list_type(data_type):
return 'List[{}]'.format(
self._format_type_in_doc(namespace, data_type.data_type),
)
elif is_map_type(data_type):
return 'Map[{}, {}]'.format(
self._format_type_in_doc(namespace, data_type.key_data_type),
self._format_type_in_doc(namespace, data_type.value_data_type),
)
else:
return fmt_type(data_type)
def _generate_python_value(self, namespace, value):
if is_tag_ref(value):
return '{}.{}.{}'.format(
fmt_namespace(namespace.name),
class_name_for_data_type(value.union_data_type),
fmt_var(value.tag_name))
else:
return fmt_obj(value)

View file

@ -0,0 +1,201 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from contextlib import contextmanager
import pprint
from stone.backend import Backend, CodeBackend
from stone.backends.helpers import (
fmt_pascal,
fmt_underscores,
)
from stone.ir import ApiNamespace
from stone.ir import (
AnnotationType,
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
is_user_defined_type,
is_alias,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
_type_table = {
Boolean: 'bool',
Bytes: 'bytes',
Float32: 'float',
Float64: 'float',
Int32: 'int',
Int64: 'int',
List: 'list',
String: 'str',
Timestamp: 'datetime',
UInt32: 'int',
UInt64: 'int',
}
_reserved_keywords = {
'break',
'class',
'continue',
'for',
'pass',
'while',
'async',
}
@contextmanager
def emit_pass_if_nothing_emitted(codegen):
# type: (CodeBackend) -> typing.Iterator[None]
starting_lineno = codegen.lineno
yield
ending_lineno = codegen.lineno
if starting_lineno == ending_lineno:
codegen.emit("pass")
codegen.emit()
def _rename_if_reserved(s):
if s in _reserved_keywords:
return s + '_'
else:
return s
def fmt_class(name, check_reserved=False):
s = fmt_pascal(name)
return _rename_if_reserved(s) if check_reserved else s
def fmt_func(name, check_reserved=False, version=1):
name = fmt_underscores(name)
if check_reserved:
name = _rename_if_reserved(name)
if version > 1:
name = '{}_v{}'.format(name, version)
return name
def fmt_obj(o):
return pprint.pformat(o, width=1)
def fmt_type(data_type):
return _type_table.get(data_type.__class__, fmt_class(data_type.name))
def fmt_var(name, check_reserved=False):
s = fmt_underscores(name)
return _rename_if_reserved(s) if check_reserved else s
def fmt_namespaced_var(ns_name, data_type_name, field_name):
return ".".join([ns_name, data_type_name, fmt_var(field_name)])
def fmt_namespace(name):
return _rename_if_reserved(name)
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, version=route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route
TYPE_IGNORE_COMMENT = " # type: ignore"
def generate_imports_for_referenced_namespaces(
backend, namespace, package, insert_type_ignore=False):
# type: (Backend, ApiNamespace, typing.Text, bool) -> None
"""
Both the true Python backend and the Python PEP 484 Type Stub backend have
to perform the same imports.
:param insert_type_ignore: add a MyPy type-ignore comment to the imports in
the except: clause.
"""
imported_namespaces = namespace.get_imported_namespaces(consider_annotation_types=True)
if not imported_namespaces:
return
type_ignore_comment = TYPE_IGNORE_COMMENT if insert_type_ignore else ""
for ns in imported_namespaces:
backend.emit('from {package} import {namespace_name}{type_ignore_comment}'.format(
package=package,
namespace_name=fmt_namespace(ns.name),
type_ignore_comment=type_ignore_comment
))
backend.emit()
def generate_module_header(backend):
backend.emit('# -*- coding: utf-8 -*-')
backend.emit('# Auto-generated by Stone, do not modify.')
# Silly way to not type ATgenerated in our code to avoid having this
# file marked as auto-generated by our code review tool.
backend.emit('# @{}'.format('generated'))
backend.emit('# flake8: noqa')
backend.emit('# pylint: skip-file')
# This will be at the top of every generated file.
_validators_import_template = """\
from stone.backends.python_rsrc import stone_base as bb{type_ignore_comment}
from stone.backends.python_rsrc import stone_validators as bv{type_ignore_comment}
"""
validators_import = _validators_import_template.format(type_ignore_comment="")
validators_import_with_type_ignore = _validators_import_template.format(
type_ignore_comment=TYPE_IGNORE_COMMENT
)
def prefix_with_ns_if_necessary(name, name_ns, source_ns):
# type: (typing.Text, ApiNamespace, ApiNamespace) -> typing.Text
"""
Returns a name that can be used to reference `name` in namespace `name_ns`
from `source_ns`.
If `source_ns` and `name_ns` are the same, that's just `name`. Otherwise
it's `name_ns`.`name`.
"""
if source_ns == name_ns:
return name
return '{}.{}'.format(fmt_namespace(name_ns.name), name)
def class_name_for_data_type(data_type, ns=None):
"""
Returns the name of the Python class that maps to a user-defined type.
The name is identical to the name in the spec.
If ``ns`` is set to a Namespace and the namespace of `data_type` does
not match, then a namespace prefix is added to the returned name.
For example, ``foreign_ns.TypeName``.
"""
assert is_user_defined_type(data_type) or is_alias(data_type), \
'Expected composite type, got %r' % type(data_type)
name = fmt_class(data_type.name)
if ns:
return prefix_with_ns_if_necessary(name, data_type.namespace, ns)
return name
def class_name_for_annotation_type(annotation_type, ns=None):
"""
Same as class_name_for_data_type, but works with annotation types.
"""
assert isinstance(annotation_type, AnnotationType)
name = fmt_class(annotation_type.name)
if ns:
return prefix_with_ns_if_necessary(name, annotation_type.namespace, ns)
return name

View file

@ -0,0 +1 @@
# Make this a package so that the Python backend tests can import these.

View file

@ -0,0 +1,250 @@
"""
Helpers for representing Stone data types in Python.
"""
from __future__ import absolute_import, unicode_literals
import functools
from stone.backends.python_rsrc import stone_validators as bv
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
class AnnotationType(object):
# This is a base class for all annotation types.
pass
if _MYPY:
T = typing.TypeVar('T', bound=AnnotationType)
U = typing.TypeVar('U')
class NotSet(object):
__slots__ = ()
def __copy__(self):
# type: () -> NotSet
# disable copying so we can do identity comparison even after copying stone objects
return self
def __deepcopy__(self, memo):
# type: (typing.Dict[typing.Text, typing.Any]) -> NotSet
# disable copying so we can do identity comparison even after copying stone objects
return self
def __repr__(self):
return "NOT_SET"
NOT_SET = NotSet() # dummy object to denote that a field has not been set
NO_DEFAULT = object()
class Attribute(object):
__slots__ = ("name", "default", "nullable", "user_defined", "validator")
def __init__(self, name, nullable=False, user_defined=False):
# type: (typing.Text, bool, bool) -> None
# Internal name to store actual value for attribute.
self.name = "_{}_value".format(name)
self.nullable = nullable
self.user_defined = user_defined
# These should be set later, because of possible cross-references.
self.validator = None # type: typing.Any
self.default = NO_DEFAULT
def __get__(self, instance, owner):
# type: (typing.Any, typing.Any) -> typing.Any
if instance is None:
return self
value = getattr(instance, self.name)
if value is not NOT_SET:
return value
if self.nullable:
return None
if self.default is not NO_DEFAULT:
return self.default
# No luck, give a nice error.
raise AttributeError("missing required field '{}'".format(public_name(self.name)))
def __set__(self, instance, value):
# type: (typing.Any, typing.Any) -> None
if self.nullable and value is None:
setattr(instance, self.name, NOT_SET)
return
if self.user_defined:
self.validator.validate_type_only(value)
else:
value = self.validator.validate(value)
setattr(instance, self.name, value)
def __delete__(self, instance):
# type: (typing.Any) -> None
setattr(instance, self.name, NOT_SET)
class Struct(object):
# This is a base class for all classes representing Stone structs.
# every parent class in the inheritance tree must define __slots__ in order to get full memory
# savings
__slots__ = ()
_all_field_names_ = set() # type: typing.Set[str]
def __eq__(self, other):
# type: (object) -> bool
if not isinstance(other, Struct):
return False
if self._all_field_names_ != other._all_field_names_:
return False
if not isinstance(other, self.__class__) and not isinstance(self, other.__class__):
return False
for field_name in self._all_field_names_:
if getattr(self, field_name) != getattr(other, field_name):
return False
return True
def __ne__(self, other):
# type: (object) -> bool
return not self == other
def __repr__(self):
args = ["{}={!r}".format(name, getattr(self, "_{}_value".format(name)))
for name in sorted(self._all_field_names_)]
return "{}({})".format(type(self).__name__, ", ".join(args))
def _process_custom_annotations(self, annotation_type, field_path, processor):
# type: (typing.Type[T], typing.Text, typing.Callable[[T, U], U]) -> None
pass
class Union(object):
# TODO(kelkabany): Possible optimization is to remove _value if a
# union is composed of only symbols.
__slots__ = ['_tag', '_value']
_tagmap = {} # type: typing.Dict[str, bv.Validator]
_permissioned_tagmaps = set() # type: typing.Set[typing.Text]
def __init__(self, tag, value=None):
validator = None
tagmap_names = ['_{}_tagmap'.format(map_name) for map_name in self._permissioned_tagmaps]
for tagmap_name in ['_tagmap'] + tagmap_names:
if tag in getattr(self, tagmap_name):
validator = getattr(self, tagmap_name)[tag]
assert validator is not None, 'Invalid tag %r.' % tag
if isinstance(validator, bv.Void):
assert value is None, 'Void type union member must have None value.'
elif isinstance(validator, (bv.Struct, bv.Union)):
validator.validate_type_only(value)
else:
validator.validate(value)
self._tag = tag
self._value = value
def __eq__(self, other):
# Also need to check if one class is a subclass of another. If one union extends another,
# the common fields should be able to be compared to each other.
return (
isinstance(other, Union) and
(isinstance(self, other.__class__) or isinstance(other, self.__class__)) and
self._tag == other._tag and self._value == other._value
)
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash((self._tag, self._value))
def __repr__(self):
return "{}({!r}, {!r})".format(type(self).__name__, self._tag, self._value)
def _process_custom_annotations(self, annotation_type, field_path, processor):
# type: (typing.Type[T], typing.Text, typing.Callable[[T, U], U]) -> None
pass
@classmethod
def _is_tag_present(cls, tag, caller_permissions):
assert tag is not None, 'tag value should not be None'
if tag in cls._tagmap:
return True
for extra_permission in caller_permissions.permissions:
tagmap_name = '_{}_tagmap'.format(extra_permission)
if hasattr(cls, tagmap_name) and tag in getattr(cls, tagmap_name):
return True
return False
@classmethod
def _get_val_data_type(cls, tag, caller_permissions):
assert tag is not None, 'tag value should not be None'
for extra_permission in caller_permissions.permissions:
tagmap_name = '_{}_tagmap'.format(extra_permission)
if hasattr(cls, tagmap_name) and tag in getattr(cls, tagmap_name):
return getattr(cls, tagmap_name)[tag]
return cls._tagmap[tag]
class Route(object):
__slots__ = ("name", "version", "deprecated", "arg_type", "result_type", "error_type", "attrs")
def __init__(self, name, version, deprecated, arg_type, result_type, error_type, attrs):
self.name = name
self.version = version
self.deprecated = deprecated
self.arg_type = arg_type
self.result_type = result_type
self.error_type = error_type
assert isinstance(attrs, dict), 'Expected dict, got %r' % attrs
self.attrs = attrs
def __repr__(self):
return 'Route({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r})'.format(
self.name,
self.version,
self.deprecated,
self.arg_type,
self.result_type,
self.error_type,
self.attrs)
# helper functions used when constructing custom annotation processors
# put this here so that every other file doesn't need to import functools
partially_apply = functools.partial
def make_struct_annotation_processor(annotation_type, processor):
def g(field_path, struct):
if struct is None:
return struct
struct._process_custom_annotations(annotation_type, field_path, processor)
return struct
return g
def make_list_annotation_processor(processor):
def g(field_path, list_):
if list_ is None:
return list_
return [processor('{}[{}]'.format(field_path, idx), x) for idx, x in enumerate(list_)]
return g
def make_map_value_annotation_processor(processor):
def g(field_path, map_):
if map_ is None:
return map_
return {k: processor('{}[{}]'.format(field_path, repr(k)), v) for k, v in map_.items()}
return g
def public_name(name):
# _some_attr_value -> some_attr
return "_".join(name.split("_")[1:-1])

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,722 @@
"""
Defines classes to represent each Stone type in Python. These classes should
be used to validate Python objects and normalize them for a given type.
The data types defined here should not be specific to an RPC or serialization
format.
"""
from __future__ import absolute_import, unicode_literals
import datetime
import hashlib
import math
import numbers
import re
from abc import ABCMeta, abstractmethod
import six
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# See <http://python3porting.com/differences.html#buffer>
if six.PY3:
_binary_types = (bytes, memoryview) # noqa: E501,F821 # pylint: disable=undefined-variable,useless-suppression
else:
_binary_types = (bytes, buffer) # noqa: E501,F821 # pylint: disable=undefined-variable,useless-suppression
class ValidationError(Exception):
"""Raised when a value doesn't pass validation by its validator."""
def __init__(self, message, parent=None):
"""
Args:
message (str): Error message detailing validation failure.
parent (str): Adds the parent as the closest reference point for
the error. Use :meth:`add_parent` to add more.
"""
super(ValidationError, self).__init__(message)
self.message = message
self._parents = []
if parent:
self._parents.append(parent)
def add_parent(self, parent):
"""
Args:
parent (str): Adds the parent to the top of the tree of references
that lead to the validator that failed.
"""
self._parents.append(parent)
def __str__(self):
"""
Returns:
str: A descriptive message of the validation error that may also
include the path to the validator that failed.
"""
if self._parents:
return '{}: {}'.format('.'.join(self._parents[::-1]), self.message)
else:
return self.message
def __repr__(self):
# Not a perfect repr, but includes the error location information.
return 'ValidationError(%r)' % six.text_type(self)
def type_name_with_module(t):
# type: (typing.Type[typing.Any]) -> six.text_type
return '%s.%s' % (t.__module__, t.__name__)
def generic_type_name(v):
# type: (typing.Any) -> six.text_type
"""Return a descriptive type name that isn't Python specific. For example,
an int value will return 'integer' rather than 'int'."""
if isinstance(v, bool):
# Must come before any numbers checks since booleans are integers too
return 'boolean'
elif isinstance(v, numbers.Integral):
# Must come before real numbers check since integrals are reals too
return 'integer'
elif isinstance(v, numbers.Real):
return 'float'
elif isinstance(v, (tuple, list)):
return 'list'
elif isinstance(v, six.string_types):
return 'string'
elif v is None:
return 'null'
else:
return type_name_with_module(type(v))
class Validator(six.with_metaclass(ABCMeta, object)):
"""All primitive and composite data types should be a subclass of this."""
__slots__ = ("_redact",)
@abstractmethod
def validate(self, val):
"""Validates that val is of this data type.
Returns: A normalized value if validation succeeds.
Raises: ValidationError
"""
def has_default(self):
return False
def get_default(self):
raise AssertionError('No default available.')
class Primitive(Validator): # pylint: disable=abstract-method
"""A basic type that is defined by Stone."""
__slots__ = ()
class Boolean(Primitive):
__slots__ = ()
def validate(self, val):
if not isinstance(val, bool):
raise ValidationError('%r is not a valid boolean' % val)
return val
class Integer(Primitive):
"""
Do not use this class directly. Extend it and specify a 'default_minimum' and
'default_maximum' value as class variables for a more restrictive integer range.
"""
__slots__ = ("minimum", "maximum")
default_minimum = None # type: typing.Optional[int]
default_maximum = None # type: typing.Optional[int]
def __init__(self, min_value=None, max_value=None):
"""
A more restrictive minimum or maximum value can be specified than the
range inherent to the defined type.
"""
if min_value is not None:
assert isinstance(min_value, numbers.Integral), \
'min_value must be an integral number'
assert min_value >= self.default_minimum, \
'min_value cannot be less than the minimum value for this ' \
'type (%d < %d)' % (min_value, self.default_minimum)
self.minimum = min_value
else:
self.minimum = self.default_minimum
if max_value is not None:
assert isinstance(max_value, numbers.Integral), \
'max_value must be an integral number'
assert max_value <= self.default_maximum, \
'max_value cannot be greater than the maximum value for ' \
'this type (%d < %d)' % (max_value, self.default_maximum)
self.maximum = max_value
else:
self.maximum = self.default_maximum
def validate(self, val):
if not isinstance(val, numbers.Integral):
raise ValidationError('expected integer, got %s'
% generic_type_name(val))
elif not (self.minimum <= val <= self.maximum):
raise ValidationError('%d is not within range [%d, %d]'
% (val, self.minimum, self.maximum))
return val
def __repr__(self):
return '%s()' % self.__class__.__name__
class Int32(Integer):
__slots__ = ()
default_minimum = -2**31
default_maximum = 2**31 - 1
class UInt32(Integer):
__slots__ = ()
default_minimum = 0
default_maximum = 2**32 - 1
class Int64(Integer):
__slots__ = ()
default_minimum = -2**63
default_maximum = 2**63 - 1
class UInt64(Integer):
__slots__ = ()
default_minimum = 0
default_maximum = 2**64 - 1
class Real(Primitive):
"""
Do not use this class directly. Extend it and optionally set a 'default_minimum'
and 'default_maximum' value to enforce a range that's a subset of the Python float
implementation. Python floats are doubles.
"""
__slots__ = ("minimum", "maximum")
default_minimum = None # type: typing.Optional[float]
default_maximum = None # type: typing.Optional[float]
def __init__(self, min_value=None, max_value=None):
"""
A more restrictive minimum or maximum value can be specified than the
range inherent to the defined type.
"""
if min_value is not None:
assert isinstance(min_value, numbers.Real), \
'min_value must be a real number'
if not isinstance(min_value, float):
try:
min_value = float(min_value)
except OverflowError:
raise AssertionError('min_value is too small for a float')
if self.default_minimum is not None and min_value < self.default_minimum:
raise AssertionError('min_value cannot be less than the '
'minimum value for this type (%f < %f)' %
(min_value, self.default_minimum))
self.minimum = min_value
else:
self.minimum = self.default_minimum
if max_value is not None:
assert isinstance(max_value, numbers.Real), \
'max_value must be a real number'
if not isinstance(max_value, float):
try:
max_value = float(max_value)
except OverflowError:
raise AssertionError('max_value is too large for a float')
if self.default_maximum is not None and max_value > self.default_maximum:
raise AssertionError('max_value cannot be greater than the '
'maximum value for this type (%f < %f)' %
(max_value, self.default_maximum))
self.maximum = max_value
else:
self.maximum = self.default_maximum
def validate(self, val):
if not isinstance(val, numbers.Real):
raise ValidationError('expected real number, got %s' %
generic_type_name(val))
if not isinstance(val, float):
# This checks for the case where a number is passed in with a
# magnitude larger than supported by float64.
try:
val = float(val)
except OverflowError:
raise ValidationError('too large for float')
if math.isnan(val) or math.isinf(val):
raise ValidationError('%f values are not supported' % val)
if self.minimum is not None and val < self.minimum:
raise ValidationError('%f is not greater than %f' %
(val, self.minimum))
if self.maximum is not None and val > self.maximum:
raise ValidationError('%f is not less than %f' %
(val, self.maximum))
return val
def __repr__(self):
return '%s()' % self.__class__.__name__
class Float32(Real):
__slots__ = ()
# Maximum and minimums from the IEEE 754-1985 standard
default_minimum = -3.40282 * 10**38
default_maximum = 3.40282 * 10**38
class Float64(Real):
__slots__ = ()
class String(Primitive):
"""Represents a unicode string."""
__slots__ = ("min_length", "max_length", "pattern", "pattern_re")
def __init__(self, min_length=None, max_length=None, pattern=None):
if min_length is not None:
assert isinstance(min_length, numbers.Integral), \
'min_length must be an integral number'
assert min_length >= 0, 'min_length must be >= 0'
if max_length is not None:
assert isinstance(max_length, numbers.Integral), \
'max_length must be an integral number'
assert max_length > 0, 'max_length must be > 0'
if min_length and max_length:
assert max_length >= min_length, 'max_length must be >= min_length'
if pattern is not None:
assert isinstance(pattern, six.string_types), \
'pattern must be a string'
self.min_length = min_length
self.max_length = max_length
self.pattern = pattern
self.pattern_re = None
if pattern:
try:
self.pattern_re = re.compile(r"\A(?:" + pattern + r")\Z")
except re.error as e:
raise AssertionError('Regex {!r} failed: {}'.format(
pattern, e.args[0]))
def validate(self, val):
"""
A unicode string of the correct length and pattern will pass validation.
In PY2, we enforce that a str type must be valid utf-8, and a unicode
string will be returned.
"""
if not isinstance(val, six.string_types):
raise ValidationError("'%s' expected to be a string, got %s"
% (val, generic_type_name(val)))
if not six.PY3 and isinstance(val, str):
try:
val = val.decode('utf-8')
except UnicodeDecodeError:
raise ValidationError("'%s' was not valid utf-8")
if self.max_length is not None and len(val) > self.max_length:
raise ValidationError("'%s' must be at most %d characters, got %d"
% (val, self.max_length, len(val)))
if self.min_length is not None and len(val) < self.min_length:
raise ValidationError("'%s' must be at least %d characters, got %d"
% (val, self.min_length, len(val)))
if self.pattern and not self.pattern_re.match(val):
raise ValidationError("'%s' did not match pattern '%s'"
% (val, self.pattern))
return val
class Bytes(Primitive):
__slots__ = ("min_length", "max_length")
def __init__(self, min_length=None, max_length=None):
if min_length is not None:
assert isinstance(min_length, numbers.Integral), \
'min_length must be an integral number'
assert min_length >= 0, 'min_length must be >= 0'
if max_length is not None:
assert isinstance(max_length, numbers.Integral), \
'max_length must be an integral number'
assert max_length > 0, 'max_length must be > 0'
if min_length is not None and max_length is not None:
assert max_length >= min_length, 'max_length must be >= min_length'
self.min_length = min_length
self.max_length = max_length
def validate(self, val):
if not isinstance(val, _binary_types):
raise ValidationError("expected bytes type, got %s"
% generic_type_name(val))
elif self.max_length is not None and len(val) > self.max_length:
raise ValidationError("'%s' must have at most %d bytes, got %d"
% (val, self.max_length, len(val)))
elif self.min_length is not None and len(val) < self.min_length:
raise ValidationError("'%s' has fewer than %d bytes, got %d"
% (val, self.min_length, len(val)))
return val
class Timestamp(Primitive):
"""Note that while a format is specified, it isn't used in validation
since a native Python datetime object is preferred. The format, however,
can and should be used by serializers."""
__slots__ = ("format",)
def __init__(self, fmt):
"""fmt must be composed of format codes that the C standard (1989)
supports, most notably in its strftime() function."""
assert isinstance(fmt, six.text_type), 'format must be a string'
self.format = fmt
def validate(self, val):
if not isinstance(val, datetime.datetime):
raise ValidationError('expected timestamp, got %s'
% generic_type_name(val))
elif val.tzinfo is not None and \
val.tzinfo.utcoffset(val).total_seconds() != 0:
raise ValidationError('timestamp should have either a UTC '
'timezone or none set at all')
return val
class Composite(Validator): # pylint: disable=abstract-method
"""Validator for a type that builds on other primitive and composite
types."""
__slots__ = ()
class List(Composite):
"""Assumes list contents are homogeneous with respect to types."""
__slots__ = ("item_validator", "min_items", "max_items")
def __init__(self, item_validator, min_items=None, max_items=None):
"""Every list item will be validated with item_validator."""
self.item_validator = item_validator
if min_items is not None:
assert isinstance(min_items, numbers.Integral), \
'min_items must be an integral number'
assert min_items >= 0, 'min_items must be >= 0'
if max_items is not None:
assert isinstance(max_items, numbers.Integral), \
'max_items must be an integral number'
assert max_items > 0, 'max_items must be > 0'
if min_items is not None and max_items is not None:
assert max_items >= min_items, 'max_items must be >= min_items'
self.min_items = min_items
self.max_items = max_items
def validate(self, val):
if not isinstance(val, (tuple, list)):
raise ValidationError('%r is not a valid list' % val)
elif self.max_items is not None and len(val) > self.max_items:
raise ValidationError('%r has more than %s items'
% (val, self.max_items))
elif self.min_items is not None and len(val) < self.min_items:
raise ValidationError('%r has fewer than %s items'
% (val, self.min_items))
return [self.item_validator.validate(item) for item in val]
class Map(Composite):
"""Assumes map keys and values are homogeneous with respect to types."""
__slots__ = ("key_validator", "value_validator")
def __init__(self, key_validator, value_validator):
"""
Every Map key/value pair will be validated with item_validator.
key validators must be a subclass of a String validator
"""
self.key_validator = key_validator
self.value_validator = value_validator
def validate(self, val):
if not isinstance(val, dict):
raise ValidationError('%r is not a valid dict' % val)
return {
self.key_validator.validate(key):
self.value_validator.validate(value) for key, value in val.items()
}
class Struct(Composite):
__slots__ = ("definition",)
def __init__(self, definition):
"""
Args:
definition (class): A generated class representing a Stone struct
from a spec. Must have a _fields_ attribute with the following
structure:
_fields_ = [(field_name, validator), ...]
where
field_name: Name of the field (str).
validator: Validator object.
"""
super(Struct, self).__init__()
self.definition = definition
def validate(self, val):
"""
For a val to pass validation, val must be of the correct type and have
all required fields present.
"""
self.validate_type_only(val)
self.validate_fields_only(val)
return val
def validate_with_permissions(self, val, caller_permissions):
"""
For a val to pass validation, val must be of the correct type and have
all required permissioned fields present. Should only be called
for callers with extra permissions.
"""
self.validate(val)
self.validate_fields_only_with_permissions(val, caller_permissions)
return val
def validate_fields_only(self, val):
"""
To pass field validation, no required field should be missing.
This method assumes that the contents of each field have already been
validated on assignment, so it's merely a presence check.
FIXME(kelkabany): Since the definition object does not maintain a list
of which fields are required, all fields are scanned.
"""
for field_name in self.definition._all_field_names_:
if not hasattr(val, field_name):
raise ValidationError("missing required field '%s'" %
field_name)
def validate_fields_only_with_permissions(self, val, caller_permissions):
"""
To pass field validation, no required field should be missing.
This method assumes that the contents of each field have already been
validated on assignment, so it's merely a presence check.
Should only be called for callers with extra permissions.
"""
self.validate_fields_only(val)
# check if type has been patched
for extra_permission in caller_permissions.permissions:
all_field_names = '_all_{}_field_names_'.format(extra_permission)
for field_name in getattr(self.definition, all_field_names, set()):
if not hasattr(val, field_name):
raise ValidationError("missing required field '%s'" % field_name)
def validate_type_only(self, val):
"""
Use this when you only want to validate that the type of an object
is correct, but not yet validate each field.
"""
# Since the definition maintains the list of fields for serialization,
# we're okay with a subclass that might have extra information. This
# makes it easier to return one subclass for two routes, one of which
# relies on the parent class.
if not isinstance(val, self.definition):
raise ValidationError('expected type %s, got %s' %
(
type_name_with_module(self.definition),
generic_type_name(val),
),
)
def has_default(self):
return not self.definition._has_required_fields
def get_default(self):
assert not self.definition._has_required_fields, 'No default available.'
return self.definition()
class StructTree(Struct):
"""Validator for structs with enumerated subtypes.
NOTE: validate_fields_only() validates the fields known to this base
struct, but does not do any validation specific to the subtype.
"""
__slots__ = ()
# See PyCQA/pylint#1043 for why this is disabled; this should show up
# as a usless-suppression (and can be removed) once a fix is released
def __init__(self, definition): # pylint: disable=useless-super-delegation
super(StructTree, self).__init__(definition)
class Union(Composite):
__slots__ = ("definition",)
def __init__(self, definition):
"""
Args:
definition (class): A generated class representing a Stone union
from a spec. Must have a _tagmap attribute with the following
structure:
_tagmap = {field_name: validator, ...}
where
field_name (str): Tag name.
validator (Validator): Tag value validator.
"""
self.definition = definition
def validate(self, val):
"""
For a val to pass validation, it must have a _tag set. This assumes
that the object validated that _tag is a valid tag, and that any
associated value has also been validated.
"""
self.validate_type_only(val)
if not hasattr(val, '_tag') or val._tag is None:
raise ValidationError('no tag set')
return val
def validate_type_only(self, val):
"""
Use this when you only want to validate that the type of an object
is correct, but not yet validate each field.
We check whether val is a Python parent class of the definition. This
is because Union subtyping works in the opposite direction of Python
inheritance. For example, if a union U2 extends U1 in Python, this
validator will accept U1 in places where U2 is expected.
"""
if not issubclass(self.definition, type(val)):
raise ValidationError('expected type %s or subtype, got %s' %
(
type_name_with_module(self.definition),
generic_type_name(val),
),
)
class Void(Primitive):
__slots__ = ()
def validate(self, val):
if val is not None:
raise ValidationError('expected NoneType, got %s' %
generic_type_name(val))
def has_default(self):
return True
def get_default(self):
return None
class Nullable(Validator):
__slots__ = ("validator",)
def __init__(self, validator):
super(Nullable, self).__init__()
assert isinstance(validator, (Primitive, Composite)), \
'validator must be for a primitive or composite type'
assert not isinstance(validator, Nullable), \
'nullables cannot be stacked'
assert not isinstance(validator, Void), \
'void cannot be made nullable'
self.validator = validator
def validate(self, val):
if val is None:
return
else:
return self.validator.validate(val)
def validate_type_only(self, val):
"""Use this only if Nullable is wrapping a Composite."""
if val is None:
return
else:
return self.validator.validate_type_only(val)
def has_default(self):
return True
def get_default(self):
return None
class Redactor(object):
__slots__ = ("regex",)
def __init__(self, regex):
"""
Args:
regex: What parts of the field to redact.
"""
self.regex = regex
@abstractmethod
def apply(self, val):
"""Redacts information from annotated field.
Returns: A redacted version of the string provided.
"""
def _get_matches(self, val):
if not self.regex:
return None
try:
return re.search(self.regex, val)
except TypeError:
return None
class HashRedactor(Redactor):
__slots__ = ()
def apply(self, val):
matches = self._get_matches(val)
val_to_hash = str(val) if isinstance(val, int) or isinstance(val, float) else val
try:
# add string literal to ensure unicode
hashed = hashlib.md5(val_to_hash.encode('utf-8')).hexdigest() + ''
except [AttributeError, ValueError]:
hashed = None
if matches:
blotted = '***'.join(matches.groups())
if hashed:
return '{} ({})'.format(hashed, blotted)
return blotted
return hashed
class BlotRedactor(Redactor):
__slots__ = ()
def apply(self, val):
matches = self._get_matches(val)
if matches:
return '***'.join(matches.groups())
return '********'

View file

@ -0,0 +1,114 @@
from stone.ir import (
Alias,
ApiNamespace,
DataType,
List,
Map,
Nullable,
Timestamp,
UserDefined,
is_alias,
is_boolean_type,
is_bytes_type,
is_float_type,
is_integer_type,
is_list_type,
is_map_type,
is_nullable_type,
is_string_type,
is_timestamp_type,
is_user_defined_type,
is_void_type,
)
from stone.backends.python_helpers import class_name_for_data_type, fmt_namespace
from stone.ir.data_types import String
from stone.typing_hacks import cast
MYPY = False
if MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
DataTypeCls = typing.Type[DataType]
# Unfortunately these are codependent, so I'll weakly type the Dict in Callback
Callback = typing.Callable[
[ApiNamespace, DataType, typing.Dict[typing.Any, typing.Any]],
typing.Text
]
OverrideDefaultTypesDict = typing.Dict[DataTypeCls, Callback]
else:
OverrideDefaultTypesDict = "OverrideDefaultTypesDict"
def map_stone_type_to_python_type(ns, data_type, override_dict=None):
# type: (ApiNamespace, DataType, typing.Optional[OverrideDefaultTypesDict]) -> typing.Text
"""
Args:
override_dict: lets you override the default behavior for a given type by hooking into
a callback. (Currently only hooked up for stone's List and Nullable)
"""
override_dict = override_dict or {}
if is_string_type(data_type):
string_override = override_dict.get(String, None)
if string_override:
return string_override(ns, data_type, override_dict)
return 'str'
elif is_bytes_type(data_type):
return 'bytes'
elif is_boolean_type(data_type):
return 'bool'
elif is_float_type(data_type):
return 'float'
elif is_integer_type(data_type):
return 'int'
elif is_void_type(data_type):
return 'None'
elif is_timestamp_type(data_type):
timestamp_override = override_dict.get(Timestamp, None)
if timestamp_override:
return timestamp_override(ns, data_type, override_dict)
return 'datetime.datetime'
elif is_alias(data_type):
alias_type = cast(Alias, data_type)
return map_stone_type_to_python_type(ns, alias_type.data_type, override_dict)
elif is_user_defined_type(data_type):
user_defined_type = cast(UserDefined, data_type)
class_name = class_name_for_data_type(user_defined_type)
if user_defined_type.namespace.name != ns.name:
return '{}.{}'.format(
fmt_namespace(user_defined_type.namespace.name), class_name)
else:
return class_name
elif is_list_type(data_type):
list_type = cast(List, data_type)
if List in override_dict:
return override_dict[List](ns, list_type.data_type, override_dict)
# PyCharm understands this description format for a list
return 'list of [{}]'.format(
map_stone_type_to_python_type(ns, list_type.data_type, override_dict)
)
elif is_map_type(data_type):
map_type = cast(Map, data_type)
if Map in override_dict:
return override_dict[Map](
ns,
data_type,
override_dict
)
return 'dict of [{}:{}]'.format(
map_stone_type_to_python_type(ns, map_type.key_data_type, override_dict),
map_stone_type_to_python_type(ns, map_type.value_data_type, override_dict)
)
elif is_nullable_type(data_type):
nullable_type = cast(Nullable, data_type)
if Nullable in override_dict:
return override_dict[Nullable](ns, nullable_type.data_type, override_dict)
return 'Optional[{}]'.format(
map_stone_type_to_python_type(ns, nullable_type.data_type, override_dict)
)
else:
raise TypeError('Unknown data type %r' % data_type)

View file

@ -0,0 +1,486 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
from six import StringIO
from stone.backend import CodeBackend
from stone.backends.python_helpers import (
check_route_name_conflict,
class_name_for_annotation_type,
class_name_for_data_type,
emit_pass_if_nothing_emitted,
fmt_func,
fmt_namespace,
fmt_var,
generate_imports_for_referenced_namespaces,
generate_module_header,
validators_import_with_type_ignore,
)
from stone.backends.python_type_mapping import (
map_stone_type_to_python_type,
OverrideDefaultTypesDict,
)
from stone.ir import (
Alias,
AnnotationType,
Api,
ApiNamespace,
DataType,
is_nullable_type,
is_struct_type,
is_union_type,
is_user_defined_type,
is_void_type,
List,
Map,
Nullable,
Struct,
Timestamp,
Union,
unwrap_aliases,
)
from stone.ir.data_types import String
from stone.typing_hacks import cast
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
class ImportTracker(object):
def __init__(self):
# type: () -> None
self.cur_namespace_typing_imports = set() # type: typing.Set[typing.Text]
self.cur_namespace_adhoc_imports = set() # type: typing.Set[typing.Text]
def clear(self):
# type: () -> None
self.cur_namespace_typing_imports.clear()
self.cur_namespace_adhoc_imports.clear()
def _register_typing_import(self, s):
# type: (typing.Text) -> None
"""
Denotes that we need to import something specifically from the `typing` module.
For example, _register_typing_import("Optional")
"""
self.cur_namespace_typing_imports.add(s)
def _register_adhoc_import(self, s):
# type: (typing.Text) -> None
"""
Denotes an ad-hoc import.
For example,
_register_adhoc_import("import datetime")
or
_register_adhoc_import("from xyz import abc")
"""
self.cur_namespace_adhoc_imports.add(s)
_cmdline_parser = argparse.ArgumentParser(prog='python-types-backend')
_cmdline_parser.add_argument(
'-p',
'--package',
type=str,
required=True,
help='Package prefix for absolute imports in generated files.',
)
class PythonTypeStubsBackend(CodeBackend):
"""Generates Python modules to represent the input Stone spec."""
cmdline_parser = _cmdline_parser
# Instance var of the current namespace being generated
cur_namespace = None
preserve_aliases = True
import_tracker = ImportTracker()
def __init__(self, *args, **kwargs):
# type: (...) -> None
super(PythonTypeStubsBackend, self).__init__(*args, **kwargs)
self._pep_484_type_mapping_callbacks = self._get_pep_484_type_mapping_callbacks()
def generate(self, api):
# type: (Api) -> None
"""
Generates a module for each namespace.
Each namespace will have Python classes to represent data types and
routes in the Stone spec.
"""
for namespace in api.namespaces.values():
with self.output_to_relative_path('{}.pyi'.format(fmt_namespace(namespace.name))):
self._generate_base_namespace_module(namespace)
def _generate_base_namespace_module(self, namespace):
# type: (ApiNamespace) -> None
"""Creates a module for the namespace. All data types and routes are
represented as Python classes."""
self.cur_namespace = namespace
self.import_tracker.clear()
generate_module_header(self)
self.emit_placeholder('imports_needed_for_typing')
self.emit_raw(validators_import_with_type_ignore)
# Generate import statements for all referenced namespaces.
self._generate_imports_for_referenced_namespaces(namespace)
self._generate_typevars()
for annotation_type in namespace.annotation_types:
self._generate_annotation_type_class(namespace, annotation_type)
for data_type in namespace.linearize_data_types():
if isinstance(data_type, Struct):
self._generate_struct_class(namespace, data_type)
elif isinstance(data_type, Union):
self._generate_union_class(namespace, data_type)
else:
raise TypeError('Cannot handle type %r' % type(data_type))
for alias in namespace.linearize_aliases():
self._generate_alias_definition(namespace, alias)
self._generate_routes(namespace)
self._generate_imports_needed_for_typing()
def _generate_imports_for_referenced_namespaces(self, namespace):
# type: (ApiNamespace) -> None
assert self.args is not None
generate_imports_for_referenced_namespaces(
backend=self,
namespace=namespace,
package=self.args.package,
insert_type_ignore=True,
)
def _generate_typevars(self):
# type: () -> None
"""
Creates type variables that are used by the type signatures for
_process_custom_annotations.
"""
self.emit("T = TypeVar('T', bound=bb.AnnotationType)")
self.emit("U = TypeVar('U')")
self.import_tracker._register_typing_import('TypeVar')
self.emit()
def _generate_annotation_type_class(self, ns, annotation_type):
# type: (ApiNamespace, AnnotationType) -> None
"""Defines a Python class that represents an annotation type in Stone."""
self.emit('class {}(object):'.format(class_name_for_annotation_type(annotation_type, ns)))
with self.indent():
self._generate_annotation_type_class_init(ns, annotation_type)
self._generate_annotation_type_class_properties(ns, annotation_type)
self.emit()
def _generate_annotation_type_class_init(self, ns, annotation_type):
# type: (ApiNamespace, AnnotationType) -> None
args = ['self']
for param in annotation_type.params:
param_name = fmt_var(param.name, True)
param_type = self.map_stone_type_to_pep484_type(ns, param.data_type)
if not is_nullable_type(param.data_type):
self.import_tracker._register_typing_import('Optional')
param_type = 'Optional[{}]'.format(param_type)
args.append(
"{param_name}: {param_type} = ...".format(
param_name=param_name,
param_type=param_type))
self.generate_multiline_list(args, before='def __init__', after=' -> None: ...')
self.emit()
def _generate_annotation_type_class_properties(self, ns, annotation_type):
# type: (ApiNamespace, AnnotationType) -> None
for param in annotation_type.params:
prop_name = fmt_var(param.name, True)
param_type = self.map_stone_type_to_pep484_type(ns, param.data_type)
self.emit('@property')
self.emit('def {prop_name}(self) -> {param_type}: ...'.format(
prop_name=prop_name,
param_type=param_type,
))
self.emit()
def _generate_struct_class(self, ns, data_type):
# type: (ApiNamespace, Struct) -> None
"""Defines a Python class that represents a struct in Stone."""
self.emit(self._class_declaration_for_type(ns, data_type))
with self.indent():
self._generate_struct_class_init(ns, data_type)
self._generate_struct_class_properties(ns, data_type)
self._generate_struct_or_union_class_custom_annotations()
self._generate_validator_for(data_type)
self.emit()
def _generate_validator_for(self, data_type):
# type: (DataType) -> None
cls_name = class_name_for_data_type(data_type)
self.emit("{}_validator: bv.Validator = ...".format(
cls_name
))
def _generate_union_class(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
self.emit(self._class_declaration_for_type(ns, data_type))
with self.indent(), emit_pass_if_nothing_emitted(self):
self._generate_union_class_vars(ns, data_type)
self._generate_union_class_is_set(data_type)
self._generate_union_class_variant_creators(ns, data_type)
self._generate_union_class_get_helpers(ns, data_type)
self._generate_struct_or_union_class_custom_annotations()
self._generate_validator_for(data_type)
self.emit()
def _generate_union_class_vars(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
lineno = self.lineno
# Generate stubs for class variables so that IDEs like PyCharms have an
# easier time detecting their existence.
for field in data_type.fields:
if is_void_type(field.data_type):
field_name = fmt_var(field.name)
field_type = class_name_for_data_type(data_type, ns)
self.emit('{field_name}: {field_type} = ...'.format(
field_name=field_name,
field_type=field_type,
))
if lineno != self.lineno:
self.emit()
def _generate_union_class_is_set(self, union):
# type: (Union) -> None
for field in union.fields:
field_name = fmt_func(field.name)
self.emit('def is_{}(self) -> bool: ...'.format(field_name))
self.emit()
def _generate_union_class_variant_creators(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
"""
Generate the following section in the 'union Shape' example:
@classmethod
def circle(cls, val: float) -> Shape: ...
"""
union_type = class_name_for_data_type(data_type)
for field in data_type.fields:
if not is_void_type(field.data_type):
field_name_reserved_check = fmt_func(field.name, check_reserved=True)
val_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
self.emit('@classmethod')
self.emit('def {field_name}(cls, val: {val_type}) -> {union_type}: ...'.format(
field_name=field_name_reserved_check,
val_type=val_type,
union_type=union_type,
))
self.emit()
def _generate_union_class_get_helpers(self, ns, data_type):
# type: (ApiNamespace, Union) -> None
"""
Generates the following section in the 'union Shape' example:
def get_circle(self) -> float: ...
"""
for field in data_type.fields:
field_name = fmt_func(field.name)
if not is_void_type(field.data_type):
# generate getter for field
val_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
self.emit('def get_{field_name}(self) -> {val_type}: ...'.format(
field_name=field_name,
val_type=val_type,
))
self.emit()
def _generate_alias_definition(self, namespace, alias):
# type: (ApiNamespace, Alias) -> None
self._generate_validator_for(alias)
unwrapped_dt, _ = unwrap_aliases(alias)
if is_user_defined_type(unwrapped_dt):
# If the alias is to a composite type, we want to alias the
# generated class as well.
self.emit('{} = {}'.format(
alias.name,
class_name_for_data_type(alias.data_type, namespace)))
def _class_declaration_for_type(self, ns, data_type):
# type: (ApiNamespace, typing.Union[Struct, Union]) -> typing.Text
assert is_user_defined_type(data_type), \
'Expected struct, got %r' % type(data_type)
if data_type.parent_type:
extends = class_name_for_data_type(data_type.parent_type, ns)
else:
if is_struct_type(data_type):
# Use a handwritten base class
extends = 'bb.Struct'
elif is_union_type(data_type):
extends = 'bb.Union'
else:
extends = 'object'
return 'class {}({}):'.format(
class_name_for_data_type(data_type), extends)
def _generate_struct_class_init(self, ns, struct):
# type: (ApiNamespace, Struct) -> None
args = ["self"]
for field in struct.all_fields:
field_name_reserved_check = fmt_var(field.name, True)
field_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
if field.has_default:
self.import_tracker._register_typing_import('Optional')
field_type = 'Optional[{}]'.format(field_type)
args.append("{field_name}: {field_type} = ...".format(
field_name=field_name_reserved_check,
field_type=field_type))
self.generate_multiline_list(args, before='def __init__', after=' -> None: ...')
def _generate_struct_class_properties(self, ns, struct):
# type: (ApiNamespace, Struct) -> None
to_emit = [] # type: typing.List[typing.Text]
for field in struct.all_fields:
field_name_reserved_check = fmt_func(field.name, check_reserved=True)
field_type = self.map_stone_type_to_pep484_type(ns, field.data_type)
to_emit.append(
"{}: bb.Attribute[{}] = ...".format(field_name_reserved_check, field_type)
)
for s in to_emit:
self.emit(s)
def _generate_struct_or_union_class_custom_annotations(self):
"""
The _process_custom_annotations function allows client code to access
custom annotations defined in the spec.
"""
self.emit('def _process_custom_annotations(')
with self.indent():
self.emit('self,')
self.emit('annotation_type: Type[T],')
self.emit('field_path: Text,')
self.emit('processor: Callable[[T, U], U],')
self.import_tracker._register_typing_import('Type')
self.import_tracker._register_typing_import('Text')
self.import_tracker._register_typing_import('Callable')
self.emit(') -> None: ...')
self.emit()
def _get_pep_484_type_mapping_callbacks(self):
# type: () -> OverrideDefaultTypesDict
"""
Once-per-instance, generate a mapping from
"List" -> return pep4848-compatible List[SomeType]
"Nullable" -> return pep484-compatible Optional[SomeType]
This is per-instance because we have to also call `self._register_typing_import`, because
we need to potentially import some things.
"""
def upon_encountering_list(ns, data_type, override_dict):
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
self.import_tracker._register_typing_import("List")
return "List[{}]".format(
map_stone_type_to_python_type(ns, data_type, override_dict)
)
def upon_encountering_map(ns, map_data_type, override_dict):
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
map_type = cast(Map, map_data_type)
self.import_tracker._register_typing_import("Dict")
return "Dict[{}, {}]".format(
map_stone_type_to_python_type(ns, map_type.key_data_type, override_dict),
map_stone_type_to_python_type(ns, map_type.value_data_type, override_dict)
)
def upon_encountering_nullable(ns, data_type, override_dict):
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
self.import_tracker._register_typing_import("Optional")
return "Optional[{}]".format(
map_stone_type_to_python_type(ns, data_type, override_dict)
)
def upon_encountering_timestamp(
ns, data_type, override_dict
): # pylint: disable=unused-argument
# type: (ApiNamespace, DataType, OverrideDefaultTypesDict) -> typing.Text
self.import_tracker._register_adhoc_import("import datetime")
return map_stone_type_to_python_type(ns, data_type)
def upon_encountering_string(
ns, data_type, override_dict
): # pylint: disable=unused-argument
# type: (...) -> typing.Text
self.import_tracker._register_typing_import("Text")
return "Text"
callback_dict = {
List: upon_encountering_list,
Map: upon_encountering_map,
Nullable: upon_encountering_nullable,
Timestamp: upon_encountering_timestamp,
String: upon_encountering_string,
} # type: OverrideDefaultTypesDict
return callback_dict
def map_stone_type_to_pep484_type(self, ns, data_type):
# type: (ApiNamespace, DataType) -> typing.Text
assert self._pep_484_type_mapping_callbacks
return map_stone_type_to_python_type(ns, data_type,
override_dict=self._pep_484_type_mapping_callbacks)
def _generate_routes(
self,
namespace, # type: ApiNamespace
):
# type: (...) -> None
check_route_name_conflict(namespace)
for route in namespace.routes:
self.emit(
"{method_name}: bb.Route = ...".format(
method_name=fmt_func(route.name, version=route.version)))
if namespace.routes:
self.emit()
def _generate_imports_needed_for_typing(self):
# type: () -> None
output_buffer = StringIO()
with self.capture_emitted_output(output_buffer):
if self.import_tracker.cur_namespace_typing_imports:
self.emit("")
self.emit('from typing import (')
with self.indent():
for to_import in sorted(self.import_tracker.cur_namespace_typing_imports):
self.emit("{},".format(to_import))
self.emit(')')
if self.import_tracker.cur_namespace_adhoc_imports:
self.emit("")
for to_import in self.import_tracker.cur_namespace_adhoc_imports:
self.emit(to_import)
self.add_named_placeholder('imports_needed_for_typing', output_buffer.getvalue())

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,196 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from contextlib import contextmanager
from stone.ir import (
Boolean,
Bytes,
DataType,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_list_type,
is_timestamp_type,
is_union_type,
is_user_defined_type,
unwrap_nullable,
)
from stone.backend import CodeBackend
from stone.backends.swift_helpers import (
fmt_class,
fmt_func,
fmt_obj,
fmt_type,
fmt_var,
)
_serial_type_table = {
Boolean: 'BoolSerializer',
Bytes: 'NSDataSerializer',
Float32: 'FloatSerializer',
Float64: 'DoubleSerializer',
Int32: 'Int32Serializer',
Int64: 'Int64Serializer',
List: 'ArraySerializer',
String: 'StringSerializer',
Timestamp: 'NSDateSerializer',
UInt32: 'UInt32Serializer',
UInt64: 'UInt64Serializer',
Void: 'VoidSerializer',
}
stone_warning = """\
///
/// Copyright (c) 2016 Dropbox, Inc. All rights reserved.
///
/// Auto-generated by Stone, do not modify.
///
"""
# This will be at the top of the generated file.
base = """\
{}\
import Foundation
""".format(stone_warning)
undocumented = '(no description)'
class SwiftBaseBackend(CodeBackend):
"""Wrapper class over Stone generator for Swift logic."""
# pylint: disable=abstract-method
@contextmanager
def function_block(self, func, args, return_type=None):
signature = '{}({})'.format(func, args)
if return_type:
signature += ' -> {}'.format(return_type)
with self.block(signature):
yield
def _func_args(self, args_list, newlines=False, force_first=False, not_init=False):
out = []
first = True
for k, v in args_list:
# this is a temporary hack -- injected client-side args
# do not have a separate field for default value. Right now,
# default values are stored along with the type, e.g.
# `Bool = True` is a type, hence this check.
if first and force_first and '=' not in v:
k = "{0} {0}".format(k)
if first and v is not None and not_init:
out.append('{}'.format(v))
elif v is not None:
out.append('{}: {}'.format(k, v))
first = False
sep = ', '
if newlines:
sep += '\n' + self.make_indent()
return sep.join(out)
@contextmanager
def class_block(self, thing, protocols=None):
protocols = protocols or []
extensions = []
if isinstance(thing, DataType):
name = fmt_class(thing.name)
if thing.parent_type:
extensions.append(fmt_type(thing.parent_type))
else:
name = thing
extensions.extend(protocols)
extend_suffix = ': {}'.format(', '.join(extensions)) if extensions else ''
with self.block('open class {}{}'.format(name, extend_suffix)):
yield
def _struct_init_args(self, data_type, namespace=None): # pylint: disable=unused-argument
args = []
for field in data_type.all_fields:
name = fmt_var(field.name)
value = fmt_type(field.data_type)
data_type, nullable = unwrap_nullable(field.data_type)
if field.has_default:
if is_union_type(data_type):
default = '.{}'.format(fmt_var(field.default.tag_name))
else:
default = fmt_obj(field.default)
value += ' = {}'.format(default)
elif nullable:
value += ' = nil'
arg = (name, value)
args.append(arg)
return args
def _docf(self, tag, val):
if tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
return fmt_func(val, version)
elif tag == 'field':
if '.' in val:
cls, field = val.split('.')
return ('{} in {}'.format(fmt_var(field),
fmt_class(cls)))
else:
return fmt_var(val)
elif tag in ('type', 'val', 'link'):
return val
else:
import pdb
pdb.set_trace()
return val
def fmt_serial_type(data_type):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}.{}Serializer'
result = result.format(fmt_class(data_type.namespace.name),
fmt_class(data_type.name))
else:
result = _serial_type_table.get(data_type.__class__, fmt_class(data_type.name))
if is_list_type(data_type):
result = result + '<{}>'.format(fmt_serial_type(data_type.data_type))
return result if not nullable else 'NullableSerializer'
def fmt_serial_obj(data_type):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}.{}Serializer()'
result = result.format(fmt_class(data_type.namespace.name),
fmt_class(data_type.name))
else:
result = _serial_type_table.get(data_type.__class__, fmt_class(data_type.name))
if is_list_type(data_type):
result = result + '({})'.format(fmt_serial_obj(data_type.data_type))
elif is_timestamp_type(data_type):
result = result + '("{}")'.format(data_type.format)
else:
result = 'Serialization._{}'.format(result)
return result if not nullable else 'NullableSerializer({})'.format(result)

View file

@ -0,0 +1,280 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
from stone.ir import (
is_struct_type,
is_union_type,
is_void_type,
)
from stone.backends.swift import (
base,
fmt_serial_type,
SwiftBaseBackend,
undocumented,
)
from stone.backends.swift_helpers import (
check_route_name_conflict,
fmt_class,
fmt_func,
fmt_var,
fmt_type,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(
prog='swift-client-backend',
description=(
'Generates a Swift class with an object for each namespace, and in each '
'namespace object, a method for each route. This class assumes that the '
'swift_types backend was used with the same output directory.'),
)
_cmdline_parser.add_argument(
'-m',
'--module-name',
required=True,
type=str,
help=('The name of the Swift module to generate. Please exclude the .swift '
'file extension.'),
)
_cmdline_parser.add_argument(
'-c',
'--class-name',
required=True,
type=str,
help=('The name of the Swift class that contains an object for each namespace, '
'and in each namespace object, a method for each route.')
)
_cmdline_parser.add_argument(
'-t',
'--transport-client-name',
required=True,
type=str,
help='The name of the Swift class that manages network API calls.',
)
_cmdline_parser.add_argument(
'-y',
'--client-args',
required=True,
type=str,
help='The client-side route arguments to append to each route by style type.',
)
_cmdline_parser.add_argument(
'-z',
'--style-to-request',
required=True,
type=str,
help='The dict that maps a style type to a Swift request object name.',
)
class SwiftBackend(SwiftBaseBackend):
"""
Generates Swift client base that implements route interfaces.
Examples:
```
open class ExampleClientBase {
/// Routes within the namespace1 namespace. See Namespace1 for details.
open var namespace1: Namespace1!
/// Routes within the namespace2 namespace. See Namespace2 for details.
open var namespace2: Namespace2!
public init(client: ExampleTransportClient) {
self.namespace1 = Namespace1(client: client)
self.namespace2 = Namespace2(client: client)
}
}
```
Here, `ExampleTransportClient` would contain the implementation of a handwritten,
project-specific networking client. Additionally, the `Namespace1` object would
have as its methods all routes in the `Namespace1` namespace. A hypothetical 'copy'
enpoding might be implemented like:
```
open func copy(fromPath fromPath: String, toPath: String) ->
ExampleRequestType<Namespace1.CopySerializer, Namespace1.CopyErrorSerializer> {
let route = Namespace1.copy
let serverArgs = Namespace1.CopyArg(fromPath: fromPath, toPath: toPath)
return client.request(route, serverArgs: serverArgs)
}
```
Here, ExampleRequestType is a project-specific request type, parameterized by response and
error serializers.
"""
cmdline_parser = _cmdline_parser
def generate(self, api):
for namespace in api.namespaces.values():
ns_class = fmt_class(namespace.name)
if namespace.routes:
with self.output_to_relative_path('{}Routes.swift'.format(ns_class)):
self._generate_routes(namespace)
with self.output_to_relative_path('{}.swift'.format(self.args.module_name)):
self._generate_client(api)
def _generate_client(self, api):
self.emit_raw(base)
self.emit('import Alamofire')
self.emit()
with self.block('open class {}'.format(self.args.class_name)):
namespace_fields = []
for namespace in api.namespaces.values():
if namespace.routes:
namespace_fields.append((namespace.name,
fmt_class(namespace.name)))
for var, typ in namespace_fields:
self.emit('/// Routes within the {} namespace. '
'See {}Routes for details.'.format(var, typ))
self.emit('open var {}: {}Routes!'.format(var, typ))
self.emit()
with self.function_block('public init', args=self._func_args(
[('client', '{}'.format(self.args.transport_client_name))])):
for var, typ in namespace_fields:
self.emit('self.{} = {}Routes(client: client)'.format(var, typ))
def _generate_routes(self, namespace):
check_route_name_conflict(namespace)
ns_class = fmt_class(namespace.name)
self.emit_raw(base)
self.emit('/// Routes for the {} namespace'.format(namespace.name))
with self.block('open class {}Routes'.format(ns_class)):
self.emit('public let client: {}'.format(self.args.transport_client_name))
args = [('client', '{}'.format(self.args.transport_client_name))]
with self.function_block('init', self._func_args(args)):
self.emit('self.client = client')
self.emit()
for route in namespace.routes:
self._generate_route(namespace, route)
def _get_route_args(self, namespace, route):
data_type = route.arg_data_type
arg_type = fmt_type(data_type)
if is_struct_type(data_type):
arg_list = self._struct_init_args(data_type, namespace=namespace)
doc_list = [(fmt_var(f.name), self.process_doc(f.doc, self._docf)
if f.doc else undocumented) for f in data_type.fields if f.doc]
elif is_union_type(data_type):
arg_list = [(fmt_var(data_type.name), '{}.{}'.format(
fmt_class(namespace.name), fmt_class(data_type.name)))]
doc_list = [(fmt_var(data_type.name),
self.process_doc(data_type.doc, self._docf)
if data_type.doc else 'The {} union'.format(fmt_class(data_type.name)))]
else:
arg_list = [] if is_void_type(data_type) else [('request', arg_type)]
doc_list = []
return arg_list, doc_list
def _emit_route(self, namespace, route, req_obj_name, extra_args=None, extra_docs=None):
arg_list, doc_list = self._get_route_args(namespace, route)
extra_args = extra_args or []
extra_docs = extra_docs or []
arg_type = fmt_type(route.arg_data_type)
func_name = fmt_func(route.name, route.version)
if route.doc:
route_doc = self.process_doc(route.doc, self._docf)
else:
route_doc = 'The {} route'.format(func_name)
self.emit_wrapped_text(route_doc, prefix='/// ', width=120)
self.emit('///')
for name, doc in doc_list + extra_docs:
param_doc = '- parameter {}: {}'.format(name, doc if doc is not None else undocumented)
self.emit_wrapped_text(param_doc, prefix='/// ', width=120)
self.emit('///')
output = (' - returns: Through the response callback, the caller will ' +
'receive a `{}` object on success or a `{}` object on failure.')
output = output.format(fmt_type(route.result_data_type),
fmt_type(route.error_data_type))
self.emit_wrapped_text(output, prefix='/// ', width=120)
func_args = [
('route', '{}.{}'.format(fmt_class(namespace.name), func_name)),
]
client_args = []
return_args = [('route', 'route')]
for name, value, typ in extra_args:
arg_list.append((name, typ))
func_args.append((name, value))
client_args.append((name, value))
rtype = fmt_serial_type(route.result_data_type)
etype = fmt_serial_type(route.error_data_type)
self._maybe_generate_deprecation_warning(route)
with self.function_block('@discardableResult open func {}'.format(func_name),
args=self._func_args(arg_list, force_first=False),
return_type='{}<{}, {}>'.format(req_obj_name, rtype, etype)):
self.emit('let route = {}.{}'.format(fmt_class(namespace.name), func_name))
if is_struct_type(route.arg_data_type):
args = [(name, name) for name, _ in self._struct_init_args(route.arg_data_type)]
func_args += [('serverArgs', '{}({})'.format(arg_type, self._func_args(args)))]
self.emit('let serverArgs = {}({})'.format(arg_type, self._func_args(args)))
elif is_union_type(route.arg_data_type):
self.emit('let serverArgs = {}'.format(fmt_var(route.arg_data_type.name)))
if not is_void_type(route.arg_data_type):
return_args += [('serverArgs', 'serverArgs')]
return_args += client_args
txt = 'return client.request({})'.format(
self._func_args(return_args, not_init=True)
)
self.emit(txt)
self.emit()
def _maybe_generate_deprecation_warning(self, route):
if route.deprecated:
msg = '{} is deprecated.'.format(fmt_func(route.name, route.version))
if route.deprecated.by:
msg += ' Use {}.'.format(
fmt_func(route.deprecated.by.name, route.deprecated.by.version))
self.emit('@available(*, unavailable, message:"{}")'.format(msg))
def _generate_route(self, namespace, route):
route_type = route.attrs.get('style')
client_args = json.loads(self.args.client_args)
style_to_request = json.loads(self.args.style_to_request)
if route_type not in client_args.keys():
self._emit_route(namespace, route, style_to_request[route_type])
else:
for args_data in client_args[route_type]:
req_obj_key, type_data_list = tuple(args_data)
req_obj_name = style_to_request[req_obj_key]
extra_args = [tuple(type_data[:-1]) for type_data in type_data_list]
extra_docs = [(type_data[0], type_data[-1]) for type_data in type_data_list]
self._emit_route(namespace, route, req_obj_name, extra_args, extra_docs)

View file

@ -0,0 +1,173 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import pprint
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_boolean_type,
is_list_type,
is_numeric_type,
is_string_type,
is_tag_ref,
is_user_defined_type,
unwrap_nullable,
)
from .helpers import split_words
# This file defines *stylistic* choices for Swift
# (ie, that class names are UpperCamelCase and that variables are lowerCamelCase)
_type_table = {
Boolean: 'Bool',
Bytes: 'Data',
Float32: 'Float',
Float64: 'Double',
Int32: 'Int32',
Int64: 'Int64',
List: 'Array',
String: 'String',
Timestamp: 'Date',
UInt32: 'UInt32',
UInt64: 'UInt64',
Void: 'Void',
}
_reserved_words = {
'description',
'bool',
'nsdata'
'float',
'double',
'int32',
'int64',
'list',
'string',
'timestamp',
'uint32',
'uint64',
'void',
'associatedtype',
'class',
'deinit',
'enum',
'extension',
'func',
'import',
'init',
'inout',
'internal',
'let',
'operator',
'private',
'protocol',
'public',
'static',
'struct',
'subscript',
'typealias',
'var',
'default',
}
def fmt_obj(o):
assert not isinstance(o, dict), "Only use for base type literals"
if o is True:
return 'true'
if o is False:
return 'false'
if o is None:
return 'nil'
if o == u'':
return '""'
return pprint.pformat(o, width=1)
def _format_camelcase(name, lower_first=True):
words = [word.capitalize() for word in split_words(name)]
if lower_first:
words[0] = words[0].lower()
ret = ''.join(words)
if ret.lower() in _reserved_words:
ret += '_'
return ret
def fmt_class(name):
return _format_camelcase(name, lower_first=False)
def fmt_func(name, version):
if version > 1:
name = '{}_v{}'.format(name, version)
name = _format_camelcase(name)
return name
def fmt_type(data_type):
data_type, nullable = unwrap_nullable(data_type)
if is_user_defined_type(data_type):
result = '{}.{}'.format(fmt_class(data_type.namespace.name),
fmt_class(data_type.name))
else:
result = _type_table.get(data_type.__class__, fmt_class(data_type.name))
if is_list_type(data_type):
result = result + '<{}>'.format(fmt_type(data_type.data_type))
return result if not nullable else result + '?'
def fmt_var(name):
return _format_camelcase(name)
def fmt_default_value(namespace, field):
if is_tag_ref(field.default):
return '{}.{}Serializer().serialize(.{})'.format(
fmt_class(namespace.name),
fmt_class(field.default.union_data_type.name),
fmt_var(field.default.tag_name))
elif is_list_type(field.data_type):
return '.array({})'.format(field.default)
elif is_numeric_type(field.data_type):
return '.number({})'.format(field.default)
elif is_string_type(field.data_type):
return '.str({})'.format(fmt_obj(field.default))
elif is_boolean_type(field.data_type):
if field.default:
bool_str = '1'
else:
bool_str = '0'
return '.number({})'.format(bool_str)
else:
raise TypeError('Can\'t handle default value type %r' %
type(field.data_type))
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route

View file

@ -0,0 +1,486 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import os
import shutil
import six
from contextlib import contextmanager
from stone.ir import (
is_list_type,
is_numeric_type,
is_string_type,
is_struct_type,
is_union_type,
is_void_type,
unwrap_nullable,
)
from stone.backends.swift_helpers import (
check_route_name_conflict,
fmt_class,
fmt_default_value,
fmt_func,
fmt_var,
fmt_type,
)
from stone.backends.swift import (
base,
fmt_serial_obj,
SwiftBaseBackend,
undocumented,
)
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
_cmdline_parser = argparse.ArgumentParser(prog='swift-types-backend')
_cmdline_parser.add_argument(
'-r',
'--route-method',
help=('A string used to construct the location of a Swift method for a '
'given route; use {ns} as a placeholder for namespace name and '
'{route} for the route name.'),
)
class SwiftTypesBackend(SwiftBaseBackend):
"""
Generates Swift modules to represent the input Stone spec.
Examples for a hypothetical 'copy' enpoint:
Endpoint argument (struct):
```
open class CopyArg: CustomStringConvertible {
open let fromPath: String
open let toPath: String
public init(fromPath: String, toPath: String) {
stringValidator(pattern: "/(.|[\\r\\n])*")(value: fromPath)
self.fromPath = fromPath
stringValidator(pattern: "/(.|[\\r\\n])*")(value: toPath)
self.toPath = toPath
}
open var description: String {
return "\\(SerializeUtil.prepareJSONForSerialization(
CopyArgSerializer().serialize(self)))"
}
}
```
Endpoint error (union):
```
open enum CopyError: CustomStringConvertible {
case TooManyFiles
case Other
open var description: String {
return "\\(SerializeUtil.prepareJSONForSerialization(
CopyErrorSerializer().serialize(self)))"
}
}
```
Argument serializer (error serializer not listed):
```
open class CopyArgSerializer: JSONSerializer {
public init() { }
open func serialize(value: CopyArg) -> JSON {
let output = [
"from_path": Serialization.serialize(value.fromPath),
"to_path": Serialization.serialize(value.toPath),
]
return .Dictionary(output)
}
open func deserialize(json: JSON) -> CopyArg {
switch json {
case .Dictionary(let dict):
let fromPath = Serialization.deserialize(dict["from_path"] ?? .Null)
let toPath = Serialization.deserialize(dict["to_path"] ?? .Null)
return CopyArg(fromPath: fromPath, toPath: toPath)
default:
fatalError("Type error deserializing")
}
}
}
```
"""
cmdline_parser = _cmdline_parser
def generate(self, api):
rsrc_folder = os.path.join(os.path.dirname(__file__), 'swift_rsrc')
self.logger.info('Copying StoneValidators.swift to output folder')
shutil.copy(os.path.join(rsrc_folder, 'StoneValidators.swift'),
self.target_folder_path)
self.logger.info('Copying StoneSerializers.swift to output folder')
shutil.copy(os.path.join(rsrc_folder, 'StoneSerializers.swift'),
self.target_folder_path)
self.logger.info('Copying StoneBase.swift to output folder')
shutil.copy(os.path.join(rsrc_folder, 'StoneBase.swift'),
self.target_folder_path)
jazzy_cfg_path = os.path.join('../Format', 'jazzy.json')
with open(jazzy_cfg_path) as jazzy_file:
jazzy_cfg = json.load(jazzy_file)
for namespace in api.namespaces.values():
ns_class = fmt_class(namespace.name)
with self.output_to_relative_path('{}.swift'.format(ns_class)):
self._generate_base_namespace_module(api, namespace)
jazzy_cfg['custom_categories'][1]['children'].append(ns_class)
if namespace.routes:
jazzy_cfg['custom_categories'][0]['children'].append(ns_class + 'Routes')
with self.output_to_relative_path('../../../../.jazzy.json'):
self.emit_raw(json.dumps(jazzy_cfg, indent=2) + '\n')
def _generate_base_namespace_module(self, api, namespace):
self.emit_raw(base)
routes_base = 'Datatypes and serializers for the {} namespace'.format(namespace.name)
self.emit_wrapped_text(routes_base, prefix='/// ', width=120)
with self.block('open class {}'.format(fmt_class(namespace.name))):
for data_type in namespace.linearize_data_types():
if is_struct_type(data_type):
self._generate_struct_class(namespace, data_type)
self.emit()
elif is_union_type(data_type):
self._generate_union_type(namespace, data_type)
self.emit()
if namespace.routes:
self._generate_route_objects(api.route_schema, namespace)
def _generate_struct_class(self, namespace, data_type):
if data_type.doc:
doc = self.process_doc(data_type.doc, self._docf)
else:
doc = 'The {} struct'.format(fmt_class(data_type.name))
self.emit_wrapped_text(doc, prefix='/// ', width=120)
protocols = []
if not data_type.parent_type:
protocols.append('CustomStringConvertible')
with self.class_block(data_type, protocols=protocols):
for field in data_type.fields:
fdoc = self.process_doc(field.doc,
self._docf) if field.doc else undocumented
self.emit_wrapped_text(fdoc, prefix='/// ', width=120)
self.emit('public let {}: {}'.format(
fmt_var(field.name),
fmt_type(field.data_type),
))
self._generate_struct_init(namespace, data_type)
decl = 'open var' if not data_type.parent_type else 'open override var'
with self.block('{} description: String'.format(decl)):
cls = fmt_class(data_type.name) + 'Serializer'
self.emit('return "\\(SerializeUtil.prepareJSONForSerialization' +
'({}().serialize(self)))"'.format(cls))
self._generate_struct_class_serializer(namespace, data_type)
def _generate_struct_init(self, namespace, data_type): # pylint: disable=unused-argument
# init method
args = self._struct_init_args(data_type)
if data_type.parent_type and not data_type.fields:
return
with self.function_block('public init', self._func_args(args)):
for field in data_type.fields:
v = fmt_var(field.name)
validator = self._determine_validator_type(field.data_type, v)
if validator:
self.emit('{}({})'.format(validator, v))
self.emit('self.{0} = {0}'.format(v))
if data_type.parent_type:
func_args = [(fmt_var(f.name),
fmt_var(f.name))
for f in data_type.parent_type.all_fields]
self.emit('super.init({})'.format(self._func_args(func_args)))
def _determine_validator_type(self, data_type, value):
data_type, nullable = unwrap_nullable(data_type)
if is_list_type(data_type):
item_validator = self._determine_validator_type(data_type.data_type, value)
if item_validator:
v = "arrayValidator({})".format(
self._func_args([
("minItems", data_type.min_items),
("maxItems", data_type.max_items),
("itemValidator", item_validator),
])
)
else:
return None
elif is_numeric_type(data_type):
v = "comparableValidator({})".format(
self._func_args([
("minValue", data_type.min_value),
("maxValue", data_type.max_value),
])
)
elif is_string_type(data_type):
pat = data_type.pattern if data_type.pattern else None
pat = pat.encode('unicode_escape').replace(six.ensure_binary("\""),
six.ensure_binary("\\\"")) if pat else pat
v = "stringValidator({})".format(
self._func_args([
("minLength", data_type.min_length),
("maxLength", data_type.max_length),
("pattern", '"{}"'.format(six.ensure_str(pat)) if pat else None),
])
)
else:
return None
if nullable:
v = "nullableValidator({})".format(v)
return v
def _generate_enumerated_subtype_serializer(self, namespace, # pylint: disable=unused-argument
data_type):
with self.block('switch value'):
for tags, subtype in data_type.get_all_subtypes_with_tags():
assert len(tags) == 1, tags
tag = tags[0]
tagvar = fmt_var(tag)
self.emit('case let {} as {}:'.format(
tagvar,
fmt_type(subtype)
))
with self.indent():
block_txt = 'for (k, v) in Serialization.getFields({}.serialize({}))'.format(
fmt_serial_obj(subtype),
tagvar,
)
with self.block(block_txt):
self.emit('output[k] = v')
self.emit('output[".tag"] = .str("{}")'.format(tag))
self.emit('default: fatalError("Tried to serialize unexpected subtype")')
def _generate_struct_base_class_deserializer(self, namespace, data_type):
args = []
for field in data_type.all_fields:
var = fmt_var(field.name)
value = 'dict["{}"]'.format(field.name)
self.emit('let {} = {}.deserialize({} ?? {})'.format(
var,
fmt_serial_obj(field.data_type),
value,
fmt_default_value(namespace, field) if field.has_default else '.null'
))
args.append((var, var))
self.emit('return {}({})'.format(
fmt_class(data_type.name),
self._func_args(args)
))
def _generate_enumerated_subtype_deserializer(self, namespace, data_type):
self.emit('let tag = Serialization.getTag(dict)')
with self.block('switch tag'):
for tags, subtype in data_type.get_all_subtypes_with_tags():
assert len(tags) == 1, tags
tag = tags[0]
self.emit('case "{}":'.format(tag))
with self.indent():
self.emit('return {}.deserialize(json)'.format(fmt_serial_obj(subtype)))
self.emit('default:')
with self.indent():
if data_type.is_catch_all():
self._generate_struct_base_class_deserializer(namespace, data_type)
else:
self.emit('fatalError("Unknown tag \\(tag)")')
def _generate_struct_class_serializer(self, namespace, data_type):
with self.serializer_block(data_type):
with self.serializer_func(data_type):
if not data_type.all_fields:
self.emit('let output = [String: JSON]()')
else:
intro = 'var' if data_type.has_enumerated_subtypes() else 'let'
self.emit("{} output = [ ".format(intro))
for field in data_type.all_fields:
self.emit('"{}": {}.serialize(value.{}),'.format(
field.name,
fmt_serial_obj(field.data_type),
fmt_var(field.name)
))
self.emit(']')
if data_type.has_enumerated_subtypes():
self._generate_enumerated_subtype_serializer(namespace, data_type)
self.emit('return .dictionary(output)')
with self.deserializer_func(data_type):
with self.block("switch json"):
dict_name = "let dict" if data_type.all_fields else "_"
self.emit("case .dictionary({}):".format(dict_name))
with self.indent():
if data_type.has_enumerated_subtypes():
self._generate_enumerated_subtype_deserializer(namespace, data_type)
else:
self._generate_struct_base_class_deserializer(namespace, data_type)
self.emit("default:")
with self.indent():
self.emit('fatalError("Type error deserializing")')
def _format_tag_type(self, namespace, data_type): # pylint: disable=unused-argument
if is_void_type(data_type):
return ''
else:
return '({})'.format(fmt_type(data_type))
def _generate_union_type(self, namespace, data_type):
if data_type.doc:
doc = self.process_doc(data_type.doc, self._docf)
else:
doc = 'The {} union'.format(fmt_class(data_type.name))
self.emit_wrapped_text(doc, prefix='/// ', width=120)
class_type = fmt_class(data_type.name)
with self.block('public enum {}: CustomStringConvertible'.format(class_type)):
for field in data_type.all_fields:
typ = self._format_tag_type(namespace, field.data_type)
fdoc = self.process_doc(field.doc,
self._docf) if field.doc else 'An unspecified error.'
self.emit_wrapped_text(fdoc, prefix='/// ', width=120)
self.emit('case {}{}'.format(fmt_var(field.name), typ))
self.emit()
with self.block('public var description: String'):
cls = class_type + 'Serializer'
self.emit('return "\\(SerializeUtil.prepareJSONForSerialization' +
'({}().serialize(self)))"'.format(cls))
self._generate_union_serializer(data_type)
def _tag_type(self, data_type, field):
return "{}.{}".format(
fmt_class(data_type.name),
fmt_var(field.name)
)
def _generate_union_serializer(self, data_type):
with self.serializer_block(data_type):
with self.serializer_func(data_type), self.block('switch value'):
for field in data_type.all_fields:
field_type = field.data_type
case = '.{}{}'.format(fmt_var(field.name),
'' if is_void_type(field_type) else '(let arg)')
self.emit('case {}:'.format(case))
with self.indent():
if is_void_type(field_type):
self.emit('var d = [String: JSON]()')
elif (is_struct_type(field_type) and
not field_type.has_enumerated_subtypes()):
self.emit('var d = Serialization.getFields({}.serialize(arg))'.format(
fmt_serial_obj(field_type)))
else:
self.emit('var d = ["{}": {}.serialize(arg)]'.format(
field.name,
fmt_serial_obj(field_type)))
self.emit('d[".tag"] = .str("{}")'.format(field.name))
self.emit('return .dictionary(d)')
with self.deserializer_func(data_type):
with self.block("switch json"):
self.emit("case .dictionary(let d):")
with self.indent():
self.emit('let tag = Serialization.getTag(d)')
with self.block('switch tag'):
for field in data_type.all_fields:
field_type = field.data_type
self.emit('case "{}":'.format(field.name))
tag_type = self._tag_type(data_type, field)
with self.indent():
if is_void_type(field_type):
self.emit('return {}'.format(tag_type))
else:
if (is_struct_type(field_type) and
not field_type.has_enumerated_subtypes()):
subdict = 'json'
else:
subdict = 'd["{}"] ?? .null'.format(field.name)
self.emit('let v = {}.deserialize({})'.format(
fmt_serial_obj(field_type), subdict
))
self.emit('return {}(v)'.format(tag_type))
self.emit('default:')
with self.indent():
if data_type.catch_all_field:
self.emit('return {}'.format(
self._tag_type(data_type, data_type.catch_all_field)
))
else:
self.emit('fatalError("Unknown tag \\(tag)")')
self.emit("default:")
with self.indent():
self.emit('fatalError("Failed to deserialize")')
@contextmanager
def serializer_block(self, data_type):
with self.class_block(fmt_class(data_type.name) + 'Serializer',
protocols=['JSONSerializer']):
self.emit("public init() { }")
yield
@contextmanager
def serializer_func(self, data_type):
with self.function_block('open func serialize',
args=self._func_args([('_ value', fmt_class(data_type.name))]),
return_type='JSON'):
yield
@contextmanager
def deserializer_func(self, data_type):
with self.function_block('open func deserialize',
args=self._func_args([('_ json', 'JSON')]),
return_type=fmt_class(data_type.name)):
yield
def _generate_route_objects(self, route_schema, namespace):
check_route_name_conflict(namespace)
self.emit()
self.emit('/// Stone Route Objects')
self.emit()
for route in namespace.routes:
var_name = fmt_func(route.name, route.version)
with self.block('static let {} = Route('.format(var_name),
delim=(None, None), after=')'):
self.emit('name: \"{}\",'.format(route.name))
self.emit('version: {},'.format(route.version))
self.emit('namespace: \"{}\",'.format(namespace.name))
self.emit('deprecated: {},'.format('true' if route.deprecated
is not None else 'false'))
self.emit('argSerializer: {},'.format(fmt_serial_obj(route.arg_data_type)))
self.emit('responseSerializer: {},'.format(fmt_serial_obj(route.result_data_type)))
self.emit('errorSerializer: {},'.format(fmt_serial_obj(route.error_data_type)))
attrs = []
for field in route_schema.fields:
attr_key = field.name
attr_val = ("\"{}\"".format(route.attrs.get(attr_key))
if route.attrs.get(attr_key) else 'nil')
attrs.append('\"{}\": {}'.format(attr_key, attr_val))
self.generate_multiline_list(
attrs, delim=('attrs: [', ']'), compact=True)

View file

@ -0,0 +1,152 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import re
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
from stone.backend import CodeBackend
from stone.backends.tsd_helpers import (
check_route_name_conflict,
fmt_error_type,
fmt_func,
fmt_tag,
fmt_type,
)
from stone.ir import Void
_cmdline_parser = argparse.ArgumentParser(prog='tsd-client-backend')
_cmdline_parser.add_argument(
'template',
help=('A template to use when generating the TypeScript definition file.')
)
_cmdline_parser.add_argument(
'filename',
help=('The name to give the single TypeScript definition file to contain '
'all of the emitted types.'),
)
_cmdline_parser.add_argument(
'-t',
'--template-string',
type=str,
default='ROUTES',
help=('The name of the template string to replace with route definitions. '
'Defaults to ROUTES, which replaces the string /*ROUTES*/ with route '
'definitions.')
)
_cmdline_parser.add_argument(
'-i',
'--indent-level',
type=int,
default=1,
help=('Indentation level to emit types at. Routes are automatically '
'indented one level further than this.')
)
_cmdline_parser.add_argument(
'-s',
'--spaces-per-indent',
type=int,
default=2,
help=('Number of spaces to use per indentation level.')
)
_cmdline_parser.add_argument(
'--wrap-response-in',
type=str,
default='',
help=('Wraps the response in a response class')
)
_header = """\
// Auto-generated by Stone, do not modify.
"""
class TSDClientBackend(CodeBackend):
"""Generates a TypeScript definition file with routes defined."""
cmdline_parser = _cmdline_parser
preserve_aliases = True
def generate(self, api):
spaces_per_indent = self.args.spaces_per_indent
indent_level = self.args.indent_level
template_path = os.path.join(self.target_folder_path, self.args.template)
template_string = self.args.template_string
with self.output_to_relative_path(self.args.filename):
if os.path.isfile(template_path):
with open(template_path, 'r') as template_file:
template = template_file.read()
else:
raise AssertionError('TypeScript template file does not exist.')
# /*ROUTES*/
r_match = re.search("/\\*%s\\*/" % (template_string), template)
if not r_match:
raise AssertionError(
'Missing /*%s*/ in TypeScript template file.' % template_string)
r_start = r_match.start()
r_end = r_match.end()
r_ends_with_newline = template[r_end - 1] == '\n'
t_end = len(template)
t_ends_with_newline = template[t_end - 1] == '\n'
self.emit_raw(template[0:r_start] + ('\n' if not r_ends_with_newline else ''))
self._generate_routes(api, spaces_per_indent, indent_level)
self.emit_raw(template[r_end + 1:t_end] + ('\n' if not t_ends_with_newline else ''))
def _generate_routes(self, api, spaces_per_indent, indent_level):
with self.indent(dent=spaces_per_indent * (indent_level + 1)):
for namespace in api.namespaces.values():
# first check for route name conflict
check_route_name_conflict(namespace)
for route in namespace.routes:
self._generate_route(
namespace, route)
def _generate_route(self, namespace, route):
function_name = fmt_func(namespace.name + '_' + route.name, route.version)
self.emit()
self.emit('/**')
if route.doc:
self.emit_wrapped_text(self.process_doc(route.doc, self._docf), prefix=' * ')
self.emit(' *')
self.emit_wrapped_text('When an error occurs, the route rejects the promise with type %s.'
% fmt_error_type(route.error_data_type), prefix=' * ')
if route.deprecated:
self.emit(' * @deprecated')
if route.arg_data_type.__class__ != Void:
self.emit(' * @param arg The request parameters.')
self.emit(' */')
return_type = None
if self.args.wrap_response_in:
return_type = 'Promise<%s<%s>>;' % (self.args.wrap_response_in,
fmt_type(route.result_data_type))
else:
return_type = 'Promise<%s>;' % (fmt_type(route.result_data_type))
arg = ''
if route.arg_data_type.__class__ != Void:
arg = 'arg: %s' % fmt_type(route.arg_data_type)
self.emit('public %s(%s): %s' % (function_name, arg, return_type))
def _docf(self, tag, val):
"""
Callback to process documentation references.
"""
return fmt_tag(None, tag, val)

View file

@ -0,0 +1,178 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from stone.backend import Backend
from stone.ir.api import ApiNamespace
from stone.ir import (
Boolean,
Bytes,
Float32,
Float64,
Int32,
Int64,
List,
String,
Timestamp,
UInt32,
UInt64,
Void,
is_alias,
is_list_type,
is_struct_type,
is_map_type,
is_user_defined_type,
)
from stone.backends.helpers import (
fmt_camel,
)
_base_type_table = {
Boolean: 'boolean',
Bytes: 'string',
Float32: 'number',
Float64: 'number',
Int32: 'number',
Int64: 'number',
List: 'Array',
String: 'string',
UInt32: 'number',
UInt64: 'number',
Timestamp: 'Timestamp',
Void: 'void',
}
def fmt_error_type(data_type, inside_namespace=None):
"""
Converts the error type into a TypeScript type.
inside_namespace should be set to the namespace that the reference
occurs in, or None if this parameter is not relevant.
"""
return 'Error<%s>' % fmt_type(data_type, inside_namespace)
def fmt_type_name(data_type, inside_namespace=None):
"""
Produces a TypeScript type name for the given data type.
inside_namespace should be set to the namespace that the reference
occurs in, or None if this parameter is not relevant.
"""
if is_user_defined_type(data_type) or is_alias(data_type):
if data_type.namespace == inside_namespace:
return data_type.name
else:
return '%s.%s' % (data_type.namespace.name, data_type.name)
else:
fmted_type = _base_type_table.get(data_type.__class__, 'Object')
if is_list_type(data_type):
fmted_type += '<' + fmt_type(data_type.data_type, inside_namespace) + '>'
elif is_map_type(data_type):
key_data_type = _base_type_table.get(data_type.key_data_type, 'string')
value_data_type = fmt_type_name(data_type.value_data_type, inside_namespace)
fmted_type = '{[key: %s]: %s}' % (key_data_type, value_data_type)
return fmted_type
def fmt_polymorphic_type_reference(data_type, inside_namespace=None):
"""
Produces a TypeScript type name for the meta-type that refers to the given
struct, which belongs to an enumerated subtypes tree. This meta-type contains the
.tag field that lets developers discriminate between subtypes.
"""
# NOTE: These types are not properly namespaced, so there could be a conflict
# with other user-defined types. If this ever surfaces as a problem, we
# can defer emitting these types until the end, and emit them in a
# nested namespace (e.g., files.references.MetadataReference).
return fmt_type_name(data_type, inside_namespace) + "Reference"
def fmt_type(data_type, inside_namespace=None):
"""
Returns a TypeScript type annotation for a data type.
May contain a union of enumerated subtypes.
inside_namespace should be set to the namespace that the type reference
occurs in, or None if this parameter is not relevant.
"""
if is_struct_type(data_type) and data_type.has_enumerated_subtypes():
possible_types = []
possible_subtypes = data_type.get_all_subtypes_with_tags()
for _, subtype in possible_subtypes:
possible_types.append(fmt_polymorphic_type_reference(subtype, inside_namespace))
if data_type.is_catch_all():
possible_types.append(fmt_polymorphic_type_reference(data_type, inside_namespace))
return fmt_union(possible_types)
else:
return fmt_type_name(data_type, inside_namespace)
def fmt_union(type_strings):
"""
Returns a union type of the given types.
"""
return '|'.join(type_strings) if len(type_strings) > 1 else type_strings[0]
def fmt_func(name, version):
if version == 1:
return fmt_camel(name)
return fmt_camel(name) + 'V{}'.format(version)
def fmt_var(name):
return fmt_camel(name)
def fmt_tag(cur_namespace, tag, val):
"""
Processes a documentation reference.
"""
if tag == 'type':
fq_val = val
if '.' not in val and cur_namespace is not None:
fq_val = cur_namespace.name + '.' + fq_val
return fq_val
elif tag == 'route':
if ':' in val:
val, version = val.split(':', 1)
version = int(version)
else:
version = 1
return fmt_func(val, version) + "()"
elif tag == 'link':
anchor, link = val.rsplit(' ', 1)
# There's no way to have links in TSDoc, so simply use JSDoc's formatting.
# It's entirely possible some editors support this.
return '[%s]{@link %s}' % (anchor, link)
elif tag == 'val':
# Value types seem to match JavaScript (true, false, null)
return val
elif tag == 'field':
return val
else:
raise RuntimeError('Unknown doc ref tag %r' % tag)
def check_route_name_conflict(namespace):
"""
Check name conflicts among generated route definitions. Raise a runtime exception when a
conflict is encountered.
"""
route_by_name = {}
for route in namespace.routes:
route_name = fmt_func(route.name, route.version)
if route_name in route_by_name:
other_route = route_by_name[route_name]
raise RuntimeError(
'There is a name conflict between {!r} and {!r}'.format(other_route, route))
route_by_name[route_name] = route
def generate_imports_for_referenced_namespaces(backend, namespace, module_name_prefix):
# type: (Backend, ApiNamespace, str) -> None
imported_namespaces = namespace.get_imported_namespaces()
if not imported_namespaces:
return
for ns in imported_namespaces:
backend.emit(
"import * as {namespace_name} from '{module_name_prefix}{namespace_name}';".format(
module_name_prefix=module_name_prefix,
namespace_name=ns.name
)
)
backend.emit()

View file

@ -0,0 +1,517 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import json
import os
import re
import six
import sys
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
from stone.ir import ApiNamespace
from stone.ir import (
is_alias,
is_struct_type,
is_union_type,
is_user_defined_type,
is_void_type,
unwrap_nullable,
)
from stone.backend import CodeBackend
from stone.backends.helpers import (
fmt_pascal,
)
from stone.backends.tsd_helpers import (
fmt_polymorphic_type_reference,
fmt_tag,
fmt_type,
fmt_type_name,
fmt_union,
generate_imports_for_referenced_namespaces,
)
_cmdline_parser = argparse.ArgumentParser(prog='tsd-types-backend')
_cmdline_parser.add_argument(
'template',
help=('A template to use when generating the TypeScript definition file. '
'Replaces the string /*TYPES*/ with stone type definitions.')
)
_cmdline_parser.add_argument(
'filename',
nargs='?',
help=('The name of the generated typeScript definition file that contains '
'all of the emitted types.'),
)
_cmdline_parser.add_argument(
'--exclude_error_types',
default=False,
action='store_true',
help='If true, the output will exclude the interface for Error type.',
)
_cmdline_parser.add_argument(
'-e',
'--extra-arg',
action='append',
type=str,
default=[],
help=("Additional argument to add to a route's argument based "
"on if the route has a certain attribute set. Format (JSON): "
'{"match": ["ROUTE_ATTR", ROUTE_VALUE_TO_MATCH], '
'"arg_name": "ARG_NAME", "arg_type": "ARG_TYPE", '
'"arg_docstring": "ARG_DOCSTRING"}'),
)
_cmdline_parser.add_argument(
'-i',
'--indent-level',
type=int,
default=1,
help=('Indentation level to emit types at. Routes are automatically '
'indented one level further than this.')
)
_cmdline_parser.add_argument(
'-s',
'--spaces-per-indent',
type=int,
default=2,
help=('Number of spaces to use per indentation level.')
)
_cmdline_parser.add_argument(
'-p',
'--module-name-prefix',
type=str,
default='',
help=('Prefix for data type module names. '
'This is useful for repo which requires absolute path as '
'module name')
)
_cmdline_parser.add_argument(
'--export-namespaces',
default=False,
action='store_true',
help=('Adds the export tag to each namespace.'
'This is useful is you are not placing each namespace '
'inside of a module and want to export each namespace individually')
)
_header = """\
// Auto-generated by Stone, do not modify.
"""
_types_header = """\
/**
* An Error object returned from a route.
*/
interface Error<T> {
\t// Text summary of the error.
\terror_summary: string;
\t// The error object.
\terror: T;
\t// User-friendly error message.
\tuser_message: UserMessage;
}
/**
* User-friendly error message.
*/
interface UserMessage {
\t// The message.
\ttext: string;
\t// The locale of the message.
\tlocale: string;
}
"""
_timestamp_definition = "type Timestamp = string;"
class TSDTypesBackend(CodeBackend):
"""
Generates a single TypeScript definition file with all of the types defined, organized
as namespaces, if a filename is provided in input arguments. Otherwise generates one
declaration file for each namespace with the corresponding typescript definitions.
If a single output file is generated, a top level type definition will be added for the
Timestamp data type. Otherwise, each namespace will have the type definition for Timestamp.
Also, note that namespace definitions are emitted as declaration files. Hence any template
provided as argument must not have a top level declare statement. If namespaces are emitted
into a single file, the template file can be used to wrap them around a declare statement.
"""
cmdline_parser = _cmdline_parser
preserve_aliases = True
# Instance var of the current namespace being generated
cur_namespace = None # type: typing.Optional[ApiNamespace]
# Instance var to denote if one file is output for each namespace.
split_by_namespace = False
def generate(self, api):
extra_args = self._parse_extra_args(api, self.args.extra_arg)
template = self._read_template()
if self.args.filename:
self._generate_base_namespace_module(api.namespaces.values(), self.args.filename,
template, extra_args,
exclude_error_types=self.args.exclude_error_types)
else:
self.split_by_namespace = True
for namespace in api.namespaces.values():
filename = '{}.d.ts'.format(namespace.name)
self._generate_base_namespace_module(
[namespace], filename, template,
extra_args,
exclude_error_types=self.args.exclude_error_types)
def _read_template(self):
template_path = os.path.join(self.target_folder_path, self.args.template)
if os.path.isfile(template_path):
with open(template_path, 'r') as template_file:
return template_file.read()
else:
raise AssertionError('TypeScript template file does not exist.')
def _get_data_types(self, namespace):
return namespace.data_types + namespace.aliases
def _generate_base_namespace_module(self, namespace_list, filename,
template, extra_args,
exclude_error_types=False):
# Skip namespaces that do not contain types.
if all([len(self._get_data_types(ns)) == 0 for ns in namespace_list]):
return
spaces_per_indent = self.args.spaces_per_indent
indent_level = self.args.indent_level
with self.output_to_relative_path(filename):
# /*TYPES*/
t_match = re.search("/\\*TYPES\\*/", template)
if not t_match:
raise AssertionError('Missing /*TYPES*/ in TypeScript template file.')
t_start = t_match.start()
t_end = t_match.end()
t_ends_with_newline = template[t_end - 1] == '\n'
temp_end = len(template)
temp_ends_with_newline = template[temp_end - 1] == '\n'
self.emit_raw(template[0:t_start] + ("\n" if not t_ends_with_newline else ''))
indent = spaces_per_indent * indent_level
indent_spaces = (' ' * indent)
with self.indent(dent=indent):
if not exclude_error_types:
indented_types_header = indent_spaces + (
('\n' + indent_spaces)
.join(_types_header.split('\n'))
.replace('\t', ' ' * spaces_per_indent)
)
self.emit_raw(indented_types_header + '\n')
if not self.split_by_namespace:
self.emit(_timestamp_definition)
self.emit()
for namespace in namespace_list:
self._generate_types(namespace, spaces_per_indent, extra_args)
self.emit_raw(template[t_end + 1:temp_end] +
("\n" if not temp_ends_with_newline else ''))
def _generate_types(self, namespace, spaces_per_indent, extra_args):
self.cur_namespace = namespace
# Count aliases as data types too!
data_types = self._get_data_types(namespace)
# Skip namespaces that do not contain types.
if len(data_types) == 0:
return
if self.split_by_namespace:
generate_imports_for_referenced_namespaces(
backend=self, namespace=namespace, module_name_prefix=self.args.module_name_prefix
)
if namespace.doc:
self._emit_tsdoc_header(namespace.doc)
self.emit_wrapped_text(self._get_top_level_declaration(namespace.name))
with self.indent(dent=spaces_per_indent):
for data_type in data_types:
self._generate_type(data_type, spaces_per_indent,
extra_args.get(data_type, []))
if self.split_by_namespace:
with self.indent(dent=spaces_per_indent):
# TODO(Pranay): May avoid adding an unused definition if needed.
self.emit(_timestamp_definition)
self.emit('}')
self.emit()
def _get_top_level_declaration(self, name):
if self.split_by_namespace:
# Use module for when emitting declaration files.
return "declare module '%s%s' {" % (self.args.module_name_prefix, name)
else:
if self.args.export_namespaces:
return "export namespace %s {" % name
else:
# Use namespace for organizing code with-in the file.
return "namespace %s {" % name
def _parse_extra_args(self, api, extra_args_raw):
"""
Parses extra arguments into a map keyed on particular data types.
"""
extra_args = {}
def invalid(msg, extra_arg_raw):
print('Invalid --extra-arg:%s: %s' % (msg, extra_arg_raw),
file=sys.stderr)
sys.exit(1)
for extra_arg_raw in extra_args_raw:
try:
extra_arg = json.loads(extra_arg_raw)
except ValueError as e:
invalid(str(e), extra_arg_raw)
# Validate extra_arg JSON blob
if 'match' not in extra_arg:
invalid('No match key', extra_arg_raw)
elif (not isinstance(extra_arg['match'], list) or
len(extra_arg['match']) != 2):
invalid('match key is not a list of two strings', extra_arg_raw)
elif (not isinstance(extra_arg['match'][0], six.text_type) or
not isinstance(extra_arg['match'][1], six.text_type)):
print(type(extra_arg['match'][0]))
invalid('match values are not strings', extra_arg_raw)
elif 'arg_name' not in extra_arg:
invalid('No arg_name key', extra_arg_raw)
elif not isinstance(extra_arg['arg_name'], six.text_type):
invalid('arg_name is not a string', extra_arg_raw)
elif 'arg_type' not in extra_arg:
invalid('No arg_type key', extra_arg_raw)
elif not isinstance(extra_arg['arg_type'], six.text_type):
invalid('arg_type is not a string', extra_arg_raw)
elif ('arg_docstring' in extra_arg and
not isinstance(extra_arg['arg_docstring'], six.text_type)):
invalid('arg_docstring is not a string', extra_arg_raw)
attr_key, attr_val = extra_arg['match'][0], extra_arg['match'][1]
extra_args.setdefault(attr_key, {})[attr_val] = \
(extra_arg['arg_name'], extra_arg['arg_type'],
extra_arg.get('arg_docstring'))
# Extra arguments, keyed on data type objects.
extra_args_for_types = {}
# Locate data types that contain extra arguments
for namespace in api.namespaces.values():
for route in namespace.routes:
extra_parameters = []
if is_user_defined_type(route.arg_data_type):
for attr_key in route.attrs:
if attr_key not in extra_args:
continue
attr_val = route.attrs[attr_key]
if attr_val in extra_args[attr_key]:
extra_parameters.append(extra_args[attr_key][attr_val])
if len(extra_parameters) > 0:
extra_args_for_types[route.arg_data_type] = extra_parameters
return extra_args_for_types
def _emit_tsdoc_header(self, docstring):
self.emit('/**')
self.emit_wrapped_text(self.process_doc(docstring, self._docf), prefix=' * ')
self.emit(' */')
def _generate_type(self, data_type, indent_spaces, extra_args):
"""
Generates a TypeScript type for the given type.
"""
if is_alias(data_type):
self._generate_alias_type(data_type)
elif is_struct_type(data_type):
self._generate_struct_type(data_type, indent_spaces, extra_args)
elif is_union_type(data_type):
self._generate_union_type(data_type, indent_spaces)
def _generate_alias_type(self, alias_type):
"""
Generates a TypeScript type for a stone alias.
"""
namespace = alias_type.namespace
self.emit('export type %s = %s;' % (fmt_type_name(alias_type, namespace),
fmt_type_name(alias_type.data_type, namespace)))
self.emit()
def _generate_struct_type(self, struct_type, indent_spaces, extra_parameters):
"""
Generates a TypeScript interface for a stone struct.
"""
namespace = struct_type.namespace
if struct_type.doc:
self._emit_tsdoc_header(struct_type.doc)
parent_type = struct_type.parent_type
extends_line = ' extends %s' % fmt_type_name(parent_type, namespace) if parent_type else ''
self.emit('export interface %s%s {' % (fmt_type_name(struct_type, namespace), extends_line))
with self.indent(dent=indent_spaces):
for param_name, param_type, param_docstring in extra_parameters:
if param_docstring:
self._emit_tsdoc_header(param_docstring)
self.emit('%s: %s;' % (param_name, param_type))
for field in struct_type.fields:
doc = field.doc
field_type, nullable = unwrap_nullable(field.data_type)
field_ts_type = fmt_type(field_type, namespace)
optional = nullable or field.has_default
if field.has_default:
# doc may be None. If it is not empty, add newlines
# before appending to it.
doc = doc + '\n\n' if doc else ''
doc = "Defaults to %s." % field.default
if doc:
self._emit_tsdoc_header(doc)
# Translate nullable types into optional properties.
field_name = '%s?' % field.name if optional else field.name
self.emit('%s: %s;' % (field_name, field_ts_type))
self.emit('}')
self.emit()
# Some structs can explicitly list their subtypes. These structs have a .tag field that
# indicate which subtype they are, which is only present when a type reference is
# ambiguous.
# Emit a special interface that contains this extra field, and refer to it whenever we
# encounter a reference to a type with enumerated subtypes.
if struct_type.is_member_of_enumerated_subtypes_tree():
if struct_type.has_enumerated_subtypes():
# This struct is the parent to multiple subtypes. Determine all of the possible
# values of the .tag property.
tag_values = []
for tags, _ in struct_type.get_all_subtypes_with_tags():
for tag in tags:
tag_values.append('"%s"' % tag)
tag_union = fmt_union(tag_values)
self._emit_tsdoc_header('Reference to the %s polymorphic type. Contains a .tag '
'property to let you discriminate between possible '
'subtypes.' % fmt_type_name(struct_type, namespace))
self.emit('export interface %s extends %s {' %
(fmt_polymorphic_type_reference(struct_type, namespace),
fmt_type_name(struct_type, namespace)))
with self.indent(dent=indent_spaces):
self._emit_tsdoc_header('Tag identifying the subtype variant.')
self.emit('\'.tag\': %s;' % tag_union)
self.emit('}')
self.emit()
else:
# This struct is a particular subtype. Find the applicable .tag value from the
# parent type, which may be an arbitrary number of steps up the inheritance
# hierarchy.
parent = struct_type.parent_type
while not parent.has_enumerated_subtypes():
parent = parent.parent_type
# parent now contains the closest parent type in the inheritance hierarchy that has
# enumerated subtypes. Determine which subtype this is.
for subtype in parent.get_enumerated_subtypes():
if subtype.data_type == struct_type:
self._emit_tsdoc_header('Reference to the %s type, identified by the '
'value of the .tag property.' %
fmt_type_name(struct_type, namespace))
self.emit('export interface %s extends %s {' %
(fmt_polymorphic_type_reference(struct_type, namespace),
fmt_type_name(struct_type, namespace)))
with self.indent(dent=indent_spaces):
self._emit_tsdoc_header('Tag identifying this subtype variant. This '
'field is only present when needed to '
'discriminate between multiple possible '
'subtypes.')
self.emit_wrapped_text('\'.tag\': \'%s\';' % subtype.name)
self.emit('}')
self.emit()
break
def _generate_union_type(self, union_type, indent_spaces):
"""
Generates a TypeScript interface for a stone union.
"""
# Emit an interface for each variant. TypeScript 2.0 supports these tagged unions.
# https://github.com/Microsoft/TypeScript/wiki/What%27s-new-in-TypeScript#tagged-union-types
parent_type = union_type.parent_type
namespace = union_type.namespace
union_type_name = fmt_type_name(union_type, namespace)
variant_type_names = []
if parent_type:
variant_type_names.append(fmt_type_name(parent_type, namespace))
def _is_struct_without_enumerated_subtypes(data_type):
"""
:param data_type: any data type.
:return: True if the given data type is a struct which has no enumerated subtypes.
"""
return is_struct_type(data_type) and (
not data_type.has_enumerated_subtypes())
for variant in union_type.fields:
if variant.doc:
self._emit_tsdoc_header(variant.doc)
variant_name = '%s%s' % (union_type_name, fmt_pascal(variant.name))
variant_type_names.append(variant_name)
is_struct_without_enumerated_subtypes = _is_struct_without_enumerated_subtypes(
variant.data_type)
if is_struct_without_enumerated_subtypes:
self.emit('export interface %s extends %s {' % (
variant_name, fmt_type(variant.data_type, namespace)))
else:
self.emit('export interface %s {' % variant_name)
with self.indent(dent=indent_spaces):
# Since field contains non-alphanumeric character, we need to enclose
# it in quotation marks.
self.emit("'.tag': '%s';" % variant.name)
if is_void_type(variant.data_type) is False and (
not is_struct_without_enumerated_subtypes
):
self.emit("%s: %s;" % (variant.name, fmt_type(variant.data_type, namespace)))
self.emit('}')
self.emit()
if union_type.doc:
self._emit_tsdoc_header(union_type.doc)
self.emit('export type %s = %s;' % (union_type_name, ' | '.join(variant_type_names)))
self.emit()
def _docf(self, tag, val):
"""
Callback to process documentation references.
"""
return fmt_tag(self.cur_namespace, tag, val)

View file

@ -0,0 +1,380 @@
"""
A command-line interface for StoneAPI.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import codecs
import imp
import io
import json
import logging
import os
import six
import sys
import traceback
from .cli_helpers import parse_route_attr_filter
from .compiler import (
BackendException,
Compiler,
)
from .frontend.exception import InvalidSpec
from .frontend.frontend import specs_to_ir
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
# Hack to get around some of Python 2's standard library modules that
# accept ascii-encodable unicode literals in lieu of strs, but where
# actually passing such literals results in errors with mypy --py2. See
# <https://github.com/python/typeshed/issues/756> and
# <https://github.com/python/mypy/issues/2536>.
import importlib
argparse = importlib.import_module(str('argparse')) # type: typing.Any
# These backends come by default
_builtin_backends = (
'obj_c_client',
'obj_c_types',
'obj_c_tests',
'js_client',
'js_types',
'tsd_client',
'tsd_types',
'python_types',
'python_type_stubs',
'python_client',
'swift_types',
'swift_client',
)
# The parser for command line arguments
_cmdline_description = (
'Write your APIs in Stone. Use backends to translate your specification '
'into a target language or format. The following describes arguments to '
'the Stone CLI. To specify arguments that are specific to a backend, '
'add "--" followed by arguments. For example, "stone python_client . '
'example.spec -- -h".'
)
_cmdline_parser = argparse.ArgumentParser(description=_cmdline_description)
_cmdline_parser.add_argument(
'-v',
'--verbose',
action='count',
help='Print debugging statements.',
)
_backend_help = (
'Either the name of a built-in backend or the path to a backend '
'module. Paths to backend modules must end with a .stoneg.py extension. '
'The following backends are built-in: ' + ', '.join(_builtin_backends))
_cmdline_parser.add_argument(
'backend',
type=six.text_type,
help=_backend_help,
)
_cmdline_parser.add_argument(
'output',
type=six.text_type,
help='The folder to save generated files to.',
)
_cmdline_parser.add_argument(
'spec',
nargs='*',
type=six.text_type,
help=('Path to API specifications. Each must have a .stone extension. '
'If omitted or set to "-", the spec is read from stdin. Multiple '
'namespaces can be provided over stdin by concatenating multiple '
'specs together.'),
)
_cmdline_parser.add_argument(
'--clean-build',
action='store_true',
help='The path to the template SDK for the target language.',
)
_cmdline_parser.add_argument(
'-f',
'--filter-by-route-attr',
type=six.text_type,
help=('Removes routes that do not match the expression. The expression '
'must specify a route attribute on the left-hand side and a value '
'on the right-hand side. Use quotes for strings and bytes. The only '
'supported operators are "=" and "!=". For example, if "hide" is a '
'route attribute, we can use this filter: "hide!=true". You can '
'combine multiple expressions with "and"/"or" and use parentheses '
'to enforce precedence.'),
)
_cmdline_parser.add_argument(
'-r',
'--route-whitelist-filter',
type=six.text_type,
help=('Restrict datatype generation to only the routes specified in the whitelist '
'and their dependencies. Input should be a file containing a JSON dict with '
'the following form: {"route_whitelist": {}, "datatype_whitelist": {}} '
'where each object maps namespaces to lists of routes or datatypes to whitelist.'),
)
_cmdline_parser.add_argument(
'-a',
'--attribute',
action='append',
type=str,
default=[],
help=('Route attributes that the backend will have access to and '
'presumably expose in generated code. Use ":all" to select all '
'attributes defined in stone_cfg.Route. Note that you can filter '
'(-f) by attributes that are not listed here.'),
)
_filter_ns_group = _cmdline_parser.add_mutually_exclusive_group()
_filter_ns_group.add_argument(
'-w',
'--whitelist-namespace-routes',
action='append',
type=str,
default=[],
help='If set, backends will only see the specified namespaces as having routes.',
)
_filter_ns_group.add_argument(
'-b',
'--blacklist-namespace-routes',
action='append',
type=str,
default=[],
help='If set, backends will not see any routes for the specified namespaces.',
)
def main():
"""The entry point for the program."""
if '--' in sys.argv:
cli_args = sys.argv[1:sys.argv.index('--')]
backend_args = sys.argv[sys.argv.index('--') + 1:]
else:
cli_args = sys.argv[1:]
backend_args = []
args = _cmdline_parser.parse_args(cli_args)
debug = False
if args.verbose is None:
logging_level = logging.WARNING
elif args.verbose == 1:
logging_level = logging.INFO
elif args.verbose == 2:
logging_level = logging.DEBUG
debug = True
else:
print('error: I can only be so garrulous, try -vv.', file=sys.stderr)
sys.exit(1)
logging.basicConfig(level=logging_level)
if args.spec and args.spec[0].startswith('+') and args.spec[0].endswith('.py'):
# Hack: Special case for defining a spec in Python for testing purposes
# Use this if you want to define a Stone spec using a Python module.
# The module should should contain an api variable that references a
# :class:`stone.api.Api` object.
try:
api = imp.load_source('api', args.api[0]).api # pylint: disable=redefined-outer-name
except ImportError as e:
print('error: Could not import API description due to:',
e, file=sys.stderr)
sys.exit(1)
else:
if args.spec:
specs = []
read_from_stdin = False
for spec_path in args.spec:
if spec_path == '-':
read_from_stdin = True
elif not spec_path.endswith('.stone'):
print("error: Specification '%s' must have a .stone extension."
% spec_path,
file=sys.stderr)
sys.exit(1)
elif not os.path.exists(spec_path):
print("error: Specification '%s' cannot be found." % spec_path,
file=sys.stderr)
sys.exit(1)
else:
with open(spec_path) as f:
specs.append((spec_path, f.read()))
if read_from_stdin and specs:
print("error: Do not specify stdin and specification files "
"simultaneously.", file=sys.stderr)
sys.exit(1)
if not args.spec or read_from_stdin:
specs = []
if debug:
print('Reading specification from stdin.')
if six.PY2:
UTF8Reader = codecs.getreader('utf8')
sys.stdin = UTF8Reader(sys.stdin)
stdin_text = sys.stdin.read()
else:
stdin_buffer = sys.stdin.buffer # pylint: disable=no-member,useless-suppression
stdin_text = io.TextIOWrapper(stdin_buffer, encoding='utf-8').read()
parts = stdin_text.split('namespace')
if len(parts) == 1:
specs.append(('stdin.1', parts[0]))
else:
specs.append(
('stdin.1', '%snamespace%s' % (parts.pop(0), parts.pop(0))))
while parts:
specs.append(('stdin.%s' % (len(specs) + 1),
'namespace%s' % parts.pop(0)))
if args.filter_by_route_attr:
route_filter, route_filter_errors = parse_route_attr_filter(
args.filter_by_route_attr, debug)
if route_filter_errors:
print('Error(s) in route filter:', file=sys.stderr)
for err in route_filter_errors:
print(err, file=sys.stderr)
sys.exit(1)
else:
route_filter = None
if args.route_whitelist_filter:
with open(args.route_whitelist_filter) as f:
route_whitelist_filter = json.loads(f.read())
else:
route_whitelist_filter = None
try:
# TODO: Needs version
api = specs_to_ir(specs, debug=debug,
route_whitelist_filter=route_whitelist_filter)
except InvalidSpec as e:
print('%s:%s: error: %s' % (e.path, e.lineno, e.msg), file=sys.stderr)
if debug:
print('A traceback is included below in case this is a bug in '
'Stone.\n', traceback.format_exc(), file=sys.stderr)
sys.exit(1)
if api is None:
print('You must fix the above parsing errors for generation to '
'continue.', file=sys.stderr)
sys.exit(1)
if args.whitelist_namespace_routes:
for namespace_name in args.whitelist_namespace_routes:
if namespace_name not in api.namespaces:
print('error: Whitelisted namespace missing from spec: %s' %
namespace_name, file=sys.stderr)
sys.exit(1)
for namespace in api.namespaces.values():
if namespace.name not in args.whitelist_namespace_routes:
namespace.routes = []
namespace.route_by_name = {}
namespace.routes_by_name = {}
if args.blacklist_namespace_routes:
for namespace_name in args.blacklist_namespace_routes:
if namespace_name not in api.namespaces:
print('error: Blacklisted namespace missing from spec: %s' %
namespace_name, file=sys.stderr)
sys.exit(1)
else:
namespace = api.namespaces[namespace_name]
namespace.routes = []
namespace.route_by_name = {}
namespace.routes_by_name = {}
if route_filter:
for namespace in api.namespaces.values():
filtered_routes = []
for route in namespace.routes:
if route_filter.eval(route):
filtered_routes.append(route)
namespace.routes = []
namespace.route_by_name = {}
namespace.routes_by_name = {}
for route in filtered_routes:
namespace.add_route(route)
if args.attribute:
attrs = set(args.attribute)
if ':all' in attrs:
attrs = {field.name for field in api.route_schema.fields}
else:
attrs = set()
for namespace in api.namespaces.values():
for route in namespace.routes:
for k in list(route.attrs.keys()):
if k not in attrs:
del route.attrs[k]
# Remove attrs that weren't specified from the route schema
for field in api.route_schema.fields[:]:
if field.name not in attrs:
api.route_schema.fields.remove(field)
del api.route_schema._fields_by_name[field.name]
else:
attrs.remove(field.name)
# Error if specified attr isn't even a field in the route schema
if attrs:
attr = attrs.pop()
print('error: Attribute not defined in stone_cfg.Route: %s' %
attr, file=sys.stderr)
sys.exit(1)
if args.backend in _builtin_backends:
backend_module = __import__(
'stone.backends.%s' % args.backend, fromlist=[''])
elif not os.path.exists(args.backend):
print("error: Backend '%s' cannot be found." % args.backend,
file=sys.stderr)
sys.exit(1)
elif not os.path.isfile(args.backend):
print("error: Backend '%s' must be a file." % args.backend,
file=sys.stderr)
sys.exit(1)
elif not Compiler.is_stone_backend(args.backend):
print("error: Backend '%s' must have a .stoneg.py extension." %
args.backend, file=sys.stderr)
sys.exit(1)
else:
# A bit hacky, but we add the folder that the backend is in to our
# python path to support the case where the backend imports other
# files in its local directory.
new_python_path = os.path.dirname(args.backend)
if new_python_path not in sys.path:
sys.path.append(new_python_path)
try:
backend_module = imp.load_source('user_backend', args.backend)
except Exception:
print("error: Importing backend '%s' module raised an exception:" %
args.backend, file=sys.stderr)
raise
c = Compiler(
api,
backend_module,
backend_args,
args.output,
clean_build=args.clean_build,
)
try:
c.build()
except BackendException as e:
print('%s: error: %s raised an exception:\n%s' %
(args.backend, e.backend_name, e.traceback),
file=sys.stderr)
sys.exit(1)
if not sys.argv[0].endswith('stone'):
# If we aren't running from an entry_point, then return api to make it
# easier to do debugging.
return api
if __name__ == '__main__':
# Assign api variable for easy debugging from a Python console
api = main()

View file

@ -0,0 +1,237 @@
import abc
import six
from ply import lex, yacc
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
class FilterExprLexer(object):
tokens = (
'ID',
'LPAR',
'RPAR',
) # type: typing.Tuple[str, ...]
# Conjunctions
tokens += (
'AND',
'OR',
)
# Comparison operators
tokens += (
'NEQ',
'EQ',
)
# Primitive types
tokens += (
'BOOLEAN',
'FLOAT',
'INTEGER',
'NULL',
'STRING',
)
t_LPAR = r'\('
t_RPAR = r'\)'
t_NEQ = r'!='
t_EQ = r'='
t_ignore = ' '
KEYWORDS = {
'and': 'AND',
'or': 'OR',
}
def __init__(self, debug=False):
self.lexer = lex.lex(module=self, debug=debug)
self.errors = []
def get_yacc_compat_lexer(self):
return self.lexer
def t_BOOLEAN(self, token):
r'\btrue\b|\bfalse\b'
token.value = (token.value == 'true')
return token
def t_NULL(self, token):
r'\bnull\b'
token.value = None
return token
def t_FLOAT(self, token):
r'-?\d+(\.\d*(e-?\d+)?|e-?\d+)'
token.value = float(token.value)
return token
def t_INTEGER(self, token):
r'-?\d+'
token.value = int(token.value)
return token
def t_STRING(self, token):
r'\"([^\\"]|(\\.))*\"'
token.value = token.value[1:-1]
return token
def t_ID(self, token):
r'[a-zA-Z_][a-zA-Z0-9_-]*'
if token.value in self.KEYWORDS:
token.type = self.KEYWORDS[token.value]
return token
else:
return token
# Error handling rule
def t_error(self, token):
self.errors.append(
('Illegal character %s.' % repr(token.value[0]).lstrip('u')))
token.lexer.skip(1)
# Test output
def test(self, data):
self.lexer.input(data)
while True:
tok = self.lexer.token()
if not tok:
break
print(tok)
class FilterExprParser(object):
# Ply parser requiment: Tokens must be re-specified in parser
tokens = FilterExprLexer.tokens
# Ply wants a 'str' instance; this makes it work in Python 2 and 3
start = str('expr')
# To match most languages, give logical conjunctions a higher precedence
# than logical disjunctions.
precedence = (
('left', 'OR'),
('left', 'AND'),
)
def __init__(self, debug=False):
self.debug = debug
self.yacc = yacc.yacc(module=self, debug=debug, write_tables=debug)
self.lexer = FilterExprLexer(debug)
self.errors = []
def parse(self, data):
"""
Args:
data (str): Raw filter expression.
"""
parsed_data = self.yacc.parse(
data, lexer=self.lexer.get_yacc_compat_lexer(), debug=self.debug)
self.errors = self.lexer.errors + self.errors
return parsed_data, self.errors
def p_expr(self, p):
'expr : pred'
p[0] = p[1]
def p_expr_parens(self, p):
'expr : LPAR expr RPAR'
p[0] = p[2]
def p_expr_group(self, p):
"""expr : expr OR expr
| expr AND expr"""
p[0] = FilterExprConjunction(p[2], p[1], p[3])
def p_pred(self, p):
'pred : ID op primitive'
p[0] = FilterExprPredicate(p[2], p[1], p[3])
def p_op(self, p):
"""op : NEQ
| EQ"""
p[0] = p[1]
def p_primitive(self, p):
"""primitive : BOOLEAN
| FLOAT
| INTEGER
| NULL
| STRING"""
p[0] = p[1]
def p_error(self, token):
if token:
self.errors.append(
("Unexpected %s with value %s." %
(token.type, repr(token.value).lstrip('u'))))
else:
self.errors.append('Unexpected end of expression.')
class FilterExpr(object):
__metaclass__ = abc.ABCMeta
@abc.abstractmethod
def eval(self, route):
pass
class FilterExprConjunction(object):
def __init__(self, conj, lhs, rhs):
self.conj = conj
self.lhs = lhs
self.rhs = rhs
def eval(self, route):
if self.conj == 'and':
return self.lhs.eval(route) and self.rhs.eval(route)
elif self.conj == 'or':
return self.lhs.eval(route) or self.rhs.eval(route)
else:
assert False
def __repr__(self):
return 'EvalConj(%r, %r, %r)' % (self.conj, self.lhs, self.rhs)
class FilterExprPredicate(object):
def __init__(self, op, lhs, rhs):
self.op = op
self.lhs = lhs
self.rhs = rhs
def eval(self, route):
val = route.attrs.get(self.lhs, None)
if self.op == '=':
return val == self.rhs
elif self.op == '!=':
return val != self.rhs
else:
assert False
def __repr__(self):
return 'EvalPred(%r, %r, %r)' % (self.op, self.lhs, self.rhs)
def parse_route_attr_filter(route_attr_filter, debug=False):
"""
Args:
route_attr_filter (str): The raw command-line input of the route
filter.
Returns:
Tuple[FilterExpr, List[str]]: The second element is a list of errors.
"""
assert isinstance(route_attr_filter, six.text_type), type(route_attr_filter)
parser = FilterExprParser(debug)
return parser.parse(route_attr_filter)

View file

@ -0,0 +1,126 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import inspect
import os
import shutil
import traceback
from stone.backend import (
Backend,
remove_aliases_from_api,
)
class BackendException(Exception):
"""Saves the traceback of an exception raised by a backend."""
def __init__(self, backend_name, tb):
"""
:type backend_name: str
:type tb: str
"""
super(BackendException, self).__init__()
self.backend_name = backend_name
self.traceback = tb
class Compiler(object):
"""
Applies a collection of backends found in a single backend module to an
API specification.
"""
backend_extension = '.stoneg'
def __init__(self,
api,
backend_module,
backend_args,
build_path,
clean_build=False):
"""
Creates a Compiler.
:param stone.ir.Api api: A Stone description of the API.
:param backend_module: Python module that contains at least one
top-level class definition that descends from a
:class:`stone.backend.Backend`.
:param list(str) backend_args: A list of command-line arguments to
pass to the backend.
:param str build_path: Location to save compiled sources to. If None,
source files are compiled into the same directories.
:param bool clean_build: If True, the build_path is removed before
source files are compiled into them.
"""
self._logger = logging.getLogger('stone.compiler')
self.api = api
self.backend_module = backend_module
self.backend_args = backend_args
self.build_path = build_path
# Remove existing build directory if it's a clean build
if clean_build and os.path.exists(self.build_path):
logging.info('Cleaning existing build directory %s...',
self.build_path)
shutil.rmtree(self.build_path)
def build(self):
"""Creates outputs. Outputs are files made by a backend."""
if os.path.exists(self.build_path) and not os.path.isdir(self.build_path):
self._logger.error('Output path must be a folder if it already exists')
return
Compiler._mkdir(self.build_path)
self._execute_backend_on_spec()
@staticmethod
def _mkdir(path):
"""
Creates a directory at path if it doesn't exist. If it does exist,
this function does nothing. Note that if path is a file, it will not
be converted to a directory.
"""
try:
os.makedirs(path)
except OSError as e:
if e.errno != 17:
raise
@classmethod
def is_stone_backend(cls, path):
"""
Returns True if the file name matches the format of a stone backend,
ie. its inner extension of "stoneg". For example: xyz.stoneg.py
"""
path_without_ext, _ = os.path.splitext(path)
_, second_ext = os.path.splitext(path_without_ext)
return second_ext == cls.backend_extension
def _execute_backend_on_spec(self):
"""Renders a source file into its final form."""
api_no_aliases_cache = None
for attr_key in dir(self.backend_module):
attr_value = getattr(self.backend_module, attr_key)
if (inspect.isclass(attr_value) and
issubclass(attr_value, Backend) and
not inspect.isabstract(attr_value)):
self._logger.info('Running backend: %s', attr_value.__name__)
backend = attr_value(self.build_path, self.backend_args)
if backend.preserve_aliases:
api = self.api
else:
if not api_no_aliases_cache:
api_no_aliases_cache = remove_aliases_from_api(self.api)
api = api_no_aliases_cache
try:
backend.generate(api)
except Exception:
# Wrap this exception so that it isn't thought of as a bug
# in the stone parser, but rather a bug in the backend.
# Remove the last char of the traceback b/c it's a newline.
raise BackendException(
attr_value.__name__, traceback.format_exc()[:-1])

View file

@ -0,0 +1,443 @@
from collections import OrderedDict
import six
class ASTNode(object):
def __init__(self, path, lineno, lexpos):
"""
Args:
lineno (int): The line number where the start of this element
occurs.
lexpos (int): The character offset into the file where this element
occurs.
"""
self.path = path
self.lineno = lineno
self.lexpos = lexpos
class AstNamespace(ASTNode):
def __init__(self, path, lineno, lexpos, name, doc):
"""
Args:
name (str): The namespace of the spec.
doc (Optional[str]): The docstring for this namespace.
"""
super(AstNamespace, self).__init__(path, lineno, lexpos)
self.name = name
self.doc = doc
def __str__(self):
return self.__repr__()
def __repr__(self):
return 'AstNamespace({!r})'.format(self.name)
class AstImport(ASTNode):
def __init__(self, path, lineno, lexpos, target):
"""
Args:
target (str): The name of the namespace to import.
"""
super(AstImport, self).__init__(path, lineno, lexpos)
self.target = target
def __str__(self):
return self.__repr__()
def __repr__(self):
return 'AstImport({!r})'.format(self.target)
class AstAlias(ASTNode):
def __init__(self, path, lineno, lexpos, name, type_ref, doc):
"""
Args:
name (str): The name of the alias.
type_ref (AstTypeRef): The data type of the field.
doc (Optional[str]): Documentation string for the alias.
"""
super(AstAlias, self).__init__(path, lineno, lexpos)
self.name = name
self.type_ref = type_ref
self.doc = doc
self.annotations = []
def set_annotations(self, annotations):
self.annotations = annotations
def __repr__(self):
return 'AstAlias({!r}, {!r})'.format(self.name, self.type_ref)
class AstTypeDef(ASTNode):
def __init__(self, path, lineno, lexpos, name, extends, doc, fields,
examples):
"""
Args:
name (str): Name assigned to the type.
extends (Optional[str]); Name of the type this inherits from.
doc (Optional[str]): Docstring for the type.
fields (List[AstField]): Fields of a type, not including
inherited ones.
examples (Optional[OrderedDict[str, AstExample]]): Map from label
to example.
"""
super(AstTypeDef, self).__init__(path, lineno, lexpos)
self.name = name
assert isinstance(extends, (AstTypeRef, type(None))), type(extends)
self.extends = extends
assert isinstance(doc, (six.text_type, type(None)))
self.doc = doc
assert isinstance(fields, list)
self.fields = fields
assert isinstance(examples, (OrderedDict, type(None))), type(examples)
self.examples = examples
def __str__(self):
return self.__repr__()
def __repr__(self):
return 'AstTypeDef({!r}, {!r}, {!r})'.format(
self.name,
self.extends,
self.fields,
)
class AstStructDef(AstTypeDef):
def __init__(self, path, lineno, lexpos, name, extends, doc, fields,
examples, subtypes=None):
"""
Args:
subtypes (Tuple[List[AstSubtypeField], bool]): Inner list
enumerates subtypes. The bool indicates whether this struct
is a catch-all.
See AstTypeDef for other constructor args.
"""
super(AstStructDef, self).__init__(
path, lineno, lexpos, name, extends, doc, fields, examples)
assert isinstance(subtypes, (tuple, type(None))), type(subtypes)
self.subtypes = subtypes
def __repr__(self):
return 'AstStructDef({!r}, {!r}, {!r})'.format(
self.name,
self.extends,
self.fields,
)
class AstStructPatch(ASTNode):
def __init__(self, path, lineno, lexpos, name, fields, examples):
super(AstStructPatch, self).__init__(path, lineno, lexpos)
self.name = name
assert isinstance(fields, list)
self.fields = fields
assert isinstance(examples, (OrderedDict, type(None))), type(examples)
self.examples = examples
def __repr__(self):
return 'AstStructPatch({!r}, {!r})'.format(
self.name,
self.fields,
)
class AstUnionDef(AstTypeDef):
def __init__(self, path, lineno, lexpos, name, extends, doc, fields,
examples, closed=False):
"""
Args:
closed (bool): Set if this is a closed union.
See AstTypeDef for other constructor args.
"""
super(AstUnionDef, self).__init__(
path, lineno, lexpos, name, extends, doc, fields, examples)
self.closed = closed
def __repr__(self):
return 'AstUnionDef({!r}, {!r}, {!r}, {!r})'.format(
self.name,
self.extends,
self.fields,
self.closed,
)
class AstUnionPatch(ASTNode):
def __init__(self, path, lineno, lexpos, name, fields, examples, closed):
super(AstUnionPatch, self).__init__(path, lineno, lexpos)
self.name = name
assert isinstance(fields, list)
self.fields = fields
assert isinstance(examples, (OrderedDict, type(None))), type(examples)
self.examples = examples
self.closed = closed
def __repr__(self):
return 'AstUnionPatch({!r}, {!r}, {!r})'.format(
self.name,
self.fields,
self.closed,
)
class AstTypeRef(ASTNode):
def __init__(self, path, lineno, lexpos, name, args, nullable, ns):
"""
Args:
name (str): Name of the referenced type.
args (tuple[list, dict]): Arguments to type.
nullable (bool): Whether the type is nullable (can be null)
ns (Optional[str]): Namespace that referred type is a member of.
If none, then refers to the current namespace.
"""
super(AstTypeRef, self).__init__(path, lineno, lexpos)
self.name = name
self.args = args
self.nullable = nullable
self.ns = ns
def __repr__(self):
return 'AstTypeRef({!r}, {!r}, {!r}, {!r})'.format(
self.name,
self.args,
self.nullable,
self.ns,
)
class AstTagRef(ASTNode):
def __init__(self, path, lineno, lexpos, tag):
"""
Args:
tag (str): Name of the referenced type.
"""
super(AstTagRef, self).__init__(path, lineno, lexpos)
self.tag = tag
def __repr__(self):
return 'AstTagRef({!r})'.format(
self.tag,
)
class AstAnnotationRef(ASTNode):
def __init__(self, path, lineno, lexpos, annotation, ns):
"""
Args:
annotation (str): Name of the referenced annotation.
"""
super(AstAnnotationRef, self).__init__(path, lineno, lexpos)
self.annotation = annotation
self.ns = ns
def __repr__(self):
return 'AstAnnotationRef({!r}, {!r})'.format(
self.annotation, self.ns
)
class AstAnnotationDef(ASTNode):
def __init__(self, path, lineno, lexpos, name, annotation_type,
annotation_type_ns, args, kwargs):
"""
Args:
name (str): Name of the defined annotation.
annotation_type (str): Type of annotation to define.
annotation_type_ns (Optional[str]): Namespace where the annotation
type was defined. If None, current namespace or builtin.
args (str): Arguments to define annotation.
kwargs (str): Keyword Arguments to define annotation.
"""
super(AstAnnotationDef, self).__init__(path, lineno, lexpos)
self.name = name
self.annotation_type = annotation_type
self.annotation_type_ns = annotation_type_ns
self.args = args
self.kwargs = kwargs
def __repr__(self):
return 'AstAnnotationDef({!r}, {!r}, {!r}, {!r}, {!r})'.format(
self.name,
self.annotation_type,
self.annotation_type_ns,
self.args,
self.kwargs,
)
class AstAnnotationTypeDef(ASTNode):
def __init__(self, path, lineno, lexpos, name, doc, params):
"""
Args:
name (str): Name of the defined annotation type.
doc (str): Docstring for the defined annotation type.
params (List[AstField]): Parameters that can be passed to the
annotation type.
"""
super(AstAnnotationTypeDef, self).__init__(path, lineno, lexpos)
self.name = name
self.doc = doc
self.params = params
def __repr__(self):
return 'AstAnnotationTypeDef({!r}, {!r}, {!r})'.format(
self.name,
self.doc,
self.params,
)
class AstField(ASTNode):
"""
Represents both a field of a struct and a field of a union.
TODO(kelkabany): Split this into two different classes.
"""
def __init__(self, path, lineno, lexpos, name, type_ref):
"""
Args:
name (str): The name of the field.
type_ref (AstTypeRef): The data type of the field.
"""
super(AstField, self).__init__(path, lineno, lexpos)
self.name = name
self.type_ref = type_ref
self.doc = None
self.has_default = False
self.default = None
self.annotations = []
def set_doc(self, docstring):
self.doc = docstring
def set_default(self, default):
self.has_default = True
self.default = default
def set_annotations(self, annotations):
self.annotations = annotations
def __repr__(self):
return 'AstField({!r}, {!r}, {!r})'.format(
self.name,
self.type_ref,
self.annotations,
)
class AstVoidField(ASTNode):
def __init__(self, path, lineno, lexpos, name):
super(AstVoidField, self).__init__(path, lineno, lexpos)
self.name = name
self.doc = None
self.annotations = []
def set_doc(self, docstring):
self.doc = docstring
def set_annotations(self, annotations):
self.annotations = annotations
def __str__(self):
return self.__repr__()
def __repr__(self):
return 'AstVoidField({!r}, {!r})'.format(
self.name,
self.annotations,
)
class AstSubtypeField(ASTNode):
def __init__(self, path, lineno, lexpos, name, type_ref):
super(AstSubtypeField, self).__init__(path, lineno, lexpos)
self.name = name
self.type_ref = type_ref
def __repr__(self):
return 'AstSubtypeField({!r}, {!r})'.format(
self.name,
self.type_ref,
)
class AstRouteDef(ASTNode):
def __init__(self, path, lineno, lexpos, name, version, deprecated,
arg_type_ref, result_type_ref, error_type_ref=None):
super(AstRouteDef, self).__init__(path, lineno, lexpos)
self.name = name
self.version = version
self.deprecated = deprecated
self.arg_type_ref = arg_type_ref
self.result_type_ref = result_type_ref
self.error_type_ref = error_type_ref
self.doc = None
self.attrs = {}
def set_doc(self, docstring):
self.doc = docstring
def set_attrs(self, attrs):
self.attrs = attrs
class AstAttrField(ASTNode):
def __init__(self, path, lineno, lexpos, name, value):
super(AstAttrField, self).__init__(path, lineno, lexpos)
self.name = name
self.value = value
def __repr__(self):
return 'AstAttrField({!r}, {!r})'.format(
self.name,
self.value,
)
class AstExample(ASTNode):
def __init__(self, path, lineno, lexpos, label, text, fields):
super(AstExample, self).__init__(path, lineno, lexpos)
self.label = label
self.text = text
self.fields = fields
def __repr__(self):
return 'AstExample({!r}, {!r}, {!r})'.format(
self.label,
self.text,
self.fields,
)
class AstExampleField(ASTNode):
def __init__(self, path, lineno, lexpos, name, value):
super(AstExampleField, self).__init__(path, lineno, lexpos)
self.name = name
self.value = value
def __repr__(self):
return 'AstExampleField({!r}, {!r})'.format(
self.name,
self.value,
)
class AstExampleRef(ASTNode):
def __init__(self, path, lineno, lexpos, label):
super(AstExampleRef, self).__init__(path, lineno, lexpos)
self.label = label
def __repr__(self):
return 'AstExampleRef({!r})'.format(self.label)

View file

@ -0,0 +1,28 @@
import six
class InvalidSpec(Exception):
"""Raise this to indicate there was an error in a specification."""
def __init__(self, msg, lineno, path=None):
"""
Args:
msg: Error message intended for the spec writer to read.
lineno: The line number the error occurred on.
path: Path to the spec file with the error.
"""
super(InvalidSpec, self).__init__()
assert isinstance(msg, six.text_type), type(msg)
assert isinstance(lineno, (six.integer_types, type(None))), type(lineno)
self.msg = msg
self.lineno = lineno
self.path = path
def __str__(self):
return repr(self)
def __repr__(self):
return 'InvalidSpec({!r}, {!r}, {!r})'.format(
self.msg,
self.lineno,
self.path,
)

View file

@ -0,0 +1,55 @@
import logging
from .exception import InvalidSpec
from .parser import (
ParserFactory,
)
from .ir_generator import IRGenerator
logger = logging.getLogger('stone.frontend.frontend')
# FIXME: Version should not have a default.
def specs_to_ir(specs, version='0.1b1', debug=False, route_whitelist_filter=None):
"""
Converts a collection of Stone specifications into the intermediate
representation used by Stone backends.
The process is: Lexer -> Parser -> Semantic Analyzer -> IR Generator.
The code is structured as:
1. Parser (Lexer embedded within)
2. IR Generator (Semantic Analyzer embedded within)
:type specs: List[Tuple[path: str, text: str]]
:param specs: `path` is never accessed and is only used to report the
location of a bad spec to the user. `spec` is the text contents of
a spec (.stone) file.
:raises: InvalidSpec
:returns: stone.ir.Api
"""
parser_factory = ParserFactory(debug=debug)
partial_asts = []
for path, text in specs:
logger.info('Parsing spec %s', path)
parser = parser_factory.get_parser()
if debug:
parser.test_lexing(text)
partial_ast = parser.parse(text, path)
if parser.got_errors_parsing():
# TODO(kelkabany): Show more than one error at a time.
msg, lineno, path = parser.get_errors()[0]
raise InvalidSpec(msg, lineno, path)
elif len(partial_ast) == 0:
logger.info('Empty spec: %s', path)
else:
partial_asts.append(partial_ast)
return IRGenerator(partial_asts, version, debug=debug,
route_whitelist_filter=route_whitelist_filter).generate_IR()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,446 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
import ply.lex as lex
_MYPY = False
if _MYPY:
import typing # noqa: F401 # pylint: disable=import-error,unused-import,useless-suppression
class MultiToken(object):
"""Object used to monkeypatch ply.lex so that we can return multiple
tokens from one lex operation."""
def __init__(self, tokens):
self.type = tokens[0].type
self.tokens = tokens
# Represents a null value. We want to differentiate between the Python "None"
# and null in several places.
NullToken = object()
class Lexer(object):
"""
Lexer. Tokenizes stone files.
"""
states = (
('WSIGNORE', 'inclusive'),
)
def __init__(self):
self.lex = None
self.tokens_queue = None
# The current indentation "level" rather than a count of spaces.
self.cur_indent = None
self._logger = logging.getLogger('stone.stone.lexer')
self.last_token = None
# [(character, line number), ...]
self.errors = []
def input(self, file_data, **kwargs):
"""
Required by ply.yacc for this to quack (duck typing) like a ply lexer.
:param str file_data: Contents of the file to lex.
"""
self.lex = lex.lex(module=self, **kwargs)
self.tokens_queue = []
self.cur_indent = 0
# Hack to avoid tokenization bugs caused by files that do not end in a
# new line.
self.lex.input(file_data + '\n')
def token(self):
"""
Returns the next LexToken. Returns None when all tokens have been
exhausted.
"""
if self.tokens_queue:
self.last_token = self.tokens_queue.pop(0)
else:
r = self.lex.token()
if isinstance(r, MultiToken):
self.tokens_queue.extend(r.tokens)
self.last_token = self.tokens_queue.pop(0)
else:
if r is None and self.cur_indent > 0:
if (self.last_token and
self.last_token.type not in ('NEWLINE', 'LINE')):
newline_token = _create_token(
'NEWLINE', '\n', self.lex.lineno, self.lex.lexpos)
self.tokens_queue.append(newline_token)
dedent_count = self.cur_indent
dedent_token = _create_token(
'DEDENT', '\t', self.lex.lineno, self.lex.lexpos)
self.tokens_queue.extend([dedent_token] * dedent_count)
self.cur_indent = 0
self.last_token = self.tokens_queue.pop(0)
else:
self.last_token = r
return self.last_token
def test(self, data):
"""Logs all tokens for human inspection. Useful for debugging."""
self.input(data)
while True:
token = self.token()
if not token:
break
self._logger.debug('Token %r', token)
# List of token names
tokens = (
'ID',
'KEYWORD',
'PATH',
'DOT',
) # type: typing.Tuple[typing.Text, ...]
# Whitespace tokens
tokens += (
'DEDENT',
'INDENT',
'NEWLINE',
)
# Attribute lists, aliases
tokens += (
'COMMA',
'EQ',
'LPAR',
'RPAR',
)
# Primitive types
tokens += (
'BOOLEAN',
'FLOAT',
'INTEGER',
'NULL',
'STRING',
)
# List notation
tokens += (
'LBRACKET',
'RBRACKET',
)
# Map notation
tokens += (
'LBRACE',
'RBRACE',
'COLON',
)
tokens += (
'Q',
)
# Annotation notation
tokens += (
'AT',
)
# Regular expression rules for simple tokens
t_DOT = r'\.'
t_LBRACKET = r'\['
t_RBRACKET = r'\]'
t_EQ = r'='
t_COMMA = r','
t_Q = r'\?'
t_LBRACE = r'\{'
t_RBRACE = r'\}'
t_COLON = r'\:'
t_AT = r'@'
# TODO(kelkabany): Use scoped/conditional lexing to restrict where keywords
# are identified as such.
KEYWORDS = [
'alias',
'annotation',
'annotation_type',
'attrs',
'by',
'deprecated',
'doc',
'example',
'error',
'extends',
'import',
'namespace',
'patch',
'route',
'struct',
'union',
'union_closed',
]
RESERVED = {
'annotation': 'ANNOTATION',
'annotation_type': 'ANNOTATION_TYPE',
'attrs': 'ATTRS',
'deprecated': 'DEPRECATED',
'by': 'BY',
'extends': 'EXTENDS',
'import': 'IMPORT',
'patch': 'PATCH',
'route': 'ROUTE',
'struct': 'STRUCT',
'union': 'UNION',
'union_closed': 'UNION_CLOSED',
}
tokens += tuple(RESERVED.values())
def t_LPAR(self, token):
r'\('
token.lexer.push_state('WSIGNORE')
return token
def t_RPAR(self, token):
r'\)'
token.lexer.pop_state()
return token
def t_ANY_BOOLEAN(self, token):
r'\btrue\b|\bfalse\b'
token.value = (token.value == 'true')
return token
def t_ANY_NULL(self, token):
r'\bnull\b'
token.value = NullToken
return token
# No leading digits
def t_ANY_ID(self, token):
r'[a-zA-Z_][a-zA-Z0-9_-]*'
if token.value in self.KEYWORDS:
if (token.value == 'annotation_type') and self.cur_indent:
# annotation_type was added as a reserved keyword relatively
# late, when there could be identifers with the same name
# in existing specs. because annotation_type-the-keyword can
# only be used at the beginning of a non-indented line, this
# check lets both the keyword and the identifer coexist and
# maintains backward compatibility.
# Note: this is kind of a hack, and we should get rid of it if
# the lexer gets better at telling keywords from identifiers in general.
return token
token.type = self.RESERVED.get(token.value, 'KEYWORD')
return token
else:
return token
def t_ANY_PATH(self, token):
r'\/[/a-zA-Z0-9_-]*'
return token
def t_ANY_FLOAT(self, token):
r'-?\d+(\.\d*(e-?\d+)?|e-?\d+)'
token.value = float(token.value)
return token
def t_ANY_INTEGER(self, token):
r'-?\d+'
token.value = int(token.value)
return token
# Read in a string while respecting the following escape sequences:
# \", \\, \n, and \t.
def t_ANY_STRING(self, t):
r'\"([^\\"]|(\\.))*\"'
escaped = 0
t.lexer.lineno += t.value.count('\n')
s = t.value[1:-1]
new_str = ""
for i in range(0, len(s)):
c = s[i]
if escaped:
if c == 'n':
c = '\n'
elif c == 't':
c = '\t'
new_str += c
escaped = 0
else:
if c == '\\':
escaped = 1
else:
new_str += c
# remove current indentation
indentation_str = ' ' * _indent_level_to_spaces_count(self.cur_indent)
lines_without_indentation = [
line.replace(indentation_str, '', 1)
for line in new_str.splitlines()]
t.value = '\n'.join(lines_without_indentation)
return t
# Ignore comments.
# There are two types of comments.
# 1. Comments that take up a full line. These lines are ignored entirely.
# 2. Comments that come after tokens in the same line. These comments
# are ignored, but, we still need to emit a NEWLINE since this rule
# takes all trailing newlines.
# Regardless of comment type, the following line must be checked for a
# DEDENT or INDENT.
def t_INITIAL_comment(self, token):
r'[#][^\n]*\n+'
token.lexer.lineno += token.value.count('\n')
# Scan backwards from the comment hash to figure out which type of
# comment this is. If we find an non-ws character, we know it was a
# partial line. But, if we find a newline before a non-ws character,
# then we know the entire line was a comment.
i = token.lexpos - 1
while i >= 0:
is_full_line_comment = token.lexer.lexdata[i] == '\n'
is_partial_line_comment = (not is_full_line_comment and
token.lexer.lexdata[i] != ' ')
if is_full_line_comment or is_partial_line_comment:
newline_token = _create_token('NEWLINE', '\n',
token.lineno, token.lexpos + len(token.value) - 1)
newline_token.lexer = token.lexer
dent_tokens = self._create_tokens_for_next_line_dent(
newline_token)
if is_full_line_comment:
# Comment takes the full line so ignore entirely.
return dent_tokens
elif is_partial_line_comment:
# Comment is only a partial line. Preserve newline token.
if dent_tokens:
dent_tokens.tokens.insert(0, newline_token)
return dent_tokens
else:
return newline_token
i -= 1
def t_WSIGNORE_comment(self, token):
r'[#][^\n]*\n+'
token.lexer.lineno += token.value.count('\n')
newline_token = _create_token('NEWLINE', '\n',
token.lineno, token.lexpos + len(token.value) - 1)
newline_token.lexer = token.lexer
self._check_for_indent(newline_token)
# Define a rule so we can track line numbers
def t_INITIAL_NEWLINE(self, newline_token):
r'\n+'
newline_token.lexer.lineno += newline_token.value.count('\n')
dent_tokens = self._create_tokens_for_next_line_dent(newline_token)
if dent_tokens:
dent_tokens.tokens.insert(0, newline_token)
return dent_tokens
else:
return newline_token
def t_WSIGNORE_NEWLINE(self, newline_token):
r'\n+'
newline_token.lexer.lineno += newline_token.value.count('\n')
self._check_for_indent(newline_token)
def _create_tokens_for_next_line_dent(self, newline_token):
"""
Starting from a newline token that isn't followed by another newline
token, returns any indent or dedent tokens that immediately follow.
If indentation doesn't change, returns None.
"""
indent_delta = self._get_next_line_indent_delta(newline_token)
if indent_delta is None or indent_delta == 0:
# Next line's indent isn't relevant OR there was no change in
# indentation.
return None
dent_type = 'INDENT' if indent_delta > 0 else 'DEDENT'
dent_token = _create_token(
dent_type, '\t', newline_token.lineno + 1,
newline_token.lexpos + len(newline_token.value))
tokens = [dent_token] * abs(indent_delta)
self.cur_indent += indent_delta
return MultiToken(tokens)
def _check_for_indent(self, newline_token):
"""
Checks that the line following a newline is indented, otherwise a
parsing error is generated.
"""
indent_delta = self._get_next_line_indent_delta(newline_token)
if indent_delta is None or indent_delta == 1:
# Next line's indent isn't relevant (e.g. it's a comment) OR
# next line is correctly indented.
return None
else:
self.errors.append(
('Line continuation must increment indent by 1.',
newline_token.lexer.lineno))
def _get_next_line_indent_delta(self, newline_token):
"""
Returns the change in indentation. The return units are in
indentations rather than spaces/tabs.
If the next line's indent isn't relevant (e.g. it's a comment),
returns None. Since the return value might be 0, the caller should
explicitly check the return type, rather than rely on truthiness.
"""
assert newline_token.type == 'NEWLINE', \
'Can only search for a dent starting from a newline.'
next_line_pos = newline_token.lexpos + len(newline_token.value)
if next_line_pos == len(newline_token.lexer.lexdata):
# Reached end of file
return None
line = newline_token.lexer.lexdata[next_line_pos:].split(os.linesep, 1)[0]
if not line:
return None
lstripped_line = line.lstrip()
lstripped_line_length = len(lstripped_line)
if lstripped_line_length == 0:
# If the next line is composed of only spaces, ignore indentation.
return None
if lstripped_line[0] == '#':
# If it's a comment line, ignore indentation.
return None
indent = len(line) - lstripped_line_length
if indent % 4 > 0:
self.errors.append(
('Indent is not divisible by 4.', newline_token.lexer.lineno))
return None
indent_delta = indent - _indent_level_to_spaces_count(self.cur_indent)
return indent_delta // 4
# A string containing ignored characters (spaces and tabs)
t_ignore = ' \t'
# Error handling rule
def t_ANY_error(self, token):
self._logger.debug('Illegal character %r at line %d',
token.value[0], token.lexer.lineno)
self.errors.append(
('Illegal character %s.' % repr(token.value[0]).lstrip('u'),
token.lexer.lineno))
token.lexer.skip(1)
def _create_token(token_type, value, lineno, lexpos):
"""
Helper for creating ply.lex.LexToken objects. Unfortunately, LexToken
does not have a constructor defined to make settings these values easy.
"""
token = lex.LexToken()
token.type = token_type
token.value = value
token.lineno = lineno
token.lexpos = lexpos
return token
def _indent_level_to_spaces_count(indent):
return indent * 4

View file

@ -0,0 +1,880 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import OrderedDict
import logging
import ply.yacc as yacc
from .lexer import (
Lexer,
NullToken,
)
from .ast import (
AstAlias,
AstAnnotationDef,
AstAnnotationRef,
AstAnnotationTypeDef,
AstAttrField,
AstExample,
AstExampleField,
AstExampleRef,
AstField,
AstNamespace,
AstImport,
AstRouteDef,
AstStructDef,
AstStructPatch,
AstSubtypeField,
AstTagRef,
AstTypeRef,
AstUnionDef,
AstUnionPatch,
AstVoidField,
)
logger = logging.getLogger(str('stone.frontend.parser'))
class ParserFactory(object):
"""
After instantiating a ParserFactory, call get_parser() to get an object
with a parse() method. It so happens that the object is also a
ParserFactory. The purpose of get_parser() is to reset the internal state
of the fatory. The details for why these aren't cleanly separated have to
do with the inability to separate out the yacc.yacc BNF definition parser
from the class methods that implement the parser handling logic.
Due to how ply.yacc works, the docstring of each parser method is a BNF
rule. Comments that would normally be docstrings for each parser rule
method are kept before the method definition.
"""
# Ply parser requiment: Tokens must be re-specified in parser
tokens = Lexer.tokens
# Ply feature: Starting grammar rule
start = str('spec') # PLY wants a 'str' instance; this makes it work in Python 2 and 3
def __init__(self, debug=False):
self.debug = debug
self.yacc = yacc.yacc(module=self, debug=self.debug, write_tables=self.debug)
self.lexer = Lexer()
# [(token type, token value, line number), ...]
self.errors = []
# Path to file being parsed. This is added to each token for its
# utility in error reporting. But the path is never accessed, so this
# is optional.
self.path = None
self.anony_defs = []
self.exhausted = True
def get_parser(self):
"""
Returns a ParserFactory with the state reset so it can be used to
parse again.
:return: ParserFactory
"""
self.path = None
self.anony_defs = []
self.exhausted = False
return self
def parse(self, data, path=None):
"""
Args:
data (str): Raw specification text.
path (Optional[str]): Path to specification on filesystem. Only
used to tag tokens with the file they originated from.
"""
assert not self.exhausted, 'Must call get_parser() to reset state.'
self.path = path
parsed_data = self.yacc.parse(data, lexer=self.lexer, debug=self.debug)
# It generally makes sense for lexer errors to come first, because
# those can be the root of parser errors. Also, since we only show one
# error max right now, it's best to show the lexing one.
for err_msg, lineno in self.lexer.errors[::-1]:
self.errors.insert(0, (err_msg, lineno, self.path))
parsed_data.extend(self.anony_defs)
self.exhausted = True
return parsed_data
def test_lexing(self, data):
self.lexer.test(data)
def got_errors_parsing(self):
"""Whether the lexer or parser had errors."""
return self.errors
def get_errors(self):
"""
If got_errors_parsing() returns True, call this to get the errors.
Returns:
list[tuple[msg: str, lineno: int, path: str]]
"""
return self.errors[:]
# --------------------------------------------------------------
# Spec := Namespace Import* Definition*
def p_spec_init(self, p):
"""spec : NL
| empty"""
p[0] = []
def p_spec_init_decl(self, p):
"""spec : namespace
| import
| definition"""
p[0] = [p[1]]
def p_spec_iter(self, p):
"""spec : spec namespace
| spec import
| spec definition"""
p[0] = p[1]
p[0].append(p[2])
# This covers the case where we have garbage characters in a file that
# splits a NL token into two separate tokens.
def p_spec_ignore_newline(self, p):
'spec : spec NL'
p[0] = p[1]
def p_definition(self, p):
"""definition : alias
| annotation
| annotation_type
| struct
| struct_patch
| union
| union_patch
| route"""
p[0] = p[1]
def p_namespace(self, p):
"""namespace : KEYWORD ID NL
| KEYWORD ID NL INDENT docsection DEDENT"""
if p[1] == 'namespace':
doc = None
if len(p) > 4:
doc = p[5]
p[0] = AstNamespace(
self.path, p.lineno(1), p.lexpos(1), p[2], doc)
else:
raise ValueError('Expected namespace keyword')
def p_import(self, p):
'import : IMPORT ID NL'
p[0] = AstImport(self.path, p.lineno(1), p.lexpos(1), p[2])
def p_alias(self, p):
"""alias : KEYWORD ID EQ type_ref NL
| KEYWORD ID EQ type_ref NL INDENT annotation_ref_list docsection DEDENT"""
if p[1] == 'alias':
has_annotations = len(p) > 6 and p[7] is not None
doc = p[8] if len(p) > 6 else None
p[0] = AstAlias(
self.path, p.lineno(1), p.lexpos(1), p[2], p[4], doc)
if has_annotations:
p[0].set_annotations(p[7])
else:
raise ValueError('Expected alias keyword')
def p_nl(self, p):
'NL : NEWLINE'
p[0] = p[1]
# Sometimes we'll have multiple consecutive newlines that the lexer has
# trouble combining, so we do it in the parser.
def p_nl_combine(self, p):
'NL : NL NEWLINE'
p[0] = p[1]
# --------------------------------------------------------------
# Primitive Types
def p_primitive(self, p):
"""primitive : BOOLEAN
| FLOAT
| INTEGER
| NULL
| STRING"""
p[0] = p[1]
# --------------------------------------------------------------
# References to Types
#
# There are several places references to types are made:
# 1. Alias sources
# alias x = TypeRef
# 2. Field data types
# struct S
# f TypeRef
# 3. In arguments to type references
# struct S
# f TypeRef(key=TypeRef)
#
# A type reference can have positional and keyword arguments:
# TypeRef(value1, ..., kwarg1=kwvalue1)
# If it has no arguments, the parentheses can be omitted.
#
# If a type reference has a '?' suffix, it is a nullable type.
def p_pos_arg(self, p):
"""pos_arg : primitive
| type_ref"""
p[0] = p[1]
def p_pos_args_list_create(self, p):
"""pos_args_list : pos_arg"""
p[0] = [p[1]]
def p_pos_args_list_extend(self, p):
"""pos_args_list : pos_args_list COMMA pos_arg"""
p[0] = p[1]
p[0].append(p[3])
def p_kw_arg(self, p):
"""kw_arg : ID EQ primitive
| ID EQ type_ref"""
p[0] = {p[1]: p[3]}
def p_kw_args(self, p):
"""kw_args : kw_arg"""
p[0] = p[1]
def p_kw_args_update(self, p):
"""kw_args : kw_args COMMA kw_arg"""
p[0] = p[1]
for key in p[3]:
if key in p[1]:
msg = "Keyword argument '%s' defined more than once." % key
self.errors.append((msg, p.lineno(2), self.path))
p[0].update(p[3])
def p_args(self, p):
"""args : LPAR pos_args_list COMMA kw_args RPAR
| LPAR pos_args_list RPAR
| LPAR kw_args RPAR
| LPAR RPAR
| empty"""
if len(p) > 3:
if p[3] == ',':
p[0] = (p[2], p[4])
elif isinstance(p[2], dict):
p[0] = ([], p[2])
else:
p[0] = (p[2], {})
else:
p[0] = ([], {})
def p_field_nullable(self, p):
"""nullable : Q
| empty"""
p[0] = p[1] == '?'
def p_type_ref(self, p):
'type_ref : ID args nullable'
p[0] = AstTypeRef(
path=self.path,
lineno=p.lineno(1),
lexpos=p.lexpos(1),
name=p[1],
args=p[2],
nullable=p[3],
ns=None,
)
# A reference to a type in another namespace.
def p_foreign_type_ref(self, p):
'type_ref : ID DOT ID args nullable'
p[0] = AstTypeRef(
path=self.path,
lineno=p.lineno(1),
lexpos=p.lexpos(1),
name=p[3],
args=p[4],
nullable=p[5],
ns=p[1],
)
# --------------------------------------------------------------
# Annotation types
#
# An example annotation type:
#
# annotation_type Sensitive
# "This is a docstring for the annotation type"
#
# sensitivity Int32
#
# reason String?
# "This is a docstring for the field"
#
def p_annotation_type(self, p):
"""annotation_type : ANNOTATION_TYPE ID NL \
INDENT docsection field_list DEDENT"""
p[0] = AstAnnotationTypeDef(
path=self.path,
lineno=p.lineno(1),
lexpos=p.lexpos(1),
name=p[2],
doc=p[5],
params=p[6])
# --------------------------------------------------------------
# Structs
#
# An example struct looks as follows:
#
# struct S extends P
# "This is a docstring for the struct"
#
# typed_field String
# "This is a docstring for the field"
#
# An example struct that enumerates subtypes looks as follows:
#
# struct P
# union
# t1 S1
# t2 S2
# field String
#
# struct S1 extends P
# ...
#
# struct S2 extends P
# ...
#
def p_enumerated_subtypes(self, p):
"""enumerated_subtypes : uniont NL INDENT subtypes_list DEDENT
| empty"""
if len(p) > 2:
p[0] = (p[4], p[1][0] == 'union')
def p_struct(self, p):
"""struct : STRUCT ID inheritance NL \
INDENT docsection enumerated_subtypes field_list examples DEDENT"""
self.make_struct(p)
def p_anony_struct(self, p):
"""anony_def : STRUCT empty inheritance NL \
INDENT docsection enumerated_subtypes field_list examples DEDENT"""
self.make_struct(p)
def make_struct(self, p):
p[0] = AstStructDef(
path=self.path,
lineno=p.lineno(1),
lexpos=p.lexpos(1),
name=p[2],
extends=p[3],
doc=p[6],
subtypes=p[7],
fields=p[8],
examples=p[9])
def p_struct_patch(self, p):
"""struct_patch : PATCH STRUCT ID NL INDENT field_list examples DEDENT"""
p[0] = AstStructPatch(
path=self.path,
lineno=p.lineno(1),
lexpos=p.lexpos(1),
name=p[3],
fields=p[6],
examples=p[7])
def p_inheritance(self, p):
"""inheritance : EXTENDS type_ref
| empty"""
if p[1]:
if p[2].nullable:
msg = 'Reference cannot be nullable.'
self.errors.append((msg, p.lineno(1), self.path))
else:
p[0] = p[2]
def p_enumerated_subtypes_list_create(self, p):
"""subtypes_list : subtype_field
| empty"""
if p[1] is not None:
p[0] = [p[1]]
def p_enumerated_subtypes_list_extend(self, p):
'subtypes_list : subtypes_list subtype_field'
p[0] = p[1]
p[0].append(p[2])
def p_enumerated_subtype_field(self, p):
'subtype_field : ID type_ref NL'
p[0] = AstSubtypeField(
self.path, p.lineno(1), p.lexpos(1), p[1], p[2])
# --------------------------------------------------------------
# Fields
#
# Each struct has zero or more fields. A field has a name, type,
# and docstring.
#
# TODO(kelkabany): Split fields into struct fields and union fields
# since they differ in capabilities rather significantly now.
def p_field_list_create(self, p):
"""field_list : field
| empty"""
if p[1] is None:
p[0] = []
else:
p[0] = [p[1]]
def p_field_list_extend(self, p):
'field_list : field_list field'
p[0] = p[1]
p[0].append(p[2])
def p_default_option(self, p):
"""default_option : EQ primitive
| EQ tag_ref
| empty"""
if p[1]:
if isinstance(p[2], AstTagRef):
p[0] = p[2]
else:
p[0] = p[2]
def p_field(self, p):
"""field : ID type_ref default_option NL \
INDENT annotation_ref_list docsection anony_def_option DEDENT
| ID type_ref default_option NL"""
has_annotations = len(p) > 5 and p[6] is not None
has_docstring = len(p) > 5 and p[7] is not None
has_anony_def = len(p) > 5 and p[8] is not None
p[0] = AstField(
self.path, p.lineno(1), p.lexpos(1), p[1], p[2])
if p[3] is not None:
if p[3] is NullToken:
p[0].set_default(None)
else:
p[0].set_default(p[3])
if has_annotations:
p[0].set_annotations(p[6])
if has_docstring:
p[0].set_doc(p[7])
if has_anony_def:
p[8].name = p[2].name
self.anony_defs.append(p[8])
def p_anony_def_option(self, p):
"""anony_def_option : anony_def
| empty"""
p[0] = p[1]
def p_tag_ref(self, p):
'tag_ref : ID'
p[0] = AstTagRef(self.path, p.lineno(1), p.lexpos(1), p[1])
def p_annotation(self, p):
"""annotation : ANNOTATION ID EQ ID args NL
| ANNOTATION ID EQ ID DOT ID args NL"""
if len(p) < 8:
args, kwargs = p[5]
p[0] = AstAnnotationDef(
self.path, p.lineno(1), p.lexpos(1), p[2], p[4], None, args, kwargs)
else:
args, kwargs = p[7]
p[0] = AstAnnotationDef(
self.path, p.lineno(1), p.lexpos(1), p[2], p[6], p[4], args, kwargs)
def p_annotation_ref_list_create(self, p):
"""annotation_ref_list : annotation_ref
| empty"""
if p[1] is not None:
p[0] = [p[1]]
else:
p[0] = None
def p_annotation_ref_list_extend(self, p):
"""annotation_ref_list : annotation_ref_list annotation_ref"""
p[0] = p[1]
p[0].append(p[2])
def p_annotation_ref(self, p):
"""annotation_ref : AT ID NL
| AT ID DOT ID NL"""
if len(p) < 5:
p[0] = AstAnnotationRef(self.path, p.lineno(1), p.lexpos(1), p[2], None)
else:
p[0] = AstAnnotationRef(self.path, p.lineno(1), p.lexpos(1), p[4], p[2])
# --------------------------------------------------------------
# Unions
#
# An example union looks as follows:
#
# union U
# "This is a docstring for the union"
#
# void_field*
# "Docstring for field with type Void"
# typed_field String
#
# void_field demonstrates the notation for a catch all variant.
def p_union(self, p):
"""union : uniont ID inheritance NL \
INDENT docsection field_list examples DEDENT"""
self.make_union(p)
def p_anony_union(self, p):
"""anony_def : uniont empty inheritance NL \
INDENT docsection field_list examples DEDENT"""
self.make_union(p)
def make_union(self, p):
p[0] = AstUnionDef(
path=self.path,
lineno=p[1][1],
lexpos=p[1][2],
name=p[2],
extends=p[3],
doc=p[6],
fields=p[7],
examples=p[8],
closed=p[1][0] == 'union_closed')
def p_union_patch(self, p):
"""union_patch : PATCH uniont ID NL INDENT field_list examples DEDENT"""
p[0] = AstUnionPatch(
path=self.path,
lineno=p[2][1],
lexpos=p[2][2],
name=p[3],
fields=p[6],
examples=p[7],
closed=p[2][0] == 'union_closed')
def p_uniont(self, p):
"""uniont : UNION
| UNION_CLOSED"""
p[0] = (p[1], p.lineno(1), p.lexpos(1))
def p_field_void(self, p):
"""field : ID NL
| ID NL INDENT annotation_ref_list docsection DEDENT"""
p[0] = AstVoidField(self.path, p.lineno(1), p.lexpos(1), p[1])
if len(p) > 3:
if p[4] is not None:
p[0].set_annotations(p[4])
if p[5] is not None:
p[0].set_doc(p[5])
# --------------------------------------------------------------
# Routes
#
# An example route looks as follows:
#
# route sample-route/sub-path:2 (arg, result, error)
# "This is a docstring for the route"
#
# attrs
# key="value"
#
# The error type is optional.
def p_route(self, p):
"""route : ROUTE route_name route_version route_io route_deprecation NL \
INDENT docsection attrssection DEDENT
| ROUTE route_name route_version route_io route_deprecation NL"""
p[0] = AstRouteDef(self.path, p.lineno(1), p.lexpos(1), p[2], p[3], p[5], *p[4])
if len(p) > 7:
p[0].set_doc(p[8])
if p[9]:
keys = set()
for attr in p[9]:
if attr.name in keys:
msg = "Attribute '%s' defined more than once." % attr.name
self.errors.append((msg, attr.lineno, attr.path))
keys.add(attr.name)
p[0].set_attrs(p[9])
def p_route_name(self, p):
'route_name : ID route_path'
if p[2]:
p[0] = p[1] + p[2]
else:
p[0] = p[1]
def p_route_path_suffix(self, p):
"""route_path : PATH
| empty"""
p[0] = p[1]
def p_route_version(self, p):
"""route_version : COLON INTEGER
| empty"""
if len(p) > 2:
if p[2] <= 0:
msg = "Version number should be a positive integer."
self.errors.append((msg, p.lineno(2), self.path))
p[0] = p[2]
else:
p[0] = 1
def p_route_io(self, p):
"""route_io : LPAR type_ref COMMA type_ref RPAR
| LPAR type_ref COMMA type_ref COMMA type_ref RPAR"""
if len(p) > 6:
p[0] = (p[2], p[4], p[6])
else:
p[0] = (p[2], p[4], None)
def p_route_deprecation(self, p):
"""route_deprecation : DEPRECATED
| DEPRECATED BY route_name route_version
| empty"""
if len(p) == 5:
p[0] = (True, p[3], p[4])
elif p[1]:
p[0] = (True, None, None)
def p_attrs_section(self, p):
"""attrssection : ATTRS NL INDENT attr_fields DEDENT
| empty"""
if p[1]:
p[0] = p[4]
def p_attr_fields_create(self, p):
'attr_fields : attr_field'
p[0] = [p[1]]
def p_attr_fields_add(self, p):
'attr_fields : attr_fields attr_field'
p[0] = p[1]
p[0].append(p[2])
def p_attr_field(self, p):
"""attr_field : ID EQ primitive NL
| ID EQ tag_ref NL"""
if p[3] is NullToken:
p[0] = AstAttrField(
self.path, p.lineno(1), p.lexpos(1), p[1], None)
else:
p[0] = AstAttrField(
self.path, p.lineno(1), p.lexpos(1), p[1], p[3])
# --------------------------------------------------------------
# Doc sections
#
# Doc sections appear after struct, union, and route signatures;
# also after field declarations.
#
# They're represented by text (multi-line supported) enclosed by
# quotations.
#
# struct S
# "This is a docstring
# for struct S"
#
# number Int64
# "This is a docstring for this field"
def p_docsection(self, p):
"""docsection : docstring NL
| empty"""
if p[1] is not None:
p[0] = p[1]
def p_docstring_string(self, p):
'docstring : STRING'
# Remove trailing whitespace on every line.
p[0] = '\n'.join([line.rstrip() for line in p[1].split('\n')])
# --------------------------------------------------------------
# Examples
#
# Examples appear at the bottom of struct definitions to give
# illustrative examples of what struct values may look like.
#
# struct S
# number Int64
#
# example default "This is a label"
# number=42
def p_examples_create(self, p):
"""examples : example
| empty"""
p[0] = OrderedDict()
if p[1] is not None:
p[0][p[1].label] = p[1]
def p_examples_add(self, p):
'examples : examples example'
p[0] = p[1]
if p[2].label in p[0]:
existing_ex = p[0][p[2].label]
self.errors.append(
("Example with label '%s' already defined on line %d." %
(existing_ex.label, existing_ex.lineno),
p[2].lineno, p[2].path))
p[0][p[2].label] = p[2]
# It's possible for no example fields to be specified.
def p_example(self, p):
"""example : KEYWORD ID NL INDENT docsection example_fields DEDENT
| KEYWORD ID NL"""
if len(p) > 4:
seen_fields = set()
for example_field in p[6]:
if example_field.name in seen_fields:
self.errors.append(
("Example with label '%s' defines field '%s' more "
"than once." % (p[2], example_field.name),
p.lineno(1), self.path))
seen_fields.add(example_field.name)
p[0] = AstExample(
self.path, p.lineno(1), p.lexpos(1), p[2], p[5],
OrderedDict((f.name, f) for f in p[6]))
else:
p[0] = AstExample(
self.path, p.lineno(1), p.lexpos(1), p[2], None, OrderedDict())
def p_example_fields_create(self, p):
'example_fields : example_field'
p[0] = [p[1]]
def p_example_fields_add(self, p):
'example_fields : example_fields example_field'
p[0] = p[1]
p[0].append(p[2])
def p_example_field(self, p):
"""example_field : ID EQ primitive NL
| ID EQ ex_list NL
| ID EQ ex_map NL"""
if p[3] is NullToken:
p[0] = AstExampleField(
self.path, p.lineno(1), p.lexpos(1), p[1], None)
else:
p[0] = AstExampleField(
self.path, p.lineno(1), p.lexpos(1), p[1], p[3])
def p_example_multiline(self, p):
"""example_field : ID EQ NL INDENT ex_map NL DEDENT"""
p[0] = AstExampleField(
self.path, p.lineno(1), p.lexpos(1), p[1], p[5])
def p_example_field_ref(self, p):
'example_field : ID EQ ID NL'
p[0] = AstExampleField(self.path, p.lineno(1), p.lexpos(1),
p[1], AstExampleRef(self.path, p.lineno(3), p.lexpos(3), p[3]))
# --------------------------------------------------------------
# Example of list
def p_ex_list(self, p):
"""ex_list : LBRACKET ex_list_items RBRACKET
| LBRACKET empty RBRACKET"""
if p[2] is None:
p[0] = []
else:
p[0] = p[2]
def p_ex_list_item_primitive(self, p):
'ex_list_item : primitive'
if p[1] is NullToken:
p[0] = None
else:
p[0] = p[1]
def p_ex_list_item_id(self, p):
'ex_list_item : ID'
p[0] = AstExampleRef(self.path, p.lineno(1), p.lexpos(1), p[1])
def p_ex_list_item_list(self, p):
'ex_list_item : ex_list'
p[0] = p[1]
def p_ex_list_items_create(self, p):
"""ex_list_items : ex_list_item"""
p[0] = [p[1]]
def p_ex_list_items_extend(self, p):
"""ex_list_items : ex_list_items COMMA ex_list_item"""
p[0] = p[1]
p[0].append(p[3])
# --------------------------------------------------------------
# Maps
#
def p_ex_map(self, p):
"""ex_map : LBRACE ex_map_pairs RBRACE
| LBRACE empty RBRACE"""
p[0] = p[2] or {}
def p_ex_map_multiline(self, p):
"""ex_map : LBRACE NL INDENT ex_map_pairs NL DEDENT RBRACE"""
p[0] = p[4] or {}
def p_ex_map_elem_primitive(self, p):
"""ex_map_elem : primitive"""
p[0] = None if p[1] == NullToken else p[1]
def p_ex_map_elem_composit(self, p):
"""ex_map_elem : ex_map
| ex_list"""
p[0] = p[1]
def p_ex_map_elem_id(self, p):
"""ex_map_elem : ID"""
p[0] = AstExampleRef(self.path, p.lineno(1), p.lexpos(1), p[1])
def p_ex_map_pair(self, p):
"""ex_map_pair : ex_map_elem COLON ex_map_elem"""
try:
p[0] = {p[1]: p[3]}
except TypeError:
msg = u"%s is an invalid hash key because it cannot be hashed." % repr(p[1])
self.errors.append((msg, p.lineno(2), self.path))
p[0] = {}
def p_ex_map_pairs_create(self, p):
"""ex_map_pairs : ex_map_pair """
p[0] = p[1]
def p_ex_map_pairs_extend(self, p):
"""ex_map_pairs : ex_map_pairs COMMA ex_map_pair"""
p[0] = p[1]
p[0].update(p[3])
def p_ex_map_pairs_multiline(self, p):
"""ex_map_pairs : ex_map_pairs COMMA NL ex_map_pair"""
p[0] = p[1]
p[0].update(p[4])
# --------------------------------------------------------------
# In ply, this is how you define an empty rule. This is used when we want
# the parser to treat a rule as optional.
def p_empty(self, p):
'empty :'
# Called by the parser whenever a token doesn't match any rule.
def p_error(self, token):
assert token is not None, "Unknown error, please report this."
logger.debug('Unexpected %s(%r) at line %d',
token.type,
token.value,
token.lineno)
self.errors.append(
("Unexpected %s with value %s." %
(token.type, repr(token.value).lstrip('u')),
token.lineno, self.path))

View file

@ -0,0 +1,2 @@
from .api import * # noqa: F401,F403
from .data_types import * # noqa: F401,F403

View file

@ -0,0 +1,440 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from collections import OrderedDict
# See <https://github.com/PyCQA/pylint/issues/73>
from distutils.version import StrictVersion
import six
from .data_types import (
doc_unwrap,
is_alias,
is_composite_type,
is_list_type,
is_nullable_type,
)
_MYPY = False
if _MYPY:
import typing # pylint: disable=import-error,useless-suppression
from .data_types import ( # noqa: F401 # pylint: disable=unused-import
Alias,
Annotation,
AnnotationType,
DataType,
List as DataTypeList,
Nullable,
Struct,
UserDefined,
)
from stone.frontend.ast import AstRouteDef # noqa: F401 # pylint: disable=unused-import
# TODO: This can be changed back to a single declaration with a
# unicode literal after <https://github.com/python/mypy/pull/2516>
# makes it into a PyPi release
if six.PY3:
NamespaceDict = typing.Dict[typing.Text, 'ApiNamespace']
else:
NamespaceDict = typing.Dict[typing.Text, b'ApiNamespace']
class Api(object):
"""
A full description of an API's namespaces, data types, and routes.
"""
def __init__(self, version):
# type: (str) -> None
self.version = StrictVersion(version)
self.namespaces = OrderedDict() # type: NamespaceDict
self.route_schema = None # type: typing.Optional[Struct]
def ensure_namespace(self, name):
# type: (str) -> ApiNamespace
"""
Only creates a namespace if it hasn't yet been defined.
:param str name: Name of the namespace.
:return ApiNamespace:
"""
if name not in self.namespaces:
self.namespaces[name] = ApiNamespace(name)
return self.namespaces[name]
def normalize(self):
# type: () -> None
"""
Alphabetizes namespaces and routes to make spec parsing order mostly
irrelevant.
"""
ordered_namespaces = OrderedDict() # type: NamespaceDict
# self.namespaces is currently ordered by declaration order.
for namespace_name in sorted(self.namespaces.keys()):
ordered_namespaces[namespace_name] = self.namespaces[namespace_name]
self.namespaces = ordered_namespaces
for namespace in self.namespaces.values():
namespace.normalize()
def add_route_schema(self, route_schema):
# type: (Struct) -> None
assert self.route_schema is None
self.route_schema = route_schema
class _ImportReason(object):
"""
Tracks the reason a namespace was imported.
"""
def __init__(self):
# type: () -> None
self.alias = False
self.data_type = False
self.annotation = False
self.annotation_type = False
class ApiNamespace(object):
"""
Represents a category of API endpoints and their associated data types.
"""
def __init__(self, name):
# type: (typing.Text) -> None
self.name = name
self.doc = None # type: typing.Optional[six.text_type]
self.routes = [] # type: typing.List[ApiRoute]
# TODO (peichao): route_by_name is deprecated by routes_by_name and should be removed.
self.route_by_name = {} # type: typing.Dict[typing.Text, ApiRoute]
self.routes_by_name = {} # type: typing.Dict[typing.Text, ApiRoutesByVersion]
self.data_types = [] # type: typing.List[UserDefined]
self.data_type_by_name = {} # type: typing.Dict[str, UserDefined]
self.aliases = [] # type: typing.List[Alias]
self.alias_by_name = {} # type: typing.Dict[str, Alias]
self.annotations = [] # type: typing.List[Annotation]
self.annotation_by_name = {} # type: typing.Dict[str, Annotation]
self.annotation_types = [] # type: typing.List[AnnotationType]
self.annotation_type_by_name = {} # type: typing.Dict[str, AnnotationType]
self._imported_namespaces = {} # type: typing.Dict[ApiNamespace, _ImportReason]
def add_doc(self, docstring):
# type: (six.text_type) -> None
"""Adds a docstring for this namespace.
The input docstring is normalized to have no leading whitespace and
no trailing whitespace except for a newline at the end.
If a docstring already exists, the new normalized docstring is appended
to the end of the existing one with two newlines separating them.
"""
assert isinstance(docstring, six.text_type), type(docstring)
normalized_docstring = doc_unwrap(docstring) + '\n'
if self.doc is None:
self.doc = normalized_docstring
else:
self.doc += normalized_docstring
def add_route(self, route):
# type: (ApiRoute) -> None
self.routes.append(route)
if route.version == 1:
self.route_by_name[route.name] = route
if route.name not in self.routes_by_name:
self.routes_by_name[route.name] = ApiRoutesByVersion()
self.routes_by_name[route.name].at_version[route.version] = route
def add_data_type(self, data_type):
# type: (UserDefined) -> None
self.data_types.append(data_type)
self.data_type_by_name[data_type.name] = data_type
def add_alias(self, alias):
# type: (Alias) -> None
self.aliases.append(alias)
self.alias_by_name[alias.name] = alias
def add_annotation(self, annotation):
# type: (Annotation) -> None
self.annotations.append(annotation)
self.annotation_by_name[annotation.name] = annotation
def add_annotation_type(self, annotation_type):
# type: (AnnotationType) -> None
self.annotation_types.append(annotation_type)
self.annotation_type_by_name[annotation_type.name] = annotation_type
def add_imported_namespace(self,
namespace,
imported_alias=False,
imported_data_type=False,
imported_annotation=False,
imported_annotation_type=False):
# type: (ApiNamespace, bool, bool, bool, bool) -> None
"""
Keeps track of namespaces that this namespace imports.
Args:
namespace (Namespace): The imported namespace.
imported_alias (bool): Set if this namespace references an alias
in the imported namespace.
imported_data_type (bool): Set if this namespace references a
data type in the imported namespace.
imported_annotation (bool): Set if this namespace references a
annotation in the imported namespace.
imported_annotation_type (bool): Set if this namespace references an
annotation in the imported namespace, possibly indirectly (by
referencing an annotation elsewhere that has this type).
"""
assert self.name != namespace.name, \
'Namespace cannot import itself.'
reason = self._imported_namespaces.setdefault(namespace, _ImportReason())
if imported_alias:
reason.alias = True
if imported_data_type:
reason.data_type = True
if imported_annotation:
reason.annotation = True
if imported_annotation_type:
reason.annotation_type = True
def linearize_data_types(self):
# type: () -> typing.List[UserDefined]
"""
Returns a list of all data types used in the namespace. Because the
inheritance of data types can be modeled as a DAG, the list will be a
linearization of the DAG. It's ideal to generate data types in this
order so that composite types that reference other composite types are
defined in the correct order.
"""
linearized_data_types = []
seen_data_types = set() # type: typing.Set[UserDefined]
def add_data_type(data_type):
# type: (UserDefined) -> None
if data_type in seen_data_types:
return
elif data_type.namespace != self:
# We're only concerned with types defined in this namespace.
return
if is_composite_type(data_type) and data_type.parent_type:
add_data_type(data_type.parent_type)
linearized_data_types.append(data_type)
seen_data_types.add(data_type)
for data_type in self.data_types:
add_data_type(data_type)
return linearized_data_types
def linearize_aliases(self):
# type: () -> typing.List[Alias]
"""
Returns a list of all aliases used in the namespace. The aliases are
ordered to ensure that if they reference other aliases those aliases
come earlier in the list.
"""
linearized_aliases = []
seen_aliases = set() # type: typing.Set[Alias]
def add_alias(alias):
# type: (Alias) -> None
if alias in seen_aliases:
return
elif alias.namespace != self:
return
if is_alias(alias.data_type):
add_alias(alias.data_type)
linearized_aliases.append(alias)
seen_aliases.add(alias)
for alias in self.aliases:
add_alias(alias)
return linearized_aliases
def get_route_io_data_types(self):
# type: () -> typing.List[UserDefined]
"""
Returns a list of all user-defined data types that are referenced as
either an argument, result, or error of a route. If a List or Nullable
data type is referenced, then the contained data type is returned
assuming it's a user-defined type.
"""
data_types = set() # type: typing.Set[UserDefined]
for route in self.routes:
data_types |= self.get_route_io_data_types_for_route(route)
return sorted(data_types, key=lambda dt: dt.name)
def get_route_io_data_types_for_route(self, route):
# type: (ApiRoute) -> typing.Set[UserDefined]
"""
Given a route, returns a set of its argument/result/error datatypes.
"""
data_types = set() # type: typing.Set[UserDefined]
for dtype in (route.arg_data_type, route.result_data_type, route.error_data_type):
while is_list_type(dtype) or is_nullable_type(dtype):
data_list_type = dtype # type: typing.Any
dtype = data_list_type.data_type
if is_composite_type(dtype) or is_alias(dtype):
data_user_type = dtype # type: typing.Any
data_types.add(data_user_type)
return data_types
def get_imported_namespaces(self,
must_have_imported_data_type=False,
consider_annotations=False,
consider_annotation_types=False):
# type: (bool, bool, bool) -> typing.List[ApiNamespace]
"""
Returns a list of Namespace objects. A namespace is a member of this
list if it is imported by the current namespace and a data type is
referenced from it. Namespaces are in ASCII order by name.
Args:
must_have_imported_data_type (bool): If true, result does not
include namespaces that were not imported for data types.
consider_annotations (bool): If false, result does not include
namespaces that were only imported for annotations
consider_annotation_types (bool): If false, result does not
include namespaces that were only imported for annotation types.
Returns:
List[Namespace]: A list of imported namespaces.
"""
imported_namespaces = []
for imported_namespace, reason in self._imported_namespaces.items():
if must_have_imported_data_type and not reason.data_type:
continue
if (not consider_annotations) and not (
reason.data_type or reason.alias or reason.annotation_type
):
continue
if (not consider_annotation_types) and not (
reason.data_type or reason.alias or reason.annotation
):
continue
imported_namespaces.append(imported_namespace)
imported_namespaces.sort(key=lambda n: n.name)
return imported_namespaces
def get_namespaces_imported_by_route_io(self):
# type: () -> typing.List[ApiNamespace]
"""
Returns a list of Namespace objects. A namespace is a member of this
list if it is imported by the current namespace and has a data type
from it referenced as an argument, result, or error of a route.
Namespaces are in ASCII order by name.
"""
namespace_data_types = sorted(self.get_route_io_data_types(),
key=lambda dt: dt.name)
referenced_namespaces = set()
for data_type in namespace_data_types:
if data_type.namespace != self:
referenced_namespaces.add(data_type.namespace)
return sorted(referenced_namespaces, key=lambda n: n.name)
def normalize(self):
# type: () -> None
"""
Alphabetizes routes to make route declaration order irrelevant.
"""
self.routes.sort(key=lambda route: route.name)
self.data_types.sort(key=lambda data_type: data_type.name)
self.aliases.sort(key=lambda alias: alias.name)
self.annotations.sort(key=lambda annotation: annotation.name)
def __repr__(self):
# type: () -> str
return str('ApiNamespace({!r})').format(self.name)
class ApiRoute(object):
"""
Represents an API endpoint.
"""
def __init__(self,
name,
version,
ast_node):
# type: (typing.Text, int, typing.Optional[AstRouteDef]) -> None
"""
:param str name: Designated name of the endpoint.
:param int version: Designated version of the endpoint.
:param ast_node: Raw route definition from the parser.
"""
self.name = name
self.version = version
self._ast_node = ast_node
# These attributes are set later by set_attributes()
self.deprecated = None # type: typing.Optional[DeprecationInfo]
self.raw_doc = None # type: typing.Optional[typing.Text]
self.doc = None # type: typing.Optional[typing.Text]
self.arg_data_type = None # type: typing.Optional[DataType]
self.result_data_type = None # type: typing.Optional[DataType]
self.error_data_type = None # type: typing.Optional[DataType]
self.attrs = None # type: typing.Optional[typing.Mapping[typing.Text, typing.Any]]
def set_attributes(self, deprecated, doc, arg_data_type, result_data_type,
error_data_type, attrs):
"""
Converts a forward reference definition of a route into a full
definition.
:param DeprecationInfo deprecated: Set if this route is deprecated.
:param str doc: Description of the endpoint.
:type arg_data_type: :class:`stone.data_type.DataType`
:type result_data_type: :class:`stone.data_type.DataType`
:type error_data_type: :class:`stone.data_type.DataType`
:param dict attrs: Map of string keys to values that are either int,
float, bool, str, or None. These are the route attributes assigned
in the spec.
"""
self.deprecated = deprecated
self.raw_doc = doc
self.doc = doc_unwrap(doc)
self.arg_data_type = arg_data_type
self.result_data_type = result_data_type
self.error_data_type = error_data_type
self.attrs = attrs
def name_with_version(self):
"""
Get user-friendly representation of the route.
:return: Route name with version suffix. The version suffix is omitted for version 1.
"""
if self.version == 1:
return self.name
else:
return '{}:{}'.format(self.name, self.version)
def __repr__(self):
return 'ApiRoute({})'.format(self.name_with_version())
class DeprecationInfo(object):
def __init__(self, by=None):
# type: (typing.Optional[ApiRoute]) -> None
"""
:param ApiRoute by: The route that replaces this deprecated one.
"""
assert by is None or isinstance(by, ApiRoute), repr(by)
self.by = by
class ApiRoutesByVersion(object):
"""
Represents routes of different versions for a common name.
"""
def __init__(self):
# type: () -> None
"""
:param at_version: The dict mapping a version number to a route.
"""
self.at_version = {} # type: typing.Dict[int, ApiRoute]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,6 @@
MYPY = False
if MYPY:
from typing import cast # noqa # pylint: disable=unused-import,useless-suppression,import-error
else:
def cast(typ, obj): # pylint: disable=unused-argument
return obj