04_MERFISH_U2OS_cell_line

1_dataset details

Dataset Cell number Gene number TSGs Pattern
merfish_u2os 989 130 123375 ——
merfish_u2os(group1) 621 9 947 6
merfish_u2os(group2) 634 25 1000 5
merfish_u2os(group3) 621 25 1000 6
merfish_u2os(group4) 629 25 1000 7
merfish_u2os(group5) 989 25 23242 8

How are labels annotated?

  • The labels were generated through a semi-automated pipeline that combined predictions from five trained classifiers (LightGBM, Random Forest, XGBoost, Gradient Boosting, and Decision Tree) on simulated data, followed by manual refinement. This approach accurately annotated 23,242 transcript spots from 25 selected genes, with results consistent with known localization patterns (e.g., nuclear genes like MALAT1 labeled correctly, and secreted proteins like THBS1 assigned to the nuclear membrane). The integration of computational and manual validation ensured high-confidence pseudo-labels for the unannotated dataset.

f4_classifiers

2_GRASP preprocessing

step1: Load data

dataset = "merfish_u2os_group1"
outfile = f'../1_input/pkl_data/{dataset}_data_dict.pkl'

with open(outfile, 'rb') as f:
    pickle_dict = pickle.load(f)
    
df_registered = pickle_dict['df_registered'] 
cell_radii = pickle_dict['cell_radii']
cell_boundary = pickle_dict['cell_boundary']
nuclear_boundary = pickle_dict['nuclear_boundary']
nuclear_boundary_df_registered = pickle_dict['nuclear_boundary_df_registered'] 

step2: Visualize the original and normalized TSGs

path = "../2_scaled_cell"
cep.plot_raw_gene_distribution(dataset, cell_boundary, nuclear_boundary, df_registered, path)
cep.plot_register_gene_distribution(dataset, df_registered, path, nuclear_boundary_df_registered)

step3: Cell partitioning

import utils_code.partition as pat

gene_list = df_registered['gene'].unique()
cell_list = df_registered['cell'].unique()
dir = f"../4_partition_same/{dataset}_partition/"
if not os.path.exists(dir):
    os.makedirs(dir)
for n_sectors in range(30, 31, 10): 
    for m_rings in range(15, 16, 5):  
        for target_cell in cell_list:    
            df = df_registered[df_registered['cell'] == target_cell]
            genes = df['gene'].unique()
            k_neighbor = int((n_sectors * m_rings) / 10) 
            plot_dir = os.path.join(dir,f"{target_cell}/{target_cell}_{n_sectors}_{m_rings}_k{k_neighbor}")
            if not os.path.exists(plot_dir):
                os.makedirs(plot_dir)
            print(f"This is [target_cell: {target_cell}] - [n_sectors: {n_sectors}] - [m_rings: {m_rings}], and k_neighbor is {k_neighbor}")                     
            nuclear_boundary_df = nuclear_boundary_df_registered[nuclear_boundary_df_registered['cell'] == target_cell] 
            for gene in tqdm(gene_list, desc="Processing genes"): 
                df_filtered = df[df['gene'] == gene]
                r = 1  
                count_matrix, center_points, point_counts, is_virtual, is_edge = pat.count_points_in_areas_same(df_filtered, n_sectors, m_rings, r) 
                nuclear_positions = pat.classify_center_points_with_edge(center_points, nuclear_boundary_df, is_edge) 
                edges = pat.build_graph_k_nearest(center_points, k=k_neighbor)
                G = pat.build_graph_with_networkx(center_points, edges, is_virtual)
                pat.save_node_data_to_csv_old(center_points, is_virtual, plot_dir, gene, point_counts, k=k_neighbor, nuclear_positions=nuclear_positions)
                pat.plot_cell_partition_heatmap(target_cell, gene, point_counts, n_sectors, m_rings, r, plot_dir, nuclear_boundary_df)
                

step4: Enhancement of TSGs

import utils_code.augumentation as aug

n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)
dropout_ratios = [0.1, 0.2, 0.3]

