rliterman@0: #!/usr/bin/env python3 rliterman@0: rliterman@0: import numpy as np rliterman@0: import os rliterman@0: import pandas as pd rliterman@0: import sys rliterman@0: from sklearn.cluster import KMeans rliterman@0: from sklearn.metrics import silhouette_score rliterman@0: import scipy.stats rliterman@0: from itertools import combinations rliterman@0: from Bio import SeqIO rliterman@0: import argparse rliterman@0: rliterman@0: def getOptimalK(data, ref_count): rliterman@0: rliterman@0: silhouette_scores = [] rliterman@0: rliterman@0: kmeans_1 = KMeans(n_clusters=1, random_state=0, n_init='auto').fit(data) rliterman@0: kmeans_2 = KMeans(n_clusters=2, random_state=0, n_init='auto').fit(data) rliterman@0: rliterman@0: # Compare 1 vs. 2 rliterman@0: inertia_1 = kmeans_1.inertia_ rliterman@0: inertia_2 = kmeans_2.inertia_ rliterman@0: if inertia_1 > inertia_2: rliterman@0: negative_movements = 1 rliterman@0: else: rliterman@0: negative_movements = 0 rliterman@0: rliterman@0: # Add k2 data rliterman@0: labels = kmeans_2.labels_ rliterman@0: score = silhouette_score(data, labels) rliterman@0: silhouette_scores.append(score) rliterman@0: prev_score = score rliterman@0: rliterman@0: for k in range(3, ref_count + 3): rliterman@0: kmeans = KMeans(n_clusters=k, random_state=0, n_init='auto').fit(data) rliterman@0: labels = kmeans.labels_ rliterman@0: score = silhouette_score(data, labels) rliterman@0: rliterman@0: if score < prev_score: rliterman@0: negative_movements += 1 rliterman@0: else: rliterman@0: negative_movements = 0 rliterman@0: rliterman@0: silhouette_scores.append(score) rliterman@0: rliterman@0: # Stop if two consecutive negative movements occur rliterman@0: if negative_movements == 2: rliterman@0: break rliterman@0: rliterman@0: prev_score = score rliterman@0: rliterman@0: if (inertia_1 < inertia_2) & (silhouette_scores[0] > silhouette_scores[1]): rliterman@0: optimal_k = 1 rliterman@0: else: rliterman@0: optimal_k = np.argmax(silhouette_scores) + 2 rliterman@0: rliterman@0: return optimal_k rliterman@0: rliterman@0: def fasta_info(file_path): rliterman@0: records = list(SeqIO.parse(file_path, 'fasta')) rliterman@0: contig_count = int(len(records)) rliterman@0: lengths = sorted([len(record) for record in records], reverse=True) rliterman@0: assembly_bases = sum(lengths) rliterman@0: rliterman@0: cumulative_length = 0 rliterman@0: n50 = None rliterman@0: n90 = None rliterman@0: l50 = None rliterman@0: l90 = None rliterman@0: rliterman@0: for i, length in enumerate(lengths, start=1): rliterman@0: cumulative_length += length rliterman@0: if cumulative_length >= assembly_bases * 0.5 and n50 is None: rliterman@0: n50 = length rliterman@0: l50 = i rliterman@0: if cumulative_length >= assembly_bases * 0.9 and n90 is None: rliterman@0: n90 = length rliterman@0: l90 = i rliterman@0: if n50 is not None and n90 is not None: rliterman@0: break rliterman@0: rliterman@0: return [file_path,contig_count,assembly_bases,n50,n90,l50,l90] rliterman@0: rliterman@0: parser = argparse.ArgumentParser(description='Choose reference isolates based on FASTA metrics and mean distances.') rliterman@27: parser.add_argument('--ref_count', type=int, default=1, help='Number of reference isolates to select') rliterman@0: parser.add_argument('--mash_triangle_file', type=str, help='Path to the mash triangle file') rliterman@28: parser.add_argument('--trim_name', nargs='?', const="", default="", type=str, help='Trim name') rliterman@0: args = parser.parse_args() rliterman@0: rliterman@0: ref_count = args.ref_count rliterman@0: mash_triangle_file = os.path.abspath(args.mash_triangle_file) rliterman@0: trim_name = args.trim_name rliterman@0: ref_file = os.path.join(os.path.dirname(mash_triangle_file), 'CSP2_Ref_Selection.tsv') rliterman@0: rliterman@0: # Get Sample IDs rliterman@0: sample_df = pd.read_csv(mash_triangle_file, sep='\t', usecols=[0], skip_blank_lines=True).dropna() rliterman@0: sample_df = sample_df[sample_df[sample_df.columns[0]].str.strip() != ''] rliterman@0: sample_df.columns = ['Path'] rliterman@0: sample_df['Isolate_ID'] = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in sample_df[sample_df.columns[0]].tolist()] rliterman@0: assembly_names = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in sample_df[sample_df.columns[0]].tolist()] rliterman@0: num_isolates = sample_df.shape[0] rliterman@0: rliterman@0: # Get FASTA metrics rliterman@0: metrics_df = pd.DataFrame(sample_df['Path'].apply(fasta_info).tolist(), columns=['Path', 'Contigs', 'Length', 'N50','N90','L50','L90']) rliterman@0: metrics_df['Assembly_Bases_Zscore'] = metrics_df['Length'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0) rliterman@0: metrics_df['Contig_Count_Zscore'] = metrics_df['Contigs'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0) rliterman@0: metrics_df['N50_Zscore'] = metrics_df['N50'].transform(scipy.stats.zscore).astype('float').round(3).fillna(0) rliterman@0: rliterman@0: # Find outliers rliterman@0: inlier_df = metrics_df.loc[(metrics_df['N50_Zscore'] > -3) & rliterman@0: (metrics_df['Assembly_Bases_Zscore'] < 3) & rliterman@0: (metrics_df['Assembly_Bases_Zscore'] > -3) & rliterman@0: (metrics_df['Contig_Count_Zscore'] < 3)] rliterman@0: rliterman@0: inlier_count = inlier_df.shape[0] rliterman@0: inlier_isolates = [os.path.splitext(os.path.basename(file))[0].replace(trim_name, '') for file in inlier_df[inlier_df.columns[0]].tolist()] rliterman@0: rliterman@0: # If not enough or just enough inliers, script is done rliterman@0: if ref_count > inlier_count: rliterman@0: sys.exit("Error: Fewer inliers than requested references?") rliterman@0: elif ref_count == inlier_count: rliterman@0: print(",".join(inlier_df['Path'].tolist())) rliterman@0: sys.exit(0) rliterman@0: rliterman@0: # Left join metrics_df and inlier_df rliterman@0: sample_df = inlier_df.merge(sample_df, on = "Path", how='left')[['Isolate_ID','Path','Contigs','Length','N50','N90','L50','L90','N50_Zscore']] rliterman@0: rliterman@0: # Create distance matrix rliterman@0: with open(mash_triangle_file) as mash_triangle: rliterman@0: a = np.zeros((num_isolates, num_isolates)) rliterman@0: mash_triangle.readline() rliterman@0: mash_triangle.readline() rliterman@0: idx = 1 rliterman@0: for line in mash_triangle: rliterman@0: tokens = line.split() rliterman@0: distances = [float(token) for token in tokens[1:]] rliterman@0: a[idx, 0: len(distances)] = distances rliterman@0: a[0: len(distances), idx] = distances rliterman@0: idx += 1 rliterman@0: rliterman@0: dist_df = pd.DataFrame(a, index=assembly_names, columns=assembly_names).loc[inlier_isolates,inlier_isolates] rliterman@0: # Get mean distances after masking diagonal rliterman@0: mask = ~np.eye(dist_df.shape[0], dtype=bool) rliterman@0: mean_distances = dist_df.where(mask).mean().reset_index() rliterman@0: mean_distances.columns = ['Isolate_ID', 'Mean_Distance'] rliterman@0: rliterman@0: sample_df = sample_df.merge(mean_distances, on='Isolate_ID', how='left') rliterman@0: sample_df['Mean_Distance_Zscore'] = sample_df['Mean_Distance'].transform(scipy.stats.zscore).astype('float').round(3) rliterman@0: sample_df['Base_Score'] = sample_df['N50_Zscore'] - sample_df['Mean_Distance_Zscore'].fillna(0) rliterman@0: rliterman@0: if ref_count == 1: rliterman@0: print(",".join(sample_df.nlargest(1, 'Base_Score')['Path'].tolist())) rliterman@0: sys.exit(0) rliterman@0: rliterman@0: optimal_k = getOptimalK(dist_df, ref_count) rliterman@0: rliterman@0: if optimal_k == 1: rliterman@0: print(",".join(sample_df.nlargest(ref_count, 'Base_Score')['Path'].tolist())) rliterman@0: sys.exit(0) rliterman@0: rliterman@0: kmeans = KMeans(n_clusters=optimal_k, random_state=0,n_init='auto').fit(dist_df) rliterman@0: clusters = kmeans.labels_ rliterman@0: rliterman@0: cluster_df = pd.DataFrame({'Isolate_ID': dist_df.index, 'Cluster': clusters}).merge(sample_df, on='Isolate_ID',how='left') rliterman@0: cluster_counts = cluster_df['Cluster'].value_counts().reset_index() rliterman@0: cluster_counts.columns = ['Cluster', 'count'] rliterman@0: cluster_counts['Prop'] = cluster_counts['count'] / cluster_counts['count'].sum() rliterman@0: cluster_df = cluster_df.merge(cluster_counts[['Cluster', 'Prop']], on='Cluster') rliterman@0: rliterman@0: # Grab top ref rliterman@0: final_ref_df = cluster_df.nlargest(1, 'Base_Score') rliterman@0: refs_chosen = final_ref_df['Isolate_ID'].tolist() rliterman@0: rliterman@0: possible_refs = cluster_df.loc[~cluster_df['Isolate_ID'].isin(refs_chosen)].copy() rliterman@0: rliterman@0: while len(refs_chosen) < ref_count: rliterman@0: possible_refs['Mean_Ref_Distance'] = possible_refs['Isolate_ID'].apply(lambda isolate_id: np.mean(dist_df.loc[isolate_id, refs_chosen].values)) rliterman@0: possible_refs['Mean_Ref_Distance_Zscore'] = possible_refs['Mean_Ref_Distance'].transform(scipy.stats.zscore).astype('float').round(3) rliterman@0: possible_refs['Sort_Score'] = possible_refs.apply(lambda row: (row['Base_Score'] + row['Mean_Ref_Distance_Zscore']) if row['Mean_Ref_Distance_Zscore'] <= 0 else (row['Base_Score'] + (row['Mean_Ref_Distance_Zscore']*row['Prop'])), axis=1) rliterman@0: rliterman@0: final_ref_df = pd.concat([final_ref_df, possible_refs.nlargest(1, 'Sort_Score').drop(['Sort_Score','Mean_Ref_Distance','Mean_Ref_Distance_Zscore'],axis=1)]) rliterman@0: refs_chosen = final_ref_df['Isolate_ID'].tolist() rliterman@0: possible_refs = possible_refs.loc[~possible_refs['Isolate_ID'].isin(refs_chosen)].copy() rliterman@0: rliterman@0: non_ref_df = cluster_df.loc[~cluster_df['Isolate_ID'].isin(refs_chosen)].sort_values('Base_Score', ascending=False) rliterman@0: non_ref_df['Is_Ref'] = False rliterman@0: final_ref_df['Is_Ref'] = True rliterman@39: rliterman@39: with open(ref_file, 'w') as f: rliterman@39: pd.concat([final_ref_df, non_ref_df]).reset_index(drop=True).to_csv(f, index=False, sep="\t") rliterman@0: rliterman@0: print(",".join(final_ref_df['Path'].tolist()))