Browse Source

Add files via upload

main
Ziwen Ke GitHub 1 year ago
parent
commit
0deade8265
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1795 additions and 0 deletions
  1. +173
    -0
      main_net_v3.py
  2. +578
    -0
      model_net_v3.py
  3. +2
    -0
      models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/checkpoint
  4. BIN
      models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/ckpt.data-00000-of-00002
  5. BIN
      models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/ckpt.data-00001-of-00002
  6. BIN
      models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/ckpt.index
  7. +142
    -0
      test_net_v3.py
  8. +224
    -0
      tools/compressed_sensing.py
  9. +148
    -0
      tools/mymath.py
  10. +341
    -0
      tools/tools.py
  11. +187
    -0
      tools/wavelet.py

+ 173
- 0
main_net_v3.py View File

@@ -0,0 +1,173 @@
import tensorflow as tf
import os
from model_net_v3 import Manifold_Net
from dataset_tfrecord import get_dataset
import argparse
import scipy.io as scio
import mat73
import numpy as np
from datetime import datetime
import time
from tools.tools import video_summary

from tools.tools import tempfft, mse


#tf.debugging.set_log_device_placement(True)
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# tf.debugging.set_log_device_placement(True)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--num_epoch', metavar='int', nargs=1, default=['50'], help='number of epochs')
parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], help='batch size')
parser.add_argument('--learning_rate', metavar='float', nargs=1, default=['0.001'], help='initial learning rate')
parser.add_argument('--niter', metavar='int', nargs=1, default=['5'], help='number of network iterations')
parser.add_argument('--nconv', metavar='int', nargs=1, default=['3'], help='number of convolutional layers on CNNLayer')
parser.add_argument('--acc', metavar='int', nargs=1, default=['8'], help='accelerate rate')
parser.add_argument('--dc_type', metavar='str', nargs=1, default=['v1'], help='v1: kspace; v2: image')
parser.add_argument('--mask_pattern', metavar='str', nargs=1, default=['spiral'], help='mask pattern: cartesian, radial, spiral, vista')
parser.add_argument('--net', metavar='str', nargs=1, default=['Manifold_Net'], help='Manifold_Net')
parser.add_argument('--gpu', metavar='int', nargs=1, default=['2'], help='GPU No.')
parser.add_argument('--data', metavar='str', nargs=1, default=['DYNAMIC_V2'], help='dataset name')
parser.add_argument('--learnedSVT', metavar='bool', nargs=1, default=['True'], help='Learned SVT threshold or not')
parser.add_argument('--SVT_favtor', metavar='float', nargs=1, default=['1.3'], help='SVT factor')

args = parser.parse_args()
# GPU setup
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu[0]
GPUs = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(GPUs[0], True)
mode = 'training'
dataset_name = args.data[0].upper()
dc_type = args.dc_type[0]
batch_size = int(args.batch_size[0])
num_epoch = int(args.num_epoch[0])
learning_rate = float(args.learning_rate[0])

acc = int(args.acc[0])
mask_pattern = args.mask_pattern[0]
net_name = args.net[0]
niter = int(args.niter[0])
nconv = int(args.nconv[0])
learnedSVT = bool(args.learnedSVT[0])
#N_factor = int(args.SVT_favtor[0])
N_factor = float(args.SVT_favtor[0])


logdir = './logs'
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
model_id = TIMESTAMP + '_'+ net_name + '_v3_correct_' + 'dc_' + dc_type +'_d'+str(nconv)+'c'+str(niter)+'_acc_'+ str(acc) + '_lr_' + str(learning_rate) + '_N_factor_' + str(N_factor) + '_rank_' + str(int(18/N_factor)) +'_'+ mask_pattern
summary_writer = tf.summary.create_file_writer(os.path.join(logdir, mode, model_id + '/'))

modeldir = os.path.join('models/stable/', model_id)
os.makedirs(modeldir)

# prepare undersampling mask
if dataset_name == 'DYNAMIC_V2':
multi_coil = False
mask_size = '18_192_192'
elif dataset_name == 'DYNAMIC_V2_MULTICOIL':
multi_coil = True
mask_size = '18_192_192'
elif dataset_name == 'FLOW':
multi_coil = False
mask_size = '20_180_180'

if acc == 8:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc8.mat')['mask']
elif acc == 10:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acc10.mat')['mask']
elif acc == 12:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc12.mat')['mask']

mask = tf.cast(tf.constant(mask), tf.complex64)

# prepare dataset
dataset = get_dataset(mode, dataset_name, batch_size, shuffle=True, full=True)
#dataset = get_dataset('test', dataset_name, batch_size, shuffle=True, full=True)
tf.print('dataset loaded.')

# initialize network
if net_name == 'Manifold_Net':
net = Manifold_Net(mask, niter, learnedSVT, N_factor)


tf.print('network initialized.')

learning_rate_org = learning_rate
learning_rate_decay = 0.95

optimizer = tf.optimizers.Adam(learning_rate_org)
# Iterate over epochs.
total_step = 0
param_num = 0
loss = 0

for epoch in range(num_epoch):
for step, sample in enumerate(dataset):
# forward
t0 = time.time()
k0 = None

csm = None
with tf.GradientTape() as tape:
if multi_coil:
k0, label, csm = sample
if k0 == None:
continue
else:
k0, label = sample
if k0.shape[0] < batch_size:
continue

label_abs = tf.abs(label)

k0 = k0 * mask

recon = net(k0, csm)
recon_abs = tf.abs(recon)

loss_mse = mse(recon, label)

# backward
grads = tape.gradient(loss_mse, net.trainable_weights)####################################
optimizer.apply_gradients(zip(grads, net.trainable_weights))#################################

# record loss
with summary_writer.as_default():
tf.summary.scalar('loss/total', loss_mse.numpy(), step=total_step)

# record gif
if step % 20 == 0:
with summary_writer.as_default():
combine_video = tf.concat([label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:]], axis=0).numpy()
combine_video = np.expand_dims(combine_video, -1)
video_summary('result', combine_video, step=total_step, fps=10)
# calculate parameter number
if total_step == 0:
param_num = np.sum([np.prod(v.get_shape()) for v in net.trainable_variables])

# log output
tf.print('Epoch', epoch+1, '/', num_epoch, 'Step', step, 'loss =', loss_mse.numpy(), 'time', time.time() - t0, 'lr = ', learning_rate, 'param_num', param_num)
total_step += 1

# learning rate decay for each epoch
learning_rate = learning_rate_org * learning_rate_decay ** (epoch + 1)#(total_step / decay_steps)
optimizer = tf.optimizers.Adam(learning_rate)

# save model each epoch
#if epoch in [0, num_epoch-1, num_epoch]:
model_epoch_dir = os.path.join(modeldir,'epoch-'+str(epoch+1), 'ckpt')
net.save_weights(model_epoch_dir, save_format='tf')


+ 578
- 0
model_net_v3.py View File

@@ -0,0 +1,578 @@
import tensorflow as tf
from tensorflow.keras import layers
import os
import numpy as np
import time
from tools.tools import tempfft, fft2c_mri, ifft2c_mri, Emat_xyt


class CNNLayer(tf.keras.layers.Layer):
def __init__(self, n_f=32, n_out=2):
super(CNNLayer, self).__init__()
self.mylayers = []

self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
self.mylayers.append(tf.keras.layers.LeakyReLU())
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
self.mylayers.append(tf.keras.layers.LeakyReLU())
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
self.mylayers.append(tf.keras.layers.LeakyReLU())
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
self.mylayers.append(tf.keras.layers.LeakyReLU())
self.mylayers.append(tf.keras.layers.Conv3D(n_out, 3, strides=1, padding='same', use_bias=False))
self.seq = tf.keras.Sequential(self.mylayers)

def call(self, input):
if len(input.shape) == 4:
input2c = tf.stack([tf.math.real(input), tf.math.imag(input)], axis=-1)
else:
input2c = tf.concat([tf.math.real(input), tf.math.imag(input)], axis=-1)
res = self.seq(input2c)
res = tf.complex(res[:,:,:,:,0], res[:,:,:,:,1])
return res

class CONV_OP(tf.keras.layers.Layer):
def __init__(self, n_f=32, ifactivate=False):
super(CONV_OP, self).__init__()
self.mylayers = []
self.mylayers.append(tf.keras.layers.Conv3D(n_f, 3, strides=1, padding='same', use_bias=False))
if ifactivate == True:
self.mylayers.append(tf.keras.layers.ReLU())
self.seq = tf.keras.Sequential(self.mylayers)

