02_Simulated_datasets

1_dataset details

Dataset Cell number Gene number Cell type TSGs Pattern Ratio Batch
Simulated_data1 10 80 1 800 8 5 2
Simulated_data2 100 25 5 2500 5 5 1
Simulated_data3 50 400 2 15000 8 10 5

f2_summary

Tips

  • Using the built-in 318 cell templates from the simfish python package, 10/20 cells were randomly selected;- Generated RNA spatial localization data for 8 gene spatial distribution patterns: Random, Foci, Intranuclear, Extranuclear, Perinuclear, Pericellular, Protrusion, and Radial (ours);

f2_pattern

  • Each spatial pattern simulated 5 different intensities of spatial distribution: 60%, 70%, 80%, 90%, and 100%, with higher intensity leading to more pronounced spatial distribution patterns;

  • Each spatial pattern also simulated 25, 50, 100, 150, and 200 different transcript numbers;

  • Different expression intensities and different transcript numbers of the same spatial pattern can be viewed as a gene expressing that pattern;

  • Foci is a special pattern that does not require specification of transcript numbers.

2_GRASP preprocessing

step1: Load data

dataset = "simulated1"
outfile = f'../1_input/pkl_data/{dataset}_data_dict.pkl'

with open(outfile, 'rb') as f:
    pickle_dict = pickle.load(f)
    
df_registered = pickle_dict['df_registered'] 
cell_radii = pickle_dict['cell_radii']
cell_boundary = pickle_dict['cell_boundary']
nuclear_boundary = pickle_dict['nuclear_boundary']
nuclear_boundary_df_registered = pickle_dict['nuclear_boundary_df_registered'] 

step2: Visualize the original and normalized TSGs

path = "../2_scaled_cell"
cep.plot_raw_gene_distribution(dataset, cell_boundary, nuclear_boundary, df_registered, path)
cep.plot_register_gene_distribution(dataset, df_registered, path, nuclear_boundary_df_registered)
Processing cells: 100%|██████████| 10/10 [11:26<00:00, 68.62s/it]
All cell images have been saved to ../2_scaled_cell/simulated1/raw_gene/cell_cell_257
Plotting per cell: 100%|██████████| 10/10 [12:38<00:00, 75.84s/it]

step3: Cell partitioning

import utils_code.partition as pat

gene_list = df_registered['gene'].unique()
cell_list = df_registered['cell'].unique()
dir = f"../4_partition_same/{dataset}_partition/"
if not os.path.exists(dir):
    os.makedirs(dir)
for n_sectors in range(30, 31, 10):  
    for m_rings in range(15, 16, 5):   
        for target_cell in cell_list:    
            df = df_registered[df_registered['cell'] == target_cell]
            genes = df['gene'].unique()
            k_neighbor = int((n_sectors * m_rings) / 10) 
            plot_dir = os.path.join(dir,f"{target_cell}/{target_cell}_{n_sectors}_{m_rings}_k{k_neighbor}")
            if not os.path.exists(plot_dir):
                os.makedirs(plot_dir)
            print(f"This is [target_cell: {target_cell}] - [n_sectors: {n_sectors}] - [m_rings: {m_rings}], and k_neighbor is {k_neighbor}")                     
            nuclear_boundary_df = nuclear_boundary_df_registered[nuclear_boundary_df_registered['cell'] == target_cell]
            for gene in tqdm(gene_list, desc="Processing genes"): 
                df_filtered = df[df['gene'] == gene]
                r = 1  # r = cell_radii[target_cell]  
                count_matrix, center_points, point_counts, is_virtual, is_edge = pat.count_points_in_areas_same(df_filtered, n_sectors, m_rings, r) 
                nuclear_positions = pat.classify_center_points_with_edge(center_points, nuclear_boundary_df, is_edge) 
                edges = pat.build_graph_k_nearest(center_points, k=k_neighbor)
                G = pat.build_graph_with_networkx(center_points, edges, is_virtual)
                pat.save_node_data_to_csv_old(center_points, is_virtual, plot_dir, gene, point_counts, k=k_neighbor, nuclear_positions=nuclear_positions)
                pat.plot_cell_partition_heatmap(target_cell, gene, point_counts, n_sectors, m_rings, r, plot_dir, nuclear_boundary_df)

step4: Enhancement of TSGs

import utils_code.augumentation as aug

n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)
dropout_ratios = [0.1, 0.2, 0.3] 

dir = f"../4_partition_same/{dataset}_partition/"
gene_list = df_registered['gene'].unique()
cell_list = df_registered['cell'].unique()

