#!/usr/bin/env python3
#
# Copyright 2024 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import argparse
import collections
import copy
import os
import pathlib
import sys
import typing
import re
import dataclasses
def _GetDirAbove(dirname: str):
"""Returns the directory "above" this file containing |dirname| (which must
also be "above" this file)."""
path = os.path.abspath(__file__)
while True:
path, tail = os.path.split(path)
if not tail:
return None
if tail == dirname:
return path
SOURCE_DIR = _GetDirAbove('testing')
sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party'))
sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party/domato/src'))
sys.path.append(os.path.join(SOURCE_DIR, 'build'))
import action_helpers
import jinja2
import grammar
# TODO(crbug.com/361369290): Remove this disable once DomatoLPM development is
# finished and upstream changes can be made to expose the relevant protected
# fields.
# pylint: disable=protected-access
def to_snake_case(name):
name = re.sub(r'([A-Z]{2,})([A-Z][a-z])', r'\1_\2', name)
return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', name, sys.maxsize).lower()
DOMATO_INT_TYPE_TO_CPP_INT_TYPE = {
'int': 'int',
'int32': 'int32_t',
'uint32': 'uint32_t',
'int8': 'int8_t',
'uint8': 'uint8_t',
'int16': 'int16_t',
'uint16': 'uint16_t',
'int64': 'uint64_t',
'uint64': 'uint64_t',
}
DOMATO_TO_PROTO_BUILT_IN = {
'int': 'int32',
'int32': 'int32',
'uint32': 'uint32',
'int8': 'int32',
'uint8': 'uint32',
'int16': 'int16',
'uint16': 'uint16',
'int64': 'int64',
'uint64': 'uint64',
'float': 'float',
'double': 'double',
'char': 'int32',
'string': 'string',
'htmlsafestring': 'string',
'hex': 'int32',
'lines': 'repeated lines',
}
DOMATO_TO_CPP_HANDLERS = {
'int': 'handle_int_conversion<int32_t, int>',
'int32': 'handle_int_conversion<int32_t, int32_t>',
'uint32': 'handle_int_conversion<uint32_t, uint32_t>',
'int8': 'handle_int_conversion<int32_t, int8_t>',
'uint8': 'handle_int_conversion<uint32_t, uint8_t>',
'int16': 'handle_int_conversion<int16_t, int16_t>',
'uint16': 'handle_int_conversion<uint16_t, uint16_t>',
'int64': 'handle_int_conversion<int64_t, int64_t>',
'uint64': 'handle_int_conversion<uint64_t, uint64_t>',
'float': 'handle_float',
'double': 'handle_double',
'char': 'handle_char',
'string': 'handle_string',
'htmlsafestring': 'handle_string',
'hex': 'handle_hex',
}
_C_STR_TRANS = str.maketrans({
'\n': '\\n',
'\r': '\\r',
'\t': '\\t',
'\"': '\\\"',
'\\': '\\\\'
})
BASE_PROTO_NS = 'domatolpm.generated'
def to_cpp_ns(proto_ns: str) -> str:
return proto_ns.replace('.', '::')
CPP_HANDLER_PREFIX = 'handle_'
def to_proto_field_name(name: str) -> str:
"""Converts a creator or rule name to a proto field name. This tries to
respect the protobuf naming convention that field names should be snake case.
Args:
name: the name of the creator or the rule.
Returns:
the proto field name to use.
"""
res = to_snake_case(name.replace('-', '_'))
if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']:
res += '_proto'
return res
def to_proto_type(creator_name: str) -> str:
"""Converts a creator name to a proto type. This is deliberately very simple
so that we avoid naming conflicts.
Args:
creator_name: the name of the creator.
Returns:
the name of the proto type.
"""
res = creator_name.replace('-', '_')
if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']:
res += '_proto'
return res
def c_escape(v: str) -> str:
return v.translate(_C_STR_TRANS)
@dataclasses.dataclass
class ProtoType:
"""Represents a Proto type."""
name: str
def is_one_of(self) -> bool:
return False
@dataclasses.dataclass
class ProtoField:
"""Represents a proto message field."""
type: ProtoType
name: str
proto_id: int
@dataclasses.dataclass
class ProtoMessage(ProtoType):
"""Represents a Proto message."""
fields: typing.List[ProtoField]
@dataclasses.dataclass
class OneOfProtoMessage(ProtoMessage):
"""Represents a Proto message with a oneof field."""
oneofname: str
def is_one_of(self) -> bool:
return True
class CppExpression:
def repr(self):
raise Exception('Not implemented.')
@dataclasses.dataclass
class CppTxtExpression(CppExpression):
"""Represents a Raw text expression."""
content: str
def repr(self):
return self.content
@dataclasses.dataclass
class CppCallExpr(CppExpression):
"""Represents a CallExpr."""
fct_name: str
args: typing.List[CppExpression]
ns: str = ''
def repr(self):
arg_s = ', '.join([a.repr() for a in self.args])
return f'{self.ns}{self.fct_name}({arg_s})'
class CppHandlerCallExpr(CppCallExpr):
def __init__(self,
handler: str,
field_name: str,
extra_args: typing.Optional[typing.List[CppExpression]] = None):
args = [CppTxtExpression('ctx'), CppTxtExpression(f'arg.{field_name}()')]
if extra_args:
args += extra_args
super().__init__(fct_name=handler, args=args)
self.handler = handler
self.field_name = field_name
self.extra_args = extra_args
@dataclasses.dataclass
class CppStringExpr(CppExpression):
"""Represents a C++ literal string.
"""
content: str
def repr(self):
return f'\"{c_escape(self.content)}\"'
@dataclasses.dataclass
class CppFunctionHandler:
"""Represents a C++ function.
"""
name: str
exprs: typing.List[CppExpression]
def is_oneof_handler(self) -> bool:
return False
def is_string_table_handler(self) -> bool:
return False
def is_message_handler(self) -> bool:
return False
class CppStringTableHandler(CppFunctionHandler):
"""Represents a C++ function that implements a string table and returns one
of the represented strings.
"""
def __init__(self, name: str, var_name: str,
strings: typing.List[CppStringExpr]):
super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[])
self.proto_type = f'{name}& arg'
self.strings = strings
self.var_name = var_name
def is_string_table_handler(self) -> bool:
return True
class CppProtoMessageFunctionHandler(CppFunctionHandler):
"""Represents a C++ function that handles a ProtoMessage.
"""
def __init__(self,
name: str,
exprs: typing.List[CppExpression],
creator: typing.Optional[typing.Dict[str, str]] = None):
super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=exprs)
self.proto_type = f'{name}& arg'
self.creator = creator
def creates_new(self):
return self.creator is not None
def is_message_handler(self) -> bool:
return True
class CppOneOfMessageFunctionHandler(CppFunctionHandler):
"""Represents a C++ function that handles a OneOfProtoMessage.
"""
def __init__(self, name: str, switch_name: str,
cases: typing.Dict[str, typing.List[CppExpression]]):
super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[])
self.proto_type = f'{name}& arg'
self.switch_name = switch_name
self.cases = cases
def all_except_last(self):
a = list(self.cases.keys())[:-1]
return {e: self.cases[e] for e in a}
def last(self):
a = list(self.cases.keys())[-1]
return self.cases[a]
def is_oneof_handler(self) -> bool:
return True
class DomatoBuilder:
"""DomatoBuilder is the class that takes a Domato grammar, and modelize it
into a protobuf representation and its corresponding C++ parsing code.
"""
@dataclasses.dataclass
class Entry:
msg: ProtoMessage
func: CppFunctionHandler
def __init__(self, g: grammar.Grammar):
self.handlers: typing.Dict[str, DomatoBuilder.Entry] = {}
self.backrefs: typing.Dict[str, typing.List[str]] = {}
self.grammar = g
if self.grammar._root and self.grammar._root != 'root':
self.root = self.grammar._root
else:
self.root = 'lines'
if self.grammar._root and self.grammar._root == 'root':
rules = self.grammar._creators[self.grammar._root]
# multiple roots doesn't make sense, so we only consider the last defined
# one.
rule = rules[-1]
for part in rule['parts']:
if part['type'] == 'tag' and part[
'tagname'] == 'lines' and 'count' in part:
self.root = f'lines_{part["count"]}'
break
self._built_in_types_parser = {
'int': self._int_handler,
'int32': self._int_handler,
'uint32': self._int_handler,
'int8': self._int_handler,
'uint8': self._int_handler,
'int16': self._int_handler,
'uint16': self._int_handler,
'int64': self._int_handler,
'uint64': self._int_handler,
'float': self._default_handler,
'double': self._default_handler,
'char': self._default_handler,
'string': self._default_handler,
'htmlsafestring': self._default_handler,
'hex': self._default_handler,
'lines': self._lines_handler,
}
def parse_grammar(self):
for creator, rules in self.grammar._creators.items():
field_name = to_proto_field_name(creator)
type_name = to_proto_type(creator)
messages = self._parse_rule(creator, rules)
proto_fields: typing.List[ProtoField] = []
for proto_id, msg in enumerate(messages, start=1):
proto_fields.append(
ProtoField(type=ProtoType(name=msg.name),
name=f'{field_name}_{proto_id}',
proto_id=proto_id))
msg = OneOfProtoMessage(name=type_name,
oneofname='oneoffield',
fields=proto_fields)
cases = {
f.name: [
CppHandlerCallExpr(handler=f'{CPP_HANDLER_PREFIX}{f.type.name}',
field_name=f.name)
]
for f in proto_fields
}
func = CppOneOfMessageFunctionHandler(name=type_name,
switch_name='oneoffield',
cases=cases)
self._add(msg, func)
def all_proto_messages(self):
return [v.msg for v in self.handlers.values()]
def all_cpp_functions(self):
return [v.func for v in self.handlers.values()]
def get_line_prefix(self) -> str:
if not self.grammar._line_guard:
return ''
return self.grammar._line_guard.split('<line>')[0]
def get_line_suffix(self) -> str:
if not self.grammar._line_guard:
return ''
return self.grammar._line_guard.split('<line>')[1]
def should_generate_repeated_lines(self):
return self.root == 'lines'
def should_generate_one_line_handler(self):
return self.root.startswith('lines')
def maybe_add_lines_handler(self, number: int) -> bool:
name = f'lines_{number}'
if name in self.handlers:
return False
fields = []
exprs = []
for i in range(1, number + 1):
fields.append(ProtoField(ProtoType('line'), f'line_{i}', i))
exprs.append(CppHandlerCallExpr('handle_one_line', f'line_{i}'))
msg = ProtoMessage(name, fields=fields)
handler = CppProtoMessageFunctionHandler(name, exprs=exprs)
self.handlers[name] = DomatoBuilder.Entry(msg, handler)
return True
def get_roots(self) -> typing.Tuple[ProtoMessage, CppFunctionHandler]:
root = self.root
root_handler = f'{CPP_HANDLER_PREFIX}{root}'
fuzz_case = ProtoMessage(
name='fuzzcase',
fields=[ProtoField(type=ProtoType(name=root), name='root', proto_id=1)])
fuzz_fct = CppProtoMessageFunctionHandler(
name='fuzzcase',
exprs=[CppHandlerCallExpr(handler=root_handler, field_name='root')])
return fuzz_case, fuzz_fct
def get_protos(self) -> typing.Tuple[typing.List[ProtoMessage]]:
if self.should_generate_one_line_handler():
# We're handling a code grammar.
roots = [v.msg for k, v in self.handlers.items() if k.startswith('line')]
roots.append(self.get_roots()[0])
non_roots = [
v.msg for k, v in self.handlers.items() if not k.startswith('line')
]
return roots, non_roots
return [self.get_roots()[0]], self.all_proto_messages()
def simplify(self):
"""Simplifies the proto and functions."""
should_continue = True
while should_continue:
should_continue = False
should_continue |= self._merge_unary_oneofs()
should_continue |= self._merge_strings()
should_continue |= self._merge_multistrings_oneofs()
should_continue |= self._remove_unlinked_nodes()
should_continue |= self._merge_proto_messages()
should_continue |= self._merge_oneofs()
self._oneofs_reorderer()
self._oneof_message_renamer()
self._message_renamer()
def _add(self, message: ProtoMessage,
handler: CppProtoMessageFunctionHandler):
self.handlers[message.name] = DomatoBuilder.Entry(message, handler)
for field in message.fields:
if not field.type.name in self.backrefs:
self.backrefs[field.type.name] = []
self.backrefs[field.type.name].append(message.name)
def _int_handler(
self, part,
field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]:
proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']]
handler = DOMATO_TO_CPP_HANDLERS[part['tagname']]
extra_args = []
if 'min' in part:
extra_args.append(CppTxtExpression(part['min']))
if 'max' in part:
if not extra_args:
cpp_type = DOMATO_INT_TYPE_TO_CPP_INT_TYPE[part['tagname']]
extra_args.append(
CppTxtExpression(f'std::numeric_limits<{cpp_type}>::min()'))
extra_args.append(CppTxtExpression(part['max']))
contents = CppHandlerCallExpr(handler=handler,
field_name=field_name,
extra_args=extra_args)
return proto_type, contents
def _lines_handler(
self, part,
field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]:
handler_name = 'lines'
if 'count' in part:
count = part['count']
handler_name = f'{handler_name}_{count}'
self.maybe_add_lines_handler(int(part['count']))
proto_type = handler_name
contents = CppHandlerCallExpr(handler=f'handle_{handler_name}',
field_name=field_name)
return proto_type, contents
def _default_handler(
self, part,
field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]:
proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']]
handler = DOMATO_TO_CPP_HANDLERS[part['tagname']]
contents = CppHandlerCallExpr(handler=handler, field_name=field_name)
return proto_type, contents
def _parse_rule(self, creator_name, rules):
messages = []
for rule_id, rule in enumerate(rules, start=1):
rule_msg_field_name = f'{to_proto_field_name(creator_name)}_{rule_id}'
proto_fields = []
cpp_contents = []
ret_vars = 0
for part_id, part in enumerate(rule['parts'], start=1):
field_name = f'{rule_msg_field_name}_{part_id}'
proto_type = None
if rule['type'] == 'code' and 'new' in part:
proto_fields.insert(
0,
ProtoField(type=ProtoType('optional int32'),
name='old',
proto_id=part_id))
ret_vars += 1
continue
if part['type'] == 'text':
contents = CppStringExpr(part['text'])
elif part['tagname'] == 'import':
# The current domato project is currently not handling that either in
# its built-in rules, and I do not plan on using the feature with
# newly written rules, as I think this directive has a lot of
# constraints with not much added value.
continue
elif part['tagname'] == 'call':
raise Exception(
'DomatoLPM does not implement <call> and <import> tags.')
elif part['tagname'] in self.grammar._constant_types.keys():
contents = CppStringExpr(
self.grammar._constant_types[part['tagname']])
elif part['tagname'] in self._built_in_types_parser:
handler = self._built_in_types_parser[part['tagname']]
proto_type, contents = handler(part, field_name)
elif part['type'] == 'tag':
proto_type = to_proto_type(part['tagname'])
contents = CppHandlerCallExpr(
handler=f'{CPP_HANDLER_PREFIX}{proto_type}',
field_name=field_name)
if proto_type:
proto_fields.append(
ProtoField(type=ProtoType(name=proto_type),
name=field_name,
proto_id=part_id))
cpp_contents.append(contents)
if ret_vars > 1:
raise Exception('Not implemented.')
creator = None
if rule['type'] == 'code' and ret_vars > 0:
creator = {'var_type': creator_name, 'var_prefix': 'var'}
proto_type = to_proto_type(creator_name)
rule_msg = ProtoMessage(name=f'{proto_type}_{rule_id}',
fields=proto_fields)
rule_func = CppProtoMessageFunctionHandler(name=f'{proto_type}_{rule_id}',
exprs=cpp_contents,
creator=creator)
self._add(rule_msg, rule_func)
messages.append(rule_msg)
return messages
def _remove(self, name: str):
assert name in self.handlers
for field in self.handlers[name].msg.fields:
if field.type.name in self.backrefs:
self.backrefs[field.type.name].remove(name)
if name in self.backrefs:
self.backrefs.pop(name)
self.handlers.pop(name)
def _update(self, name: str):
assert name in self.handlers
for field in self.handlers[name].msg.fields:
if not field.type.name in self.backrefs:
self.backrefs[field.type.name] = []
self.backrefs[field.type.name].append(name)
def _count_backref(self, proto_name: str) -> int:
"""Counts the number of backreference a given proto message has.
Args:
proto_name: the proto message name.
Returns:
the number of backreferences.
"""
return len(self.backrefs[proto_name])
def _merge_proto_messages(self) -> bool:
"""Merges messages referencing other messages into the same message. This
allows to tremendously reduce the number of protobuf messages that will be
generated.
"""
to_merge = collections.defaultdict(set)
for name in self.handlers:
msg = self.handlers[name].msg
func = self.handlers[name].func
if msg.is_one_of() or not func.is_message_handler() or func.creates_new(
) or self._is_root_node(name):
continue
if name not in self.backrefs:
continue
for elt in self.backrefs[name]:
if elt == name or elt not in self.handlers:
continue
if self.handlers[elt].msg.is_one_of():
continue
to_merge[elt].add(name)
for parent, childs in to_merge.items():
msg = self.handlers[parent].msg
fct = self.handlers[parent].func
for child in childs:
new_contents = []
for expr in fct.exprs:
if isinstance(expr, CppStringExpr):
new_contents.append(expr)
continue
assert isinstance(expr, CppHandlerCallExpr)
field: ProtoField = next(
(f for f in msg.fields if f.type.name == child), None)
if not field or not expr.field_name == field.name:
new_contents.append(expr)
continue
self.backrefs[field.type.name].remove(msg.name)
idx = msg.fields.index(field)
field_msg = self.handlers[child].msg
field_fct = self.handlers[child].func
# The following deepcopy is required because we might change the
# child's messages fields at some point, and we don't want those
# changes to affect this current's message fields.
fields_copy = copy.deepcopy(field_msg.fields)
msg.fields = msg.fields[:idx] + fields_copy + msg.fields[idx + 1:]
new_contents += copy.deepcopy(field_fct.exprs)
for f in field_msg.fields:
self.backrefs[f.type.name].append(msg.name)
fct.exprs = new_contents
return len(to_merge) > 0
def _message_renamer(self):
"""Renames ProtoMessage fields that might have been merged. This ensures
proto field naming remains consistent with the current rule being
generated.
"""
for entry in self.handlers.values():
if entry.msg.is_one_of() or entry.func.is_string_table_handler():
continue
for proto_id, field in enumerate(entry.msg.fields, start=1):
field.proto_id = proto_id
if entry.func.creates_new() and field.name == 'old':
continue
field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}')
index = 2 if entry.func.creates_new() else 1
new_contents = []
for expr in entry.func.exprs:
if not isinstance(expr, CppHandlerCallExpr):
new_contents.append(expr)
continue
new_contents.append(
CppHandlerCallExpr(expr.handler,
to_proto_field_name(f'{entry.msg.name}_{index}'),
expr.extra_args))
index += 1
entry.func.exprs = new_contents
def _oneof_message_renamer(self):
"""Renames OneOfProtoMessage fields that might have been merged. This
ensures proto field naming remains consistent with the current rule being
generated.
"""
for entry in self.handlers.values():
if not entry.msg.is_one_of():
continue
cases = {}
for proto_id, field in enumerate(entry.msg.fields, start=1):
field.proto_id = proto_id
exprs = entry.func.cases.pop(field.name)
field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}')
new_contents = []
for expr in exprs:
if not isinstance(expr, CppHandlerCallExpr):
new_contents.append(expr)
continue
new_contents.append(
CppHandlerCallExpr(expr.handler, field.name, expr.extra_args))
cases[field.name] = new_contents
entry.func.cases = cases
def _merge_multistrings_oneofs(self) -> bool:
"""Merges multiple strings into a string table function."""
has_made_changes = False
for name in list(self.handlers.keys()):
msg = self.handlers[name].msg
if not msg.is_one_of():
continue
if not all(f.type.name in self.handlers and len(self.handlers[
f.type.name].msg.fields) == 0 and not self.handlers[f.type.name].msg.
is_one_of() and len(self.handlers[f.type.name].func.exprs) == 1
for f in msg.fields):
continue
fields = [ProtoField(type=ProtoType('uint32'), name='val', proto_id=1)]
new_msg = ProtoMessage(name=msg.name, fields=fields)
strings = []
for field in msg.fields:
self.backrefs[field.type.name].remove(name)
for expr in self.handlers[field.type.name].func.exprs:
assert isinstance(expr, CppStringExpr)
strings += [expr]
new_func = CppStringTableHandler(name=msg.name,
var_name='val',
strings=strings)
self.handlers[name] = DomatoBuilder.Entry(new_msg, new_func)
self._update(name)
has_made_changes = True
return has_made_changes
def _oneofs_reorderer(self):
"""Reorders the OneOfProtoMessage so that the last element can be extracted
out of the protobuf oneof's field in order to always have a correct
path to be generated. This requires having at least one terminal path in
the grammar.
"""
_terminal_messages = set()
_being_visited = set()
def recursive_terminal_marker(name: str):
if name in _terminal_messages or name not in self.handlers:
return True
if name in _being_visited:
return False
_being_visited.add(name)
msg = self.handlers[name].msg
func = self.handlers[name].func
if len(msg.fields) == 0:
_terminal_messages.add(name)
_being_visited.remove(name)
return True
if msg.is_one_of():
f = next(
(f for f in msg.fields if recursive_terminal_marker(f.type.name)),
None)
if not f:
#FIXME: for testing purpose only, we're not hard-failing on this.
_being_visited.remove(name)
return False
msg.fields.remove(f)
msg.fields.append(f)
m = next(k for k in func.cases.keys() if k == f.name)
func.cases[m] = func.cases.pop(m)
_terminal_messages.add(name)
_being_visited.remove(name)
return True
res = all(recursive_terminal_marker(f.type.name) for f in msg.fields)
#FIXME: for testing purpose only, we're not hard-failing on this.
_being_visited.remove(name)
return res
for name in self.handlers:
recursive_terminal_marker(name)
def _merge_oneofs(self) -> bool:
has_made_changes = False
for name in list(self.handlers.keys()):
msg = self.handlers[name].msg
func = self.handlers[name].func
if not msg.is_one_of():
continue
for field in msg.fields:
if not field.type.name in self.handlers:
continue
field_msg = self.handlers[field.type.name].msg
field_func = self.handlers[field.type.name].func
if field_msg.is_one_of() or len(
field_msg.fields) != 1 or not field_func.is_message_handler(
) or field_func.creates_new():
continue
func.cases.pop(field.name)
field.name = field_msg.fields[0].name
field.type = field_msg.fields[0].type
while field.name in func.cases:
field.name += '_1'
func.cases[field.name] = copy.deepcopy(field_func.exprs)
self.backrefs[field_msg.name].remove(name)
self.backrefs[field.type.name].append(name)
has_made_changes = True
return has_made_changes
def _merge_unary_oneofs(self) -> bool:
"""Transfors OneOfProtoMessage messages containing only one field into a
ProtoMessage containing the fields of the contained message. E.g.:
message B {
int field1 = 1;
Whatever field2 = 2;
}
message A {
oneof field {
B b = 1;
}
}
Into:
message A {
int field1 = 1;
Whatever field2 = 2;
}
"""
has_made_changes = False
for name in list(self.handlers.keys()):
msg = self.handlers[name].msg
func = self.handlers[name].func
if not msg.is_one_of() or len(msg.fields) > 1:
continue
# The message is a unary oneof. Let's make sure it's only child doesn't
# have backrefs.
if self._count_backref(msg.fields[0].type.name) > 1:
continue
# The only backref should really only be us. If not we screwed up
# somewhere else.
assert name in self.backrefs[msg.fields[0].type.name]
field_msg: ProtoMessage = self.handlers[msg.fields[0].type.name].msg
if field_msg.is_one_of():
continue
field_func = self.handlers[msg.fields[0].type.name].func
self._remove(msg.fields[0].type.name)
msg = ProtoMessage(name=msg.name, fields=field_msg.fields)
func = CppProtoMessageFunctionHandler(name=msg.name,
exprs=field_func.exprs,
creator=field_func.creator)
self.handlers[name] = DomatoBuilder.Entry(msg, func)
self._update(name)
has_made_changes = True
return has_made_changes
def _merge_strings(self) -> bool:
"""Merges following CppString, e.g.
[ CppString("<first>"), CppString("<second>")]
Into:
[ CppString("<first><second>")]
"""
has_made_changes = False
for name in self.handlers:
func: CppFunctionHandler = self.handlers[name].func
if not func.is_message_handler() or len(func.exprs) <= 1:
continue
exprs = []
prev = func.exprs[0]
for i in range(1, len(func.exprs)):
cur = func.exprs[i]
if isinstance(prev, CppStringExpr) and isinstance(cur, CppStringExpr):
cur = CppStringExpr(prev.content + cur.content)
has_made_changes = True
else:
exprs.append(prev)
prev = cur
exprs.append(prev)
func.exprs = exprs
return has_made_changes
def _is_root_node(self, name: str):
# If there is no existing root, we set it to `lines`, since this will
# be picked as the default root.
if 'line' not in self.root:
return self.root == name
return re.match('^line(s)?(_[0-9]*)?$', name) is not None
def _remove_unlinked_nodes(self) -> bool:
"""Removes proto messages that are neither part of the root definition nor
referenced by any other messages. This can happen during other optimization
functions.
Returns:
whether a change was made.
"""
to_remove = set()
for name in self.handlers:
if name not in self.backrefs or len(self.backrefs[name]) == 0:
if not self._is_root_node(name):
to_remove.add(name)
local_root = 'line' if self.should_generate_one_line_handler(
) else self.root
seen = set()
def visit_msg(msg: ProtoMessage):
if msg.name in seen:
return
seen.add(msg.name)
for field in msg.fields:
if field.type.name in self.handlers:
visit_msg(self.handlers[field.type.name].msg)
visit_msg(self.handlers[local_root].msg)
not_seen = set(self.handlers.keys()) - seen
to_remove.update(set(filter(lambda x: not self._is_root_node(x), not_seen)))
for t in to_remove:
self._remove(t)
return len(to_remove) > 0
def _render_internal(template: jinja2.Template,
context: typing.Dict[str, typing.Any], out_f: str):
with action_helpers.atomic_output(out_f, mode='w') as f:
f.write(template.render(context))
def _render_proto_internal(
template: jinja2.Template, out_f: str,
proto_messages: typing.List[typing.Union[ProtoMessage, OneOfProtoMessage]],
should_generate_repeated_lines: bool, proto_ns: str,
imports: typing.List[str]):
_render_internal(template, {
'messages': [m for m in proto_messages if not m.is_one_of()],
'oneofmessages': [m for m in proto_messages if m.is_one_of()],
'generate_repeated_lines': should_generate_repeated_lines,
'proto_ns': proto_ns,
'imports': imports,
},
out_f=out_f)
def render_proto(environment: jinja2.Environment, generated_dir: str,
out_f: str, name: str, builder: DomatoBuilder):
template = environment.get_template('domatolpm.proto.tmpl')
roots, non_roots = builder.get_protos()
ns = f'{BASE_PROTO_NS}.{name}'
sub_proto_filename = pathlib.PurePosixPath(f'{out_f}_sub.proto').name
import_path = pathlib.PurePosixPath(generated_dir).joinpath(
sub_proto_filename)
_render_proto_internal(template, f'{out_f}.proto', roots,
builder.should_generate_repeated_lines(), ns,
[str(import_path)])
_render_proto_internal(template, f'{out_f}_sub.proto', non_roots, False, ns,
[])
def render_cpp(environment: jinja2.Environment, out_f: str, name: str,
builder: DomatoBuilder):
functions = builder.all_cpp_functions()
funcs = [f for f in functions if f.is_message_handler()]
oneofs = [f for f in functions if f.is_oneof_handler()]
stfunctions = [f for f in functions if f.is_string_table_handler()]
_, root_func = builder.get_roots()
rendering_context = {
'basename': os.path.basename(out_f),
'functions': funcs,
'oneoffunctions': oneofs,
'stfunctions': stfunctions,
'root': root_func,
'generate_repeated_lines': builder.should_generate_repeated_lines(),
'generate_one_line_handler': builder.should_generate_one_line_handler(),
'line_prefix': builder.get_line_prefix(),
'line_suffix': builder.get_line_suffix(),
'proto_ns': to_cpp_ns(f'{BASE_PROTO_NS}.{name}'),
'cpp_ns': f'domatolpm::{name}',
}
template = environment.get_template('domatolpm.cc.tmpl')
_render_internal(template, rendering_context, f'{out_f}.cc')
template = environment.get_template('domatolpm.h.tmpl')
_render_internal(template, rendering_context, f'{out_f}.h')
def main():
parser = argparse.ArgumentParser(
description=
'Generate the necessary files for DomatoLPM to function properly.')
parser.add_argument('-p',
'--path',
required=True,
help='The path to a Domato grammar file.')
parser.add_argument('-n',
'--name',
required=True,
help='The name of this grammar.')
parser.add_argument(
'-f',
'--file-format',
required=True,
help='The path prefix to which the files should be generated.')
parser.add_argument('-d',
'--generated-dir',
required=True,
help='The path to the target gen directory.')
args = parser.parse_args()
g = grammar.Grammar()
g.parse_from_file(filename=args.path)
template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'templates')
environment = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir))
builder = DomatoBuilder(g)
builder.parse_grammar()
builder.simplify()
render_cpp(environment, args.file_format, args.name, builder)
render_proto(environment, args.generated_dir, args.file_format, args.name,
builder)
if __name__ == '__main__':
main()