def call(self, input):
res = self.seq(input)
return res

class SLR_Net(tf.keras.Model):
def __init__(self, mask, niter, learned_topk=False):
super(SLR_Net, self).__init__(name='SLR_Net')
self.niter = niter
self.E = Emat_xyt(mask)
self.learned_topk = learned_topk
self.celllist = []

def build(self, input_shape):
for i in range(self.niter-1):
self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk))
self.celllist.append(SLRCell(input_shape, self.E, learned_topk=self.learned_topk, is_last=True))

def call(self, d, csm):
"""
d: undersampled k-space
csm: coil sensitivity map
"""
if csm == None:
nb, nt, nx, ny = d.shape
else:
nb, nc, nt, nx, ny = d.shape
X_SYM = []
x_rec = self.E.mtimes(d, inv=True, csm=csm)
t = tf.zeros_like(x_rec)
beta = tf.zeros_like(x_rec)
x_sym = tf.zeros_like(x_rec)
data = [x_rec, x_sym, beta, t, d, csm]
for i in range(self.niter):
data = self.celllist[i](data, d.shape)
x_sym = data[1]
X_SYM.append(x_sym)

x_rec = data[0]
return x_rec, X_SYM


class SLRCell(layers.Layer):
def __init__(self, input_shape, E, learned_topk=False, is_last=False):
super(SLRCell, self).__init__()
if len(input_shape) == 4:
self.nb, self.nt, self.nx, self.ny = input_shape
else:
self.nb, nc, self.nt, self.nx, self.ny = input_shape

self.E = E
self.learned_topk = learned_topk
if self.learned_topk:
if is_last:
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=False, name='thres_coef')
self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=False, name='eta')
else:
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=True, name='thres_coef')
self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='eta')

self.conv_1 = CONV_OP(n_f=16, ifactivate=True)
self.conv_2 = CONV_OP(n_f=16, ifactivate=True)
self.conv_3 = CONV_OP(n_f=16, ifactivate=False)
self.conv_4 = CONV_OP(n_f=16, ifactivate=True)
self.conv_5 = CONV_OP(n_f=16, ifactivate=True)
self.conv_6 = CONV_OP(n_f=2, ifactivate=False)
#self.conv_7 = CONV_OP(n_f=16, ifactivate=True)
#self.conv_8 = CONV_OP(n_f=16, ifactivate=True)
#self.conv_9 = CONV_OP(n_f=16, ifactivate=True)
#self.conv_10 = CONV_OP(n_f=16, ifactivate=True)

self.lambda_step = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_1')
self.lambda_step_2 = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='lambda_2')
self.soft_thr = tf.Variable(tf.constant(0.1, dtype=tf.float32), trainable=True, name='soft_thr')


def call(self, data, input_shape):
if len(input_shape) == 4:
self.nb, self.nt, self.nx, self.ny = input_shape
else:
self.nb, nc, self.nt, self.nx, self.ny = input_shape
x_rec, x_sym, beta, t, d, csm = data

x_rec, x_sym = self.sparse(x_rec, d, t, beta, csm)
t = self.lowrank(x_rec)
beta = self.beta_mid(beta, x_rec, t)

data[0] = x_rec
data[1] = x_sym
data[2] = beta
data[3] = t

return data

def sparse(self, x_rec, d, t, beta, csm):
lambda_step = tf.cast(tf.nn.relu(self.lambda_step), tf.complex64)
lambda_step_2 = tf.cast(tf.nn.relu(self.lambda_step_2), tf.complex64)

ATAX_cplx = self.E.mtimes(self.E.mtimes(x_rec, inv=False, csm=csm) - d, inv=True, csm=csm)

r_n = x_rec - tf.math.scalar_mul(lambda_step, ATAX_cplx) +\
tf.math.scalar_mul(lambda_step_2, x_rec + beta - t)

# D_T(soft(D_r_n))
if len(r_n.shape) == 4:
r_n = tf.stack([tf.math.real(r_n), tf.math.imag(r_n)], axis=-1)

x_1 = self.conv_1(r_n)
x_2 = self.conv_2(x_1)
x_3 = self.conv_3(x_2)

x_soft = tf.math.multiply(tf.math.sign(x_3), tf.nn.relu(tf.abs(x_3) - self.soft_thr))

x_4 = self.conv_4(x_soft)
x_5 = self.conv_5(x_4)
x_6 = self.conv_6(x_5)

x_rec = x_6 + r_n

x_1_sym = self.conv_4(x_3)
x_1_sym = self.conv_5(x_1_sym)
x_1_sym = self.conv_6(x_1_sym)
#x_sym_1 = self.conv_10(x_1_sym)

x_sym = x_1_sym - r_n
x_rec = tf.complex(x_rec[:, :, :, :, 0], x_rec[:, :, :, :, 1])

return x_rec, x_sym

def lowrank(self, x_rec):
[batch, Nt, Nx, Ny] = x_rec.get_shape()
M = tf.reshape(x_rec, [batch, Nt, Nx*Ny])
St, Ut, Vt = tf.linalg.svd(M)
if self.learned_topk:
#tf.print(tf.sigmoid(self.thres_coef))
thres = tf.sigmoid(self.thres_coef) * St[:, 0]
thres = tf.expand_dims(thres, -1)
St = tf.nn.relu(St - thres)
else:
top1_mask = np.concatenate(
[np.ones([self.nb, 1], dtype=np.float32), np.zeros([self.nb, self.nt - 1], dtype=np.float32)], 1)
top1_mask = tf.constant(top1_mask)
St = St * top1_mask
St = tf.linalg.diag(St)
St = tf.dtypes.cast(St, tf.complex64)
Vt_conj = tf.transpose(Vt, perm=[0, 2, 1])
Vt_conj = tf.math.conj(Vt_conj)
US = tf.linalg.matmul(Ut, St)
M = tf.linalg.matmul(US, Vt_conj)
x_rec = tf.reshape(M, [batch, Nt, Nx, Ny])

return x_rec

def beta_mid(self, beta, x_rec, t):
eta = tf.cast(tf.nn.relu(self.eta), tf.complex64)
return beta + tf.multiply(eta, x_rec - t)

###### DC-CNN ######
class DC_CNN_LR(tf.keras.Model):
def __init__(self, mask, niter, learned_topk=False):
super(DC_CNN_LR, self).__init__(name='DC_CNN_LR')
self.niter = niter
self.E = Emat_xyt(mask)
self.mask = mask
self.learned_topk = learned_topk
self.celllist = []

def build(self, input_shape):
for i in range(self.niter-1):
self.celllist.append(DNCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk))
self.celllist.append(DNCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, is_last=True))

def call(self, d, csm):
"""
d: undersampled k-space
csm: coil sensitivity map
"""
if csm == None:
nb, nt, nx, ny = d.shape
else:
nb, nc, nt, nx, ny = d.shape
x_rec = self.E.mtimes(d, inv=True, csm=csm)
for i in range(self.niter):
x_rec = self.celllist[i](x_rec, d, d.shape)
return x_rec


class DNCell(layers.Layer):

def __init__(self, input_shape, E, mask, learned_topk=False, is_last=False):
super(DNCell, self).__init__()
if len(input_shape) == 4:
self.nb, self.nt, self.nx, self.ny = input_shape
else:
self.nb, nc, self.nt, self.nx, self.ny = input_shape

self.E = E
self.mask = mask
self.learned_topk = learned_topk
if self.learned_topk:
if is_last:
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=False, name='thres_coef')
else:
self.thres_coef = tf.Variable(tf.constant(-2, dtype=tf.float32), trainable=True, name='thres_coef')

self.conv_1 = CONV_OP(n_f=16, ifactivate=True)
self.conv_2 = CONV_OP(n_f=16, ifactivate=True)
self.conv_3 = CONV_OP(n_f=16, ifactivate=True)
self.conv_4 = CONV_OP(n_f=16, ifactivate=True)
self.conv_5 = CONV_OP(n_f=2, ifactivate=False)

def call(self, x_rec, d, input_shape):
if len(input_shape) == 4:
self.nb, self.nt, self.nx, self.ny = input_shape
else:
self.nb, nc, self.nt, self.nx, self.ny = input_shape
x_rec = self.sparse(x_rec, d)
return x_rec

def sparse(self, x_rec, d):

r_n = tf.stack([tf.math.real(x_rec), tf.math.imag(x_rec)], axis=-1)

