Skip to content

Commit 443548d

Browse files
feat: support more than two variables
1 parent 1c63119 commit 443548d

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

workflow/scripts/plot_distributions.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,31 @@
88

99
EPSILON = 0.1
1010

11-
assert len(snakemake.params.vars) == 2, "only two variables are supported for now"
12-
11+
vars = snakemake.params.vars
1312
mode = snakemake.wildcards.mode
1413

1514
assert (
1615
snakemake.params.min_fold_change >= 1.0
1716
), "min_fold_change must be greater than 1.0"
1817
min_conservative_log2_fold_change = math.log2(snakemake.params.min_fold_change)
1918

20-
color_col = "case" if mode == "all" else snakemake.params.vars[1]
21-
2219
data = pl.read_parquet(snakemake.input.data)
20+
21+
if len(vars) > 2:
22+
# combine vars[1:] into a single variable with ":" as separator
23+
data = data.with_columns(
24+
pl.concat_list(vars[1:]).list.join(": ").alias("combined_var"),
25+
)
26+
vars = [vars[0], "combined_var"]
27+
28+
color_col = "case" if mode == "all" else vars[1]
29+
2330
var_values = (
24-
data.get_column(snakemake.params.vars[1]).unique(maintain_order=True).to_list()
31+
data.get_column(vars[1]).unique(maintain_order=True).to_list()
2532
)
2633
var_indexes = {value: i for i, value in enumerate(var_values)}
2734
data = data.with_columns(
28-
pl.col(snakemake.params.vars[1]).replace_strict(var_indexes).alias("index"),
35+
pl.col(vars[1]).replace_strict(var_indexes).alias("index"),
2936
pl.concat_list(snakemake.params.vars).list.join(": ").alias("case"),
3037
)
3138

@@ -50,7 +57,7 @@ def get_conservative_log2_fold_change(ci) -> float:
5057
[
5158
pl.col(f"group_{group}").list.get(i).alias(f"{varname}_{group}")
5259
for group in ["a", "b"]
53-
for i, varname in enumerate(snakemake.params.vars)
60+
for i, varname in enumerate(vars)
5461
],
5562
)
5663
)
@@ -60,7 +67,7 @@ def get_conservative_log2_fold_change(ci) -> float:
6067
pl.Series(
6168
[
6269
var_indexes[value]
63-
for value in cis.get_column(f"{snakemake.params.vars[1]}_{group}")
70+
for value in cis.get_column(f"{vars[1]}_{group}")
6471
]
6572
).alias(f"index_{group}")
6673
for group in ["a", "b"]
@@ -102,18 +109,18 @@ def fmt_fold_change(effect) -> str:
102109
pl.struct("conservative_log2_fold_change", "brunner_munzel_adjusted_pvalue")
103110
.map_elements(fmt_fold_change, return_dtype=str)
104111
.alias("fold change"),
105-
pl.concat_list([f"{var}_a" for var in snakemake.params.vars])
112+
pl.concat_list([f"{var}_a" for var in vars])
106113
.list.join(": ")
107114
.alias("case_a"),
108-
pl.concat_list([f"{var}_b" for var in snakemake.params.vars])
115+
pl.concat_list([f"{var}_b" for var in vars])
109116
.list.join(": ")
110117
.alias("case_b"),
111118
)
112119

113120

114121
# generate data frame with two rows for each group_a, group_b pair, one with
115-
# the group_a values and the corresponding snakemake.params.vars[0] value and the corresponding index,
116-
# and one with group_b values and the corresponding snakemake.params.vars[0] value and the corresponding index
122+
# the group_a values and the corresponding vars[0] value and the corresponding index,
123+
# and one with group_b values and the corresponding vars[0] value and the corresponding index
117124
color_spec = alt.Color(color_col, type="nominal").scale(
118125
domain=color_order, range=snakemake.params.color_scheme
119126
)
@@ -140,12 +147,12 @@ def fmt_fold_change(effect) -> str:
140147
)
141148

142149
if mode == "selected":
143-
# add an underline for each variable in snakemake.params.vars[0]
150+
# add an underline for each variable in vars[0]
144151
# alt.X should be the value of the first value of case of the respective group in the data frame
145152
# alt.X2 should be the value of the last value of case of the respective group in the data frame
146-
underline_data = data.group_by(snakemake.params.vars[0], maintain_order=True).agg(
153+
underline_data = data.group_by(vars[0], maintain_order=True).agg(
147154
[
148-
pl.col(snakemake.params.vars[0]).first().alias("label"),
155+
pl.col(vars[0]).first().alias("label"),
149156
pl.col("case").first().alias("x"),
150157
pl.col("case").last().alias("x2"),
151158
]
@@ -201,7 +208,7 @@ def swap_colname(col):
201208
how="semi",
202209
on=[
203210
f"{var}_{group}"
204-
for var in snakemake.params.vars
211+
for var in vars
205212
for group in ["a", "b"]
206213
],
207214
)

0 commit comments

Comments
 (0)