88
99EPSILON = 0.1
1010
11- assert len (snakemake .params .vars ) == 2 , "only two variables are supported for now"
12-
11+ vars = snakemake .params .vars
1312mode = snakemake .wildcards .mode
1413
1514assert (
1615 snakemake .params .min_fold_change >= 1.0
1716), "min_fold_change must be greater than 1.0"
1817min_conservative_log2_fold_change = math .log2 (snakemake .params .min_fold_change )
1918
20- color_col = "case" if mode == "all" else snakemake .params .vars [1 ]
21-
2219data = pl .read_parquet (snakemake .input .data )
20+
21+ if len (vars ) > 2 :
22+ # combine vars[1:] into a single variable with ":" as separator
23+ data = data .with_columns (
24+ pl .concat_list (vars [1 :]).list .join (": " ).alias ("combined_var" ),
25+ )
26+ vars = [vars [0 ], "combined_var" ]
27+
28+ color_col = "case" if mode == "all" else vars [1 ]
29+
2330var_values = (
24- data .get_column (snakemake . params . vars [1 ]).unique (maintain_order = True ).to_list ()
31+ data .get_column (vars [1 ]).unique (maintain_order = True ).to_list ()
2532)
2633var_indexes = {value : i for i , value in enumerate (var_values )}
2734data = data .with_columns (
28- pl .col (snakemake . params . vars [1 ]).replace_strict (var_indexes ).alias ("index" ),
35+ pl .col (vars [1 ]).replace_strict (var_indexes ).alias ("index" ),
2936 pl .concat_list (snakemake .params .vars ).list .join (": " ).alias ("case" ),
3037)
3138
@@ -50,7 +57,7 @@ def get_conservative_log2_fold_change(ci) -> float:
5057 [
5158 pl .col (f"group_{ group } " ).list .get (i ).alias (f"{ varname } _{ group } " )
5259 for group in ["a" , "b" ]
53- for i , varname in enumerate (snakemake . params . vars )
60+ for i , varname in enumerate (vars )
5461 ],
5562 )
5663)
@@ -60,7 +67,7 @@ def get_conservative_log2_fold_change(ci) -> float:
6067 pl .Series (
6168 [
6269 var_indexes [value ]
63- for value in cis .get_column (f"{ snakemake . params . vars [1 ]} _{ group } " )
70+ for value in cis .get_column (f"{ vars [1 ]} _{ group } " )
6471 ]
6572 ).alias (f"index_{ group } " )
6673 for group in ["a" , "b" ]
@@ -102,18 +109,18 @@ def fmt_fold_change(effect) -> str:
102109 pl .struct ("conservative_log2_fold_change" , "brunner_munzel_adjusted_pvalue" )
103110 .map_elements (fmt_fold_change , return_dtype = str )
104111 .alias ("fold change" ),
105- pl .concat_list ([f"{ var } _a" for var in snakemake . params . vars ])
112+ pl .concat_list ([f"{ var } _a" for var in vars ])
106113 .list .join (": " )
107114 .alias ("case_a" ),
108- pl .concat_list ([f"{ var } _b" for var in snakemake . params . vars ])
115+ pl .concat_list ([f"{ var } _b" for var in vars ])
109116 .list .join (": " )
110117 .alias ("case_b" ),
111118)
112119
113120
114121# generate data frame with two rows for each group_a, group_b pair, one with
115- # the group_a values and the corresponding snakemake.params. vars[0] value and the corresponding index,
116- # and one with group_b values and the corresponding snakemake.params. vars[0] value and the corresponding index
122+ # the group_a values and the corresponding vars[0] value and the corresponding index,
123+ # and one with group_b values and the corresponding vars[0] value and the corresponding index
117124color_spec = alt .Color (color_col , type = "nominal" ).scale (
118125 domain = color_order , range = snakemake .params .color_scheme
119126)
@@ -140,12 +147,12 @@ def fmt_fold_change(effect) -> str:
140147)
141148
142149if mode == "selected" :
143- # add an underline for each variable in snakemake.params. vars[0]
150+ # add an underline for each variable in vars[0]
144151 # alt.X should be the value of the first value of case of the respective group in the data frame
145152 # alt.X2 should be the value of the last value of case of the respective group in the data frame
146- underline_data = data .group_by (snakemake . params . vars [0 ], maintain_order = True ).agg (
153+ underline_data = data .group_by (vars [0 ], maintain_order = True ).agg (
147154 [
148- pl .col (snakemake . params . vars [0 ]).first ().alias ("label" ),
155+ pl .col (vars [0 ]).first ().alias ("label" ),
149156 pl .col ("case" ).first ().alias ("x" ),
150157 pl .col ("case" ).last ().alias ("x2" ),
151158 ]
@@ -201,7 +208,7 @@ def swap_colname(col):
201208 how = "semi" ,
202209 on = [
203210 f"{ var } _{ group } "
204- for var in snakemake . params . vars
211+ for var in vars
205212 for group in ["a" , "b" ]
206213 ],
207214 )
0 commit comments