comparison predict_source.py @ 2:954eccb7cc48 draft

planemo upload commit 8da69fa6fb55921a1b005486e59563f0c6956c2e
author galaxytrakr
date Wed, 29 Apr 2026 16:33:55 +0000
parents f25631df0e9f
children
comparison
equal deleted inserted replaced
1:fe9ff3859d68 2:954eccb7cc48
8 # Suppress scikit-learn warnings about feature names if they pop up 8 # Suppress scikit-learn warnings about feature names if they pop up
9 warnings.filterwarnings("ignore", category=UserWarning) 9 warnings.filterwarnings("ignore", category=UserWarning)
10 10
11 def main(): 11 def main():
12 parser = argparse.ArgumentParser(description="Predict the source of an isolate using a trained Random Forest model and Mash distances.") 12 parser = argparse.ArgumentParser(description="Predict the source of an isolate using a trained Random Forest model and Mash distances.")
13 parser.add_argument('--version', action='version', version='PlasmidTrakr v0.1.0')
14
13 parser.add_argument("-i", "--input", required=True, help="Input Mash screen/dist file for one or more isolates.") 15 parser.add_argument("-i", "--input", required=True, help="Input Mash screen/dist file for one or more isolates.")
14
15 # --- KEY FIX 1: Replaced -m and -f with a single -b (bundle) argument ---
16 parser.add_argument("-b", "--bundle", required=True, help="Path to the bundled model and features (.joblib file)") 16 parser.add_argument("-b", "--bundle", required=True, help="Path to the bundled model and features (.joblib file)")
17
18 parser.add_argument("-t", "--threshold", type=float, default=0.95, help="Mash identity threshold (default: 0.95)") 17 parser.add_argument("-t", "--threshold", type=float, default=0.95, help="Mash identity threshold (default: 0.95)")
19 parser.add_argument("-o", "--output", default="predictions.tsv", help="Output file for predictions (default: predictions.tsv)") 18 parser.add_argument("-o", "--output", default="predictions.tsv", help="Output file for predictions (default: predictions.tsv)")
20 args = parser.parse_args() 19 args = parser.parse_args()
21 20
22 print(f"Loading model bundle: {args.bundle}") 21 print(f"Loading model bundle: {args.bundle}")