@@ -0,0 +1,173 @@ | |||
import tensorflow as tf | |||
import os | |||
from model_net_v3 import Manifold_Net | |||
from dataset_tfrecord import get_dataset | |||
import argparse | |||
import scipy.io as scio | |||
import mat73 | |||
import numpy as np | |||
from datetime import datetime | |||
import time | |||
from tools.tools import video_summary | |||
from tools.tools import tempfft, mse | |||
#tf.debugging.set_log_device_placement(True) | |||
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |||
# tf.debugging.set_log_device_placement(True) | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--num_epoch', metavar='int', nargs=1, default=['50'], help='number of epochs') | |||
parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], help='batch size') | |||
parser.add_argument('--learning_rate', metavar='float', nargs=1, default=['0.001'], help='initial learning rate') | |||
parser.add_argument('--niter', metavar='int', nargs=1, default=['5'], help='number of network iterations') | |||
parser.add_argument('--nconv', metavar='int', nargs=1, default=['3'], help='number of convolutional layers on CNNLayer') | |||
parser.add_argument('--acc', metavar='int', nargs=1, default=['8'], help='accelerate rate') | |||
parser.add_argument('--dc_type', metavar='str', nargs=1, default=['v1'], help='v1: kspace; v2: image') | |||
parser.add_argument('--mask_pattern', metavar='str', nargs=1, default=['spiral'], help='mask pattern: cartesian, radial, spiral, vista') | |||
parser.add_argument('--net', metavar='str', nargs=1, default=['Manifold_Net'], help='Manifold_Net') | |||
parser.add_argument('--gpu', metavar='int', nargs=1, default=['2'], help='GPU No.') | |||
parser.add_argument('--data', metavar='str', nargs=1, default=['DYNAMIC_V2'], help='dataset name') | |||
parser.add_argument('--learnedSVT', metavar='bool', nargs=1, default=['True'], help='Learned SVT threshold or not') | |||
parser.add_argument('--SVT_favtor', metavar='float', nargs=1, default=['1.3'], help='SVT factor') | |||
args = parser.parse_args() | |||
# GPU setup | |||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu[0] | |||
GPUs = tf.config.experimental.list_physical_devices('GPU') | |||
tf.config.experimental.set_memory_growth(GPUs[0], True) | |||
mode = 'training' | |||
dataset_name = args.data[0].upper() | |||
dc_type = args.dc_type[0] | |||
batch_size = int(args.batch_size[0]) | |||
num_epoch = int(args.num_epoch[0]) | |||
learning_rate = float(args.learning_rate[0]) | |||
acc = int(args.acc[0]) | |||
mask_pattern = args.mask_pattern[0] | |||
net_name = args.net[0] | |||
niter = int(args.niter[0]) | |||
nconv = int(args.nconv[0]) | |||
learnedSVT = bool(args.learnedSVT[0]) | |||
#N_factor = int(args.SVT_favtor[0]) | |||
N_factor = float(args.SVT_favtor[0]) | |||
logdir = './logs' | |||
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) | |||
model_id = TIMESTAMP + '_'+ net_name + '_v3_correct_' + 'dc_' + dc_type +'_d'+str(nconv)+'c'+str(niter)+'_acc_'+ str(acc) + '_lr_' + str(learning_rate) + '_N_factor_' + str(N_factor) + '_rank_' + str(int(18/N_factor)) +'_'+ mask_pattern | |||
summary_writer = tf.summary.create_file_writer(os.path.join(logdir, mode, model_id + '/')) | |||
modeldir = os.path.join('models/stable/', model_id) | |||
os.makedirs(modeldir) | |||
# prepare undersampling mask | |||
if dataset_name == 'DYNAMIC_V2': | |||
multi_coil = False | |||
mask_size = '18_192_192' | |||
elif dataset_name == 'DYNAMIC_V2_MULTICOIL': | |||
multi_coil = True | |||
mask_size = '18_192_192' | |||
elif dataset_name == 'FLOW': | |||
multi_coil = False | |||
mask_size = '20_180_180' | |||
if acc == 8: | |||
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc8.mat')['mask'] | |||
elif acc == 10: | |||
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acc10.mat')['mask'] | |||
elif acc == 12: | |||
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc12.mat')['mask'] | |||
mask = tf.cast(tf.constant(mask), tf.complex64) | |||
# prepare dataset | |||
dataset = get_dataset(mode, dataset_name, batch_size, shuffle=True, full=True) | |||
#dataset = get_dataset('test', dataset_name, batch_size, shuffle=True, full=True) | |||
tf.print('dataset loaded.') | |||
# initialize network | |||
if net_name == 'Manifold_Net': | |||
net = Manifold_Net(mask, niter, learnedSVT, N_factor) | |||
tf.print('network initialized.') | |||
learning_rate_org = learning_rate | |||
learning_rate_decay = 0.95 | |||
optimizer = tf.optimizers.Adam(learning_rate_org) | |||
# Iterate over epochs. | |||
total_step = 0 | |||
param_num = 0 | |||
loss = 0 | |||
for epoch in range(num_epoch): | |||
for step, sample in enumerate(dataset): | |||
# forward | |||
t0 = time.time() | |||
k0 = None | |||
csm = None | |||
with tf.GradientTape() as tape: | |||
if multi_coil: | |||
k0, label, csm = sample | |||
if k0 == None: | |||
continue | |||
else: | |||
k0, label = sample | |||
if k0.shape[0] < batch_size: | |||
continue | |||
label_abs = tf.abs(label) | |||
k0 = k0 * mask | |||
recon = net(k0, csm) | |||
recon_abs = tf.abs(recon) | |||
loss_mse = mse(recon, label) | |||
# backward | |||
grads = tape.gradient(loss_mse, net.trainable_weights)#################################### | |||
optimizer.apply_gradients(zip(grads, net.trainable_weights))################################# | |||
# record loss | |||
with summary_writer.as_default(): | |||
tf.summary.scalar('loss/total', loss_mse.numpy(), step=total_step) | |||
# record gif | |||
if step % 20 == 0: | |||
with summary_writer.as_default(): | |||
combine_video = tf.concat([label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:]], axis=0).numpy() | |||
combine_video = np.expand_dims(combine_video, -1) | |||
video_summary('result', combine_video, step=total_step, fps=10) | |||
# calculate parameter number | |||
if total_step == 0: | |||
param_num = np.sum([np.prod(v.get_shape()) for v in net.trainable_variables]) | |||
# log output | |||
tf.print('Epoch', epoch+1, '/', num_epoch, 'Step', step, 'loss =', loss_mse.numpy(), 'time', time.time() - t0, 'lr = ', learning_rate, 'param_num', param_num) | |||
total_step += 1 | |||
# learning rate decay for each epoch | |||
learning_rate = learning_rate_org * learning_rate_decay ** (epoch + 1)#(total_step / decay_steps) | |||
optimizer = tf.optimizers.Adam(learning_rate) | |||
# save model each epoch | |||
#if epoch in [0, num_epoch-1, num_epoch]: | |||
model_epoch_dir = os.path.join(modeldir,'epoch-'+str(epoch+1), 'ckpt') | |||
net.save_weights(model_epoch_dir, save_format='tf') | |||
@@ -0,0 +1,578 @@ | |||
import tensorflow as tf | |||
from tensorflow.keras import layers | |||
import os | |||
import numpy as np | |||
import time | |||
from tools.tools import tempfft, fft2c_mri, ifft2c_mri, Emat_xyt | |||
class CNNLayer(tf.keras.layers.Layer): | |||
def __init__(self, n_f=32, n_out=2): | |||
super(CNNLayer, self).__init__() | |||
self.mylayers = [] | |||
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False)) | |||
self.mylayers.append(tf.keras.layers.LeakyReLU()) | |||
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False)) | |||
self.mylayers.append(tf.keras.layers.LeakyReLU()) | |||
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False)) | |||
self.mylayers.append(tf.keras.layers.LeakyReLU()) | |||
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False)) | |||
self.mylayers.append(tf.keras.layers.LeakyReLU()) | |||
self.mylayers.append(tf.keras.layers.Conv3D(n_out, 3, strides=1, padding='same', use_bias=False)) | |||
self.seq = tf.keras.Sequential(self.mylayers) | |||
def call(self, input): | |||
if len(input.shape) == 4: | |||
input2c = tf.stack([tf.math.real(input), tf.math.imag(input)], axis=-1) | |||
else: | |||
input2c = tf.concat([tf.math.real(input), tf.math.imag(input)], axis=-1) | |||
res = self.seq(input2c) | |||
res = tf.complex(res[:,:,:,:,0], res[:,:,:,:,1]) | |||
return res | |||
class CONV_OP(tf.keras.layers.Layer): | |||
def __init__(self, n_f=32, ifactivate=False): | |||
super(CONV_OP, self).__init__() | |||
self.mylayers = [] | |||
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False)) | |||
if ifactivate == True: | |||
self.mylayers.append(tf.keras.layers.ReLU()) | |||
self.seq = tf.keras.Sequential(self.mylayers) | |||
def call(self, input): | |||
res = self.seq(input) | |||
return res | |||
class SLR_Net(tf.keras.Model): | |||
def __init__(self, mask, niter, learned_topk=False): | |||
super(SLR_Net, self).__init__(name='SLR_Net') | |||
self.niter = niter | |||
self.E = Emat_xyt(mask) | |||
self.learned_topk = learned_topk | |||
self.celllist = [] | |||
def build(self, input_shape): | |||
for i in range(self.niter-1): | |||
self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk)) | |||
self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk, is_last=True)) | |||
def call(self, d, csm): | |||
""" | |||
d: undersampled k-space | |||
csm: coil sensitivity map | |||
""" | |||
if csm == None: | |||
nb, nt, nx, ny = d.shape | |||
else: | |||
nb, nc, nt, nx, ny = d.shape | |||
X_SYM = [] | |||
x_rec = self.E.mtimes(d, inv=True, csm=csm) | |||
t = tf.zeros_like(x_rec) | |||
beta = tf.zeros_like(x_rec) | |||
x_sym = tf.zeros_like(x_rec) | |||
data = [x_rec, x_sym, beta, t, d, csm] | |||
for i in range(self.niter): | |||
data = self.celllist[i](data, d.shape) | |||
x_sym = data[1] | |||
X_SYM.append(x_sym) | |||
x_rec = data[0] | |||
return x_rec, X_SYM | |||
class SLRCell(layers.Layer): | |||
def __init__(self, input_shape, E, learned_topk=False, is_last=False): | |||
super(SLRCell, self).__init__() | |||
if len(input_shape) == 4: | |||
self.nb, self.nt, self.nx, self.ny = input_shape | |||
else: | |||
self.nb, nc, self.nt, self.nx, self.ny = input_shape | |||
self.E = E | |||
self.learned_topk = learned_topk | |||
if self.learned_topk: | |||
if is_last: | |||
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=False, name='thres_coef') | |||
self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=False, name='eta') | |||
else: | |||
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=True, name='thres_coef') | |||
self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='eta') | |||
self.conv_1 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_2 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_3 = CONV_OP(n_f=16, ifactivate=False) | |||
self.conv_4 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_5 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_6 = CONV_OP(n_f=2, ifactivate=False) | |||
#self.conv_7 = CONV_OP(n_f=16, ifactivate=True) | |||
#self.conv_8 = CONV_OP(n_f=16, ifactivate=True) | |||
#self.conv_9 = CONV_OP(n_f=16, ifactivate=True) | |||
#self.conv_10 = CONV_OP(n_f=16, ifactivate=True) | |||
self.lambda_step = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_1') | |||
self.lambda_step_2 = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_2') | |||
self.soft_thr = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='soft_thr') | |||
def call(self, data, input_shape): | |||
if len(input_shape) == 4: | |||
self.nb, self.nt, self.nx, self.ny = input_shape | |||
else: | |||
self.nb, nc, self.nt, self.nx, self.ny = input_shape | |||
x_rec, x_sym, beta, t, d, csm = data | |||
x_rec, x_sym = self.sparse(x_rec, d, t, beta, csm) | |||
t = self.lowrank(x_rec) | |||
beta = self.beta_mid(beta, x_rec, t) | |||
data[0] = x_rec | |||
data[1] = x_sym | |||
data[2] = beta | |||
data[3] = t | |||
return data | |||
def sparse(self, x_rec, d, t, beta, csm): | |||
lambda_step = tf.cast(tf.nn.relu(self.lambda_step), tf.complex64) | |||
lambda_step_2 = tf.cast(tf.nn.relu(self.lambda_step_2), tf.complex64) | |||
ATAX_cplx = self.E.mtimes(self.E.mtimes(x_rec, inv=False, csm=csm) - d, inv=True, csm=csm) | |||
r_n = x_rec - tf.math.scalar_mul(lambda_step, ATAX_cplx) +\ | |||
tf.math.scalar_mul(lambda_step_2, x_rec + beta - t) | |||
# D_T(soft(D_r_n)) | |||
if len(r_n.shape) == 4: | |||
r_n = tf.stack([tf.math.real(r_n), tf.math.imag(r_n)], axis=-1) | |||
x_1 = self.conv_1(r_n) | |||
x_2 = self.conv_2(x_1) | |||
x_3 = self.conv_3(x_2) | |||
x_soft = tf.math.multiply(tf.math.sign(x_3), tf.nn.relu(tf.abs(x_3) - self.soft_thr)) | |||
x_4 = self.conv_4(x_soft) | |||
x_5 = self.conv_5(x_4) | |||
x_6 = self.conv_6(x_5) | |||
x_rec = x_6 + r_n | |||
x_1_sym = self.conv_4(x_3) | |||
x_1_sym = self.conv_5(x_1_sym) | |||
x_1_sym = self.conv_6(x_1_sym) | |||
#x_sym_1 = self.conv_10(x_1_sym) | |||
x_sym = x_1_sym - r_n | |||
x_rec = tf.complex(x_rec[:, :, :, :, 0], x_rec[:, :, :, :, 1]) | |||
return x_rec, x_sym | |||
def lowrank(self, x_rec): | |||
[batch, Nt, Nx, Ny] = x_rec.get_shape() | |||
M = tf.reshape(x_rec, [batch, Nt, Nx*Ny]) | |||
St, Ut, Vt = tf.linalg.svd(M) | |||
if self.learned_topk: | |||
#tf.print(tf.sigmoid(self.thres_coef)) | |||
thres = tf.sigmoid(self.thres_coef) * St[:, 0] | |||
thres = tf.expand_dims(thres, -1) | |||
St = tf.nn.relu(St - thres) | |||
else: | |||
top1_mask = np.concatenate( | |||
[np.ones([self.nb, 1], dtype=np.float32), np.zeros([self.nb, self.nt - 1], dtype=np.float32)], 1) | |||
top1_mask = tf.constant(top1_mask) | |||
St = St * top1_mask | |||
St = tf.linalg.diag(St) | |||
St = tf.dtypes.cast(St, tf.complex64) | |||
Vt_conj = tf.transpose(Vt, perm=[0, 2, 1]) | |||
Vt_conj = tf.math.conj(Vt_conj) | |||
US = tf.linalg.matmul(Ut, St) | |||
M = tf.linalg.matmul(US, Vt_conj) | |||
x_rec = tf.reshape(M, [batch, Nt, Nx, Ny]) | |||
return x_rec | |||
def beta_mid(self, beta, x_rec, t): | |||
eta = tf.cast(tf.nn.relu(self.eta), tf.complex64) | |||
return beta + tf.multiply(eta, x_rec - t) | |||
###### DC-CNN ###### | |||
class DC_CNN_LR(tf.keras.Model): | |||
def __init__(self, mask, niter, learned_topk=False): | |||
super(DC_CNN_LR, self).__init__(name='DC_CNN_LR') | |||
self.niter = niter | |||
self.E = Emat_xyt(mask) | |||
self.mask = mask | |||
self.learned_topk = learned_topk | |||
self.celllist = [] | |||
def build(self, input_shape): | |||
for i in range(self.niter-1): | |||
self.celllist.append(DNCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk)) | |||
self.celllist.append(DNCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, is_last=True)) | |||
def call(self, d, csm): | |||
""" | |||
d: undersampled k-space | |||
csm: coil sensitivity map | |||
""" | |||
if csm == None: | |||
nb, nt, nx, ny = d.shape | |||
else: | |||
nb, nc, nt, nx, ny = d.shape | |||
x_rec = self.E.mtimes(d, inv=True, csm=csm) | |||
for i in range(self.niter): | |||
x_rec = self.celllist[i](x_rec, d, d.shape) | |||
return x_rec | |||
class DNCell(layers.Layer): | |||
def __init__(self, input_shape, E, mask, learned_topk=False, is_last=False): | |||
super(DNCell, self).__init__() | |||
if len(input_shape) == 4: | |||
self.nb, self.nt, self.nx, self.ny = input_shape | |||
else: | |||
self.nb, nc, self.nt, self.nx, self.ny = input_shape | |||
self.E = E | |||
self.mask = mask | |||
self.learned_topk = learned_topk | |||
if self.learned_topk: | |||
if is_last: | |||
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=False, name='thres_coef') | |||
else: | |||
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=True, name='thres_coef') | |||
self.conv_1 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_2 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_3 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_4 = CONV_OP(n_f=16, ifactivate=True) | |||
self.conv_5 = CONV_OP(n_f=2, ifactivate=False) | |||
def call(self, x_rec, d, input_shape): | |||
if len(input_shape) == 4: | |||
self.nb, self.nt, self.nx, self.ny = input_shape | |||
else: | |||
self.nb, nc, self.nt, self.nx, self.ny = input_shape | |||
x_rec = self.sparse(x_rec, d) | |||
return x_rec | |||
def sparse(self, x_rec, d): | |||
r_n = tf.stack([tf.math.real(x_rec), tf.math.imag(x_rec)], axis=-1) | |||
x_1 = self.conv_1(r_n) | |||
x_2 = self.conv_2(x_1) | |||
x_3 = self.conv_3(x_2) | |||
x_4 = self.conv_4(x_3) | |||
x_5 = self.conv_5(x_4) | |||
x_rec = x_5 + r_n | |||
x_rec = tf.complex(x_rec[:, :, :, :, 0], x_rec[:, :, :, :, 1]) | |||
if self.learned_topk: | |||
x_rec = self.lowrank(x_rec) | |||
x_rec = self.dc_layer(x_rec, d) | |||
return x_rec | |||
def lowrank(self, x_rec): | |||
[batch, Nt, Nx, Ny] = x_rec.get_shape() | |||
M = tf.reshape(x_rec, [batch, Nt, Nx*Ny]) | |||
St, Ut, Vt = tf.linalg.svd(M) | |||
if self.learned_topk: | |||
#tf.print(tf.sigmoid(self.thres_coef)) | |||
thres = tf.sigmoid(self.thres_coef) * St[:, 0] | |||
thres = tf.expand_dims(thres, -1) | |||
St = tf.nn.relu(St - thres) | |||
else: | |||
top1_mask = np.concatenate( | |||
[np.ones([self.nb, 1], dtype=np.float32), np.zeros([self.nb, self.nt - 1], dtype=np.float32)], 1) | |||
top1_mask = tf.constant(top1_mask) | |||
St = St * top1_mask | |||
St = tf.linalg.diag(St) | |||
St = tf.dtypes.cast(St, tf.complex64) | |||
Vt_conj = tf.transpose(Vt, perm=[0, 2, 1]) | |||
Vt_conj = tf.math.conj(Vt_conj) | |||
US = tf.linalg.matmul(Ut, St) | |||
M = tf.linalg.matmul(US, Vt_conj) | |||
x_rec = tf.reshape(M, [batch, Nt, Nx, Ny]) | |||
return x_rec | |||
def dc_layer(self, x_rec, d): | |||
k_rec = fft2c_mri(x_rec) | |||
k_rec = (1 - self.mask) * k_rec + self.mask * d | |||
x_rec = ifft2c_mri(k_rec) | |||
return x_rec | |||
###### Manifold_Net ###### | |||
class Manifold_Net(tf.keras.Model): | |||
def __init__(self, mask, niter, learned_topk=False, N_factor=1): | |||
super(Manifold_Net, self).__init__(name='Manifold_Net') | |||
self.niter = niter | |||
self.E = Emat_xyt(mask) | |||
self.mask = mask | |||
self.learned_topk = learned_topk | |||
self.N_factor = N_factor | |||
self.celllist = [] | |||
def build(self, input_shape): | |||
for i in range(self.niter-1): | |||
self.celllist.append(ManifoldCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, N_factor=self.N_factor)) | |||
self.celllist.append(ManifoldCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, N_factor=self.N_factor, is_last=True)) | |||
def call(self, d, csm): | |||
""" | |||
d: undersampled k-space | |||
csm: coil sensitivity map | |||
""" | |||
if csm == None: | |||
nb, nt, nx, ny = d.shape | |||
else: | |||
nb, nc, nt, nx, ny = d.shape | |||
x_rec = self.E.mtimes(d, inv=True, csm=csm) | |||
for i in range(self.niter): | |||
x_rec = self.celllist[i](x_rec, d, d.shape) | |||
return x_rec | |||
class ManifoldCell(layers.Layer): | |||
def __init__(self, input_shape, E, mask, learned_topk=False, N_factor=1, is_last=False): | |||
super(ManifoldCell, self).__init__() | |||
if len(input_shape) == 4: | |||
self.nb, self.nt, self.nx, self.ny = input_shape | |||
else: | |||
self.nb, nc, self.nt, self.nx, self.ny = input_shape | |||
self.E = E | |||
self.mask = mask | |||
self.Nx_factor = N_factor | |||
self.Ny_factor = N_factor | |||
self.Nt_factor = N_factor | |||
self.learned_topk = learned_topk | |||
if self.learned_topk: | |||
self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='eta') | |||
#self.lambda_sparse = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='lambda') | |||
self.conv_1 = CNNLayer(n_f=16, n_out=2) | |||
#self.conv_2 = CNNLayer(n_f=16, n_out=2) | |||
#self.conv_3 = CNNLayer(n_f=16, n_out=2) | |||
#self.conv_D = CNNLayer(n_f=16, n_out=2) | |||
#self.conv_transD = CNNLayer(n_f=16, n_out=2) | |||
def call(self, x_rec, d, input_shape): | |||
if len(input_shape) == 4: | |||
self.nb, self.nt, self.nx, self.ny = input_shape | |||
else: | |||
self.nb, nc, self.nt, self.nx, self.ny = input_shape | |||
x_k = self.conv_1(x_rec) | |||
#grad_sparse = self.conv_transD(self.conv_D(x_k)) | |||
#grad_sparse = tf.stack([tf.math.real(grad_sparse), tf.math.imag(grad_sparse)], axis=-1) | |||
#grad_sparse = tf.multiply(self.lambda_sparse, grad_sparse) | |||
#grad_sparse = tf.complex(grad_sparse[..., 0], grad_sparse[..., 1]) | |||
#grad_dc = ifft2c_mri((fft2c_mri(x_k) * self.mask - d) * self.mask) | |||
#g_k = grad_dc + grad_sparse | |||
g_k = ifft2c_mri((fft2c_mri(x_k) * self.mask - d) * self.mask) | |||
#g_k = self.E.mtimes(self.E.mtimes(x_k, inv=False, csm=csm) - d, inv=True, csm=csm) | |||
t_k = self.Tangent_Module(g_k, x_k) | |||
x_k = self.Retraction_Module(x_k, t_k) | |||
#x_k = self.conv_3(x_k) | |||
x_k = self.dc_layer(x_k, d) | |||
return x_k | |||
def Tangent_Module(self, g_k, x_k): | |||
batch, Nt, Nx, Ny = x_k.shape | |||
x_k = tf.transpose(x_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt | |||
g_k = tf.transpose(g_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt | |||
Ux, Uy, Ut = self.Mode(x_k) | |||
first_term = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1) | |||
first_term = self.Mode_Multiply(first_term, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2) | |||
first_term = self.Mode_Multiply(first_term, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3) | |||
first_term = self.Mode_Multiply(first_term, Ux, mode_n=1) | |||
first_term = self.Mode_Multiply(first_term, Uy, mode_n=2) | |||
first_term = self.Mode_Multiply(first_term, Ut, mode_n=3) | |||
C_mode_x, C_mode_y, C_mode_t = self.Core_C(x_k, Ux, Uy, Ut) | |||
second_term_1 = self.Mode_Multiply(g_k, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2) | |||
second_term_1 = self.Mode_Multiply(second_term_1, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3) | |||
second_term_1 = tf.reshape(second_term_1, [batch, Nx, Ny * Nt]) | |||
second_term_1 = self.Projector(second_term_1, Ux) | |||
second_term_1 = self.Core_Multiply(second_term_1, C_mode_x) | |||
second_term_1 = tf.linalg.matmul(second_term_1, C_mode_x) | |||
second_term_1 = tf.reshape(second_term_1, [batch, Nx, Ny, Nt]) | |||
second_term_1 = self.Mode_Multiply(second_term_1, Uy, mode_n=2) | |||
second_term_1 = self.Mode_Multiply(second_term_1, Ut, mode_n=3) | |||
second_term_2 = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1) | |||
second_term_2 = self.Mode_Multiply(second_term_2, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3) | |||
second_term_2 = tf.reshape(tf.transpose(second_term_2, [0, 2, 1, 3]), [batch, Ny, Nx*Nt]) | |||
second_term_2 = self.Projector(second_term_2, Uy) | |||
second_term_2 = self.Core_Multiply(second_term_2, C_mode_y) | |||
second_term_2 = tf.linalg.matmul(second_term_2, C_mode_y) | |||
second_term_2 = tf.transpose(tf.reshape(second_term_2, [batch, Ny, Nx, Nt]), [0, 2, 1, 3]) | |||
second_term_2 = self.Mode_Multiply(second_term_2, Ux, mode_n=1) | |||
second_term_2 = self.Mode_Multiply(second_term_2, Ut, mode_n=3) | |||
second_term_3 = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1) | |||
second_term_3 = self.Mode_Multiply(second_term_3, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2) | |||
second_term_3 = tf.reshape(tf.transpose(second_term_3, [0, 3, 1, 2]), [batch, Nt, Nx * Ny]) | |||
second_term_3 = self.Projector(second_term_3, Ut) | |||
second_term_3 = self.Core_Multiply(second_term_3, C_mode_t) | |||
second_term_3 = tf.linalg.matmul(second_term_3, C_mode_t) | |||
second_term_3 = tf.transpose(tf.reshape(second_term_3, [batch, Nt, Nx, Ny]), [0, 2, 3, 1]) | |||
second_term_3 = self.Mode_Multiply(second_term_3, Ux, mode_n=1) | |||
second_term_3 = self.Mode_Multiply(second_term_3, Uy, mode_n=2) | |||
t_k = first_term + second_term_1 + second_term_2 + second_term_3 | |||
t_k = tf.transpose(t_k, [0, 3, 1, 2]) | |||
return t_k | |||
def Retraction_Module(self, x_k, t_k): | |||
x_k = tf.stack([tf.math.real(x_k), tf.math.imag(x_k)], axis=-1) | |||
t_k = tf.stack([tf.math.real(t_k), tf.math.imag(t_k)], axis=-1) | |||
x_k = x_k - tf.multiply(self.eta, t_k) | |||
x_k = tf.complex(x_k[..., 0], x_k[..., 1]) | |||
batch, Nt, Nx, Ny = x_k.shape | |||
x_k = tf.transpose(x_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt | |||
Ux, Uy, Ut = self.Mode(x_k) | |||
Ux = self.SVT_U(Ux, top_kth= int(Nx / self.Nx_factor)) | |||
Uy = self.SVT_U(Uy, top_kth= int(Ny / self.Ny_factor)) | |||
Ut = self.SVT_U(Ut, top_kth= int(Nt / self.Nt_factor)) | |||
""" | |||
Ux = self.SVT_U(Ux, top_kth= Nx // self.Nx_factor) | |||
Uy = self.SVT_U(Uy, top_kth= Ny // self.Ny_factor) | |||
Ut = self.SVT_U(Ut, top_kth= Nt // self.Nt_factor) | |||
""" | |||
C = self.Mode_Multiply(x_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1) | |||
C = self.Mode_Multiply(C, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2) | |||
C = self.Mode_Multiply(C, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3) | |||
x_k = self.Mode_Multiply(C, Ux, mode_n=1) | |||
x_k = self.Mode_Multiply(x_k, Uy, mode_n=2) | |||
x_k = self.Mode_Multiply(x_k, Ut, mode_n=3) | |||
x_k = tf.transpose(x_k, [0, 3, 1, 2]) | |||
return x_k | |||
def Mode(self, x_k): | |||
batch, Nx, Ny, Nt = x_k.shape | |||
mode_x = tf.reshape(x_k, [batch, Nx, Ny*Nt]) | |||
mode_y = tf.reshape(tf.transpose(x_k, [0, 2, 1, 3]), [batch, Ny, Nx*Nt]) | |||
mode_t = tf.reshape(tf.transpose(x_k, [0, 3, 1, 2]), [batch, Nt, Nx*Ny]) | |||
Sx, Ux, Vx = tf.linalg.svd(mode_x) # Ux: batch, 192, 192 | |||
Sy, Uy, Vy = tf.linalg.svd(mode_y) # Uy: batch, 192, 192 | |||
St, Ut, Vt = tf.linalg.svd(mode_t) | |||
return Ux, Uy, Ut | |||
def Mode_Multiply(self, A, U, mode_n=1): | |||
""" | |||
A: batch, Nx, Ny, Nt | |||
U: batch, Nx, Ny | |||
return: batch, Nx, Ny, Nt | |||
""" | |||
batch, Nx, Ny, Nt = A.shape | |||
if mode_n == 1: | |||
out = tf.linalg.matmul(U, tf.reshape(A, [batch, Nx, Ny*Nt])) # batch, Nx, Ny*Nt | |||
out= tf.reshape(out, [batch, Nx, Ny, Nt]) | |||
elif mode_n == 2: | |||
out = tf.linalg.matmul(U, tf.reshape(tf.transpose(A, [0, 2, 1, 3]), [batch, Ny, Nx * Nt])) # batch, Ny, Nx*Nt | |||
out = tf.transpose(tf.reshape(out, [batch, Ny, Nx, Nt]), [0, 2, 1, 3]) | |||
elif mode_n == 3: | |||
out = tf.linalg.matmul(U, tf.reshape(tf.transpose(A, [0, 3, 1, 2]), [batch, Nt, Nx * Ny])) # batch, Nt, Nx*Ny | |||
out = tf.transpose(tf.reshape(out, [batch, Nt, Nx, Ny]), [0, 2, 3, 1]) | |||
return out | |||
def Core_C(self, x_k, Ux, Uy, Ut): | |||
batch, Nx, Ny, Nt = x_k.shape | |||
C = self.Mode_Multiply(x_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1) | |||
C = self.Mode_Multiply(C, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2) | |||
C = self.Mode_Multiply(C, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3) | |||
C_mode_x = tf.reshape(C, [batch, Nx, Ny * Nt]) | |||
C_mode_y = tf.reshape(tf.transpose(C, [0, 2, 1, 3]), [batch, Ny, Nx * Nt]) | |||
C_mode_t = tf.reshape(tf.transpose(C, [0, 3, 1, 2]), [batch, Nt, Nx * Ny]) | |||
return C_mode_x, C_mode_y, C_mode_t | |||
def Projector(self, second_term, U): | |||
second_term = second_term - tf.linalg.matmul( | |||
tf.linalg.matmul(U, | |||
tf.transpose(U, [0, 2, 1], conjugate=True)), | |||
second_term) | |||
return second_term | |||
def Core_Multiply(self, second_term, C_mode): | |||
second_term = tf.linalg.matmul(second_term, | |||
tf.linalg.matmul(tf.transpose(C_mode, [0, 2, 1], conjugate=True), | |||
tf.linalg.inv(tf.linalg.matmul(C_mode, | |||
tf.transpose(C_mode, [0, 2, 1], conjugate=True))))) | |||
return second_term | |||
def SVT_U(self, Uk, top_kth): | |||
[batch, Nx, Ny] = Uk.get_shape() | |||
mask_1 = tf.ones([batch, Nx, top_kth]) | |||
mask_2 = tf.zeros([batch, Nx, Ny - top_kth]) | |||
mask_top_k = tf.concat([mask_1, mask_2], axis=-1) | |||
mask_top_k = tf.cast(mask_top_k, dtype=Uk.dtype) | |||
Uk = tf.multiply(Uk, mask_top_k) | |||
return Uk | |||
def dc_layer(self, x_rec, d): | |||
k_rec = fft2c_mri(x_rec) | |||
k_rec = (1 - self.mask) * k_rec + self.mask * d | |||
x_rec = ifft2c_mri(k_rec) | |||
return x_rec | |||
def dc_layer_v2(self, x_rec, d): | |||
x_rec = x_rec - ifft2c_mri(fft2c_mri(x_rec) * self.mask - d) | |||
return x_rec | |||
@@ -0,0 +1,2 @@ | |||
model_checkpoint_path: "ckpt" | |||
all_model_checkpoint_paths: "ckpt" |
@@ -0,0 +1,142 @@ | |||
import tensorflow as tf | |||
import os | |||
from model_net_v3 import Manifold_Net | |||
from dataset_tfrecord import get_dataset | |||
import argparse | |||
import scipy.io as scio | |||
import mat73 | |||
import numpy as np | |||
from datetime import datetime | |||
import time | |||
from tools.tools import video_summary, mse, tempfft | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--mode', metavar='str', nargs=1, default=['test'], help='training or test') | |||
parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], help='batch size') | |||
parser.add_argument('--niter', metavar='int', nargs=1, default=['5'], help='number of network iterations') | |||
parser.add_argument('--acc', metavar='int', nargs=1, default=['8'], help='accelerate rate') | |||
parser.add_argument('--mask_pattern', metavar='str', nargs=1, default=['cartesian'], help='mask pattern: cartesian, radial, spiral, vsita') | |||
parser.add_argument('--net', metavar='str', nargs=1, default=['Manifold_Net'], help='Manifold_Net') | |||
parser.add_argument('--weight', metavar='str', nargs=1, default=['models/stable/2021-02-28T13-44-00_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.05_rank_17_cartesian/epoch-60/ckpt'], help='modeldir in ./models') | |||
parser.add_argument('--gpu', metavar='int', nargs=1, default=['2'], help='GPU No.') | |||
parser.add_argument('--data', metavar='str', nargs=1, default=['DYNAMIC_V2'], help='dataset name') | |||
parser.add_argument('--learnedSVT', metavar='bool', nargs=1, default=['True'], help='Learned SVT threshold or not') | |||
parser.add_argument('--SVT_favtor', metavar='float', nargs=1, default=['1.05'], help='SVT factor') | |||
args = parser.parse_args() | |||
# GPU setup | |||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu[0] | |||
GPUs = tf.config.experimental.list_physical_devices('GPU') | |||
tf.config.experimental.set_memory_growth(GPUs[0], True) | |||
dataset_name = args.data[0].upper() | |||
mode = args.mode[0] | |||
batch_size = int(args.batch_size[0]) | |||
niter = int(args.niter[0]) | |||
acc = int(args.acc[0]) | |||
mask_pattern = args.mask_pattern[0] | |||
net_name = args.net[0] | |||
weight_file = args.weight[0] | |||
learnedSVT = bool(args.learnedSVT[0]) | |||
N_factor = float(args.SVT_favtor[0]) | |||
print('network: ', net_name) | |||
print('acc: ', acc) | |||
print('load weight file from: ', weight_file) | |||
result_dir = os.path.join('results/stable', weight_file.split('/')[2]) | |||
if not os.path.isdir(result_dir): | |||
os.makedirs(result_dir) | |||
logdir = './logs' | |||
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) | |||
summary_writer = tf.summary.create_file_writer(os.path.join(logdir, mode, TIMESTAMP + net_name + str(acc) + '/')) | |||
# prepare undersampling mask | |||
if dataset_name == 'DYNAMIC_V2': | |||
multi_coil = False | |||
mask_size = '18_192_192' | |||
elif dataset_name == 'DYNAMIC_V2_MULTICOIL': | |||
multi_coil = True | |||
mask_size = '18_192_192' | |||
elif dataset_name == 'FLOW': | |||
multi_coil = False | |||
mask_size = '20_180_180' | |||
if acc == 8: | |||
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc8.mat')['mask'] | |||
elif acc == 10: | |||
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acs4_acc10.mat')['mask'] | |||
elif acc == 12: | |||
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc12.mat')['mask'] | |||
""" | |||
if acc == 8: | |||
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_8.mat')['mask'] | |||
elif acc == 10: | |||
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_10.mat')['mask'] | |||
elif acc == 12: | |||
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_12.mat')['mask'] | |||
elif acc == 16: | |||
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_16.mat')['mask'] | |||
elif acc == 20: | |||
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_20.mat')['mask'] | |||
elif acc == 24: | |||
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_24.mat')['mask'] | |||
""" | |||
mask = tf.cast(tf.constant(mask), tf.complex64) | |||
# prepare dataset | |||
dataset = get_dataset(mode, dataset_name, batch_size, shuffle=False) | |||
# initialize network | |||
if net_name == 'Manifold_Net': | |||
net = Manifold_Net(mask, niter, learnedSVT, N_factor) | |||
net.load_weights(weight_file) | |||
# Iterate over epochs. | |||
for i, sample in enumerate(dataset): | |||
# forward | |||
k0 = None | |||
csm = None | |||
#with tf.GradientTape() as tape: | |||
if multi_coil: | |||
k0, label, csm = sample | |||
else: | |||
k0, label = sample | |||
label_abs = tf.abs(label) | |||
k0 = k0 * mask | |||
t0 = time.time() | |||
recon = net(k0, csm) | |||
t1 = time.time() | |||
recon_abs = tf.abs(recon) | |||
loss_total = mse(recon, label) | |||
tf.print(i, 'mse =', loss_total.numpy(), 'time = ', t1-t0) | |||
result_file = os.path.join(result_dir, 'recon_'+str(i+1)+'.mat') | |||
datadict = {'recon': np.squeeze(tf.transpose(recon, [0,2,3,1]).numpy())} | |||
scio.savemat(result_file, datadict) | |||
# record gif | |||
with summary_writer.as_default(): | |||
combine_video = tf.concat([label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:]], axis=0).numpy() | |||
combine_video = np.expand_dims(combine_video, -1) | |||
video_summary('convin-'+str(i+1), combine_video, step=1, fps=10) | |||
@@ -0,0 +1,224 @@ | |||
import numpy as np | |||
import tools.mymath | |||
from numpy.lib.stride_tricks import as_strided | |||
def soft_thresh(u, lmda): | |||
"""Soft-threshing operator for complex valued input""" | |||
Su = (abs(u) - lmda) / abs(u) * u | |||
Su[abs(u) < lmda] = 0 | |||
return Su | |||
def normal_pdf(length, sensitivity): | |||
return np.exp(-sensitivity * (np.arange(length) - length / 2)**2) | |||
def var_dens_mask(shape, ivar, sample_high_freq=True): | |||
"""Variable Density Mask (2D undersampling)""" | |||
if len(shape) == 4: | |||
Num, Nt, Nx, Ny = shape | |||
else: | |||
Nx, Ny = shape | |||
Nt = 1 | |||
pdf_x = normal_pdf(Nx, ivar) | |||
pdf_y = normal_pdf(Ny, ivar) | |||
pdf = np.outer(pdf_x, pdf_y) | |||
size = pdf.itemsize | |||
strided_pdf = as_strided(pdf, (Nt, Nx, Ny), (0, Ny * size, size)) | |||
# this must be false if undersampling rate is very low (around 90%~ish) | |||
if sample_high_freq: | |||
strided_pdf = strided_pdf / 1.25 + 0.02 | |||
mask = np.random.binomial(1, strided_pdf) | |||
xc = Nx / 2 | |||
yc = Ny / 2 | |||
mask[:, xc - 4:xc + 5, yc - 4:yc + 5] = True | |||
if Nt == 1: | |||
return mask.reshape((Nx, Ny)) | |||
mask_4D = mask[np.newaxis, :, :, :] | |||
mask_4D = np.tile(mask_4D, (Num, 1, 1, 1)) | |||
return mask_4D | |||
def cartesian_mask(shape, ivar, centred=False, | |||
sample_high_freq=True, sample_centre=True, sample_n=4): | |||
""" | |||
undersamples along Nx | |||
:param shape: tuple - [nt, nx, ny] | |||
:param ivar: sensitivity parameter for Gaussian distribution | |||
""" | |||
if len(shape) == 4: | |||
Num, Nt, Nx, Ny = shape | |||
else: | |||
Nx, Ny = shape | |||
Nt = 1 | |||
pdf_x = normal_pdf(Nx, ivar) | |||
if sample_high_freq: | |||
pdf_x = pdf_x / 1.25 + 0.02 | |||
size = pdf_x.itemsize | |||
stride_pdf = as_strided(pdf_x, (Nt, Nx, 1), (0, size, 0)) | |||
mask = np.random.binomial(1, stride_pdf) | |||
size = mask.itemsize | |||
mask = as_strided(mask, (Nt, Nx, Ny), (size * Nx, size, 0)) | |||
if sample_centre: | |||
s = sample_n / 2 | |||
xc = Nx / 2 | |||
mask[:, xc - s: xc + s, :] = True | |||
if not centred: | |||
mask = mymath.ifftshift(mask, axes=(-1, -2)) | |||
mask_4D = mask[np.newaxis, :, :, :] | |||
mask_4D = np.tile(mask_4D, (Num, 1, 1, 1)) | |||
return mask_4D | |||
def shear_grid_mask(shape, acceleration_rate, sample_low_freq=True, | |||
centred=False, sample_n=10): | |||
''' | |||
Creates undersampling mask which samples in sheer grid | |||
Parameters | |||
---------- | |||
shape: (nt, nx, ny) | |||
acceleration_rate: int | |||
Returns | |||
------- | |||
array | |||
''' | |||
Nt, Nx, Ny = shape | |||
start = np.random.randint(0, acceleration_rate) | |||
mask = np.zeros((Nt, Nx)) | |||
for t in xrange(Nt): | |||
mask[t, (start+t)%acceleration_rate::acceleration_rate] = 1 | |||
xc = Nx / 2 | |||
xl = sample_n / 2 | |||
if sample_low_freq and centred: | |||
xh = xl | |||
if sample_n % 2 == 0: | |||
xh += 1 | |||
mask[:, xc - xl:xc + xh+1] = 1 | |||
elif sample_low_freq: | |||
xh = xl | |||
if sample_n % 2 == 1: | |||
xh -= 1 | |||
if xl > 0: | |||
mask[:, :xl] = 1 | |||
if xh > 0: | |||
mask[:, -xh:] = 1 | |||
mask_rep = np.repeat(mask[..., np.newaxis], Ny, axis=-1) | |||
return mask_rep | |||
def perturbed_shear_grid_mask(shape, acceleration_rate, sample_low_freq=True, | |||
centred=False, | |||
sample_n=10): | |||
Nt, Nx, Ny = shape | |||
start = np.random.randint(0, acceleration_rate) | |||
mask = np.zeros((Nt, Nx)) | |||
for t in xrange(Nt): | |||
mask[t, (start+t)%acceleration_rate::acceleration_rate] = 1 | |||
# brute force | |||
rand_code = np.random.randint(0, 3, size=Nt*Nx) | |||
shift = np.array([-1, 0, 1])[rand_code] | |||
new_mask = np.zeros_like(mask) | |||
for t in xrange(Nt): | |||
for x in xrange(Nx): | |||
if mask[t, x]: | |||
new_mask[t, (x + shift[t*x])%Nx] = 1 | |||
xc = Nx / 2 | |||
xl = sample_n / 2 | |||
if sample_low_freq and centred: | |||
xh = xl | |||
if sample_n % 2 == 0: | |||
xh += 1 | |||
new_mask[:, xc - xl:xc + xh+1] = 1 | |||
elif sample_low_freq: | |||
xh = xl | |||
if sample_n % 2 == 1: | |||
xh -= 1 | |||
new_mask[:, :xl] = 1 | |||
new_mask[:, -xh:] = 1 | |||
mask_rep = np.repeat(new_mask[..., np.newaxis], Ny, axis=-1) | |||
return mask_rep | |||
def undersample(x, mask, centred=False, norm='ortho'): | |||
''' | |||
Undersample x. FFT2 will be applied to the last 2 axis | |||
Parameters | |||
---------- | |||
x: array_like | |||
data | |||
mask: array_like | |||
undersampling mask in fourier domain | |||
Returns | |||
------- | |||
xu: array_like | |||
undersampled image in image domain. Note that it is complex valued | |||
x_fu: array_like | |||
undersampled data in kspace | |||
''' | |||
assert x.shape == mask.shape | |||
if centred: | |||
x_f = mymath.fft2c(x, norm=norm) | |||
x_fu = x_f * mask | |||
x_u = mymath.ifft2c(x_fu, norm=norm) | |||
return x_u, x_fu | |||
else: | |||
x_f = mymath.fft2(x, norm=norm) | |||
x_fu = x_f * mask | |||
x_u = mymath.ifft2(x_fu, norm=norm) | |||
return x_u, x_fu, x_f | |||
def data_consistency(x, y, mask, centered=False, norm='ortho'): | |||
''' | |||
x is in image space, | |||
y is in k-space | |||
''' | |||
if centered: | |||
xf = mymath.fft2c(x, norm=norm) | |||
xm = (1 - mask) * xf + y | |||
xd = mymath.ifft2c(xm, norm=norm) | |||
else: | |||
xf = mymath.fft2(x, norm=norm) | |||
xm = (1 - mask) * xf + y | |||
xd = mymath.ifft2(xm, norm=norm) | |||
return xd | |||
def get_phase(x): | |||
xr = np.real(x) | |||
xi = np.imag(x) | |||
phase = np.arctan(xi / (xr + 1e-12)) | |||
return phase |
@@ -0,0 +1,148 @@ | |||
__author__ = 'Jo Schlemper' | |||
import numpy as np | |||
sqrt = np.sqrt | |||
from numpy.fft import fft, fft2, ifft2, ifft, ifftshift, fftshift | |||
def fftc(x, axis=-1, norm='ortho'): | |||
''' expect x as m*n matrix ''' | |||
return fftshift(fft(ifftshift(x, axes=axis), axis=axis, norm=norm), axes=axis) | |||
def ifftc(x, axis=-1, norm='ortho'): | |||
''' expect x as m*n matrix ''' | |||
return fftshift(ifft(ifftshift(x, axes=axis), axis=axis, norm=norm), axes=axis) | |||
def fft2c(x): | |||
''' | |||
Centered fft | |||
Note: fft2 applies fft to last 2 axes by default | |||
:param x: 2D onwards. e.g: if its 3d, x.shape = (n,row,col). 4d:x.shape = (n,slice,row,col) | |||
:return: | |||
''' | |||
# axes = (len(x.shape)-2, len(x.shape)-1) # get last 2 axes | |||
axes = (-2, -1) # get last 2 axes | |||
res = fftshift(fft2(ifftshift(x, axes=axes), norm='ortho'), axes=axes) | |||
return res | |||
def ifft2c(x): | |||
''' | |||
Centered ifft | |||
Note: fft2 applies fft to last 2 axes by default | |||
:param x: 2D onwards. e.g: if its 3d, x.shape = (n,row,col). 4d:x.shape = (n,slice,row,col) | |||
:return: | |||
''' | |||
axes = (-2, -1) # get last 2 axes | |||
res = fftshift(ifft2(ifftshift(x, axes=axes), norm='ortho'), axes=axes) | |||
return res | |||
def fourier_matrix(rows, cols): | |||
''' | |||
parameters: | |||
rows: number or rows | |||
cols: number of columns | |||
return unitary (rows x cols) fourier matrix | |||
''' | |||
# from scipy.linalg import dft | |||
# return dft(rows,scale='sqrtn') | |||
col_range = np.arange(cols) | |||
row_range = np.arange(rows) | |||
scale = 1 / np.sqrt(cols) | |||
coeffs = np.outer(row_range, col_range) | |||
fourier_matrix = np.exp(coeffs * (-2. * np.pi * 1j / cols)) * scale | |||
return fourier_matrix | |||
def inverse_fourier_matrix(rows, cols): | |||
return np.array(np.matrix(fourier_matrix(rows, cols)).getH()) | |||
def flip(m, axis): | |||
""" | |||
==== > Only in numpy 1.12 < ===== | |||
Reverse the order of elements in an array along the given axis. | |||
The shape of the array is preserved, but the elements are reordered. | |||
.. versionadded:: 1.12.0 | |||
Parameters | |||
---------- | |||
m : array_like | |||
Input array. | |||
axis : integer | |||
Axis in array, which entries are reversed. | |||
Returns | |||
------- | |||
out : array_like | |||
A view of `m` with the entries of axis reversed. Since a view is | |||
returned, this operation is done in constant time. | |||
See Also | |||
-------- | |||
flipud : Flip an array vertically (axis=0). | |||
fliplr : Flip an array horizontally (axis=1). | |||
Notes | |||
----- | |||
flip(m, 0) is equivalent to flipud(m). | |||
flip(m, 1) is equivalent to fliplr(m). | |||
flip(m, n) corresponds to ``m[...,::-1,...]`` with ``::-1`` at position n. | |||
Examples | |||
-------- | |||
>>> A = np.arange(8).reshape((2,2,2)) | |||
>>> A | |||
array([[[0, 1], | |||
[2, 3]], | |||
[[4, 5], | |||
[6, 7]]]) | |||
>>> flip(A, 0) | |||
array([[[4, 5], | |||
[6, 7]], | |||
[[0, 1], | |||
[2, 3]]]) | |||
>>> flip(A, 1) | |||
array([[[2, 3], | |||
[0, 1]], | |||
[[6, 7], | |||
[4, 5]]]) | |||
>>> A = np.random.randn(3,4,5) | |||
>>> np.all(flip(A,2) == A[:,:,::-1,...]) | |||
True | |||
""" | |||
if not hasattr(m, 'ndim'): | |||
m = np.asarray(m) | |||
indexer = [slice(None)] * m.ndim | |||
try: | |||
indexer[axis] = slice(None, None, -1) | |||
except IndexError: | |||
raise ValueError("axis=%i is invalid for the %i-dimensional input array" | |||
% (axis, m.ndim)) | |||
return m[tuple(indexer)] | |||
def rot90_nd(x, axes=(-2, -1), k=1): | |||
"""Rotates selected axes""" | |||
def flipud(x): | |||
return flip(x, axes[0]) | |||
def fliplr(x): | |||
return flip(x, axes[1]) | |||
x = np.asanyarray(x) | |||
if x.ndim < 2: | |||
raise ValueError("Input must >= 2-d.") | |||
k = k % 4 | |||
if k == 0: | |||
return x | |||
elif k == 1: | |||
return fliplr(x).swapaxes(*axes) | |||
elif k == 2: | |||
return fliplr(flipud(x)) | |||
else: | |||
# k == 3 | |||
return fliplr(x.swapaxes(*axes)) |
@@ -0,0 +1,341 @@ | |||
import tempfile | |||
import os | |||
import tensorflow as tf | |||
import numpy as np | |||
from numpy.lib.stride_tricks import as_strided | |||
import scipy.io as scio | |||
def video_summary(name, video, step=None, fps=10): | |||
name = tf.constant(name).numpy().decode('utf-8') | |||
video = np.array(video) | |||
if video.dtype in (np.float32, np.float64): | |||
video = np.clip(255 * video, 0, 255).astype(np.uint8) | |||
B, T, H, W, C = video.shape | |||
try: | |||
frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C)) | |||
summary = tf.compat.v1.Summary() | |||
image = tf.compat.v1.Summary.Image( | |||
height=B * H, width=T * W, colorspace=C) | |||
image.encoded_image_string = encode_gif(frames, fps) | |||
summary.value.add(tag=name + '/gif', image=image) | |||
tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step) | |||
except (IOError, OSError) as e: | |||
print('GIF summaries require ffmpeg in $PATH.', e) | |||
frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C)) | |||
tf.summary.image(name + '/grid', frames, step) | |||
def encode_gif(frames, fps): | |||
from subprocess import Popen, PIPE | |||
h, w, c = frames[0].shape | |||
pxfmt = {1: 'gray', 3: 'rgb24'}[c] | |||
cmd = ' '.join([ | |||
f'ffmpeg -y -f rawvideo -vcodec rawvideo', | |||
f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex', | |||
f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse', | |||
f'-r {fps:.02f} -f gif -']) | |||
proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE) | |||
for image in frames: | |||
proc.stdin.write(image.tostring()) | |||
out, err = proc.communicate() | |||
if proc.returncode: | |||
raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')])) | |||
del proc | |||
return out | |||
def normal_pdf(length, sensitivity): | |||
return np.exp(-sensitivity * (np.arange(length) - length / 2)**2) | |||
def cartesian_mask(shape, acc, sample_n=10, centred=False): | |||
""" | |||
Sampling density estimated from implementation of kt FOCUSS | |||
shape: tuple - of form (..., nx, ny) | |||
acc: float - doesn't have to be integer 4, 8, etc.. | |||
""" | |||
N, Nx, Ny = int(np.prod(shape[:-2])), shape[-2], shape[-1] | |||
pdf_x = normal_pdf(Nx, 0.5/(Nx/10.)**2) | |||
lmda = Nx/(2.*acc) | |||
n_lines = int(Nx / acc) | |||
# add uniform distribution | |||
pdf_x += lmda * 1./Nx | |||
if sample_n: | |||
pdf_x[Nx//2-sample_n//2:Nx//2+sample_n//2] = 0 | |||
pdf_x /= np.sum(pdf_x) | |||
n_lines -= sample_n | |||
mask = np.zeros((N, Nx)) | |||
for i in range(N): | |||
idx = np.random.choice(Nx, n_lines, False, pdf_x) | |||
mask[i, idx] = 1 | |||
if sample_n: | |||
mask[:, Nx//2-sample_n//2:Nx//2+sample_n//2] = 1 | |||
size = mask.itemsize | |||
mask = as_strided(mask, (N, Nx, Ny), (size * Nx, size, 0)) | |||
mask = mask.reshape(shape) | |||
if not centred: | |||
mask = mymath.ifftshift(mask, axes=(-1, -2)) | |||
return mask | |||
def loss_function_ISTA(y, y_, y_sym, n_iter): | |||
pred = tf.stack([tf.math.real(y), tf.math.imag(y)], axis=-1) | |||
label = tf.stack([tf.math.real(y_), tf.math.imag(y_)], axis=-1) | |||
cost = tf.reduce_mean(tf.math.square(pred - label)) | |||
cost_sym = 0 | |||
for k in range(n_iter): | |||
#pred_sym = tf.stack([tf.math.real(y_sym[k]), tf.math.imag(y_sym[k])], axis=-1) | |||
cost_sym += tf.reduce_mean(tf.square(y_sym)) | |||
loss = cost + 0.01 * cost_sym | |||
return loss | |||
def tempfft(input, inv): | |||
if len(input.shape) == 4: | |||
nb, nt, nx, ny = np.float32(input.shape) | |||
nt = tf.constant(np.complex64(nt + 0j)) | |||
if inv: | |||
x = tf.transpose(input, perm=[0,2,3,1]) | |||
#x = tf.signal.fftshift(x, 3) | |||
x = tf.signal.ifft(x) | |||
x = tf.transpose(x, perm=[0,3,1,2]) | |||
x = x * tf.sqrt(nt) | |||
else: | |||
x = tf.transpose(input, perm=[0,2,3,1]) | |||
x = tf.signal.fft(x) | |||
#x = tf.signal.fftshift(x, 3) | |||
x = tf.transpose(x, perm=[0,3,1,2]) | |||
x = x / tf.sqrt(nt) | |||
else: | |||
nb, nt, nx, ny, _ = np.float32(input.shape) | |||
nt = tf.constant(np.complex64(nt + 0j)) | |||
if inv: | |||
x = tf.transpose(input, perm=[0,2,3,4,1]) | |||
#x = tf.signal.fftshift(x, 4) | |||
x = tf.signal.ifft(x) | |||
x = tf.transpose(x, perm=[0,4,1,2,3]) | |||
x = x * tf.sqrt(nt) | |||
else: | |||
x = tf.transpose(input, perm=[0,2,3,4,1]) | |||
x = tf.signal.fft(x) | |||
#x = tf.signal.fftshift(x, 4) | |||
x = tf.transpose(x, perm=[0,4,1,2,3]) | |||
x = x / tf.sqrt(nt) | |||
return x | |||
def mse(recon, label): | |||
if recon.dtype == tf.complex64: | |||
residual_cplx = recon - label | |||
residual = tf.stack([tf.math.real(residual_cplx), tf.math.imag(residual_cplx)], axis=-1) | |||
mse = tf.reduce_mean(residual**2) | |||
else: | |||
residual = recon - label | |||
mse = tf.reduce_mean(residual**2) | |||
return mse | |||
def fft2c_mri(x): | |||
# nb nx ny nt | |||
X = tf.signal.fftshift(x, 2) | |||
X = tf.transpose(X, perm=[0,1,3,2]) # permute to make nx dimension the last one. | |||
X = tf.signal.fft(X) | |||
X = tf.transpose(X, perm=[0,1,3,2]) # permute back to original order. | |||
nb, nt, nx, ny = np.float32(x.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
X = tf.signal.fftshift(X, 2) / tf.sqrt(nx) | |||
X = tf.signal.fftshift(X, 3) | |||
X = tf.signal.fft(X) | |||
X = tf.signal.fftshift(X, 3) / tf.sqrt(ny) | |||
return X | |||
def ifft2c_mri(X): | |||
# nb nx ny nt | |||
x = tf.signal.fftshift(X, 2) | |||
x = tf.transpose(x, perm=[0,1,3,2]) # permute a to make nx dimension the last one. | |||
x = tf.signal.ifft(x) | |||
x = tf.transpose(x, perm=[0,1,3,2]) # permute back to original order. | |||
nb, nt, nx, ny = np.float32(X.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
x = tf.signal.fftshift(x, 2) * tf.sqrt(nx) | |||
x = tf.signal.fftshift(x, 3) | |||
x = tf.signal.ifft(x) | |||
x = tf.signal.fftshift(x, 3) * tf.sqrt(ny) | |||
return x | |||
def sos(x): | |||
# x: nb, ncoil, nt, nx, ny; complex64 | |||
x = tf.math.reduce_sum(tf.abs(x**2), axis=1) | |||
x = x**(1.0/2) | |||
return x | |||
def softthres(x, thres): | |||
x_abs = tf.abs(x) | |||
coef = tf.nn.relu(x_abs - thres) / (x_abs + 1e-10) | |||
coef = tf.cast(coef, tf.complex64) | |||
return coef * x | |||
""" | |||
class Emat_xyt(): | |||
def __init__(self, mask): | |||
super(Emat_xyt, self).__init__() | |||
self.mask = mask | |||
def mtimes(self, b, inv): | |||
if inv: | |||
# this is for single channel reconstruction only. | |||
x = self._ifft2c_mri(b * self.mask) | |||
else: | |||
x = self._fft2c_mri(b) * self.mask | |||
return x | |||
def _fft2c_mri(self, x): | |||
# nb nx ny nt | |||
X = tf.signal.fftshift(x, 2) | |||
X = tf.transpose(X, perm=[0,1,3,2]) # permute to make nx dimension the last one. | |||
X = tf.signal.fft(X) | |||
X = tf.transpose(X, perm=[0,1,3,2]) # permute back to original order. | |||
nb, nt, nx, ny = np.float32(x.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
X = tf.signal.fftshift(X, 2) / tf.sqrt(nx) | |||
X = tf.signal.fftshift(X, 3) | |||
X = tf.signal.fft(X) | |||
X = tf.signal.fftshift(X, 3) / tf.sqrt(ny) | |||
return X | |||
def _ifft2c_mri(self, X): | |||
# nb nx ny nt | |||
x = tf.signal.fftshift(X, 2) | |||
x = tf.transpose(x, perm=[0,1,3,2]) # permute a to make nx dimension the last one. | |||
x = tf.signal.ifft(x) | |||
x = tf.transpose(x, perm=[0,1,3,2]) # permute back to original order. | |||
nb, nt, nx, ny = np.float32(X.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
x = tf.signal.fftshift(x, 2) * tf.sqrt(nx) | |||
x = tf.signal.fftshift(x, 3) | |||
x = tf.signal.ifft(x) | |||
x = tf.signal.fftshift(x, 3) * tf.sqrt(ny) | |||
return x | |||
""" | |||
class Emat_xyt(): | |||
def __init__(self, mask): | |||
super(Emat_xyt, self).__init__() | |||
self.mask = mask | |||
def mtimes(self, b, inv, csm): | |||
if csm == None: | |||
if inv: | |||
x = self._ifft2c_mri_singlecoil(b * self.mask) | |||
else: | |||
x = self._fft2c_mri_singlecoil(b) * self.mask | |||
else: | |||
if len(self.mask.shape) > 3: | |||
if inv: | |||
x = self._ifft2c_mri_multicoil(b * self.mask[:,:,0:b.shape[2],:,:]) | |||
x = x * tf.math.conj(csm) | |||
x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64) | |||
else: | |||
b = tf.expand_dims(b, 1) * csm | |||
x = self._fft2c_mri_multicoil(b) * self.mask[:,:,0:b.shape[2],:,:] | |||
else: | |||
if inv: | |||
x = self._ifft2c_mri_multicoil(b * self.mask) | |||
x = x * tf.math.conj(csm) | |||
x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64) | |||
else: | |||
b = tf.expand_dims(b, 1) * csm | |||
x = self._fft2c_mri_multicoil(b) * self.mask | |||
return x | |||
def _fft2c_mri_multicoil(self, x): | |||
# nb nt nx ny -> nb, nc, nt, nx, ny | |||
X = tf.signal.fftshift(x, 3) | |||
X = tf.transpose(X, perm=[0,1,2,4,3]) # permute to make nx dimension the last one. | |||
X = tf.signal.fft(X) | |||
X = tf.transpose(X, perm=[0,1,2,4,3]) # permute back to original order. | |||
nb, nc, nt, nx, ny = np.float32(x.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
X = tf.signal.fftshift(X, 3) / tf.sqrt(nx) | |||
X = tf.signal.fftshift(X, 4) | |||
X = tf.signal.fft(X) | |||
X = tf.signal.fftshift(X, 4) / tf.sqrt(ny) | |||
return X | |||
def _ifft2c_mri_multicoil(self, X): | |||
# nb nt nx ny -> nb, nc, nt, nx, ny | |||
x = tf.signal.fftshift(X, 3) | |||
x = tf.transpose(x, perm=[0,1,2,4,3]) # permute a to make nx dimension the last one. | |||
x = tf.signal.ifft(x) | |||
x = tf.transpose(x, perm=[0,1,2,4,3]) # permute back to original order. | |||
nb, nc, nt, nx, ny = np.float32(X.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
x = tf.signal.fftshift(x, 3) * tf.sqrt(nx) | |||
x = tf.signal.fftshift(x, 4) | |||
x = tf.signal.ifft(x) | |||
x = tf.signal.fftshift(x, 4) * tf.sqrt(ny) | |||
return x | |||
def _fft2c_mri_singlecoil(self, x): | |||
# nb nx ny nt | |||
X = tf.signal.fftshift(x, 2) | |||
X = tf.transpose(X, perm=[0,1,3,2]) # permute to make nx dimension the last one. | |||
X = tf.signal.fft(X) | |||
X = tf.transpose(X, perm=[0,1,3,2]) # permute back to original order. | |||
nb, nt, nx, ny = np.float32(x.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
X = tf.signal.fftshift(X, 2) / tf.sqrt(nx) | |||
X = tf.signal.fftshift(X, 3) | |||
X = tf.signal.fft(X) | |||
X = tf.signal.fftshift(X, 3) / tf.sqrt(ny) | |||
return X | |||
def _ifft2c_mri_singlecoil(self, X): | |||
# nb nx ny nt | |||
x = tf.signal.fftshift(X, 2) | |||
x = tf.transpose(x, perm=[0,1,3,2]) # permute a to make nx dimension the last one. | |||
x = tf.signal.ifft(x) | |||
x = tf.transpose(x, perm=[0,1,3,2]) # permute back to original order. | |||
nb, nt, nx, ny = np.float32(X.shape) | |||
nx = tf.constant(np.complex64(nx + 0j)) | |||
ny = tf.constant(np.complex64(ny + 0j)) | |||
x = tf.signal.fftshift(x, 2) * tf.sqrt(nx) | |||
x = tf.signal.fftshift(x, 3) | |||
x = tf.signal.ifft(x) | |||
x = tf.signal.fftshift(x, 3) * tf.sqrt(ny) | |||
return x |
@@ -0,0 +1,187 @@ | |||
# -*- coding: utf-8 -*- | |||
""" | |||
Created on Mon Dec 9 15:50:03 2019 | |||
@author: wmy | |||
""" | |||
import numpy as np | |||
#import matplotlib.pyplot as plt | |||
import pywt | |||
import tensorflow as tf | |||
#import pylab | |||
#pylab.rcParams['figure.figsize'] = (10.0, 10.0) | |||
def dwt2d(x, wave='haar'): | |||
# shape x: (b, h, w, c) | |||
nc = int(x.shape.dims[3]) | |||
# 小波波形 | |||
w = pywt.Wavelet(wave) | |||
# 水平低频 垂直低频 | |||
ll = np.outer(w.dec_lo, w.dec_lo) | |||
# 水平低频 垂直高频 | |||
lh = np.outer(w.dec_hi, w.dec_lo) | |||
# 水平高频 垂直低频 | |||
hl = np.outer(w.dec_lo, w.dec_hi) | |||
# 水平高频 垂直高频 | |||
hh = np.outer(w.dec_hi, w.dec_hi) | |||
# 卷积核 | |||
core = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4)) | |||
core[:, :, 0, 0] = ll[::-1, ::-1] | |||
core[:, :, 0, 1] = lh[::-1, ::-1] | |||
core[:, :, 0, 2] = hl[::-1, ::-1] | |||
core[:, :, 0, 3] = hh[::-1, ::-1] | |||
core = core.astype(np.float32) | |||
kernel = np.array([core], dtype=np.float32) | |||
kernel = tf.convert_to_tensor(kernel) | |||
p = 2 * (len(w.dec_lo) // 2 - 1) | |||
with tf.compat.v1.variable_scope('dwt2d'): | |||
# padding odd length | |||
x = tf.pad(x, tf.constant([[0, 0], [p, p+1], [p, p+1], [0, 0]])) | |||
xh = tf.shape(x)[1] - tf.shape(x)[1]%2 | |||
xw = tf.shape(x)[2] - tf.shape(x)[2]%2 | |||
x = x[:, 0:xh, 0:xw, :] | |||
# convert to 3d data | |||
x3d = tf.expand_dims(x, 1) | |||
# 切开通道 | |||
x3d = tf.split(x3d, int(x3d.shape.dims[4]), 4) | |||
# 贴到维度一 | |||
x3d = tf.concat([a for a in x3d], 1) | |||
# 三维卷积 | |||
y3d = tf.nn.conv3d(x3d, kernel, padding='VALID', strides=[1, 1, 2, 2, 1]) | |||
# 切开维度一 | |||
y = tf.split(y3d, int(y3d.shape.dims[1]), 1) | |||
# 贴到通道维 | |||
y = tf.concat([a for a in y], 4) | |||
y = tf.reshape(y, (tf.shape(y)[0], tf.shape(y)[2], tf.shape(y)[3], 4*nc)) | |||
# 拼贴通道 | |||
channels = tf.split(y, nc, 3) | |||
outputs = [] | |||
for channel in channels: | |||
(cA, cH, cV, cD) = tf.split(channel, 4, 3) | |||
AH = tf.concat([cA, cH], axis=2) | |||
VD = tf.concat([cV, cD], axis=2) | |||
outputs.append(tf.concat([AH, VD], axis=1)) | |||
pass | |||
outputs = tf.concat(outputs, axis=-1) | |||
pass | |||
return outputs | |||
def wavedec2d(x, level=1, wave='haar'): | |||
if level == 0: | |||
return x | |||
y = dwt2d(x, wave=wave) | |||
hcA = tf.math.floordiv(tf.shape(y)[1], 2) | |||
wcA = tf.math.floordiv(tf.shape(y)[2], 2) | |||
cA = y[:, 0:hcA, 0:wcA, :] | |||
cA = wavedec2d(cA, level=level-1, wave=wave) | |||
cA = cA[:, 0:hcA, 0:wcA, :] | |||
hcA = tf.shape(cA)[1] | |||
wcA = tf.shape(cA)[2] | |||
cH = y[:, 0:hcA, wcA:, :] | |||
cV = y[:, hcA:, 0:wcA, :] | |||
cD = y[:, hcA:, wcA:, :] | |||
AH = tf.concat([cA, cH], axis=2) | |||
VD = tf.concat([cV, cD], axis=2) | |||
outputs = tf.concat([AH, VD], axis=1) | |||
return outputs | |||
def idwt2d(x, wave='haar'): | |||
# shape x: (b, h, w, c) | |||
nc = int(x.shape.dims[3]) | |||
# 小波波形 | |||
w = pywt.Wavelet(wave) | |||
# 水平低频 垂直低频 | |||
ll = np.outer(w.dec_lo, w.dec_lo) | |||
# 水平低频 垂直高频 | |||
lh = np.outer(w.dec_hi, w.dec_lo) | |||
# 水平高频 垂直低频 | |||
hl = np.outer(w.dec_lo, w.dec_hi) | |||
# 水平高频 垂直高频 | |||
hh = np.outer(w.dec_hi, w.dec_hi) | |||
# 卷积核 | |||
core = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4)) | |||
core[:, :, 0, 0] = ll[::-1, ::-1] | |||
core[:, :, 0, 1] = lh[::-1, ::-1] | |||
core[:, :, 0, 2] = hl[::-1, ::-1] | |||
core[:, :, 0, 3] = hh[::-1, ::-1] | |||
core = core.astype(np.float32) | |||
kernel = np.array([core], dtype=np.float32) | |||
kernel = tf.convert_to_tensor(kernel) | |||
s = 2 * (len(w.dec_lo) // 2 - 1) | |||
# 反变换 | |||
with tf.compat.v1.variable_scope('idwt2d'): | |||
hcA = tf.math.floordiv(tf.shape(x)[1], 2) | |||
wcA = tf.math.floordiv(tf.shape(x)[2], 2) | |||
y = [] | |||
for c in range(nc): | |||
channel = x[:, :, :, c] | |||
channel = tf.expand_dims(channel, -1) | |||
cA = channel[:, 0:hcA, 0:wcA, :] | |||
cH = channel[:, 0:hcA, wcA:, :] | |||
cV = channel[:, hcA:, 0:wcA, :] | |||
cD = channel[:, hcA:, wcA:, :] | |||
temp = tf.concat([cA, cH, cV, cD], axis=-1) | |||
y.append(temp) | |||
pass | |||
# nc * 4 | |||
y = tf.concat(y, axis=-1) | |||
y3d = tf.expand_dims(y, 1) | |||
y3d = tf.split(y3d, nc, 4) | |||
y3d = tf.concat([a for a in y3d], 1) | |||
output_shape = [tf.shape(y)[0], tf.shape(y3d)[1], \ | |||
2*(tf.shape(y)[1]-1)+np.shape(ll)[0], \ | |||
2*(tf.shape(y)[2]-1)+np.shape(ll)[1], 1] | |||
x3d = tf.nn.conv3d_transpose(y3d, kernel, output_shape=output_shape, padding='VALID', strides=[1, 1, 2, 2, 1]) | |||
outputs = tf.split(x3d, nc, 1) | |||
outputs = tf.concat([x for x in outputs], 4) | |||
outputs = tf.reshape(outputs, (tf.shape(outputs)[0], tf.shape(outputs)[2], tf.shape(outputs)[3], nc)) | |||
outputs = outputs[:, s:2*(tf.shape(y)[1]-1)+np.shape(ll)[0]-s, \ | |||
s:2*(tf.shape(y)[2]-1)+np.shape(ll)[1]-s, :] | |||
pass | |||
return outputs | |||
def dwt2dc(x, wave='haar'): | |||
# b, t, h, w | |||
nt = x.shape[1] | |||
x_2c = tf.transpose(x, perm=[0,2,3,1]) | |||
x_2c = tf.concat([tf.math.real(x_2c), tf.math.imag(x_2c)], axis=-1) | |||
dwtx_2c = dwt2d(x_2c) | |||
dwtx = tf.complex(dwtx_2c[:,:,:,0:nt], dwtx_2c[:,:,:,nt:2*nt]) | |||
dwtx = tf.transpose(dwtx, perm=[0,3,1,2]) | |||
return dwtx | |||
def idwt2dc(x, wave='haar'): | |||
# b, t, h, w | |||
nt = x.shape[1] | |||
x_2c = tf.transpose(x, perm=[0,2,3,1]) | |||
x_2c = tf.concat([tf.math.real(x_2c), tf.math.imag(x_2c)], axis=-1) | |||
idwtx_2c = idwt2d(x_2c) | |||
idwtx = tf.complex(idwtx_2c[:,:,:,0:nt], idwtx_2c[:,:,:,nt:2*nt]) | |||
idwtx = tf.transpose(idwtx, perm=[0,3,1,2]) | |||
return idwtx | |||
""" | |||
tf.reset_default_graph() | |||
inputs = tf.placeholder(tf.float32, [None, None, None, 3], name='inputs') | |||
image = plt.imread('test.jpg') | |||
plt.imshow(image) | |||
plt.show() | |||
x = np.array([image, image[:, ::-1, :]]) | |||
dec = wavedec2d(inputs, level=5, wave='sym4') | |||
dwt = dwt2d(inputs, wave='sym4') | |||
idwt = idwt2d(dwt, wave='sym4') | |||
with tf.Session() as sess: | |||
sess.run(tf.global_variables_initializer()) | |||
result = sess.run(dec, feed_dict={inputs:x}) | |||
plt.imshow(np.array(result[0], dtype=np.uint8)) | |||
plt.show() | |||
trans = sess.run(dwt, feed_dict={inputs:x}) | |||
plt.imshow(np.array(trans[0], dtype=np.uint8)) | |||
plt.show() | |||
recons = sess.run(idwt, feed_dict={inputs:x}) | |||
plt.imshow(np.array(recons[0], dtype=np.uint8)) | |||
plt.show() | |||
pass | |||
""" |