dir = f"../4_partition_same/{dataset}_partition/"
gene_list = df_registered['gene'].unique()
cell_list = df_registered['cell'].unique()

for cell in tqdm(cell_list, desc="Processing all cells", leave=True):
    path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
    save_path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    for gene in tqdm(gene_list, desc="Processing all genes", leave=True):
        nodes_file = f'{path}/{gene}_node_matrix.csv'
        adj_file = f'{path}/{gene}_adj_matrix.csv'
        if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
            print(f"Skipping {gene} in {cell} (file not found).")
            continue
        node_matrix = pd.read_csv(nodes_file)
        adj_matrix = pd.read_csv(adj_file)
        random_angle = random.uniform(0, 360)
        # print(random_angle)
        node_matrix_rotated = aug.rotate_nodes(node_matrix.copy(), random_angle)
        real_nodes_count = (node_matrix_rotated['is_virtual'] == 0).sum()
        if real_nodes_count >= 10:
            if real_nodes_count <= 100:
                dropout_ratio = dropout_ratios[0]
            elif real_nodes_count > 100 and real_nodes_count <= 150:
                dropout_ratio = dropout_ratios[1]
            else:
                dropout_ratio = dropout_ratios[2]
            adj_matrix_dropped, node_matrix_dropped = aug.dropout_nodes(adj_matrix.copy(), node_matrix_rotated.copy(), dropout_ratio)
            adj_matrix_dropped.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
            node_matrix_dropped.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
        else:
            adj_matrix.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
            node_matrix_rotated.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
            

3_GRASP training

step5: Load all TSGs to prepare for training

cell_list = df_registered['cell'].unique()
gene_list = df_registered['gene'].unique()
cell_numbers = len(cell_list)
gene_numbers = len(gene_list)

n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)

path = f"../4_partition_same/{dataset}_partition" 
df = pd.read_csv(f"../1_input/label/{dataset}_label.csv")
original_graphs, augmented_graphs = gra.generate_graph_data_target(dataset, df, path, n_sectors, m_rings, k_neighbor)

print(len(original_graphs))
print(len(augmented_graphs))

gene_labels = [data.gene for data in original_graphs]
cell_labels = [data.cell for data in original_graphs]
Processing Graphs: 100%|██████████| 947/947 [3:22<00:00,  7.49it/s]  
graphs_number = len(original_graphs)
print(f"cell_numbers:{cell_numbers} - gene_numbers:{gene_numbers} - graphs_number:{graphs_number}")

save_path = f"../5_graph_data"
if not os.path.exists(save_path):
    os.makedirs(save_path)

graph_data = {"original_graphs": original_graphs, 
              "augmented_graphs": augmented_graphs,
              "gene_labels": gene_labels,
              "cell_labels": cell_labels}

save_file = f"{save_path}/{dataset}_cell{cell_numbers}_gene{gene_numbers}_graph{graphs_number}.pkl"
with open(save_file, 'wb') as f:  
    pickle.dump(graph_data, f)

print(f"Graph data saved to {save_file}")

step6: Clustering and identifying spatial localization patterns

color_map = {'Nuclear':'#ed6ca4', 'Cytoplasmic': '#7bc4e2', 'Protrusion': '#acd372', 'Nuclear edge':'#fbb05b',
             'Cell edge': '#EDABB5', 'Random':'#ACD0E4', 'Foci':'#FFD4AB', 'Radial':'#DDC4E0'}


def plot_tsne(data, label_col, title, legend_title, save_name):
    plt.figure(figsize=(6, 4))
    for label, group in data.groupby(label_col):
        plt.scatter(x=group['tsne_x'], y=group['tsne_y'], color=color_map[label], label=label, s=5)
    
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1)
    ax.spines['bottom'].set_linewidth(1)
    ax.xaxis.set_major_locator(MultipleLocator(20))  
    ax.yaxis.set_major_locator(MultipleLocator(20))  
    ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=16)
    ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=16)
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')
    for label in ax.get_yticklabels():
        label.set_fontweight('bold')
    plt.grid(False)
    plt.legend(title=legend_title, frameon=True, fontsize=12, title_fontsize=13, markerscale=3.0, bbox_to_anchor=(1, 1), loc='upper left')
    plt.title(title, fontsize=16)
    plt.xlabel('t-SNE 1', fontsize=16,fontweight='bold')
    plt.ylabel('t-SNE 2', fontsize=16,fontweight='bold')
    plt.tight_layout()
    for ext in ['png', 'pdf', 'svg']:
        plt.savefig(f'{save_name}.{ext}', bbox_inches='tight', dpi=300)
    plt.show()
