"""DLPFC"""
section_ids_list = [['151507', '151508'], ['151508', '151509'], ['151509', '151510'], ['151669', '151670'], ['151670', '151671'], ['151671', '151672'], ['151673', '151674'], ['151674', '151675'], ['151675', '151676']]
run_times = []
for iter_ in range(iters):
for section_ids in section_ids_list:
dataset = section_ids[0] + '_' + section_ids[1]
start_time = time.time()
output = '.'
data_slice1 = load_DLPFC(root_dir="../benchmarking_data/DLPFC12", section_id=section_ids[0])
data_slice1 = process_data(data_slice1, n_top_genes=200)
data_slice1.obs['batch'] = 0
data_slice2 = load_DLPFC(root_dir="../benchmarking_data/DLPFC12", section_id=section_ids[1])
data_slice2 = process_data(data_slice2, n_top_genes=200)
data_slice2.obs['batch'] = 1
data = anndata.concat([data_slice1, data_slice2])
if N_SAMPLES is not None:
rand_idx = np.random.choice(
np.arange(data_slice1.shape[0]), size=N_SAMPLES, replace=False
)
data_slice1 = data_slice1[rand_idx]
rand_idx = np.random.choice(
np.arange(data_slice2.shape[0]), size=N_SAMPLES, replace=False
)
data_slice2 = data_slice2[rand_idx]
# all_slices = anndata.concat([data_slice1, data_slice2])
n_samples_list = [data_slice1.shape[0], data_slice2.shape[0]]
view_idx = [
np.arange(data_slice1.shape[0]),
np.arange(data_slice1.shape[0], data_slice1.shape[0] + data_slice2.shape[0]),
]
X1 = data_slice1.obsm["spatial"]
X2 = data_slice2.obsm["spatial"]
Y1 = data_slice1.X.todense()
Y2 = data_slice2.X.todense()
X1 = scale_spatial_coords(X1)
X2 = scale_spatial_coords(X2)
Y1 = (Y1 - Y1.mean(0)) / Y1.std(0)
Y2 = (Y2 - Y2.mean(0)) / Y2.std(0)
X = np.concatenate([X1, X2])
Y = np.concatenate([Y1, Y2])
n_outputs = Y.shape[1]
x = torch.from_numpy(X).float().clone().to(device)
y = torch.from_numpy(Y).float().clone().to(device)
data_dict = {
"expression": {
"spatial_coords": x,
"outputs": y,
"n_samples_list": n_samples_list,
}
}
model = VariationalGPSA(
data_dict,
n_spatial_dims=n_spatial_dims,
m_X_per_view=m_X_per_view,
m_G=m_G,
data_init=True,
minmax_init=False,
grid_init=False,
n_latent_gps=N_LATENT_GPS,
mean_function="identity_fixed",
kernel_func_warp=rbf_kernel,
kernel_func_data=rbf_kernel,
# fixed_warp_kernel_variances=np.ones(n_views) * 1.,
# fixed_warp_kernel_lengthscales=np.ones(n_views) * 10,
fixed_view_idx=0,
).to(device)
view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
for t in tqdm(range(N_EPOCHS), desc="Training Progress"):
loss, G_means = train(model, model.loss_fn, optimizer)
curr_aligned_coords = G_means["expression"].detach().cpu().numpy()
print("Done!")
# G_means, _, _, _ = model.forward({"expression": x}, view_idx=view_idx, Ns=Ns)
# out = G_means['expression'].detach().cpu().numpy()
df3 = pd.DataFrame(
{
"aligned_x": curr_aligned_coords.T[0],
"aligned_y": curr_aligned_coords.T[1],
},
)
df3.index = data.obs.index
results = pd.concat([data.obs, df3], axis=1)
results.to_csv('./results/' + dataset + '_' + str(0) + '.csv')
end_time = time.time()
run_times.append(end_time - start_time)