|
|
@@ -20,12 +20,11 @@ import mindspore.common.dtype as mstype |
|
|
|
from mindspore import Tensor, context, load_checkpoint, export |
|
|
|
|
|
|
|
from src.finetune_eval_config import bert_net_cfg |
|
|
|
from src.finetune_eval_model import BertCLSModel |
|
|
|
from src.bert_for_finetune import BertCLS |
|
|
|
parser = argparse.ArgumentParser(description="Bert export") |
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id") |
|
|
|
parser.add_argument("--batch_size", type=int, default=16, help="batch size") |
|
|
|
parser.add_argument("--number_labels", type=int, default=16, help="batch size") |
|
|
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
|
|
|
parser.add_argument("--number_labels", type=int, default=26, help="batch size") |
|
|
|
parser.add_argument("--ckpt_file", type=str, required=True, help="Bert ckpt file.") |
|
|
|
parser.add_argument("--file_name", type=str, default="Bert", help="bert output air name.") |
|
|
|
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") |
|
|
|