dataset = "merfish_u2os_group4"
params_list = [(0.2, 0.8)]
for a, b in params_list:
    path1 = f'../1.5_benchmark/figure/{dataset}/s1_tsne_ours_a{a}_b{b}'
    path2 = f'../1.5_benchmark/figure/{dataset}/s1_tsne_gt_a{a}_b{b}'
    our_label = pd.read_csv(f'../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv')
    plot_tsne(data=our_label, label_col='pattern', title='', legend_title='GRASP', save_name=path1)
    plot_tsne(data=our_label, label_col='groundtruth_wzx', title='', legend_title='Ground truth', save_name=path2)

f4_tsne

step7: Statistical data distribution

for group in ['group1']:
    tmp = pd.read_csv(f"../7_classifier/predicted/sampled_data_{group}.csv")

    gene_label_counts = pd.crosstab(tmp['gene'], tmp['groundtruth_wzx'])
    fig, ax = plt.subplots(figsize=(5, 2.5)) 
    sns.heatmap(gene_label_counts.T, annot=True, fmt='d', cmap='GnBu')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.tick_params(axis='x', labelsize=10)
    ax.tick_params(axis='y', labelsize=10)
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')
    for label in ax.get_yticklabels():
        label.set_fontweight('bold')
    plt.tight_layout()
    plt.show()

f4_group4

step8: Analyze individual genes

