diff scimap_spatial.py @ 3:8c55377d7f06 draft

planemo upload for repository https://github.com/goeckslab/tools-mti/tree/main/tools/scimap commit 49210b00535415865694ddbec16238d8cf5e6bb0
author goeckslab
date Wed, 26 Jun 2024 15:27:13 +0000
parents 8ca435ec19be
children
line wrap: on
line diff
--- a/scimap_spatial.py	Mon Jun 10 18:45:21 2024 +0000
+++ b/scimap_spatial.py	Wed Jun 26 15:27:13 2024 +0000
@@ -2,6 +2,7 @@
 import json
 import warnings
 
+import pandas as pd
 import scimap as sm
 from anndata import read_h5ad
 
@@ -28,6 +29,8 @@
     tool_func = getattr(sm.tl, tool)
 
     options = params['analyses']['options']
+
+    # tool specific pre-processing
     if tool == 'cluster':
         options['method'] = params['analyses']['method']
         subset_genes = options.pop('subset_genes')
@@ -38,15 +41,42 @@
         if sub_cluster_group:
             options['sub_cluster_group'] = \
                 [x.strip() for x in sub_cluster_group.split(',')]
+    elif tool == 'spatial_lda':
+        max_weight_assignment = options.pop('max_weight_assignment')
 
     for k, v in options.items():
         if v == '':
             options[k] = None
 
+    # tool execution
     tool_func(adata, **options)
 
+    # spatial LDA post-processing
     if tool == 'spatial_lda':
-        adata.uns.pop('spatial_lda_model')
+
+        if max_weight_assignment:
+            # assign cell to a motif based on maximum weight
+            adata.uns['spatial_lda']['neighborhood_motif'] = \
+                adata.uns['spatial_lda'].idxmax(axis=1)
+
+            # merge motif assignment into adata.obs
+            adata.obs = pd.merge(
+                adata.obs,
+                adata.uns['spatial_lda']['neighborhood_motif'],
+                left_index=True,
+                right_index=True
+            )
+
+        # write out LDA results as tabular files
+        # so they're accessible to Galaxy users
+        adata.uns['spatial_lda'].reset_index().to_csv(
+            'lda_weights.txt', sep='\t', index=False)
+        adata.uns['spatial_lda_probability'].T.reset_index(
+            names='motif').to_csv(
+                'lda_probabilities.txt', sep='\t', index=False)
+
+        if 'spatial_lda_model' in adata.uns:
+            adata.uns.pop('spatial_lda_model')
 
     adata.write(output)