From fc8b356e5d257b06a6207fe47daaa05bf362acab Mon Sep 17 00:00:00 2001 From: xiahb <940382006@qq.com> Date: Fri, 3 Feb 2023 16:37:03 +0800 Subject: [PATCH] Blank space modified --- app.py | 4 +- config/config.json | 2 +- model/import_model.py | 70 +++++++++++----------------------- model/model_export.py | 6 +-- model/model_inference.py | 6 ++- model/rewrite_yml.py | 2 +- test/test_gene_meta.py | 10 +++-- test/test_import_caffe.py | 14 ++++--- test/test_import_darknet.py | 20 ++++++---- test/test_import_keras.py | 10 +++-- test/test_import_onnx.py | 15 ++++---- test/test_import_pytorch.py | 16 ++++---- test/test_import_tensorflow.py | 6 ++- test/test_import_tflite.py | 14 +++---- test/test_model_export.py | 44 ++++++++++----------- test/test_model_inference.py | 20 +++++----- test/test_model_measure.py | 4 +- test/test_model_quantize.py | 22 +++++++---- utils/file_op.py | 32 ++++++++++------ utils/model_op.py | 39 ++++++++++--------- utils/quantize_op.py | 28 ++++++++------ 21 files changed, 199 insertions(+), 185 deletions(-) diff --git a/app.py b/app.py index dc5e50f..0a6b980 100644 --- a/app.py +++ b/app.py @@ -29,15 +29,15 @@ handler.setFormatter(logging_format) app.logger.addHandler(handler) # Add it to the built-in logger app.logger.setLevel(logging.DEBUG) + class get_guid(Resource): def post(self): args = get_guid_parser.parse_args() print(args) - res = {"status":"sucess","code":200, "guid": str(uuid.uuid4())} + res = {"status": "sucess", "code": 200, "guid": str(uuid.uuid4())} return jsonify(res) - api.add_resource(get_guid, '/v1/model/get_guid') diff --git a/config/config.json b/config/config.json index 935d00a..9d799d1 100644 --- a/config/config.json +++ b/config/config.json @@ -2,6 +2,6 @@ "UPLOAD_FOLDER":"/tmp/model", "MODEL_FILE_NAME": "input_model", "EXPORT_RELATIVE_PATH":"export_file", - "EXPORT_FILE_NAME":"export_file.zip" + "EXPORT_FILE_NAME":"export_file.zip", "INPUTMETA_FILE_NAME":"input_model_inputmeta.yml" } \ No newline at end of file diff --git a/model/import_model.py b/model/import_model.py index 2f5cfff..a18cfb7 100644 --- a/model/import_model.py +++ b/model/import_model.py @@ -9,12 +9,9 @@ def import_caffe(model, weights): parser.add_argument('--which', type=str, default='caffe') parser.add_argument('model', type=str, help='Input model') parser.add_argument('weights', type=str, help='The model weights') - parser.add_argument('--output_data', type=str, default=None, - help='The path of model_data export') - parser.add_argument('--output_model', type=str, - default=None, help='The path of model export') - parser.add_argument('--proto', type=str, - choices=['caffe', 'lstm_caffe'], default='caffe') + parser.add_argument('--output_data', type=str, default=None) + parser.add_argument('--output_model', type=str, default=None) + parser.add_argument('--proto', type=str, default='caffe') ret = 0 try: args = parser.parse_args([model, weights]) @@ -64,17 +61,12 @@ def import_tensorflow(model, inputs, input_size_list, outputs): def import_tflite(model): parser = ArgumentParser() - parser.add_argument('--import', type=str, - default='tflite', - help='The framework of neural network model') - parser.add_argument('--which', type=str, - default='tflite') + parser.add_argument('--import', type=str, default='tflite') + parser.add_argument('--which', type=str, default='tflite') parser.add_argument('model', type=str, help='Input model') parser.add_argument('--outputs', type=str, default=None) - parser.add_argument('--output_data', type=str, default=None, - help='The path of model_data export') - parser.add_argument('--output_model', type=str, default=None, - help='The path of model export') + parser.add_argument('--output_data', type=str, default=None) + parser.add_argument('--output_model', type=str, default=None) ret = 0 try: args = parser.parse_args([model]) @@ -91,15 +83,12 @@ def import_tflite(model): def import_darknet(model, weights): parser = ArgumentParser() - parser.add_argument('--import', type=str, default='darknet', - help='The framework of neural network model') + parser.add_argument('--import', type=str, default='darknet') parser.add_argument('--which', type=str, default='darknet') parser.add_argument('model', type=str, help='Input model') parser.add_argument('weights', type=str, default=None) - parser.add_argument('--output_data', type=str, default=None, - help='The path of model_data export') - parser.add_argument('--output_model', type=str, default=None, - help='The path of model export') + parser.add_argument('--output_data', type=str, default=None) + parser.add_argument('--output_model', type=str, default=None) ret = 0 try: args = parser.parse_args([model, weights]) @@ -117,21 +106,16 @@ def import_darknet(model, weights): def import_onnx(model): parser = ArgumentParser() - parser.add_argument('--import', type=str, - default='onnx', - help='The framework of neural network model') - parser.add_argument('--which', type=str, - default='onnx') + parser.add_argument('--import', type=str, default='onnx') + parser.add_argument('--which', type=str, default='onnx') parser.add_argument('model', type=str, help='Input model') parser.add_argument('--inputs', type=str, default=None) parser.add_argument('--input_size_list', type=str, default=None) parser.add_argument('--outputs', type=str, default=None) parser.add_argument('--size_with_batch', type=str, default=None) parser.add_argument('--input_dtype_list', type=str, default=None) - parser.add_argument('--output_data', type=str, default=None, - help='The path of model_data export') - parser.add_argument('--output_model', type=str, - default=None, help='The path of model export') + parser.add_argument('--output_data', type=str, default=None) + parser.add_argument('--output_model', type=str, default=None) ret = 0 try: args = parser.parse_args([model]) @@ -149,9 +133,7 @@ def import_onnx(model): def import_pytorch(model, input_size_list): parser = ArgumentParser() - parser.add_argument('--import', type=str, - default='pytorch', - help='The framework of neural network model') + parser.add_argument('--import', type=str, default='pytorch',) parser.add_argument('--which', type=str, default='pytorch') parser.add_argument('model', type=str, help='Input model') parser.add_argument('input_size_list', type=str, default=None) @@ -159,10 +141,8 @@ def import_pytorch(model, input_size_list): parser.add_argument('--outputs', type=str, default=None) parser.add_argument('--size_with_batch', type=str, default=None) parser.add_argument('--config', type=str, default=None) - parser.add_argument('--output_data', type=str, default=None, - help='The path of model_data export') - parser.add_argument('--output_model', type=str, default=None, - help='The path of model export') + parser.add_argument('--output_data', type=str, default=None) + parser.add_argument('--output_model', type=str, default=None) ret = 0 try: args = parser.parse_args([model, input_size_list]) @@ -180,21 +160,15 @@ def import_pytorch(model, input_size_list): def import_keras(model): parser = ArgumentParser() - parser.add_argument('--import', type=str, - default='keras', - help='The framework of neural network model') - parser.add_argument('--which', type=str, - default='keras') - parser.add_argument('model', type=str, - help='Input model') + parser.add_argument('--import', type=str, default='keras') + parser.add_argument('--which', type=str, default='keras') + parser.add_argument('model', type=str, help='Input model') parser.add_argument('--convert-engine', type=str, default='keras') parser.add_argument("--inputs", type=str, default=None) parser.add_argument('--input_size_list', type=str, default=None) parser.add_argument('--outputs', type=str, default=None) - parser.add_argument('--output_data', type=str, default=None, - help='The path of model_data export') - parser.add_argument('--output_model', type=str, default=None, - help='The path of model export') + parser.add_argument('--output_data', type=str, default=None) + parser.add_argument('--output_model', type=str, default=None) ret = 0 try: args = parser.parse_args([model]) diff --git a/model/model_export.py b/model/model_export.py index a33960a..c282574 100644 --- a/model/model_export.py +++ b/model/model_export.py @@ -29,9 +29,9 @@ def model_export_fun(model, model_data, output_path, with_input_meta): ret = 0 try: args = parser.parse_args([model, - model_data, - output_path, - with_input_meta]) + model_data, + output_path, + with_input_meta]) ret = exporter.execute(args) except Exception as e: if isinstance(e, ValueError): diff --git a/model/model_inference.py b/model/model_inference.py index d672461..11f3d31 100644 --- a/model/model_inference.py +++ b/model/model_inference.py @@ -14,8 +14,10 @@ def model_inference_fun(model, model_data, with_input_meta): parser.add_argument('model_data', type=str) parser.add_argument('--model_quantize') parser.add_argument('--output_dir') - parser.add_argument('--postprocess', type=str, - default='classification_classic') + parser.add_argument('--postprocess', + type=str, + default='classification_classic' + ) parser.add_argument('--postprocess_file') parser.add_argument('--which', type=str, default='inference') parser.add_argument('with_input_meta', type=str) diff --git a/model/rewrite_yml.py b/model/rewrite_yml.py index 655fdc4..073a46d 100644 --- a/model/rewrite_yml.py +++ b/model/rewrite_yml.py @@ -4,6 +4,7 @@ import yaml from flask import current_app as app + def set_yml_parameter(guid, mean_value, standard_value): model_dir = os.path.join(app.config["UPLOAD_FOLDER"], guid) yml_path = os.path.join(model_dir, app.config["INPUTMETA_FILE_NAME"]) @@ -18,4 +19,3 @@ def set_yml_parameter(guid, mean_value, standard_value): with open(yml_path, 'w') as f: yaml.safe_dump(doc, f, default_flow_style=False) - \ No newline at end of file diff --git a/test/test_gene_meta.py b/test/test_gene_meta.py index cb8bbcd..751d0c9 100644 --- a/test/test_gene_meta.py +++ b/test/test_gene_meta.py @@ -6,10 +6,12 @@ from model.gen_mata import generate_yml class test_generate_meta(unittest.TestCase): model_dir = "../model_file" - generate_correct_yml_file_path = os.path.join(model_dir, - "model_data_json/caffe_inputmeta.yml") - model_file_path = os.path.join(model_dir, - "model_data_json/caffe.json") + generate_correct_yml_file_path = os.path.join( + model_dir, + "model_data_json/caffe_inputmeta.yml") + model_file_path = os.path.join( + model_dir, + "model_data_json/caffe.json") error_path = os.path.join(model_dir, "caffe/test.pb") def setUp(self): diff --git a/test/test_import_caffe.py b/test/test_import_caffe.py index 4f6e475..d0cda6a 100644 --- a/test/test_import_caffe.py +++ b/test/test_import_caffe.py @@ -17,7 +17,7 @@ class test_import_caffe(unittest.TestCase): model_weights_file_postfix_error = os.path.join(model_dir, "test_w.onnx") model_file_content_error = os.path.join(model_dir, "test_error.prototxt") model_weights_file_content_error = os.path.join(model_dir, - "test_error.caffemodel") + "test_error.caffemodel") error_path = os.path.join(model_dir, "test.txt") def setUp(self): @@ -57,8 +57,10 @@ class test_import_caffe(unittest.TestCase): ret = import_caffe(self.model_file_postfix_errorr_2, self.model_weights_path) self.assertTrue(ret is None) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_json)) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_data)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_json)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_data)) def test_import_caffe_weight_file_postfix_error_but_content_correct(self): ret = import_caffe(self.model_file_path, @@ -71,8 +73,10 @@ class test_import_caffe(unittest.TestCase): ret = import_caffe(self.model_file_postfix_error_1, self.model_weights_file_postfix_error) self.assertTrue(ret is None) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_json)) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_data)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_json)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_data)) def test_import_caffe_weight_and_model_file_content_error(self): ret = import_caffe(self.model_file_content_error, diff --git a/test/test_import_darknet.py b/test/test_import_darknet.py index 8a6e758..83da5aa 100644 --- a/test/test_import_darknet.py +++ b/test/test_import_darknet.py @@ -8,17 +8,17 @@ class testimport_darknet(unittest.TestCase): model_dir = "../model_file/darknet" model_json_path = os.path.join(model_dir, "test.json") model_data_path = os.path.join(model_dir, "test.data") - model_file_postfix_error_gene_json = os.path.join(model_dir, - "test_correct.json") + model_file_postfix_error_gene_json = os.path.join(model_dir, + "test_correct.json") model_file_postfix_error_gene_data = os.path.join(model_dir, - "test_correct.data") + "test_correct.data") model_file_path = os.path.join(model_dir, "test.cfg") model_weights_path = os.path.join(model_dir, "test.weights") model_file_postfix_error = os.path.join(model_dir, "test_correct.pb") model_weights_file_postfix_error = os.path.join(model_dir, "test_w.h5") model_file_content_error = os.path.join(model_dir, "test_error.cfg") model_weights_file_content_error = os.path.join(model_dir, - "test_error.weights") + "test_error.weights") error_path = os.path.join(model_dir, "te.txt") def setUp(self): @@ -52,8 +52,10 @@ class testimport_darknet(unittest.TestCase): ret = import_darknet(self.model_file_postfix_error, self.model_weights_path) self.assertTrue(ret is None) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_json)) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_data)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_json)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_data)) def test_import_darknet_weight_file_postfix_error(self): ret = import_darknet(self.model_file_path, @@ -66,8 +68,10 @@ class testimport_darknet(unittest.TestCase): ret = import_darknet(self.model_file_postfix_error, self.model_weights_file_postfix_error) self.assertTrue(ret is None) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_json)) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_data)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_json)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_data)) def test_import_darknet_weight_and_model_content_error(self): ret = import_darknet(self.model_file_content_error, diff --git a/test/test_import_keras.py b/test/test_import_keras.py index 1ba4be2..fc5de2f 100644 --- a/test/test_import_keras.py +++ b/test/test_import_keras.py @@ -8,10 +8,12 @@ class test_import_keras(unittest.TestCase): model_dir = "../model_file/keras" model_json_path = os.path.join(model_dir, "test.json") model_data_path = os.path.join(model_dir, "test.data") - model_file_postfix_error_json = os.path.join(model_dir, - "test_correct.json") - model_file_postfix_error_data = os.path.join(model_dir, - "test_correct.data") + model_file_postfix_error_json = os.path.join( + model_dir, + "test_correct.json") + model_file_postfix_error_data = os.path.join( + model_dir, + "test_correct.data") model_file_path = os.path.join(model_dir, "test.h5") model_file_postfix_error = os.path.join(model_dir, "test_correct.pb") model_file_content_error_1 = os.path.join(model_dir, "test_error.h5") diff --git a/test/test_import_onnx.py b/test/test_import_onnx.py index b7c535b..de0f1ea 100644 --- a/test/test_import_onnx.py +++ b/test/test_import_onnx.py @@ -8,10 +8,10 @@ class test_import_onnx(unittest.TestCase): model_dir = "../model_file/onnx" model_json_path = os.path.join(model_dir, "test.json") model_data_path = os.path.join(model_dir, "test.data") - model_file_postfix_error_gene_json = os.path.join(model_dir, - "test_correct.json") - model_file_postfix_error_gene_data = os.path.join(model_dir, - "test_correct.data") + model_file_postfix_error_gene_json = os.path.join(model_dir, + "test_correct.json") + model_file_postfix_error_gene_data = os.path.join(model_dir, + "test_correct.data") model_file_path = os.path.join(model_dir, "test.onnx") model_file_postfix_error = os.path.join(model_dir, "test_correct.pb") model_file_content_error = os.path.join(model_dir, "test_error.onnx") @@ -19,7 +19,6 @@ class test_import_onnx(unittest.TestCase): model_file_error_2 = os.path.join(model_dir, "test_error_1.pb") error_path = os.path.join(model_dir, "te.pb") - def setUp(self): if os.path.exists(self.model_json_path): os.remove(self.model_json_path) @@ -49,8 +48,10 @@ class test_import_onnx(unittest.TestCase): def test_import_onnx_model_file_postfix_error_but_content_correct(self): ret = import_onnx(self.model_file_postfix_error) self.assertTrue(ret is None) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_json)) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_data)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_json)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_data)) def test_import_onnx_model_file_postfix_correct_but_content_error_1(self): ret = import_onnx(self.model_file_content_error) diff --git a/test/test_import_pytorch.py b/test/test_import_pytorch.py index aa03af9..5e4254d 100644 --- a/test/test_import_pytorch.py +++ b/test/test_import_pytorch.py @@ -8,10 +8,10 @@ class test_import_pytorch(unittest.TestCase): model_dir = "../model_file/pytorch" model_json_path = os.path.join(model_dir, "test.json") model_data_path = os.path.join(model_dir, "test.data") - model_file_postfix_error_gene_json = os.path.join(model_dir, - "test_correct.json") - model_file_postfix_error_gene_data = os.path.join(model_dir, - "test_correct.data") + model_file_postfix_error_gene_json = os.path.join(model_dir, + "test_correct.json") + model_file_postfix_error_gene_data = os.path.join(model_dir, + "test_correct.data") model_file_path = os.path.join(model_dir, "test.pth") model_file_postfix_error = os.path.join(model_dir, "test_correct.onnx") model_file_content_error_1 = os.path.join(model_dir, "test_error.pth") @@ -19,7 +19,7 @@ class test_import_pytorch(unittest.TestCase): error_path = os.path.join(model_dir, "test_error.txt") def setUp(self): - if os.path.exists (self.model_json_path): + if os.path.exists(self.model_json_path): os.remove(self.model_json_path) if os.path.exists(self.model_data_path): os.remove(self.model_data_path) @@ -47,8 +47,10 @@ class test_import_pytorch(unittest.TestCase): def test_import_pytorch_model_file_postfix_error_but_content_correct(self): ret = import_pytorch(self.model_file_postfix_error, '3,224,224') self.assertTrue(ret is None) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_json)) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_data)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_json)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_data)) def test_import_pytorch_model_file_postfix_correct_but_content_error(self): ret = import_pytorch(self.model_file_content_error_1, '3,224,224') diff --git a/test/test_import_tensorflow.py b/test/test_import_tensorflow.py index 376b664..b3cdd59 100644 --- a/test/test_import_tensorflow.py +++ b/test/test_import_tensorflow.py @@ -64,8 +64,10 @@ class test_import_tensorflow(unittest.TestCase): '300,300,3', 'concat') self.assertTrue(ret is None) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_json)) - self.assertTrue(os.path.exists(self.model_file_postfix_error_gene_data)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_json)) + self.assertTrue(os.path.exists( + self.model_file_postfix_error_gene_data)) def test_import_tensorflow_model_file_postfix_correct_but_content_error( self): diff --git a/test/test_import_tflite.py b/test/test_import_tflite.py index f2c1ab6..609e16e 100644 --- a/test/test_import_tflite.py +++ b/test/test_import_tflite.py @@ -9,19 +9,19 @@ class test_import_tflite(unittest.TestCase): model_json_path = os.path.join(model_dir, "test.json") model_data_path = os.path.join(model_dir, "test.data") model_file_postfix_error_json = os.path.join(model_dir, - "test_correct.json") + "test_correct.json") model_file_postfix_error_data = os.path.join(model_dir, - "test_correct.data") + "test_correct.data") model_file_path = os.path.join(model_dir, "test.tflite") model_file_postfix_error = os.path.join(model_dir, - "test_correct.pb") + "test_correct.pb") model_file_content_error_1 = os.path.join(model_dir, - "test_error.tflite") + "test_error.tflite") model_file_content_error_2 = os.path.join(model_dir, - "test_error_t.tflite") + "test_error_t.tflite") model_file_content_error_3 = os.path.join(model_dir, - "test_error_t.pb") - error_path = os.path.join(model_dir,"te.txt") + "test_error_t.pb") + error_path = os.path.join(model_dir, "te.txt") def setUp(self): if os.path.exists(self.model_json_path): diff --git a/test/test_model_export.py b/test/test_model_export.py index 0d9464b..1c534cf 100644 --- a/test/test_model_export.py +++ b/test/test_model_export.py @@ -7,41 +7,41 @@ from model.model_export import model_export_fun class test_model_export(unittest.TestCase): model_dir = "../model_file/model_data_json" model_dir_1 = "../model_file/model_yml" - export_model_json = os.path.join(model_dir,"caffe.json") - export_model_data = os.path.join(model_dir,"caffe.data") - model_export_path = os.path.join(model_dir,"export_file/") - model_inputmeta_path = os.path.join(model_dir_1,"test_inputmeta.yml") - model_main_c_path = os.path.join(model_dir,"export_file/main.c") - inputmeta_file_error_path = os.path.join(model_dir_1,"test_inpu.yml") - model_file_error_path = os.path.join(model_dir,"t.json") - error_frame_json = os.path.join(model_dir,"darknet/test.json") - error_frame_data = os.path.join(model_dir,"darknet/test.data") + export_model_json = os.path.join(model_dir, "caffe.json") + export_model_data = os.path.join(model_dir, "caffe.data") + model_export_path = os.path.join(model_dir, "export_file/") + model_inputmeta_path = os.path.join(model_dir_1, "test_inputmeta.yml") + model_main_c_path = os.path.join(model_dir, "export_file/main.c") + inputmeta_file_error_path = os.path.join(model_dir_1, "test_inpu.yml") + model_file_error_path = os.path.join(model_dir, "t.json") + error_frame_json = os.path.join(model_dir, "darknet/test.json") + error_frame_data = os.path.join(model_dir, "darknet/test.data") def test_model_export_success(self): - ret = model_export_fun(self.export_model_json, - self.export_model_data, - self.model_export_path, - self.model_inputmeta_path) + ret = model_export_fun(self.export_model_json, + self.export_model_data, + self.model_export_path, + self.model_inputmeta_path) self.assertTrue(ret is None) self.assertTrue(os.path.exists(self.model_main_c_path)) def test_model_export_no_file_1(self): ret = model_export_fun(self.export_model_json, - self.export_model_data, - self.model_export_path, - self.inputmeta_file_error_path) + self.export_model_data, + self.model_export_path, + self.inputmeta_file_error_path) self.assertTrue(ret == -2) def test_model_export_no_file_2(self): ret = model_export_fun(self.model_file_error_path, - self.export_model_data, - self.model_export_path, - self.model_inputmeta_path) + self.export_model_data, + self.model_export_path, + self.model_inputmeta_path) self.assertTrue(ret == -2) def test_model_export_other_model_data(self): ret = model_export_fun(self.error_frame_json, - self.error_frame_data, - self.model_export_path, - self.model_inputmeta_path) + self.error_frame_data, + self.model_export_path, + self.model_inputmeta_path) self.assertTrue(ret == -10) diff --git a/test/test_model_inference.py b/test/test_model_inference.py index ba622ab..f7184fd 100644 --- a/test/test_model_inference.py +++ b/test/test_model_inference.py @@ -13,29 +13,29 @@ class test_model_inference(unittest.TestCase): error_json_path = os.path.join(model_dir, "te.json") error_data_path = os.path.join(model_dir, "caf.data") error_inputmeta_path = os.path.join(model_dir_1, "test_inputme.yml") - meta_file_dataset_path_error = os.path.join(model_dir_1, - "test_inputmeta_c.yml") + meta_file_dataset_path_error = os.path.join(model_dir_1, + "test_inputmeta_c.yml") def test_model_inference_success(self): ret = model_inference_fun(self.inference_model_json, - self.inference_model_data, - self.inference_inputmeta_path) + self.inference_model_data, + self.inference_inputmeta_path) self.assertTrue(ret is None) def test_model_inference_no_file_1(self): ret = model_inference_fun(self.error_json_path, - self.inference_model_data, - self.inference_inputmeta_path) + self.inference_model_data, + self.inference_inputmeta_path) self.assertTrue(ret == -3) def test_model_inference_no_file_2(self): ret = model_inference_fun(self.inference_model_json, - self.error_data_path, - self.error_inputmeta_path) + self.error_data_path, + self.error_inputmeta_path) self.assertTrue(ret == -2) def test_model_inference_no_dataset_file(self): ret = model_inference_fun(self.inference_model_json, - self.inference_model_data, - self.meta_file_dataset_path_error) + self.inference_model_data, + self.meta_file_dataset_path_error) self.assertTrue(ret == -3) diff --git a/test/test_model_measure.py b/test/test_model_measure.py index ff56894..b7932d8 100644 --- a/test/test_model_measure.py +++ b/test/test_model_measure.py @@ -6,8 +6,8 @@ from model.model_measure import model_measure_fun class test_model_measure(unittest.TestCase): model_dir = "../model_file/model_data_json" - model_path = os.path.join(model_dir,"caffe.json") - error_model_path = os.path.join(model_dir,"te.json") + model_path = os.path.join(model_dir, "caffe.json") + error_model_path = os.path.join(model_dir, "te.json") def test_model_measure_success(self): ret = model_measure_fun(self.model_path) diff --git a/test/test_model_quantize.py b/test/test_model_quantize.py index 207c4ca..bf78ccc 100644 --- a/test/test_model_quantize.py +++ b/test/test_model_quantize.py @@ -13,7 +13,7 @@ class test_model_quantize(unittest.TestCase): quantize_inputmeta_path = os.path.join(model_dir_1, "test_inputmeta.yml") quantize_model_error_path = os.path.join(model_dir, "test.json") meta_file_dataset_path_error = os.path.join(model_dir_1, - "test_inputmeta_c.yml") + "test_inputmeta_c.yml") def setUp(self): if os.path.exists(self.quantize_gen_file): @@ -26,7 +26,8 @@ class test_model_quantize(unittest.TestCase): def test_model_quantize_success_int16(self): ret = model_quantize(self.quantize_json_path, self.quantize_data_path, - 'int16','dynamic_fixed_point', + 'int16', + 'dynamic_fixed_point', self.quantize_inputmeta_path) self.assertTrue(ret is None) self.assertTrue(os.path.exists(self.quantize_gen_file)) @@ -34,7 +35,8 @@ class test_model_quantize(unittest.TestCase): def test_model_quantize_success_int8(self): ret = model_quantize(self.quantize_json_path, self.quantize_data_path, - 'int8','dynamic_fixed_point', + 'int8', + 'dynamic_fixed_point', self.quantize_inputmeta_path) self.assertTrue(ret is None) self.assertTrue(os.path.exists(self.quantize_gen_file)) @@ -42,7 +44,8 @@ class test_model_quantize(unittest.TestCase): def test_model_quantize_success_uint8(self): ret = model_quantize(self.quantize_json_path, self.quantize_data_path, - 'uint8','asymmetric_affine', + 'uint8', + 'asymmetric_affine', self.quantize_inputmeta_path) self.assertTrue(ret is None) self.assertTrue(os.path.exists(self.quantize_gen_file)) @@ -66,20 +69,23 @@ class test_model_quantize(unittest.TestCase): def test_model_quantize_qtype_match_quantizer_error2(self): ret = model_quantize(self.quantize_json_path, self.quantize_data_path, - 'int8','asymmetric_affine', + 'int8', + 'asymmetric_affine', self.quantize_inputmeta_path) self.assertTrue(ret == -21) def test_model_quantize_qtype_match_quantizer_error3(self): ret = model_quantize(self.quantize_json_path, self.quantize_data_path, - 'uint8','dynamic_fixed_point', + 'uint8', + 'dynamic_fixed_point', self.quantize_inputmeta_path) self.assertTrue(ret == -21) def test_model_quantize_no_dataset_file(self): ret = model_quantize(self.quantize_json_path, self.quantize_data_path, - 'int8','dynamic_fixed_point', - self.meta_file_dataset_path_error) + 'int8', + 'dynamic_fixed_point', + self.meta_file_dataset_path_error) self.assertTrue(ret == -20) diff --git a/utils/file_op.py b/utils/file_op.py index edc1bdf..0c25b2b 100644 --- a/utils/file_op.py +++ b/utils/file_op.py @@ -9,6 +9,8 @@ api = Api(file_op) upload_file_parser = reqparse.RequestParser() upload_file_parser.add_argument('guid', type=str, required=True) + + class upload_file(Resource): def post(self): args = upload_file_parser.parse_args() @@ -27,9 +29,9 @@ class upload_file(Resource): os.makedirs(savePath) filename = '{}.{}'.format(app.config['MODEL_FILE_NAME'], - file.filename.rsplit('.', 1)[1].lower()) + file.filename.rsplit('.', 1)[1].lower()) file.save(os.path.join(savePath, filename)) - return {'savePath':savePath} + return {'savePath': savePath} def get(self): return make_response(''' @@ -44,11 +46,14 @@ class upload_file(Resource): - ''',200) + ''', 200) + upload_dataset_parser = reqparse.RequestParser() upload_dataset_parser.add_argument('guid', type=str, required=True) upload_dataset_parser.add_argument('relativePath', type=str, required=True) + + class upload_dataset(Resource): def post(self): @@ -72,7 +77,7 @@ class upload_dataset(Resource): filename = file.filename file.save(os.path.join(savePath, filename)) - return {'savePath':savePath} + return {'savePath': savePath} def get(self): return make_response(''' @@ -91,26 +96,29 @@ class upload_dataset(Resource): - ''',200) + ''', 200) download_file_parser = reqparse.RequestParser() download_file_parser.add_argument('guid', type=str, required=True) + class download_file(Resource): def get(self): - args = download_file_parser.parse_args() + args = download_file_parser.parse_args() file_path = os.path.join(app.config['UPLOAD_FOLDER'], - args.guid, - app.config['EXPORT_RELATIVE_PATH'], - app.config['EXPORT_FILE_NAME']) + args.guid, + app.config['EXPORT_RELATIVE_PATH'], + app.config['EXPORT_FILE_NAME'] + ) if os.path.isfile(file_path): return send_file(file_path, as_attachment=True) else: return "The downloaded file does not exist." -api.add_resource(upload_file,'/v1/upload') -api.add_resource(upload_dataset,'/v1/upload_dataset') -api.add_resource(download_file,'/v1/download') + +api.add_resource(upload_file, '/v1/upload') +api.add_resource(upload_dataset, '/v1/upload_dataset') +api.add_resource(download_file, '/v1/download') diff --git a/utils/model_op.py b/utils/model_op.py index c44dc7b..7699310 100644 --- a/utils/model_op.py +++ b/utils/model_op.py @@ -1,6 +1,6 @@ import os -from flask import Flask,Blueprint +from flask import Flask, Blueprint from flask import current_app as app from flask import jsonify from flask_restful import Resource, Api @@ -11,19 +11,19 @@ from model.import_model import import_caffe model_op = Blueprint('model_op', __name__) api = Api(model_op) -return_dict ={ - "-1":"The model file dose not match the current option.", - "-2":"The model file dose not match the current option.", - "-3":"No such file or directory.", - "-4":"The model file dose not match the current option.", - "-5":"The model file dose not match the current option.", - "-6":"The model file dose not match the current option.", - "-7":"The model file dose not match the current option.", - "-8":"The model file dose not match the current option.", - "-9":"please input correct model data.", - "-10":"The model data is Error.", - "-11":"import failure!" - } +return_dict = { + "-1": "The model file dose not match the current option.", + "-2": "The model file dose not match the current option.", + "-3": "No such file or directory.", + "-4": "The model file dose not match the current option.", + "-5": "The model file dose not match the current option.", + "-6": "The model file dose not match the current option.", + "-7": "The model file dose not match the current option.", + "-8": "The model file dose not match the current option.", + "-9": "please input correct model data.", + "-10": "The model data is Error.", + "-11": "import failure!" +} def return_info(ret, guid): @@ -40,22 +40,25 @@ def get_model_file_name(guid, suffix): app.config['UPLOAD_FOLDER'], guid, app.config['MODEL_FILE_NAME'] + suffix - ) + ) + import_parser = reqparse.RequestParser() import_parser.add_argument('guid', type=str, required=True) + class caffe(Resource): def post(self): args = import_parser.parse_args() app.logger.info(f"args: {args}") - model = get_model_file_name(args.guid,".prototxt") - weights = get_model_file_name(args.guid,".caffemodel") + model = get_model_file_name(args.guid, ".prototxt") + weights = get_model_file_name(args.guid, ".caffemodel") app.logger.info(f"model: {model}") app.logger.info(f"weights: {weights}") ret = import_caffe(model, weights) - result = return_info(ret,args.guid) + result = return_info(ret, args.guid) app.logger.info(f"result: {ret}") return jsonify(result) + api.add_resource(caffe, '/v1/model/import/caffe') diff --git a/utils/quantize_op.py b/utils/quantize_op.py index 7802d3a..bf492a6 100644 --- a/utils/quantize_op.py +++ b/utils/quantize_op.py @@ -14,10 +14,11 @@ from model.model_quantize import model_quantize quantize_op = Blueprint('quantize_op', __name__) api = Api(quantize_op) -return_dict ={ - "-20":"The path is not exist", - "-21":"Please input correct quantize data." - } +return_dict = { + "-20": "The path is not exist", + "-21": "Please input correct quantize data." +} + def return_info(ret, guid): if ret is None: @@ -27,23 +28,26 @@ def return_info(ret, guid): "error_msg": return_dict[str(ret)]} return result + def get_model_file_name(guid, suffix): return os.path.join( app.config['UPLOAD_FOLDER'], guid, app.config['MODEL_FILE_NAME'] + suffix - ) + ) + def get_mean_value(mean_value): mean_value_str = mean_value.split() mean_value = list(map(float, mean_value_str)) return mean_value + import_parser = reqparse.RequestParser() import_parser.add_argument('guid', type=str, required=True) import_parser.add_argument('qtype', type=str, required=True) import_parser.add_argument('quantizer', type=str, required=True) -import_parser.add_argument('mean_value',type=str, required=True) +import_parser.add_argument('mean_value', type=str, required=True) import_parser.add_argument('standard_value', type=str, required=True) @@ -53,15 +57,15 @@ class quantize(Resource): app.logger.info(f"args: {args}") model_dir = os.path.join(app.config['UPLOAD_FOLDER'], args.guid) with_input_meta = os.path.join(model_dir, - app.config["INPUTMETA_FILE_NAME"]) - model = get_model_file_name(args.guid,".json") - model_data = get_model_file_name(args.guid,".data") + app.config["INPUTMETA_FILE_NAME"]) + model = get_model_file_name(args.guid, ".json") + model_data = get_model_file_name(args.guid, ".data") qtype = args.qtype quantizer = args.quantizer standard_value = float(args.standard_value) mean_value = get_mean_value(args.mean_value) generate_yml(model) - set_yml_parameter(args.guid,mean_value,standard_value) + set_yml_parameter(args.guid, mean_value, standard_value) app.logger.info(f"model: {model}") app.logger.info(f"model_data: {model_data}") app.logger.info(f"qtype: {qtype}") @@ -74,8 +78,8 @@ class quantize(Resource): qtype, quantizer, with_input_meta - ) - result = return_info(ret,args.guid) + ) + result = return_info(ret, args.guid) app.logger.info(f"ret: {ret}") return jsonify(result) -- 2.34.1