diff COBRAxy/utils/model_utils.py @ 490:c6ea189ea7e9 draft default tip

Uploaded
author francesco_lapi
date Mon, 29 Sep 2025 15:13:21 +0000
parents a6e45049c1b9
children
line wrap: on
line diff
--- a/COBRAxy/utils/model_utils.py	Mon Sep 29 10:33:26 2025 +0000
+++ b/COBRAxy/utils/model_utils.py	Mon Sep 29 15:13:21 2025 +0000
@@ -18,6 +18,100 @@
 import utils.rule_parsing  as rulesUtils
 import utils.reaction_parsing as reactionUtils
 from cobra import Model as cobraModel, Reaction, Metabolite
+import sys
+
+
+############################ check_methods ####################################
+def gene_type(l :str, name :str) -> str:
+    """
+    Determine the type of gene ID.
+
+    Args:
+        l (str): The gene identifier to check.
+        name (str): The name of the dataset, used in error messages.
+
+    Returns:
+        str: The type of gene ID ('hugo_id', 'ensembl_gene_id', 'symbol', or 'entrez_id').
+
+    Raises:
+        sys.exit: If the gene ID type is not supported, the execution is aborted.
+    """
+    if check_hgnc(l):
+        return 'hugo_id'
+    elif check_ensembl(l):
+        return 'ensembl_gene_id'
+    elif check_symbol(l):
+        return 'symbol'
+    elif check_entrez(l):
+        return 'entrez_id'
+    else:
+        sys.exit('Execution aborted:\n' +
+                 'gene ID type in ' + name + ' not supported. Supported ID'+
+                 'types are: HUGO ID, Ensemble ID, HUGO symbol, Entrez ID\n')
+
+def check_hgnc(l :str) -> bool:
+    """
+    Check if a gene identifier follows the HGNC format.
+
+    Args:
+        l (str): The gene identifier to check.
+
+    Returns:
+        bool: True if the gene identifier follows the HGNC format, False otherwise.
+    """
+    if len(l) > 5:
+        if (l.upper()).startswith('HGNC:'):
+            return l[5:].isdigit()
+        else:
+            return False
+    else:
+        return False
+
+def check_ensembl(l :str) -> bool:
+    """
+    Check if a gene identifier follows the Ensembl format.
+
+    Args:
+        l (str): The gene identifier to check.
+
+    Returns:
+        bool: True if the gene identifier follows the Ensembl format, False otherwise.
+    """
+    return l.upper().startswith('ENS')
+ 
+
+def check_symbol(l :str) -> bool:
+    """
+    Check if a gene identifier follows the symbol format.
+
+    Args:
+        l (str): The gene identifier to check.
+
+    Returns:
+        bool: True if the gene identifier follows the symbol format, False otherwise.
+    """
+    if len(l) > 0:
+        if l[0].isalpha() and l[1:].isalnum():
+            return True
+        else:
+            return False
+    else:
+        return False
+
+def check_entrez(l :str) -> bool:
+    """
+    Check if a gene identifier follows the Entrez ID format.
+
+    Args:
+        l (str): The gene identifier to check.
+
+    Returns:
+        bool: True if the gene identifier follows the Entrez ID format, False otherwise.
+    """ 
+    if len(l) > 0:
+        return l.isdigit()
+    else: 
+        return False
 
 ################################- DATA GENERATION -################################
 ReactionId = str
@@ -506,110 +600,96 @@
 
 def _simplify_boolean_expression(expr: str) -> str:
     """
-    Simplify a boolean expression by removing duplicates and redundancies.
-    Handles expressions with 'and' and 'or'.
+    Simplify a boolean expression by removing duplicates while strictly preserving semantics.
+    This function handles simple duplicates within parentheses while being conservative about
+    complex expressions that could change semantics.
     """
     if not expr or not expr.strip():
         return expr
     
-    # normalize operators
+    # Normalize operators and whitespace
     expr = expr.replace(' AND ', ' and ').replace(' OR ', ' or ')
+    expr = ' '.join(expr.split())  # Normalize whitespace
     
-    # recursive helper to process expressions
-    def process_expression(s: str) -> str:
-        s = s.strip()
-        if not s:
-            return s
-            
-    # handle parentheses
-        while '(' in s:
-            # find the innermost parentheses
-            start = -1
-            for i, c in enumerate(s):
-                if c == '(':
-                    start = i
-                elif c == ')' and start != -1:
-                    # process inner content
-                    inner = s[start+1:i]
-                    processed_inner = process_expression(inner)
-                    s = s[:start] + processed_inner + s[i+1:]
-                    break
-            else:
-                break
-        
-    # split by 'or' at top level
-        or_parts = []
-        current_part = ""
-        paren_count = 0
-        
-        tokens = s.split()
-        i = 0
-        while i < len(tokens):
-            token = tokens[i]
-            if token == 'or' and paren_count == 0:
-                if current_part.strip():
-                    or_parts.append(current_part.strip())
-                current_part = ""
-            else:
-                if token.count('(') > token.count(')'):
-                    paren_count += token.count('(') - token.count(')')
-                elif token.count(')') > token.count('('):
-                    paren_count -= token.count(')') - token.count('(')
-                current_part += token + " "
-            i += 1
+    def simplify_parentheses_content(match_obj):
+        """Helper function to simplify content within parentheses."""
+        content = match_obj.group(1)  # Content inside parentheses
         
