SEDR tutorial

0. import packages and select GPU if accessible

[1]:
import os
import torch
import argparse
import warnings
import numpy as np
import anndata
import scanpy as sc
import matplotlib.pyplot as plt
import pandas as pd
from src.graph_func import graph_construction
from src.utils_func import mk_dir, adata_preprocess, load_visium_sge
from src.SEDR_train import SEDR_Train
from sklearn.metrics import adjusted_rand_score
from st_loading_utils import load_DLPFC, load_BC, load_mVC, load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP

warnings.filterwarnings('ignore')
torch.cuda.cudnn_enabled = False
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# Run device, by default, the package is implemented on 'cpu'. We recommend using GPU.
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
# iters = 1 # for script testing

1. DLPFC dataset

change ‘${dir_}’ to ‘path/to/your/DLPFC/data’

[ ]:
iters = 20 # for boxplotting
parser = argparse.ArgumentParser()
parser.add_argument('--k', type=int, default=10, help='parameter k in spatial graph')
parser.add_argument('--knn_distanceType', type=str, default='euclidean',
                    help='graph distance type: euclidean/cosine/correlation')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--cell_feat_dim', type=int, default=200, help='Dim of PCA')
parser.add_argument('--feat_hidden1', type=int, default=100, help='Dim of DNN hidden 1-layer.')
parser.add_argument('--feat_hidden2', type=int, default=20, help='Dim of DNN hidden 2-layer.')
parser.add_argument('--gcn_hidden1', type=int, default=32, help='Dim of GCN hidden 1-layer.')
parser.add_argument('--gcn_hidden2', type=int, default=8, help='Dim of GCN hidden 2-layer.')
parser.add_argument('--p_drop', type=float, default=0.2, help='Dropout rate.')
parser.add_argument('--using_dec', type=bool, default=True, help='Using DEC loss.')
parser.add_argument('--using_mask', type=bool, default=False, help='Using mask for multi-dataset.')
parser.add_argument('--feat_w', type=float, default=10, help='Weight of DNN loss.')
parser.add_argument('--gcn_w', type=float, default=0.1, help='Weight of GCN loss.')
parser.add_argument('--dec_kl_w', type=float, default=10, help='Weight of DEC loss.')
parser.add_argument('--gcn_lr', type=float, default=0.01, help='Initial GNN learning rate.')
parser.add_argument('--gcn_decay', type=float, default=0.01, help='Initial decay rate.')
parser.add_argument('--dec_cluster_n', type=int, default=10, help='DEC cluster number.')
parser.add_argument('--dec_interval', type=int, default=20, help='DEC interval nnumber.')
parser.add_argument('--dec_tol', type=float, default=0.00, help='DEC tol.')
# ______________ Eval clustering Setting _________
parser.add_argument('--eval_resolution', type=int, default=1, help='Eval cluster number.')
parser.add_argument('--eval_graph_n', type=int, default=20, help='Eval graph kN tol.')

params = parser.parse_args()
params.device = device


def res_search_fixed_clus(adata, fixed_clus_count, increment=0.02):
    '''
        arg1(adata)[AnnData matrix]
        arg2(fixed_clus_count)[int]

        return:
            resolution[int]
    '''
    for res in sorted(list(np.arange(0.2, 2.5, increment)), reverse=True):
        sc.tl.leiden(adata, random_state=0, resolution=res)
        count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
        if count_unique_leiden == fixed_clus_count:
            break
    return res
