changeset 0:f25631df0e9f draft

planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
author galaxytrakr
date Wed, 29 Apr 2026 15:04:37 +0000
parents
children fe9ff3859d68
files plasmidtrakr.xml predict_source.py tool_data/plasmidtrakr.loc.sample tool_data_table_conf.xml.sample
diffstat 4 files changed, 190 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/plasmidtrakr.xml	Wed Apr 29 15:04:37 2026 +0000
@@ -0,0 +1,72 @@
+<tool id="plasmidtrakr" name="Predict Isolate Source" version="0.1.0">
+    <description>Predicts isolate source from plasmid profiles using a trained machine learning model</description>
+
+    <requirements>
+        <requirement type="package" version="1.5.3">pandas</requirement>
+        <requirement type="package" version="1.2.2">scikit-learn</requirement>
+    </requirements>
+
+    <!-- FIXED: Added $ before __tool_directory__ -->
+    <version_command>
+        python '$__tool_directory__/predict_source.py' --version
+    </version_command>
+
+    <!-- FIXED: Added $ before __tool_directory__ -->
+    <command detect_errors="exit_code"><![CDATA[
+        python '$__tool_directory__/predict_source.py'
+            -i '$mash_input'
+            -b '$model_selection.path'
+            -t '$threshold'
+            -o '$prediction_output'
+    ]]></command>
+
+    <inputs>
+        <param name="mash_input" type="data" format="tabular" label="Mash Screen Output" help="The tabular output file from the Galaxy 'mash screen' tool."/>
+
+        <param name="model_selection" type="select" label="Select Prediction Model" help="Choose which trained model to use for prediction.">
+            <options from_data_table="plasmidtrakr_models">
+                <validator type="no_options" message="No prediction models are configured. Please contact your Galaxy administrator." />
+            </options>
+        </param>
+
+        <param name="threshold" type="float" value="0.95" label="Mash Identity Threshold" help="Filter plasmid hits below this identity. Must match the threshold used for model training."/>
+    </inputs>
+
+    <outputs>
+        <data name="prediction_output" format="tabular" label="Prediction for ${on_string} using ${model_selection.name}" />
+    </outputs>
+
+    <!-- FIXED: Cleaned up Markdown formatting in the help block (removed backslashes) -->
+    <help><![CDATA[
+**What it does**
+
+This tool takes the list of plasmid hits from the Galaxy **mash screen** tool and uses a pre-trained **machine learning model** to predict the original source of the isolate.
+
+**Workflow for Genome Assemblies**
+
+1.  Go to the **mash screen** tool in Galaxy.
+2.  In the **"Single or Paired-end reads"** dropdown, select **"Single"**.
+3.  For the **"Select fastq dataset"** input, provide your **genome assembly FASTA file**.
+4.  Run the `mash screen` job against the appropriate plasmid database.
+5.  Use the tabular output from that job as the input for **this prediction tool**.
+6.  Select the desired prediction model from the dropdown menu.
+7.  Execute to get your prediction.
+
+**Output**
+
+A tabular file containing the isolate ID, the predicted source, and a confidence score.
+    ]]></help>
+
+    <citations>
+        <citation type="bibtex">
+            @misc{strain_2026_plasmidtrakr,
+                author = {Strain, Errol},
+                title = {PlasmidTrakr: A tool for predicting isolate source from plasmid profiles},
+                year = {2026},
+                publisher = {GitHub},
+                journal = {GitHub repository},
+                howpublished = {\url{https://github.com/estrain/plasmidtrakr}}
+            }
+        </citation>
+    </citations>
+</tool>
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/predict_source.py	Wed Apr 29 15:04:37 2026 +0000
@@ -0,0 +1,96 @@
+import argparse
+import sys
+import os
+import pandas as pd
+import joblib
+import warnings
+
+# Suppress scikit-learn warnings about feature names if they pop up
+warnings.filterwarnings("ignore", category=UserWarning)
+
+def main():
+    parser = argparse.ArgumentParser(description="Predict the source of an isolate using a trained Random Forest model and Mash distances.")
+    parser.add_argument("-i", "--input", required=True, help="Input Mash screen/dist file for one or more isolates.")
+    
+    # --- KEY FIX 1: Replaced -m and -f with a single -b (bundle) argument ---
+    parser.add_argument("-b", "--bundle", required=True, help="Path to the bundled model and features (.joblib file)")
+    
+    parser.add_argument("-t", "--threshold", type=float, default=0.95, help="Mash identity threshold (default: 0.95)")
+    parser.add_argument("-o", "--output", default="predictions.tsv", help="Output file for predictions (default: predictions.tsv)")
+    args = parser.parse_args()
+
+    print(f"Loading model bundle: {args.bundle}")
+    
+    try:
+        # --- KEY FIX 2: Load the dictionary and extract both pieces ---
+        bundle = joblib.load(args.bundle)
+        rf_model = bundle['model']
+        training_features = bundle['features']
+        print(f"Successfully loaded model and {len(training_features)} features.")
+    except Exception as e:
+        print(f"FATAL: Error loading model bundle: {e}")
+        sys.exit(1)
+
+    print(f"Loading and processing input data: {args.input}")
+    
+    try:
+        df = pd.read_csv(args.input, sep='\s+', header=None, engine='python')
+
+        # Your format is from 'mash screen', where the columns are:
+        # Identity, Shared-hashes, Median-multiplicity, P-value, Query-ID
+        if len(df.columns) >= 5:
+            print("--> Standard headerless Mash output detected.")
+            # Keep only the first 5 columns to be safe
+            df = df.iloc[:, :5]
+            df.columns = ['Identity', 'Shared_Hashes', 'Median_Multiplicity', 'P_value', 'Plasmid_ID']
+            
+            # The 'Identity' is already the first column, just convert it to numeric
+            df['Identity'] = pd.to_numeric(df['Identity'], errors='coerce')
+            
+            # We need to manually add the 'Run' column. For screen output, the Query-ID (isolate name)
+            # is not present in the file itself. We must get it from the filename.
+            run_id = os.path.splitext(os.path.basename(args.input))[0]
+            df['Run'] = run_id
+        else:
+            print(f"FATAL: Input file format not recognized. Expected at least 5 columns for Mash output, but got {len(df.columns)}.")
+            sys.exit(1)
+            
+        df.dropna(subset=['Identity'], inplace=True)
+        df['Run'] = df['Run'].astype(str).str.strip()
+        
+    except Exception as e: 
+        print(f"FATAL: Error reading input file '{args.input}'. Error: {e}")
+        sys.exit(1)
+
+    print(f"Filtering features (Identity >= {args.threshold})...")
+    filtered_df = df[df['Identity'] >= args.threshold].copy()
+    
+    if filtered_df.empty:
+        print("Warning: No plasmid hits met the identity threshold. Cannot make a prediction.")
+        sys.exit(0)
+
+    new_data_matrix = filtered_df.pivot_table(index='Run', columns='Plasmid_ID', values='Identity', fill_value=0)
+    
+    print("Aligning input features with the trained model...")
+    aligned_matrix = pd.DataFrame(0, index=new_data_matrix.index, columns=training_features)
+    common_plasmids = new_data_matrix.columns.intersection(training_features)
+    aligned_matrix[common_plasmids] = new_data_matrix[common_plasmids]
+
+    print(f"Making predictions for {len(aligned_matrix)} isolate(s)...")
+    predictions = rf_model.predict(aligned_matrix)
+    probabilities = rf_model.predict_proba(aligned_matrix)
+    max_probs = probabilities.max(axis=1)
+
+    results_df = pd.DataFrame({
+        'Run': aligned_matrix.index,
+        'Predicted_Source': predictions,
+        'Confidence_Score': max_probs
+    })
+
+    results_df.to_csv(args.output, sep='\t', index=False)
+    print(f"\n✅ Predictions complete! Saved to {args.output}")
+    print("--- PREDICTION RESULTS ---")
+    print(results_df.to_string(index=False))
+
+if __name__ == "__main__":
+    main()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tool_data/plasmidtrakr.loc.sample	Wed Apr 29 15:04:37 2026 +0000
@@ -0,0 +1,15 @@
+# This is a sample file that enables the PlasmidTrakr tool
+# to use a directory of bundled machine learning models.
+# The plasmidtrakr_models.loc file has this format (white space is a TAB):
+#   
+# value <tab> name <tab> path 
+# <unique_id>   <display_name>   <bundle_path>
+#
+# For example, if you have the biological source model (v1) stored in
+# /galaxy/tool-data/plasmidtrakr/rf_source_model_bundle.joblib,
+# then the entry would look like this:
+#
+# rf_source_v1    Biological Source (v1)    /galaxy/tool-data/plasmidtrakr/rf_source_model_bundle.joblib
+
+#rf_source_v1	Biological Source (v1)	/path/to/your/tool-data/rf_plasmid_model_bundle.joblib
+#rf_country_v1	Geographic Origin - Country (v1)	/path/to/your/tool-data/rf_country_model_bundle.joblib
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/tool_data_table_conf.xml.sample	Wed Apr 29 15:04:37 2026 +0000
@@ -0,0 +1,7 @@
+<tables>
+    <!-- Location of ML models for PlasmidTrakr -->
+    <table name="plasmidtrakr_models" comment_char="#">
+        <columns>value, name, path</columns>
+        <file path="tool-data/plasmidtrakr.loc" />
+    </table>
+</tables>
\ No newline at end of file