-        if current_part.strip():
-            or_parts.append(current_part.strip())
-        
-    # process each OR part
-        processed_or_parts = []
-        for or_part in or_parts:
-            # split by 'and' within each OR part
-            and_parts = []
-            current_and = ""
-            paren_count = 0
+        # Only simplify if it's a pure OR or pure AND chain
+        if ' or ' in content and ' and ' not in content:
+            # Pure OR chain - safe to deduplicate
+            parts = [p.strip() for p in content.split(' or ') if p.strip()]
+            unique_parts = []
+            seen = set()
+            for part in parts:
+                if part not in seen:
+                    unique_parts.append(part)
+                    seen.add(part)
             
-            and_tokens = or_part.split()
-            j = 0
-            while j < len(and_tokens):
-                token = and_tokens[j]
-                if token == 'and' and paren_count == 0:
-                    if current_and.strip():
-                        and_parts.append(current_and.strip())
-                    current_and = ""
-                else:
-                    if token.count('(') > token.count(')'):
-                        paren_count += token.count('(') - token.count(')')
-                    elif token.count(')') > token.count('('):
-                        paren_count -= token.count(')') - token.count('(')
-                    current_and += token + " "
-                j += 1
+            if len(unique_parts) == 1:
+                return unique_parts[0]  # Remove unnecessary parentheses for single items
+            else:
+                return '(' + ' or '.join(unique_parts) + ')'
+                
+        elif ' and ' in content and ' or ' not in content:
+            # Pure AND chain - safe to deduplicate  
+            parts = [p.strip() for p in content.split(' and ') if p.strip()]
+            unique_parts = []
+            seen = set()
+            for part in parts:
+                if part not in seen:
+                    unique_parts.append(part)
+                    seen.add(part)
             
-            if current_and.strip():
-                and_parts.append(current_and.strip())
-            
-            # deduplicate AND parts
-            unique_and_parts = list(dict.fromkeys(and_parts))  # mantiene l'ordine
-            
-            if len(unique_and_parts) == 1:
-                processed_or_parts.append(unique_and_parts[0])
-            elif len(unique_and_parts) > 1:
-                processed_or_parts.append(" and ".join(unique_and_parts))
+            if len(unique_parts) == 1:
+                return unique_parts[0]  # Remove unnecessary parentheses for single items
+            else:
+                return '(' + ' and '.join(unique_parts) + ')'
+        else:
+            # Mixed operators or single item - return with parentheses as-is
+            return '(' + content + ')'
+    
+    def remove_duplicates_simple(parts_str: str, separator: str) -> str:
+        """Remove duplicates from a simple chain of operations."""
+        parts = [p.strip() for p in parts_str.split(separator) if p.strip()]
         
-    # deduplicate OR parts
-        unique_or_parts = list(dict.fromkeys(processed_or_parts))
+        # Remove duplicates while preserving order
+        unique_parts = []
+        seen = set()
+        for part in parts:
+            if part not in seen:
+                unique_parts.append(part)
+                seen.add(part)
         
-        if len(unique_or_parts) == 1:
-            return unique_or_parts[0]
-        elif len(unique_or_parts) > 1:
-            return " or ".join(unique_or_parts)
+        if len(unique_parts) == 1:
+            return unique_parts[0]
         else:
-            return ""
+            return f' {separator} '.join(unique_parts)
     
     try:
-        return process_expression(expr)
+        import re
+        
+        # First, simplify content within parentheses
+        # This handles cases like (A or A) -> A and (B and B) -> B
+        expr_simplified = re.sub(r'\(([^()]+)\)', simplify_parentheses_content, expr)
+        
+        # Check if the resulting expression has mixed operators  
+        has_and = ' and ' in expr_simplified
+        has_or = ' or ' in expr_simplified
+        
+        # Only simplify top-level if it's pure AND or pure OR
+        if has_and and not has_or and '(' not in expr_simplified:
+            # Pure AND chain at top level - safe to deduplicate
+            return remove_duplicates_simple(expr_simplified, 'and')
+        elif has_or and not has_and and '(' not in expr_simplified:
+            # Pure OR chain at top level - safe to deduplicate  
+            return remove_duplicates_simple(expr_simplified, 'or')
+        else:
+            # Mixed operators or has parentheses - return the simplified version (with parentheses content cleaned)
+            return expr_simplified
+            
     except Exception:
-    # if simplification fails, return the original expression
+        # If anything goes wrong, return the original expression
         return expr
 
 # ---------- Main public function ----------
@@ -618,7 +698,7 @@
                          target_nomenclature: str,
                          source_nomenclature: str = 'hgnc_id',
                          allow_many_to_one: bool = False,
