Skip to content

Commit 7c296bb

Browse files
authored
Merge pull request #7917 from dannon/alphagenome-ism-perf
alphagenome: vectorize post-processing across scoring tools
2 parents fc143b0 + 8559b90 commit 7c296bb

8 files changed

Lines changed: 533 additions & 171 deletions

tools/alphagenome/alphagenome_interval_predictor.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414

1515
import numpy as np
16+
import pandas as pd
1617
from alphagenome.data import genome
1718
from alphagenome.models import dna_client
1819

@@ -81,6 +82,71 @@ def extract_region_slice(values, interval_start, region_start, region_end):
8182
return values[offset_start:offset_end]
8283

8384

85+
def metadata_column(metadata, column, size):
86+
if metadata is not None and column in metadata.columns:
87+
values = metadata[column].astype(str).to_numpy()
88+
if len(values) >= size:
89+
return values[:size]
90+
padded = np.full(size, "", dtype=object)
91+
padded[:len(values)] = values
92+
return padded
93+
return np.full(size, "", dtype=object)
94+
95+
96+
def as_2d(values):
97+
values = np.asarray(values)
98+
if values.ndim == 1:
99+
return values.reshape(-1, 1)
100+
return values
101+
102+
103+
def binned_means(values, bin_size):
104+
starts = np.arange(0, values.shape[0], bin_size)
105+
ends = np.minimum(starts + bin_size, values.shape[0])
106+
sums = np.add.reduceat(values, starts, axis=0)
107+
return starts, ends, sums / (ends - starts)[:, None]
108+
109+
110+
def write_interval_output(outfile, chrom, start, end, name, otype, values, metadata,
111+
output_mode, bin_size):
112+
values = as_2d(values)
113+
num_tracks = values.shape[1]
114+
track_names = metadata_column(metadata, "track_name", num_tracks)
115+
ontology_curies = metadata_column(metadata, "ontology_curie", num_tracks)
116+
117+
if output_mode == "summary":
118+
df = pd.DataFrame({
119+
"chrom": np.repeat(chrom, num_tracks),
120+
"start": np.repeat(start, num_tracks),
121+
"end": np.repeat(end, num_tracks),
122+
"name": np.repeat(name, num_tracks),
123+
"output_type": np.repeat(otype, num_tracks),
124+
"track_name": track_names,
125+
"ontology_curie": ontology_curies,
126+
"mean_signal": np.mean(values, axis=0),
127+
"max_signal": np.max(values, axis=0),
128+
})
129+
else:
130+
if values.shape[0] == 0:
131+
return 0
132+
bin_starts, bin_ends, means = binned_means(values, bin_size)
133+
num_bins = len(bin_starts)
134+
df = pd.DataFrame({
135+
"chrom": np.repeat(chrom, num_tracks * num_bins),
136+
"bin_start": start + np.tile(bin_starts, num_tracks),
137+
"bin_end": start + np.tile(bin_ends, num_tracks),
138+
"region_name": np.repeat(name, num_tracks * num_bins),
139+
"output_type": np.repeat(otype, num_tracks * num_bins),
140+
"track_name": np.repeat(track_names, num_bins),
141+
"ontology_curie": np.repeat(ontology_curies, num_bins),
142+
"mean_signal": means.T.ravel(),
143+
})
144+
145+
df.to_csv(outfile, sep="\t", header=False, index=False,
146+
float_format="%.6f", lineterminator="\r\n")
147+
return len(df)
148+
149+
84150
def run(args):
85151
logging.info("AlphaGenome Interval Predictor v%s", __version__)
86152
logging.info("Input: %s", args.input)
@@ -172,41 +238,10 @@ def run(args):
172238
values, interval.start, start, end,
173239
)
174240