x_1 = self.conv_1(r_n)
x_2 = self.conv_2(x_1)
x_3 = self.conv_3(x_2)
x_4 = self.conv_4(x_3)
x_5 = self.conv_5(x_4)
x_rec = x_5 + r_n
x_rec = tf.complex(x_rec[:, :, :, :, 0], x_rec[:, :, :, :, 1])
if self.learned_topk:
x_rec = self.lowrank(x_rec)
x_rec = self.dc_layer(x_rec, d)

return x_rec

def lowrank(self, x_rec):
[batch, Nt, Nx, Ny] = x_rec.get_shape()
M = tf.reshape(x_rec, [batch, Nt, Nx*Ny])
St, Ut, Vt = tf.linalg.svd(M)
if self.learned_topk:
#tf.print(tf.sigmoid(self.thres_coef))
thres = tf.sigmoid(self.thres_coef) * St[:, 0]
thres = tf.expand_dims(thres, -1)
St = tf.nn.relu(St - thres)
else:
top1_mask = np.concatenate(
[np.ones([self.nb, 1], dtype=np.float32), np.zeros([self.nb, self.nt - 1], dtype=np.float32)], 1)
top1_mask = tf.constant(top1_mask)
St = St * top1_mask
St = tf.linalg.diag(St)
St = tf.dtypes.cast(St, tf.complex64)
Vt_conj = tf.transpose(Vt, perm=[0, 2, 1])
Vt_conj = tf.math.conj(Vt_conj)
US = tf.linalg.matmul(Ut, St)
M = tf.linalg.matmul(US, Vt_conj)
x_rec = tf.reshape(M, [batch, Nt, Nx, Ny])
return x_rec
def dc_layer(self, x_rec, d):

k_rec = fft2c_mri(x_rec)
k_rec = (1 - self.mask) * k_rec + self.mask * d
x_rec = ifft2c_mri(k_rec)

return x_rec
###### Manifold_Net ######
class Manifold_Net(tf.keras.Model):
def __init__(self, mask, niter, learned_topk=False, N_factor=1):
super(Manifold_Net, self).__init__(name='Manifold_Net')
self.niter = niter
self.E = Emat_xyt(mask)
self.mask = mask
self.learned_topk = learned_topk
self.N_factor = N_factor
self.celllist = []

def build(self, input_shape):
for i in range(self.niter-1):
self.celllist.append(ManifoldCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, N_factor=self.N_factor))
self.celllist.append(ManifoldCell(input_shape, self.E, self.mask, learned_topk=self.learned_topk, N_factor=self.N_factor, is_last=True))

def call(self, d, csm):
"""
d: undersampled k-space
csm: coil sensitivity map
"""
if csm == None:
nb, nt, nx, ny = d.shape
else:
nb, nc, nt, nx, ny = d.shape
x_rec = self.E.mtimes(d, inv=True, csm=csm)
for i in range(self.niter):
x_rec = self.celllist[i](x_rec, d, d.shape)
return x_rec

class ManifoldCell(layers.Layer):

def __init__(self, input_shape, E, mask, learned_topk=False, N_factor=1, is_last=False):
super(ManifoldCell, self).__init__()
if len(input_shape) == 4:
self.nb, self.nt, self.nx, self.ny = input_shape
else:
self.nb, nc, self.nt, self.nx, self.ny = input_shape

self.E = E
self.mask = mask
self.Nx_factor = N_factor
self.Ny_factor = N_factor
self.Nt_factor = N_factor
self.learned_topk = learned_topk
if self.learned_topk:
self.eta = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='eta')
#self.lambda_sparse = tf.Variable(tf.constant(0.01, dtype=tf.float32), trainable=True, name='lambda')

self.conv_1 = CNNLayer(n_f=16, n_out=2)
#self.conv_2 = CNNLayer(n_f=16, n_out=2)
#self.conv_3 = CNNLayer(n_f=16, n_out=2)
#self.conv_D = CNNLayer(n_f=16, n_out=2)
#self.conv_transD = CNNLayer(n_f=16, n_out=2)

def call(self, x_rec, d, input_shape):
if len(input_shape) == 4:
self.nb, self.nt, self.nx, self.ny = input_shape
else:
self.nb, nc, self.nt, self.nx, self.ny = input_shape
x_k = self.conv_1(x_rec)
#grad_sparse = self.conv_transD(self.conv_D(x_k))
#grad_sparse = tf.stack([tf.math.real(grad_sparse), tf.math.imag(grad_sparse)], axis=-1)
#grad_sparse = tf.multiply(self.lambda_sparse, grad_sparse)
#grad_sparse = tf.complex(grad_sparse[..., 0], grad_sparse[..., 1])
#grad_dc = ifft2c_mri((fft2c_mri(x_k) * self.mask - d) * self.mask)
#g_k = grad_dc + grad_sparse

g_k = ifft2c_mri((fft2c_mri(x_k) * self.mask - d) * self.mask)
#g_k = self.E.mtimes(self.E.mtimes(x_k, inv=False, csm=csm) - d, inv=True, csm=csm)
t_k = self.Tangent_Module(g_k, x_k)
x_k = self.Retraction_Module(x_k, t_k)
#x_k = self.conv_3(x_k)

x_k = self.dc_layer(x_k, d)
return x_k