-                         logger: Optional[logging.Logger] = None) -> 'cobra.Model':
+                         logger: Optional[logging.Logger] = None) -> Tuple['cobra.Model', Dict[str, str]]:
     """
     Translate model genes from source_nomenclature to target_nomenclature using a mapping table.
     mapping_df should contain columns enabling mapping (e.g., ensg, hgnc_id, hgnc_symbol, entrez).
@@ -630,6 +710,11 @@
         source_nomenclature: Current source key in the model (default 'hgnc_id').
         allow_many_to_one: If True, allow many-to-one mappings and handle duplicates in GPRs.
         logger: Optional logger.
+    
+    Returns:
+        Tuple containing:
+        - Translated COBRA model
+        - Dictionary mapping reaction IDs to translation issue descriptions
     """
     if logger is None:
         logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -701,6 +786,9 @@
     stats = {'translated': 0, 'one_to_one': 0, 'one_to_many': 0, 'not_found': 0, 'simplified_gprs': 0}
     unmapped = []
     multi = []
+    
+    # Dictionary to store translation issues per reaction
+    reaction_translation_issues = {}
 
     original_genes = {g.id for g in model_copy.genes}
     logger.info(f"Original genes count: {len(original_genes)}")
@@ -709,7 +797,10 @@
     for rxn in model_copy.reactions:
         gpr = rxn.gene_reaction_rule
         if gpr and gpr.strip():
-            new_gpr = _translate_gpr(gpr, gene_mapping, stats, unmapped, multi, logger)
+            new_gpr, rxn_issues = _translate_gpr(gpr, gene_mapping, stats, unmapped, multi, logger, track_issues=True)
+            if rxn_issues:
+                reaction_translation_issues[rxn.id] = rxn_issues
+            
             if new_gpr != gpr:
                 simplified_gpr = _simplify_boolean_expression(new_gpr)
                 if simplified_gpr != new_gpr:
@@ -725,7 +816,7 @@
     _log_translation_statistics(stats, unmapped, multi, original_genes, model_copy.genes, logger)
 
     logger.info("Translation finished")
-    return model_copy
+    return model_copy, reaction_translation_issues
 
 
 # ---------- helper functions ----------
@@ -762,10 +853,11 @@
                    stats: Dict[str, int],
                    unmapped_genes: List[str],
                    multi_mapping_genes: List[Tuple[str, List[str]]],
-                   logger: logging.Logger) -> str:
+                   logger: logging.Logger,
+                   track_issues: bool = False) -> Union[str, Tuple[str, str]]:
     """
     Translate genes inside a GPR string using gene_mapping.
-    Returns new GPR string.
+    Returns new GPR string, and optionally translation issues if track_issues=True.
     """
     # Generic token pattern: letters, digits, :, _, -, ., (captures HGNC:1234, ENSG000..., symbols)
     token_pattern = r'\b[A-Za-z0-9:_.-]+\b'
@@ -775,6 +867,7 @@
     tokens = [t for t in tokens if t not in logical]
 
     new_gpr = gpr_string
+    issues = []
 
     for token in sorted(set(tokens), key=lambda x: -len(x)):  # longer tokens first to avoid partial replacement
         norm = _normalize_gene_id(token)
@@ -788,6 +881,8 @@
                 stats['one_to_many'] += 1
                 multi_mapping_genes.append((token, targets))
                 replacement = "(" + " or ".join(targets) + ")"
+                if track_issues:
+                    issues.append(f"{token} -> {' or '.join(targets)}")
 
             pattern = r'\b' + re.escape(token) + r'\b'
             new_gpr = re.sub(pattern, replacement, new_gpr)
@@ -797,7 +892,32 @@
                 unmapped_genes.append(token)
             logger.debug(f"Token not found in mapping (left as-is): {token}")
 
-    return new_gpr
+    # Check for many-to-one cases (multiple source genes mapping to same target)
+    if track_issues:
+        # Build reverse mapping to detect many-to-one cases from original tokens
+        original_to_target = {}
+        
+        for orig_token in tokens:
+            norm = _normalize_gene_id(orig_token)
+            if norm in gene_mapping:
+                targets = gene_mapping[norm]
+                for target in targets:
+                    if target not in original_to_target:
+                        original_to_target[target] = []
+                    if orig_token not in original_to_target[target]:
+                        original_to_target[target].append(orig_token)
+        
+        # Identify many-to-one mappings in this specific GPR
+        for target, original_genes in original_to_target.items():
+            if len(original_genes) > 1:
+                issues.append(f"{' or '.join(original_genes)} -> {target}")
+    
+    issue_text = "; ".join(issues) if issues else ""
+    
+    if track_issues:
+        return new_gpr, issue_text
+    else:
+        return new_gpr
 
 
 def _update_model_genes(model: 'cobra.Model', logger: logging.Logger):
@@ -874,4 +994,6 @@
     if multi_mapping_genes:
         logger.info(f"Multi-mapping examples ({len(multi_mapping_genes)}):")
         for orig, targets in multi_mapping_genes[:10]:
-            logger.info(f"  {orig} -> {', '.join(targets)}")
\ No newline at end of file
+            logger.info(f"  {orig} -> {', '.join(targets)}")
+
+    
\ No newline at end of file