|
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """model"""
- import itertools
- import numpy as np
- import mindspore as ms
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore import Tensor
- from mindspore.common.initializer import initializer, XavierUniform
-
-
- def gather_edges(edges, neighbor_idx):
- # Features [B,N,N,C] at Neighbor indices [B,N,K] => Neighbor features [B,N,K,C]
- # print(neighbor_idx.shape)
- neighbors = ops.broadcast_to(ops.expand_dims(neighbor_idx, -1), (neighbor_idx.shape[0], neighbor_idx.shape[1], neighbor_idx.shape[2], edges.shape[-1]))
- edge_features = ops.GatherD()(edges, 2, neighbors)
- return edge_features
-
-
- def gather_nodes(nodes, neighbor_idx):
- # Features [B,N,C] at Neighbor indices [B,N,K] => [B,N,K,C]
- # Flatten and expand indices per batch [B,N,K] => [B,NK] => [B,NK,C]
- neighbors_flat = neighbor_idx.view((neighbor_idx.shape[0], -1))
- neighbors_flat = ops.broadcast_to(ops.expand_dims(neighbors_flat, -1), (neighbors_flat.shape[0], neighbors_flat.shape[1], nodes.shape[2]))
- # Gather and re-pack
- neighbor_features = ops.GatherD()(nodes, 1, neighbors_flat)
- neighbor_features = neighbor_features.view(tuple(list(neighbor_idx.shape)[:3] + [-1]))
- return neighbor_features
-
-
- def gather_nodes_t(nodes, neighbor_idx):
- # Features [B,N,C] at Neighbor index [B,K] => Neighbor features[B,K,C]
- idx_flat = ops.broadcast_to(ops.expand_dims(neighbor_idx, -1), (neighbor_idx.shape[0], neighbor_idx.shape[1], nodes.shape[2]))
- neighbor_features = ops.GatherD()(nodes, 1, idx_flat)
- return neighbor_features
-
-
- def cat_neighbors_nodes(h_nodes, h_neighbors, E_idx):
- """cat_neighbors_nodes"""
- h_nodes = gather_nodes(h_nodes, E_idx)
- h_nn = ops.Concat(axis=-1)((h_neighbors, h_nodes))
- return h_nn
-
-
- class EncLayer(nn.Cell):
- """Encoder"""
- def __init__(self, num_hidden, num_in, dropout=0.1, scale=30):
- super(EncLayer, self).__init__()
- self.num_hidden = num_hidden
- self.num_in = num_in
- self.scale = scale
- self.dropout1 = nn.Dropout(1 - dropout)
- self.dropout2 = nn.Dropout(1 - dropout)
- self.dropout3 = nn.Dropout(1 - dropout)
- self.norm1 = nn.LayerNorm([num_hidden])
- self.norm2 = nn.LayerNorm([num_hidden])
- self.norm3 = nn.LayerNorm([num_hidden])
-
- self.W1 = nn.Dense(num_hidden + num_in, num_hidden, has_bias=True)
- self.W2 = nn.Dense(num_hidden, num_hidden, has_bias=True)
- self.W3 = nn.Dense(num_hidden, num_hidden, has_bias=True)
- self.W11 = nn.Dense(num_hidden + num_in, num_hidden, has_bias=True)
- self.W12 = nn.Dense(num_hidden, num_hidden, has_bias=True)
- self.W13 = nn.Dense(num_hidden, num_hidden, has_bias=True)
- self.act = nn.GELU()
- # self.act = nn.ReLU()
- self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
-
- def construct(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None):
- """ Parallel computation of full transformer layer """
- h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
- h_V_expand = ops.broadcast_to(ops.expand_dims(h_V, -2), (h_V.shape[0], h_V.shape[1], h_EV.shape[-2], h_V.shape[2]))
- h_EV = ops.Concat(axis=-1)((h_V_expand, h_EV))
- h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
- if mask_attend is not None:
- h_message = ops.expand_dims(mask_attend, -1) * h_message
- dh = ops.ReduceSum()(h_message, -2) / self.scale
- h_V = self.norm1(h_V + self.dropout1(dh))
-
- dh = self.dense(h_V)
- h_V = self.norm2(h_V + self.dropout2(dh))
- if mask_V is not None:
- mask_V = ops.expand_dims(mask_V, -1)
- h_V = mask_V * h_V
-
- h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
- h_V_expand = ops.broadcast_to(ops.expand_dims(h_V, -2), (h_V.shape[0], h_V.shape[1], h_EV.shape[-2], h_V.shape[2]))
- h_EV = ops.Concat(axis=-1)((h_V_expand, h_EV))
- h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV)))))
- h_E = self.norm3(h_E + self.dropout3(h_message))
- return h_V, h_E
-
-
- class DecLayer(nn.Cell):
- """Decoder"""
- def __init__(self, num_hidden, num_in, dropout=0.1, scale=30): # dropout=0.1
- super(DecLayer, self).__init__()
- self.num_hidden = num_hidden
- self.num_in = num_in
- self.scale = scale
- self.dropout1 = nn.Dropout(1 - dropout)
- self.dropout2 = nn.Dropout(1 - dropout)
- self.norm1 = nn.LayerNorm([num_hidden])
- self.norm2 = nn.LayerNorm([num_hidden])
-
- self.W1 = nn.Dense(num_hidden + num_in, num_hidden, has_bias=True)
- self.W2 = nn.Dense(num_hidden, num_hidden, has_bias=True)
- self.W3 = nn.Dense(num_hidden, num_hidden, has_bias=True)
- self.act = nn.GELU()
- # self.act = nn.ReLU()
- self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)
-
- def construct(self, h_V, h_E, mask_V=None, mask_attend=None):
- """ Parallel computation of full transformer layer """
- # Concatenate h_V_i to h_E_ij
- h_V_expand = ops.broadcast_to(ops.expand_dims(h_V, -2), (h_V.shape[0], h_V.shape[1], h_E.shape[-2], h_V.shape[2]))
- h_EV = ops.Concat(axis=-1)((h_V_expand, h_E))
-
- h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
- if mask_attend is not None:
- h_message = ops.expand_dims(mask_attend, -1) * h_message
- dh = ops.ReduceSum()(h_message, -2) / self.scale
-
- h_V = self.norm1(h_V + self.dropout1(dh))
-
- # Position-wise feedforward
- dh = self.dense(h_V)
- h_V = self.norm2(h_V + self.dropout2(dh))
-
- if mask_V is not None:
- mask_V = ops.expand_dims(mask_V, -1)
- h_V = mask_V * h_V
- return h_V
-
-
- class PositionWiseFeedForward(nn.Cell):
- """PositionWiseFeedForward"""
- def __init__(self, num_hidden, num_ff):
- super(PositionWiseFeedForward, self).__init__()
- self.W_in = nn.Dense(num_hidden, num_ff, has_bias=True)
- self.W_out = nn.Dense(num_ff, num_hidden, has_bias=True)
- self.act = nn.GELU()
- # self.act = nn.ReLU()
-
- def construct(self, h_V):
- h = self.act(self.W_in(h_V))
- h = self.W_out(h)
- return h
-
-
- class PositionalEncodings(nn.Cell):
- """PositionalEncodings"""
- def __init__(self, num_embeddings, max_relative_feature=32):
- super(PositionalEncodings, self).__init__()
- self.num_embeddings = num_embeddings
- self.max_relative_feature = max_relative_feature
- self.linear = nn.Dense(2 * max_relative_feature + 1 + 1, num_embeddings)
-
- def construct(self, offset, mask):
- d = ops.clip_by_value(offset + self.max_relative_feature, ms.Tensor(0),
- ms.Tensor(2 * self.max_relative_feature)) * mask + (1 - mask) * (
- 2 * self.max_relative_feature + 1)
- d_onehot = nn.OneHot(depth=2 * self.max_relative_feature + 1 + 1)(d)
- E = self.linear(d_onehot)
- return E
-
-
- class ProteinFeatures(nn.Cell):
- """ProteinFeatures"""
- def __init__(self, edge_features, node_features, num_positional_embeddings=16,
- num_rbf=16, top_k=30, augment_eps=0.):
- """ Extract protein features """
- super(ProteinFeatures, self).__init__()
- self.edge_features = edge_features
- self.node_features = node_features
- self.top_k = top_k
- self.augment_eps = augment_eps
- self.num_rbf = num_rbf
- self.num_positional_embeddings = num_positional_embeddings
-
- self.embeddings = PositionalEncodings(num_positional_embeddings)
- edge_in = num_positional_embeddings + num_rbf * 25
- self.edge_embedding = nn.Dense(edge_in, edge_features, has_bias=False)
- self.norm_edges = nn.LayerNorm([edge_features])
-
- def construct(self, X, mask, residue_idx, chain_labels):
- """construct"""
- if self.augment_eps > 0:
- X = X + self.augment_eps * ops.StandardNormal()(X.shape)
- # X = X + self.augment_eps * ms.Tensor(np.load('/home/zhaoyue/Huawei_BiologicalComputing/ProteinMPNN-main/training/randn_like_X.npy'))
-
- b = X[:, :, 1, :] - X[:, :, 0, :]
- c = X[:, :, 2, :] - X[:, :, 1, :]
- a = ms.numpy.cross(b, c)
- Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + X[:, :, 1, :]
- Ca = X[:, :, 1, :]
- N = X[:, :, 0, :]
- C = X[:, :, 2, :]
- O = X[:, :, 3, :]
-
- D_neighbors, E_idx = self._dist(Ca, mask)
-
- RBF_all = []
- RBF_all.append(self._rbf(D_neighbors)) # Ca-Ca
- RBF_all.append(self._get_rbf(N, N, E_idx)) # N-N
- RBF_all.append(self._get_rbf(C, C, E_idx)) # C-C
- RBF_all.append(self._get_rbf(O, O, E_idx)) # O-O
- RBF_all.append(self._get_rbf(Cb, Cb, E_idx)) # Cb-Cb
- RBF_all.append(self._get_rbf(Ca, N, E_idx)) # Ca-N
- RBF_all.append(self._get_rbf(Ca, C, E_idx)) # Ca-C
- RBF_all.append(self._get_rbf(Ca, O, E_idx)) # Ca-O
- RBF_all.append(self._get_rbf(Ca, Cb, E_idx)) # Ca-Cb
- RBF_all.append(self._get_rbf(N, C, E_idx)) # N-C
- RBF_all.append(self._get_rbf(N, O, E_idx)) # N-O
- RBF_all.append(self._get_rbf(N, Cb, E_idx)) # N-Cb
- RBF_all.append(self._get_rbf(Cb, C, E_idx)) # Cb-C
- RBF_all.append(self._get_rbf(Cb, O, E_idx)) # Cb-O
- RBF_all.append(self._get_rbf(O, C, E_idx)) # O-C
- RBF_all.append(self._get_rbf(N, Ca, E_idx)) # N-Ca
- RBF_all.append(self._get_rbf(C, Ca, E_idx)) # C-Ca
- RBF_all.append(self._get_rbf(O, Ca, E_idx)) # O-Ca
- RBF_all.append(self._get_rbf(Cb, Ca, E_idx)) # Cb-Ca
- RBF_all.append(self._get_rbf(C, N, E_idx)) # C-N
- RBF_all.append(self._get_rbf(O, N, E_idx)) # O-N
- RBF_all.append(self._get_rbf(Cb, N, E_idx)) # Cb-N
- RBF_all.append(self._get_rbf(C, Cb, E_idx)) # C-Cb
- RBF_all.append(self._get_rbf(O, Cb, E_idx)) # O-Cb
- RBF_all.append(self._get_rbf(C, O, E_idx)) # C-O
- RBF_all = ops.Concat(axis=-1)(tuple(RBF_all))
-
- offset = residue_idx[:, :, None] - residue_idx[:, None, :]
- offset = gather_edges(offset[:, :, :, None], E_idx)[:, :, :, 0] # [B, L, K]
-
- d_chains = ops.Cast()(((chain_labels[:, :, None] - chain_labels[:, None,\
- :]) == 0), ms.int32) # find self vs non-self interaction
- E_chains = gather_edges(d_chains[:, :, :, None], E_idx)[:, :, :, 0]
- E_positional = self.embeddings(ops.Cast()(offset, ms.int32), E_chains)
- E = ops.Concat(axis=-1)((E_positional, RBF_all))
- E = self.edge_embedding(E)
- # E=ms.Tensor(np.load('E.npy'))
- E = self.norm_edges(E)
- return E, E_idx
-
- def _dist(self, X, mask, eps=1E-6):
- """_dist"""
- mask_2D = ops.expand_dims(mask, 1) * ops.expand_dims(mask, 2)
- dX = ops.expand_dims(X, 1) - ops.expand_dims(X, 2)
- D = mask_2D * ops.Sqrt()(ops.ReduceSum()(dX ** 2, 3) + eps)
- _, D_max = ops.ArgMaxWithValue(keep_dims=True, axis=-1)(D)
- D_adjust = D + (1. - mask_2D) * D_max
- # D_neighbors, E_idx = ops.Sort(axis=-1)(D_adjust.astype(ms.float32))
- shape=X.shape[1]
- # D_neighbors, E_idx = D_neighbors[:, :, 0:int(np.minimum(self.top_k, shape))], E_idx[:, :, 0:int(np.minimum(self.top_k, shape))]
- # D_neighbors=ops.Zeros()((1,68,48), ms.float32)
- # E_idx=ops.Zeros()((1,68,48), ms.int32)
- D_neighbors, E_idx = ops.TopK(sorted=True)(D_adjust, X.shape[1])
- D_neighbors = D_neighbors[...,::-1]
- E_idx = E_idx[...,::-1]
- if self.top_k > shape:
- slice_index= shape
- else:
- slice_index = self.top_k
- D_neighbors = D_neighbors[..., :slice_index]
- E_idx = E_idx[..., :slice_index]
- # D_neighbors=ms.Tensor(np.load('D_neighbors_torch.npy'))
- # E_idx = ms.Tensor(np.load('E_idx_torch.npy'))
- return D_neighbors, E_idx
-
- def _rbf(self, D):
- D_min, D_max, D_count = 2., 22., self.num_rbf
- D_mu = ops.linspace(Tensor(D_min, ms.float32), Tensor(D_max, ms.float32), D_count)
- D_mu = D_mu.view((1, 1, 1, -1))
- D_sigma = (D_max - D_min) / D_count
- D_expand = ops.expand_dims(D, -1)
- RBF = ops.exp(-((D_expand - D_mu) / D_sigma) ** 2)
- return RBF
-
- def _get_rbf(self, A, B, E_idx):
- """_get_rbf"""
- D_A_B = ops.Sqrt()(ops.ReduceSum()((A[:, :, None, :] - B[:, None, :, :]) ** 2, -1) + 1e-6) # [B, L, L]
- D_A_B_neighbors = gather_edges(D_A_B[:, :, :, None], E_idx)[:, :, :, 0] # [B,L,K]
- RBF_A_B = self._rbf(D_A_B_neighbors)
- return RBF_A_B
-
-
- def broadcast(src: ms.Tensor, axis: int):
- src = src.asnumpy()
- ix = np.argwhere(src == src.copy())
- src = src.reshape(-1)
- ix[:, axis] = src
- return ms.Tensor(ix)
-
-
- def scatter_(src: ms.Tensor, index: ms.Tensor, out: ms.Tensor, axis: int = -1):
- index = broadcast(index, axis)
- op = ops.TensorScatterUpdate()
- return op(out, index, src.reshape(-1))
-
-
- class ProteinMPNN(nn.Cell):
- """ProteinMPNN"""
- def __init__(self, num_letters, node_features, edge_features,
- hidden_dim, num_encoder_layers=3, num_decoder_layers=3,
- vocab=21, k_neighbors=64, augment_eps=0.05, dropout=0.1): # dropout=0.1
- super(ProteinMPNN, self).__init__()
-
- # Hyperparameters
- self.node_features = node_features
- self.edge_features = edge_features
- self.hidden_dim = hidden_dim
-
- # Featurization layers
- self.features = ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)
-
- self.W_e = nn.Dense(edge_features, hidden_dim, has_bias=True)
- self.W_s = nn.Embedding(vocab, hidden_dim)
-
- # Encoder layers
- self.encoder_layers = nn.CellList([
- EncLayer(hidden_dim, hidden_dim * 2, dropout=dropout)
- for _ in range(num_encoder_layers)
- ])
-
- # Decoder layers
- self.decoder_layers = nn.CellList([
- DecLayer(hidden_dim, hidden_dim * 3, dropout=dropout)
- for _ in range(num_decoder_layers)
- ])
- self.W_out = nn.Dense(hidden_dim, num_letters, has_bias=True)
-
- for p in self.get_parameters():
- if p.dim() > 1:
- # ms.common.initializer.XavierUniform(p)
- # print(p)
- p.set_data(initializer(XavierUniform(), p.shape, ms.float32))
- def my_einsum(self, x1, x2, x3):
- i, j = x1.shape
- b, i, q = x2.shape
- b, j, p = x3.shape
- x2 = ops.transpose(x2, (0, 2, 1)).reshape(-1, q)
- x2 = ops.MatMul()(x2, x1) # bqi * ij ==> bqj
- out = ops.BatchMatMul()(x2.reshape(b, q, j), x3) # bqj * bjp ==> bqp
- return out
-
- def construct(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn=None, use_input_decoding_order=True,
- decoding_order=None):
- """ Graph-conditioned sequence model """
- # Prepare node and edge embeddings
- E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
- h_V = ops.Zeros()((E.shape[0], E.shape[1], E.shape[-1]), ms.float32)
- h_E = self.W_e(E)
-
- # Encoder is unmasked self-attention
- mask_attend = gather_nodes(ops.expand_dims(mask, -1), E_idx).squeeze(-1)
- mask_attend = ops.expand_dims(mask, -1) * mask_attend
- # h_E = ms.Tensor(np.load('h_E_torch.npy'))
- for layer in self.encoder_layers:
- h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
-
- # Concatenate sequence embeddings for autoregressive decoder
- S = ops.Cast()(S, ms.int32)
- h_S = self.W_s(S)
- h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
-
- # Build encoder embeddings
- h_EX_encoder = cat_neighbors_nodes(ops.ZerosLike()(h_S), h_E, E_idx)
- h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
-
- chain_M = chain_M * mask # update chain_m to include missing regions
- if not use_input_decoding_order:
- _, decoding_order = ops.Sort()(ops.Mul()((chain_M + 0.0001), (ops.Abs()(
- randn))))
- else:
- _, decoding_order = ops.Sort()(ops.Mul()((chain_M + 0.0001), (ops.Abs()(
- ops.StandardNormal()(chain_M.shape)))))
- # _, decoding_order = ops.Sort()(ops.Mul()((chain_M + 0.0001), (ops.Abs()(
- # ms.Tensor(np.load('/home/zhaoyue/Huawei_BiologicalComputing/ProteinMPNN-main/training/randn_chainm.npy'))))))
- mask_size = E_idx.shape[1]
- permutation_matrix_reverse = ops.Cast()(nn.OneHot(depth=mask_size)(decoding_order), ms.float32)
- # permutation_matrix_reverse = ops.Zeros()((1, 68, 68), ms.float32)
- '''
- order_mask_backward = ops.Einsum('ij, biq, bjp->bqp')(
- (1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32)),
- permutation_matrix_reverse, permutation_matrix_reverse))
- '''
- mask_matrix = 1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32))
- order_mask_backward = self.my_einsum(mask_matrix, permutation_matrix_reverse, permutation_matrix_reverse)
- mask_attend = ops.expand_dims(ops.GatherD()(order_mask_backward, 2, E_idx), -1)
- mask_1D = mask.view((mask.shape[0], mask.shape[1], 1, 1))
- mask_bw = mask_1D * mask_attend
- mask_fw = mask_1D * (1. - mask_attend)
-
- h_EXV_encoder_fw = mask_fw * h_EXV_encoder
- for layer in self.decoder_layers:
- # Masked positions attend to encoder information, unmasked see.
- h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
- h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
- h_V = layer(h_V, h_ESV, mask)
-
- logits = self.W_out(h_V)
- log_probs = nn.LogSoftmax(axis=-1)(logits)
- return log_probs
-
- def sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0,
- omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None, pssm_bias=None,
- pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None, pssm_bias_flag=None,
- bias_by_res=None):
- """sample"""
- # Prepare node and edge embeddings
- E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
- h_V = ops.Zeros()((E.shape[0], E.shape[1], E.shape[-1]), ms.float32)
- h_E = self.W_e(E)
-
- # Encoder is unmasked self-attention
- mask_attend = gather_nodes(ops.expand_dims(mask, -1), E_idx).squeeze(-1)
- mask_attend = ops.expand_dims(mask, -1) * mask_attend
- for layer in self.encoder_layers:
- h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
-
- # Decoder uses masked self-attention
- chain_mask = chain_mask * chain_M_pos * mask # update chain_m to include missing regions
- _, decoding_order = ops.Sort()((chain_mask + 0.0001) * (ops.Abs()(
- randn)))
- mask_size = E_idx.shape[1]
- permutation_matrix_reverse = ops.Cast()(nn.OneHot(depth=mask_size)(decoding_order), ms.float32)
- '''
- order_mask_backward = ms.Tensor(np.einsum('ij, biq, bjp->bqp',
- (1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32)).asnumpy()),
- permutation_matrix_reverse.asnumpy(), permutation_matrix_reverse.asnumpy()))
- '''
- mask_matrix = 1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32))
- order_mask_backward = self.my_einsum(mask_matrix, permutation_matrix_reverse, permutation_matrix_reverse)
- mask_attend=ops.expand_dims(ops.GatherD()(order_mask_backward, 2, E_idx), -1)
- mask_1D = mask.view((mask.shape[0], mask.shape[1], 1, 1))
- mask_bw = mask_1D * mask_attend
- mask_fw = mask_1D * (1. - mask_attend)
-
- N_batch, N_nodes = X.shape[0], X.shape[1]
- all_probs = ops.Zeros()((N_batch, N_nodes, 21), ms.float32)
- h_S = ops.ZerosLike()(h_V)
- S = ops.Zeros()((N_batch, N_nodes), ms.int32)
- h_V_stack = [h_V] + [ops.ZerosLike()(h_V) for _ in range(len(self.decoder_layers))]
- constant = ms.Tensor(omit_AAs_np)
- constant_bias = ms.Tensor(bias_AAs_np, ms.float32)
- omit_AA_mask_flag = omit_AA_mask is not None
-
- h_EX_encoder = cat_neighbors_nodes(ops.ZerosLike()(h_S), h_E, E_idx)
- h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
- h_EXV_encoder_fw = mask_fw * h_EXV_encoder
- for t_ in range(N_nodes):
- t = decoding_order[:, t_] # [B]
- chain_mask_gathered = ops.GatherD()(chain_mask, 1, t[:, None]) # [B]
- bias_by_res_gathered = ops.GatherD()(bias_by_res, 1, ms.numpy.tile(t[:, None, None], (1, 1, 21)))[:, 0,\
- :] # [B, 21]
- if (chain_mask_gathered == 0).all():
- S_t = ops.GatherD()(S_true, 1, t[:, None])
- else:
- # Hidden layers
- E_idx_t = ops.GatherD()(E_idx, 1, ms.numpy.tile(t[:, None, None], (1, 1, E_idx.shape[-1])))
- h_E_t = ops.GatherD()(h_E, 1,
- ms.numpy.tile(t[:, None, None, None], (1, 1, h_E.shape[-2], h_E.shape[-1])))
- h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
- h_EXV_encoder_t = ops.GatherD()(h_EXV_encoder_fw, 1,
- ms.numpy.tile(t[:, None, None, None], (1, 1, h_EXV_encoder_fw.shape[-2],
- h_EXV_encoder_fw.shape[-1])))
- mask_t = ops.GatherD()(mask, 1, t[:, None])
- for l, layer in enumerate(self.decoder_layers):
- # Updated relational features for future states
- h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
- h_V_t = ops.GatherD()(h_V_stack[l], 1,
- ms.numpy.tile(t[:, None, None], (1, 1, h_V_stack[l].shape[-1])))
- h_ESV_t = ops.GatherD()(mask_bw, 1, ms.numpy.tile(t[:, None, None, None], \
- (1, 1, mask_bw.shape[-2], mask_bw.shape[-1]))) * h_ESV_decoder_t + h_EXV_encoder_t
- h_V_stack[l + 1] = scatter_(layer(h_V_t, h_ESV_t, mask_V=mask_t),
- ms.numpy.tile(t[:, None, None], (1, 1, h_V.shape[-1])),
- h_V_stack[l + 1], axis=1)
- # Sampling step
- h_V_t = ops.GatherD()(h_V_stack[-1], 1,
- ms.numpy.tile(t[:, None, None], (1, 1, h_V_stack[-1].shape[-1])))[:, 0]
- logits = self.W_out(h_V_t) / temperature
- probs = ops.Softmax(axis=-1)((logits - constant[None, :] * 1e8 + constant_bias[None, \
- :] / temperature + bias_by_res_gathered / temperature).astype(ms.float32))
- #probs = ops.Softmax(axis=-1)(logits - constant[None, :] * 1e8 + constant_bias[None, :]) #/ temperature + bias_by_res_gathered / temperature).astype(ms.float32))
- if pssm_bias_flag:
- pssm_coef_gathered = ops.GatherD()(pssm_coef, 1, t[:, None])[:, 0]
- pssm_bias_gathered = ops.GatherD()(pssm_bias, 1, \
- ms.numpy.tile(t[:, None, None], (1, 1, pssm_bias.shape[-1])))[:,\
- 0]
- probs = (1 - pssm_multi * pssm_coef_gathered[:, None]) * probs + pssm_multi * pssm_coef_gathered[:,\
- None] * pssm_bias_gathered
- if pssm_log_odds_flag:
- pssm_log_odds_mask_gathered = ops.GatherD()(pssm_log_odds_mask, 1, ms.numpy.tile(t[:, None, None],\
- (1, 1, pssm_log_odds_mask.shape[-1])))[:, 0]
- probs_masked = probs * pssm_log_odds_mask_gathered
- probs_masked += probs * 0.001
- probs = probs_masked / ops.ReduceSum(keep_dims=True)(probs_masked, axis=-1)
- if omit_AA_mask_flag:
- omit_AA_mask_gathered = ops.GatherD()(omit_AA_mask, 1,
- ms.numpy.tile(t[:, None, None],
- (1, 1, omit_AA_mask.shape[-1])))[:, 0]
- probs_masked = probs * (1.0 - omit_AA_mask_gathered)
- probs = probs_masked / ops.ReduceSum(keep_dims=True)(probs_masked, axis=-1) # [B, 21]
-
- probs_ = np.squeeze(probs.asnumpy(), axis=0).astype("float64")
- #S_t = np.random.choice(a=probs_, size=(1))
- #S_t = ms.Tensor(np.where(S_t==probs_))
- p = np.array([i / np.sum(probs_) for i in probs_])
- #print(p.sum())
- S_t = np.random.multinomial(1, p, size=1)
- S_t = ms.Tensor(np.where(S_t == 1)[1])
- #S_t = ms.Tensor(3)
- '''
- multinomial=ops.Multinomial()
- S_t = multinomial(probs, 1)
- '''
- all_probs = scatter_(ops.Cast()((chain_mask_gathered[:, :, None,] * probs[:, None, :]), ms.float32),
- ms.numpy.tile(t[:, None, None], (1, 1, 21)),
- all_probs, axis=1)
- S_true_gathered = ops.GatherD()(S_true, 1, t[:, None])
- S_t = ops.Cast()((S_t * chain_mask_gathered + S_true_gathered * (1.0 - chain_mask_gathered)), ms.int32)
- temp1 = self.W_s(S_t)
- h_S = scatter_(temp1, ms.numpy.tile(t[:, None, None], (1, 1, temp1.shape[-1])), h_S, axis=1)
- S_t = ops.Cast()(S_t, ms.float32)
- S = ops.Cast()(S, ms.float32)
- S = scatter_(S_t, t[:, None], S, axis=1)
- output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
- return output_dict
-
- def tied_sample(self, X, randn, S_true, chain_mask, chain_encoding_all, residue_idx, mask=None, temperature=1.0,
- omit_AAs_np=None, bias_AAs_np=None, chain_M_pos=None, omit_AA_mask=None, pssm_coef=None,
- pssm_bias=None, pssm_multi=None, pssm_log_odds_flag=None, pssm_log_odds_mask=None,
- pssm_bias_flag=None, tied_pos=None, tied_beta=None, bias_by_res=None):
- """tied_sample"""
- # Prepare node and edge embeddings
- E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
- h_V = ops.Zeros()((E.shape[0], E.shape[1], E.shape[-1]), ms.float32)
- h_E = self.W_e(E)
- # Encoder is unmasked self-attention
- mask_attend = gather_nodes(ops.expand_dims(mask, -1), E_idx).squeeze(-1)
- mask_attend = ops.expand_dims(mask, -1) * mask_attend
- for layer in self.encoder_layers:
- h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
-
- # Decoder uses masked self-attention
- chain_mask = chain_mask * chain_M_pos * mask
- _, decoding_order = ops.Sort()((chain_mask + 0.0001) * (ops.Abs()(
- randn)))
-
- new_decoding_order = []
- for t_dec in list(decoding_order[0,]): #.asnumpy()
- if t_dec not in list(itertools.chain(*new_decoding_order)):
- list_a = [item for item in tied_pos if t_dec in item]
- if list_a:
- new_decoding_order.append(list_a[0])
- else:
- new_decoding_order.append([t_dec])
- decoding_order = ms.numpy.tile(ms.Tensor(list(itertools.chain(*new_decoding_order)))[None,], (
- X.shape[0], 1))
-
- mask_size = E_idx.shape[1]
- permutation_matrix_reverse = ops.Cast()(nn.OneHot(depth=mask_size)(decoding_order), ms.float32)
- '''
- order_mask_backward = ms.Tensor(np.einsum('ij, biq, bjp->bqp',
- (1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32)).asnumpy()),
- permutation_matrix_reverse.asnumpy(), permutation_matrix_reverse.asnumpy()))
- '''
- mask_matrix = 1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32))
- order_mask_backward = self.my_einsum(mask_matrix, permutation_matrix_reverse, permutation_matrix_reverse)
- mask_attend = ops.expand_dims(ops.GatherD()(order_mask_backward, 2, E_idx), -1)
- mask_1D = mask.view((mask.shape[0], mask.shape[1], 1, 1))
- mask_bw = mask_1D * mask_attend
- mask_fw = mask_1D * (1. - mask_attend)
-
- N_batch, N_nodes = X.shape[0], X.shape[1]
- all_probs = ops.Zeros()((N_batch, N_nodes, 21), ms.float32)
- h_S = ops.ZerosLike()(h_V)
- S = ops.Zeros()((N_batch, N_nodes), ms.int32)
- h_V_stack = [h_V] + [ops.ZerosLike()(h_V) for _ in range(len(self.decoder_layers))]
- constant = ms.Tensor(omit_AAs_np)
- constant_bias = ms.Tensor(bias_AAs_np)
- omit_AA_mask_flag = omit_AA_mask is not None
-
- h_EX_encoder = cat_neighbors_nodes(ops.ZerosLike()(h_S), h_E, E_idx)
- h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
- h_EXV_encoder_fw = mask_fw * h_EXV_encoder
- for t_list in new_decoding_order:
- logits = 0.0
- logit_list = []
- done_flag = False
- t = None
- for t in t_list:
- t = int(t)
- if (chain_mask[:, t] == 0).all():
- S_t = S_true[:, t]
- for t1 in t_list:
- t1 = int(t1)
- h_S[:, t1, :] = self.W_s(S_t)
- S[:, t1] = S_t
- done_flag = True
- break
- else:
- E_idx_t = E_idx[:, t:t + 1, :]
- h_E_t = h_E[:, t:t + 1, :, :]
- h_ES_t = cat_neighbors_nodes(h_S, h_E_t, E_idx_t)
- h_EXV_encoder_t = h_EXV_encoder_fw[:, t:t + 1, :, :]
- mask_t = mask[:, t:t + 1]
- for l, layer in enumerate(self.decoder_layers):
- h_ESV_decoder_t = cat_neighbors_nodes(h_V_stack[l], h_ES_t, E_idx_t)
- h_V_t = h_V_stack[l][:, t:t + 1, :]
- h_ESV_t = mask_bw[:, t:t + 1, :, :] * h_ESV_decoder_t + h_EXV_encoder_t
- h_V_stack[l + 1][:, t, :] = layer(h_V_t, h_ESV_t, mask_V=mask_t).squeeze(1)
- h_V_t = h_V_stack[-1][:, t, :]
- logit_list.append((self.W_out(h_V_t) / temperature) / len(t_list))
- logits += tied_beta[t] * (self.W_out(h_V_t) / temperature) / len(t_list)
- if done_flag:
- pass
- else:
- bias_by_res_gathered = bias_by_res[:, t, :] # [B, 21]
- probs = ops.Softmax(axis=-1)((logits - constant[None, :] * 1e8 + constant_bias[None, \
- :] / temperature + bias_by_res_gathered / temperature).astype(ms.float32))
- if pssm_bias_flag:
- pssm_coef_gathered = pssm_coef[:, t]
- pssm_bias_gathered = pssm_bias[:, t]
- probs = (1 - pssm_multi * pssm_coef_gathered[:, None]) * probs + pssm_multi * pssm_coef_gathered[:,\
- None] * pssm_bias_gathered
- if pssm_log_odds_flag:
- pssm_log_odds_mask_gathered = pssm_log_odds_mask[:, t]
- probs_masked = probs * pssm_log_odds_mask_gathered
- probs_masked += probs * 0.001
- probs = probs_masked / ops.ReduceSum(keep_dims=True)(probs_masked, axis=-1) # [B, 21]
- if omit_AA_mask_flag:
- omit_AA_mask_gathered = omit_AA_mask[:, t]
- probs_masked = probs * (1.0 - omit_AA_mask_gathered)
- probs = probs_masked / ops.ReduceSum(keep_dims=True)(probs_masked, axis=-1) # [B, 21]
- '''
- probs_ = np.squeeze(probs.asnumpy(), axis=0)
- S_t_repeat = np.random.choice(a=probs_, size=1).squeeze(-1)
- print(S_t_repeat)
- S_t = ms.Tensor(np.where(S_t_repeat))
- '''
- S_t_repeat = ops.multinomial(probs, 1).squeeze(-1)
- for t in t_list:
- t = int(t)
- h_S[:, t, :] = self.W_s(S_t_repeat)
- S[:, t] = S_t_repeat
- all_probs[:, t, :] = ops.Cast()(probs, ms.float32)
- output_dict = {"S": S, "probs": all_probs, "decoding_order": decoding_order}
- return output_dict
-
- def conditional_probs(self, X, S, mask, chain_M, residue_idx, chain_encoding_all, randn, backbone_only=False):
- """ Graph-conditioned sequence model """
- # Prepare node and edge embeddings
- E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
- h_V_enc = ops.Zeros()((E.shape[0], E.shape[1], E.shape[-1]), ms.float32)
- h_E = self.W_e(E)
-
- # Encoder is unmasked self-attention
- mask_attend = gather_nodes(ops.expand_dims(mask, -1), E_idx).squeeze(-1)
- mask_attend = ops.expand_dims(mask, -1) * mask_attend
- for layer in self.encoder_layers:
- h_V_enc, h_E = layer(h_V_enc, h_E, E_idx, mask, mask_attend)
-
- # Concatenate sequence embeddings for autoregressive decoder
- #S = ops.Cast()(S, ms.int32)
- h_S = self.W_s(S)
- h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
-
- # Build encoder embeddings
- h_EX_encoder = cat_neighbors_nodes(ops.ZerosLike()(h_S), h_E, E_idx)
- h_EXV_encoder = cat_neighbors_nodes(h_V_enc, h_EX_encoder, E_idx)
-
- chain_M = chain_M * mask # update chain_m to include missing regions
-
- chain_M_np = chain_M.asnumpy()
- idx_to_loop = np.argwhere(chain_M_np[0, :] == 1)[:, 0]
- log_conditional_probs = ops.Cast()(ops.Zeros()([X.shape[0], chain_M.shape[1], 21]), ms.float32)
-
- for idx in idx_to_loop:
- h_V = Tensor.copy(h_V_enc)
- order_mask = ops.Cast()(ops.Zeros()(chain_M.shape[1]), ms.float32)
- if backbone_only:
- order_mask = ops.Cast()(ops.Ones()(chain_M.shape[1]), ms.float32)
- order_mask[idx] = 0.
- else:
- order_mask = ops.Cast()(ops.Zeros()(chain_M.shape[1]), ms.float32)
- order_mask[idx] = 1.
- _, decoding_order = ops.Sort()((order_mask[None,] + 0.0001) * (ops.Abs()(
- randn)))
- mask_size = E_idx.shape[1]
- permutation_matrix_reverse = ops.Cast()(nn.OneHot(depth=mask_size)(decoding_order), ms.float32)
- '''
- order_mask_backward = ms.Tensor(np.einsum('ij, biq, bjp->bqp',
- (1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32)).asnumpy()),
- permutation_matrix_reverse.asnumpy(), permutation_matrix_reverse.asnumpy()))
- '''
- mask_matrix = 1 - nn.Triu()(ops.Ones()((mask_size, mask_size), ms.float32))
- order_mask_backward = self.my_einsum(mask_matrix, permutation_matrix_reverse, permutation_matrix_reverse)
- mask_attend = ops.expand_dims(ops.GatherD()(order_mask_backward, 2, E_idx), -1)
- mask_1D = mask.view((mask.shape[0], mask.shape[1], 1, 1))
- mask_bw = mask_1D * mask_attend
- mask_fw = mask_1D * (1. - mask_attend)
-
- h_EXV_encoder_fw = mask_fw * h_EXV_encoder
- for layer in self.decoder_layers:
- # Masked positions attend to encoder information, unmasked see.
- h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
- h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
- h_V = layer(h_V, h_ESV, mask)
-
- logits = self.W_out(h_V)
- log_probs = nn.LogSoftmax(axis=-1)(logits)
- log_conditional_probs[:, idx, :] = log_probs[:, idx, :]
- return log_conditional_probs
-
- def unconditional_probs(self, X, mask, residue_idx, chain_encoding_all):
- """ Graph-conditioned sequence model """
- # Prepare node and edge embeddings
- E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
- h_V = ops.Zeros()((E.shape[0], E.shape[1], E.shape[-1]), ms.float32)
- h_E = self.W_e(E)
-
- # Encoder is unmasked self-attention
- mask_attend = gather_nodes(ops.expand_dims(mask, -1), E_idx).squeeze(-1)
- mask_attend = ops.expand_dims(mask, -1) * mask_attend
- for layer in self.encoder_layers:
- h_V, h_E = layer(h_V, h_E, E_idx, mask, mask_attend)
-
- # Build encoder embeddings
- h_EX_encoder = cat_neighbors_nodes(ops.ZerosLike()(h_V), h_E, E_idx)
- h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
-
- order_mask_backward = ops.Zeros()((X.shape[0], X.shape[1], X.shape[1]), ms.float32)
- mask_attend = ops.expand_dims(ops.GatherD()(order_mask_backward, 2, E_idx), -1)
- mask_1D = mask.view((mask.shape[0], mask.shape[1], 1, 1))
- mask_fw = mask_1D * (1. - mask_attend)
-
- h_EXV_encoder_fw = mask_fw * h_EXV_encoder
- for layer in self.decoder_layers:
- h_V = layer(h_V, h_EXV_encoder_fw, mask)
-
- logits = self.W_out(h_V)
- log_probs = nn.LogSoftmax(axis=-1)(logits)
- return log_probs
-
|