SpaceFlow tutorial
[ ]:
import os
import torch
import pandas as pd
import scanpy as sc
from sklearn import metrics
import multiprocessing as mp
import numpy as np
import squidpy as sq
import scanpy as sc
from SpaceFlow import SpaceFlow
from st_loading_utils import load_DLPFC, load_BC, load_mVC, load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP, load_embryo
from scipy.spatial import *
from sklearn.preprocessing import *
from sklearn.metrics import *
from scipy.spatial.distance import *
[ ]:
def res_search_fixed_clus(adata, fixed_clus_count, increment=0.1):
'''
arg1(adata)[AnnData matrix]
arg2(fixed_clus_count)[int]
return:
resolution[int]
'''
for res in sorted(list(np.arange(0.2, 2.5, increment)), reverse=True):
sc.tl.leiden(adata, random_state=0, resolution=res)
count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
if count_unique_leiden == fixed_clus_count:
break
return res
def fx_1NN(i,location_in):
location_in = np.array(location_in)
dist_array = distance_matrix(location_in[i,:][None,:],location_in)[0,:]
dist_array[i] = np.inf
return np.min(dist_array)
def fx_kNN(i,location_in,k,cluster_in):
location_in = np.array(location_in)
cluster_in = np.array(cluster_in)
dist_array = distance_matrix(location_in[i,:][None,:],location_in)[0,:]
dist_array[i] = np.inf
ind = np.argsort(dist_array)[:k]
cluster_use = np.array(cluster_in)
if np.sum(cluster_use[ind]!=cluster_in[i])>(k/2):
return 1
else:
return 0
def _compute_CHAOS(clusterlabel, location):
clusterlabel = np.array(clusterlabel)
location = np.array(location)
matched_location = StandardScaler().fit_transform(location)
clusterlabel_unique = np.unique(clusterlabel)
dist_val = np.zeros(len(clusterlabel_unique))
count = 0
for k in clusterlabel_unique:
location_cluster = matched_location[clusterlabel==k,:]
if len(location_cluster)<=2:
continue
n_location_cluster = len(location_cluster)
results = [fx_1NN(i,location_cluster) for i in range(n_location_cluster)]
dist_val[count] = np.sum(results)
count = count + 1
return np.sum(dist_val)/len(clusterlabel)
def _compute_PAS(clusterlabel,location):
clusterlabel = np.array(clusterlabel)
location = np.array(location)
matched_location = location
results = [fx_kNN(i,matched_location,k=10,cluster_in=clusterlabel) for i in range(matched_location.shape[0])]
return np.sum(results)/len(clusterlabel)
def compute_CHAOS(adata,pred_key,spatial_key='spatial'):
return _compute_CHAOS(adata.obs[pred_key],adata.obsm[spatial_key])
def compute_PAS(adata,pred_key,spatial_key='spatial'):
return _compute_PAS(adata.obs[pred_key],adata.obsm[spatial_key])
def compute_ASW(adata,pred_key,spatial_key='spatial'):
d = squareform(pdist(adata.obsm[spatial_key]))
return silhouette_score(X=d,labels=adata.obs[pred_key],metric='precomputed')
DLPFC
[ ]:
"""DLPFC"""
setting_combinations = [[7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'], [7, '151676']]
for setting_combi in setting_combinations:
n_clusters = setting_combi[0] # 7
dataset = setting_combi[1] # '151673'
dir_ = '../benchmarking_data/DLPFC12'
adata = load_DLPFC(root_dir=dir_, section_id=dataset)
sc.pp.filter_genes(adata, min_cells=3)
adata.var_names_make_unique()
sf = SpaceFlow.SpaceFlow(adata=adata)
#preprocess
sf.preprocessing_data(n_top_genes=3000)
ari_list = []
nmi_list = []
ami_list = []
hm_list = []
time_list = []
chaos_list = []
pas_list = []
asw_list = []
for iter in range(20):
import tracemalloc
import time
tracemalloc.start()
start_time=time.time()
sf.train(spatial_regularization_strength=0.1,
embedding_save_filepath="./results_0424/DLPFC/"+dataset+"_"+str(iter)+"embedding.tsv",
z_dim=50,
lr=1e-3,
epochs=1000,
max_patience=50,
min_stop=100,
random_seed=42,
gpu=1,
regularization_acceleration=True,
edge_subset_sz=1000000)
# n_clusters=7
sc.pp.neighbors(adata, n_neighbors=50)
eval_resolution = res_search_fixed_clus(adata, n_clusters)
sf.segmentation(domain_label_save_filepath="./results_0424/DLPFC/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1),
n_neighbors=50,
resolution=eval_resolution)
pred=pd.read_csv("./results_0424/DLPFC/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1),header=None)
pred_list=pred.iloc[:,0].to_list()
adata.obs['pred_{}'.format(iter+1)] = np.array(pred_list)
obs_df = adata.obs.dropna()
ari = metrics.adjusted_rand_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
nmi = metrics.normalized_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
# print("AMI")
ami = metrics.adjusted_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
# print("homogeneity")
homogeneity = metrics.homogeneity_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
chaos = compute_CHAOS(adata, 'pred_{}'.format(iter+1))
pas = compute_PAS(adata, 'pred_{}'.format(iter+1))
asw = compute_ASW(adata, 'pred_{}'.format(iter+1))
ari_list.append(ari)
nmi_list.append(nmi)
ami_list.append(ami)
hm_list.append(homogeneity)
chaos_list.append(chaos)
pas_list.append(pas)
asw_list.append(asw)
end_time=time.time()
during=end_time-start_time
size, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
# memory[i]=peak /1024/1024
time_list.append(during)
# print('memory blocks peak:{:>10.4f} MB'.format(memory[i]))
print('time: {:.4f} s'.format(during))
print('ARI:{}'.format(ari))
print('NMI:{}'.format(nmi))
print('AMI:{}'.format(ami))
print('Homogeneity:{}'.format(homogeneity))
print('chaos:{}'.format(chaos))
print('pas:{}'.format(pas))
print('asw:{}'.format(asw))
MHypo
[ ]:
"""MHypo"""
setting_combinations = [[8, '-0.04'], [8, '-0.09'], [8, '-0.14'], [8, '-0.19'], [8, '-0.24']]
for setting_combi in setting_combinations:
n_clusters = setting_combi[0] # 7
dataset = setting_combi[1] # '151673'
dir_ = '../benchmarking_data/mHypothalamus'
adata = load_mHypothalamus(root_dir=dir_, section_id=dataset)
sc.pp.filter_genes(adata, min_cells=3)
adata.var_names_make_unique()
sf = SpaceFlow.SpaceFlow(adata=adata)
#preprocess
sf.preprocessing_data(n_top_genes=3000)
ari_list = []
nmi_list = []
ami_list = []
hm_list = []
time_list = []
chaos_list = []
pas_list = []
asw_list = []
for iter in range(20):
import tracemalloc
import time
tracemalloc.start()
start_time=time.time()
sf.train(spatial_regularization_strength=0.1,
embedding_save_filepath="./results_0424/mHypo/"+dataset+"_"+str(iter)+"embedding.tsv",
z_dim=50,
lr=1e-3,
epochs=1000,
max_patience=50,
min_stop=100,
random_seed=42,
gpu=1,
regularization_acceleration=True,
edge_subset_sz=1000000)
# n_clusters=7
sc.pp.neighbors(adata, n_neighbors=50)
eval_resolution = res_search_fixed_clus(adata, n_clusters)
sf.segmentation(domain_label_save_filepath="./results_0424/mHypo/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1),
n_neighbors=50,
resolution=eval_resolution)
pred=pd.read_csv("./results_0424/mHypo/"+dataset+"_"+str(iter)+"domains.tsv".format(iter+1),header=None)
pred_list=pred.iloc[:,0].to_list()
adata.obs['pred_{}'.format(iter+1)] = np.array(pred_list)
obs_df = adata.obs.dropna()
ari = metrics.adjusted_rand_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
nmi = metrics.normalized_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
# print("AMI")
ami = metrics.adjusted_mutual_info_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
# print("homogeneity")
homogeneity = metrics.homogeneity_score(obs_df['original_clusters'], obs_df['pred_{}'.format(iter+1)])
chaos = compute_CHAOS(adata, 'pred_{}'.format(iter+1))
pas = compute_PAS(adata, 'pred_{}'.format(iter+1))
asw = compute_ASW(adata, 'pred_{}'.format(iter+1))
ari_list.append(ari)
nmi_list.append(nmi)
ami_list.append(ami)
hm_list.append(homogeneity)
chaos_list.append(chaos)
pas_list.append(pas)
asw_list.append(asw)
end_time=time.time()
during=end_time-start_time
size, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
# memory[i]=peak /1024/1024
time_list.append(during)
# print('memory blocks peak:{:>10.4f} MB'.format(memory[i]))
print('time: {:.4f} s'.format(during))
print('ARI:{}'.format(ari))
print('NMI:{}'.format(nmi))
print('AMI:{}'.format(ami))
print('Homogeneity:{}'.format(homogeneity))
print('chaos:{}'.format(chaos))
print('pas:{}'.format(pas))
print('asw:{}'.format(asw))