[ ]:
"""DLPFC"""
setting_combinations = [[7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'], [7, '151676']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]  # 7

   dataset = setting_combi[1]  # '151673'
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/DLPFC12'
   adata_h5 = load_DLPFC(root_dir=dir_, section_id=dataset)

   aris = []
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('DLPFC' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

2. BC/MA datasets

[ ]:
"""BC"""
# the number of clusters
setting_combinations = [[20, 'section1']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]  # 7

   dataset = setting_combi[1]  #
   save_fold = os.path.join('./output/', dataset)
   dir_ = '/home/yunfei/spatial_benchmarking/benchmarking_data/BC'
   adata_h5 = load_BC(root_dir=dir_, section_id=dataset)

   aris = []
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('HBRC1 ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')
[ ]:
"""load mMAMP ma section"""
setting_combinations = [[52, 'MA']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/mMAMP'
   adata_h5 = load_mMAMP(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mABC ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

3. mVC/mPFC datasets

[ ]:
"""mVC"""
setting_combinations = [[7, 'STARmap_20180505_BY3_1k.h5ad']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/STARmap_mouse_visual_cortex'
   adata_h5 = load_mVC(root_dir=dir_, section_id=dataset)

   aris = []
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mVC ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')
[ ]:
"""mPFC"""
# the number of clusters
setting_combinations = [[4, '20180417_BZ5_control'], [4, '20180419_BZ9_control'], [4, '20180424_BZ14_control']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/STARmap_mouse_PFC'
   adata_h5 = load_mPFC(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mPFC' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

4. mHypothalamus dataset

[ ]:
"""mHypo"""
setting_combinations = [[8, '-0.04'], [8, '-0.09'], [8, '-0.14'], [8, '-0.19'], [8, '-0.24']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]

   dataset = setting_combi[1]
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/mHypothalamus'
   adata_h5 = load_mHypothalamus(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('mHypothalamus' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

5. Her2Tumor dataset

[ ]:
"""Her2"""
setting_combinations = [[6, 'A1'], [5, 'B1'], [4, 'C1'], [4, 'D1'], [4, 'E1'], [4, 'F1'], [7, 'G2'], [7, 'H1']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]  # 7

   dataset = setting_combi[1]  # '151673'
   save_fold = os.path.join('./output/', dataset)
   dir_ = './benchmarking_data/Her2_tumor'
   adata_h5 = load_her2_tumor(root_dir=dir_, section_id=dataset)

   aris = []
   if params.cell_feat_dim > len(adata_h5.var.index):
      params.cell_feat_dim = len(adata_h5.var.index)-1
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      # adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['original_clusters'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
      aris.append(ARI)
   print('Dataset:', dataset)
   print(aris)
   print(np.mean(aris))
   with open('sedr_aris.txt', 'a+') as fp:
      fp.write('Her2tumor' + dataset + ' ')
      fp.write(' '.join([str(i) for i in aris]))
      fp.write('\n')

6. mouse hippocampus

[ ]:
iters = 20 # for boxplotting
parser = argparse.ArgumentParser()


params, unknown = parser.parse_known_args()
params.device = device
params.cell_feat_dim = 100
params.k = 6
params.knn_distanceType = 'euclidean'
params.epochs = 200
params.feat_hidden1 = 50
params.feat_hidden2 = 10
params.gcn_hidden1 = 16
params.gcn_hidden2 = 8
params.p_drop = 0.2
params.using_dec = True
params.using_mask = False
params.feat_w = 10
params.gcn_w = 0.1
params.dec_kl_w = 10
params.gcn_lr = 0.01
params.gcn_decay = 0.01
params.dec_cluster_n = 10
params.dec_interval = 20
params.dec_tol = 0
params.eval_resolution = 1
params.eval_graph_n = 20

def res_search_fixed_clus(adata, fixed_clus_count, increment=0.02):
    '''
        arg1(adata)[AnnData matrix]
        arg2(fixed_clus_count)[int]

        return:
            resolution[int]
    '''
    for res in sorted(list(np.arange(0.2, 2.5, increment)), reverse=True):
        sc.tl.leiden(adata, random_state=0, resolution=res)
        count_unique_leiden = len(pd.DataFrame(adata.obs['leiden']).leiden.unique())
        if count_unique_leiden == fixed_clus_count:
            break
    return res


"""DLPFC"""
setting_combinations = [[14, 'sshippo.h5ad']]
for setting_combi in setting_combinations:
   n_clusters = setting_combi[0]  # 7

   dataset = setting_combi[1]  # '151673'
   save_fold = os.path.join('./output/', dataset)
   dir_ = '/home/yunfei/spatial_benchmarking/benchmarking_data/mouse_hyppocampus_slideseqv2'
   adata_h5 = sc.read_h5ad(os.path.join(dir_, dataset))
   spatial = np.vstack((adata_h5.obs['x'].to_numpy(), adata_h5.obs['y'].to_numpy()))
   adata_h5.obsm['spatial'] = spatial.T

   aris = []
   adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
   graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
   params.cell_num = adata_h5.shape[0]
   params.save_path = mk_dir(save_fold)
   print('==== Graph Construction Finished')
   for iter in range(iters):


      sedr_net = SEDR_Train(adata_X, graph_dict, params)
      if params.using_dec:
         sedr_net.train_with_dec()
      else:
         sedr_net.train_without_dec()
      sedr_feat, _, _, _ = sedr_net.process()

      # ################## Result plot
      adata_sedr = anndata.AnnData(sedr_feat, obs=adata_h5.obs)
      adata_sedr.uns['spatial'] = adata_h5.uns['spatial']
      adata_sedr.obsm['spatial'] = adata_h5.obsm['spatial']
      # adata_sedr.obs['original_clusters'] = adata_h5.obs['original_clusters']
      sc.pp.neighbors(adata_sedr, n_neighbors=params.eval_graph_n)

      eval_resolution = res_search_fixed_clus(adata_sedr, n_clusters)
      sc.tl.leiden(adata_sedr, key_added="SEDR_leiden", resolution=eval_resolution)
      print(adata_sedr.obs)
      # calculate metric ARI
      obs_df = adata_sedr.obs.dropna()
      ARI = adjusted_rand_score(obs_df['SEDR_leiden'], obs_df['cluster'])

      print('Dataset:', dataset)
      print('ARI:', ARI)
#       aris.append(ARI)
#    print('Dataset:', dataset)
#    print(aris)
#    print(np.mean(aris))
#    with open('sedr_aris.txt', 'a+') as fp:
#       fp.write('DLPFC' + dataset + ' ')
#       fp.write(' '.join([str(i) for i in aris]))
#       fp.write('\n')
[ ]:
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.dpi'] = 300
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

SMALL_SIZE = 15
MEDIUM_SIZE = 18
BIGGER_SIZE = 26

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

sc.pl.spatial(adata_sedr,
        color=["SEDR_leiden", "cluster"],
        title=["SEDR", "Ground Truth"],
        show=False, spot_size=20)
plt.savefig(os.path.join("/home/yunfei/spatial_benchmarking/BenchmarkST/analysis1110/clustering/mousehippo", "hippocampus_sedr.pdf"), bbox_inches='tight')