def Tangent_Module(self, g_k, x_k):
batch, Nt, Nx, Ny = x_k.shape
x_k = tf.transpose(x_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt
g_k = tf.transpose(g_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt

Ux, Uy, Ut = self.Mode(x_k)
first_term = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
first_term = self.Mode_Multiply(first_term, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
first_term = self.Mode_Multiply(first_term, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
first_term = self.Mode_Multiply(first_term, Ux, mode_n=1)
first_term = self.Mode_Multiply(first_term, Uy, mode_n=2)
first_term = self.Mode_Multiply(first_term, Ut, mode_n=3)

C_mode_x, C_mode_y, C_mode_t = self.Core_C(x_k, Ux, Uy, Ut)

second_term_1 = self.Mode_Multiply(g_k, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
second_term_1 = self.Mode_Multiply(second_term_1, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
second_term_1 = tf.reshape(second_term_1, [batch, Nx, Ny * Nt])
second_term_1 = self.Projector(second_term_1, Ux)
second_term_1 = self.Core_Multiply(second_term_1, C_mode_x)
second_term_1 = tf.linalg.matmul(second_term_1, C_mode_x)
second_term_1 = tf.reshape(second_term_1, [batch, Nx, Ny, Nt])
second_term_1 = self.Mode_Multiply(second_term_1, Uy, mode_n=2)
second_term_1 = self.Mode_Multiply(second_term_1, Ut, mode_n=3)

second_term_2 = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
second_term_2 = self.Mode_Multiply(second_term_2, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
second_term_2 = tf.reshape(tf.transpose(second_term_2, [0, 2, 1, 3]), [batch, Ny, Nx*Nt])
second_term_2 = self.Projector(second_term_2, Uy)
second_term_2 = self.Core_Multiply(second_term_2, C_mode_y)
second_term_2 = tf.linalg.matmul(second_term_2, C_mode_y)
second_term_2 = tf.transpose(tf.reshape(second_term_2, [batch, Ny, Nx, Nt]), [0, 2, 1, 3])
second_term_2 = self.Mode_Multiply(second_term_2, Ux, mode_n=1)
second_term_2 = self.Mode_Multiply(second_term_2, Ut, mode_n=3)

second_term_3 = self.Mode_Multiply(g_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
second_term_3 = self.Mode_Multiply(second_term_3, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
second_term_3 = tf.reshape(tf.transpose(second_term_3, [0, 3, 1, 2]), [batch, Nt, Nx * Ny])
second_term_3 = self.Projector(second_term_3, Ut)
second_term_3 = self.Core_Multiply(second_term_3, C_mode_t)
second_term_3 = tf.linalg.matmul(second_term_3, C_mode_t)
second_term_3 = tf.transpose(tf.reshape(second_term_3, [batch, Nt, Nx, Ny]), [0, 2, 3, 1])
second_term_3 = self.Mode_Multiply(second_term_3, Ux, mode_n=1)
second_term_3 = self.Mode_Multiply(second_term_3, Uy, mode_n=2)

t_k = first_term + second_term_1 + second_term_2 + second_term_3
t_k = tf.transpose(t_k, [0, 3, 1, 2])
return t_k
def Retraction_Module(self, x_k, t_k):
x_k = tf.stack([tf.math.real(x_k), tf.math.imag(x_k)], axis=-1)
t_k = tf.stack([tf.math.real(t_k), tf.math.imag(t_k)], axis=-1)
x_k = x_k - tf.multiply(self.eta, t_k)
x_k = tf.complex(x_k[..., 0], x_k[..., 1])

batch, Nt, Nx, Ny = x_k.shape
x_k = tf.transpose(x_k, [0, 2, 3, 1]) # batch, Nx, Ny, Nt

Ux, Uy, Ut = self.Mode(x_k)
Ux = self.SVT_U(Ux, top_kth= int(Nx / self.Nx_factor))
Uy = self.SVT_U(Uy, top_kth= int(Ny / self.Ny_factor))
Ut = self.SVT_U(Ut, top_kth= int(Nt / self.Nt_factor))
"""
Ux = self.SVT_U(Ux, top_kth= Nx // self.Nx_factor)
Uy = self.SVT_U(Uy, top_kth= Ny // self.Ny_factor)
Ut = self.SVT_U(Ut, top_kth= Nt // self.Nt_factor)
"""
C = self.Mode_Multiply(x_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
C = self.Mode_Multiply(C, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
C = self.Mode_Multiply(C, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)

x_k = self.Mode_Multiply(C, Ux, mode_n=1)
x_k = self.Mode_Multiply(x_k, Uy, mode_n=2)
x_k = self.Mode_Multiply(x_k, Ut, mode_n=3)

x_k = tf.transpose(x_k, [0, 3, 1, 2])

return x_k
def Mode(self, x_k):
batch, Nx, Ny, Nt = x_k.shape
mode_x = tf.reshape(x_k, [batch, Nx, Ny*Nt])
mode_y = tf.reshape(tf.transpose(x_k, [0, 2, 1, 3]), [batch, Ny, Nx*Nt])
mode_t = tf.reshape(tf.transpose(x_k, [0, 3, 1, 2]), [batch, Nt, Nx*Ny])

Sx, Ux, Vx = tf.linalg.svd(mode_x) # Ux: batch, 192, 192
Sy, Uy, Vy = tf.linalg.svd(mode_y) # Uy: batch, 192, 192
St, Ut, Vt = tf.linalg.svd(mode_t)

return Ux, Uy, Ut
def Mode_Multiply(self, A, U, mode_n=1):
"""
A: batch, Nx, Ny, Nt
U: batch, Nx, Ny
return: batch, Nx, Ny, Nt
"""
batch, Nx, Ny, Nt = A.shape

if mode_n == 1:
out = tf.linalg.matmul(U, tf.reshape(A, [batch, Nx, Ny*Nt])) # batch, Nx, Ny*Nt
out= tf.reshape(out, [batch, Nx, Ny, Nt])
elif mode_n == 2:
out = tf.linalg.matmul(U, tf.reshape(tf.transpose(A, [0, 2, 1, 3]), [batch, Ny, Nx * Nt])) # batch, Ny, Nx*Nt
out = tf.transpose(tf.reshape(out, [batch, Ny, Nx, Nt]), [0, 2, 1, 3])
elif mode_n == 3:
out = tf.linalg.matmul(U, tf.reshape(tf.transpose(A, [0, 3, 1, 2]), [batch, Nt, Nx * Ny])) # batch, Nt, Nx*Ny
out = tf.transpose(tf.reshape(out, [batch, Nt, Nx, Ny]), [0, 2, 3, 1])

return out

def Core_C(self, x_k, Ux, Uy, Ut):
batch, Nx, Ny, Nt = x_k.shape

C = self.Mode_Multiply(x_k, tf.transpose(Ux, [0, 2, 1], conjugate=True), mode_n=1)
C = self.Mode_Multiply(C, tf.transpose(Uy, [0, 2, 1], conjugate=True), mode_n=2)
C = self.Mode_Multiply(C, tf.transpose(Ut, [0, 2, 1], conjugate=True), mode_n=3)
C_mode_x = tf.reshape(C, [batch, Nx, Ny * Nt])
C_mode_y = tf.reshape(tf.transpose(C, [0, 2, 1, 3]), [batch, Ny, Nx * Nt])
C_mode_t = tf.reshape(tf.transpose(C, [0, 3, 1, 2]), [batch, Nt, Nx * Ny])

return C_mode_x, C_mode_y, C_mode_t

def Projector(self, second_term, U):
second_term = second_term - tf.linalg.matmul(
tf.linalg.matmul(U,
tf.transpose(U, [0, 2, 1], conjugate=True)),
second_term)
return second_term

def Core_Multiply(self, second_term, C_mode):
second_term = tf.linalg.matmul(second_term,
tf.linalg.matmul(tf.transpose(C_mode, [0, 2, 1], conjugate=True),
tf.linalg.inv(tf.linalg.matmul(C_mode,
tf.transpose(C_mode, [0, 2, 1], conjugate=True)))))
return second_term

def SVT_U(self, Uk, top_kth):
[batch, Nx, Ny] = Uk.get_shape()
mask_1 = tf.ones([batch, Nx, top_kth])
mask_2 = tf.zeros([batch, Nx, Ny - top_kth])
mask_top_k = tf.concat([mask_1, mask_2], axis=-1)
mask_top_k = tf.cast(mask_top_k, dtype=Uk.dtype)
Uk = tf.multiply(Uk, mask_top_k)
return Uk

def dc_layer(self, x_rec, d):

k_rec = fft2c_mri(x_rec)
k_rec = (1 - self.mask) * k_rec + self.mask * d
x_rec = ifft2c_mri(k_rec)

return x_rec

def dc_layer_v2(self, x_rec, d):
x_rec = x_rec - ifft2c_mri(fft2c_mri(x_rec) * self.mask - d)
return x_rec





+ 2
- 0
models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/checkpoint View File

@@ -0,0 +1,2 @@
model_checkpoint_path: "ckpt"
all_model_checkpoint_paths: "ckpt"

BIN
models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/ckpt.data-00000-of-00002 View File


BIN
models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/ckpt.data-00001-of-00002 View File


BIN
models/stable/2021-02-28T13-38-56_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.3_rank_13_cartesian/epoch-50/ckpt.index View File


+ 142
- 0
test_net_v3.py View File

@@ -0,0 +1,142 @@
import tensorflow as tf
import os
from model_net_v3 import Manifold_Net
from dataset_tfrecord import get_dataset
import argparse
import scipy.io as scio
import mat73
import numpy as np
from datetime import datetime
import time
from tools.tools import video_summary, mse, tempfft



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--mode', metavar='str', nargs=1, default=['test'], help='training or test')
parser.add_argument('--batch_size', metavar='int', nargs=1, default=['1'], help='batch size')
parser.add_argument('--niter', metavar='int', nargs=1, default=['5'], help='number of network iterations')
parser.add_argument('--acc', metavar='int', nargs=1, default=['8'], help='accelerate rate')
parser.add_argument('--mask_pattern', metavar='str', nargs=1, default=['cartesian'], help='mask pattern: cartesian, radial, spiral, vsita')
parser.add_argument('--net', metavar='str', nargs=1, default=['Manifold_Net'], help='Manifold_Net')
parser.add_argument('--weight', metavar='str', nargs=1, default=['models/stable/2021-02-28T13-44-00_Manifold_Net_v3_correct_dc_v1_d3c5_acc_8_lr_0.001_N_factor_1.05_rank_17_cartesian/epoch-60/ckpt'], help='modeldir in ./models')
parser.add_argument('--gpu', metavar='int', nargs=1, default=['2'], help='GPU No.')
parser.add_argument('--data', metavar='str', nargs=1, default=['DYNAMIC_V2'], help='dataset name')
parser.add_argument('--learnedSVT', metavar='bool', nargs=1, default=['True'], help='Learned SVT threshold or not')
parser.add_argument('--SVT_favtor', metavar='float', nargs=1, default=['1.05'], help='SVT factor')
args = parser.parse_args()

# GPU setup
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu[0]
GPUs = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(GPUs[0], True)

dataset_name = args.data[0].upper()
mode = args.mode[0]
batch_size = int(args.batch_size[0])
niter = int(args.niter[0])
acc = int(args.acc[0])
mask_pattern = args.mask_pattern[0]
net_name = args.net[0]
weight_file = args.weight[0]
learnedSVT = bool(args.learnedSVT[0])
N_factor = float(args.SVT_favtor[0])

print('network: ', net_name)
print('acc: ', acc)
print('load weight file from: ', weight_file)


result_dir = os.path.join('results/stable', weight_file.split('/')[2])
if not os.path.isdir(result_dir):
os.makedirs(result_dir)

logdir = './logs'
TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
summary_writer = tf.summary.create_file_writer(os.path.join(logdir, mode, TIMESTAMP + net_name + str(acc) + '/'))

# prepare undersampling mask
if dataset_name == 'DYNAMIC_V2':
multi_coil = False
mask_size = '18_192_192'
elif dataset_name == 'DYNAMIC_V2_MULTICOIL':
multi_coil = True
mask_size = '18_192_192'
elif dataset_name == 'FLOW':
multi_coil = False
mask_size = '20_180_180'

if acc == 8:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc8.mat')['mask']
elif acc == 10:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/cartesian_' + mask_size + '_acs4_acc10.mat')['mask']
elif acc == 12:
mask = scio.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/'+mask_pattern + '_' + mask_size + '_acc12.mat')['mask']
"""
if acc == 8:
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_8.mat')['mask']
elif acc == 10:
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_10.mat')['mask']
elif acc == 12:
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_12.mat')['mask']
elif acc == 16:
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_16.mat')['mask']
elif acc == 20:
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_20.mat')['mask']
elif acc == 24:
mask = mat73.loadmat('/data1/wenqihuang/LplusSNet/mask_newdata/vista_' + mask_size + '_acc_24.mat')['mask']
"""
mask = tf.cast(tf.constant(mask), tf.complex64)

# prepare dataset
dataset = get_dataset(mode, dataset_name, batch_size, shuffle=False)
# initialize network
if net_name == 'Manifold_Net':
net = Manifold_Net(mask, niter, learnedSVT, N_factor)


net.load_weights(weight_file)
# Iterate over epochs.
for i, sample in enumerate(dataset):
# forward
k0 = None
csm = None
#with tf.GradientTape() as tape:
if multi_coil:
k0, label, csm = sample
else:
k0, label = sample
label_abs = tf.abs(label)

k0 = k0 * mask

t0 = time.time()
recon = net(k0, csm)
t1 = time.time()
recon_abs = tf.abs(recon)

loss_total = mse(recon, label)

tf.print(i, 'mse =', loss_total.numpy(), 'time = ', t1-t0)

result_file = os.path.join(result_dir, 'recon_'+str(i+1)+'.mat')
datadict = {'recon': np.squeeze(tf.transpose(recon, [0,2,3,1]).numpy())}
scio.savemat(result_file, datadict)

# record gif
with summary_writer.as_default():
combine_video = tf.concat([label_abs[0:1,:,:,:], recon_abs[0:1,:,:,:]], axis=0).numpy()
combine_video = np.expand_dims(combine_video, -1)
video_summary('convin-'+str(i+1), combine_video, step=1, fps=10)



+ 224
- 0
tools/compressed_sensing.py View File

@@ -0,0 +1,224 @@
import numpy as np
import tools.mymath
from numpy.lib.stride_tricks import as_strided


def soft_thresh(u, lmda):
"""Soft-threshing operator for complex valued input"""
Su = (abs(u) - lmda) / abs(u) * u
Su[abs(u) < lmda] = 0
return Su


def normal_pdf(length, sensitivity):
return np.exp(-sensitivity * (np.arange(length) - length / 2)**2)


def var_dens_mask(shape, ivar, sample_high_freq=True):
"""Variable Density Mask (2D undersampling)"""
if len(shape) == 4:
Num, Nt, Nx, Ny = shape
else:
Nx, Ny = shape
Nt = 1

pdf_x = normal_pdf(Nx, ivar)
pdf_y = normal_pdf(Ny, ivar)
pdf = np.outer(pdf_x, pdf_y)

size = pdf.itemsize
strided_pdf = as_strided(pdf, (Nt, Nx, Ny), (0, Ny * size, size))
# this must be false if undersampling rate is very low (around 90%~ish)
if sample_high_freq:
strided_pdf = strided_pdf / 1.25 + 0.02
mask = np.random.binomial(1, strided_pdf)

xc = Nx / 2
yc = Ny / 2
mask[:, xc - 4:xc + 5, yc - 4:yc + 5] = True

if Nt == 1:
return mask.reshape((Nx, Ny))

mask_4D = mask[np.newaxis, :, :, :]
mask_4D = np.tile(mask_4D, (Num, 1, 1, 1))

return mask_4D


def cartesian_mask(shape, ivar, centred=False,
sample_high_freq=True, sample_centre=True, sample_n=4):
"""
undersamples along Nx

:param shape: tuple - [nt, nx, ny]
:param ivar: sensitivity parameter for Gaussian distribution

"""
if len(shape) == 4:
Num, Nt, Nx, Ny = shape
else:
Nx, Ny = shape
Nt = 1

pdf_x = normal_pdf(Nx, ivar)

if sample_high_freq:
pdf_x = pdf_x / 1.25 + 0.02

size = pdf_x.itemsize
stride_pdf = as_strided(pdf_x, (Nt, Nx, 1), (0, size, 0))
mask = np.random.binomial(1, stride_pdf)
size = mask.itemsize
mask = as_strided(mask, (Nt, Nx, Ny), (size * Nx, size, 0))

if sample_centre:
s = sample_n / 2
xc = Nx / 2
mask[:, xc - s: xc + s, :] = True

if not centred:
mask = mymath.ifftshift(mask, axes=(-1, -2))

mask_4D = mask[np.newaxis, :, :, :]
mask_4D = np.tile(mask_4D, (Num, 1, 1, 1))

return mask_4D


def shear_grid_mask(shape, acceleration_rate, sample_low_freq=True,
centred=False, sample_n=10):
'''
Creates undersampling mask which samples in sheer grid

Parameters
----------

shape: (nt, nx, ny)

acceleration_rate: int

Returns
-------

array

'''
Nt, Nx, Ny = shape
start = np.random.randint(0, acceleration_rate)
mask = np.zeros((Nt, Nx))
for t in xrange(Nt):
mask[t, (start+t)%acceleration_rate::acceleration_rate] = 1

xc = Nx / 2
xl = sample_n / 2
if sample_low_freq and centred:
xh = xl
if sample_n % 2 == 0:
xh += 1
mask[:, xc - xl:xc + xh+1] = 1

elif sample_low_freq:
xh = xl
if sample_n % 2 == 1:
xh -= 1

if xl > 0:
mask[:, :xl] = 1
if xh > 0:
mask[:, -xh:] = 1

mask_rep = np.repeat(mask[..., np.newaxis], Ny, axis=-1)
return mask_rep


def perturbed_shear_grid_mask(shape, acceleration_rate, sample_low_freq=True,
centred=False,
sample_n=10):
Nt, Nx, Ny = shape
start = np.random.randint(0, acceleration_rate)
mask = np.zeros((Nt, Nx))
for t in xrange(Nt):
mask[t, (start+t)%acceleration_rate::acceleration_rate] = 1

# brute force
rand_code = np.random.randint(0, 3, size=Nt*Nx)
shift = np.array([-1, 0, 1])[rand_code]
new_mask = np.zeros_like(mask)
for t in xrange(Nt):
for x in xrange(Nx):
if mask[t, x]:
new_mask[t, (x + shift[t*x])%Nx] = 1

xc = Nx / 2
xl = sample_n / 2
if sample_low_freq and centred:
xh = xl
if sample_n % 2 == 0:
xh += 1
new_mask[:, xc - xl:xc + xh+1] = 1

elif sample_low_freq:
xh = xl
if sample_n % 2 == 1:
xh -= 1

new_mask[:, :xl] = 1
new_mask[:, -xh:] = 1
mask_rep = np.repeat(new_mask[..., np.newaxis], Ny, axis=-1)

return mask_rep


def undersample(x, mask, centred=False, norm='ortho'):
'''
Undersample x. FFT2 will be applied to the last 2 axis
Parameters
----------
x: array_like
data
mask: array_like
undersampling mask in fourier domain
Returns
-------
xu: array_like
undersampled image in image domain. Note that it is complex valued

x_fu: array_like
undersampled data in kspace

'''
assert x.shape == mask.shape
if centred:
x_f = mymath.fft2c(x, norm=norm)
x_fu = x_f * mask
x_u = mymath.ifft2c(x_fu, norm=norm)
return x_u, x_fu
else:
x_f = mymath.fft2(x, norm=norm)
x_fu = x_f * mask
x_u = mymath.ifft2(x_fu, norm=norm)
return x_u, x_fu, x_f

def data_consistency(x, y, mask, centered=False, norm='ortho'):
'''
x is in image space,
y is in k-space
'''
if centered:
xf = mymath.fft2c(x, norm=norm)
xm = (1 - mask) * xf + y
xd = mymath.ifft2c(xm, norm=norm)
else:
xf = mymath.fft2(x, norm=norm)
xm = (1 - mask) * xf + y
xd = mymath.ifft2(xm, norm=norm)

return xd


def get_phase(x):
xr = np.real(x)
xi = np.imag(x)
phase = np.arctan(xi / (xr + 1e-12))
return phase

+ 148
- 0
tools/mymath.py View File

@@ -0,0 +1,148 @@
__author__ = 'Jo Schlemper'

import numpy as np
sqrt = np.sqrt
from numpy.fft import fft, fft2, ifft2, ifft, ifftshift, fftshift


def fftc(x, axis=-1, norm='ortho'):
''' expect x as m*n matrix '''
return fftshift(fft(ifftshift(x, axes=axis), axis=axis, norm=norm), axes=axis)


def ifftc(x, axis=-1, norm='ortho'):
''' expect x as m*n matrix '''
return fftshift(ifft(ifftshift(x, axes=axis), axis=axis, norm=norm), axes=axis)


def fft2c(x):
'''
Centered fft
Note: fft2 applies fft to last 2 axes by default
:param x: 2D onwards. e.g: if its 3d, x.shape = (n,row,col). 4d:x.shape = (n,slice,row,col)
:return:
'''
# axes = (len(x.shape)-2, len(x.shape)-1) # get last 2 axes
axes = (-2, -1) # get last 2 axes
res = fftshift(fft2(ifftshift(x, axes=axes), norm='ortho'), axes=axes)
return res


def ifft2c(x):
'''
Centered ifft
Note: fft2 applies fft to last 2 axes by default
:param x: 2D onwards. e.g: if its 3d, x.shape = (n,row,col). 4d:x.shape = (n,slice,row,col)
:return:
'''
axes = (-2, -1) # get last 2 axes
res = fftshift(ifft2(ifftshift(x, axes=axes), norm='ortho'), axes=axes)
return res


def fourier_matrix(rows, cols):
'''
parameters:
rows: number or rows
cols: number of columns

return unitary (rows x cols) fourier matrix
'''
# from scipy.linalg import dft
# return dft(rows,scale='sqrtn')

col_range = np.arange(cols)
row_range = np.arange(rows)
scale = 1 / np.sqrt(cols)

coeffs = np.outer(row_range, col_range)
fourier_matrix = np.exp(coeffs * (-2. * np.pi * 1j / cols)) * scale

return fourier_matrix


def inverse_fourier_matrix(rows, cols):
return np.array(np.matrix(fourier_matrix(rows, cols)).getH())


def flip(m, axis):
"""
==== > Only in numpy 1.12 < =====

Reverse the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered.
.. versionadded:: 1.12.0
Parameters
----------
m : array_like
Input array.
axis : integer
Axis in array, which entries are reversed.
Returns
-------
out : array_like
A view of `m` with the entries of axis reversed. Since a view is
returned, this operation is done in constant time.
See Also
--------
flipud : Flip an array vertically (axis=0).
fliplr : Flip an array horizontally (axis=1).
Notes
-----
flip(m, 0) is equivalent to flipud(m).
flip(m, 1) is equivalent to fliplr(m).
flip(m, n) corresponds to ``m[...,::-1,...]`` with ``::-1`` at position n.
Examples
--------
>>> A = np.arange(8).reshape((2,2,2))
>>> A
array([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> flip(A, 0)
array([[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]]])
>>> flip(A, 1)
array([[[2, 3],
[0, 1]],
[[6, 7],
[4, 5]]])
>>> A = np.random.randn(3,4,5)
>>> np.all(flip(A,2) == A[:,:,::-1,...])
True
"""
if not hasattr(m, 'ndim'):
m = np.asarray(m)
indexer = [slice(None)] * m.ndim
try:
indexer[axis] = slice(None, None, -1)
except IndexError:
raise ValueError("axis=%i is invalid for the %i-dimensional input array"
% (axis, m.ndim))
return m[tuple(indexer)]


def rot90_nd(x, axes=(-2, -1), k=1):
"""Rotates selected axes"""
def flipud(x):
return flip(x, axes[0])

def fliplr(x):
return flip(x, axes[1])

x = np.asanyarray(x)
if x.ndim < 2:
raise ValueError("Input must >= 2-d.")
k = k % 4
if k == 0:
return x
elif k == 1:
return fliplr(x).swapaxes(*axes)
elif k == 2:
return fliplr(flipud(x))
else:
# k == 3
return fliplr(x.swapaxes(*axes))

+ 341
- 0
tools/tools.py View File

@@ -0,0 +1,341 @@
import tempfile
import os
import tensorflow as tf
import numpy as np
from numpy.lib.stride_tricks import as_strided
import scipy.io as scio

def video_summary(name, video, step=None, fps=10):
name = tf.constant(name).numpy().decode('utf-8')
video = np.array(video)
if video.dtype in (np.float32, np.float64):
video = np.clip(255 * video, 0, 255).astype(np.uint8)
B, T, H, W, C = video.shape
try:
frames = video.transpose((1, 2, 0, 3, 4)).reshape((T, H, B * W, C))
summary = tf.compat.v1.Summary()
image = tf.compat.v1.Summary.Image(
height=B * H, width=T * W, colorspace=C)
image.encoded_image_string = encode_gif(frames, fps)
summary.value.add(tag=name + '/gif', image=image)
tf.summary.experimental.write_raw_pb(summary.SerializeToString(), step)
except (IOError, OSError) as e:
print('GIF summaries require ffmpeg in $PATH.', e)
frames = video.transpose((0, 2, 1, 3, 4)).reshape((1, B * H, T * W, C))
tf.summary.image(name + '/grid', frames, step)


def encode_gif(frames, fps):
from subprocess import Popen, PIPE
h, w, c = frames[0].shape
pxfmt = {1: 'gray', 3: 'rgb24'}[c]
cmd = ' '.join([
f'ffmpeg -y -f rawvideo -vcodec rawvideo',
f'-r {fps:.02f} -s {w}x{h} -pix_fmt {pxfmt} -i - -filter_complex',
f'[0:v]split[x][z];[z]palettegen[y];[x]fifo[x];[x][y]paletteuse',
f'-r {fps:.02f} -f gif -'])
proc = Popen(cmd.split(' '), stdin=PIPE, stdout=PIPE, stderr=PIPE)
for image in frames:
proc.stdin.write(image.tostring())
out, err = proc.communicate()
if proc.returncode:
raise IOError('\n'.join([' '.join(cmd), err.decode('utf8')]))
del proc
return out

def normal_pdf(length, sensitivity):
return np.exp(-sensitivity * (np.arange(length) - length / 2)**2)

def cartesian_mask(shape, acc, sample_n=10, centred=False):
"""
Sampling density estimated from implementation of kt FOCUSS
shape: tuple - of form (..., nx, ny)
acc: float - doesn't have to be integer 4, 8, etc..
"""
N, Nx, Ny = int(np.prod(shape[:-2])), shape[-2], shape[-1]
pdf_x = normal_pdf(Nx, 0.5/(Nx/10.)**2)
lmda = Nx/(2.*acc)
n_lines = int(Nx / acc)

# add uniform distribution
pdf_x += lmda * 1./Nx

if sample_n:
pdf_x[Nx//2-sample_n//2:Nx//2+sample_n//2] = 0
pdf_x /= np.sum(pdf_x)
n_lines -= sample_n

mask = np.zeros((N, Nx))
for i in range(N):
idx = np.random.choice(Nx, n_lines, False, pdf_x)
mask[i, idx] = 1

if sample_n:
mask[:, Nx//2-sample_n//2:Nx//2+sample_n//2] = 1

size = mask.itemsize
mask = as_strided(mask, (N, Nx, Ny), (size * Nx, size, 0))

mask = mask.reshape(shape)

if not centred:
mask = mymath.ifftshift(mask, axes=(-1, -2))

return mask


def loss_function_ISTA(y, y_, y_sym, n_iter):
pred = tf.stack([tf.math.real(y), tf.math.imag(y)], axis=-1)
label = tf.stack([tf.math.real(y_), tf.math.imag(y_)], axis=-1)

cost = tf.reduce_mean(tf.math.square(pred - label))
cost_sym = 0
for k in range(n_iter):
#pred_sym = tf.stack([tf.math.real(y_sym[k]), tf.math.imag(y_sym[k])], axis=-1)
cost_sym += tf.reduce_mean(tf.square(y_sym))

loss = cost + 0.01 * cost_sym
return loss


def tempfft(input, inv):
if len(input.shape) == 4:
nb, nt, nx, ny = np.float32(input.shape)
nt = tf.constant(np.complex64(nt + 0j))

if inv:
x = tf.transpose(input, perm=[0,2,3,1])
#x = tf.signal.fftshift(x, 3)
x = tf.signal.ifft(x)
x = tf.transpose(x, perm=[0,3,1,2])
x = x * tf.sqrt(nt)
else:
x = tf.transpose(input, perm=[0,2,3,1])
x = tf.signal.fft(x)
#x = tf.signal.fftshift(x, 3)
x = tf.transpose(x, perm=[0,3,1,2])
x = x / tf.sqrt(nt)
else:
nb, nt, nx, ny, _ = np.float32(input.shape)
nt = tf.constant(np.complex64(nt + 0j))

if inv:
x = tf.transpose(input, perm=[0,2,3,4,1])
#x = tf.signal.fftshift(x, 4)
x = tf.signal.ifft(x)
x = tf.transpose(x, perm=[0,4,1,2,3])
x = x * tf.sqrt(nt)
else:
x = tf.transpose(input, perm=[0,2,3,4,1])
x = tf.signal.fft(x)
#x = tf.signal.fftshift(x, 4)
x = tf.transpose(x, perm=[0,4,1,2,3])
x = x / tf.sqrt(nt)
return x


def mse(recon, label):
if recon.dtype == tf.complex64:
residual_cplx = recon - label
residual = tf.stack([tf.math.real(residual_cplx), tf.math.imag(residual_cplx)], axis=-1)
mse = tf.reduce_mean(residual**2)
else:
residual = recon - label
mse = tf.reduce_mean(residual**2)
return mse

def fft2c_mri(x):
# nb nx ny nt
X = tf.signal.fftshift(x, 2)
X = tf.transpose(X, perm=[0,1,3,2]) # permute to make nx dimension the last one.
X = tf.signal.fft(X)
X = tf.transpose(X, perm=[0,1,3,2]) # permute back to original order.
nb, nt, nx, ny = np.float32(x.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))
X = tf.signal.fftshift(X, 2) / tf.sqrt(nx)
X = tf.signal.fftshift(X, 3)
X = tf.signal.fft(X)
X = tf.signal.fftshift(X, 3) / tf.sqrt(ny)
return X

def ifft2c_mri(X):
# nb nx ny nt
x = tf.signal.fftshift(X, 2)
x = tf.transpose(x, perm=[0,1,3,2]) # permute a to make nx dimension the last one.
x = tf.signal.ifft(x)
x = tf.transpose(x, perm=[0,1,3,2]) # permute back to original order.
nb, nt, nx, ny = np.float32(X.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))

x = tf.signal.fftshift(x, 2) * tf.sqrt(nx)

x = tf.signal.fftshift(x, 3)
x = tf.signal.ifft(x)
x = tf.signal.fftshift(x, 3) * tf.sqrt(ny)
return x

def sos(x):
# x: nb, ncoil, nt, nx, ny; complex64
x = tf.math.reduce_sum(tf.abs(x**2), axis=1)
x = x**(1.0/2)
return x

def softthres(x, thres):
x_abs = tf.abs(x)
coef = tf.nn.relu(x_abs - thres) / (x_abs + 1e-10)
coef = tf.cast(coef, tf.complex64)
return coef * x

"""
class Emat_xyt():
def __init__(self, mask):
super(Emat_xyt, self).__init__()
self.mask = mask

def mtimes(self, b, inv):
if inv:
# this is for single channel reconstruction only.
x = self._ifft2c_mri(b * self.mask)
else:
x = self._fft2c_mri(b) * self.mask
return x
def _fft2c_mri(self, x):
# nb nx ny nt
X = tf.signal.fftshift(x, 2)
X = tf.transpose(X, perm=[0,1,3,2]) # permute to make nx dimension the last one.
X = tf.signal.fft(X)
X = tf.transpose(X, perm=[0,1,3,2]) # permute back to original order.
nb, nt, nx, ny = np.float32(x.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))
X = tf.signal.fftshift(X, 2) / tf.sqrt(nx)
X = tf.signal.fftshift(X, 3)
X = tf.signal.fft(X)
X = tf.signal.fftshift(X, 3) / tf.sqrt(ny)
return X

def _ifft2c_mri(self, X):
# nb nx ny nt
x = tf.signal.fftshift(X, 2)
x = tf.transpose(x, perm=[0,1,3,2]) # permute a to make nx dimension the last one.
x = tf.signal.ifft(x)
x = tf.transpose(x, perm=[0,1,3,2]) # permute back to original order.
nb, nt, nx, ny = np.float32(X.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))

x = tf.signal.fftshift(x, 2) * tf.sqrt(nx)

x = tf.signal.fftshift(x, 3)
x = tf.signal.ifft(x)
x = tf.signal.fftshift(x, 3) * tf.sqrt(ny)
return x
"""
class Emat_xyt():
def __init__(self, mask):
super(Emat_xyt, self).__init__()
self.mask = mask

def mtimes(self, b, inv, csm):
if csm == None:
if inv:
x = self._ifft2c_mri_singlecoil(b * self.mask)
else:
x = self._fft2c_mri_singlecoil(b) * self.mask
else:
if len(self.mask.shape) > 3:
if inv:
x = self._ifft2c_mri_multicoil(b * self.mask[:,:,0:b.shape[2],:,:])
x = x * tf.math.conj(csm)
x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64)
else:
b = tf.expand_dims(b, 1) * csm
x = self._fft2c_mri_multicoil(b) * self.mask[:,:,0:b.shape[2],:,:]
else:
if inv:
x = self._ifft2c_mri_multicoil(b * self.mask)
x = x * tf.math.conj(csm)
x = tf.reduce_sum(x, 1) #/ tf.cast(tf.reduce_sum(tf.abs(csm)**2, 1), dtype=tf.complex64)
else:
b = tf.expand_dims(b, 1) * csm
x = self._fft2c_mri_multicoil(b) * self.mask
return x
def _fft2c_mri_multicoil(self, x):
# nb nt nx ny -> nb, nc, nt, nx, ny
X = tf.signal.fftshift(x, 3)
X = tf.transpose(X, perm=[0,1,2,4,3]) # permute to make nx dimension the last one.
X = tf.signal.fft(X)
X = tf.transpose(X, perm=[0,1,2,4,3]) # permute back to original order.
nb, nc, nt, nx, ny = np.float32(x.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))
X = tf.signal.fftshift(X, 3) / tf.sqrt(nx)
X = tf.signal.fftshift(X, 4)
X = tf.signal.fft(X)
X = tf.signal.fftshift(X, 4) / tf.sqrt(ny)
return X

def _ifft2c_mri_multicoil(self, X):
# nb nt nx ny -> nb, nc, nt, nx, ny
x = tf.signal.fftshift(X, 3)
x = tf.transpose(x, perm=[0,1,2,4,3]) # permute a to make nx dimension the last one.
x = tf.signal.ifft(x)
x = tf.transpose(x, perm=[0,1,2,4,3]) # permute back to original order.
nb, nc, nt, nx, ny = np.float32(X.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))

x = tf.signal.fftshift(x, 3) * tf.sqrt(nx)

x = tf.signal.fftshift(x, 4)
x = tf.signal.ifft(x)
x = tf.signal.fftshift(x, 4) * tf.sqrt(ny)
return x

def _fft2c_mri_singlecoil(self, x):
# nb nx ny nt
X = tf.signal.fftshift(x, 2)
X = tf.transpose(X, perm=[0,1,3,2]) # permute to make nx dimension the last one.
X = tf.signal.fft(X)
X = tf.transpose(X, perm=[0,1,3,2]) # permute back to original order.
nb, nt, nx, ny = np.float32(x.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))
X = tf.signal.fftshift(X, 2) / tf.sqrt(nx)
X = tf.signal.fftshift(X, 3)
X = tf.signal.fft(X)
X = tf.signal.fftshift(X, 3) / tf.sqrt(ny)
return X

def _ifft2c_mri_singlecoil(self, X):
# nb nx ny nt
x = tf.signal.fftshift(X, 2)
x = tf.transpose(x, perm=[0,1,3,2]) # permute a to make nx dimension the last one.
x = tf.signal.ifft(x)
x = tf.transpose(x, perm=[0,1,3,2]) # permute back to original order.
nb, nt, nx, ny = np.float32(X.shape)
nx = tf.constant(np.complex64(nx + 0j))
ny = tf.constant(np.complex64(ny + 0j))

x = tf.signal.fftshift(x, 2) * tf.sqrt(nx)

x = tf.signal.fftshift(x, 3)
x = tf.signal.ifft(x)
x = tf.signal.fftshift(x, 3) * tf.sqrt(ny)
return x

+ 187
- 0
tools/wavelet.py View File

@@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 9 15:50:03 2019
@author: wmy
"""

import numpy as np
#import matplotlib.pyplot as plt
import pywt
import tensorflow as tf
#import pylab

#pylab.rcParams['figure.figsize'] = (10.0, 10.0)

def dwt2d(x, wave='haar'):
# shape x: (b, h, w, c)
nc = int(x.shape.dims[3])
# 小波波形
w = pywt.Wavelet(wave)
# 水平低频 垂直低频
ll = np.outer(w.dec_lo, w.dec_lo)
# 水平低频 垂直高频
lh = np.outer(w.dec_hi, w.dec_lo)
# 水平高频 垂直低频
hl = np.outer(w.dec_lo, w.dec_hi)
# 水平高频 垂直高频
hh = np.outer(w.dec_hi, w.dec_hi)
# 卷积核
core = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4))
core[:, :, 0, 0] = ll[::-1, ::-1]
core[:, :, 0, 1] = lh[::-1, ::-1]
core[:, :, 0, 2] = hl[::-1, ::-1]
core[:, :, 0, 3] = hh[::-1, ::-1]
core = core.astype(np.float32)
kernel = np.array([core], dtype=np.float32)
kernel = tf.convert_to_tensor(kernel)
p = 2 * (len(w.dec_lo) // 2 - 1)
with tf.compat.v1.variable_scope('dwt2d'):
# padding odd length
x = tf.pad(x, tf.constant([[0, 0], [p, p+1], [p, p+1], [0, 0]]))
xh = tf.shape(x)[1] - tf.shape(x)[1]%2
xw = tf.shape(x)[2] - tf.shape(x)[2]%2
x = x[:, 0:xh, 0:xw, :]
# convert to 3d data
x3d = tf.expand_dims(x, 1)
# 切开通道
x3d = tf.split(x3d, int(x3d.shape.dims[4]), 4)
# 贴到维度一
x3d = tf.concat([a for a in x3d], 1)
# 三维卷积
y3d = tf.nn.conv3d(x3d, kernel, padding='VALID', strides=[1, 1, 2, 2, 1])
# 切开维度一
y = tf.split(y3d, int(y3d.shape.dims[1]), 1)
# 贴到通道维
y = tf.concat([a for a in y], 4)
y = tf.reshape(y, (tf.shape(y)[0], tf.shape(y)[2], tf.shape(y)[3], 4*nc))
# 拼贴通道
channels = tf.split(y, nc, 3)
outputs = []
for channel in channels:
(cA, cH, cV, cD) = tf.split(channel, 4, 3)
AH = tf.concat([cA, cH], axis=2)
VD = tf.concat([cV, cD], axis=2)
outputs.append(tf.concat([AH, VD], axis=1))
pass
outputs = tf.concat(outputs, axis=-1)
pass
return outputs

def wavedec2d(x, level=1, wave='haar'):
if level == 0:
return x
y = dwt2d(x, wave=wave)
hcA = tf.math.floordiv(tf.shape(y)[1], 2)
wcA = tf.math.floordiv(tf.shape(y)[2], 2)
cA = y[:, 0:hcA, 0:wcA, :]
cA = wavedec2d(cA, level=level-1, wave=wave)
cA = cA[:, 0:hcA, 0:wcA, :]
hcA = tf.shape(cA)[1]
wcA = tf.shape(cA)[2]
cH = y[:, 0:hcA, wcA:, :]
cV = y[:, hcA:, 0:wcA, :]
cD = y[:, hcA:, wcA:, :]
AH = tf.concat([cA, cH], axis=2)
VD = tf.concat([cV, cD], axis=2)
outputs = tf.concat([AH, VD], axis=1)
return outputs

def idwt2d(x, wave='haar'):
# shape x: (b, h, w, c)
nc = int(x.shape.dims[3])
# 小波波形
w = pywt.Wavelet(wave)
# 水平低频 垂直低频
ll = np.outer(w.dec_lo, w.dec_lo)
# 水平低频 垂直高频
lh = np.outer(w.dec_hi, w.dec_lo)
# 水平高频 垂直低频
hl = np.outer(w.dec_lo, w.dec_hi)
# 水平高频 垂直高频
hh = np.outer(w.dec_hi, w.dec_hi)
# 卷积核
core = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4))
core[:, :, 0, 0] = ll[::-1, ::-1]
core[:, :, 0, 1] = lh[::-1, ::-1]
core[:, :, 0, 2] = hl[::-1, ::-1]
core[:, :, 0, 3] = hh[::-1, ::-1]
core = core.astype(np.float32)
kernel = np.array([core], dtype=np.float32)
kernel = tf.convert_to_tensor(kernel)
s = 2 * (len(w.dec_lo) // 2 - 1)
# 反变换
with tf.compat.v1.variable_scope('idwt2d'):
hcA = tf.math.floordiv(tf.shape(x)[1], 2)
wcA = tf.math.floordiv(tf.shape(x)[2], 2)
y = []
for c in range(nc):
channel = x[:, :, :, c]
channel = tf.expand_dims(channel, -1)
cA = channel[:, 0:hcA, 0:wcA, :]
cH = channel[:, 0:hcA, wcA:, :]
cV = channel[:, hcA:, 0:wcA, :]
cD = channel[:, hcA:, wcA:, :]
temp = tf.concat([cA, cH, cV, cD], axis=-1)
y.append(temp)
pass
# nc * 4
y = tf.concat(y, axis=-1)
y3d = tf.expand_dims(y, 1)
y3d = tf.split(y3d, nc, 4)
y3d = tf.concat([a for a in y3d], 1)
output_shape = [tf.shape(y)[0], tf.shape(y3d)[1], \
2*(tf.shape(y)[1]-1)+np.shape(ll)[0], \
2*(tf.shape(y)[2]-1)+np.shape(ll)[1], 1]
x3d = tf.nn.conv3d_transpose(y3d, kernel, output_shape=output_shape, padding='VALID', strides=[1, 1, 2, 2, 1])
outputs = tf.split(x3d, nc, 1)
outputs = tf.concat([x for x in outputs], 4)
outputs = tf.reshape(outputs, (tf.shape(outputs)[0], tf.shape(outputs)[2], tf.shape(outputs)[3], nc))
outputs = outputs[:, s:2*(tf.shape(y)[1]-1)+np.shape(ll)[0]-s, \
s:2*(tf.shape(y)[2]-1)+np.shape(ll)[1]-s, :]
pass
return outputs

def dwt2dc(x, wave='haar'):
# b, t, h, w
nt = x.shape[1]
x_2c = tf.transpose(x, perm=[0,2,3,1])
x_2c = tf.concat([tf.math.real(x_2c), tf.math.imag(x_2c)], axis=-1)
dwtx_2c = dwt2d(x_2c)
dwtx = tf.complex(dwtx_2c[:,:,:,0:nt], dwtx_2c[:,:,:,nt:2*nt])
dwtx = tf.transpose(dwtx, perm=[0,3,1,2])
return dwtx

def idwt2dc(x, wave='haar'):
# b, t, h, w
nt = x.shape[1]
x_2c = tf.transpose(x, perm=[0,2,3,1])
x_2c = tf.concat([tf.math.real(x_2c), tf.math.imag(x_2c)], axis=-1)
idwtx_2c = idwt2d(x_2c)
idwtx = tf.complex(idwtx_2c[:,:,:,0:nt], idwtx_2c[:,:,:,nt:2*nt])
idwtx = tf.transpose(idwtx, perm=[0,3,1,2])
return idwtx


"""
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32, [None, None, None, 3], name='inputs')
image = plt.imread('test.jpg')
plt.imshow(image)
plt.show()
x = np.array([image, image[:, ::-1, :]])
dec = wavedec2d(inputs, level=5, wave='sym4')
dwt = dwt2d(inputs, wave='sym4')
idwt = idwt2d(dwt, wave='sym4')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
result = sess.run(dec, feed_dict={inputs:x})
plt.imshow(np.array(result[0], dtype=np.uint8))
plt.show()
trans = sess.run(dwt, feed_dict={inputs:x})
plt.imshow(np.array(trans[0], dtype=np.uint8))
plt.show()
recons = sess.run(idwt, feed_dict={inputs:x})
plt.imshow(np.array(recons[0], dtype=np.uint8))
plt.show()
pass
"""

Loading…
Cancel
Save