@@ -29,6 +29,8 @@ class SolutionTestBase:
self.init_success_flag = False
self.base_dir = os.getcwd()
self.case_name = case_name
self.args_dict = dict()
self.input_dict = dict()
self.params_dict = dict()
# 从环境获取执行模式
self.execute_mode = os.environ.get("TRAIN_MODE", None)
@@ -80,8 +82,6 @@ class SolutionTestBase:
return True
def generate_case_params(self):
args_dict = dict()
input_dict = dict()
yaml_content = yaml_read(self.yaml_path)
args = yaml_content.get("args")
variation = yaml_content.get("variation")
@@ -106,13 +106,13 @@ class SolutionTestBase:
if args is not None:
for key in args.keys():
if self.params_dict.get(key) is not None:
args_dict[key] = self.params_dict.get(key)
self. args_dict[key] = self.params_dict.get(key)
if input_shape is not None:
for key in input_shape.keys():
if self.params_dict.get(key) is not None:
input_dict[key] = self.params_dict.get(key)
self.ms_log.info("params_dict: %s, args_dict: %s", self.params_dict, args_dict)
return self.params_dict, args_dict, input_shape
self. input_dict[key] = self.params_dict.get(key)
self.ms_log.info("params_dict: %s, args_dict: %s", self.params_dict, self. args_dict)
return self.params_dict, self.args_dict, self.input_dict
def generate_input_shape(self):
"""
@@ -120,33 +120,29 @@ class SolutionTestBase:
"""
ms_input_shape_list = list()
torch_input_shape_list = list()
_, args _dict , input_dict = self.generate_case_params()
_, _, input_dict = self.generate_case_params()
for key, value in input_dict.items():
random_shape = np.random.randn(*input_dict.get(key).get("input_shape"))
torch_input_shape = torch.tensor(random_shape, dtype=input_dict.get("torch_dtype"))
ms_input_shape = ms_pytorch.tensor(random_shape, dtype=input_dict.get("ms_dtype"))
ms_input_shape_list.append(ms_input_shape)
torch_input_shape_list.append(torch_input_shape)
return torch_input_shape_list, ms_input_shape_list
def generate_args_with_tensor(self, args):
"""
生成参数中的tensor类型数据
:param args: 设为tesnor数据的参数名
:param frame_name: 使用的框架名,取值分别为:"torch"或"ms_pytorch"
"""
_, args_dict, _ = self.generate_case_params()
numpy_shape = args_dict.get(args).get("numpy_shape")
numpy_type = args_dict.get(args).get("numpy_type")
random_shape = np.random.randn(*numpy_shape)
numpy_shape = self.args_dict.get(args).get("numpy_shape")
numpy_type = self.args_dict.get(args).get("numpy_type")
random_shape = np.random.randn(numpy_shape)
torch_tensor = torch.tensor(random_shape, dtype=eval(".".join(["torch", numpy_type])))
args_dict["weight"] = torch_tensor
torch_args = args_dict.copy( )
self.args_dict["weight"] = torch_tensor
torch_args = copy.deepcopy(self. args_dict)
ms_tensor = ms_pytorch.tensor(random_shape, dtype=eval(".".join(["ms_pytorch", numpy_type])))
args_dict["weight"] = ms_tensor
ms_args = args_dict
self. args_dict["weight"] = ms_tensor
ms_args = self. args_dict
self.ms_log.info("func generate_args_with_tensor get torch_args: %s\n ms_args: %s\n", torch_args, ms_args)
return torch_args, ms_args