175-
num_tracks = region_values.shape[1] if region_values.ndim > 1 else 1
176-
if region_values.ndim == 1:
177-
region_values = region_values.reshape(-1, 1)
178-
179-
for track_idx in range(num_tracks):
180-
track_vals = region_values[:, track_idx]
181-
track_name = ""
182-
ontology_curie = ""
183-
if metadata is not None and len(metadata) > track_idx:
184-
row = metadata.iloc[track_idx]
185-
track_name = str(row.get("track_name", ""))
186-
ontology_curie = str(row.get("ontology_curie", ""))
187-
188-
if args.output_mode == "summary":
189-
mean_sig = float(np.mean(track_vals))
190-
max_sig = float(np.max(track_vals))
191-
writer.writerow([
192-
chrom, start, end, name, otype,
193-
track_name, ontology_curie,
194-
f"{mean_sig:.6f}", f"{max_sig:.6f}",
195-
])
196-
else:
197-
# Binned mode
198-
region_len = region_values.shape[0]
199-
bin_size = args.bin_size
200-
for bin_start_offset in range(0, region_len, bin_size):
201-
bin_end_offset = min(bin_start_offset + bin_size, region_len)
202-
bin_vals = track_vals[bin_start_offset:bin_end_offset]
203-
mean_sig = float(np.mean(bin_vals))
204-
writer.writerow([
205-
chrom, start + bin_start_offset,
206-
start + bin_end_offset, name, otype,
207-
track_name, ontology_curie,
208-
f"{mean_sig:.6f}",
209-
])
241+
write_interval_output(
242+
outfile, chrom, start, end, name, otype, region_values,
243+
metadata, args.output_mode, args.bin_size,
244+
)
210245

211246
stats["predicted"] += 1
212247

tools/alphagenome/alphagenome_ism_scanner.py

Lines changed: 176 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,170 @@
99

1010
import argparse
1111
import csv
12+
import json
1213
import logging
1314
import os
1415
import sys
16+
from types import SimpleNamespace
1517

1618
import numpy as np
19+
import pandas as pd
1720
from alphagenome.data import genome
1821
from alphagenome.models import dna_client
1922
from alphagenome.models.variant_scorers import RECOMMENDED_VARIANT_SCORERS
2023

