Mercurial > repos > galaxytrakr > plasmidtrakr
annotate predict_source.py @ 6:b226c7a8580d draft
planemo upload commit ebe76ee27375408d989dc1cd612b20639a38b260
| author | galaxytrakr |
|---|---|
| date | Wed, 29 Apr 2026 18:17:28 +0000 |
| parents | 954eccb7cc48 |
| children |
| rev | line source |
|---|---|
|
0
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
1 import argparse |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
2 import sys |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
3 import os |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
4 import pandas as pd |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
5 import joblib |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
6 import warnings |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
7 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
8 # Suppress scikit-learn warnings about feature names if they pop up |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
9 warnings.filterwarnings("ignore", category=UserWarning) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
10 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
11 def main(): |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
12 parser = argparse.ArgumentParser(description="Predict the source of an isolate using a trained Random Forest model and Mash distances.") |
|
2
954eccb7cc48
planemo upload commit 8da69fa6fb55921a1b005486e59563f0c6956c2e
galaxytrakr
parents:
0
diff
changeset
|
13 parser.add_argument('--version', action='version', version='PlasmidTrakr v0.1.0') |
|
954eccb7cc48
planemo upload commit 8da69fa6fb55921a1b005486e59563f0c6956c2e
galaxytrakr
parents:
0
diff
changeset
|
14 |
|
0
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
15 parser.add_argument("-i", "--input", required=True, help="Input Mash screen/dist file for one or more isolates.") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
16 parser.add_argument("-b", "--bundle", required=True, help="Path to the bundled model and features (.joblib file)") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
17 parser.add_argument("-t", "--threshold", type=float, default=0.95, help="Mash identity threshold (default: 0.95)") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
18 parser.add_argument("-o", "--output", default="predictions.tsv", help="Output file for predictions (default: predictions.tsv)") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
19 args = parser.parse_args() |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
20 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
21 print(f"Loading model bundle: {args.bundle}") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
22 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
23 try: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
24 # --- KEY FIX 2: Load the dictionary and extract both pieces --- |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
25 bundle = joblib.load(args.bundle) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
26 rf_model = bundle['model'] |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
27 training_features = bundle['features'] |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
28 print(f"Successfully loaded model and {len(training_features)} features.") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
29 except Exception as e: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
30 print(f"FATAL: Error loading model bundle: {e}") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
31 sys.exit(1) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
32 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
33 print(f"Loading and processing input data: {args.input}") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
34 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
35 try: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
36 df = pd.read_csv(args.input, sep='\s+', header=None, engine='python') |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
37 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
38 # Your format is from 'mash screen', where the columns are: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
39 # Identity, Shared-hashes, Median-multiplicity, P-value, Query-ID |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
40 if len(df.columns) >= 5: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
41 print("--> Standard headerless Mash output detected.") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
42 # Keep only the first 5 columns to be safe |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
43 df = df.iloc[:, :5] |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
44 df.columns = ['Identity', 'Shared_Hashes', 'Median_Multiplicity', 'P_value', 'Plasmid_ID'] |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
45 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
46 # The 'Identity' is already the first column, just convert it to numeric |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
47 df['Identity'] = pd.to_numeric(df['Identity'], errors='coerce') |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
48 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
49 # We need to manually add the 'Run' column. For screen output, the Query-ID (isolate name) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
50 # is not present in the file itself. We must get it from the filename. |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
51 run_id = os.path.splitext(os.path.basename(args.input))[0] |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
52 df['Run'] = run_id |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
53 else: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
54 print(f"FATAL: Input file format not recognized. Expected at least 5 columns for Mash output, but got {len(df.columns)}.") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
55 sys.exit(1) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
56 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
57 df.dropna(subset=['Identity'], inplace=True) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
58 df['Run'] = df['Run'].astype(str).str.strip() |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
59 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
60 except Exception as e: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
61 print(f"FATAL: Error reading input file '{args.input}'. Error: {e}") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
62 sys.exit(1) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
63 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
64 print(f"Filtering features (Identity >= {args.threshold})...") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
65 filtered_df = df[df['Identity'] >= args.threshold].copy() |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
66 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
67 if filtered_df.empty: |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
68 print("Warning: No plasmid hits met the identity threshold. Cannot make a prediction.") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
69 sys.exit(0) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
70 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
71 new_data_matrix = filtered_df.pivot_table(index='Run', columns='Plasmid_ID', values='Identity', fill_value=0) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
72 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
73 print("Aligning input features with the trained model...") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
74 aligned_matrix = pd.DataFrame(0, index=new_data_matrix.index, columns=training_features) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
75 common_plasmids = new_data_matrix.columns.intersection(training_features) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
76 aligned_matrix[common_plasmids] = new_data_matrix[common_plasmids] |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
77 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
78 print(f"Making predictions for {len(aligned_matrix)} isolate(s)...") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
79 predictions = rf_model.predict(aligned_matrix) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
80 probabilities = rf_model.predict_proba(aligned_matrix) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
81 max_probs = probabilities.max(axis=1) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
82 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
83 results_df = pd.DataFrame({ |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
84 'Run': aligned_matrix.index, |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
85 'Predicted_Source': predictions, |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
86 'Confidence_Score': max_probs |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
87 }) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
88 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
89 results_df.to_csv(args.output, sep='\t', index=False) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
90 print(f"\n✅ Predictions complete! Saved to {args.output}") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
91 print("--- PREDICTION RESULTS ---") |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
92 print(results_df.to_string(index=False)) |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
93 |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
94 if __name__ == "__main__": |
|
f25631df0e9f
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
galaxytrakr
parents:
diff
changeset
|
95 main() |
