|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import argparse
- import collections
- import hashlib
- import io
- import os
- import struct
- import textwrap
-
- from gen_param_defs import IndentWriterBase, ParamDef, member_defs
-
-
- class ConverterWriter(IndentWriterBase):
- _skip_current_param = False
- _last_param = None
- _param_fields = None
- _fb_fields = []
-
- def __call__(self, fout, defs):
- super().__call__(fout)
- self._write("// %s", self._get_header())
- self._write("#include <flatbuffers/flatbuffers.h>")
- self._write("namespace mgb {")
- self._write("namespace serialization {")
- self._write("namespace fbs {")
- self._process(defs)
- self._write("} // namespace fbs")
- self._write("} // namespace serialization")
- self._write("} // namespace mgb")
-
- def _on_param_begin(self, p):
- self._last_param = p
- self._param_fields = []
- self._fb_fields = ["builder"]
- self._write(
- "template<>\nstruct ParamConverter<megdnn::param::%s> {", p.name, indent=1
- )
- self._write("using MegDNNType = megdnn::param::%s;", p.name)
- self._write("using FlatBufferType = fbs::param::%s;\n", p.name)
-
- def _on_param_end(self, p):
- if self._skip_current_param:
- self._skip_current_param = False
- return
- self._write("static MegDNNType to_param(const FlatBufferType* fb) {", indent=1)
- line = "return {"
- line += ", ".join(self._param_fields)
- line += "};"
- self._write(line)
- self._write("}\n", indent=-1)
-
- self._write(
- "static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {",
- indent=1,
- )
- line = "return fbs::param::Create{}(".format(str(p.name))
- line += ", ".join(self._fb_fields)
- line += ");"
- self._write(line)
- self._write("}", indent=-1)
-
- self._write("};\n", indent=-1)
-
- def _on_member_enum(self, e):
- p = self._last_param
- key = str(p.name) + str(e.name)
- if self._skip_current_param:
- return
- self._param_fields.append(
- "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
- str(p.name), str(e.name), e.name_field
- )
- )
- self._fb_fields.append(
- "static_cast<fbs::param::{}>(param.{})".format(key, e.name_field)
- )
-
- def _on_member_field(self, f):
- if self._skip_current_param:
- return
- if f.dtype.cname == "DTypeEnum":
- self._param_fields.append(
- "intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)
- )
- self._fb_fields.append(
- "intl::convert_dtype_to_fbs(param.{})".format(f.name)
- )
- else:
- self._param_fields.append("fb->{}()".format(f.name))
- self._fb_fields.append("param.{}".format(f.name))
-
- def _on_const_field(self, f):
- pass
-
- def _on_member_enum_alias(self, e):
- if self._skip_current_param:
- return
- enum_name = e.src_class + e.src_name
- self._param_fields.append(
- "static_cast<megdnn::param::{}::{}>(fb->{}())".format(
- e.src_class, e.src_name, e.name_field
- )
- )
- self._fb_fields.append(
- "static_cast<fbs::param::{}>(param.{})".format(enum_name, e.name_field)
- )
-
-
- def main():
- parser = argparse.ArgumentParser(
- "generate convert functions between FlatBuffers type and MegBrain type"
- )
- parser.add_argument("input")
- parser.add_argument("output")
- args = parser.parse_args()
-
- with open(args.input) as fin:
- inputs = fin.read()
- exec(inputs, {"pdef": ParamDef, "Doc": member_defs.Doc})
- input_hash = hashlib.sha256()
- input_hash.update(inputs.encode(encoding="UTF-8"))
- input_hash = input_hash.hexdigest()
-
- writer = ConverterWriter()
- with open(args.output, "w") as fout:
- writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs)
-
-
- if __name__ == "__main__":
- main()
|