### 03_SeqFISH+_fibroblast_cell_line
#### 1_dataset details
| Dataset | Cell number | Gene number | Graph | Pattern |
| -------------------- | --------- |------------ |------- | ---------------- |
| seqfish_fibroblast | 171 | 2747 | —— | —— |
| seqfish_fibroblast(part) | 171 | 60 | 8068 | 3 seqfish+ original study |
- seqfish_fibroblast中的 60 个gene为:
| Pattern | Gene set |
| ---------------------------- | ------------------------------------------ |
| Nuclear or nuclear edge | Col1a1, Fn1, Fbln2, Col6a2, Bgn, Nid1, Lox, P4hb, Aebp1, Emp1, Col5a1, Sdc4, Postn, Col3a1, Pdia6, Col5a2, Itgb1, Calu, Pdia3, Cyr61 |
| Cytoplasmic | Ddb1, Myh9, Actn1, Tagln2, Kpnb1, Hnrnpf, Ppp1ca, Hnrnpl, Pcbp1, Tagln, Fscn1, Psat1, Cald1, Snd1, Uba1, Hnrnpm, Cap1, Ssrp1, Ugdh, Caprin1 |
| Protrusion | Cyb5r3, Sh3pxd2a, Ddr2, Net1, Trak2, Kif1c, Kctd10, Dynll2, Arhgap11a, Gxylt1, H6pd, Gdf11, Dync1li2, Palld, Ppfia1, Naa50,Ptgfr, Zeb1, Arhgap32, Scd1 |
#### 2_GRASP preprocessing
##### step1: Load data
```python
dataset = "seqfish_fibroblast"
outfile = f'../1_input/pkl_data/{dataset}_data_dict.pkl'
with open(outfile, 'rb') as f:
pickle_dict = pickle.load(f)
df_registered = pickle_dict['df_registered']
cell_radii = pickle_dict['cell_radii']
cell_boundary = pickle_dict['cell_boundary']
nuclear_boundary = pickle_dict['nuclear_boundary']
nuclear_boundary_df_registered = pickle_dict['nuclear_boundary_df_registered']
```
##### step2: Visualize the original and normalized TSGs
```python
path = "../2_scaled_cell"
df_registered = df_registered[df_registered['gene']=='Cyb5r3']
cep.plot_raw_gene_distribution(dataset, cell_boundary, nuclear_boundary, df_registered, path)
```
Processing cells: 0%| | 0/179 [00:00, ?it/s]
Processing cells: 100%|██████████| 179/179 [02:54<00:00, 1.02it/s]
All cell images have been saved to ../2_scaled_cell/seqfish_fibroblast/raw_gene/cell_9-16
##### step3: Statistical valid TSGs
```python
cell_list = df_registered['cell'].unique()
gene_list = df_registered['gene'].unique()
print(f'{len(cell_list)} cells - {len(gene_list)} genes')
save_dir = f'../3_filter/{dataset}/'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
no_points_list = []
low_points_list = []
grouped = df_registered.groupby(['gene', 'cell'], observed=True).size().reset_index(name='num_points')
print(grouped)
gene_cell_count = df_registered.groupby('gene', observed=True)['cell'].nunique()
valid_genes = gene_cell_count[gene_cell_count >= 10].index
invalid_genes = gene_cell_count[gene_cell_count < 10].index
print("List of invalid genes (expressed in fewer than 10 cells):")
print(invalid_genes.tolist())
df_registered = df_registered[df_registered['gene'].isin(valid_genes)]
filtered_grouped = grouped[grouped['gene'].isin(valid_genes)]
print(filtered_grouped)
for _, row in tqdm(filtered_grouped.iterrows(), total=len(filtered_grouped), desc="Processing genes and cells"):
gene = row['gene']
cell = row['cell']
num_points = row['num_points']
if num_points == 0:
no_points_list.append({'gene': gene, 'cell': cell})
elif num_points < 10:
low_points_list.append({'gene': gene, 'cell': cell})
pd.DataFrame(no_points_list).to_csv(f'{save_dir}/all_no_points_list.csv', index=False)
pd.DataFrame(low_points_list).to_csv(f'{save_dir}/all_low_points_list.csv', index=False)
```
171 cells - 2747 genes
gene cell num_points
0 5830417i10rik 0-0 5
1 5830417i10rik 0-1 6
2 5830417i10rik 0-11 9
3 5830417i10rik 0-12 21
4 5830417i10rik 0-14 11
... ... ... ...
307043 Zzef1 9-1 5
307044 Zzef1 9-14 9
307045 Zzef1 9-4 5
307046 Zzef1 9-5 5
307047 Zzef1 9-8 5
[307048 rows x 3 columns]
List of invalid genes (expressed in fewer than 10 cells):
[]
gene cell num_points
0 5830417i10rik 0-0 5
1 5830417i10rik 0-1 6
2 5830417i10rik 0-11 9
3 5830417i10rik 0-12 21
4 5830417i10rik 0-14 11
... ... ... ...
307043 Zzef1 9-1 5
307044 Zzef1 9-14 9
307045 Zzef1 9-4 5
307046 Zzef1 9-5 5
307047 Zzef1 9-8 5
[307048 rows x 3 columns]
Processing genes and cells: 100%|██████████| 307048/307048 [00:26<00:00, 11770.99it/s]
```python
df_list = df_registered[['gene','cell']].copy()
print("raw df_list shape:", df_list.shape)
df_list_unique = df_list.drop_duplicates()
print("duplicate df_list shape:", df_list_unique.shape)
low_points_path = f'../3_filter/{dataset}/all_low_points_list.csv'
no_points_path = f'../3_filter/{dataset}/all_no_points_list.csv'
def safe_read_csv(path):
if os.path.exists(path):
try:
return pd.read_csv(path)
except pd.errors.EmptyDataError:
return pd.DataFrame(columns=['gene', 'cell'])
return pd.DataFrame(columns=['gene', 'cell'])
low_points_list = safe_read_csv(low_points_path)
print("low_points_list:", low_points_list.shape)
no_points_list = safe_read_csv(no_points_path)
print("no_points_list:", no_points_list.shape)
filter_list = pd.concat([low_points_list, no_points_list]).drop_duplicates()
print("filter_list:", filter_list.shape)
mask = ~df_list_unique.set_index(['gene','cell']).index.isin(
filter_list.set_index(['gene', 'cell']).index
)
result = df_list_unique[mask]
print("Result shape after filtering:", result.shape)
print("Number of unique cells:", len(result['cell'].unique()))
print("Number of unique genes:", len(result['gene'].unique()))
result.to_csv(f"../3_filter/{dataset}/load_graph_data.csv", index=False)
```
raw df_list shape: (4163087, 2)
duplicate df_list shape: (307048, 2)
low_points_list: (163259, 2)
no_points_list: (0, 2)
filter_list: (163259, 2)
Result shape after filtering: (143789, 2)
Number of unique cells: 171
Number of unique genes: 2734
##### step4: Cell partitioning
```python
import os
import pandas as pd
from tqdm import tqdm
import utils_code.partition as pat
from multiprocessing import Pool, cpu_count
dir = f"../4_partition_same/{dataset}_partition/"
os.makedirs(dir, exist_ok=True)
n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)
r = 1
result = pd.read_csv(f"../1_input/label/{dataset}_label.csv")
print("Number of TSGs:", result.shape)
df_registered_group = None
nuclear_boundary_group = None
def init_globals(df_reg, nuclear_boundary_reg):
global df_registered_group, nuclear_boundary_group
df_registered_group = df_reg.groupby("cell")
nuclear_boundary_group = nuclear_boundary_reg.groupby("cell")
def process_row(row):
target_cell = row["cell"]
target_gene = row["gene"]
try:
df = df_registered_group.get_group(target_cell)
df_filtered = df[df["gene"] == target_gene]
if df_filtered.empty:
return
nuclear_boundary_df = nuclear_boundary_group.get_group(target_cell)
except KeyError:
return
plot_dir = os.path.join(dir, f"{target_cell}/{target_cell}_{n_sectors}_{m_rings}_k{k_neighbor}")
csv_path = os.path.join(plot_dir, f"{target_gene}_node.csv")
if os.path.exists(csv_path):
return
os.makedirs(plot_dir, exist_ok=True)
count_matrix, center_points, point_counts, is_virtual, is_edge = pat.count_points_in_areas_same(df_filtered, n_sectors, m_rings, r)
nuclear_positions = pat.classify_center_points_with_edge(center_points, nuclear_boundary_df, is_edge)
edges = pat.build_graph_k_nearest(center_points, k=k_neighbor)
G = pat.build_graph_with_networkx(center_points, edges, is_virtual)
pat.save_node_data_to_csv_old(center_points, is_virtual, plot_dir, target_gene, point_counts, k=k_neighbor, nuclear_positions=nuclear_positions)
if __name__ == "__main__":
import multiprocessing
with Pool(processes=cpu_count(), initializer=init_globals,
initargs=(df_registered, nuclear_boundary_df_registered)) as pool:
list(tqdm(pool.imap_unordered(process_row, [row for _, row in result.iterrows()]), total=result.shape[0], desc="Parallel processing"))
```
Number of TSGs: (8068, 3)
Parallel processing: 100%|██████████| 8068/8068 [05:22<00:00, 24.99it/s]
##### step5: Enhancement of TSGs
```python
import os
import pandas as pd
import random
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import utils_code.augumentation as aug
dataset = "seqfish_fibroblast"
n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)
dropout_ratios = [0.1, 0.2, 0.3]
dir = f"../4_partition_same/{dataset}_partition/"
def process_cell_gene(row_dict):
cell = row_dict['cell']
gene = row_dict['gene']
path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}"
save_path = f"{dir}/{cell}/{cell}_{n_sectors}_{m_rings}_k{k_neighbor}_aug"
nodes_file = f'{path}/{gene}_node_matrix.csv'
adj_file = f'{path}/{gene}_adj_matrix.csv'
if not os.path.exists(nodes_file) or not os.path.exists(adj_file):
return f"skip {cell} - {gene}"
try:
node_matrix = pd.read_csv(nodes_file)
adj_matrix = pd.read_csv(adj_file)
random_angle = random.uniform(0, 360)
node_matrix_rotated = aug.rotate_nodes(node_matrix.copy(), random_angle)
real_nodes_count = (node_matrix_rotated['is_virtual'] == 0).sum()
os.makedirs(save_path, exist_ok=True)
if real_nodes_count >= 10:
if real_nodes_count <= 100:
dropout_ratio = dropout_ratios[0]
elif real_nodes_count <= 150:
dropout_ratio = dropout_ratios[1]
else:
dropout_ratio = dropout_ratios[2]
adj_matrix_dropped, node_matrix_dropped = aug.dropout_nodes(
adj_matrix.copy(), node_matrix_rotated.copy(), dropout_ratio)
adj_matrix_dropped.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
node_matrix_dropped.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
else:
adj_matrix.to_csv(f"{save_path}/{gene}_adj_matrix.csv", index=False)
node_matrix_rotated.to_csv(f"{save_path}/{gene}_node_matrix.csv", index=False)
return f"finish {cell} - {gene}"
except Exception as e:
return f"error {cell} - {gene}:{str(e)}"
df_registered = pd.read_csv(f"../1_input/label/{dataset}_label.csv")
com_gene = ["Col1a1", "Fn1", "Fbln2", "Col6a2", "Bgn", # nuclear or nuclear edge
"Nid1", "Lox", "P4hb", "Aebp1", "Emp1",
"Col5a1", "Sdc4", "Postn", "Col3a1", "Pdia6",
"Col5a2", "Itgb1", "Calu", "Pdia3", "Cyr61",
"Ddb1", "Myh9", "Actn1", "Tagln2", "Kpnb1", # cytoplasmic
"Hnrnpf", "Ppp1ca", "Hnrnpl", "Pcbp1", "Tagln",
"Fscn1", "Psat1", "Cald1", "Snd1", "Uba1",
"Hnrnpm", "Cap1", "Ssrp1", "Ugdh", "Caprin1",
"Cyb5r3", "Sh3pxd2a", "Ddr2", "Net1", "Trak2", # 20 protrusion
"Kif1c", "Kctd10", "Dynll2", "Arhgap11a", "Gxylt1",
"H6pd", "Gdf11", "Dync1li2", "Palld", "Ppfia1",
"Naa50","Ptgfr", "Zeb1", "Arhgap32", "Scd1"]
# df_registered = df_registered[~df_registered['gene'].isin(com_gene)]
print(df_registered.shape)
task_list = df_registered[['cell', 'gene']].drop_duplicates().to_dict(orient='records')
with ProcessPoolExecutor(max_workers=8) as executor:
results = list(tqdm(executor.map(process_cell_gene, task_list), total=len(task_list), desc="In multi-process processing"))
for r in results:
print(r)
```
#### 3_GRASP training
##### step6: Load all TSGs to prepare for training
```python
import gnn_model.graphloader as gra
dataset = "seqfish_fibroblast"
n_sectors = 30
m_rings = 15
k_neighbor = int((n_sectors * m_rings) / 10)
df = pd.read_csv(f"../1_input/label/{dataset}_label.csv")
print(df.shape)
path = f"../4_partition_same/{dataset}_partition"
original_graphs, augmented_graphs = gra.generate_graph_data_target(dataset, df, path, n_sectors, m_rings, k_neighbor)
print(len(original_graphs))
print(len(augmented_graphs))
gene_labels = [data.gene for data in original_graphs]
cell_labels = [data.cell for data in original_graphs]
```
(8068, 3)
Processing Graphs generate_graph_data_target: 100%|██████████| 8068/8068 [16:37<00:00, 8.09it/s]
8068
8068
```python
dataset = "seqfish_fibroblast"
graphs_number = len(original_graphs)
cell_numbers = len(df['cell'].unique())
gene_numbers = len(df['gene'].unique())
print(f"cell_numbers:{cell_numbers} - gene_numbers:{gene_numbers} - graphs_number:{graphs_number}")
save_path = f"../5.1_graph_data"
if not os.path.exists(save_path):
os.makedirs(save_path)
graph_data = {"original_graphs": original_graphs,
"augmented_graphs": augmented_graphs,
"gene_labels": gene_labels,
"cell_labels": cell_labels}
save_file = f"{save_path}/{dataset}_cell{cell_numbers}_gene{gene_numbers}_graph{graphs_number}.pkl"
with open(save_file, 'wb') as f:
pickle.dump(graph_data, f)
print(f"Graph data saved to {save_file}")
```
cell_numbers:171 - gene_numbers:59 - graphs_number:8068
Graph data saved to ../5.1_graph_data/seqfish_fibroblast_cell171_gene59_graph8068.pkl
##### step7: Clustering and identifying spatial localization patterns
```python
dataset = "seqfish_fibroblast"
a, b, lr, epoch = 0.2, 0.8, 0.005, 200
our_label = pd.read_csv(f'../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv')
color_map = {'Nuclear or nuclear edge': '#fbb05b', 'Cytoplasmic': '#7bc4e2','Protrusion': '#EDABB5'}
def plot_tsne(data, label_col, title, legend_title, save_name):
plt.figure(figsize=(7, 4))
for label, group in data.groupby(label_col):
plt.scatter(x=group['tsne_x'], y=group['tsne_y'], color=color_map[label], label=label, s=2)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1)
ax.spines['bottom'].set_linewidth(1)
ax.xaxis.set_major_locator(MultipleLocator(40))
ax.yaxis.set_major_locator(MultipleLocator(40))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=16)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=16)
for label in ax.get_xticklabels():
label.set_fontweight('bold')
for label in ax.get_yticklabels():
label.set_fontweight('bold')
plt.grid(False)
plt.legend(title=legend_title, frameon=True, fontsize=12, title_fontsize=13, markerscale=5.0, bbox_to_anchor=(1, 1), loc='upper left')
plt.title(title, fontsize=16)
plt.xlabel('t-SNE 1', fontsize=16,fontweight='bold')
plt.ylabel('t-SNE 2', fontsize=16,fontweight='bold')
plt.tight_layout()
for ext in ['png', 'pdf', 'svg']:
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{save_name}.{ext}', bbox_inches='tight', dpi=300)
plt.show()
plot_tsne(data=our_label, label_col='graph_level_pattern', title='', legend_title='GRASP', save_name='s1_graphlevel_tsne_ours')
plot_tsne(data=our_label, label_col='groundtruth', title='', legend_title='Ground truth', save_name='s1_graphlevel_tsne_gt')
```
```python
dataset = "seqfish_fibroblast"
our_label = pd.read_csv(f'../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}_genelevel.csv')
color_map = {'Nuclear or nuclear edge': '#fbb05b', 'Cytoplasmic': '#7bc4e2','Protrusion': '#EDABB5'}
def plot_gene_tsne(data, label_col, title, legend_title, save_name):
plt.figure(figsize=(3.5, 3), dpi=120)
for label, group in data.groupby(label_col):
plt.scatter(x=group['tsne_x'], y=group['tsne_y'], color=color_map[label], label=label, s=20)
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1)
ax.spines['bottom'].set_linewidth(1)
ax.xaxis.set_major_locator(MultipleLocator(2))
ax.yaxis.set_major_locator(MultipleLocator(2))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, labelsize=14)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=14)
for label in ax.get_xticklabels():
label.set_fontweight('bold')
for label in ax.get_yticklabels():
label.set_fontweight('bold')
plt.grid(False)
plt.title(title, fontsize=16)
plt.xlabel('t-SNE 1', fontsize=14, fontweight='bold')
plt.ylabel('t-SNE 2', fontsize=14, fontweight='bold')
plt.tight_layout()
for ext in ['png', 'pdf', 'svg']:
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{save_name}.{ext}', bbox_inches='tight', dpi=300)
plt.show()
plot_gene_tsne(data=our_label, label_col='gene_level_pattern', title='', legend_title='Spatial pattern',save_name='s2_genelevel_tsne_ours')
plot_gene_tsne(data=our_label, label_col='groundtruth', title='', legend_title='Spatial pattern',save_name='s2_genelevel_tsne_gt')
```
##### step8: Calculate the GRASP similarity of the embedding
```python
dataset, cell_numbers, gene_numbers, graphs_number = "seqfish_fibroblast", 171, 59, 8068
a, b, lr, epoch = 0.2, 0.8, 0.005, 200
batch = "seqfish_cosine_nocluster_noclip_js_a02b08_20250704_153934"
save_path = f"../6.1_embedding/{dataset}/graph{graphs_number}/gat_moco3/pos4_temp0.07_a{a}_b{b}_c0.0/n30_m15_cluster3_uniform_scaler/{batch}"
df = pd.read_csv(f"{save_path}/epoch{epoch}_lr{lr}_embedding.csv")
label = pd.read_csv(f"../1_input/label/{dataset}_label.csv")
print(label.shape)
```
(8068, 3)
```python
list1 = ["Col1a1", "Fn1", "Fbln2", "Col6a2", "Bgn", # nuclear or nuclear edge genes
"Nid1", "Lox", "P4hb", "Aebp1", "Emp1",
"Col5a1", "Sdc4", "Postn", "Col3a1", "Pdia6",
"Col5a2", "Itgb1", "Calu", "Pdia3", "Cyr61"]
list2 = ["Ddb1", "Myh9", "Actn1", "Tagln2", "Kpnb1", # cytoplasmic
"Hnrnpf", "Ppp1ca", "Hnrnpl", "Pcbp1", "Tagln",
"Fscn1", "Psat1", "Cald1", "Snd1", "Uba1",
"Hnrnpm", "Cap1", "Ssrp1", "Ugdh", "Caprin1"]
list3 = ["Cyb5r3", "Sh3pxd2a", "Ddr2", "Net1", "Trak2", # 20 protrusion
"Kif1c", "Kctd10", "Dynll2", "Arhgap11a", "Gxylt1",
"H6pd", "Gdf11", "Dync1li2", "Palld", "Ppfia1",
"Naa50","Ptgfr", "Zeb1", "Arhgap32", "Scd1"]
def assign_groundtruth(gene):
if gene in list1:
return 'Nuclear or nuclear edge'
elif gene in list2:
return 'Cytoplasmic'
elif gene in list3:
return 'Protrusion'
else:
return 'unknown'
df_copy = df.copy(deep=True) # embedding
label_copy = label.copy(deep=True) # label groundtruth
feature_cols = [f'feature_{i}' for i in range(1, 129)] # 实际使用时改为range(1, 129)
df_copy = df_copy.groupby('gene')[feature_cols].mean().reset_index()
df_copy['groundtruth'] = df_copy['gene'].apply(assign_groundtruth)
true_labels = df_copy['groundtruth'].astype(str).values
features = df_copy.drop(columns=['gene', 'groundtruth']).values # 特征
pattern_groups = df_copy.groupby('groundtruth')
pattern_means = pattern_groups[feature_cols].mean()
```
```python
similarity_data = []
# ---- 1. Pattern internal similarity ----
for pattern, group in pattern_groups:
embeddings = group[feature_cols].values
sim_matrix = cosine_similarity(embeddings)
triu_indices = np.triu_indices_from(sim_matrix, k=1)
sims = sim_matrix[triu_indices]
for s in sims:
similarity_data.append({
'pattern_pair': f'{pattern}',
'similarity': s,
'type': 'within'
})
# ---- 2. Similarity between patterns ----
for p1, p2 in combinations(pattern_groups.groups.keys(), 2):
group1 = pattern_groups.get_group(p1)
group2 = pattern_groups.get_group(p2)
sims = cosine_similarity(group1[feature_cols].values, group2[feature_cols].values).flatten()
for s in sims:
similarity_data.append({
'pattern_pair': f'{p1} to {p2}',
'similarity': s,
'type': 'between'
})
sim_df = pd.DataFrame(similarity_data)
sim_df_within = sim_df[sim_df['type'] == 'within']
sim_df_between = sim_df[sim_df['type'] == 'between']
sim_df_clean = pd.concat([sim_df_within, sim_df_between])
short_names = {'Cytoplasmic': 'Cyto', 'Protrusion': 'Pro', 'Nuclear or nuclear edge': 'Nuc or NE'}
unique_patterns = sim_df_clean['pattern_pair'].unique()
label_map = {}
for p in unique_patterns:
parts = p.split(' to ')
if len(parts) == 2:
a, b = parts
a_short = short_names.get(a.strip(), a.strip())
b_short = short_names.get(b.strip(), b.strip())
label_map[p] = f"{a_short} → {b_short}"
else:
label_map[p] = short_names.get(p.strip(), p.strip())
sim_df_clean['pattern_pair'] = sim_df_clean['pattern_pair'].replace(label_map)
```
```python
my_palette = {'within': '#EDABB5', 'between': '#7bc4e2'}# 'pastel'
fig, ax = plt.subplots(figsize=(4, 4))
sim_df = sim_df_clean
def sort_by_median(data, y_col='similarity', x_col='pattern_pair'):
medians = data.groupby(x_col)[y_col].median().sort_values(ascending=False)
return medians.index.tolist()
sorted_pairs = sort_by_median(sim_df)
sns.boxplot(data=sim_df, x='pattern_pair', y='similarity', hue='type', palette=my_palette,
width=0.5, order=sorted_pairs, ax=ax, showcaps=True, showbox=True, showfliers=True, linewidth=1,
saturation=0.9,flierprops = dict(marker='o', markersize=1, markerfacecolor='black', markeredgecolor='black', alpha=0.6)
)
ax.set_ylabel("Cosine Similarity", fontsize=14)
ax.set_xlabel("", fontsize=12)
ax.spines['bottom'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_linewidth(1)
ax.spines['left'].set_linewidth(1)
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.5))
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=14)
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, rotation=45, labelsize=14)
ax.grid(False)
sns.despine(ax=ax, top=True, right=True)
plt.legend(title='Type', bbox_to_anchor=(1, 1), loc='upper left')
plt.tight_layout()
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_embedding1.png', dpi=300, bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_embedding1.pdf', bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_embedding1.svg', bbox_inches='tight')
plt.show()
stats = []
for pair in sorted_pairs:
pair_data = sim_df[sim_df['pattern_pair'] == pair]['similarity']
q1 = np.percentile(pair_data, 25)
q2 = np.percentile(pair_data, 50)
q3 = np.percentile(pair_data, 75)
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
outliers = pair_data[(pair_data < lower_bound) | (pair_data > upper_bound)]
stats.append({
'Pattern Pair': pair,
'Count': len(pair_data),
'Median': q2,
'Q1': q1,
'Q3': q3,
'IQR': iqr,
'Lower Bound': lower_bound,
'Upper Bound': upper_bound
})
stats_df = pd.DataFrame(stats)
stats_df.to_csv(f'../1.5_benchmark/figure/{dataset}/{dataset}_pattern_pair_stats.csv', index=False)
```
```python
def get_significance_symbol(p):
if p <= 0.0001:
return '****'
elif p <= 0.001:
return '***'
elif p <= 0.01:
return '**'
elif p <= 0.05:
return '*'
else:
return 'n.s.'
custom_colors = {'Cyto': '#7bc4e2', 'Pro': '#fbb05b', 'Nuc or NE': '#ed6ca4',
'Cyto → Nuc or NE': '#f98fac', 'Cyto → Pro': '#d6b3e4','Nuc or NE → Pro': '#aad8d3'}
short_names = {'Cytoplasmic': 'Cyto', 'Protrusion': 'Pro','Radial': 'Rad', 'Nuclear edge': 'NE', 'Cell edge': 'CE', 'Nuclear': 'Nuc', 'Foci': 'Foci', 'Random': 'Rnd'}
unique_patterns = sim_df_clean['pattern_pair'].unique()
label_map = {}
for p in unique_patterns:
parts = p.split(' to ')
if len(parts) == 2:
a, b = parts
a_short = short_names.get(a.strip(), a.strip())
b_short = short_names.get(b.strip(), b.strip())
label_map[p] = f"{a_short}→{b_short}"
else:
label_map[p] = short_names.get(p.strip(), p.strip())
sim_df_clean['pattern_pair'] = sim_df_clean['pattern_pair'].replace(label_map)
fig, axes = plt.subplots(1, 3, figsize=(8, 4))
sim_df1 = sim_df[sim_df['pattern_pair'].str.contains('Cyto')]
sim_df2 = sim_df[sim_df['pattern_pair'].str.contains('Pro')]
sim_df3 = sim_df[sim_df['pattern_pair'].str.contains('Nuc or NE')]
patterns = [sim_df1, sim_df2, sim_df3]
titles = ["Cytoplasmic", "Protrusion", "Nuclear or nuclear edge"]
reference_keywords = ['Cyto', 'Pro', 'Nuc or NE']
p_values = []
for i, (df_sub, ax, title) in enumerate(zip(patterns, axes, titles)):
ref_keyword = reference_keywords[i]
ref_group = [g for g in df_sub['pattern_pair'].unique() if ref_keyword in g]
if not ref_group:
continue
ref_group = ref_group[0]
other_groups = [g for g in df_sub['pattern_pair'].unique() if g != ref_group]
medians = df_sub[df_sub['pattern_pair'].isin(other_groups)].groupby('pattern_pair')['similarity'].median()
sorted_other = medians.sort_values(ascending=False).index.tolist()
sorted_groups = [ref_group] + sorted_other
sns.violinplot(data=df_sub, x='pattern_pair', y='similarity', palette=custom_colors, saturation=0.8, order=sorted_groups,
width=0.8, ax=ax, linewidth=1, inner=None)
sns.boxplot(data=df_sub, x='pattern_pair', y='similarity', palette=custom_colors, order=sorted_groups,
width=0.3, ax=ax, showcaps=False, showbox=True, showfliers=False, linewidth=1,
saturation=0.9, legend=False)
ax.set_title(title, fontsize=16, weight='bold', pad=25)
ax.set_ylabel("Cosine Similarity", fontsize=14,fontweight='bold')
ax.set_xlabel("", fontsize=12)
ax.spines['bottom'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_linewidth(1)
ax.spines['left'].set_linewidth(1)
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.25))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, rotation=45, labelsize=14)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=12)
for label in ax.get_xticklabels():
label.set_fontweight('bold')
for label in ax.get_yticklabels():
label.set_fontweight('bold')
sns.despine(ax=ax, top=True, right=True)
ax.grid(False)
ax.set_ylim(0.4, 1.24)
sns.despine(ax=ax)
groups = sorted_groups
group_map = {g: df_sub[df_sub['pattern_pair'] == g]['similarity'] for g in groups}
y_max = df_sub['similarity'].max()
y_range = 0.5
ref_values = group_map[ref_group]
for j, g in enumerate(groups):
if g == ref_group:
continue
values = group_map[g]
stat, p = mannwhitneyu(ref_values, values, alternative='two-sided')
x1 = groups.index(ref_group)
x2 = groups.index(g)
step = 0.06 * (j + 1)
line_y = 0.96 + step
text_y = line_y + 0.008
ax.plot([x1, x1, x2, x2], [line_y, line_y + 0.02, line_y+0.02, line_y], lw=1, c='black')
ax.text((x1 + x2) / 2, text_y+0.001, get_significance_symbol(p), ha='center', va='bottom', fontsize=10)
p_values.append({
'Pattern Pair': f"{ref_group} vs {g}",
'p-value': p,
'Significance': get_significance_symbol(p)
})
plt.tight_layout()
plt.subplots_adjust(wspace=0.8, hspace=0.6)
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_emebdding2.png', dpi=300, bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_emebdding2.pdf', bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_emebdding2.svg', bbox_inches='tight')
plt.show()
p_values_df = pd.DataFrame(p_values)
p_values_df.to_csv(f'../1.5_benchmark/figure/{dataset}/{dataset}_p_values3.csv', index=False)
```
```python
p_values_df
```
|
Pattern Pair |
p-value |
Significance |
| 0 |
Cyto vs Cyto → Pro |
3.343267e-78 |
**** |
| 1 |
Cyto vs Cyto → Nuc or NE |
8.777574e-86 |
**** |
| 2 |
Pro vs Cyto → Pro |
3.641106e-54 |
**** |
| 3 |
Pro vs Nuc or NE → Pro |
1.352650e-66 |
**** |
| 4 |
Nuc or NE vs Cyto → Nuc or NE |
7.410871e-80 |
**** |
| 5 |
Nuc or NE vs Nuc or NE → Pro |
1.768558e-79 |
**** |
```python
my_palette = {'within': '#EDABB5', 'between': '#7bc4e2'}# 'pastel'
p_values = []
fig, ax = plt.subplots(figsize=(2.5, 3))
sns.violinplot(data=sim_df, x='type', y='similarity', hue='type', palette=my_palette, saturation=0.8,width=0.8, ax=ax, linewidth=1, inner=None)
sns.boxplot(data=sim_df, x='type', y='similarity', hue='type', palette=my_palette,
width=0.3, ax=ax, showcaps=False, showbox=True, showfliers=False, linewidth=1,
saturation=0.9, legend=False)
# ax.set_title("Gene Embedding Similarity", weight='bold',fontsize=16, pad=20)
ax.set_ylabel("Cosine Similarity", fontsize=14)
ax.set_xlabel("", fontsize=12)
# ax.tick_params(axis='x')
ax.spines['bottom'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_linewidth(1)
ax.spines['left'].set_linewidth(1)
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(0.25))
ax.tick_params(axis='x', which='both', direction='out', length=3, width=1, color='black', top=False, bottom=True, rotation=45, labelsize=14)
ax.tick_params(axis='y', which='both', direction='out', length=3, width=1, color='black', right=False, left=True, labelsize=12)
sns.despine(ax=ax, top=True, right=True)
ax.grid(False)
ax.set_ylim(0.4, 1.14)
within = sim_df[sim_df['type'] == 'within']['similarity']
between = sim_df[sim_df['type'] == 'between']['similarity']
stat, p = mannwhitneyu(within, between, alternative='two-sided')
if p <= 0.0001:
symbol = '****'
elif p <= 0.001:
symbol = '***'
elif p <= 0.01:
symbol = '**'
elif p <= 0.05:
symbol = '*'
else:
symbol = 'n.s.'
y_max = sim_df['similarity'].max()
y_min = sim_df['similarity'].min()
y_range = y_max - y_min - 0.1
line_y = y_max + 0.15 * y_range
text_y = line_y + 0.01 * y_range
x1, x2 = 0, 1
ax.plot([x1, x1, x2, x2], [line_y, line_y + 0.02, line_y + 0.02, line_y], lw=1, c='black')
ax.text((x1 + x2) / 2, text_y, symbol, ha='center', va='bottom', fontsize=11)
plt.tight_layout()
p_values.append({'Pattern Pair': 'within vs between','p-value': p,'Significance': get_significance_symbol(p)})
p_values_df = pd.DataFrame(p_values)
print(p_values_df)
p_values_df.to_csv(f'../1.5_benchmark/figure/{dataset}/{dataset}_p_values4.csv', index=False)
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_emebdding3.png', dpi=300, bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_emebdding3.pdf', bbox_inches='tight')
plt.savefig(f'../1.5_benchmark/figure/{dataset}/{dataset}_emebdding3.svg', bbox_inches='tight')
plt.show()
```
Pattern Pair p-value Significance
0 within vs between 3.300398e-214 ****
##### step9: Plot a TSG clustering heatmap
```python
a, b = 0.2, 0.8
dataset = "seqfish_fibroblast"
df = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv")
gene_counts = df.groupby('groundtruth')['gene'].nunique().reset_index()
gene_counts.columns = ['groundtruth', 'unique_genes']
df['cell_gene'] = df['cell'] + '_' + df['gene']
cell_gene_counts = df.groupby('groundtruth')['cell_gene'].nunique().reset_index()
cell_gene_counts.columns = ['groundtruth', 'unique_cell_genes']
stats_df = pd.merge(gene_counts, cell_gene_counts, on='groundtruth')
print(stats_df)
```
groundtruth unique_genes unique_cell_genes
0 Cytoplasmic 20 3322
1 Nuclear or nuclear edge 20 3233
2 Protrusion 19 1513
```python
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
a, b = 0.2, 0.8
dataset = "seqfish_fibroblast"
selected_genes = [
'Dync1li2', 'Gxylt1', 'Zeb1', 'Trak2', 'Ptgfr', 'Kif1c', 'Arhgap11a', 'Scd1', 'Ppfia1', 'Net1',
'Arhgap32', 'Sh3pxd2a', 'Ddr2', 'Naa50', 'Palld', 'Dynll2', 'Gdf11', 'Kctd10', 'Nid1', 'Fn1',
'Col1a1', 'Fbln2', 'Cyr61', 'P4hb', 'Aebp1', 'Sdc4', 'Col6a2', 'Emp1', 'Itgb1', 'Col5a1',
'Pdia6', 'Calu', 'Lox', 'Postn', 'Col5a2', 'Col3a1', 'Pdia3', 'Bgn', 'Tagln2', 'Myh9', 'Ddb1',
'Ppp1ca', 'Hnrnpl', 'Hnrnpf', 'Actn1', 'Kpnb1', 'Fscn1', 'Cald1', 'Pcbp1', 'Tagln', 'Psat1',
'Caprin1', 'Uba1', 'Snd1', 'Cyb5r3', 'Cap1', 'Ugdh', 'Hnrnpm', 'Ssrp1'
]
custom_colors = ['white', '#c6c0e0', '#a3a3c2']
custom_cmap = LinearSegmentedColormap.from_list('custom_blues', custom_colors)
# Create 3 subplots vertically
fig, axes = plt.subplots(nrows=3, figsize=(15, 8), constrained_layout=True)
# 1. Plot: groundtruth
df1 = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}_genelevel.csv")
label_counts = df1.groupby(['gene', 'groundtruth']).size().unstack(fill_value=0)
label_props = label_counts.div(label_counts.sum(axis=1), axis=0)
label_props_selected = label_props.loc[label_props.index.intersection(selected_genes)]
label_props_selected = label_props_selected.reindex(selected_genes)
sns.heatmap(label_props_selected.T, annot=False, cmap="GnBu", ax=axes[0], cbar_kws={'label': 'Proportion'},
linewidths=0.5, linecolor='black')
axes[0].set_title("Groundtruth", fontsize=14)
axes[0].tick_params(axis='x', labelsize=10)
axes[0].tick_params(axis='y', labelsize=12)
for label in axes[0].get_yticklabels(): label.set_fontweight('bold')
# 2. Plot: gene_level_pattern
df2 = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}_genelevel.csv")
label_counts = df2.groupby(['gene', 'gene_level_pattern']).size().unstack(fill_value=0)
label_props = label_counts.div(label_counts.sum(axis=1), axis=0)
label_props_selected = label_props.loc[label_props.index.intersection(selected_genes)]
label_props_selected = label_props_selected.reindex(selected_genes)
sns.heatmap(label_props_selected.T, annot=False, cmap="GnBu", ax=axes[1], cbar_kws={'label': 'Proportion'},
linewidths=0.5, linecolor='black')
axes[1].set_title("Gene-level Prediction", fontsize=14)
axes[1].tick_params(axis='x', labelsize=10)
axes[1].tick_params(axis='y', labelsize=12)
for label in axes[1].get_yticklabels(): label.set_fontweight('bold')
# 3. Plot: graph_level_pattern
df3 = pd.read_csv(f"../1.5_benchmark/method4_ours/{dataset}/ours_label_a{a}_b{b}.csv")
label_counts = df3.groupby(['gene', 'graph_level_pattern']).size().unstack(fill_value=0)
label_props = label_counts.div(label_counts.sum(axis=1), axis=0)
label_props_selected = label_props.loc[label_props.index.intersection(selected_genes)]
label_props_selected = label_props_selected.reindex(selected_genes)
sns.heatmap(label_props_selected.T, annot=False, cmap="GnBu", ax=axes[2], cbar_kws={'label': 'Proportion'},
linewidths=0.5, linecolor='black')
axes[2].set_title("Graph-level Prediction", fontsize=14)
axes[2].tick_params(axis='x', labelsize=10)
axes[2].tick_params(axis='y', labelsize=12)
for label in axes[2].get_yticklabels(): label.set_fontweight('bold')
plt.show()
```