diff --git a/docs/posts/udf-rewriting/images/tree-pruning.png b/docs/posts/udf-rewriting/images/tree-pruning.png new file mode 100644 index 000000000000..e9672dc4060e Binary files /dev/null and b/docs/posts/udf-rewriting/images/tree-pruning.png differ diff --git a/docs/posts/udf-rewriting/index.qmd b/docs/posts/udf-rewriting/index.qmd new file mode 100644 index 000000000000..095739a3b28b --- /dev/null +++ b/docs/posts/udf-rewriting/index.qmd @@ -0,0 +1,417 @@ +--- +title: "Dynamic UDF Rewriting with Predicate Pushdowns" +author: "Hussain Sultan" +date: "2025-02-04" +categories: + - blog + - case study + - machine learning + - ecosystem +image: images/tree-pruning.png +--- + +## Introduction + +In an ideal world, deploying machine learning models within SQL queries would +be as simple as calling a built-in function. Unfortunately, many ML predictions +live inside **User-Defined Functions (UDFs)** that traditional SQL planners +can't modify, preventing optimizations like predicate pushdowns. + +This blog post will showcase how you can **prune decision tree models based on +query filters** by dynamically rewriting your expression using **Ibis** and +**quickgrove**, an experimental +[GBDT](https://developers.google.com/machine-learning/decision-forests/intro-to-gbdt) +inference library built in Rust. We'll also show how +[LetSQL](https://github.com/letsql/letsql) can simplify this pattern further +and integrate seamlessly into your ML workflows. + + +## ML models meet SQL + +When you deploy machine learning models (like a gradient-boosted trees model +from XGBoost) in a data warehouse, you typically wrap them in a UDF. Something +like: + +```sql +SELECT + my_udf_predict(carat, depth, color, clarity, ...) +FROM diamonds +WHERE color_i < 1 AND clarity_vvs2 < 1 +``` + +The challenge is that **SQL planners don’t know what’s happening inside the +UDF**. Even if you filter `color_i < 1`, the full model, including skippable +tree paths, are evaluated for every row. With tree-based models, entire +branches might never be evaluated at all — so the ideal scenario is to prune +those unnecessary branches *before* evaluating them. + + +## Smart UDFs with Ibis + +**Ibis** is known for letting you write engine-agnostic deferred expressions in +Python without losing the power of underlying engines like Spark, DuckDB, or +BigQuery. Meanwhile, quickgrove provides a mechanism to prune Gradient Boosted +Decision Tree (GBDT) models based on known filter conditions. + +**Key Ideas**: + +1. **Prune decision trees** by removing branches that can never be reached, + given the known filters +2. **Rewrite expressions** with the pruned model into the query plan to skip + unnecessary computations + +### Understanding tree pruning + +![Tree Pruning](images/tree-pruning.png) + +Take a simple example: a decision tree that splits on `color_i < 1`. If your +query also has a predicate `color < 1`, any branches with feature `color_i >= +1` will never be evaluated. By **removing** that branch, the tree becomes +smaller and faster to evaluate—especially when you have hundreds of trees (as +in many gradient-boosted models). + +**Reference**: Check out the [Raven +optimizer](https://arxiv.org/pdf/2206.00136) paper. It demonstrates how you can +prune nodes in query plans for tree-based inference, so we’re taking a similar +approach here for **forests** (GBDTs) using **Ibis.** + + +### Quickgrove: prunable GBDT models + +Quickgrove is an experimental package that can load GBDT JSON models and +provides a `.prune(...)` API to remove unreachable branches. For example: + +```python +#pip install quickgrove +import quickgrove + +model = quickgrove.json_load("diamonds_model.json") # Load an XGBoost model +model.prune([quickgrove.Feature("color_i") < 0.2]) # Prune based on known predicate +``` + +Once pruned, the model is leaner to evaluate. Note: The results heavily depend on +model splits and interactions with predicate pushdowns. + + +## Scalar PyArrow UDFs in Ibis + +::: {.column-margin} +Please note that we are using our own modified DataFusion backend. The +DataFusion backend and DuckDB backend behave differently: DuckDB expects a +`ChunkedArray` while DataFusion UDFs expect `ArrayRef`. We are working on +extending quickgrove to work with the DuckDB backend. +::: + +We’ll define a simple Ibis UDF that calls our `model.predict_arrays` under the +hood: + +```python +import ibis +import ibis.expr.datatypes as dt + +ibis.set_backend("datafusion") +@ibis.udf.scalar.pyarrow +def predict_gbdt( + carat: dt.float64, + depth: dt.float64, + # ... other features ... +) -> dt.float32: + array_list = [carat, depth, ...] + return model.predict_arrays(array_list) +``` + +Currently, UDFs are opaque to Ibis. We need Ibis to teach Ibis how to rewrite a +udf based on predicates it knows about. + + +## Making Ibis UDFs predicate-aware + +Here’s the general process: + +1. **Collect predicates** from the user’s filter (e.g. `x < 0.3`). +2. **Prune** the model based on those predicates (removing unreachable tree + branches). +3. **Rewrite** a new UDF that references the pruned model, preserving the rest + of the query plan. + +### 1. Collecting predicates + +```python +from ibis.expr.operations import Filter, Less, Field, Literal +from typing import List, Dict + +def collect_predicates(filter_op: Filter) -> List[dict]: + """Extract 'column < value' predicates from a Filter operation.""" + predicates = [] + for pred in filter_op.predicates: + if isinstance(pred, Less) and isinstance(pred.left, Field): + if isinstance(pred.right, Literal): + predicates.append({ + "column": pred.left.name, + "op": "Less", + "value": pred.right.value + }) + return predicates +``` + +### 2. Pruning model and creating a new UDF + +```python +import functools +from ibis.expr.operations import ScalarUDF +from ibis.common.collections import FrozenDict + +def create_pruned_udf(original_udf, model, predicates): + """Create a new UDF using the pruned model based on the collected predicates.""" + from quickgrove import Feature + + # Prune the model + pruned_model = model.prune([ + Feature(pred["column"]) < pred["value"] + for pred in predicates + if pred["op"] == "Less" and pred["value"] is not None + ]) + # For simplicity, let’s assume we know the relevant features or keep them the same. + + def fn_from_arrays(*arrays): + return pruned_model.predict_arrays(list(arrays)) + + # Construct a dynamic UDF class + meta = { + "dtype": dt.float32, + "__input_type__": "pyarrow", + "__func__": property(lambda self: fn_from_arrays), + "__config__": FrozenDict(volatility="immutable"), + "__udf_namespace__": original_udf.__module_ + "__module__": original_udf.__module__, + "__func_name__": original_udf.__name__ + "_pruned" + } + + # Create a new ScalarUDF node type on the fly + node = type(original_udf.__name__ + "_pruned", (ScalarUDF,), {**fields, **meta}) + + @functools.wraps(fn_from_arrays) + def construct(*args, **kwargs): + return node(*args, **kwargs).to_expr() + + construct.fn = fn_from_arrays + return construct +``` + +### 3. Rewriting the plan + +Now we use an Ibis rewrite rule (or a custom function) to **detect filters** on +the expression, prune the model, and produce a new project/filter node. + +```python +from ibis.expr.operations import Project + +@replace(p.Filter) +def prune_gbdt_model(filter_op, original_udf, model): + """Rewrite rule to prune GBDT model based on filter predicates.""" + + predicates = collect_predicates(filter_op) + if not predicates: + # Nothing to prune if no relevant predicates + return filter_op + # in a real implementation you'd want to match on a ScalarUDF and ensure that the instance of the model type is + # the one implemented with quickgrove + pruned_udf, required_features = create_pruned_udf(original_udf, model, predicates) + + parent_op = filter_op.parent + # Build a new projection with the pruned UDF + new_values = {} + for name, value in parent_op.values.items(): + # If it’s the column that calls the UDF, swap with pruned version + if name == "prediction": + # For brevity, assume we pass the same columns to the pruned UDF + new_values[name] = pruned_udf(value.op().args[0], value.op().args[1]) + else: + new_values[name] = value + + new_project = Project(parent_op.parent, new_values) + + # Re-add the filter conditions on top + new_predicates = [] + for pred in filter_op.predicates: + if isinstance(pred, Less) and isinstance(pred.left, Field): + new_predicates.append( + Less(Field(new_project, pred.left.name), pred.right) + ) + else: + new_predicates.append(pred) + + return Filter(parent=new_project, predicates=new_predicates) +``` + +### Diff + +For a query like the following: + +```python +expr = ( + t.mutate(prediction=predict_gbdt(t.carat, t.depth, ...)) + .filter( + (t["clarity_vvs2"] < 1), + (t["color_i"] < 1), + (t["color_j"] < 1) + ) + .select("prediction") +) +``` +See the diff below: + +Notice that with pruning we can drop some of the projections in +the UDF e.g., `color_i`, `color_j` and `clarity_vvs2`. The underlying engine +(e.g., DataFusion) may optimize this further when pulling data for UDFs. We +cannot completely drop these from the query expression. + +```shell +- predict_gbdt_3( ++ predict_gbdt_pruned( + carat, depth, table, x, y, z, + cut_good, cut_ideal, cut_premium, cut_very_good, +- color_e, color_f, color_g, color_h, color_i, color_j, ++ color_e, color_f, color_g, color_h, + clarity_if, clarity_si1, clarity_si2, clarity_vs1, +- clarity_vs2, clarity_vvs1, clarity_vvs2 ++ clarity_vs2, clarity_vvs1 +) +``` + +## Putting it all together + +The complete example can be found [here](https://github.com/letsql/trusty/blob/main/python/examples/ibis_filter_condition.py). + +```python +# 1. Load your dataset into Ibis +t = ibis.read_csv("diamonds_data.csv") + +expr = ( + t.mutate(prediction=predict_gbdt(t.carat, t.depth, ...)) + .filter( + (t["clarity_vvs2"] < 1), + (t["color_i"] < 1), + (t["color_j"] < 1) + ) + .select("prediction") +) + +# 3. Apply your custom optimization +optimized_expr = prune_gbdt_model(expr.op(), predict_gbdt, model) + +# 4. Execute the optimized query +result = optimized_expr.to_expr().execute() +``` + +When this is done, the model inside `predict_gbdt` will be **pruned** based on +the expression's filter conditions. This can yield significant speedups on +large datasets (see @tbl-perf). + + +## Performance impact + +[Here](https://github.com/letsql/quickgrove/blob/main/python/examples/ibis_filter_condition_bench.py) +is the benchmark results ran on Apple M2 Mac Mini, 8 cores / 8GB Memory run +with a model trained with 100 trees and depth 6 with following filter +conditions: + +``` +_.carat < 1, +_.clarity_vvs2 < 1, +_.color_i < 1, +_.color_j < 1, +``` + +Benchmark results: + +| File Size | Regular (s) | Optimized (s) | Improvement | +| --- | --- | --- | --- | +| 5M | 0.82 ±0.02 | 0.67 ±0.02 | 18.0% | +| 25M | 4.16 ±0.01 | 3.46 ±0.05 | 16.7% | +| 100M | 16.80 ±0.17 | 14.07 ±0.11 | 16.3% | +: Performance improvements {#tbl-perf} + +**Key takeaway**: As data volume grows, skipping unneeded tree branches can +translate to real compute savings, albeit heavily dependent on how pertinent +the filter conditions might be. + + +## LetSQL: simplifying UDF rewriting + +[LetSQL](https://letsql.com/)) makes advanced UDF rewriting and multi-engine +pipelines much simpler. It builds on the same ideas we explored here but wraps +them in a higher-level API. + +Here’s a quick glimpse of how LetSQL might simplify the pattern: + +```python +# pip install letsql + +import letsql as ls +from letsql.expr.ml import make_quickgrove_udf, rewrite_quickgrove_expression + +model_path = "xgboost_model.json" +predict_udf = make_quickgrove_udf(model_path) + +t = ls.memtable(df).mutate(pred=predict_udf.on_expr).filter(ls._.carat < 1) +optimized_t = rewrite_quickgrove_expression(t) + +result = ls.execute(optimized_t) +``` +The complete example can be found +[here](https://github.com/letsql/letsql/blob/main/examples/quickgrove_udf.py). +With LetSQL, you get a **shorter, more declarative approach** to the same +optimization logic we manually coded with Ibis. It abstracts away the gritty +parts of rewriting your query plan. + + +## Best practices & considerations + +- **Predicate Types**: Currently, we demonstrated `column < value` logic. You +can extend it to handle `<=`, `>`, `BETWEEN`, or even categorical splits. +- **Quickgrove** only supports a handful of objective functions and most +notably does not have categorical support yet. In theory, categorical variables +make better candidates for pruning based on filter conditions. It only +supports XGBoost format. +- **Model Format**: XGBoost JSON is straightforward to parse. Other formats +(e.g. LightGBM, scikit-learn trees) require similar logic or conversion steps. +- **Edge Cases**: If the filter references columns not in the model features, +or if multiple filters combine in more complex ways, your rewriting logic may +need more robust parsing. +- **When to Use**: This approach is beneficial when queries often filter on the +same columns your trees split on. For purely adhoc queries or rarely used +filters, the overhead of rewriting might outweigh the benefit. + + +## Conclusion + +Combining **Ibis** with a prune-friendly framework like quickgrove lets you +optimize large-scale ML inference inside ML workflows. By **pushing filter +predicates down into your decision trees**, you speed up queries significantly. + +With LetSQL, you can streamline this entire process—especially if you’re +looking for an out-of-the-box solution that integrates with multiple engines +along with batteries included features like caching and aggregate/window UDFs. +For the next steps, consider experimenting with more complex models, exploring +different tree pruning strategies, or even extending this pattern to other ML +models beyond GBDTs. + +- **Try it out**: Explore the Ibis documentation to learn how to build custom +UDFs. +- **Dive deeper**: Check out [quickgrove](https://github.com/letsql/trusty) or +read the Raven optimizer [paper](https://arxiv.org/pdf/2206.00136). +- **Experiment with LetSQL**: If you need a polished solution for dynamic ML +UDF rewriting, [LetSQL](https://github.com/letsql/letsql) may be just the +ticket. + +--- + +## Resources + +- **Raven Paper**: [End-to-end Optimization of Machine Learning Prediction +Queries](https://arxiv.org/pdf/2206.00136) +- **Ibis + Torch**: [Ibis Project Blog +Post](https://ibis-project.org/posts/torch/) +- [Multi-Engine Data Stack with +Ibis](https://www.letsql.com/posts/multi-engine-data-stack-ibis/)