Browse Source

Create main_fine_tuning_ocmr.py

master
Ziwen Ke GitHub 2 years ago
parent
commit
9bb9e15f18
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 175 additions and 0 deletions
  1. +175
    -0
      main_fine_tuning_ocmr.py

+ 175
- 0
main_fine_tuning_ocmr.py View File

@@ -0,0 +1,175 @@
import tensorflow as tf
import os
from model import LplusS_Net, S_Net, SLR_Net
from dataset_ocmr_tfrecord import get_dataset
import argparse
import scipy.io as scio
import mat73
import numpy as np
from datetime import datetime
import time
from tools.tools import video_summary

from tools.tools import tempfft, mse, loss_function_ISTA


#tf.debugging.set_log_device_placement(True)
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# tf.debugging.set_log_device_placement(True)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_epoch', metavar='int', nargs=1, default=['50'], help='number of epochs')
parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], help='batch size')
parser.add_argument('--learning_rate', metavar='float', nargs=1, default=['0.0001'], help='initial learning rate')
parser.add_argument('--niter', metavar='int', nargs=1, default=['10'], help='number of network iterations')
parser.add_argument('--acc', metavar='int', nargs=1, default=['8'], help='accelerate rate')
parser.add_argument('--net', metavar='str', nargs=1, default=['SLRNET'], help='SLR Net or S Net')
parser.add_argument('--gpu', metavar='int', nargs=1, default=['0'], help='GPU No.')
parser.add_argument('--data', metavar='str', nargs=1, default=['OCMR'], help='dataset name')
parser.add_argument('--learnedSVT', metavar='bool', nargs=1, default=['True'], help='Learned SVT threshold or not')

args = parser.parse_args()
# GPU setup
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu[0]
GPUs = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(GPUs[0], True)
mode = 'training'
dataset_name = args.data[0].upper()
batch_size = int(args.batch_size[0])
num_epoch = int(args.num_epoch[0])
learning_rate = float(args.learning_rate[0])

acc = int(args.acc[0])
net_name = args.net[0].upper()
niter = int(args.niter[0])
learnedSVT = bool(args.learnedSVT[0])


logdir = './logs'
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
model_id = TIMESTAMP + net_name + '_' + dataset_name + str(acc) + '_epoch_50' + '_lr_' + str(learning_rate) + '_ocmr_fine_tuning'
summary_writer = tf.summary.create_file_writer(os.path.join(logdir, mode, model_id + '/'))

modeldir = os.path.join('models/stable/', model_id)
os.makedirs(modeldir)

# prepare undersampling mask
if dataset_name == 'DYNAMIC_V2':
multi_coil = False
mask_size = '18_192_192'
elif dataset_name == 'DYNAMIC_V2_MULTICOIL':
multi_coil = True
mask_size = '18_192_192'
elif dataset_name == 'OCMR':
multi_coil = True
mask_size = '18_192_192'
elif dataset_name == 'FLOW':
multi_coil = False
mask_size = '20_180_180'

"""
if acc == 8:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acs4_acc8.mat')['mask']
elif acc == 10:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acs4_acc10.mat')['mask']
elif acc == 12:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acs4_acc12.mat')['mask']
"""
mask = mat73.loadmat('/data1/ziwenke/SLRNet/mask_newdata/mask_144.mat')['mask']
mask = np.transpose(mask, [1,0])
mask = np.reshape(mask, [1,1,mask.shape[0],1,mask.shape[1]])
mask = tf.cast(tf.constant(mask), tf.complex64)

# prepare dataset
dataset = get_dataset(mode, dataset_name, batch_size, shuffle=True, full=True)
#dataset = get_dataset('test', dataset_name, batch_size, shuffle=True, full=True)
tf.print('dataset loaded.')

# initialize network
if net_name == 'SLRNET':
net = SLR_Net(mask, niter, learnedSVT)
weight_file = 'models/stable/2020-10-23T12-09-22SLRNET_DYNAMIC_V2_MULTICOIL8/epoch-50/ckpt'
net.load_weights(weight_file)


tf.print('network initialized.')

learning_rate_org = learning_rate
learning_rate_decay = 0.95

optimizer = tf.optimizers.Adam(learning_rate_org)
# Iterate over epochs.
total_step = 0
param_num = 0
loss = 0

for epoch in range(num_epoch):
for step, sample in enumerate(dataset):
# forward
t0 = time.time()
k0 = None
csm = None
with tf.GradientTape() as tape:
if multi_coil:
k0, label, csm = sample
if k0 == None:
continue
else:
k0, label = sample
if k0.shape[0] < batch_size:
continue

label_abs = tf.abs(label)

k0 = k0 * mask[:,:,0:k0.shape[2],:,:]

recon, X_SYM = net(k0, csm)
recon_abs = tf.abs(recon)

#loss = loss_function_ISTA(recon, label, X_SYM, niter)
loss_mse = mse(recon, label)


# backward
grads = tape.gradient(loss_mse, net.trainable_weights)####################################
optimizer.apply_gradients(zip(grads, net.trainable_weights))#################################

# record loss
with summary_writer.as_default():
tf.summary.scalar('loss/total', loss_mse.numpy(), step=total_step)

# record gif
if step % 20 == 0:
with summary_writer.as_default():
combine_video = tf.concat([label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:]], axis=0).numpy()
combine_video = np.expand_dims(combine_video, -1)
video_summary('result', combine_video, step=total_step, fps=10)
# calculate parameter number
if total_step == 0:
param_num = np.sum([np.prod(v.get_shape()) for v in net.trainable_variables])

# log output
tf.print('Epoch', epoch+1, '/', num_epoch, 'Step', step, 'loss =', loss_mse.numpy(), 'time', time.time() - t0, 'lr = ', learning_rate, 'param_num', param_num)
total_step += 1

# learning rate decay for each epoch
learning_rate = learning_rate_org * learning_rate_decay ** (epoch + 1)#(total_step / decay_steps)
optimizer = tf.optimizers.Adam(learning_rate)

# save model each epoch
#if epoch in [0, num_epoch-1, num_epoch]:
model_epoch_dir = os.path.join(modeldir,'epoch-'+str(epoch+1), 'ckpt')
net.save_weights(model_epoch_dir, save_format='tf')


Loading…
Cancel
Save