Mercurial > repos > rliterman > csp2
comparison CSP2/bin/chooseRefs.py @ 0:01431fa12065
"planemo upload"
author | rliterman |
---|---|
date | Mon, 02 Dec 2024 10:40:55 -0500 |
parents | |
children | 792274118b2e |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:01431fa12065 |
---|---|
1 #!/usr/bin/env python3 | |
2 | |
3 import numpy as np | |
4 import os | |
5 import pandas as pd | |
6 import sys | |
7 from sklearn.cluster import KMeans | |
8 from sklearn.metrics import silhouette_score | |
9 import scipy.stats | |
10 from itertools import combinations | |
11 from Bio import SeqIO | |
12 import argparse | |
13 | |
14 def getOptimalK(data, ref_count): | |
15 | |
16 silhouette_scores = [] | |
17 | |
18 kmeans_1 = KMeans(n_clusters=1, random_state=0, n_init='auto').fit(data) | |
19 kmeans_2 = KMeans(n_clusters=2, random_state=0, n_init='auto').fit(data) | |
20 | |
21 # Compare 1 vs. 2 | |
22 inertia_1 = kmeans_1.inertia_ | |
23 inertia_2 = kmeans_2.inertia_ | |
24 if inertia_1 > inertia_2: | |
25 negative_movements = 1 | |
26 else: | |
27 negative_movements = 0 | |
28 | |
29 # Add k2 data | |
30 labels = kmeans_2.labels_ | |
31 score = silhouette_score(data, labels) | |
32 silhouette_scores.append(score) | |
33 prev_score = score | |
34 | |
35 for k in range(3, ref_count + 3): | |
36 kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(data) | |
37 labels = kmeans.labels_ | |
38 score = silhouette_score(data, labels) | |
39 | |
40 if score < prev_score: | |
41 negative_movements += 1 | |
42 else: | |
43 negative_movements = 0 | |
44 | |
45 silhouette_scores.append(score) | |
46 | |
47 # Stop if two consecutive negative movements occur | |
48 if negative_movements == 2: | |
49 break | |
50 | |
51 prev_score = score | |
52 | |
53 if (inertia_1 < inertia_2) & (silhouette_scores[0] > silhouette_scores[1]): | |
54 optimal_k = 1 | |
55 else: | |
56 optimal_k = np.argmax(silhouette_scores) + 2 | |
57 | |
58 return optimal_k | |
59 | |
60 def fasta_info(file_path): | |
61 records = list(SeqIO.parse(file_path, 'fasta')) | |
62 contig_count = int(len(records)) | |
63 lengths = sorted([len(record) for record in records], reverse=True) | |
64 assembly_bases = sum(lengths) | |
65 | |
66 cumulative_length = 0 | |
67 n50 = None | |
68 n90 = None | |
69 l50 = None | |
70 l90 = None | |
71 | |
72 for i, length in enumerate(lengths, start=1): | |
73 cumulative_length += length | |
74 if cumulative_length >= assembly_bases * 0.5 and n50 is None: | |
75 n50 = length | |
76 l50 = i | |
77 if cumulative_length >= assembly_bases * 0.9 and n90 is None: | |
78 n90 = length | |
79 l90 = i | |
80 if n50 is not None and n90 is not None: | |
81 break | |
82 | |
83 return [file_path,contig_count,assembly_bases,n50,n90,l50,l90] | |
84 | |
85 parser = argparse.ArgumentParser(description='Choose reference isolates based on FASTA metrics and mean distances.') | |
86 parser.add_argument('--ref_count', type=int, help='Number of reference isolates to select') | |
87 parser.add_argument('--mash_triangle_file', type=str, help='Path to the mash triangle file') | |
88 parser.add_argument('--trim_name', type=str, help='Trim name') | |
89 args = parser.parse_args() | |
90 | |
91 ref_count = args.ref_count | |
92 mash_triangle_file = os.path.abspath(args.mash_triangle_file) | |
93 trim_name = args.trim_name | |
94 ref_file = os.path.join(os.path.dirname(mash_triangle_file), 'CSP2_Ref_Selection.tsv') | |
95 | |
96 # Get Sample IDs | |
97 sample_df = pd.read_csv(mash_triangle_file, sep='\t', usecols=[0], skip_blank_lines=True).dropna() | |
98 sample_df = sample_df[sample_df[sample_df.columns[0]].str.strip() != ''] | |
99 sample_df.columns = ['Path'] | |
100 sample_df['Isolate_ID'] = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in sample_df[sample_df.columns[0]].tolist()] | |
101 assembly_names = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in sample_df[sample_df.columns[0]].tolist()] | |
102 num_isolates = sample_df.shape[0] | |
103 | |
104 # Get FASTA metrics | |
105 metrics_df = pd.DataFrame(sample_df['Path'].apply(fasta_info).tolist(), columns=['Path', 'Contigs', 'Length', 'N50','N90','L50','L90']) | |
106 metrics_df['Assembly_Bases_Zscore'] = metrics_df['Length'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0) | |
107 metrics_df['Contig_Count_Zscore'] = metrics_df['Contigs'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0) | |
108 metrics_df['N50_Zscore'] = metrics_df['N50'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0) | |
109 | |
110 # Find outliers | |
111 inlier_df = metrics_df.loc[(metrics_df['N50_Zscore'] > -3) & | |
112 (metrics_df['Assembly_Bases_Zscore'] < 3) & | |
113 (metrics_df['Assembly_Bases_Zscore'] > -3) & | |
114 (metrics_df['Contig_Count_Zscore'] < 3)] | |
115 | |
116 inlier_count = inlier_df.shape[0] | |
117 inlier_isolates = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in inlier_df[inlier_df.columns[0]].tolist()] | |
118 | |
119 # If not enough or just enough inliers, script is done | |
120 if ref_count > inlier_count: | |
121 sys.exit("Error: Fewer inliers than requested references?") | |
122 elif ref_count == inlier_count: | |
123 print(",".join(inlier_df['Path'].tolist())) | |
124 sys.exit(0) | |
125 | |
126 # Left join metrics_df and inlier_df | |
127 sample_df = inlier_df.merge(sample_df, on = "Path", how='left')[['Isolate_ID','Path','Contigs','Length','N50','N90','L50','L90','N50_Zscore']] | |
128 | |
129 # Create distance matrix | |
130 with open(mash_triangle_file) as mash_triangle: | |
131 a = np.zeros((num_isolates, num_isolates)) | |
132 mash_triangle.readline() | |
133 mash_triangle.readline() | |
134 idx = 1 | |
135 for line in mash_triangle: | |
136 tokens = line.split() | |
137 distances = [float(token) for token in tokens[1:]] | |
138 a[idx, 0: len(distances)] = distances | |
139 a[0: len(distances), idx] = distances | |
140 idx += 1 | |
141 | |
142 dist_df = pd.DataFrame(a, index=assembly_names, columns=assembly_names).loc[inlier_isolates,inlier_isolates] | |
143 # Get mean distances after masking diagonal | |
144 mask = ~np.eye(dist_df.shape[0], dtype=bool) | |
145 mean_distances = dist_df.where(mask).mean().reset_index() | |
146 mean_distances.columns = ['Isolate_ID', 'Mean_Distance'] | |
147 | |
148 sample_df = sample_df.merge(mean_distances, on='Isolate_ID', how='left') | |
149 sample_df['Mean_Distance_Zscore'] = sample_df['Mean_Distance'].transform(scipy.stats.zscore).astype('float').round(3) | |
150 sample_df['Base_Score'] = sample_df['N50_Zscore'] - sample_df['Mean_Distance_Zscore'].fillna(0) | |
151 | |
152 if ref_count == 1: | |
153 print(",".join(sample_df.nlargest(1, 'Base_Score')['Path'].tolist())) | |
154 sys.exit(0) | |
155 | |
156 optimal_k = getOptimalK(dist_df, ref_count) | |
157 | |
158 if optimal_k == 1: | |
159 print(",".join(sample_df.nlargest(ref_count, 'Base_Score')['Path'].tolist())) | |
160 sys.exit(0) | |
161 | |
162 kmeans = KMeans(n_clusters=optimal_k, random_state=0,n_init='auto').fit(dist_df) | |
163 clusters = kmeans.labels_ | |
164 | |
165 cluster_df = pd.DataFrame({'Isolate_ID': dist_df.index, 'Cluster': clusters}).merge(sample_df, on='Isolate_ID',how='left') | |
166 cluster_counts = cluster_df['Cluster'].value_counts().reset_index() | |
167 cluster_counts.columns = ['Cluster', 'count'] | |
168 cluster_counts['Prop'] = cluster_counts['count'] / cluster_counts['count'].sum() | |
169 cluster_df = cluster_df.merge(cluster_counts[['Cluster', 'Prop']], on='Cluster') | |
170 | |
171 # Grab top ref | |
172 final_ref_df = cluster_df.nlargest(1, 'Base_Score') | |
173 refs_chosen = final_ref_df['Isolate_ID'].tolist() | |
174 | |
175 possible_refs = cluster_df.loc[~cluster_df['Isolate_ID'].isin(refs_chosen)].copy() | |
176 | |
177 while len(refs_chosen) < ref_count: | |
178 possible_refs['Mean_Ref_Distance'] = possible_refs['Isolate_ID'].apply(lambda isolate_id: np.mean(dist_df.loc[isolate_id, refs_chosen].values)) | |
179 possible_refs['Mean_Ref_Distance_Zscore'] = possible_refs['Mean_Ref_Distance'].transform(scipy.stats.zscore).astype('float').round(3) | |
180 possible_refs['Sort_Score'] = possible_refs.apply(lambda row: (row['Base_Score'] + row['Mean_Ref_Distance_Zscore']) if row['Mean_Ref_Distance_Zscore'] <= 0 else (row['Base_Score'] + (row['Mean_Ref_Distance_Zscore']*row['Prop'])), axis=1) | |
181 | |
182 final_ref_df = pd.concat([final_ref_df, possible_refs.nlargest(1, 'Sort_Score').drop(['Sort_Score','Mean_Ref_Distance','Mean_Ref_Distance_Zscore'],axis=1)]) | |
183 refs_chosen = final_ref_df['Isolate_ID'].tolist() | |
184 possible_refs = possible_refs.loc[~possible_refs['Isolate_ID'].isin(refs_chosen)].copy() | |
185 | |
186 non_ref_df = cluster_df.loc[~cluster_df['Isolate_ID'].isin(refs_chosen)].sort_values('Base_Score', ascending=False) | |
187 non_ref_df['Is_Ref'] = False | |
188 final_ref_df['Is_Ref'] = True | |
189 pd.concat([final_ref_df, non_ref_df]).reset_index(drop=True).to_csv(ref_file, index=False, sep="\t") | |
190 | |
191 print(",".join(final_ref_df['Path'].tolist())) |