Mercurial > repos > galaxytrakr > plasmidtrakr
changeset 0:f25631df0e9f draft
planemo upload commit 25e4c800a5358b8615dac18ea5e908e31c534020
| author | galaxytrakr |
|---|---|
| date | Wed, 29 Apr 2026 15:04:37 +0000 |
| parents | |
| children | fe9ff3859d68 |
| files | plasmidtrakr.xml predict_source.py tool_data/plasmidtrakr.loc.sample tool_data_table_conf.xml.sample |
| diffstat | 4 files changed, 190 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/plasmidtrakr.xml Wed Apr 29 15:04:37 2026 +0000 @@ -0,0 +1,72 @@ +<tool id="plasmidtrakr" name="Predict Isolate Source" version="0.1.0"> + <description>Predicts isolate source from plasmid profiles using a trained machine learning model</description> + + <requirements> + <requirement type="package" version="1.5.3">pandas</requirement> + <requirement type="package" version="1.2.2">scikit-learn</requirement> + </requirements> + + <!-- FIXED: Added $ before __tool_directory__ --> + <version_command> + python '$__tool_directory__/predict_source.py' --version + </version_command> + + <!-- FIXED: Added $ before __tool_directory__ --> + <command detect_errors="exit_code"><![CDATA[ + python '$__tool_directory__/predict_source.py' + -i '$mash_input' + -b '$model_selection.path' + -t '$threshold' + -o '$prediction_output' + ]]></command> + + <inputs> + <param name="mash_input" type="data" format="tabular" label="Mash Screen Output" help="The tabular output file from the Galaxy 'mash screen' tool."/> + + <param name="model_selection" type="select" label="Select Prediction Model" help="Choose which trained model to use for prediction."> + <options from_data_table="plasmidtrakr_models"> + <validator type="no_options" message="No prediction models are configured. Please contact your Galaxy administrator." /> + </options> + </param> + + <param name="threshold" type="float" value="0.95" label="Mash Identity Threshold" help="Filter plasmid hits below this identity. Must match the threshold used for model training."/> + </inputs> + + <outputs> + <data name="prediction_output" format="tabular" label="Prediction for ${on_string} using ${model_selection.name}" /> + </outputs> + + <!-- FIXED: Cleaned up Markdown formatting in the help block (removed backslashes) --> + <help><![CDATA[ +**What it does** + +This tool takes the list of plasmid hits from the Galaxy **mash screen** tool and uses a pre-trained **machine learning model** to predict the original source of the isolate. + +**Workflow for Genome Assemblies** + +1. Go to the **mash screen** tool in Galaxy. +2. In the **"Single or Paired-end reads"** dropdown, select **"Single"**. +3. For the **"Select fastq dataset"** input, provide your **genome assembly FASTA file**. +4. Run the `mash screen` job against the appropriate plasmid database. +5. Use the tabular output from that job as the input for **this prediction tool**. +6. Select the desired prediction model from the dropdown menu. +7. Execute to get your prediction. + +**Output** + +A tabular file containing the isolate ID, the predicted source, and a confidence score. + ]]></help> + + <citations> + <citation type="bibtex"> + @misc{strain_2026_plasmidtrakr, + author = {Strain, Errol}, + title = {PlasmidTrakr: A tool for predicting isolate source from plasmid profiles}, + year = {2026}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/estrain/plasmidtrakr}} + } + </citation> + </citations> +</tool>
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/predict_source.py Wed Apr 29 15:04:37 2026 +0000 @@ -0,0 +1,96 @@ +import argparse +import sys +import os +import pandas as pd +import joblib +import warnings + +# Suppress scikit-learn warnings about feature names if they pop up +warnings.filterwarnings("ignore", category=UserWarning) + +def main(): + parser = argparse.ArgumentParser(description="Predict the source of an isolate using a trained Random Forest model and Mash distances.") + parser.add_argument("-i", "--input", required=True, help="Input Mash screen/dist file for one or more isolates.") + + # --- KEY FIX 1: Replaced -m and -f with a single -b (bundle) argument --- + parser.add_argument("-b", "--bundle", required=True, help="Path to the bundled model and features (.joblib file)") + + parser.add_argument("-t", "--threshold", type=float, default=0.95, help="Mash identity threshold (default: 0.95)") + parser.add_argument("-o", "--output", default="predictions.tsv", help="Output file for predictions (default: predictions.tsv)") + args = parser.parse_args() + + print(f"Loading model bundle: {args.bundle}") + + try: + # --- KEY FIX 2: Load the dictionary and extract both pieces --- + bundle = joblib.load(args.bundle) + rf_model = bundle['model'] + training_features = bundle['features'] + print(f"Successfully loaded model and {len(training_features)} features.") + except Exception as e: + print(f"FATAL: Error loading model bundle: {e}") + sys.exit(1) + + print(f"Loading and processing input data: {args.input}") + + try: + df = pd.read_csv(args.input, sep='\s+', header=None, engine='python') + + # Your format is from 'mash screen', where the columns are: + # Identity, Shared-hashes, Median-multiplicity, P-value, Query-ID + if len(df.columns) >= 5: + print("--> Standard headerless Mash output detected.") + # Keep only the first 5 columns to be safe + df = df.iloc[:, :5] + df.columns = ['Identity', 'Shared_Hashes', 'Median_Multiplicity', 'P_value', 'Plasmid_ID'] + + # The 'Identity' is already the first column, just convert it to numeric + df['Identity'] = pd.to_numeric(df['Identity'], errors='coerce') + + # We need to manually add the 'Run' column. For screen output, the Query-ID (isolate name) + # is not present in the file itself. We must get it from the filename. + run_id = os.path.splitext(os.path.basename(args.input))[0] + df['Run'] = run_id + else: + print(f"FATAL: Input file format not recognized. Expected at least 5 columns for Mash output, but got {len(df.columns)}.") + sys.exit(1) + + df.dropna(subset=['Identity'], inplace=True) + df['Run'] = df['Run'].astype(str).str.strip() + + except Exception as e: + print(f"FATAL: Error reading input file '{args.input}'. Error: {e}") + sys.exit(1) + + print(f"Filtering features (Identity >= {args.threshold})...") + filtered_df = df[df['Identity'] >= args.threshold].copy() + + if filtered_df.empty: + print("Warning: No plasmid hits met the identity threshold. Cannot make a prediction.") + sys.exit(0) + + new_data_matrix = filtered_df.pivot_table(index='Run', columns='Plasmid_ID', values='Identity', fill_value=0) + + print("Aligning input features with the trained model...") + aligned_matrix = pd.DataFrame(0, index=new_data_matrix.index, columns=training_features) + common_plasmids = new_data_matrix.columns.intersection(training_features) + aligned_matrix[common_plasmids] = new_data_matrix[common_plasmids] + + print(f"Making predictions for {len(aligned_matrix)} isolate(s)...") + predictions = rf_model.predict(aligned_matrix) + probabilities = rf_model.predict_proba(aligned_matrix) + max_probs = probabilities.max(axis=1) + + results_df = pd.DataFrame({ + 'Run': aligned_matrix.index, + 'Predicted_Source': predictions, + 'Confidence_Score': max_probs + }) + + results_df.to_csv(args.output, sep='\t', index=False) + print(f"\n✅ Predictions complete! Saved to {args.output}") + print("--- PREDICTION RESULTS ---") + print(results_df.to_string(index=False)) + +if __name__ == "__main__": + main()
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tool_data/plasmidtrakr.loc.sample Wed Apr 29 15:04:37 2026 +0000 @@ -0,0 +1,15 @@ +# This is a sample file that enables the PlasmidTrakr tool +# to use a directory of bundled machine learning models. +# The plasmidtrakr_models.loc file has this format (white space is a TAB): +# +# value <tab> name <tab> path +# <unique_id> <display_name> <bundle_path> +# +# For example, if you have the biological source model (v1) stored in +# /galaxy/tool-data/plasmidtrakr/rf_source_model_bundle.joblib, +# then the entry would look like this: +# +# rf_source_v1 Biological Source (v1) /galaxy/tool-data/plasmidtrakr/rf_source_model_bundle.joblib + +#rf_source_v1 Biological Source (v1) /path/to/your/tool-data/rf_plasmid_model_bundle.joblib +#rf_country_v1 Geographic Origin - Country (v1) /path/to/your/tool-data/rf_country_model_bundle.joblib
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/tool_data_table_conf.xml.sample Wed Apr 29 15:04:37 2026 +0000 @@ -0,0 +1,7 @@ +<tables> + <!-- Location of ML models for PlasmidTrakr --> + <table name="plasmidtrakr_models" comment_char="#"> + <columns>value, name, path</columns> + <file path="tool-data/plasmidtrakr.loc" /> + </table> +</tables> \ No newline at end of file
