conGI tutorial
0. import packages and select GPU if accessible
[ ]:
import os
import random
import numpy as np
import scanpy as sc
import torch
from torch.utils.data import DataLoader
import argparse
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score
from st_loading_utils import load_DLPFC, load_BC, load_mVC, load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP
from model import SpaCLR, TrainerSpaCLR
from utils import get_predicted_results, load_ST_file
import pandas as pd
import warnings
from dataset import Dataset
warnings.filterwarnings("ignore")
[ ]:
def seed_torch(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
parser = argparse.ArgumentParser()
# preprocess
parser.add_argument('--dataset', type=str, default="SpatialLIBD")
parser.add_argument('--path', type=str, default="../spatialLIBD")
parser.add_argument("--gene_preprocess", choices=("pca", "hvg"), default="pca")
parser.add_argument("--n_gene", choices=(300, 1000), default=300)
parser.add_argument('--img_size', type=int, default=112)
parser.add_argument('--num_workers', type=int, default=8)
# model
parser.add_argument('--last_dim', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.0003)
parser.add_argument('--p_drop', type=float, default=0)
parser.add_argument('--w_g2i', type=float, default=1)
parser.add_argument('--w_g2g', type=float, default=0.1)
parser.add_argument('--w_i2i', type=float, default=0.1)
parser.add_argument('--w_recon', type=float, default=0)
# data augmentation
parser.add_argument('--prob_mask', type=float, default=0.5)
parser.add_argument('--pct_mask', type=float, default=0.2)
parser.add_argument('--prob_noise', type=float, default=0.5)
parser.add_argument('--pct_noise', type=float, default=0.8)
parser.add_argument('--sigma_noise', type=float, default=0.5)
parser.add_argument('--prob_swap', type=float, default=0.5)
parser.add_argument('--pct_swap', type=float, default=0.1)
# train
parser.add_argument('--batch_size', type=int, default=96)
parser.add_argument('--epochs', type=int, default=35)
parser.add_argument('--device', type=str, default="cuda:3")
parser.add_argument('--log_name', type=str, default="log_name")
parser.add_argument('--name', type=str, default="None")
iters=20
1. DLPFC dataset
change ‘${dir_}’ to ‘path/to/your/DLPFC/data’
[ ]:
"""DLPFC"""
# the number of clusters
setting_combinations = [[7, '151674'], [7, '151675'], [7, '151676']] [7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'],[7, '151676']]
# setting_combinations = [
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/DLPFC12'
name = args.name = setting_combi[1]
gene_preprocess = args.gene_preprocess
n_gene = args.n_gene
last_dim = args.last_dim
gene_dims=[n_gene, 2*last_dim]
image_dims=[n_gene]
lr = args.lr
p_drop = args.p_drop
batch_size = args.batch_size
dataset = args.dataset = 'DLPFC'
epochs = args.epochs
img_size = args.img_size
device = args.device
log_name = args.log_name
num_workers = args.num_workers
prob_mask = args.prob_mask
pct_mask = args.pct_mask
prob_noise = args.prob_noise
pct_noise = args.pct_noise
sigma_noise = args.sigma_noise
prob_swap = args.prob_swap
pct_swap = args.pct_swap
aris = []
for iter_ in range(iters):
# dataset
trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
# network
network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
optimizer = torch.optim.AdamW(network.parameters(), lr=lr)
# log
save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
log_dir = os.path.join('log', log_name, save_name)
# train
trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
trainer.fit(trainloader, epochs)
xg, xi, _ = trainer.valid(testloader)
z = xg + 0.1*xi
ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
print("Ari value : ", ARI)
print('Dataset:', name)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', name)
print(aris)
print(np.mean(aris))
with open('congi_aris.txt', 'a+') as fp:
fp.write('DLPFC' + name + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
2. BC/MA datasets (2 slides)
[ ]:
"""BC"""
# the number of clusters
setting_combinations = [[20, 'section1']]
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/BC'
name = args.name = setting_combi[1]
gene_preprocess = args.gene_preprocess
n_gene = args.n_gene
last_dim = args.last_dim
gene_dims=[n_gene, 2*last_dim]
image_dims=[n_gene]
lr = args.lr
p_drop = args.p_drop
batch_size = args.batch_size
dataset = args.dataset = 'BC'
epochs = args.epochs
img_size = args.img_size
device = args.device
log_name = args.log_name
num_workers = args.num_workers
prob_mask = args.prob_mask
pct_mask = args.pct_mask
prob_noise = args.prob_noise
pct_noise = args.pct_noise
sigma_noise = args.sigma_noise
prob_swap = args.prob_swap
pct_swap = args.pct_swap
aris = []
for iter_ in range(iters):
# dataset
trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
# network
network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
optimizer = torch.optim.AdamW(network.parameters(), lr=lr)
# log
save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
log_dir = os.path.join('log', log_name, save_name)
# train
trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
trainer.fit(trainloader, epochs)
xg, xi, _ = trainer.valid(testloader)
z = xg + 0.1*xi
ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
print("Ari value : ", ARI)
print('Dataset:', name)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', name)
print(aris)
print(np.mean(aris))
with open('congi_aris.txt', 'a+') as fp:
fp.write('BC' + name + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
[ ]:
"""MA"""
# the number of clusters
setting_combinations = [[52, 'MA']]
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/mMAMP'
name = args.name = setting_combi[1]
gene_preprocess = args.gene_preprocess
n_gene = args.n_gene
last_dim = args.last_dim
gene_dims=[n_gene, 2*last_dim]
image_dims=[n_gene]
lr = args.lr
p_drop = args.p_drop
batch_size = args.batch_size
dataset = args.dataset = 'MA'
epochs = args.epochs
img_size = args.img_size
device = args.device
log_name = args.log_name
num_workers = args.num_workers
prob_mask = args.prob_mask
pct_mask = args.pct_mask
prob_noise = args.prob_noise
pct_noise = args.pct_noise
sigma_noise = args.sigma_noise
prob_swap = args.prob_swap
pct_swap = args.pct_swap
aris = []
for iter_ in range(iters):
# dataset
trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
# network
network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
optimizer = torch.optim.AdamW(network.parameters(), lr=lr)
# log
save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
log_dir = os.path.join('log', log_name, save_name)
# train
trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
trainer.fit(trainloader, epochs)
xg, xi, _ = trainer.valid(testloader)
z = xg + 0.1*xi
ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
print("Ari value : ", ARI)
print('Dataset:', name)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', name)
print(aris)
print(np.mean(aris))
with open('congi_aris.txt', 'a+') as fp:
fp.write('mAB' + name + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
3. Her2Tumor dataset (8 slides)
[ ]:
"""Her2st"""
# the number of clusters [6, 'A1'], [5, 'B1'], [4, 'C1'],
setting_combinations = [[4, 'D1'], [4, 'E1'], [4, 'F1'], [7, 'G2'], [7, 'H1']]
# setting_combinations = [[7, '151674'], [7, '151675'], [7, '151676']]
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/Her2_tumor'
name = args.name = setting_combi[1]
gene_preprocess = args.gene_preprocess
n_gene = args.n_gene
last_dim = args.last_dim
gene_dims=[n_gene, 2*last_dim]
image_dims=[n_gene]
lr = args.lr
p_drop = args.p_drop
batch_size = args.batch_size
dataset = args.dataset = 'Her2st'
epochs = args.epochs
img_size = args.img_size
device = args.device
log_name = args.log_name
num_workers = args.num_workers
prob_mask = args.prob_mask
pct_mask = args.pct_mask
prob_noise = args.prob_noise
pct_noise = args.pct_noise
sigma_noise = args.sigma_noise
prob_swap = args.prob_swap
pct_swap = args.pct_swap
aris = []
for iter_ in range(iters):
# dataset
trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
# network
network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
optimizer = torch.optim.AdamW(network.parameters(), lr=lr)
# log
save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
log_dir = os.path.join('log', log_name, save_name)
# train
trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
trainer.fit(trainloader, epochs)
xg, xi, _ = trainer.valid(testloader)
z = xg + 0.1*xi
ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
print("Ari value : ", ARI)
print('Dataset:', name)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', name)
print(aris)
print(np.mean(aris))
with open('congi_aris.txt', 'a+') as fp:
fp.write('Her2tumor' + name + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')