Mercurial > repos > iuc > hyphy_absrel
diff scripts/infer_stasis_clusters.py @ 40:3f8261f0a826 draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/main/tools/hyphy commit cee1ce4bd7d82088b9bf62403bc175c13223e020
| author | iuc |
|---|---|
| date | Wed, 11 Mar 2026 11:15:12 +0000 |
| parents | |
| children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/scripts/infer_stasis_clusters.py Wed Mar 11 11:15:12 2026 +0000 @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +""" +B-STILL Stasis Cluster Inference Tool +==================================== +Identifies regional footprints of extreme purifying selection (stasis) in B-STILL +JSON results using a FWER-controlled Hypergeometric Scan Statistic. + +Usage: + python3 infer_stasis_clusters.py input.json --ebf 10 --permutations 10000 --output results.json +""" + +import argparse +import json +import sys +import time + +import numpy as np +from scipy.stats import hypergeom + + +def get_sf_optimized(n, d, L, K, cache): + """Retrieves or computes Hypergeometric Survival Function value.""" + key = (n, d) + if key not in cache: + cache[key] = hypergeom.sf(n - 1, L, K, d) + return cache[key] + + +def scan_intervals(indices, L, K, max_size, sf_cache, threshold=None): + """ + Scans all possible intervals [i, j] anchored by stasis events. + Returns the minimum p-value if threshold is None, else returns all significant segments. + """ + best_p = 1.0 + segments = [] + num_events = len(indices) + + for n in range(3, min(max_size + 1, num_events + 1)): + for i in range(num_events - n + 1): + d = indices[i + n - 1] - indices[i] + 1 + p = get_sf_optimized(n, d, L, K, sf_cache) + + if threshold is None: + if p < best_p: + best_p = p + else: + if p <= threshold: + segments.append({ + "start": int(indices[i] + 1), + "end": int(indices[i + n - 1] + 1), + "p_value": p, + "k": n, + "d": int(d) + }) + + return best_p if threshold is None else segments + + +def merge_segments(segments, merge_dist=15): + """Merges overlapping or nearby significant segments.""" + if not segments: + return [] + segments.sort(key=lambda x: x['start']) + + merged = [] + curr = segments[0] + for next_s in segments[1:]: + if next_s['start'] <= curr['end'] + merge_dist: + curr['end'] = max(curr['end'], next_s['end']) + curr['p_value'] = min(curr['p_value'], next_s['p_value']) + curr['d'] = curr['end'] - curr['start'] + 1 + else: + merged.append(curr) + curr = next_s + merged.append(curr) + return merged + + +def main(): + parser = argparse.ArgumentParser(description="Infer stasis clusters from B-STILL JSON.") + parser.add_argument("input", help="Path to B-STILL JSON result file") + parser.add_argument("--ebf", type=float, default=10.0, help="EBF threshold for defining stasis sites (default: 10.0)") + parser.add_argument("--permutations", type=int, default=10000, help="Number of permutations for FWER control (default: 10000)") + parser.add_argument("--alpha", type=float, default=0.05, help="Family-wise error rate threshold (default: 0.05)") + parser.add_argument("--max-cluster", type=int, default=30, help="Maximum number of stasis sites per interval scan (default: 30)") + parser.add_argument("--merge", type=int, default=15, help="Distance in codons to merge adjacent clusters (default: 15)") + parser.add_argument("--output", help="Path to save results in JSON format") + + args = parser.parse_args() + + try: + with open(args.input, "r") as f: + data = json.load(f) + except Exception as e: + print("Error loading JSON: {0}".format(e)) + sys.exit(1) + + sites = data.get("MLE", {}).get("content", {}).get("0", []) + ebfs = [s[12] if (len(s) > 12 and isinstance(s[12], (int, float))) else 0 for s in sites] + L = len(ebfs) + + if L < 10: + print("Alignment too short for cluster analysis.") + sys.exit(0) + + stasis_indices = np.array([i for i, val in enumerate(ebfs) if val >= args.ebf]) + K = len(stasis_indices) + + print("--- B-STILL Cluster Inference ---") + print("Input: {0}".format(args.input)) + print("Gene Length (L): {0} codons".format(L)) + print("Stasis Sites (K): {0} (EBF >= {1})".format(K, args.ebf)) + + if K < 3: + print("Insufficient stasis sites to form clusters (minimum 3 required).") + sys.exit(0) + + print("Running {0} permutations for FWER control...".format(args.permutations)) + null_min_ps = [] + all_positions = np.arange(L) + sf_cache = {} + + start_time = time.time() + for i in range(args.permutations): + if i > 0 and i % 1000 == 0: + elapsed = time.time() - start_time + print(" Processed {0} permutations... ({1:.1f} per sec)".format(i, i / elapsed)) + shuffled = sorted(np.random.choice(all_positions, K, replace=False)) + min_p = scan_intervals(shuffled, L, K, args.max_cluster, sf_cache) + null_min_ps.append(min_p) + + crit_p = np.percentile(null_min_ps, args.alpha * 100) + print("Gene-specific Critical P-value (FWER {0}): {1:.2e}".format(args.alpha, crit_p)) + + print("Scanning observed sequence for significant clusters...") + raw_segments = scan_intervals(stasis_indices, L, K, args.max_cluster, sf_cache, threshold=crit_p) + + final_clusters = merge_segments(raw_segments, merge_dist=args.merge) + + for c in final_clusters: + c['k'] = sum(1 for idx in stasis_indices if c['start'] <= idx + 1 <= c['end']) + + print("\nFound {0} significant stasis clusters:".format(len(final_clusters))) + if final_clusters: + print("\nLegend:") + print(" k : Number of high-confidence stasis sites within the cluster") + print(" d : Total span of the cluster in codons") + print("\n{:<8} | {:<8} | {:<5} | {:<5} | {:<10}".format("Start", "End", "k", "d", "P-value")) + print("-" * 45) + for c in final_clusters: + print("{:<8} | {:<8} | {:<5} | {:<5} | {:.2e}".format(c['start'], c['end'], c['k'], c['d'], c['p_value'])) + + if args.output: + output_data = { + "input_file": args.input, + "parameters": vars(args), + "summary": { + "gene_length": L, + "total_stasis_sites": K, + "critical_p_value": float(crit_p), + "num_clusters": len(final_clusters) + }, + "clusters": final_clusters + } + with open(args.output, "w") as f: + json.dump(output_data, f, indent=4) + print("\nDetailed results saved to {0}".format(args.output)) + + +if __name__ == "__main__": + main()
