annotate predict_source.py @ 7:cc937b6c75b5 draft

planemo upload commit 021b26b29595f4052f55b7a51bb84e8bcb0898ad
author galaxytrakr
date Wed, 29 Apr 2026 20:25:19 +0000
parents 954eccb7cc48
children
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
1 import argparse
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
2 import sys
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
3 import os
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
4 import pandas as pd
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
5 import joblib
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
6 import warnings
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
7
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
8 # Suppress scikit-learn warnings about feature names if they pop up
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
9 warnings.filterwarnings("ignore", category=UserWarning)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
10
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
11 def main():
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
12 parser = argparse.ArgumentParser(description="Predict the source of an isolate using a trained Random Forest model and Mash distances.")
2
954eccb7cc48 planemo upload commit 8da69fa6fb55921a1b005486e59563f0c6956c2e
galaxytrakr
parents: 0
diff changeset
13 parser.add_argument('--version', action='version', version='PlasmidTrakr v0.1.0')
954eccb7cc48 planemo upload commit 8da69fa6fb55921a1b005486e59563f0c6956c2e
galaxytrakr
parents: 0
diff changeset
14
0
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
15 parser.add_argument("-i", "--input", required=True, help="Input Mash screen/dist file for one or more isolates.")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
16 parser.add_argument("-b", "--bundle", required=True, help="Path to the bundled model and features (.joblib file)")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
17 parser.add_argument("-t", "--threshold", type=float, default=0.95, help="Mash identity threshold (default: 0.95)")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
18 parser.add_argument("-o", "--output", default="predictions.tsv", help="Output file for predictions (default: predictions.tsv)")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
19 args = parser.parse_args()
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
20
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
21 print(f"Loading model bundle: {args.bundle}")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
22
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
23 try:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
24 # --- KEY FIX 2: Load the dictionary and extract both pieces ---
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
25 bundle = joblib.load(args.bundle)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
26 rf_model = bundle['model']
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
27 training_features = bundle['features']
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
28 print(f"Successfully loaded model and {len(training_features)} features.")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
29 except Exception as e:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
30 print(f"FATAL: Error loading model bundle: {e}")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
31 sys.exit(1)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
32
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
33 print(f"Loading and processing input data: {args.input}")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
34
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
35 try:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
36 df = pd.read_csv(args.input, sep='\s+', header=None, engine='python')
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
37
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
38 # Your format is from 'mash screen', where the columns are:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
39 # Identity, Shared-hashes, Median-multiplicity, P-value, Query-ID
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
40 if len(df.columns) >= 5:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
41 print("--> Standard headerless Mash output detected.")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
42 # Keep only the first 5 columns to be safe
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
43 df = df.iloc[:, :5]
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
44 df.columns = ['Identity', 'Shared_Hashes', 'Median_Multiplicity', 'P_value', 'Plasmid_ID']
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
45
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
46 # The 'Identity' is already the first column, just convert it to numeric
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
47 df['Identity'] = pd.to_numeric(df['Identity'], errors='coerce')
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
48
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
49 # We need to manually add the 'Run' column. For screen output, the Query-ID (isolate name)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
50 # is not present in the file itself. We must get it from the filename.
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
51 run_id = os.path.splitext(os.path.basename(args.input))[0]
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
52 df['Run'] = run_id
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
53 else:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
54 print(f"FATAL: Input file format not recognized. Expected at least 5 columns for Mash output, but got {len(df.columns)}.")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
55 sys.exit(1)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
56
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
57 df.dropna(subset=['Identity'], inplace=True)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
58 df['Run'] = df['Run'].astype(str).str.strip()
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
59
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
60 except Exception as e:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
61 print(f"FATAL: Error reading input file '{args.input}'. Error: {e}")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
62 sys.exit(1)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
63
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
64 print(f"Filtering features (Identity >= {args.threshold})...")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
65 filtered_df = df[df['Identity'] >= args.threshold].copy()
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
66
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
67 if filtered_df.empty:
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
68 print("Warning: No plasmid hits met the identity threshold. Cannot make a prediction.")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
69 sys.exit(0)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
70
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
71 new_data_matrix = filtered_df.pivot_table(index='Run', columns='Plasmid_ID', values='Identity', fill_value=0)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
72
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
73 print("Aligning input features with the trained model...")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
74 aligned_matrix = pd.DataFrame(0, index=new_data_matrix.index, columns=training_features)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
75 common_plasmids = new_data_matrix.columns.intersection(training_features)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
76 aligned_matrix[common_plasmids] = new_data_matrix[common_plasmids]
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
77
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
78 print(f"Making predictions for {len(aligned_matrix)} isolate(s)...")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
79 predictions = rf_model.predict(aligned_matrix)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
80 probabilities = rf_model.predict_proba(aligned_matrix)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
81 max_probs = probabilities.max(axis=1)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
82
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
83 results_df = pd.DataFrame({
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
84 'Run': aligned_matrix.index,
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
85 'Predicted_Source': predictions,
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
86 'Confidence_Score': max_probs
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
87 })
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
88
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
89 results_df.to_csv(args.output, sep='\t', index=False)
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
90 print(f"\n✅ Predictions complete! Saved to {args.output}")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
91 print("--- PREDICTION RESULTS ---")
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
92 print(results_df.to_string(index=False))
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
93
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
94 if __name__ == "__main__":
f25631df0e9f planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff changeset
95 main()