24+
25+
def _col_as_list(df, col):
26+
"""Extract a column as a list of strings; empty strings if the column is missing."""
27+
if col in df.columns:
28+
return df[col].astype(str).tolist()
29+
return [""] * len(df)
30+
31+
32+
ISM_OUTPUT_COLUMNS = [
33+
"region", "position", "ref_base", "alt_base",
34+
"gene_id", "gene_name", "gene_type",
35+
"scorer", "track_name", "ontology_curie",
36+
"raw_score", "quantile_score",
37+
]
38+
39+
40+
def _write_chunked_tsv(outfile, columns):
41+
df = pd.DataFrame(columns)
42+
df.to_csv(outfile, sep="\t", header=False, index=False,
43+
float_format="%.6f", na_rep="",
44+
lineterminator="\r\n")
45+
return len(df)
46+
47+
48+
def _write_region_results(outfile, name, results, scorers_arg, flush_rows=200_000):
49+
"""Post-process ISM results for one region and stream TSV rows to outfile.
50+
51+
Returns the number of non-NaN rows written. obs/var metadata is identical
52+
across variants for a given scorer, so extract once per (region, scorer)
53+
then flatten non-NaN cells into a DataFrame and let pandas' C path handle
54+
CSV formatting.
55+
"""
56+
chunks = {col: [] for col in ISM_OUTPUT_COLUMNS}
57+
pending_rows = 0
58+
per_scorer_meta = {}
59+
row_count = 0
60+
61+
def flush_pending():
62+
nonlocal pending_rows, row_count
63+
if pending_rows == 0:
64+
return
65+
row_count += _write_chunked_tsv(
66+
outfile,
67+
{col: np.concatenate(values) for col, values in chunks.items()},
68+
)
69+
for values in chunks.values():
70+
values.clear()
71+
pending_rows = 0
72+
73+
for var_results in results:
74+
for scorer_idx, ad in enumerate(var_results):
75+
variant_obj = ad.uns["variant"]
76+
pos = variant_obj.position
77+
ref_base = variant_obj.reference_bases
78+
alt_base = variant_obj.alternate_bases
79+
scorer_name = (
80+
scorers_arg[scorer_idx]
81+
if scorer_idx < len(scorers_arg)
82+
else f"scorer_{scorer_idx}"
83+
)
84+
85+
if scorer_idx not in per_scorer_meta:
86+
per_scorer_meta[scorer_idx] = (
87+
np.asarray(_col_as_list(ad.obs, "gene_id")),
88+
np.asarray(_col_as_list(ad.obs, "gene_name")),
89+
np.asarray(_col_as_list(ad.obs, "gene_type")),
90+
np.asarray(_col_as_list(ad.var, "name")),
91+
np.asarray(_col_as_list(ad.var, "ontology_curie")),
92+
)
93+
gene_ids, gene_names, gene_types, track_names, track_curies = per_scorer_meta[scorer_idx]
94+
95+
raw_scores = np.asarray(ad.X, dtype=float)
96+
quantile_scores = ad.layers.get("quantiles", None)
97+
if quantile_scores is not None:
98+
quantile_scores = np.asarray(quantile_scores, dtype=float)
99+
100+
valid_mask = ~np.isnan(raw_scores)
101+
gi, ti = np.where(valid_mask)
102+
if gi.size == 0:
103+
continue
104+
105+
if quantile_scores is not None:
106+
q_flat = quantile_scores[gi, ti]
107+
else:
108+
q_flat = np.full(gi.size, np.nan)
109+
110+
size = gi.size
111+
chunks["region"].append(np.full(size, name, dtype=object))
112+
chunks["position"].append(np.full(size, pos))
113+
chunks["ref_base"].append(np.full(size, ref_base, dtype=object))
114+
chunks["alt_base"].append(np.full(size, alt_base, dtype=object))
115+
chunks["gene_id"].append(gene_ids[gi])
116+
chunks["gene_name"].append(gene_names[gi])
117+
chunks["gene_type"].append(gene_types[gi])
118+
chunks["scorer"].append(np.full(size, scorer_name, dtype=object))
119+
chunks["track_name"].append(track_names[ti])
120+
chunks["ontology_curie"].append(track_curies[ti])
121+
chunks["raw_score"].append(raw_scores[gi, ti])
122+
chunks["quantile_score"].append(q_flat)
123+
pending_rows += size
124+
if pending_rows >= flush_rows:
125+
flush_pending()
126+
flush_pending()
127+
return row_count
128+
129+
130+
def _load_mock_ism_results(path):
131+
"""Load JSON describing mock AnnData inputs for CI testing of the post-processing path.
132+
133+
The real --test-fixture path bypasses post-processing entirely (it just dumps
134+
pre-computed TSV rows). This loader constructs the minimal AnnData interface
135+
consumed by _write_region_results so CI can exercise the vectorized code
136+
path against controlled multi-region / NaN / missing-quantile cases.
137+
"""
138+
class _MockAd:
139+
def __init__(self, obs, var, X, quantiles, variant):
140+
self.obs = pd.DataFrame(obs)
141+
self.var = pd.DataFrame(var)
142+
self.X = np.asarray(X, dtype=float)
143+
self.layers = (
144+
{"quantiles": np.asarray(quantiles, dtype=float)}
145+
if quantiles is not None
146+
else {}
147+
)
148+
self.uns = {"variant": SimpleNamespace(**variant)}
149+
150+
with open(path) as f:
151+
data = json.load(f)
152+
153+
regions = []
154+
for region_data in data["regions"]:
155+
region_results = []
156+
for var_entry in region_data["variants"]:
157+
variant_meta = {
158+
"position": var_entry["position"],
159+
"reference_bases": var_entry["reference_bases"],
160+
"alternate_bases": var_entry["alternate_bases"],
161+
}
162+
region_results.append([
163+
_MockAd(
164+
obs=scorer.get("obs", {}),
165+
var=scorer.get("var", {}),
166+
X=scorer["X"],
167+
quantiles=scorer.get("quantiles"),
168+
variant=variant_meta,
169+
)
170+
for scorer in var_entry["scorers"]
171+
])
172+
regions.append((region_data["name"], region_results))
173+
return regions, data.get("scorers", [])
174+
175+
21176
__version__ = "0.6.1"
22177

