|
- import numpy as np
- import mindspore
- from mindspore import Tensor, context
- from mindspore.train.model import Model
- from src.config import config
- from src.efficientnet import efficientnet_b0, efficientnet_b1
- from mindspore.dataset.vision import Inter
- import cv2
-
- #读取单张图像并进行预处理,用于predict
- def data_test_handle(image_path):
-
- img = cv2.imread(image_path)
- # 1、类型转换
- # type_cast_op = C2.TypeCast(mstype.int32)
-
- # 2、resize大小缩放
- # resize_op = C.Resize(size=scale_size, interpolation=interpolation)
- scale_size = 256
- img = cv2.resize(img, (scale_size, scale_size), interpolation=Inter.BICUBIC)
- # 3、中心剪切成224*224
- # center_crop = C.CenterCrop(size=img_size)
- img = img[16:240,16:240]
- # 4、图像取值规模缩放:/255至0-1
- # rescale_op = C.Rescale(rescale, shift)
- img = np.array(img, dtype='float32')
- img /= 255.
- # 5、归一化
- # normalize_op = C.Normalize(config.mean,config.std)
- # 6、通道顺序变换 changeswap_op = C.HWC2CHW()
-
- img = img.transpose(2, 0, 1)
- img = np.expand_dims(img, 0)
- # 将图像转成向量
- img_tensor = Tensor(img, dtype=mindspore.float32)
-
- return img_tensor
-
-
-
- scale_size = 256
- path_root = 'F:\Algorithm_Project_Code\ShengSi_challenge\Data\caltech_256\caltech_for_user\\test'
- best_ckpt_path = \
- "./mode_ckpt/efficientnetV6_b0-270_142.ckpt"
-
-
-
- if __name__ == '__main__':
- net = efficientnet_b0(num_classes=config.num_classes,
- cfg=config,
- drop_rate=config.drop,
- drop_connect_rate=config.drop_connect,
- global_pool=config.gp,
- bn_tf=config.bn_tf,
- )
-
- # 加载模型参数
- param_dict = mindspore.load_checkpoint(best_ckpt_path)
- mindspore.load_param_into_net(net, param_dict)
- model = Model(net)
-
- count = 0
- # 打开文件
- Note = open('./result/predict_v3_270.txt', mode='w')
-
- for i in range(0, 5120):
- image_name = '\\' + str(i) + '.jpg'
- image_path = path_root + image_name
- temp = data_test_handle(image_path)
- # 开始预测
- predictions = model.predict(temp).asnumpy()
- r, c = np.where(predictions == np.max(predictions))
- c = c[0] + 1
- print("第", count, "个图像的预测值为:", c)
- count = count + 1
- Note.write(str(c) + '\n') # \n 换行符
-
- Note.close()
|