conST tutorial
0. import packages and select GPU if accessible
[1]:
import torch
import argparse
import random
import numpy as np
import pandas as pd
from src.graph_func import graph_construction
from src.utils_func import mk_dir, adata_preprocess, load_ST_file, res_search_fixed_clus, plot_clustering
from src.training import conST_training
import anndata
from sklearn import metrics
import matplotlib.pyplot as plt
import scanpy as sc
import os
import warnings
warnings.filterwarnings('ignore')
from st_loading_utils import load_DLPFC, load_BC, load_mVC, load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
/tmp/ipykernel_168201/1628644719.py in <module>
4 import numpy as np
5 import pandas as pd
----> 6 from src.graph_func import graph_construction
7 from src.utils_func import mk_dir, adata_preprocess, load_ST_file, res_search_fixed_clus, plot_clustering
8 from src.training import conST_training
ModuleNotFoundError: No module named 'src'
[ ]:
parser = argparse.ArgumentParser()
parser.add_argument('--k', type=int, default=10, help='parameter k in spatial graph')
parser.add_argument('--knn_distanceType', type=str, default='euclidean',
help='graph distance type: euclidean/cosine/correlation')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--cell_feat_dim', type=int, default=300, help='Dim of PCA')
parser.add_argument('--feat_hidden1', type=int, default=100, help='Dim of DNN hidden 1-layer.')
parser.add_argument('--feat_hidden2', type=int, default=20, help='Dim of DNN hidden 2-layer.')
parser.add_argument('--gcn_hidden1', type=int, default=32, help='Dim of GCN hidden 1-layer.')
parser.add_argument('--gcn_hidden2', type=int, default=8, help='Dim of GCN hidden 2-layer.')
parser.add_argument('--p_drop', type=float, default=0.2, help='Dropout rate.')
parser.add_argument('--use_img', type=bool, default=False, help='Use histology images.')
parser.add_argument('--img_w', type=float, default=0.1, help='Weight of image features.')
parser.add_argument('--use_pretrained', type=bool, default=True, help='Use pretrained weights.')
parser.add_argument('--using_mask', type=bool, default=False, help='Using mask for multi-dataset.')
parser.add_argument('--feat_w', type=float, default=10, help='Weight of DNN loss.')
parser.add_argument('--gcn_w', type=float, default=0.1, help='Weight of GCN loss.')
parser.add_argument('--dec_kl_w', type=float, default=10, help='Weight of DEC loss.')
parser.add_argument('--gcn_lr', type=float, default=0.01, help='Initial GNN learning rate.')
parser.add_argument('--gcn_decay', type=float, default=0.01, help='Initial decay rate.')
parser.add_argument('--dec_cluster_n', type=int, default=10, help='DEC cluster number.')
parser.add_argument('--dec_interval', type=int, default=20, help='DEC interval nnumber.')
parser.add_argument('--dec_tol', type=float, default=0.00, help='DEC tol.')
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--beta', type=float, default=100, help='beta value for l2c')
parser.add_argument('--cont_l2l', type=float, default=0.3, help='Weight of local contrastive learning loss.')
parser.add_argument('--cont_l2c', type=float, default= 0.1, help='Weight of context contrastive learning loss.')
parser.add_argument('--cont_l2g', type=float, default= 0.1, help='Weight of global contrastive learning loss.')
parser.add_argument('--edge_drop_p1', type=float, default=0.1, help='drop rate of adjacent matrix of the first view')
parser.add_argument('--edge_drop_p2', type=float, default=0.1, help='drop rate of adjacent matrix of the second view')
parser.add_argument('--node_drop_p1', type=float, default=0.2, help='drop rate of node features of the first view')
parser.add_argument('--node_drop_p2', type=float, default=0.3, help='drop rate of node features of the second view')
[ ]:
# ______________ Eval clustering Setting ______________
parser.add_argument('--eval_resolution', type=int, default=1, help='Eval cluster number.')
parser.add_argument('--eval_graph_n', type=int, default=20, help='Eval graph kN tol.')
params = parser.parse_args(args=['--k', '20', '--knn_distanceType', 'euclidean', '--epochs', '200'])
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('Using device: ' + device)
params.device = device
def seed_torch(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
iters=20
1. DLPFC dataset
change ‘${dir_}’ to ‘path/to/your/DLPFC/data’
[ ]:
"""DLPFC"""
setting_combinations = [[7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'], [7, '151676']]
for setting_combi in setting_combinations:
path = './benchmarking_data/DLPFC12'
adata_h5 = load_DLPFC(root_dir=path, section_id=setting_combi[1])
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
dataset = data_name = setting_combi[1]
n_clusters = setting_combi[0]
aris = []
save_root = './output/spatialLIBD/'
# data_root = '../spatialLIBD'
params.save_path = mk_dir(f'{save_root}/{data_name}/conST')
params.cell_num = adata_h5.shape[0]
for iter_ in range(iters):
seed_torch(params.seed)
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters)
conST_net.pretraining()
conST_net.major_training()
conST_embedding = conST_net.get_embedding()
# np.save(f'{params.save_path}/conST_result.npy', conST_embedding)
# clustering
adata_conST = anndata.AnnData(conST_embedding, obs=adata_h5.obs)
adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obs['original_clusters'] = adata_h5.obs['original_clusters']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']
sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)
eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)
keep_bcs = adata_conST.obs.dropna().index
adata_conST = adata_conST[keep_bcs].copy()
ARI = metrics.adjusted_rand_score(adata_conST.obs[cluster_key], adata_conST.obs['original_clusters'])
print('Dataset:', dataset)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', dataset)
print(aris)
print(np.mean(aris))
with open('const_aris.txt', 'a+') as fp:
fp.write('DLPFC' + dataset + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
2. BC/MA datasets
[ ]:
"""BC"""
setting_combinations = [[20, 'section1']]
for setting_combi in setting_combinations:
path = './benchmarking_data/BC'
adata_h5 = load_DLPFC(root_dir=path, section_id=setting_combi[1])
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
dataset = data_name = setting_combi[1]
n_clusters = setting_combi[0]
aris = []
save_root = './output/BC/'
# data_root = '../BC'
params.save_path = mk_dir(f'{save_root}/{data_name}/conST')
params.cell_num = adata_h5.shape[0]
for iter_ in range(iters):
seed_torch(params.seed)
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters)
conST_net.pretraining()
conST_net.major_training()
conST_embedding = conST_net.get_embedding()
# np.save(f'{params.save_path}/conST_result.npy', conST_embedding)
# clustering
adata_conST = anndata.AnnData(conST_embedding, obs=adata_h5.obs)
adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obs['original_clusters'] = adata_h5.obs['original_clusters']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']
sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)
eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)
keep_bcs = adata_conST.obs.dropna().index
adata_conST = adata_conST[keep_bcs].copy()
ARI = metrics.adjusted_rand_score(adata_conST.obs[cluster_key], adata_conST.obs['original_clusters'])
print('Dataset:', dataset)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', dataset)
print(aris)
print(np.mean(aris))
with open('const_aris.txt', 'a+') as fp:
fp.write('BC' + dataset + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
[ ]:
"""MA"""
setting_combinations = [[52, 'MA']]
for setting_combi in setting_combinations:
path = './benchmarking_data/mMAMP'
adata_h5 = load_mMAMP(root_dir=path, section_id=setting_combi[1])
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
dataset = data_name = setting_combi[1]
n_clusters = setting_combi[0]
aris = []
save_root = './output/MA/'
params.save_path = mk_dir(f'{save_root}/{data_name}/conST')
params.cell_num = adata_h5.shape[0]
for iter_ in range(iters):
seed_torch(params.seed)
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters)
conST_net.pretraining()
conST_net.major_training()
conST_embedding = conST_net.get_embedding()
# clustering
adata_conST = anndata.AnnData(conST_embedding, obs=adata_h5.obs)
adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obs['original_clusters'] = adata_h5.obs['original_clusters']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']
sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)
eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)
keep_bcs = adata_conST.obs.dropna().index
adata_conST = adata_conST[keep_bcs].copy()
ARI = metrics.adjusted_rand_score(adata_conST.obs[cluster_key], adata_conST.obs['original_clusters'])
print('Dataset:', dataset)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', dataset)
print(aris)
print(np.mean(aris))
with open('const_aris.txt', 'a+') as fp:
fp.write('mAB ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
3. mVC/mPFC datasets
[ ]:
"""mVC"""
setting_combinations = [[7, 'STARmap_20180505_BY3_1k.h5ad']]
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = './benchmarking_data/STARmap_mouse_visual_cortex'
adata_h5 = load_mVC(root_dir=path, section_id=setting_combi[1])
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
dataset = data_name = setting_combi[1]
n_clusters = setting_combi[0]
aris = []
save_root = './output/her2tumor/'
params.save_path = mk_dir(f'{save_root}/{data_name}/conST')
params.cell_num = adata_h5.shape[0]
for iter_ in range(iters):
seed_torch(params.seed)
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters)
conST_net.pretraining()
conST_net.major_training()
conST_embedding = conST_net.get_embedding()
# clustering
adata_conST = anndata.AnnData(conST_embedding, obs=adata_h5.obs)
# adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obs['original_clusters'] = adata_h5.obs['original_clusters']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']
sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)
eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)
keep_bcs = adata_conST.obs.dropna().index
adata_conST = adata_conST[keep_bcs].copy()
ARI = metrics.adjusted_rand_score(adata_conST.obs[cluster_key], adata_conST.obs['original_clusters'])
print('Dataset:', dataset)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', dataset)
print(aris)
print(np.mean(aris))
with open('const_aris.txt', 'a+') as fp:
fp.write('mVC ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
[ ]:
"""mPFC"""
setting_combinations = [[4, '20180417_BZ5_control'], [4, '20180419_BZ9_control'], [4, '20180424_BZ14_control']]
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = './benchmarking_data/STARmap_mouse_PFC'
adata_h5 = load_mPFC(root_dir=path, section_id=setting_combi[1])
if params.cell_feat_dim > len(adata_h5.var.index):
params.cell_feat_dim = len(adata_h5.var.index)-1
print(params.cell_feat_dim)
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
dataset = data_name = setting_combi[1]
n_clusters = setting_combi[0]
aris = []
save_root = './output/her2tumor/'
params.save_path = mk_dir(f'{save_root}/{data_name}/conST')
params.cell_num = adata_h5.shape[0]
for iter_ in range(iters):
seed_torch(params.seed)
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters)
conST_net.pretraining()
conST_net.major_training()
conST_embedding = conST_net.get_embedding()
# clustering
adata_conST = anndata.AnnData(conST_embedding, obs=adata_h5.obs)
# adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obs['original_clusters'] = adata_h5.obs['original_clusters']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']
sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)
eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)
keep_bcs = adata_conST.obs.dropna().index
adata_conST = adata_conST[keep_bcs].copy()
ARI = metrics.adjusted_rand_score(adata_conST.obs[cluster_key], adata_conST.obs['original_clusters'])
print('Dataset:', dataset)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', dataset)
print(aris)
print(np.mean(aris))
with open('const_aris.txt', 'a+') as fp:
fp.write('mPFC' + dataset + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
4. mHypothalamus dataset
[ ]:
"""mHypo"""
setting_combinations = [[8, '-0.04'], [8, '-0.09'], [8, '-0.14'], [8, '-0.19'], [8, '-0.24'], [8, '-0.29']]
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = './benchmarking_data/mHypothalamus'
adata_h5 = load_mHypothalamus(root_dir=path, section_id=setting_combi[1])
if params.cell_feat_dim > len(adata_h5.var.index):
params.cell_feat_dim = len(adata_h5.var.index)-1
# print(params.cell_feat_dim)
if params.cell_feat_dim > len(adata_h5.obs.index):
params.cell_feat_dim = len(adata_h5.obs.index)-1
# print(params.cell_feat_dim)
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
dataset = data_name = setting_combi[1]
n_clusters = setting_combi[0]
aris = []
save_root = './output/her2tumor/'
params.save_path = mk_dir(f'{save_root}/{data_name}/conST')
params.cell_num = adata_h5.shape[0]
for iter_ in range(iters):
seed_torch(params.seed)
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters)
conST_net.pretraining()
conST_net.major_training()
conST_embedding = conST_net.get_embedding()
# clustering
adata_conST = anndata.AnnData(conST_embedding, obs=adata_h5.obs)
# adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obs['original_clusters'] = adata_h5.obs['original_clusters']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']
sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)
eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)
keep_bcs = adata_conST.obs.dropna().index
adata_conST = adata_conST[keep_bcs].copy()
ARI = metrics.adjusted_rand_score(adata_conST.obs[cluster_key], adata_conST.obs['original_clusters'])
print('Dataset:', dataset)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', dataset)
print(aris)
print(np.mean(aris))
with open('const_aris.txt', 'a+') as fp:
fp.write('mHypothalamus' + dataset + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')
5. Her2Tumor dataset
[ ]:
"""Her2st"""
setting_combinations = [[5, 'B1'], [4, 'C1'], [4, 'D1'], [4, 'E1'], [4, 'F1'], [7, 'G2'], [7, 'H1']]
for setting_combi in setting_combinations:
args = parser.parse_args()
# seed
seed_torch(1)
path = args.path = './benchmarking_data/Her2_tumor'
adata_h5 = load_her2_tumor(root_dir=path, section_id=setting_combi[1])
if params.cell_feat_dim > len(adata_h5.obs.index):
params.cell_feat_dim = len(adata_h5.obs.index)-1
print(params.cell_feat_dim)
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
dataset = data_name = setting_combi[1]
n_clusters = setting_combi[0]
aris = []
save_root = './output/her2tumor/'
# data_root = '../BC'
params.save_path = mk_dir(f'{save_root}/{data_name}/conST')
params.cell_num = adata_h5.shape[0]
for iter_ in range(iters):
seed_torch(params.seed)
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters)
conST_net.pretraining()
conST_net.major_training()
conST_embedding = conST_net.get_embedding()
# clustering
adata_conST = anndata.AnnData(conST_embedding, obs=adata_h5.obs)
# adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obs['original_clusters'] = adata_h5.obs['original_clusters']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']
sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)
eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)
keep_bcs = adata_conST.obs.dropna().index
adata_conST = adata_conST[keep_bcs].copy()
ARI = metrics.adjusted_rand_score(adata_conST.obs[cluster_key], adata_conST.obs['original_clusters'])
print('Dataset:', dataset)
print('ARI:', ARI)
aris.append(ARI)
print('Dataset:', dataset)
print(aris)
print(np.mean(aris))
with open('const_aris.txt', 'a+') as fp:
fp.write('Her2tumor' + dataset + ' ')
fp.write(' '.join([str(i) for i in aris]))
fp.write('\n')