Mercurial > repos > bgruening > vpt_segment
changeset 3:0811613196a2 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/vpt commit 1180d1e85b63beb3d2e7fd6e9c73b054a9348e7f
| author | bgruening |
|---|---|
| date | Mon, 11 Aug 2025 08:53:09 +0000 |
| parents | a88a1e78702e |
| children | |
| files | macros.xml safetensors_convert_pytorch.py vpt_segment.xml |
| diffstat | 3 files changed, 139 insertions(+), 20 deletions(-) [+] |
line wrap: on
line diff
--- a/macros.xml Sun Jun 22 12:25:48 2025 +0000 +++ b/macros.xml Mon Aug 11 08:53:09 2025 +0000 @@ -1,10 +1,10 @@ <macros> <token name="@TOOL_VERSION@">1.3.0</token> - <token name="@VERSION_SUFFIX@">2</token> + <token name="@VERSION_SUFFIX@">3</token> <token name="@PROFILE@">23.0</token> <xml name="requirements"> <requirements> - <container type="docker">quay.io/bgruening/vpt:1.3.0-1</container> + <container type="docker">quay.io/bgruening/vpt:1.3.0-2</container> <yield/> </requirements> </xml> @@ -17,7 +17,7 @@ </creator> </xml> <token name="@CMD@"><![CDATA[ - mkdir -p 'input/images' 'output/' && + mkdir -p 'input/images' 'input/model' 'output/' && #for $image in $input_images: ln -s '$image' 'input/images/${image.element_identifier}.${image.ext}' && #end for @@ -25,7 +25,6 @@ ]]></token> <token name="@COMMON_ARGS@"><![CDATA[ --processes \${GALAXY_SLOTS:-10} - --verbose --log-file 'output/log' ]]> </token>
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/safetensors_convert_pytorch.py Mon Aug 11 08:53:09 2025 +0000 @@ -0,0 +1,113 @@ +#!/usr/bin/env python + +import argparse +import os +import sys + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + + +def convert_pt_to_safetensors(input_path, output_path): + """Convert PyTorch model to SafeTensors format.""" + print(f"Loading PyTorch model from: {input_path}") + + try: + # Load the PyTorch model + state_dict = torch.load(input_path) + + print(f"Converting {len(state_dict)} tensors to SafeTensors format...") + + # Save as SafeTensors + save_file(state_dict, output_path) + print(f"Successfully converted to SafeTensors: {output_path}") + + except Exception as e: + print(f"Error converting PyTorch to SafeTensors: {e}") + sys.exit(1) + + +def convert_safetensors_to_pt(input_path, output_path): + """Convert SafeTensors to PyTorch format.""" + print(f"Loading SafeTensors model from: {input_path}") + + try: + state_dict = {} + + # Load SafeTensors + with safe_open(input_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + + print(f"Converting {len(state_dict)} tensors to PyTorch format...") + + # Save as PyTorch + torch.save(state_dict, output_path) + print(f"Successfully converted to PyTorch: {output_path}") + + except Exception as e: + print(f"Error converting SafeTensors to PyTorch: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Convert between PyTorch and SafeTensors", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python converter.py --convert pt_safe -i model.pt -o model.safetensors + python converter.py --convert safe_pt -i model.safetensors -o model.pt + """ + ) + + parser.add_argument( + "--convert", + required=True, + choices=["pt_safe", "safe_pt"], + help=( + "Conversion direction: 'pt_safe' (PyTorch to SafeTensors) " + "or 'safe_pt' (SafeTensors to PyTorch)" + ) + ) + + parser.add_argument( + "-i", "--input", + required=True, + help="Input model file path" + ) + + parser.add_argument( + "-o", "--output", + required=True, + help="Output model file path" + ) + + args = parser.parse_args() + + # Validate input file exists + if not os.path.exists(args.input): + print(f"Error: Input file does not exist: {args.input}") + sys.exit(1) + + # Create output directory if it doesn't exist + output_dir = os.path.dirname(args.output) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + print(f"Created output directory: {output_dir}") + + # Perform conversion + if args.convert == "pt_safe": + + convert_pt_to_safetensors(args.input, args.output) + + elif args.convert == "safe_pt": + + convert_safetensors_to_pt(args.input, args.output) + + print("Conversion completed successfully!") + + +if __name__ == "__main__": + main()
--- a/vpt_segment.xml Sun Jun 22 12:25:48 2025 +0000 +++ b/vpt_segment.xml Mon Aug 11 08:53:09 2025 +0000 @@ -5,10 +5,19 @@ </macros> <expand macro="requirements" /> <expand macro="creator" /> + <required_files> + <include path="safetensors_convert_pytorch.py"/> + </required_files> <command detect_errors="exit_code"><![CDATA[ @CMD@ ln -s '$detected_transcripts' 'input/detected_transcripts.csv' && ln -s '$vpt_config_file' 'input/vpt_config_file.json' && + #for $i, $task in enumerate($vpt_config.segmentation_tasks.segmentation_task): + #if $task.segmentation_properties.custom_weights: + ln -s '$task.segmentation_properties.custom_weights' 'input/model/custom_weights_${i}.safetensors' && + python '$__tool_directory__/safetensors_convert_pytorch.py' --convert safe_pt -i 'input/model/custom_weights_${i}.safetensors' -o 'input/model/custom_weights_${i}.pth' && + #end if + #end for ### segmentation vpt @COMMON_ARGS@ @@ -105,7 +114,7 @@ #end if "model_dimensions": "$task.segmentation_properties.model_dimensions", #if $task.segmentation_properties.custom_weights: - "custom_weights": "$task.segmentation_properties.custom_weights", + "custom_weights": "input/model/custom_weights_${i}.pth", #else: "custom_weights": null, #end if @@ -351,7 +360,7 @@ <option value="2D">2D</option> <option value="3D">3D</option> </param> - <param argument="--custom_weights" type="data" format="data" optional="true" label="Custom Cellpose weights file"/> + <param argument="--custom_weights" type="data" format="safetensors" optional="true" label="Custom Cellpose weights file"/> <section name="channel_map" title="Channel map" expanded="true" help="Specify the channel map for the segmentation task with the corresponding task input data"> <conditional name="channel_red_conditional"> <param name="red" type="select" label="The stain that will be used for red channel"> @@ -553,21 +562,20 @@ <assert_contents> <has_text_matching expression="EntityID,fov,volume,center_x,center_y,min_x,min_y,max_x,max_y,anisotropy,transcript_count,perimeter_area_ratio,solidity"/> <has_text_matching expression=",752.854533232[0-9]+,315.600965641[0-9]+,5.50107977386[0-9]+,310.1669732237[0-9]+,"/> - <has_text_matching expression=",416.540438981[0-9]+,68.7110642977[0-9]+,9.67291568732[0-9]+,64.4947481953[0-9]+,"/> + <has_text_matching expression=",777.13824141[0-9]+,70.4487975801[0-9]+,1.976992703078[0-9]+,65.5065863409[0-9]+,"/> </assert_contents> </element> <element name="detected_transcripts"> <assert_contents> <has_text_matching expression=",barcode_id,global_x,global_y,global_z,x,y,fov,gene,transcript_id,cell_id"/> - <has_text_matching expression="86,10,285.00012,10.13364,0.0,816.0,153.0,0,AKAP11,ENST00000025301.3"/> - <has_text_matching expression="229,10,274.95612,35.72964,0.0,723.0,390.0,0,AKAP11,ENST00000025301.3"/> + <has_text_matching expression="63,10,370.95007,5.520504,0.0,1611.833,110.285774,0,AKAP11,ENST00000025301.3"/> + <has_text_matching expression="68,10,355.46716,6.4616404,0.0,1468.4725,119.0,0,AKAP11,ENST00000025301.3"/> </assert_contents> </element> <element name="metrics"> <assert_contents> <has_text_matching expression=",Cell count,Cell volume"/> - <has_text_matching expression="All cells,814,1572.4,1533.3,735.8,654.0,92.1,108.0,96.4,68.2,0.0"/> - <has_text_matching expression="Cells after filtering,700,1694.1,1667.9,851.1,810.5,104.5,111.0,95.9,63.2,0.06"/> + <has_n_lines n="3"/> </assert_contents> </element> <element name="sum_signals"> @@ -762,8 +770,7 @@ <element name="metrics"> <assert_contents> <has_text_matching expression=",Cell count,Cell volume"/> - <has_text_matching expression="All cells,811,1588.0,1549.9,740.5,669.0,92.2,109.0,96.7,68.6,0.0"/> - <has_text_matching expression="Cells after filtering,698,1711.6,1705.9,856.1,817.5,104.6,111.0,96.2,63.7,0.06"/> + <has_n_lines n="3"/> </assert_contents> </element> <element name="sum_signals"> @@ -834,7 +841,7 @@ <section name="segmentation_properties"> <param name="model" value="null"/> <param name="model_dimensions" value="2D"/> - <param name="custom_weights" location="https://zenodo.org/records/15319018/files/CP_20230830_093420"/> + <param name="custom_weights" location="https://zenodo.org/records/16784480/files/CP_20230830_093420.safetensors"/> <section name="channel_map"> <conditional name="channel_red_conditional"> <param name="red" value="Cellbound1"/> @@ -905,14 +912,14 @@ <element name="cell_by_gene"> <assert_contents> <has_text_matching expression="cell,AKAP11,CBX5,CCDC113"/> - <has_text_matching expression="6,11,4,0,5,2,1,0,1,2,1,2"/> - <has_text_matching expression="0,1,0,0,0,0,1,0,0,4"/> + <has_text_matching expression="6,10,4,0,5,2,1,0,1,2"/> + <has_text_matching expression="0,1,0,0,0,0,1,0,0,4,0"/> </assert_contents> </element> <element name="cell_metadata"> <assert_contents> <has_text_matching expression="EntityID,fov,volume,center_x,center_y,min_x,min_y,max_x,max_y,anisotropy,transcript_count,perimeter_area_ratio,solidity"/> - <has_text_matching expression="168.0683921161[0-9]+,246.3592447472[0-9]+,-1.507707574380[0-9]+,243.9322371166[0-9]+"/> + <has_text_matching expression="167.0695914980[0-9]+,246.371389688[0-9]+,-1.502969079760[0-9]+,243.9322371166[0-9]+"/> <has_text_matching expression="1540.721726928[0-9]+,412.175889439[0-9]+,2.951632912099[0-9]+,404.3500198495[0-9]+"/> </assert_contents> </element> @@ -926,14 +933,13 @@ <element name="metrics"> <assert_contents> <has_text_matching expression=",Cell count,Cell volume"/> - <has_text_matching expression="All cells,93[0-9]+,121\d*(\.\d+)?,118\d*(\.\d+)?,60\d*(\.\d+)?,49\d*(\.\d+)?,7\d*(\.\d+)?,10\d*(\.\d+)?,9\d*(\.\d+)?,6\d*(\.\d+)?,0.0"/> - <has_text_matching expression="Cells after filtering,659,1532.9,1553.5,850.2,837.0,104.8,112.0,90.2,53.8,0.15"/> + <has_n_lines n="3"/> </assert_contents> </element> <element name="sum_signals"> <assert_contents> <has_text_matching expression="Cellbound2_raw,Cellbound2_high_pass"/> - <has_text_matching expression="73446567.0,1605103.666715[0-9]+"/> + <has_text_matching expression="72859278.0,1502895.008866[0-9]+"/> <has_text_matching expression="1642190888.0,16154095.44810[0-9]+"/> </assert_contents> </element> @@ -972,6 +978,7 @@ <has_text_matching expression="null"/> <has_text_matching expression="CLAHE"/> <has_text_matching expression="harmonize"/> + <has_text_matching expression="input/model/custom_weights_0.pth"/> </assert_contents> </element> </output_collection>
