@@ -0,0 +1,110 @@ | |||
# --- Base packages --- | |||
import os | |||
import numpy as np | |||
#import matplotlib.pyplot as plt | |||
import sklearn.metrics as metrics | |||
# --- PyTorch packages --- | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.optim as optim | |||
import torch.utils.data as data | |||
import torchvision.models as models | |||
import torchvision.datasets as datasets | |||
import torchvision.transforms as transforms | |||
# --- Project Packages --- | |||
from utils import save, load, train, test | |||
from datasets import NLMCXR, MIMIC | |||
from models import Classifier, TNN | |||
from baselines.transformer.models import LSTM_Attn | |||
# --- Instructions --- | |||
# Step 1: Use train_text.py, train LSTM/Transformer models on the MIMIC-CXR dataset (14 diseases + 100 noun-phrases = 114 labels) | |||
# Step 2: Use extract_label.py, load the NLMCXR dataset and predict labels using the trained LSTM/Transformer models (CheXpert) | |||
# Step 3: Save the predicted labels and load the NLMCXR dataset again with the saved labels | |||
# Step 4: Copy file2label.json to the NLMCXR dataset folder | |||
# Step 5: Use train_text.py, train LSTM/Transformer models on the NLMCXR dataset | |||
# --- Hyperparameters --- | |||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |||
os.environ["OMP_NUM_THREADS"] = "1" | |||
torch.set_num_threads(1) | |||
torch.manual_seed(seed=0) | |||
MODEL_NAME = 'Transformer' # Transformer / LSTM | |||
BATCH_SIZE = 16 | |||
if __name__ == "__main__": | |||
# --- Choose Inputs/Outputs | |||
if MODEL_NAME == 'Transformer': | |||
SOURCES = ['caption'] | |||
TARGETS = ['label'] | |||
KW_SRC = ['txt'] # kwargs of Classifier | |||
KW_TGT = None | |||
KW_OUT = None | |||
elif MODEL_NAME == 'LSTM': | |||
SOURCES = ['caption', 'caption_length'] | |||
TARGETS = ['label'] | |||
KW_SRC = ['caption', 'caption_length'] # kwargs of LSTM_Attn | |||
KW_TGT = None | |||
KW_OUT = None | |||
else: | |||
raise ValueError('Invalid MODEL_NAME') | |||
# --- Choose a Dataset --- | |||
mimic_dataset = MIMIC('/home/dongxinxin/mimic_cxr/', view_pos=['AP','PA','LATERAL'], sources=SOURCES, targets=TARGETS, vocab_file='mimic_unigram_1000.model') | |||
dataset = NLMCXR('/home/dongxinxin/iu_xray/', view_pos=['AP','PA','LATERAL'], sources=SOURCES, targets=TARGETS, vocab_file='mimic_unigram_1000.model') | |||
# Use the same vocab_file as MIMIC because language models were trained on this. | |||
NUM_LABELS = 114 # (14 diseases + 100 top noun-phrases) <-- MIMIC-CXR | |||
NUM_CLASSES = 2 | |||
VOCAB_SIZE = len(dataset.vocab) | |||
POSIT_SIZE = dataset.max_len | |||
# --- Choose a Model --- | |||
if MODEL_NAME == 'Transformer': | |||
NUM_EMBEDS = 256 | |||
NUM_HEADS = 8 | |||
FWD_DIM = 256 | |||
NUM_LAYERS = 1 | |||
DROPOUT = 0.1 | |||
tnn = TNN(embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, fwd_dim=FWD_DIM, dropout=DROPOUT, num_layers=NUM_LAYERS, num_tokens=VOCAB_SIZE, num_posits=POSIT_SIZE) | |||
model = Classifier(num_topics=NUM_LABELS, num_states=NUM_CLASSES, cnn=None, tnn=tnn, embed_dim=NUM_EMBEDS, num_heads=NUM_HEADS, dropout=DROPOUT) | |||
elif MODEL_NAME == 'LSTM': | |||
# Justin et al. hyper-parameters | |||
NUM_EMBEDS = 256 | |||
HIDDEN_SIZE = 128 | |||
DROPOUT = 0.1 | |||
model = LSTM_Attn(num_tokens=VOCAB_SIZE, embed_dim=NUM_EMBEDS, hidden_size=HIDDEN_SIZE, num_topics=NUM_LABELS, num_states=NUM_CLASSES, dropout=DROPOUT) | |||
else: | |||
raise ValueError('Invalid MODEL_NAME') | |||
# --- Main program --- | |||
data_loader = data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8) | |||
model = nn.DataParallel(model).cuda() | |||
COMMENT = 'MaxView{}_NumLabel{}'.format(2, 114) | |||
checkpoint_path_from = 'checkpoints/{}_{}_{}.pt'.format('MIMIC',MODEL_NAME,COMMENT) | |||
last_epoch, (best_metric, test_metric) = load(checkpoint_path_from, model) | |||
print('Reload From: {} | Last Epoch: {} | Validation Metric: {} | Test Metric: {}'.format(checkpoint_path_from, last_epoch, best_metric, test_metric)) | |||
loss, outputs, _ = test(data_loader, model, device='cuda', kw_src=KW_SRC, kw_tgt=KW_TGT, kw_out=KW_OUT) | |||
# --- Label Extraction --- | |||
threshold = 0.5 | |||
label = (outputs[:,:14,1] > threshold).long().cpu().numpy() # Extract only 14 common diseases! | |||
import json | |||
file_to_label = {} | |||
for i in range(len(label)): | |||
file_to_label[dataset.file_list[i]] = label[i].tolist() | |||
json.dump(file_to_label, open('file2label.json', 'w')) |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》