Commit
·
2216d16
1
Parent(s):
813c6b1
add functions
Browse files
README.md
CHANGED
|
@@ -1,3 +1,13 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-nc-nd-4.0
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
This repo contains important large files for [PeptiVerse](https://huggingface.co/spaces/ChatterjeeLab/PeptiVerse), an interactive app for peptide property prediction.
|
| 6 |
+
|
| 7 |
+
- `embeddings` folder contains processed huggingface datasets with peptideCLM embeddings. The `.csv` is the pre-processed data.
|
| 8 |
+
- `metrics` folder contains the model performance on the validation data
|
| 9 |
+
- `models` host all trained model weights
|
| 10 |
+
- `training_data` host all **raw data** to train the classifiers
|
| 11 |
+
- `functions` contains files to utilize the trained weights and classifiers
|
| 12 |
+
- `train` contains the script to train classifiers on the pre-processed embeddings, either through xgboost or MLPs.
|
| 13 |
+
- `scoring_function.py` contains a class that aggregates all trained classifiers for diverse downstream sampling applications
|
embeddings/fast_embedding_generation.py
CHANGED
|
@@ -1,113 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
from transformers import AutoModelForMaskedLM
|
| 5 |
-
from datasets import Dataset
|
| 6 |
-
import sys
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
|
| 9 |
-
|
| 10 |
-
# Configuration
|
| 11 |
-
MAX_LENGTH = 768
|
| 12 |
-
BATCH_SIZE = 128 # Adjust based on your GPU memory
|
| 13 |
-
|
| 14 |
-
# Setup device
|
| 15 |
-
if torch.cuda.is_available():
|
| 16 |
-
device = torch.device('cuda:6')
|
| 17 |
-
print(f"Using device: {device}")
|
| 18 |
-
else:
|
| 19 |
-
device = torch.device('cpu')
|
| 20 |
-
print(f"CUDA not available. Using device: {device}")
|
| 21 |
-
print("To use GPU, reinstall PyTorch with CUDA support:")
|
| 22 |
-
|
| 23 |
-
# Load tokenizer and model
|
| 24 |
-
print("Loading tokenizer and model...")
|
| 25 |
-
tokenizer = SMILES_SPE_Tokenizer(
|
| 26 |
-
'/scratch/pranamlab/sophtang/home/scoring/PeptideCLM/tokenizer/new_vocab.txt',
|
| 27 |
-
'/scratch/pranamlab/sophtang/home/scoring/PeptideCLM/tokenizer/new_splits.txt'
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
embedding_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer
|
| 31 |
-
embedding_model.to(device)
|
| 32 |
-
embedding_model.eval()
|
| 33 |
-
|
| 34 |
-
# Load CSV file
|
| 35 |
-
print("Loading CSV file...")
|
| 36 |
-
csv_path = "/scratch/pranamlab/sophtang/home/scoring/functions/nonfouling/combined_nonfouling.csv"
|
| 37 |
-
df = pd.read_csv(csv_path)
|
| 38 |
-
|
| 39 |
-
sequences = df['SMILES'].tolist()
|
| 40 |
-
labels = df['LABEL'].tolist()
|
| 41 |
-
print(f"Total sequences: {len(sequences)}")
|
| 42 |
-
print(f"First sequence: {sequences[0]}")
|
| 43 |
-
|
| 44 |
-
# Filter sequences by length (faster - no tokenization)
|
| 45 |
-
print("Filtering sequences by length...")
|
| 46 |
-
valid_data = []
|
| 47 |
-
for seq, label in zip(sequences, labels):
|
| 48 |
-
if not isinstance(seq, str):
|
| 49 |
-
continue
|
| 50 |
-
# Quick pre-filter: tokenize once to check length
|
| 51 |
-
tokenized = tokenizer(seq, return_tensors='pt', max_length=MAX_LENGTH, truncation=True)
|
| 52 |
-
if tokenized['input_ids'].shape[1] <= MAX_LENGTH:
|
| 53 |
-
valid_data.append((seq, label))
|
| 54 |
-
|
| 55 |
-
filtered_sequences = [item[0] for item in valid_data]
|
| 56 |
-
filtered_labels = [item[1] for item in valid_data]
|
| 57 |
-
print(f"Filtered sequences: {len(filtered_sequences)}")
|
| 58 |
-
|
| 59 |
-
# Generate embeddings in batches
|
| 60 |
-
print("Generating embeddings...")
|
| 61 |
-
def generate_embeddings_batched(sequences, batch_size=BATCH_SIZE):
|
| 62 |
-
embeddings = []
|
| 63 |
-
|
| 64 |
-
for i in tqdm(range(0, len(sequences), batch_size), desc="Processing batches"):
|
| 65 |
-
batch_sequences = sequences[i:i + batch_size]
|
| 66 |
-
|
| 67 |
-
# Tokenize batch
|
| 68 |
-
tokenized = tokenizer(
|
| 69 |
-
batch_sequences,
|
| 70 |
-
return_tensors='pt',
|
| 71 |
-
padding=True,
|
| 72 |
-
max_length=MAX_LENGTH,
|
| 73 |
-
truncation=True
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
-
# Move to device
|
| 77 |
-
input_ids = tokenized['input_ids'].to(device)
|
| 78 |
-
attention_mask = tokenized['attention_mask'].to(device)
|
| 79 |
-
|
| 80 |
-
# Generate embeddings
|
| 81 |
-
with torch.no_grad():
|
| 82 |
-
outputs = embedding_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 83 |
-
last_hidden_state = outputs.last_hidden_state
|
| 84 |
-
|
| 85 |
-
# Mean pooling with attention mask
|
| 86 |
-
mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
| 87 |
-
sum_embeddings = torch.sum(last_hidden_state * mask_expanded, dim=1)
|
| 88 |
-
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
|
| 89 |
-
batch_embeddings = (sum_embeddings / sum_mask).cpu().numpy()
|
| 90 |
-
|
| 91 |
-
embeddings.append(batch_embeddings)
|
| 92 |
-
|
| 93 |
-
return np.vstack(embeddings)
|
| 94 |
-
|
| 95 |
-
embeddings = generate_embeddings_batched(filtered_sequences)
|
| 96 |
-
print(f"Embeddings shape: {embeddings.shape}")
|
| 97 |
-
|
| 98 |
-
# Create and save dataset
|
| 99 |
-
print("Creating dataset...")
|
| 100 |
-
data = {
|
| 101 |
-
"sequence": filtered_sequences,
|
| 102 |
-
"labels": filtered_labels,
|
| 103 |
-
"embedding": embeddings
|
| 104 |
-
}
|
| 105 |
-
dataset = Dataset.from_dict(data)
|
| 106 |
-
|
| 107 |
-
output_path = '/scratch/pranamlab/sophtang/home/scoring/data/nonfouling'
|
| 108 |
-
print(f"Saving dataset to {output_path}...")
|
| 109 |
-
dataset.save_to_disk(output_path)
|
| 110 |
-
|
| 111 |
-
print(f"✓ Dataset saved successfully!")
|
| 112 |
-
print(f" Total samples: {len(dataset)}")
|
| 113 |
-
print(f" Embedding dimension: {embeddings.shape[1]}")
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0396104ebf1dc28b0d297bdfebede1927aba9a23417c9f06cd8f39d999d099d3
|
| 3 |
+
size 3900
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metrics/binding/binding_affinity_model_clean.ipynb
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e6b03b4f485202c9264bb4c1cd8a65f6e1d03d705e0466f8569d5fe3b6f2f6ee
|
| 3 |
-
size 565356
|
|
|
|
|
|
|
|
|
|
|
|
metrics/binding/binding_utils.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:1a7bbd43f4f65d00f52658269e707a59dc3a4b4d4e53a3d2e578b5d49e411940
|
| 3 |
-
size 9633
|
|
|
|
|
|
|
|
|
|
|
|
train/binding_affinity_model_clean.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train/binding_utils.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import pdb
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def to_var(x):
|
| 7 |
+
if torch.cuda.is_available():
|
| 8 |
+
x = x.cuda()
|
| 9 |
+
return x
|
| 10 |
+
|
| 11 |
+
class MultiHeadAttentionSequence(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 14 |
+
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
self.n_head = n_head
|
| 18 |
+
self.d_model = d_model
|
| 19 |
+
self.d_k = d_k
|
| 20 |
+
self.d_v = d_v
|
| 21 |
+
|
| 22 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 23 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 24 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 25 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 26 |
+
|
| 27 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 28 |
+
|
| 29 |
+
self.dropout = nn.Dropout(dropout)
|
| 30 |
+
|
| 31 |
+
def forward(self, q, k, v):
|
| 32 |
+
|
| 33 |
+
batch, len_q, _ = q.size()
|
| 34 |
+
batch, len_k, _ = k.size()
|
| 35 |
+
batch, len_v, _ = v.size()
|
| 36 |
+
|
| 37 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 38 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 39 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 40 |
+
|
| 41 |
+
Q = Q.transpose(1, 2)
|
| 42 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 43 |
+
V = V.transpose(1, 2)
|
| 44 |
+
|
| 45 |
+
attention = torch.matmul(Q, K)
|
| 46 |
+
|
| 47 |
+
attention = attention / np.sqrt(self.d_k)
|
| 48 |
+
|
| 49 |
+
attention = F.softmax(attention, dim=-1)
|
| 50 |
+
|
| 51 |
+
output = torch.matmul(attention, V)
|
| 52 |
+
|
| 53 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 54 |
+
|
| 55 |
+
output = self.W_O(output)
|
| 56 |
+
|
| 57 |
+
output = self.dropout(output)
|
| 58 |
+
|
| 59 |
+
output = self.layer_norm(output + q)
|
| 60 |
+
|
| 61 |
+
return output, attention
|
| 62 |
+
|
| 63 |
+
class MultiHeadAttentionReciprocal(nn.Module):
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
|
| 67 |
+
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
self.n_head = n_head
|
| 71 |
+
self.d_model = d_model
|
| 72 |
+
self.d_k = d_k
|
| 73 |
+
self.d_v = d_v
|
| 74 |
+
|
| 75 |
+
self.W_Q = nn.Linear(d_model, n_head*d_k)
|
| 76 |
+
self.W_K = nn.Linear(d_model, n_head*d_k)
|
| 77 |
+
self.W_V = nn.Linear(d_model, n_head*d_v)
|
| 78 |
+
self.W_O = nn.Linear(n_head*d_v, d_model)
|
| 79 |
+
self.W_V_2 = nn.Linear(d_model, n_head*d_v)
|
| 80 |
+
self.W_O_2 = nn.Linear(n_head*d_v, d_model)
|
| 81 |
+
|
| 82 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 83 |
+
|
| 84 |
+
self.dropout = nn.Dropout(dropout)
|
| 85 |
+
|
| 86 |
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
| 87 |
+
|
| 88 |
+
self.dropout_2 = nn.Dropout(dropout)
|
| 89 |
+
|
| 90 |
+
def forward(self, q, k, v, v_2):
|
| 91 |
+
|
| 92 |
+
batch, len_q, _ = q.size()
|
| 93 |
+
batch, len_k, _ = k.size()
|
| 94 |
+
batch, len_v, _ = v.size()
|
| 95 |
+
batch, len_v_2, _ = v_2.size()
|
| 96 |
+
|
| 97 |
+
Q = self.W_Q(q).view([batch, len_q, self.n_head, self.d_k])
|
| 98 |
+
K = self.W_K(k).view([batch, len_k, self.n_head, self.d_k])
|
| 99 |
+
V = self.W_V(v).view([batch, len_v, self.n_head, self.d_v])
|
| 100 |
+
V_2 = self.W_V_2(v_2).view([batch, len_v_2, self.n_head, self.d_v])
|
| 101 |
+
|
| 102 |
+
Q = Q.transpose(1, 2)
|
| 103 |
+
K = K.transpose(1, 2).transpose(2, 3)
|
| 104 |
+
V = V.transpose(1, 2)
|
| 105 |
+
V_2 = V_2.transpose(1,2)
|
| 106 |
+
|
| 107 |
+
attention = torch.matmul(Q, K)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
attention = attention /np.sqrt(self.d_k)
|
| 111 |
+
|
| 112 |
+
attention_2 = attention.transpose(-2, -1)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
attention = F.softmax(attention, dim=-1)
|
| 117 |
+
|
| 118 |
+
attention_2 = F.softmax(attention_2, dim=-1)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
output = torch.matmul(attention, V)
|
| 122 |
+
|
| 123 |
+
output_2 = torch.matmul(attention_2, V_2)
|
| 124 |
+
|
| 125 |
+
output = output.transpose(1, 2).reshape([batch, len_q, self.d_v*self.n_head])
|
| 126 |
+
|
| 127 |
+
output_2 = output_2.transpose(1, 2).reshape([batch, len_k, self.d_v*self.n_head])
|
| 128 |
+
|
| 129 |
+
output = self.W_O(output)
|
| 130 |
+
|
| 131 |
+
output_2 = self.W_O_2(output_2)
|
| 132 |
+
|
| 133 |
+
output = self.dropout(output)
|
| 134 |
+
|
| 135 |
+
output = self.layer_norm(output + q)
|
| 136 |
+
|
| 137 |
+
output_2 = self.dropout(output_2)
|
| 138 |
+
|
| 139 |
+
output_2 = self.layer_norm(output_2 + k)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
return output, output_2, attention, attention_2
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class FFN(nn.Module):
|
| 146 |
+
|
| 147 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
| 148 |
+
super().__init__()
|
| 149 |
+
|
| 150 |
+
self.layer_1 = nn.Conv1d(d_in, d_hid,1)
|
| 151 |
+
self.layer_2 = nn.Conv1d(d_hid, d_in,1)
|
| 152 |
+
self.relu = nn.ReLU()
|
| 153 |
+
self.layer_norm = nn.LayerNorm(d_in)
|
| 154 |
+
|
| 155 |
+
self.dropout = nn.Dropout(dropout)
|
| 156 |
+
|
| 157 |
+
def forward(self, x):
|
| 158 |
+
|
| 159 |
+
residual = x
|
| 160 |
+
output = self.layer_1(x.transpose(1, 2))
|
| 161 |
+
|
| 162 |
+
output = self.relu(output)
|
| 163 |
+
|
| 164 |
+
output = self.layer_2(output)
|
| 165 |
+
|
| 166 |
+
output = self.dropout(output)
|
| 167 |
+
|
| 168 |
+
output = self.layer_norm(output.transpose(1, 2)+residual)
|
| 169 |
+
|
| 170 |
+
return output
|
| 171 |
+
|
| 172 |
+
class ConvLayer(nn.Module):
|
| 173 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
|
| 174 |
+
super(ConvLayer, self).__init__()
|
| 175 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation)
|
| 176 |
+
self.relu = nn.ReLU()
|
| 177 |
+
|
| 178 |
+
def forward(self, x):
|
| 179 |
+
out = self.conv(x)
|
| 180 |
+
out = self.relu(out)
|
| 181 |
+
return out
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class DilatedCNN(nn.Module):
|
| 185 |
+
def __init__(self, d_model, d_hidden):
|
| 186 |
+
super(DilatedCNN, self).__init__()
|
| 187 |
+
self.first_ = nn.ModuleList()
|
| 188 |
+
self.second_ = nn.ModuleList()
|
| 189 |
+
self.third_ = nn.ModuleList()
|
| 190 |
+
|
| 191 |
+
dilation_tuple = (1, 2, 3)
|
| 192 |
+
dim_in_tuple = (d_model, d_hidden, d_hidden)
|
| 193 |
+
dim_out_tuple = (d_hidden, d_hidden, d_hidden)
|
| 194 |
+
|
| 195 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 196 |
+
self.first_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=3, padding=dilation_rate,
|
| 197 |
+
dilation=dilation_rate))
|
| 198 |
+
|
| 199 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 200 |
+
self.second_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=5, padding=2*dilation_rate,
|
| 201 |
+
dilation=dilation_rate))
|
| 202 |
+
|
| 203 |
+
for i, dilation_rate in enumerate(dilation_tuple):
|
| 204 |
+
self.third_.append(ConvLayer(dim_in_tuple[i], dim_out_tuple[i], kernel_size=7, padding=3*dilation_rate,
|
| 205 |
+
dilation=dilation_rate))
|
| 206 |
+
|
| 207 |
+
def forward(self, protein_seq_enc):
|
| 208 |
+
# pdb.set_trace()
|
| 209 |
+
protein_seq_enc = protein_seq_enc.transpose(1, 2) # protein_seq_enc's shape: B*L*d_model -> B*d_model*L
|
| 210 |
+
|
| 211 |
+
first_embedding = protein_seq_enc
|
| 212 |
+
second_embedding = protein_seq_enc
|
| 213 |
+
third_embedding = protein_seq_enc
|
| 214 |
+
|
| 215 |
+
for i in range(len(self.first_)):
|
| 216 |
+
first_embedding = self.first_[i](first_embedding)
|
| 217 |
+
|
| 218 |
+
for i in range(len(self.second_)):
|
| 219 |
+
second_embedding = self.second_[i](second_embedding)
|
| 220 |
+
|
| 221 |
+
for i in range(len(self.third_)):
|
| 222 |
+
third_embedding = self.third_[i](third_embedding)
|
| 223 |
+
|
| 224 |
+
# pdb.set_trace()
|
| 225 |
+
|
| 226 |
+
protein_seq_enc = first_embedding + second_embedding + third_embedding
|
| 227 |
+
|
| 228 |
+
return protein_seq_enc.transpose(1, 2)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class ReciprocalLayerwithCNN(nn.Module):
|
| 232 |
+
|
| 233 |
+
def __init__(self, d_model, d_inner, d_hidden, n_head, d_k, d_v):
|
| 234 |
+
super().__init__()
|
| 235 |
+
|
| 236 |
+
self.cnn = DilatedCNN(d_model, d_hidden)
|
| 237 |
+
|
| 238 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 239 |
+
|
| 240 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_hidden, d_k, d_v)
|
| 241 |
+
|
| 242 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_hidden, d_k, d_v)
|
| 243 |
+
|
| 244 |
+
self.ffn_seq = FFN(d_hidden, d_inner)
|
| 245 |
+
|
| 246 |
+
self.ffn_protein = FFN(d_hidden, d_inner)
|
| 247 |
+
|
| 248 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 249 |
+
# pdb.set_trace() # protein_seq_enc.shape = B * L * d_model
|
| 250 |
+
protein_seq_enc = self.cnn(protein_seq_enc)
|
| 251 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 252 |
+
|
| 253 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 254 |
+
|
| 255 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 256 |
+
|
| 257 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 258 |
+
|
| 259 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 260 |
+
|
| 261 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
class ReciprocalLayer(nn.Module):
|
| 265 |
+
|
| 266 |
+
def __init__(self, d_model, d_inner, n_head, d_k, d_v):
|
| 267 |
+
|
| 268 |
+
super().__init__()
|
| 269 |
+
|
| 270 |
+
self.sequence_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 271 |
+
|
| 272 |
+
self.protein_attention_layer = MultiHeadAttentionSequence(n_head, d_model, d_k, d_v)
|
| 273 |
+
|
| 274 |
+
self.reciprocal_attention_layer = MultiHeadAttentionReciprocal(n_head, d_model, d_k, d_v)
|
| 275 |
+
|
| 276 |
+
self.ffn_seq = FFN(d_model, d_inner)
|
| 277 |
+
|
| 278 |
+
self.ffn_protein = FFN(d_model, d_inner)
|
| 279 |
+
|
| 280 |
+
def forward(self, sequence_enc, protein_seq_enc):
|
| 281 |
+
prot_enc, prot_attention = self.protein_attention_layer(protein_seq_enc, protein_seq_enc, protein_seq_enc)
|
| 282 |
+
|
| 283 |
+
seq_enc, sequence_attention = self.sequence_attention_layer(sequence_enc, sequence_enc, sequence_enc)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
prot_enc, seq_enc, prot_seq_attention, seq_prot_attention = self.reciprocal_attention_layer(prot_enc, seq_enc, seq_enc, prot_enc)
|
| 287 |
+
prot_enc = self.ffn_protein(prot_enc)
|
| 288 |
+
|
| 289 |
+
seq_enc = self.ffn_seq(seq_enc)
|
| 290 |
+
|
| 291 |
+
return prot_enc, seq_enc, prot_attention, sequence_attention, prot_seq_attention, seq_prot_attention
|