for cell in tqdm(cell_list, desc="Processing all cells", leave=True):
    path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
    save_path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    for gene in tqdm(gene_list, desc="Processing all genes", leave=True):
        nodes_file = f'{path}/{gene}_node_matrix.csv'
        adj_file = f'{path}/{gene}_adj_matrix.csv'
        if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
            print(f"Skipping {gene} in {cell} (file not found).")
            continue
        node_matrix = pd.read_csv(nodes_file)
        adj_matrix = pd.read_csv(adj_file)
        random_angle = random.uniform(0, 360)
        node_matrix_rotated = aug.rotate_nodes(node_matrix.copy(), random_angle)
        real_nodes_count = (node_matrix_rotated['is_virtual'] == 0).sum()
        if real_nodes_count >= 10:
            if real_nodes_count <= 100:
                dropout_ratio = dropout_ratios[0]
            elif real_nodes_count > 100 and real_nodes_count <= 150:
                dropout_ratio = dropout_ratios[1]
            else:
                dropout_ratio = dropout_ratios[2]
            # print(f"The gene {gene} needs to drop out {dropout_ratio}.")
            adj_matrix_dropped, node_matrix_dropped = aug.dropout_nodes(adj_matrix.copy(), node_matrix_rotated.copy(), dropout_ratio)
            adj_matrix_dropped.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
            node_matrix_dropped.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
            # aug.plot_graph(adj_matrix, node_matrix,adj_matrix_dropped, node_matrix_dropped, f"{cell}_{gene}", save_path)
        else:
            # print(f"The gene {gene} not need to drop out.")
            adj_matrix.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
            node_matrix_rotated.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
            # aug.plot_graph(adj_matrix, node_matrix, adj_matrix, node_matrix_rotated, f"{cell}_{gene}_original", save_path)
            

3_GRASP training

step5: Load all TSGs to prepare for training

cell_list = df_registered['cell'].unique()
gene_list = df_registered['gene'].unique()
cell_numbers = len(cell_list)
gene_numbers = len(gene_list)

n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)

path = f"../4_partition_same/{dataset}_partition" 
df = pd.read_csv(f"../1_input/label/{dataset}_label.csv")
original_graphs, augmented_graphs = gra.generate_graph_data_target(dataset, df, path, n_sectors, m_rings, k_neighbor)

print(len(original_graphs))
print(len(augmented_graphs))

gene_labels = [data.gene for data in original_graphs]
cell_labels = [data.cell for data in original_graphs]
Processing Graphs: 100%|██████████| 800/800 [3:22<00:00,  7.49it/s]  
graphs_number = len(original_graphs)
print(f"cell_numbers:{cell_numbers} - gene_numbers:{gene_numbers} - graphs_number:{graphs_number}")

save_path = f"../5_graph_data"
if not os.path.exists(save_path):
    os.makedirs(save_path)
graph_data = {"original_graphs": original_graphs, 
              "augmented_graphs": augmented_graphs,
              "gene_labels": gene_labels,
              "cell_labels": cell_labels}

save_file = f"{save_path}/{dataset}_cell{cell_numbers}_gene{gene_numbers}_graph{graphs_number}.pkl"
with open(save_file, 'wb') as f:  
    pickle.dump(graph_data, f)

print(f"Graph data saved to {save_file}")
cell_numbers:10 - gene_numbers:80 - graphs_number:800
Graph data saved to ../5_graph_data/simulated1_cell10_gene80_graph800.pkl

step6: Clustering and identifying spatial localization patterns

a, b = 0.4, 0.6 
dataset = "simulated1"
path1 = f'../1.5_benchmark/figure/{dataset}/s1_tsne_ours_a{a}_b{b}'
path2 = f'../1.5_benchmark/figure/{dataset}/s1_tsne_gt_a{a}_b{b}'

ours_label = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv")
ours_label['pattern'].value_counts()

# 定义 color_map
color_map = {'Nuclear': '#ed6ca4', 'Cytoplasmic': '#7bc4e2', 'Protrusion': '#acd372', 'Nuclear edge':'#fbb05b',
             'Cell edge': '#EDABB5', 'Random':'#ACD0E4', 'Foci':'#FFD4AB', 'Radial':'#DDC4E0'}

plot_tsne_by_label(df=ours_label,label_col='pattern',color_map=color_map, title='', save_prefix=path1,
    legend_title='GRASP', legend_loc='upper left', legend_bbox=(0.5, 1), legend_ncol=3)

plot_tsne_by_label(df=ours_label, label_col='groundtruth', color_map=color_map, title='', save_prefix=path2,
    legend_title='Ground truth', legend_loc='upper left', legend_bbox=(0.5, 1.0), legend_ncol=3)

f2_simulated1