def plot_gene_tsne(gene, pattern1='Nuclear', pattern2='Random', a=0.2, b=0.8,
                   dataset="merfish_u2os_group4", save_path="../1.5_benchmark/figure"):
    
    df = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv")
    color_map = {'Nuclear': '#ed6ca4', 'Cytoplasmic': '#7bc4e2','Protrusion': '#acd372', 'Nuclear edge':'#fbb05b',
             'Cell edge': '#EDABB5', 'Random':'#ACD0E4', 'Foci':'#FFD4AB', 'Radial':'#DDC4E0'}

    cond1 = (df['gene'] == gene) & (df['pattern'] == pattern1)
    cond2 = (df['gene'] == gene) & (df['pattern'] == pattern2)
    cond_gray = ~ (cond1 | cond2)

    # 统计信息
    patterns = ['Nuclear edge', 'Random', 'Cytoplasmic', 'Cell edge', 'Nuclear','Foci']
    for pattern in patterns:
        count = df[(df['gene'] == gene) & (df['pattern'] == pattern)].shape[0]
        print(f"{gene} - {pattern}: {count} cells")

    # 绘图
    plt.figure(figsize=(5, 4))
    plt.scatter(df.loc[cond_gray, 'tsne_x'], df.loc[cond_gray, 'tsne_y'],
                color='#e5e9ea', label='Others', s=10)
    for cond, label in [(cond1, f'{gene}, {pattern1}'), (cond2, f'{gene}, {pattern2}')]:
        if cond.any():
            pattern_type = df.loc[cond, 'pattern'].iloc[0]
            color = color_map.get(pattern_type, '#000000') 
            plt.scatter(df.loc[cond, 'tsne_x'], df.loc[cond, 'tsne_y'], color=color, label=label, s=15)

    # 坐标轴设置
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(True)
    ax.spines['bottom'].set_visible(True)
    ax.xaxis.set_major_locator(MultipleLocator(20))
    ax.yaxis.set_major_locator(MultipleLocator(20))
    ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=14)
    ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=14)
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')
    for label in ax.get_yticklabels():
        label.set_fontweight('bold')
    plt.grid(False)
    ax.legend(title='Label', fontsize=12, title_fontsize=14, markerscale=2.0, loc='lower center', bbox_to_anchor=(0.5, 1.05), columnspacing=1, ncol=2, frameon=True)
    plt.xlabel('t-SNE 1', fontsize=16, fontweight='bold')
    plt.ylabel('t-SNE 2', fontsize=16, fontweight='bold')
    output_prefix = f"{save_path}/{dataset}/s5_merfish_{gene}"
    plt.savefig(f"{output_prefix}.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"{output_prefix}.pdf", bbox_inches='tight')
    plt.savefig(f"{output_prefix}.svg", bbox_inches='tight')
    plt.show()
plot_gene_tsne(gene='SRRM2', pattern1='Nuclear', pattern2='Random')
SRRM2 - Nuclear edge: 0 cells
SRRM2 - Random: 30 cells
SRRM2 - Cytoplasmic: 1 cells
SRRM2 - Cell edge: 2 cells
SRRM2 - Nuclear: 67 cells
SRRM2 - Foci: 0 cells

f4_SRRM2

plot_gene_tsne(gene='TLN1', pattern1='Cell edge', pattern2='Cytoplasmic')
TLN1 - Nuclear edge: 2 cells
TLN1 - Random: 1 cells
TLN1 - Cytoplasmic: 57 cells
TLN1 - Cell edge: 40 cells
TLN1 - Nuclear: 0 cells
TLN1 - Foci: 0 cells

f4_TLN1

a, b= 0.2, 0.8
dataset = "merfish_u2os_group4"
df = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv")
gene1, pattern1 = "COL5A1", "Nuclear edge"
gene2, pattern2 = "MALAT1", "Nuclear"
gene3, pattern3 = "SPTBN1", "Cytoplasmic"
cond_red = (df['gene'] == gene1) & (df['pattern'] == pattern1)
cond_green = (df['gene'] == gene2) & (df['pattern'] == pattern2)
cond_blue = (df['gene'] == gene3) & (df['pattern'] == pattern3)
cond_gray = ~ (cond_red | cond_green | cond_blue)  

count = df[(df['gene'] == gene1) & (df['pattern'] == pattern1)].shape[0]
print(f"{gene1} - {pattern1}: {count} cells")
count = df[(df['gene'] == gene2) & (df['pattern'] == pattern2)].shape[0]
print(f"{gene2} - {pattern2}: {count} cells")
count = df[(df['gene'] == gene3) & (df['pattern'] == pattern3)].shape[0]
print(f"{gene3} - {pattern3}: {count} cells")
# 创建画布
plt.figure(figsize=(5, 4))
plt.scatter(df.loc[cond_gray, 'tsne_x'], df.loc[cond_gray, 'tsne_y'], color='#e5e9ea', label='Others',s=10)
plt.scatter(df.loc[cond_red, 'tsne_x'], df.loc[cond_red, 'tsne_y'], color='#fbb05b', label=f'{gene1}, {pattern1}',s=15)
plt.scatter(df.loc[cond_green, 'tsne_x'], df.loc[cond_green, 'tsne_y'], color='#ed6ca4', label=f'{gene2}, {pattern2}',s=15)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.yaxis.set_major_locator(MultipleLocator(20))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=14)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=14)
for label in ax.get_xticklabels():
    label.set_fontweight('bold')
for label in ax.get_yticklabels():
    label.set_fontweight('bold')
plt.grid(False)
ax.legend(title='Label', fontsize=12, title_fontsize=14, markerscale=2.0, loc='lower center', bbox_to_anchor=(0.5, 1.05), 
          columnspacing=1, ncol=2, frameon=True)
plt.xlabel('t-SNE 1', fontsize=16, fontweight='bold')
plt.ylabel('t-SNE 2', fontsize=16, fontweight='bold')

plt.savefig(f'../1.5_benchmark/figure/{dataset}/merfish_three.png', dpi=300, bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/merfish_three.pdf', bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/merfish_three.svg', bbox_inches='tight')
plt.show()
COL5A1 - Nuclear edge: 92 cells
MALAT1 - Nuclear: 93 cells
SPTBN1 - Cytoplasmic: 48 cells

f4_MALAT1