23178
ORGANISM_MAP = {
@@ -91,6 +246,22 @@ def run(args):
91246
logging.info("Fixture mode: wrote %d rows to %s", len(fixture_data["rows"]), args.output)
92247
return
93248

249+
if args.mock_ism_results:
250+
mock_regions, mock_scorers = _load_mock_ism_results(args.mock_ism_results)
251+
with open(args.output, "w", newline="") as outfile:
252+
writer = csv.writer(outfile, delimiter="\t")
253+
writer.writerow([
254+
"region", "position", "ref_base", "alt_base",
255+
"gene_id", "gene_name", "gene_type",
256+
"scorer", "track_name", "ontology_curie",
257+
"raw_score", "quantile_score",
258+
])
259+
row_count = 0
260+
for name, results in mock_regions:
261+
row_count += _write_region_results(outfile, name, results, mock_scorers)
262+
logging.info("Mock mode: wrote %d rows to %s", row_count, args.output)
263+
return
264+
94265
api_key = args.api_key or os.environ.get("ALPHAGENOME_API_KEY")
95266
if not api_key and not args.local_model:
96267
logging.error("No API key provided. Set ALPHAGENOME_API_KEY or use --api-key")
@@ -145,47 +316,8 @@ def run(args):
145316
max_workers=args.max_workers,
146317
)
147318

148-
# results is list[list[AnnData]] — outer=variants (3*width), inner=scorers
149-
# Each AnnData has: uns['variant'] with position/ref/alt,
150-
# X for raw scores, layers['quantiles'], obs for genes, var for tracks
151-
for var_results in results:
152-
for scorer_idx, ad in enumerate(var_results):
153-
variant_obj = ad.uns["variant"]
154-
pos = variant_obj.position
155-
ref_base = variant_obj.reference_bases
156-
alt_base = variant_obj.alternate_bases
157-
scorer_name = args.scorers[scorer_idx] if scorer_idx < len(args.scorers) else f"scorer_{scorer_idx}"
158-
159-
raw_scores = ad.X # shape (n_genes, n_tracks)
160-
quantile_scores = ad.layers.get("quantiles", None)
161-
162-
for gene_idx in range(ad.n_obs):
163-
gene_row = ad.obs.iloc[gene_idx]
164-
gene_id = str(gene_row.get("gene_id", ""))
165-
gene_name = str(gene_row.get("gene_name", ""))
166-
gene_type = str(gene_row.get("gene_type", ""))
167-
168-
for track_idx in range(ad.n_vars):
169-
track_row = ad.var.iloc[track_idx]
170-
track_name = str(track_row.get("name", ""))
171-
ontology_curie = str(track_row.get("ontology_curie", ""))
172-
173-
raw = float(raw_scores[gene_idx, track_idx])
174-
if np.isnan(raw):
175-
continue
176-
quant = ""
177-
if quantile_scores is not None:
178-
q = float(quantile_scores[gene_idx, track_idx])
179-
if not np.isnan(q):
180-
quant = f"{q:.6f}"
181-
182-
writer.writerow([
183-
name, pos, ref_base, alt_base,
184-
gene_id, gene_name, gene_type,
185-
scorer_name, track_name, ontology_curie,
186-
f"{raw:.6f}", quant,
187-
])
188-
row_count += 1
319+
# results is list[list[AnnData]] -- outer=variants (3*width), inner=scorers
320+
row_count += _write_region_results(outfile, name, results, args.scorers)
189321

190322
stats["scored"] += 1
191323
logging.info("Region %s: %d ISM variants scored", name, len(results))
@@ -228,6 +360,9 @@ def parse_arguments():
228360
parser.add_argument("--local-model", action="store_true")
229361
parser.add_argument("--test-fixture", default=None,
230362
help="Test fixture JSON for CI testing (bypasses API)")
363+
parser.add_argument("--mock-ism-results", default=None,
364+
help="Mock AnnData JSON for CI testing the post-processing "
365+
"path (bypasses API but exercises the vectorized loop)")
231366
parser.add_argument("--verbose", action="store_true")
232367
parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
233368
return parser.parse_args()

0 commit comments

Comments
 (0)