# 04_MERFISH_U2OS_cell_line
#### 1_dataset details
| Dataset | Cell number | Gene number | TSGs | Pattern |
| -------------------- | --------- |------------ |------- | ---------------- |
| merfish_u2os | 989 | 130 | 123375 | —— |
| merfish_u2os(group1) | 621 | 9 | 947 | 6 |
| merfish_u2os(group2) | 634 | 25 | 1000 | 5 |
| merfish_u2os(group3) | 621 | 25 | 1000 | 6 |
| merfish_u2os(group4) | 629 | 25 | 1000 | 7 |
| merfish_u2os(group5) | 989 | 25 | 23242 | 8 |
###### How are labels annotated?
- The labels were generated through a **semi-automated pipeline** that combined predictions from five trained classifiers (LightGBM, Random Forest, XGBoost, Gradient Boosting, and Decision Tree) on simulated data, followed by **manual refinement**. This approach accurately annotated 23,242 transcript spots from 25 selected genes, with results consistent with known localization patterns (e.g., nuclear genes like *MALAT1* labeled correctly, and secreted proteins like *THBS1* assigned to the nuclear membrane). The integration of computational and manual validation ensured high-confidence pseudo-labels for the unannotated dataset.
#### 2_GRASP preprocessing
##### step1: Load data
```python
dataset = "merfish_u2os_group1"
outfile = f'../1_input/pkl_data/{dataset}_data_dict.pkl'
with open(outfile, 'rb') as f:
pickle_dict = pickle.load(f)
df_registered = pickle_dict['df_registered']
cell_radii = pickle_dict['cell_radii']
cell_boundary = pickle_dict['cell_boundary']
nuclear_boundary = pickle_dict['nuclear_boundary']
nuclear_boundary_df_registered = pickle_dict['nuclear_boundary_df_registered']
```
##### step2: Visualize the original and normalized TSGs
```python
path = "../2_scaled_cell"
cep.plot_raw_gene_distribution(dataset, cell_boundary, nuclear_boundary, df_registered, path)
cep.plot_register_gene_distribution(dataset, df_registered, path, nuclear_boundary_df_registered)
```
##### step3: Cell partitioning
```python
import utils_code.partition as pat
gene_list = df_registered['gene'].unique()
cell_list = df_registered['cell'].unique()
dir = f"../4_partition_same/{dataset}_partition/"
if not os.path.exists(dir):
os.makedirs(dir)
for n_sectors in range(30, 31, 10):
for m_rings in range(15, 16, 5):
for target_cell in cell_list:
df = df_registered[df_registered['cell'] == target_cell]
genes = df['gene'].unique()
k_neighbor = int((n_sectors * m_rings) / 10)
plot_dir = os.path.join(dir,f"{target_cell}/{target_cell}_{n_sectors}_{m_rings}_k{k_neighbor}")
if not os.path.exists(plot_dir):
os.makedirs(plot_dir)
print(f"This is [target_cell: {target_cell}] - [n_sectors: {n_sectors}] - [m_rings: {m_rings}], and k_neighbor is {k_neighbor}")
nuclear_boundary_df = nuclear_boundary_df_registered[nuclear_boundary_df_registered['cell'] == target_cell]
for gene in tqdm(gene_list, desc="Processing genes"):
df_filtered = df[df['gene'] == gene]
r = 1
count_matrix, center_points, point_counts, is_virtual, is_edge = pat.count_points_in_areas_same(df_filtered, n_sectors, m_rings, r)
nuclear_positions = pat.classify_center_points_with_edge(center_points, nuclear_boundary_df, is_edge)
edges = pat.build_graph_k_nearest(center_points, k=k_neighbor)
G = pat.build_graph_with_networkx(center_points, edges, is_virtual)
pat.save_node_data_to_csv_old(center_points, is_virtual, plot_dir, gene, point_counts, k=k_neighbor, nuclear_positions=nuclear_positions)
pat.plot_cell_partition_heatmap(target_cell, gene, point_counts, n_sectors, m_rings, r, plot_dir, nuclear_boundary_df)
```
##### step4: Enhancement of TSGs
```python
import utils_code.augumentation as aug
n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)
dropout_ratios = [0.1, 0.2, 0.3]
dir = f"../4_partition_same/{dataset}_partition/"
gene_list = df_registered['gene'].unique()
cell_list = df_registered['cell'].unique()
for cell in tqdm(cell_list, desc="Processing all cells", leave=True):
path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
save_path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
if not os.path.exists(save_path):
os.makedirs(save_path)
for gene in tqdm(gene_list, desc="Processing all genes", leave=True):
nodes_file = f'{path}/{gene}_node_matrix.csv'
adj_file = f'{path}/{gene}_adj_matrix.csv'
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
print(f"Skipping {gene} in {cell} (file not found).")
continue
node_matrix = pd.read_csv(nodes_file)
adj_matrix = pd.read_csv(adj_file)
random_angle = random.uniform(0, 360)
# print(random_angle)
node_matrix_rotated = aug.rotate_nodes(node_matrix.copy(), random_angle)
real_nodes_count = (node_matrix_rotated['is_virtual'] == 0).sum()
if real_nodes_count >= 10:
if real_nodes_count <= 100:
dropout_ratio = dropout_ratios[0]
elif real_nodes_count > 100 and real_nodes_count <= 150:
dropout_ratio = dropout_ratios[1]
else:
dropout_ratio = dropout_ratios[2]
adj_matrix_dropped, node_matrix_dropped = aug.dropout_nodes(adj_matrix.copy(), node_matrix_rotated.copy(), dropout_ratio)
adj_matrix_dropped.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
node_matrix_dropped.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
else:
adj_matrix.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
node_matrix_rotated.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
```
#### 3_GRASP training
##### step5: Load all TSGs to prepare for training
```python
cell_list = df_registered['cell'].unique()
gene_list = df_registered['gene'].unique()
cell_numbers = len(cell_list)
gene_numbers = len(gene_list)
n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)
path = f"../4_partition_same/{dataset}_partition"
df = pd.read_csv(f"../1_input/label/{dataset}_label.csv")
original_graphs, augmented_graphs = gra.generate_graph_data_target(dataset, df, path, n_sectors, m_rings, k_neighbor)
print(len(original_graphs))
print(len(augmented_graphs))
gene_labels = [data.gene for data in original_graphs]
cell_labels = [data.cell for data in original_graphs]
```
Processing Graphs: 100%|██████████| 947/947 [3:22<00:00, 7.49it/s]
```python
graphs_number = len(original_graphs)
print(f"cell_numbers:{cell_numbers} - gene_numbers:{gene_numbers} - graphs_number:{graphs_number}")
save_path = f"../5_graph_data"
if not os.path.exists(save_path):
os.makedirs(save_path)
graph_data = {"original_graphs": original_graphs,
"augmented_graphs": augmented_graphs,
"gene_labels": gene_labels,
"cell_labels": cell_labels}
save_file = f"{save_path}/{dataset}_cell{cell_numbers}_gene{gene_numbers}_graph{graphs_number}.pkl"
with open(save_file, 'wb') as f:
pickle.dump(graph_data, f)
print(f"Graph data saved to {save_file}")
```
##### step6: Clustering and identifying spatial localization patterns
```python
color_map = {'Nuclear':'#ed6ca4', 'Cytoplasmic': '#7bc4e2', 'Protrusion': '#acd372', 'Nuclear edge':'#fbb05b',
'Cell edge': '#EDABB5', 'Random':'#ACD0E4', 'Foci':'#FFD4AB', 'Radial':'#DDC4E0'}
def plot_tsne(data, label_col, title, legend_title, save_name):
plt.figure(figsize=(6, 4))
for label, group in data.groupby(label_col):
plt.scatter(x=group['tsne_x'], y=group['tsne_y'], color=color_map[label], label=label, s=5)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1)
ax.spines['bottom'].set_linewidth(1)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.yaxis.set_major_locator(MultipleLocator(20))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=16)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=16)
for label in ax.get_xticklabels():
label.set_fontweight('bold')
for label in ax.get_yticklabels():
label.set_fontweight('bold')
plt.grid(False)
plt.legend(title=legend_title, frameon=True, fontsize=12, title_fontsize=13, markerscale=3.0, bbox_to_anchor=(1, 1), loc='upper left')
plt.title(title, fontsize=16)
plt.xlabel('t-SNE 1', fontsize=16,fontweight='bold')
plt.ylabel('t-SNE 2', fontsize=16,fontweight='bold')
plt.tight_layout()
for ext in ['png', 'pdf', 'svg']:
plt.savefig(f'{save_name}.{ext}', bbox_inches='tight', dpi=300)
plt.show()
```
```python
dataset = "merfish_u2os_group4"
params_list = [(0.2, 0.8)]
for a, b in params_list:
path1 = f'../1.5_benchmark/figure/{dataset}/s1_tsne_ours_a{a}_b{b}'
path2 = f'../1.5_benchmark/figure/{dataset}/s1_tsne_gt_a{a}_b{b}'
our_label = pd.read_csv(f'../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv')
plot_tsne(data=our_label, label_col='pattern', title='', legend_title='GRASP', save_name=path1)
plot_tsne(data=our_label, label_col='groundtruth_wzx', title='', legend_title='Ground truth', save_name=path2)
```
##### step7: Statistical data distribution
```python
for group in ['group1']:
tmp = pd.read_csv(f"../7_classifier/predicted/sampled_data_{group}.csv")
gene_label_counts = pd.crosstab(tmp['gene'], tmp['groundtruth_wzx'])
fig, ax = plt.subplots(figsize=(5, 2.5))
sns.heatmap(gene_label_counts.T, annot=True, fmt='d', cmap='GnBu')
ax.set_xlabel('')
ax.set_ylabel('')
ax.tick_params(axis='x', labelsize=10)
ax.tick_params(axis='y', labelsize=10)
for label in ax.get_xticklabels():
label.set_fontweight('bold')
for label in ax.get_yticklabels():
label.set_fontweight('bold')
plt.tight_layout()
plt.show()
```
##### step8: Analyze individual genes
```python
def plot_gene_tsne(gene, pattern1='Nuclear', pattern2='Random', a=0.2, b=0.8,
dataset="merfish_u2os_group4", save_path="../1.5_benchmark/figure"):
df = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv")
color_map = {'Nuclear': '#ed6ca4', 'Cytoplasmic': '#7bc4e2','Protrusion': '#acd372', 'Nuclear edge':'#fbb05b',
'Cell edge': '#EDABB5', 'Random':'#ACD0E4', 'Foci':'#FFD4AB', 'Radial':'#DDC4E0'}
cond1 = (df['gene'] == gene) & (df['pattern'] == pattern1)
cond2 = (df['gene'] == gene) & (df['pattern'] == pattern2)
cond_gray = ~ (cond1 | cond2)
# 统计信息
patterns = ['Nuclear edge', 'Random', 'Cytoplasmic', 'Cell edge', 'Nuclear','Foci']
for pattern in patterns:
count = df[(df['gene'] == gene) & (df['pattern'] == pattern)].shape[0]
print(f"{gene} - {pattern}: {count} cells")
# 绘图
plt.figure(figsize=(5, 4))
plt.scatter(df.loc[cond_gray, 'tsne_x'], df.loc[cond_gray, 'tsne_y'],
color='#e5e9ea', label='Others', s=10)
for cond, label in [(cond1, f'{gene}, {pattern1}'), (cond2, f'{gene}, {pattern2}')]:
if cond.any():
pattern_type = df.loc[cond, 'pattern'].iloc[0]
color = color_map.get(pattern_type, '#000000')
plt.scatter(df.loc[cond, 'tsne_x'], df.loc[cond, 'tsne_y'], color=color, label=label, s=15)
# 坐标轴设置
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.yaxis.set_major_locator(MultipleLocator(20))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=14)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=14)
for label in ax.get_xticklabels():
label.set_fontweight('bold')
for label in ax.get_yticklabels():
label.set_fontweight('bold')
plt.grid(False)
ax.legend(title='Label', fontsize=12, title_fontsize=14, markerscale=2.0, loc='lower center', bbox_to_anchor=(0.5, 1.05), columnspacing=1, ncol=2, frameon=True)
plt.xlabel('t-SNE 1', fontsize=16, fontweight='bold')
plt.ylabel('t-SNE 2', fontsize=16, fontweight='bold')
output_prefix = f"{save_path}/{dataset}/s5_merfish_{gene}"
plt.savefig(f"{output_prefix}.png", dpi=300, bbox_inches='tight')
plt.savefig(f"{output_prefix}.pdf", bbox_inches='tight')
plt.savefig(f"{output_prefix}.svg", bbox_inches='tight')
plt.show()
```
```python
plot_gene_tsne(gene='SRRM2', pattern1='Nuclear', pattern2='Random')
```
SRRM2 - Nuclear edge: 0 cells
SRRM2 - Random: 30 cells
SRRM2 - Cytoplasmic: 1 cells
SRRM2 - Cell edge: 2 cells
SRRM2 - Nuclear: 67 cells
SRRM2 - Foci: 0 cells
```python
plot_gene_tsne(gene='TLN1', pattern1='Cell edge', pattern2='Cytoplasmic')
```
TLN1 - Nuclear edge: 2 cells
TLN1 - Random: 1 cells
TLN1 - Cytoplasmic: 57 cells
TLN1 - Cell edge: 40 cells
TLN1 - Nuclear: 0 cells
TLN1 - Foci: 0 cells
```python
a, b= 0.2, 0.8
dataset = "merfish_u2os_group4"
df = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv")
gene1, pattern1 = "COL5A1", "Nuclear edge"
gene2, pattern2 = "MALAT1", "Nuclear"
gene3, pattern3 = "SPTBN1", "Cytoplasmic"
cond_red = (df['gene'] == gene1) & (df['pattern'] == pattern1)
cond_green = (df['gene'] == gene2) & (df['pattern'] == pattern2)
cond_blue = (df['gene'] == gene3) & (df['pattern'] == pattern3)
cond_gray = ~ (cond_red | cond_green | cond_blue)
count = df[(df['gene'] == gene1) & (df['pattern'] == pattern1)].shape[0]
print(f"{gene1} - {pattern1}: {count} cells")
count = df[(df['gene'] == gene2) & (df['pattern'] == pattern2)].shape[0]
print(f"{gene2} - {pattern2}: {count} cells")
count = df[(df['gene'] == gene3) & (df['pattern'] == pattern3)].shape[0]
print(f"{gene3} - {pattern3}: {count} cells")
# 创建画布
plt.figure(figsize=(5, 4))
plt.scatter(df.loc[cond_gray, 'tsne_x'], df.loc[cond_gray, 'tsne_y'], color='#e5e9ea', label='Others',s=10)
plt.scatter(df.loc[cond_red, 'tsne_x'], df.loc[cond_red, 'tsne_y'], color='#fbb05b', label=f'{gene1}, {pattern1}',s=15)
plt.scatter(df.loc[cond_green, 'tsne_x'], df.loc[cond_green, 'tsne_y'], color='#ed6ca4', label=f'{gene2}, {pattern2}',s=15)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.xaxis.set_major_locator(MultipleLocator(20))
ax.yaxis.set_major_locator(MultipleLocator(20))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=14)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=14)
for label in ax.get_xticklabels():
label.set_fontweight('bold')
for label in ax.get_yticklabels():
label.set_fontweight('bold')
plt.grid(False)
ax.legend(title='Label', fontsize=12, title_fontsize=14, markerscale=2.0, loc='lower center', bbox_to_anchor=(0.5, 1.05),
columnspacing=1, ncol=2, frameon=True)
plt.xlabel('t-SNE 1', fontsize=16, fontweight='bold')
plt.ylabel('t-SNE 2', fontsize=16, fontweight='bold')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/merfish_three.png', dpi=300, bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/merfish_three.pdf', bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/merfish_three.svg', bbox_inches='tight')
plt.show()
```
COL5A1 - Nuclear edge: 92 cells
MALAT1 - Nuclear: 93 cells
SPTBN1 - Cytoplasmic: 48 cells