from __future__ import absolute_import, division, print_function, unicode_literals from contextlib import contextmanager from stone.ir import ( is_list_type, is_map_type, is_struct_type, is_union_type, is_nullable_type, is_user_defined_type, is_void_type, unwrap_nullable, ) from stone.backend import CodeBackend from stone.backends.obj_c_helpers import ( fmt_camel_upper, fmt_class, fmt_class_prefix, fmt_import, ) stone_warning = """\ /// /// Copyright (c) 2016 Dropbox, Inc. All rights reserved. /// /// Auto-generated by Stone, do not modify. /// """ # This will be at the top of the generated file. base_file_comment = """\ {}\ """.format(stone_warning) undocumented = '(no description).' comment_prefix = '/// ' class ObjCBaseBackend(CodeBackend): """Wrapper class over Stone generator for Obj C logic.""" # pylint: disable=abstract-method @contextmanager def block_m(self, class_name): with self.block( '@implementation {}'.format(class_name), delim=('', '@end'), dent=0): self.emit() yield @contextmanager def block_h_from_data_type(self, data_type, protocol=None): assert is_user_defined_type(data_type), \ 'Expected user-defined type, got %r' % type(data_type) if not protocol: extensions = [] if data_type.parent_type and is_struct_type(data_type): extensions.append(fmt_class_prefix(data_type.parent_type)) else: if is_union_type(data_type): # Use a handwritten base class extensions.append('NSObject') else: extensions.append('NSObject') extend_suffix = ' : {}'.format( ', '.join(extensions)) if extensions else '' else: base = fmt_class_prefix(data_type.parent_type) if ( data_type.parent_type and not is_union_type(data_type)) else 'NSObject' extend_suffix = ' : {} <{}>'.format(base, ', '.join(protocol)) with self.block( '@interface {}{}'.format( fmt_class_prefix(data_type), extend_suffix), delim=('', '@end'), dent=0): self.emit() yield @contextmanager def block_h(self, class_name, protocol=None, extensions=None, protected=None): if not extensions: extensions = ['NSObject'] if not protocol: extend_suffix = ' : {}'.format(', '.join(extensions)) else: extend_suffix = ' : {} <{}>'.format(', '.join(extensions), fmt_class(protocol)) base_interface_str = '@interface {}{} {{' if protected else '@interface {}{}' with self.block( base_interface_str.format(class_name, extend_suffix), delim=('', '@end'), dent=0): if protected: with self.block('', delim=('', '')): self.emit('@protected') for field_name, field_type in protected: self.emit('{} _{};'.format(field_type, field_name)) self.emit('}') self.emit() yield @contextmanager def block_init(self): with self.block('if (self)'): yield self.emit('return self;') @contextmanager def block_func(self, func, args=None, return_type='void', class_func=False): args = args if args is not None else [] modifier = '-' if not class_func else '+' base_string = '{} ({}){}:{}' if args else '{} ({}){}' signature = base_string.format(modifier, return_type, func, args) with self.block(signature): yield def _get_imports_m(self, data_types, default_imports): """Emits all necessary implementation file imports for the given Stone data type.""" if not isinstance(data_types, list): data_types = [data_types] import_classes = default_imports for data_type in data_types: import_classes.append(fmt_class_prefix(data_type)) if data_type.parent_type: import_classes.append(fmt_class_prefix(data_type.parent_type)) if is_struct_type( data_type) and data_type.has_enumerated_subtypes(): for _, subtype in data_type.get_all_subtypes_with_tags(): import_classes.append(fmt_class_prefix(subtype)) for field in data_type.all_fields: data_type, _ = unwrap_nullable(field.data_type) # unpack list or map while is_list_type(data_type) or is_map_type(data_type): data_type = (data_type.value_data_type if is_map_type(data_type) else data_type.data_type) if is_user_defined_type(data_type): import_classes.append(fmt_class_prefix(data_type)) if import_classes: import_classes = list(set(import_classes)) import_classes.sort() return import_classes def _get_imports_h(self, data_types): """Emits all necessary header file imports for the given Stone data type.""" if not isinstance(data_types, list): data_types = [data_types] import_classes = [] for data_type in data_types: if is_user_defined_type(data_type): import_classes.append(fmt_class_prefix(data_type)) for field in data_type.all_fields: data_type, _ = unwrap_nullable(field.data_type) # unpack list or map while is_list_type(data_type) or is_map_type(data_type): data_type = (data_type.value_data_type if is_map_type(data_type) else data_type.data_type) if is_user_defined_type(data_type): import_classes.append(fmt_class_prefix(data_type)) import_classes = list(set(import_classes)) import_classes.sort() return import_classes def _generate_imports_h(self, import_classes): import_classes = list(set(import_classes)) import_classes.sort() for import_class in import_classes: self.emit('@class {};'.format(import_class)) if import_classes: self.emit() def _generate_imports_m(self, import_classes): import_classes = list(set(import_classes)) import_classes.sort() for import_class in import_classes: self.emit(fmt_import(import_class)) self.emit() def _generate_init_imports_h(self, data_type): self.emit('#import ') self.emit() self.emit('#import "DBSerializableProtocol.h"') if data_type.parent_type and not is_union_type(data_type): self.emit(fmt_import(fmt_class_prefix(data_type.parent_type))) self.emit() def _get_namespace_route_imports(self, namespace, include_route_args=True, include_route_deep_args=False): result = [] def _unpack_and_store_data_type(data_type): data_type, _ = unwrap_nullable(data_type) if is_list_type(data_type): while is_list_type(data_type): data_type, _ = unwrap_nullable(data_type.data_type) if not is_void_type(data_type) and is_user_defined_type(data_type): result.append(data_type) for route in namespace.routes: if include_route_args: data_type, _ = unwrap_nullable(route.arg_data_type) _unpack_and_store_data_type(data_type) elif include_route_deep_args: data_type, _ = unwrap_nullable(route.arg_data_type) if is_union_type(data_type) or is_list_type(data_type): _unpack_and_store_data_type(data_type) elif not is_void_type(data_type): for field in data_type.all_fields: data_type, _ = unwrap_nullable(field.data_type) if (is_struct_type(data_type) or is_union_type(data_type) or is_list_type(data_type)): _unpack_and_store_data_type(data_type) _unpack_and_store_data_type(route.result_data_type) _unpack_and_store_data_type(route.error_data_type) return result def _cstor_name_from_fields(self, fields): """Returns an Obj C appropriate name for a constructor based on the name of the first argument.""" if fields: return self._cstor_name_from_field(fields[0]) else: return 'initDefault' def _cstor_name_from_field(self, field): """Returns an Obj C appropriate name for a constructor based on the name of the supplied argument.""" return 'initWith{}'.format(fmt_camel_upper(field.name)) def _cstor_name_from_fields_names(self, fields_names): """Returns an Obj C appropriate name for a constructor based on the name of the first argument.""" if fields_names: return 'initWith{}'.format(fmt_camel_upper(fields_names[0][0])) else: return 'initDefault' def _struct_has_defaults(self, struct): """Returns whether the given struct has any default values.""" return [ f for f in struct.all_fields if f.has_default or is_nullable_type(f.data_type) ]