Mercurial > repos > galaxytrakr > plasmidtrakr
comparison predict_source.py @ 0:f25631df0e9f draft
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
| author | galaxytrakr |
|---|---|
| date | Wed, 29 Apr 2026 15:04:37 +0000 |
| parents | |
| children | 954eccb7cc48 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:f25631df0e9f |
|---|---|
| 1 import argparse | |
| 2 import sys | |
| 3 import os | |
| 4 import pandas as pd | |
| 5 import joblib | |
| 6 import warnings | |
| 7 | |
| 8 # Suppress scikit-learn warnings about feature names if they pop up | |
| 9 warnings.filterwarnings("ignore", category=UserWarning) | |
| 10 | |
| 11 def main(): | |
| 12 parser = argparse.ArgumentParser(description="Predict the source of an isolate using a trained Random Forest model and Mash distances.") | |
| 13 parser.add_argument("-i", "--input", required=True, help="Input Mash screen/dist file for one or more isolates.") | |
| 14 | |
| 15 # --- KEY FIX 1: Replaced -m and -f with a single -b (bundle) argument --- | |
| 16 parser.add_argument("-b", "--bundle", required=True, help="Path to the bundled model and features (.joblib file)") | |
| 17 | |
| 18 parser.add_argument("-t", "--threshold", type=float, default=0.95, help="Mash identity threshold (default: 0.95)") | |
| 19 parser.add_argument("-o", "--output", default="predictions.tsv", help="Output file for predictions (default: predictions.tsv)") | |
| 20 args = parser.parse_args() | |
| 21 | |
| 22 print(f"Loading model bundle: {args.bundle}") | |
| 23 | |
| 24 try: | |
| 25 # --- KEY FIX 2: Load the dictionary and extract both pieces --- | |
| 26 bundle = joblib.load(args.bundle) | |
| 27 rf_model = bundle['model'] | |
| 28 training_features = bundle['features'] | |
| 29 print(f"Successfully loaded model and {len(training_features)} features.") | |
| 30 except Exception as e: | |
| 31 print(f"FATAL: Error loading model bundle: {e}") | |
| 32 sys.exit(1) | |
| 33 | |
| 34 print(f"Loading and processing input data: {args.input}") | |
| 35 | |
| 36 try: | |
| 37 df = pd.read_csv(args.input, sep='\s+', header=None, engine='python') | |
| 38 | |
| 39 # Your format is from 'mash screen', where the columns are: | |
| 40 # Identity, Shared-hashes, Median-multiplicity, P-value, Query-ID | |
| 41 if len(df.columns) >= 5: | |
| 42 print("--> Standard headerless Mash output detected.") | |
| 43 # Keep only the first 5 columns to be safe | |
| 44 df = df.iloc[:, :5] | |
| 45 df.columns = ['Identity', 'Shared_Hashes', 'Median_Multiplicity', 'P_value', 'Plasmid_ID'] | |
| 46 | |
| 47 # The 'Identity' is already the first column, just convert it to numeric | |
| 48 df['Identity'] = pd.to_numeric(df['Identity'], errors='coerce') | |
| 49 | |
| 50 # We need to manually add the 'Run' column. For screen output, the Query-ID (isolate name) | |
| 51 # is not present in the file itself. We must get it from the filename. | |
| 52 run_id = os.path.splitext(os.path.basename(args.input))[0] | |
| 53 df['Run'] = run_id | |
| 54 else: | |
| 55 print(f"FATAL: Input file format not recognized. Expected at least 5 columns for Mash output, but got {len(df.columns)}.") | |
| 56 sys.exit(1) | |
| 57 | |
| 58 df.dropna(subset=['Identity'], inplace=True) | |
| 59 df['Run'] = df['Run'].astype(str).str.strip() | |
| 60 | |
| 61 except Exception as e: | |
| 62 print(f"FATAL: Error reading input file '{args.input}'. Error: {e}") | |
| 63 sys.exit(1) | |
| 64 | |
| 65 print(f"Filtering features (Identity >= {args.threshold})...") | |
| 66 filtered_df = df[df['Identity'] >= args.threshold].copy() | |
| 67 | |
| 68 if filtered_df.empty: | |
| 69 print("Warning: No plasmid hits met the identity threshold. Cannot make a prediction.") | |
| 70 sys.exit(0) | |
| 71 | |
| 72 new_data_matrix = filtered_df.pivot_table(index='Run', columns='Plasmid_ID', values='Identity', fill_value=0) | |
| 73 | |
| 74 print("Aligning input features with the trained model...") | |
| 75 aligned_matrix = pd.DataFrame(0, index=new_data_matrix.index, columns=training_features) | |
| 76 common_plasmids = new_data_matrix.columns.intersection(training_features) | |
| 77 aligned_matrix[common_plasmids] = new_data_matrix[common_plasmids] | |
| 78 | |
| 79 print(f"Making predictions for {len(aligned_matrix)} isolate(s)...") | |
| 80 predictions = rf_model.predict(aligned_matrix) | |
| 81 probabilities = rf_model.predict_proba(aligned_matrix) | |
| 82 max_probs = probabilities.max(axis=1) | |
| 83 | |
| 84 results_df = pd.DataFrame({ | |
| 85 'Run': aligned_matrix.index, | |
| 86 'Predicted_Source': predictions, | |
| 87 'Confidence_Score': max_probs | |
| 88 }) | |
| 89 | |
| 90 results_df.to_csv(args.output, sep='\t', index=False) | |
| 91 print(f"\n✅ Predictions complete! Saved to {args.output}") | |
| 92 print("--- PREDICTION RESULTS ---") | |
| 93 print(results_df.to_string(index=False)) | |
| 94 | |
| 95 if __name__ == "__main__": | |
| 96 main() |
