|
9 | 9 |
|
10 | 10 | import argparse |
11 | 11 | import csv |
| 12 | +import json |
12 | 13 | import logging |
13 | 14 | import os |
14 | 15 | import sys |
| 16 | +from types import SimpleNamespace |
15 | 17 |
|
16 | 18 | import numpy as np |
| 19 | +import pandas as pd |
17 | 20 | from alphagenome.data import genome |
18 | 21 | from alphagenome.models import dna_client |
19 | 22 | from alphagenome.models.variant_scorers import RECOMMENDED_VARIANT_SCORERS |
20 | 23 |
|
| 24 | + |
| 25 | +def _col_as_list(df, col): |
| 26 | + """Extract a column as a list of strings; empty strings if the column is missing.""" |
| 27 | + if col in df.columns: |
| 28 | + return df[col].astype(str).tolist() |
| 29 | + return [""] * len(df) |
| 30 | + |
| 31 | + |
| 32 | +ISM_OUTPUT_COLUMNS = [ |
| 33 | + "region", "position", "ref_base", "alt_base", |
| 34 | + "gene_id", "gene_name", "gene_type", |
| 35 | + "scorer", "track_name", "ontology_curie", |
| 36 | + "raw_score", "quantile_score", |
| 37 | +] |
| 38 | + |
| 39 | + |
| 40 | +def _write_chunked_tsv(outfile, columns): |
| 41 | + df = pd.DataFrame(columns) |
| 42 | + df.to_csv(outfile, sep="\t", header=False, index=False, |
| 43 | + float_format="%.6f", na_rep="", |
| 44 | + lineterminator="\r\n") |
| 45 | + return len(df) |
| 46 | + |
| 47 | + |
| 48 | +def _write_region_results(outfile, name, results, scorers_arg, flush_rows=200_000): |
| 49 | + """Post-process ISM results for one region and stream TSV rows to outfile. |
| 50 | +
|
| 51 | + Returns the number of non-NaN rows written. obs/var metadata is identical |
| 52 | + across variants for a given scorer, so extract once per (region, scorer) |
| 53 | + then flatten non-NaN cells into a DataFrame and let pandas' C path handle |
| 54 | + CSV formatting. |
| 55 | + """ |
| 56 | + chunks = {col: [] for col in ISM_OUTPUT_COLUMNS} |
| 57 | + pending_rows = 0 |
| 58 | + per_scorer_meta = {} |
| 59 | + row_count = 0 |
| 60 | + |
| 61 | + def flush_pending(): |
| 62 | + nonlocal pending_rows, row_count |
| 63 | + if pending_rows == 0: |
| 64 | + return |
| 65 | + row_count += _write_chunked_tsv( |
| 66 | + outfile, |
| 67 | + {col: np.concatenate(values) for col, values in chunks.items()}, |
| 68 | + ) |
| 69 | + for values in chunks.values(): |
| 70 | + values.clear() |
| 71 | + pending_rows = 0 |
| 72 | + |
| 73 | + for var_results in results: |
| 74 | + for scorer_idx, ad in enumerate(var_results): |
| 75 | + variant_obj = ad.uns["variant"] |
| 76 | + pos = variant_obj.position |
| 77 | + ref_base = variant_obj.reference_bases |
| 78 | + alt_base = variant_obj.alternate_bases |
| 79 | + scorer_name = ( |
| 80 | + scorers_arg[scorer_idx] |
| 81 | + if scorer_idx < len(scorers_arg) |
| 82 | + else f"scorer_{scorer_idx}" |
| 83 | + ) |
| 84 | + |
| 85 | + if scorer_idx not in per_scorer_meta: |
| 86 | + per_scorer_meta[scorer_idx] = ( |
| 87 | + np.asarray(_col_as_list(ad.obs, "gene_id")), |
| 88 | + np.asarray(_col_as_list(ad.obs, "gene_name")), |
| 89 | + np.asarray(_col_as_list(ad.obs, "gene_type")), |
| 90 | + np.asarray(_col_as_list(ad.var, "name")), |
| 91 | + np.asarray(_col_as_list(ad.var, "ontology_curie")), |
| 92 | + ) |
| 93 | + gene_ids, gene_names, gene_types, track_names, track_curies = per_scorer_meta[scorer_idx] |
| 94 | + |
| 95 | + raw_scores = np.asarray(ad.X, dtype=float) |
| 96 | + quantile_scores = ad.layers.get("quantiles", None) |
| 97 | + if quantile_scores is not None: |
| 98 | + quantile_scores = np.asarray(quantile_scores, dtype=float) |
| 99 | + |
| 100 | + valid_mask = ~np.isnan(raw_scores) |
| 101 | + gi, ti = np.where(valid_mask) |
| 102 | + if gi.size == 0: |
| 103 | + continue |
| 104 | + |
| 105 | + if quantile_scores is not None: |
| 106 | + q_flat = quantile_scores[gi, ti] |
| 107 | + else: |
| 108 | + q_flat = np.full(gi.size, np.nan) |
| 109 | + |
| 110 | + size = gi.size |
| 111 | + chunks["region"].append(np.full(size, name, dtype=object)) |
| 112 | + chunks["position"].append(np.full(size, pos)) |
| 113 | + chunks["ref_base"].append(np.full(size, ref_base, dtype=object)) |
| 114 | + chunks["alt_base"].append(np.full(size, alt_base, dtype=object)) |
| 115 | + chunks["gene_id"].append(gene_ids[gi]) |
| 116 | + chunks["gene_name"].append(gene_names[gi]) |
| 117 | + chunks["gene_type"].append(gene_types[gi]) |
| 118 | + chunks["scorer"].append(np.full(size, scorer_name, dtype=object)) |
| 119 | + chunks["track_name"].append(track_names[ti]) |
| 120 | + chunks["ontology_curie"].append(track_curies[ti]) |
| 121 | + chunks["raw_score"].append(raw_scores[gi, ti]) |
| 122 | + chunks["quantile_score"].append(q_flat) |
| 123 | + pending_rows += size |
| 124 | + if pending_rows >= flush_rows: |
| 125 | + flush_pending() |
| 126 | + flush_pending() |
| 127 | + return row_count |
| 128 | + |
| 129 | + |
| 130 | +def _load_mock_ism_results(path): |
| 131 | + """Load JSON describing mock AnnData inputs for CI testing of the post-processing path. |
| 132 | +
|
| 133 | + The real --test-fixture path bypasses post-processing entirely (it just dumps |
| 134 | + pre-computed TSV rows). This loader constructs the minimal AnnData interface |
| 135 | + consumed by _write_region_results so CI can exercise the vectorized code |
| 136 | + path against controlled multi-region / NaN / missing-quantile cases. |
| 137 | + """ |
| 138 | + class _MockAd: |
| 139 | + def __init__(self, obs, var, X, quantiles, variant): |
| 140 | + self.obs = pd.DataFrame(obs) |
| 141 | + self.var = pd.DataFrame(var) |
| 142 | + self.X = np.asarray(X, dtype=float) |
| 143 | + self.layers = ( |
| 144 | + {"quantiles": np.asarray(quantiles, dtype=float)} |
| 145 | + if quantiles is not None |
| 146 | + else {} |
| 147 | + ) |
| 148 | + self.uns = {"variant": SimpleNamespace(**variant)} |
| 149 | + |
| 150 | + with open(path) as f: |
| 151 | + data = json.load(f) |
| 152 | + |
| 153 | + regions = [] |
| 154 | + for region_data in data["regions"]: |
| 155 | + region_results = [] |
| 156 | + for var_entry in region_data["variants"]: |
| 157 | + variant_meta = { |
| 158 | + "position": var_entry["position"], |
| 159 | + "reference_bases": var_entry["reference_bases"], |
| 160 | + "alternate_bases": var_entry["alternate_bases"], |
| 161 | + } |
| 162 | + region_results.append([ |
| 163 | + _MockAd( |
| 164 | + obs=scorer.get("obs", {}), |
| 165 | + var=scorer.get("var", {}), |
| 166 | + X=scorer["X"], |
| 167 | + quantiles=scorer.get("quantiles"), |
| 168 | + variant=variant_meta, |
| 169 | + ) |
| 170 | + for scorer in var_entry["scorers"] |
| 171 | + ]) |
| 172 | + regions.append((region_data["name"], region_results)) |
| 173 | + return regions, data.get("scorers", []) |
| 174 | + |
| 175 | + |
21 | 176 | __version__ = "0.6.1" |
22 | 177 |
|
23 | 178 | ORGANISM_MAP = { |
@@ -91,6 +246,22 @@ def run(args): |
91 | 246 | logging.info("Fixture mode: wrote %d rows to %s", len(fixture_data["rows"]), args.output) |
92 | 247 | return |
93 | 248 |
|
| 249 | + if args.mock_ism_results: |
| 250 | + mock_regions, mock_scorers = _load_mock_ism_results(args.mock_ism_results) |
| 251 | + with open(args.output, "w", newline="") as outfile: |
| 252 | + writer = csv.writer(outfile, delimiter="\t") |
| 253 | + writer.writerow([ |
| 254 | + "region", "position", "ref_base", "alt_base", |
| 255 | + "gene_id", "gene_name", "gene_type", |
| 256 | + "scorer", "track_name", "ontology_curie", |
| 257 | + "raw_score", "quantile_score", |
| 258 | + ]) |
| 259 | + row_count = 0 |
| 260 | + for name, results in mock_regions: |
| 261 | + row_count += _write_region_results(outfile, name, results, mock_scorers) |
| 262 | + logging.info("Mock mode: wrote %d rows to %s", row_count, args.output) |
| 263 | + return |
| 264 | + |
94 | 265 | api_key = args.api_key or os.environ.get("ALPHAGENOME_API_KEY") |
95 | 266 | if not api_key and not args.local_model: |
96 | 267 | logging.error("No API key provided. Set ALPHAGENOME_API_KEY or use --api-key") |
@@ -145,47 +316,8 @@ def run(args): |
145 | 316 | max_workers=args.max_workers, |
146 | 317 | ) |
147 | 318 |
|
148 | | - # results is list[list[AnnData]] — outer=variants (3*width), inner=scorers |
149 | | - # Each AnnData has: uns['variant'] with position/ref/alt, |
150 | | - # X for raw scores, layers['quantiles'], obs for genes, var for tracks |
151 | | - for var_results in results: |
152 | | - for scorer_idx, ad in enumerate(var_results): |
153 | | - variant_obj = ad.uns["variant"] |
154 | | - pos = variant_obj.position |
155 | | - ref_base = variant_obj.reference_bases |
156 | | - alt_base = variant_obj.alternate_bases |
157 | | - scorer_name = args.scorers[scorer_idx] if scorer_idx < len(args.scorers) else f"scorer_{scorer_idx}" |
158 | | - |
159 | | - raw_scores = ad.X # shape (n_genes, n_tracks) |
160 | | - quantile_scores = ad.layers.get("quantiles", None) |
161 | | - |
162 | | - for gene_idx in range(ad.n_obs): |
163 | | - gene_row = ad.obs.iloc[gene_idx] |
164 | | - gene_id = str(gene_row.get("gene_id", "")) |
165 | | - gene_name = str(gene_row.get("gene_name", "")) |
166 | | - gene_type = str(gene_row.get("gene_type", "")) |
167 | | - |
168 | | - for track_idx in range(ad.n_vars): |
169 | | - track_row = ad.var.iloc[track_idx] |
170 | | - track_name = str(track_row.get("name", "")) |
171 | | - ontology_curie = str(track_row.get("ontology_curie", "")) |
172 | | - |
173 | | - raw = float(raw_scores[gene_idx, track_idx]) |
174 | | - if np.isnan(raw): |
175 | | - continue |
176 | | - quant = "" |
177 | | - if quantile_scores is not None: |
178 | | - q = float(quantile_scores[gene_idx, track_idx]) |
179 | | - if not np.isnan(q): |
180 | | - quant = f"{q:.6f}" |
181 | | - |
182 | | - writer.writerow([ |
183 | | - name, pos, ref_base, alt_base, |
184 | | - gene_id, gene_name, gene_type, |
185 | | - scorer_name, track_name, ontology_curie, |
186 | | - f"{raw:.6f}", quant, |
187 | | - ]) |
188 | | - row_count += 1 |
| 319 | + # results is list[list[AnnData]] -- outer=variants (3*width), inner=scorers |
| 320 | + row_count += _write_region_results(outfile, name, results, args.scorers) |
189 | 321 |
|
190 | 322 | stats["scored"] += 1 |
191 | 323 | logging.info("Region %s: %d ISM variants scored", name, len(results)) |
@@ -228,6 +360,9 @@ def parse_arguments(): |
228 | 360 | parser.add_argument("--local-model", action="store_true") |
229 | 361 | parser.add_argument("--test-fixture", default=None, |
230 | 362 | help="Test fixture JSON for CI testing (bypasses API)") |
| 363 | + parser.add_argument("--mock-ism-results", default=None, |
| 364 | + help="Mock AnnData JSON for CI testing the post-processing " |
| 365 | + "path (bypasses API but exercises the vectorized loop)") |
231 | 366 | parser.add_argument("--verbose", action="store_true") |
232 | 367 | parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}") |
233 | 368 | return parser.parse_args() |
|
0 commit comments