|
- import torch
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import save_checkpoint, Tensor
- from mindspore import load_checkpoint, load_param_into_net
- from model import ImageCompressor
- from PIL import Image
-
- model = ImageCompressor()
- zeros = ms.ops.Zeros()
- img = Image.open('sea.jpg').convert('RGB')
- totensor = ms.dataset.vision.ToTensor()
- data = ms.Tensor(totensor(img)).unsqueeze(0)
- print(data.shape)
-
- # load pretrain
- # pth_file = 'e2e_pretrain.pth'
- # torch_param_dict = torch.load(pth_file)
- # ms_params = []
- # for key in torch_param_dict.keys():
- # value = torch_param_dict[key]
- # print(value.shape)
- # value = Tensor(value.detach().cpu().numpy())
- # # convtranspose3d升维
- # if 'deconv' in key and 'weight' in key:
- # value = ms.ops.ExpandDims()(value, 2)
- # ms_params.append({"name": key, "data": value})
- # save_checkpoint(ms_params, 'e2e_pretrain.ckpt')
-
- param_dir = 'e2e_pretrain.ckpt'
- param_dict = load_checkpoint(param_dir)
- load_param_into_net(model, param_dict)
-
- # forward
- out = model(data)
- print(out)
- optimizer = nn.Adam(model.trainable_params(), lr=1e-4)
- loss_fn = nn.MSELoss()
-
-
- def forward_fn(data, label):
- logits = model(data)[0]
- loss = loss_fn(logits[0], label)
- return loss, logits
-
-
- grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
-
-
- def train_step(data, label):
- (loss, _), grads = grad_fn(data, label)
- optimizer(grads)
- return loss
-
-
- model.set_train()
- loss = train_step(data, data)
- print(loss.asnumpy())
-
- def ms_compute_psnr(a, b):
- ms_psnr = ms.nn.PSNR()
- return ms_psnr(a, b)
-
- def ms_compute_msssim(a, b):
- m_ssim = ms.nn.SSIM()
- return m_ssim(a, b)
-
- print(f'mindspore PSNR: {ms_compute_psnr(data, out[0])[0]}dB')
- print(f'mindspore MS-SSIM: {ms_compute_msssim(data, out[0])[0]}')
|