|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
-
- import argparse
- import itertools
- import os
-
- PREFIXES = {
- "dp4a": [
- ("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True),
- ("batch_conv_bias_int8_gemm_ncdiv4hw4", False),
- ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False),
- ]
- }
-
- ACTIVATIONS = {1: ("IDENTITY", "_id"), 2: ("RELU", "_relu"), 3: ("H_SWISH", "_hswish")}
-
- BIASES = {
- 1: ("PerElementBiasVisitor", "_per_elem"),
- 2: ("PerChannelBiasVisitor", "_per_chan"),
- }
-
- SUFFIXES = {"dp4a": [""], "imma": [""]}
-
-
- def main():
- parser = argparse.ArgumentParser(
- description="generate cuda batch conv bias (dp4a/imma) kern impl files",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--type",
- type=str,
- choices=["dp4a", "imma"],
- default="dp4a",
- help="generate cuda conv bias kernel file",
- )
- parser.add_argument("output", help="output directory")
- args = parser.parse_args()
-
- if not os.path.isdir(args.output):
- os.makedirs(args.output)
-
- inst = """
- template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS,
- IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>(
- const int8_t* d_src,
- const int8_t* d_filter, WORKSPACE
- BIAS bias,
- IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>> epilogue,
- const ConvParam& param,
- float alpha,
- float beta,
- cudaStream_t stream);"""
-
- for prefix in PREFIXES[args.type]:
- for suffix in SUFFIXES[args.type]:
- for _, act in ACTIVATIONS.items():
- has_workspace = prefix[1]
- bias = BIASES[2]
- fname = "{}{}{}{}.cu".format(prefix[0], suffix, bias[1], act[1])
- fname = os.path.join(args.output, fname)
- with open(fname, "w") as fout:
- w = lambda s: print(s, file=fout)
- w("// generated by gen_batch_cuda_conv_bias_kern_impls.py")
- cur_inst = (
- inst.replace("PREFIX", prefix[0])
- .replace("SUFFIX", suffix)
- .replace("BIAS", bias[0])
- .replace("ACTIVATION", act[0])
- )
- if has_workspace:
- cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ")
- else:
- cur_inst = cur_inst.replace("WORKSPACE", "")
- w('#include "../{}{}.cuinl"'.format(prefix[0], suffix))
- w(cur_inst)
-
- print("generated {}".format(fname))
- os.utime(args.output)
-
-
- if __name__ == "__main__":
- main()
|