|
- """Tests for import pytorch"""
-
- import os
- import unittest
-
- from model.import_model import import_pytorch
-
- # pylint: disable=missing-class-docstring
- class TestImportPyTorch(unittest.TestCase):
- model_dir = "../model_file/pytorch"
- model_json_path = os.path.join(model_dir, "test.json")
- model_data_path = os.path.join(model_dir, "test.data")
- model_file_postfix_error_gene_json = os.path.join(model_dir,
- "test_correct.json")
- model_file_postfix_error_gene_data = os.path.join(model_dir,
- "test_correct.data")
- model_file_path = os.path.join(model_dir, "test.pth")
- model_file_postfix_error = os.path.join(model_dir, "test_correct.onnx")
- model_file_content_error_1 = os.path.join(model_dir, "test_error.pth")
- model_file_content_error_2 = os.path.join(model_dir, "test_error.h5")
- error_path = os.path.join(model_dir, "test_error.txt")
-
- def setUp(self):
- if os.path.exists(self.model_json_path):
- os.remove(self.model_json_path)
- if os.path.exists(self.model_data_path):
- os.remove(self.model_data_path)
- if os.path.exists(self.model_file_postfix_error_gene_json):
- os.remove(self.model_file_postfix_error_gene_json)
- if os.path.exists(self.model_file_postfix_error_gene_data):
- os.remove(self.model_file_postfix_error_gene_data)
-
- def tearDown(self):
- if os.path.exists(self.model_json_path):
- os.remove(self.model_json_path)
- if os.path.exists(self.model_data_path):
- os.remove(self.model_data_path)
- if os.path.exists(self.model_file_postfix_error_gene_json):
- os.remove(self.model_file_postfix_error_gene_json)
- if os.path.exists(self.model_file_postfix_error_gene_data):
- os.remove(self.model_file_postfix_error_gene_data)
-
- def test_import_pytorch_success(self):
- ret = import_pytorch(self.model_file_path, '3,224,224')
- self.assertTrue(ret is None)
- self.assertTrue(os.path.exists(self.model_json_path))
- self.assertTrue(os.path.exists(self.model_data_path))
-
- def test_import_pytorch_model_file_postfix_error_but_content_correct(self):
- ret = import_pytorch(self.model_file_postfix_error, '3,224,224')
- self.assertTrue(ret is None)
- self.assertTrue(os.path.exists(
- self.model_file_postfix_error_gene_json))
- self.assertTrue(os.path.exists(
- self.model_file_postfix_error_gene_data))
-
- def test_import_pytorch_model_file_postfix_correct_but_content_error(self):
- ret = import_pytorch(self.model_file_content_error_1, '3,224,224')
- self.assertTrue(ret == -8)
-
- def test_import_pytorch_model_file_postfix_error_and_content_error(self):
- ret = import_pytorch(self.model_file_content_error_2, '3,224,224')
- self.assertTrue(ret == -8)
-
- def test_import_pytorch_input_size_list_content_miss(self):
- ret = import_pytorch(self.model_file_path, '')
- self.assertTrue(ret == -2)
-
- def test_import_pytorch_input_size_list_content_error(self):
- ret = import_pytorch(self.model_file_path, '3456')
- self.assertTrue(ret == -6)
-
- def test_import_pytorch_input_size_list_content_error_1(self):
- ret = import_pytorch(self.model_file_path, 'a,b,c')
- self.assertTrue(ret == -2)
-
- def test_import_pytorch_no_file(self):
- ret = import_pytorch(self.error_path, '3,224,224')
- self.assertTrue(ret == -2)
|