SPIRAL integration tutorial
data preparation (use DLPFC as an example here)
[2]:
from st_loading_utils import load_DLPFC, load_mHypothalamus
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import anndata
import scipy as sp
import umap.umap_ as umap
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
[3]:
def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True):
"""\
Construct the spatial neighbor networks.
Parameters
----------
adata
AnnData object of scanpy package.
rad_cutoff
radius cutoff when model='Radius'
k_cutoff
The number of nearest neighbors when model='KNN'
model
The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.
Returns
-------
The spatial networks are saved in adata.uns['Spatial_Net']
"""
assert(model in ['Radius', 'KNN'])
if verbose:
print('------Calculating spatial graph...')
coor = pd.DataFrame(adata.obsm['spatial'])
coor.index = adata.obs.index
# coor.columns = ['imagerow', 'imagecol']
if model == 'Radius':
nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)
distances, indices = nbrs.radius_neighbors(coor, return_distance=True)
KNN_list = []
for it in range(indices.shape[0]):
KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))
if model == 'KNN':
nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff+1).fit(coor)
distances, indices = nbrs.kneighbors(coor)
KNN_list = []
for it in range(indices.shape[0]):
KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))
KNN_df = pd.concat(KNN_list)
KNN_df.columns = ['Cell1', 'Cell2', 'Distance']
Spatial_Net = KNN_df.copy()
Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]
id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))
Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)
if verbose:
print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))
print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))
adata.uns['Spatial_Net'] = Spatial_Net
[ ]:
dirs="/home/yunfei/spatial_benchmarking/benchmarking_data/DLPFC12"
# lists=[[151507,151508],[151508,151509],[151509,151510],[151669,151670],[151670,151671],[151671,151672],[151673,151674],[151674,151675],[151675,151676]]
lists=[[151507, 151508, 151509, 151510], [151669, 151670, 151671, 151672], [151673, 151674, 151675, 151676]]
out_dirs="./data/DLPFC/"
for sample_name in lists:
# IDX=np.arange(0,2)
IDX=np.arange(0,4)
VF=[]
MAT=[]
flags=str(sample_name[IDX[0]])
for i in np.arange(1,len(IDX)):
flags=flags+'-'+str(sample_name[IDX[i]])
flags=flags+"_"
for k in np.arange(len(IDX)):
adata = load_DLPFC(root_dir=dirs, section_id=str(sample_name[k]))
adata.var_names_make_unique()
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=5000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.obs['batch']=str(sample_name[IDX[k]])
cells=[str(sample_name[IDX[k]])+'-'+i for i in adata.obs_names]
mat1=pd.DataFrame(adata.X.toarray(),columns=adata.var_names,index=cells)
coord1=pd.DataFrame(adata.obsm['spatial'],columns=['x','y'],index=cells)
meta1=adata.obs[['original_clusters', 'batch']]
meta1.columns=['celltype','batch']
meta1.index=cells
meta1.to_csv(out_dirs+"gtt_input_scanpy/"+flags+str(sample_name[IDX[k]])+"_label-1.txt")
coord1.to_csv(out_dirs+"gtt_input_scanpy/"+flags+str(sample_name[IDX[k]])+"_positions-1.txt")
MAT.append(mat1)
VF=np.union1d(VF,adata.var_names[adata.var['highly_variable']])
for i in np.arange(len(IDX)):
mat=MAT[i]
mat=mat.loc[:,VF]
mat.to_csv(out_dirs+"gtt_input_scanpy/"+flags+str(sample_name[IDX[i]])+"_features-1.txt")
[ ]:
rad=150
KNN=6
# lists=[[151507,151508],[151508,151509],[151509,151510],[151669,151670],[151670,151671],[151671,151672],[151673,151674],[151674,151675],[151675,151676]]
lists=[[151507, 151508, 151509, 151510], [151669, 151670, 151671, 151672], [151673, 151674, 151675, 151676]]
dirs="/home/yunfei/spatial_benchmarking/3d_recon/SPIRAL/data/DLPFC/"
# dirs="/home/yunfei/spatial_benchmarking/benchmarking_data/DLPFC12"
# sample_name=[151507,151508,151509,151510,151669,151670,151671,151672,151673,151674,151675,151676]
for sample_name in lists:
# IDX=[0,1]
IDX=[0,1,2,3]
flags=str(sample_name[IDX[0]])
for i in np.arange(1,len(IDX)):
flags=flags+'-'+str(sample_name[IDX[i]])
for i in IDX:
sample1=sample_name[i]
features=pd.read_csv(dirs+"gtt_input_scanpy/"+flags+'_'+str(sample1)+"_features-1.txt",header=0,index_col=0,sep=',')
meta=pd.read_csv(dirs+"gtt_input_scanpy/"+flags+'_'+str(sample1)+"_label-1.txt",header=0,index_col=0,sep=',')
coord=pd.read_csv(dirs+"gtt_input_scanpy/"+flags+'_'+str(sample1)+"_positions-1.txt",header=0,index_col=0,sep=',')
# meta=meta.iloc[:meta.shape[0]-1,:]
adata = sc.AnnData(features)
adata.var_names_make_unique()
adata.X=sp.csr_matrix(adata.X)
adata.obsm["spatial"] = coord.loc[:,['x','y']].to_numpy()
Cal_Spatial_Net(adata, rad_cutoff=rad, k_cutoff=6, model='KNN', verbose=True)
if 'highly_variable' in adata.var.columns:
adata_Vars = adata[:, adata.var['highly_variable']]
else:
adata_Vars = adata
features = pd.DataFrame(adata_Vars.X.toarray()[:, ], index=adata_Vars.obs.index, columns=adata_Vars.var.index)
cells = np.array(features.index)
cells_id_tran = dict(zip(cells, range(cells.shape[0])))
if 'Spatial_Net' not in adata.uns.keys():
raise ValueError("Spatial_Net is not existed! Run Cal_Spatial_Net first!")
Spatial_Net = adata.uns['Spatial_Net']
G_df = Spatial_Net.copy()
np.savetxt(dirs+"gtt_input_scanpy/"+flags+'_'+str(sample1)+"_edge_KNN_"+str(KNN)+".csv",G_df.values[:,:2],fmt='%s')
# G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
# G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
# adj = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
# adj+=adj.T.multiply(adj.T>adj)-adj.multiply(adj.T>adj)
# features=torch.FloatTensor(features.values)
DLPFC data integration (multi-slices)
[4]:
import os
import numpy as np
import argparse
import pandas as pd
from sklearn.decomposition import PCA
from operator import itemgetter
import random
import matplotlib.pyplot as plt
import umap.umap_ as umap
import time
import torch
from spiral.main import SPIRAL_integration
from spiral.layers import *
from spiral.utils import *
from spiral.CoordAlignment import CoordAlignment
# R_dirs="/home/tguo/tguo2/miniconda3/envs/stnet/lib/R"
# os.environ['R_HOME']=R_dirs
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
[ ]:
dirs = "/home/yunfei/spatial_benchmarking/3d_recon/SPIRAL/data/DLPFC/gtt_input_scanpy/"
sn_ = [[151507, 151508, 151509, 151510], [151510, 151669, 151670, 151671],
[151671, 151672, 151674, 151675]]
for sn in sn_[:1]:
# for sample_name in sn:
sample_name = np.array(sn)
samples = sample_name[:]
SEP = ','
net_cate = '_KNN_'
rad = 150
knn = 6
N_WALKS = knn
WALK_LEN = 1
N_WALK_LEN = knn
NUM_NEG = knn
feat_file = []
edge_file = []
meta_file = []
coord_file = []
flags = ''
flags1 = str(samples[0])
for i in range(1, len(samples)):
flags1 = flags1 + '-' + str(samples[i])
for i in range(len(sample_name)):
feat_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_features-1.txt")
# if sample_name[i] == 151676:
# edge_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_edge_Radius_" + str(rad) + ".csv")
# else:
edge_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_edge_KNN_" + str(knn) + ".csv")
meta_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_label-1.txt")
coord_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_positions-1.txt")
flags = flags + '_' + str(sample_name[i])
N = pd.read_csv(feat_file[0], header=0, index_col=0).shape[1]
if (len(sample_name) == 2):
M = 1
else:
M = len(sample_name)
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='The seed of initialization.')
parser.add_argument('--AEdims', type=list, default=[N,[512],32], help='Dim of encoder.')
parser.add_argument('--AEdimsR', type=list, default=[32,[512],N], help='Dim of decoder.')
parser.add_argument('--GSdims', type=list, default=[512,32], help='Dim of GraphSAGE.')
parser.add_argument('--zdim', type=int, default=32, help='Dim of embedding.')
parser.add_argument('--znoise_dim', type=int, default=4, help='Dim of noise embedding.')
parser.add_argument('--CLdims', type=list, default=[4,[],M], help='Dim of classifier.')
parser.add_argument('--DIdims', type=list, default=[28,[32,16],M], help='Dim of discriminator.')
parser.add_argument('--beta', type=float, default=1.0, help='weight of GraphSAGE.')
parser.add_argument('--agg_class', type=str, default=MeanAggregator, help='Function of aggregator.')
parser.add_argument('--num_samples', type=str, default=knn, help='number of neighbors to sample.')
parser.add_argument('--N_WALKS', type=int, default=N_WALKS, help='number of walks of random work for postive pairs.')
parser.add_argument('--WALK_LEN', type=int, default=WALK_LEN, help='walk length of random work for postive pairs.')
parser.add_argument('--N_WALK_LEN', type=int, default=N_WALK_LEN, help='number of walks of random work for negative pairs.')
parser.add_argument('--NUM_NEG', type=int, default=NUM_NEG, help='number of negative pairs.')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=1024, help='Size of batches to train.') ####512 for withon donor;1024 for across donor###
parser.add_argument('--lr', type=float, default=1e-3, help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay.')
parser.add_argument('--alpha1', type=float, default=N, help='Weight of decoder loss.')
parser.add_argument('--alpha2', type=float, default=1, help='Weight of GraphSAGE loss.')
parser.add_argument('--alpha3', type=float, default=1, help='Weight of classifier loss.')
parser.add_argument('--alpha4', type=float, default=1, help='Weight of discriminator loss.')
parser.add_argument('--lamda', type=float, default=1, help='Weight of GRL.')
parser.add_argument('--Q', type=float, default=10, help='Weight negative loss for sage losss.')
params,unknown=parser.parse_known_args()
SPII=SPIRAL_integration(params,feat_file,edge_file,meta_file)
SPII.train()
SPII.model.eval()
all_idx=np.arange(SPII.feat.shape[0])
all_layer,all_mapping=layer_map(all_idx.tolist(),SPII.adj,len(SPII.params.GSdims))
all_rows=SPII.adj.tolil().rows[all_layer[0]]
all_feature=torch.Tensor(SPII.feat.iloc[all_layer[0],:].values).float().cuda()
all_embed,ae_out,clas_out,disc_out=SPII.model(all_feature,all_layer,all_mapping,all_rows,SPII.params.lamda,SPII.de_act,SPII.cl_act)
[ae_embed,gs_embed,embed]=all_embed
[x_bar,x]=ae_out
embed=embed.cpu().detach()
names=['GTT_'+str(i) for i in range(embed.shape[1])]
embed1=pd.DataFrame(np.array(embed),index=SPII.feat.index,columns=names)
if not os.path.exists(dirs+"gtt_output/"):
os.makedirs(dirs+"gtt_output/")
embed_file=dirs+"gtt_output/SPIRAL"+flags+"_embed_"+str(SPII.params.batch_size)+".csv"
embed1.to_csv(embed_file)
meta=SPII.meta.values
embed_new=torch.cat((torch.zeros((embed.shape[0],SPII.params.znoise_dim)),embed.iloc[:,SPII.params.znoise_dim:]),dim=1)
xbar_new=np.array(SPII.model.agc.ae.de(embed_new.cuda(),nn.Sigmoid())[1].cpu().detach())
xbar_new1=pd.DataFrame(xbar_new,index=SPII.feat.index,columns=SPII.feat.columns)
xbar_new1.to_csv("/home/yunfei/spatial_benchmarking/3d_recon/SPIRAL/out/gtt_output/SPIRAL"+flags+"_correct_"+str(SPII.params.batch_size)+".csv")
Mouse Hypothalamus data integration (multi-slices)
[ ]:
dirs = "/home/yunfei/spatial_benchmarking/3d_recon/SPIRAL/data/mhypo/gtt_input_scanpy/"
sn_ = [['-0.04', '-0.09', '-0.14', '-0.19', '-0.24']]
c_ = [8]
iii = 0
for sn in sn_:
# for sample_name in sn:
sample_name = np.array(sn)
samples = sample_name[:]
SEP = ','
net_cate = '_KNN_'
rad = 150
knn = 6
N_WALKS = knn
WALK_LEN = 1
N_WALK_LEN = knn
NUM_NEG = knn
feat_file = []
edge_file = []
meta_file = []
coord_file = []
flags = ''
flags1 = str(samples[0])
for i in range(1, len(samples)):
flags1 = flags1 + '-' + str(samples[i])
for i in range(len(sample_name)):
feat_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_features-1.txt")
# if sample_name[i] == 151676:
# edge_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_edge_Radius_" + str(rad) + ".csv")
# else:
edge_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_edge_KNN_" + str(knn) + ".csv")
meta_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_label-1.txt")
coord_file.append(dirs + flags1 + '_' + str(sample_name[i]) + "_positions-1.txt")
flags = flags + '_' + str(sample_name[i])
N = pd.read_csv(feat_file[0], header=0, index_col=0).shape[1]
if (len(sample_name) == 2):
M = 1
else:
M = len(sample_name)
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0, help='The seed of initialization.')
parser.add_argument('--AEdims', type=list, default=[N,[512],32], help='Dim of encoder.')
parser.add_argument('--AEdimsR', type=list, default=[32,[512],N], help='Dim of decoder.')
parser.add_argument('--GSdims', type=list, default=[512,32], help='Dim of GraphSAGE.')
parser.add_argument('--zdim', type=int, default=32, help='Dim of embedding.')
parser.add_argument('--znoise_dim', type=int, default=4, help='Dim of noise embedding.')
parser.add_argument('--CLdims', type=list, default=[4,[],M], help='Dim of classifier.')
parser.add_argument('--DIdims', type=list, default=[28,[32,16],M], help='Dim of discriminator.')
parser.add_argument('--beta', type=float, default=1.0, help='weight of GraphSAGE.')
parser.add_argument('--agg_class', type=str, default=MeanAggregator, help='Function of aggregator.')
parser.add_argument('--num_samples', type=str, default=knn, help='number of neighbors to sample.')
parser.add_argument('--N_WALKS', type=int, default=N_WALKS, help='number of walks of random work for postive pairs.')
parser.add_argument('--WALK_LEN', type=int, default=WALK_LEN, help='walk length of random work for postive pairs.')
parser.add_argument('--N_WALK_LEN', type=int, default=N_WALK_LEN, help='number of walks of random work for negative pairs.')
parser.add_argument('--NUM_NEG', type=int, default=NUM_NEG, help='number of negative pairs.')
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=1024, help='Size of batches to train.') ####512 for withon donor;1024 for across donor###
parser.add_argument('--lr', type=float, default=1e-3, help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay.')
parser.add_argument('--alpha1', type=float, default=N, help='Weight of decoder loss.')
parser.add_argument('--alpha2', type=float, default=1, help='Weight of GraphSAGE loss.')
parser.add_argument('--alpha3', type=float, default=1, help='Weight of classifier loss.')
parser.add_argument('--alpha4', type=float, default=1, help='Weight of discriminator loss.')
parser.add_argument('--lamda', type=float, default=1, help='Weight of GRL.')
parser.add_argument('--Q', type=float, default=10, help='Weight negative loss for sage losss.')
params,unknown=parser.parse_known_args()
SPII=SPIRAL_integration(params,feat_file,edge_file,meta_file)
if not os.path.exists(dirs+"gtt_output/SPIRAL"+flags+"_embed_"+str(SPII.params.batch_size)+".csv"):
SPII.train()
SPII.model.eval()
all_idx=np.arange(SPII.feat.shape[0])
all_layer,all_mapping=layer_map(all_idx.tolist(),SPII.adj,len(SPII.params.GSdims))
all_rows=SPII.adj.tolil().rows[all_layer[0]]
all_feature=torch.Tensor(SPII.feat.iloc[all_layer[0],:].values).float().cuda()
all_embed,ae_out,clas_out,disc_out=SPII.model(all_feature,all_layer,all_mapping,all_rows,SPII.params.lamda,SPII.de_act,SPII.cl_act)
[ae_embed,gs_embed,embed]=all_embed
[x_bar,x]=ae_out
embed=embed.cpu().detach()
names=['GTT_'+str(i) for i in range(embed.shape[1])]
embed1=pd.DataFrame(np.array(embed),index=SPII.feat.index,columns=names)
if not os.path.exists(dirs+"gtt_output/"):
os.makedirs(dirs+"gtt_output/")
embed_file=dirs+"gtt_output/SPIRAL"+flags+"_embed_"+str(SPII.params.batch_size)+".csv"
embed1.to_csv(embed_file)
meta=SPII.meta.values
ann.obsm['spiral']=embed.numpy()
print(ann.obsm['spiral'])
else:
# embed_new=torch.cat((torch.zeros((embed.shape[0],SPII.params.znoise_dim)),embed.iloc[:,SPII.params.znoise_dim:]),dim=1)
# xbar_new=np.array(SPII.model.agc.ae.de(embed_new.cuda(),nn.Sigmoid())[1].cpu().detach())
# xbar_new1=pd.DataFrame(xbar_new,index=SPII.feat.index,columns=SPII.feat.columns)
# xbar_new1.to_csv("/home/yunfei/spatial_benchmarking/3d_recon/SPIRAL/out/gtt_output/SPIRAL"+flags+"_correct_"+str(SPII.params.batch_size)+".csv")
embed_file=dirs+"gtt_output/SPIRAL"+flags+"_embed_"+str(SPII.params.batch_size)+".csv"
embed=pd.read_csv(embed_file,header=0,index_col=0,sep=',')
ann=anndata.AnnData(SPII.feat)
print(ann.X)
ann.obsm['spiral']=embed.to_numpy()
print(ann.obsm['spiral'])
# embed_new=torch.cat((torch.zeros((embed.shape[0],SPII.params.znoise_dim)),embed.iloc[:,SPII.params.znoise_dim:]),dim=1)
# xbar_new=np.array(SPII.model.agc.ae.de(embed_new.cuda(),nn.Sigmoid())[1].cpu().detach())
# xbar_new1=pd.DataFrame(xbar_new,index=SPII.feat.index,columns=SPII.feat.columns)
# xbar_new1.to_csv("/home/yunfei/spatial_benchmarking/3d_recon/SPIRAL/out/gtt_output/SPIRAL"+flags+"_correct_"+str(SPII.params.batch_size)+".csv")
# embed_file=dirs+"gtt_output/SPIRAL"+flags+"_embed_"+str(SPII.params.batch_size)+".csv"
# embed=pd.read_csv(embed_file,header=0,index_col=0,sep=',')
ann=anndata.AnnData(SPII.feat)
print(ann.X)
ann.obsm['spiral']=embed.to_numpy()
print(ann.obsm['spiral'])
if not os.path.exists(dirs+"gtt_output/SPIRAL"+flags+"_mclust.csv"):
n_clust=c_[iii]
iii += 1
# res1=0.5 ####adjust to make sure 7 clusters
# res2=0.5
# sc.tl.leiden(ann,resolution=res1)
# sc.tl.louvain(ann,resolution=res2)
ann = mclust_R(ann, used_obsm='spiral', num_cluster=n_clust)
sc.pp.neighbors(ann,use_rep='spiral')
ann.obs['batch']=SPII.meta.loc[:,'batch'].values
ub=np.unique(ann.obs['batch'])
sc.tl.umap(ann)
coord=pd.read_csv(coord_file[0],header=0,index_col=0)
for i in np.arange(1,len(samples)):
coord=pd.concat((coord,pd.read_csv(coord_file[i],header=0,index_col=0)))
coord.columns=['y','x']
ann.obsm['spatial']=coord.loc[ann.obs_names,:].values
cluster_file=dirs+"gtt_output/SPIRAL"+flags+"_mclust.csv"
pd.DataFrame(ann.obs['mclust']).to_csv(cluster_file)
else:
cluster_file=dirs+"gtt_output/SPIRAL"+flags+"_mclust.csv"
clust_cate='louvain'
input_file=[meta_file,coord_file,embed_file,cluster_file]
output_dirs=dirs+"gtt_output/SPIRAL_alignment/"
if not os.path.exists(output_dirs):
os.makedirs(output_dirs)
ub=samples
alpha=0.5
types="weighted_mean"
R_dirs="/home/yunfei/anaconda3/envs/spiral/lib/R"
CA=CoordAlignment(input_file=input_file,output_dirs=output_dirs,ub=ub,flags=flags,clust_cate=clust_cate,R_dirs=R_dirs,alpha=alpha,types=types)
New_Coord=CA.New_Coord
New_Coord.to_csv(output_dirs+"new_coord"+flags+"_modify.csv")