Mercurial > repos > iuc > hyphy_absrel
comparison scripts/infer_stasis_clusters.py @ 40:3f8261f0a826 draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/main/tools/hyphy commit cee1ce4bd7d82088b9bf62403bc175c13223e020
| author | iuc |
|---|---|
| date | Wed, 11 Mar 2026 11:15:12 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 39:106a316b5313 | 40:3f8261f0a826 |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 """ | |
| 3 B-STILL Stasis Cluster Inference Tool | |
| 4 ==================================== | |
| 5 Identifies regional footprints of extreme purifying selection (stasis) in B-STILL | |
| 6 JSON results using a FWER-controlled Hypergeometric Scan Statistic. | |
| 7 | |
| 8 Usage: | |
| 9 python3 infer_stasis_clusters.py input.json --ebf 10 --permutations 10000 --output results.json | |
| 10 """ | |
| 11 | |
| 12 import argparse | |
| 13 import json | |
| 14 import sys | |
| 15 import time | |
| 16 | |
| 17 import numpy as np | |
| 18 from scipy.stats import hypergeom | |
| 19 | |
| 20 | |
| 21 def get_sf_optimized(n, d, L, K, cache): | |
| 22 """Retrieves or computes Hypergeometric Survival Function value.""" | |
| 23 key = (n, d) | |
| 24 if key not in cache: | |
| 25 cache[key] = hypergeom.sf(n - 1, L, K, d) | |
| 26 return cache[key] | |
| 27 | |
| 28 | |
| 29 def scan_intervals(indices, L, K, max_size, sf_cache, threshold=None): | |
| 30 """ | |
| 31 Scans all possible intervals [i, j] anchored by stasis events. | |
| 32 Returns the minimum p-value if threshold is None, else returns all significant segments. | |
| 33 """ | |
| 34 best_p = 1.0 | |
| 35 segments = [] | |
| 36 num_events = len(indices) | |
| 37 | |
| 38 for n in range(3, min(max_size + 1, num_events + 1)): | |
| 39 for i in range(num_events - n + 1): | |
| 40 d = indices[i + n - 1] - indices[i] + 1 | |
| 41 p = get_sf_optimized(n, d, L, K, sf_cache) | |
| 42 | |
| 43 if threshold is None: | |
| 44 if p < best_p: | |
| 45 best_p = p | |
| 46 else: | |
| 47 if p <= threshold: | |
| 48 segments.append({ | |
| 49 "start": int(indices[i] + 1), | |
| 50 "end": int(indices[i + n - 1] + 1), | |
| 51 "p_value": p, | |
| 52 "k": n, | |
| 53 "d": int(d) | |
| 54 }) | |
| 55 | |
| 56 return best_p if threshold is None else segments | |
| 57 | |
| 58 | |
| 59 def merge_segments(segments, merge_dist=15): | |
| 60 """Merges overlapping or nearby significant segments.""" | |
| 61 if not segments: | |
| 62 return [] | |
| 63 segments.sort(key=lambda x: x['start']) | |
| 64 | |
| 65 merged = [] | |
| 66 curr = segments[0] | |
| 67 for next_s in segments[1:]: | |
| 68 if next_s['start'] <= curr['end'] + merge_dist: | |
| 69 curr['end'] = max(curr['end'], next_s['end']) | |
| 70 curr['p_value'] = min(curr['p_value'], next_s['p_value']) | |
| 71 curr['d'] = curr['end'] - curr['start'] + 1 | |
| 72 else: | |
| 73 merged.append(curr) | |
| 74 curr = next_s | |
| 75 merged.append(curr) | |
| 76 return merged | |
| 77 | |
| 78 | |
| 79 def main(): | |
| 80 parser = argparse.ArgumentParser(description="Infer stasis clusters from B-STILL JSON.") | |
| 81 parser.add_argument("input", help="Path to B-STILL JSON result file") | |
| 82 parser.add_argument("--ebf", type=float, default=10.0, help="EBF threshold for defining stasis sites (default: 10.0)") | |
| 83 parser.add_argument("--permutations", type=int, default=10000, help="Number of permutations for FWER control (default: 10000)") | |
| 84 parser.add_argument("--alpha", type=float, default=0.05, help="Family-wise error rate threshold (default: 0.05)") | |
| 85 parser.add_argument("--max-cluster", type=int, default=30, help="Maximum number of stasis sites per interval scan (default: 30)") | |
| 86 parser.add_argument("--merge", type=int, default=15, help="Distance in codons to merge adjacent clusters (default: 15)") | |
| 87 parser.add_argument("--output", help="Path to save results in JSON format") | |
| 88 | |
| 89 args = parser.parse_args() | |
| 90 | |
| 91 try: | |
| 92 with open(args.input, "r") as f: | |
| 93 data = json.load(f) | |
| 94 except Exception as e: | |
| 95 print("Error loading JSON: {0}".format(e)) | |
| 96 sys.exit(1) | |
| 97 | |
| 98 sites = data.get("MLE", {}).get("content", {}).get("0", []) | |
| 99 ebfs = [s[12] if (len(s) > 12 and isinstance(s[12], (int, float))) else 0 for s in sites] | |
| 100 L = len(ebfs) | |
| 101 | |
| 102 if L < 10: | |
| 103 print("Alignment too short for cluster analysis.") | |
| 104 sys.exit(0) | |
| 105 | |
| 106 stasis_indices = np.array([i for i, val in enumerate(ebfs) if val >= args.ebf]) | |
| 107 K = len(stasis_indices) | |
| 108 | |
| 109 print("--- B-STILL Cluster Inference ---") | |
| 110 print("Input: {0}".format(args.input)) | |
| 111 print("Gene Length (L): {0} codons".format(L)) | |
| 112 print("Stasis Sites (K): {0} (EBF >= {1})".format(K, args.ebf)) | |
| 113 | |
| 114 if K < 3: | |
| 115 print("Insufficient stasis sites to form clusters (minimum 3 required).") | |
| 116 sys.exit(0) | |
| 117 | |
| 118 print("Running {0} permutations for FWER control...".format(args.permutations)) | |
| 119 null_min_ps = [] | |
| 120 all_positions = np.arange(L) | |
| 121 sf_cache = {} | |
| 122 | |
| 123 start_time = time.time() | |
| 124 for i in range(args.permutations): | |
| 125 if i > 0 and i % 1000 == 0: | |
| 126 elapsed = time.time() - start_time | |
| 127 print(" Processed {0} permutations... ({1:.1f} per sec)".format(i, i / elapsed)) | |
| 128 shuffled = sorted(np.random.choice(all_positions, K, replace=False)) | |
| 129 min_p = scan_intervals(shuffled, L, K, args.max_cluster, sf_cache) | |
| 130 null_min_ps.append(min_p) | |
| 131 | |
| 132 crit_p = np.percentile(null_min_ps, args.alpha * 100) | |
| 133 print("Gene-specific Critical P-value (FWER {0}): {1:.2e}".format(args.alpha, crit_p)) | |
| 134 | |
| 135 print("Scanning observed sequence for significant clusters...") | |
| 136 raw_segments = scan_intervals(stasis_indices, L, K, args.max_cluster, sf_cache, threshold=crit_p) | |
| 137 | |
| 138 final_clusters = merge_segments(raw_segments, merge_dist=args.merge) | |
| 139 | |
| 140 for c in final_clusters: | |
| 141 c['k'] = sum(1 for idx in stasis_indices if c['start'] <= idx + 1 <= c['end']) | |
| 142 | |
| 143 print("\nFound {0} significant stasis clusters:".format(len(final_clusters))) | |
| 144 if final_clusters: | |
| 145 print("\nLegend:") | |
| 146 print(" k : Number of high-confidence stasis sites within the cluster") | |
| 147 print(" d : Total span of the cluster in codons") | |
| 148 print("\n{:<8} | {:<8} | {:<5} | {:<5} | {:<10}".format("Start", "End", "k", "d", "P-value")) | |
| 149 print("-" * 45) | |
| 150 for c in final_clusters: | |
| 151 print("{:<8} | {:<8} | {:<5} | {:<5} | {:.2e}".format(c['start'], c['end'], c['k'], c['d'], c['p_value'])) | |
| 152 | |
| 153 if args.output: | |
| 154 output_data = { | |
| 155 "input_file": args.input, | |
| 156 "parameters": vars(args), | |
| 157 "summary": { | |
| 158 "gene_length": L, | |
| 159 "total_stasis_sites": K, | |
| 160 "critical_p_value": float(crit_p), | |
| 161 "num_clusters": len(final_clusters) | |
| 162 }, | |
| 163 "clusters": final_clusters | |
| 164 } | |
| 165 with open(args.output, "w") as f: | |
| 166 json.dump(output_data, f, indent=4) | |
| 167 print("\nDetailed results saved to {0}".format(args.output)) | |
| 168 | |
| 169 | |
| 170 if __name__ == "__main__": | |
| 171 main() |
