rliterman@0
|
1 #!/usr/bin/env python3
|
rliterman@0
|
2
|
rliterman@0
|
3 import numpy as np
|
rliterman@0
|
4 import os
|
rliterman@0
|
5 import pandas as pd
|
rliterman@0
|
6 import sys
|
rliterman@0
|
7 from sklearn.cluster import KMeans
|
rliterman@0
|
8 from sklearn.metrics import silhouette_score
|
rliterman@0
|
9 import scipy.stats
|
rliterman@0
|
10 from itertools import combinations
|
rliterman@0
|
11 from Bio import SeqIO
|
rliterman@0
|
12 import argparse
|
rliterman@0
|
13
|
rliterman@0
|
14 def getOptimalK(data, ref_count):
|
rliterman@0
|
15
|
rliterman@0
|
16 silhouette_scores = []
|
rliterman@0
|
17
|
rliterman@0
|
18 kmeans_1 = KMeans(n_clusters=1, random_state=0, n_init='auto').fit(data)
|
rliterman@0
|
19 kmeans_2 = KMeans(n_clusters=2, random_state=0, n_init='auto').fit(data)
|
rliterman@0
|
20
|
rliterman@0
|
21 # Compare 1 vs. 2
|
rliterman@0
|
22 inertia_1 = kmeans_1.inertia_
|
rliterman@0
|
23 inertia_2 = kmeans_2.inertia_
|
rliterman@0
|
24 if inertia_1 > inertia_2:
|
rliterman@0
|
25 negative_movements = 1
|
rliterman@0
|
26 else:
|
rliterman@0
|
27 negative_movements = 0
|
rliterman@0
|
28
|
rliterman@0
|
29 # Add k2 data
|
rliterman@0
|
30 labels = kmeans_2.labels_
|
rliterman@0
|
31 score = silhouette_score(data, labels)
|
rliterman@0
|
32 silhouette_scores.append(score)
|
rliterman@0
|
33 prev_score = score
|
rliterman@0
|
34
|
rliterman@0
|
35 for k in range(3, ref_count + 3):
|
rliterman@0
|
36 kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(data)
|
rliterman@0
|
37 labels = kmeans.labels_
|
rliterman@0
|
38 score = silhouette_score(data, labels)
|
rliterman@0
|
39
|
rliterman@0
|
40 if score < prev_score:
|
rliterman@0
|
41 negative_movements += 1
|
rliterman@0
|
42 else:
|
rliterman@0
|
43 negative_movements = 0
|
rliterman@0
|
44
|
rliterman@0
|
45 silhouette_scores.append(score)
|
rliterman@0
|
46
|
rliterman@0
|
47 # Stop if two consecutive negative movements occur
|
rliterman@0
|
48 if negative_movements == 2:
|
rliterman@0
|
49 break
|
rliterman@0
|
50
|
rliterman@0
|
51 prev_score = score
|
rliterman@0
|
52
|
rliterman@0
|
53 if (inertia_1 < inertia_2) & (silhouette_scores[0] > silhouette_scores[1]):
|
rliterman@0
|
54 optimal_k = 1
|
rliterman@0
|
55 else:
|
rliterman@0
|
56 optimal_k = np.argmax(silhouette_scores) + 2
|
rliterman@0
|
57
|
rliterman@0
|
58 return optimal_k
|
rliterman@0
|
59
|
rliterman@0
|
60 def fasta_info(file_path):
|
rliterman@0
|
61 records = list(SeqIO.parse(file_path, 'fasta'))
|
rliterman@0
|
62 contig_count = int(len(records))
|
rliterman@0
|
63 lengths = sorted([len(record) for record in records], reverse=True)
|
rliterman@0
|
64 assembly_bases = sum(lengths)
|
rliterman@0
|
65
|
rliterman@0
|
66 cumulative_length = 0
|
rliterman@0
|
67 n50 = None
|
rliterman@0
|
68 n90 = None
|
rliterman@0
|
69 l50 = None
|
rliterman@0
|
70 l90 = None
|
rliterman@0
|
71
|
rliterman@0
|
72 for i, length in enumerate(lengths, start=1):
|
rliterman@0
|
73 cumulative_length += length
|
rliterman@0
|
74 if cumulative_length >= assembly_bases * 0.5 and n50 is None:
|
rliterman@0
|
75 n50 = length
|
rliterman@0
|
76 l50 = i
|
rliterman@0
|
77 if cumulative_length >= assembly_bases * 0.9 and n90 is None:
|
rliterman@0
|
78 n90 = length
|
rliterman@0
|
79 l90 = i
|
rliterman@0
|
80 if n50 is not None and n90 is not None:
|
rliterman@0
|
81 break
|
rliterman@0
|
82
|
rliterman@0
|
83 return [file_path,contig_count,assembly_bases,n50,n90,l50,l90]
|
rliterman@0
|
84
|
rliterman@0
|
85 parser = argparse.ArgumentParser(description='Choose reference isolates based on FASTA metrics and mean distances.')
|
rliterman@27
|
86 parser.add_argument('--ref_count', type=int, default=1, help='Number of reference isolates to select')
|
rliterman@0
|
87 parser.add_argument('--mash_triangle_file', type=str, help='Path to the mash triangle file')
|
rliterman@28
|
88 parser.add_argument('--trim_name', nargs='?', const="", default="", type=str, help='Trim name')
|
rliterman@0
|
89 args = parser.parse_args()
|
rliterman@0
|
90
|
rliterman@0
|
91 ref_count = args.ref_count
|
rliterman@0
|
92 mash_triangle_file = os.path.abspath(args.mash_triangle_file)
|
rliterman@0
|
93 trim_name = args.trim_name
|
rliterman@0
|
94 ref_file = os.path.join(os.path.dirname(mash_triangle_file), 'CSP2_Ref_Selection.tsv')
|
rliterman@0
|
95
|
rliterman@0
|
96 # Get Sample IDs
|
rliterman@0
|
97 sample_df = pd.read_csv(mash_triangle_file, sep='\t', usecols=[0], skip_blank_lines=True).dropna()
|
rliterman@0
|
98 sample_df = sample_df[sample_df[sample_df.columns[0]].str.strip() != '']
|
rliterman@0
|
99 sample_df.columns = ['Path']
|
rliterman@0
|
100 sample_df['Isolate_ID'] = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in sample_df[sample_df.columns[0]].tolist()]
|
rliterman@0
|
101 assembly_names = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in sample_df[sample_df.columns[0]].tolist()]
|
rliterman@0
|
102 num_isolates = sample_df.shape[0]
|
rliterman@0
|
103
|
rliterman@0
|
104 # Get FASTA metrics
|
rliterman@0
|
105 metrics_df = pd.DataFrame(sample_df['Path'].apply(fasta_info).tolist(), columns=['Path', 'Contigs', 'Length', 'N50','N90','L50','L90'])
|
rliterman@0
|
106 metrics_df['Assembly_Bases_Zscore'] = metrics_df['Length'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0)
|
rliterman@0
|
107 metrics_df['Contig_Count_Zscore'] = metrics_df['Contigs'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0)
|
rliterman@0
|
108 metrics_df['N50_Zscore'] = metrics_df['N50'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0)
|
rliterman@0
|
109
|
rliterman@0
|
110 # Find outliers
|
rliterman@0
|
111 inlier_df = metrics_df.loc[(metrics_df['N50_Zscore'] > -3) &
|
rliterman@0
|
112 (metrics_df['Assembly_Bases_Zscore'] < 3) &
|
rliterman@0
|
113 (metrics_df['Assembly_Bases_Zscore'] > -3) &
|
rliterman@0
|
114 (metrics_df['Contig_Count_Zscore'] < 3)]
|
rliterman@0
|
115
|
rliterman@0
|
116 inlier_count = inlier_df.shape[0]
|
rliterman@0
|
117 inlier_isolates = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in inlier_df[inlier_df.columns[0]].tolist()]
|
rliterman@0
|
118
|
rliterman@0
|
119 # If not enough or just enough inliers, script is done
|
rliterman@0
|
120 if ref_count > inlier_count:
|
rliterman@0
|
121 sys.exit("Error: Fewer inliers than requested references?")
|
rliterman@0
|
122 elif ref_count == inlier_count:
|
rliterman@0
|
123 print(",".join(inlier_df['Path'].tolist()))
|
rliterman@0
|
124 sys.exit(0)
|
rliterman@0
|
125
|
rliterman@0
|
126 # Left join metrics_df and inlier_df
|
rliterman@0
|
127 sample_df = inlier_df.merge(sample_df, on = "Path", how='left')[['Isolate_ID','Path','Contigs','Length','N50','N90','L50','L90','N50_Zscore']]
|
rliterman@0
|
128
|
rliterman@0
|
129 # Create distance matrix
|
rliterman@0
|
130 with open(mash_triangle_file) as mash_triangle:
|
rliterman@0
|
131 a = np.zeros((num_isolates, num_isolates))
|
rliterman@0
|
132 mash_triangle.readline()
|
rliterman@0
|
133 mash_triangle.readline()
|
rliterman@0
|
134 idx = 1
|
rliterman@0
|
135 for line in mash_triangle:
|
rliterman@0
|
136 tokens = line.split()
|
rliterman@0
|
137 distances = [float(token) for token in tokens[1:]]
|
rliterman@0
|
138 a[idx, 0: len(distances)] = distances
|
rliterman@0
|
139 a[0: len(distances), idx] = distances
|
rliterman@0
|
140 idx += 1
|
rliterman@0
|
141
|
rliterman@0
|
142 dist_df = pd.DataFrame(a, index=assembly_names, columns=assembly_names).loc[inlier_isolates,inlier_isolates]
|
rliterman@0
|
143 # Get mean distances after masking diagonal
|
rliterman@0
|
144 mask = ~np.eye(dist_df.shape[0], dtype=bool)
|
rliterman@0
|
145 mean_distances = dist_df.where(mask).mean().reset_index()
|
rliterman@0
|
146 mean_distances.columns = ['Isolate_ID', 'Mean_Distance']
|
rliterman@0
|
147
|
rliterman@0
|
148 sample_df = sample_df.merge(mean_distances, on='Isolate_ID', how='left')
|
rliterman@0
|
149 sample_df['Mean_Distance_Zscore'] = sample_df['Mean_Distance'].transform(scipy.stats.zscore).astype('float').round(3)
|
rliterman@0
|
150 sample_df['Base_Score'] = sample_df['N50_Zscore'] - sample_df['Mean_Distance_Zscore'].fillna(0)
|
rliterman@0
|
151
|
rliterman@0
|
152 if ref_count == 1:
|
rliterman@0
|
153 print(",".join(sample_df.nlargest(1, 'Base_Score')['Path'].tolist()))
|
rliterman@0
|
154 sys.exit(0)
|
rliterman@0
|
155
|
rliterman@0
|
156 optimal_k = getOptimalK(dist_df, ref_count)
|
rliterman@0
|
157
|
rliterman@0
|
158 if optimal_k == 1:
|
rliterman@0
|
159 print(",".join(sample_df.nlargest(ref_count, 'Base_Score')['Path'].tolist()))
|
rliterman@0
|
160 sys.exit(0)
|
rliterman@0
|
161
|
rliterman@0
|
162 kmeans = KMeans(n_clusters=optimal_k, random_state=0,n_init='auto').fit(dist_df)
|
rliterman@0
|
163 clusters = kmeans.labels_
|
rliterman@0
|
164
|
rliterman@0
|
165 cluster_df = pd.DataFrame({'Isolate_ID': dist_df.index, 'Cluster': clusters}).merge(sample_df, on='Isolate_ID',how='left')
|
rliterman@0
|
166 cluster_counts = cluster_df['Cluster'].value_counts().reset_index()
|
rliterman@0
|
167 cluster_counts.columns = ['Cluster', 'count']
|
rliterman@0
|
168 cluster_counts['Prop'] = cluster_counts['count'] / cluster_counts['count'].sum()
|
rliterman@0
|
169 cluster_df = cluster_df.merge(cluster_counts[['Cluster', 'Prop']], on='Cluster')
|
rliterman@0
|
170
|
rliterman@0
|
171 # Grab top ref
|
rliterman@0
|
172 final_ref_df = cluster_df.nlargest(1, 'Base_Score')
|
rliterman@0
|
173 refs_chosen = final_ref_df['Isolate_ID'].tolist()
|
rliterman@0
|
174
|
rliterman@0
|
175 possible_refs = cluster_df.loc[~cluster_df['Isolate_ID'].isin(refs_chosen)].copy()
|
rliterman@0
|
176
|
rliterman@0
|
177 while len(refs_chosen) < ref_count:
|
rliterman@0
|
178 possible_refs['Mean_Ref_Distance'] = possible_refs['Isolate_ID'].apply(lambda isolate_id: np.mean(dist_df.loc[isolate_id, refs_chosen].values))
|
rliterman@0
|
179 possible_refs['Mean_Ref_Distance_Zscore'] = possible_refs['Mean_Ref_Distance'].transform(scipy.stats.zscore).astype('float').round(3)
|
rliterman@0
|
180 possible_refs['Sort_Score'] = possible_refs.apply(lambda row: (row['Base_Score'] + row['Mean_Ref_Distance_Zscore']) if row['Mean_Ref_Distance_Zscore'] <= 0 else (row['Base_Score'] + (row['Mean_Ref_Distance_Zscore']*row['Prop'])), axis=1)
|
rliterman@0
|
181
|
rliterman@0
|
182 final_ref_df = pd.concat([final_ref_df, possible_refs.nlargest(1, 'Sort_Score').drop(['Sort_Score','Mean_Ref_Distance','Mean_Ref_Distance_Zscore'],axis=1)])
|
rliterman@0
|
183 refs_chosen = final_ref_df['Isolate_ID'].tolist()
|
rliterman@0
|
184 possible_refs = possible_refs.loc[~possible_refs['Isolate_ID'].isin(refs_chosen)].copy()
|
rliterman@0
|
185
|
rliterman@0
|
186 non_ref_df = cluster_df.loc[~cluster_df['Isolate_ID'].isin(refs_chosen)].sort_values('Base_Score', ascending=False)
|
rliterman@0
|
187 non_ref_df['Is_Ref'] = False
|
rliterman@0
|
188 final_ref_df['Is_Ref'] = True
|
rliterman@39
|
189
|
rliterman@39
|
190 with open(ref_file, 'w') as f:
|
rliterman@39
|
191 pd.concat([final_ref_df, non_ref_df]).reset_index(drop=True).to_csv(f, index=False, sep="\t")
|
rliterman@0
|
192
|
rliterman@0
|
193 print(",".join(final_ref_df['Path'].tolist())) |