diff --git a/docs/posts/udf-rewriting/images/tree-pruning.png b/docs/posts/udf-rewriting/images/tree-pruning.png index fb92063ead55..e9672dc4060e 100644 Binary files a/docs/posts/udf-rewriting/images/tree-pruning.png and b/docs/posts/udf-rewriting/images/tree-pruning.png differ diff --git a/docs/posts/udf-rewriting/index.qmd b/docs/posts/udf-rewriting/index.qmd index 6e8c2f612797..2baf61b516cf 100644 --- a/docs/posts/udf-rewriting/index.qmd +++ b/docs/posts/udf-rewriting/index.qmd @@ -7,13 +7,14 @@ categories: - case study - machine learning - ecosystem +image: images/tree-pruning.png --- ## Introduction -In an ideal world, deploying machine learning models within SQL queries would be as simple as calling a built-in function. Unfortunately, many ML predictions live inside **User-Defined Functions (UDFs)** that traditional SQL optimizers treat as black boxes. This effectively prevents advanced optimizations like predicate pushdown, leading to significant performance overhead when running large-scale inference. +In an ideal world, deploying machine learning models optimially within SQL queries would be as simple as calling a built-in function. Unfortunately, many ML predictions live inside **User-Defined Functions (UDFs)** that traditional SQL optimizers cannot modify the UDF expression. This effectively prevents advanced optimizations like predicate pushdown, leading to significant performance overhead when running large-scale inference. -In this blog post, we’ll showcase how you can **prune decision tree models based on query filters** by dynamically rewriting your UDF using **Ibis** and **quickgrove**,an experimental XGBoost inference library built in Rust. We'll also show how [LetSQL](https://github.com/letsql/letsql) can simplify this pattern further and integrate seamlessly into your ML workflows. +In this blog post, we’ll showcase how you can **prune decision tree models based on query filters** by dynamically rewriting your expression using **Ibis** and **quickgrove**,an experimental XGBoost inference library built in Rust. We'll also show how [LetSQL](https://github.com/letsql/letsql) can simplify this pattern further and integrate seamlessly into your ML workflows. --- @@ -28,7 +29,7 @@ FROM diamonds WHERE color_i < 1 AND clarity_vvs2 < 1 ``` -The problem is that **SQL optimizers don’t know what’s happening inside the UDF**. Even if you filter `color_i < 1`, the full model itself is still evaluated for every row meeting the filter condition. With tree-based models, however, entire branches might never be evaluated at all if they exceed certain thresholds—so the ideal scenario is to prune those unnecessary branches *before* evaluating them. +The challenge is that **SQL optimizers don’t know what’s happening inside the UDF**. Even if you filter `color_i < 1`, the full model, including skippable tree paths, are evaluated for every row. With tree-based models, entire branches might never be evaluated at all — so the ideal scenario is to prune those unnecessary branches *before* evaluating them. ### Why It Matters @@ -44,21 +45,20 @@ The problem is that **SQL optimizers don’t know what’s happening inside the **Key Ideas**: -1. **Create UDFs that are predicate-aware** — So we can rewrite them if the upstream query includes filters. -2. **Prune decision trees** — Removing branches that can never be reached, given the known filters. -3. **Inject the pruned model** back into your query plan to skip unnecessary computations. +1. **Prune decision trees** by removing branches that can never be reached, given the known filters +2. **Rewritie expressions** with pruned model into the query plan to skip unnecessary computations ### Understanding Tree Pruning ![Tree Pruning](images/tree-pruning.png) -Take a simple example: a decision tree that splits on `x < 0.3`. If your query also has a predicate `x < 0.2`, any branches assuming `x >= 0.3` will never be evaluated. By **removing** that branch, the tree becomes smaller and faster to evaluate—especially when you have hundreds of trees (as in many gradient-boosted models). +Take a simple example: a decision tree that splits on `color_i < 1`. If your query also has a predicate `color < 1`, any branches with feature `color_i >= 1` will never be evaluated. By **removing** that branch, the tree becomes smaller and faster to evaluate—especially when you have hundreds of trees (as in many gradient-boosted models). **Reference**: Check out the [Raven optimizer](https://arxiv.org/pdf/2206.00136) paper. It demonstrates how you can prune nodes in query plans for tree-based inference, so we’re taking a similar approach here for **forests** (GBDTs) using **Ibis.** --- -### Enter quickgrove: Prune-able GBM Models +### Quickgrove: Prune-able GBM Models Quickgrove is an experimental package that can loads XGBoost JSON models and provides a `.prune(...)` API to remove unreachable branches. For example: @@ -67,10 +67,10 @@ Quickgrove is an experimental package that can loads XGBoost JSON models and pro import quickgrove model = quickgrove.json_load("diamonds_model.json") # Load an XGBoost model -model.prune([quickgrove.Feature("carat") < 0.2]) # Prune based on known predicate +model.prune([quickgrove.Feature("color_i") < 0.2]) # Prune based on known predicate ``` -Once pruned, the model is leaner to evaluate but the results heavily depends on the the model and how the predicates interact with the trees within the model. +Once pruned, the model is leaner to evaluate. The results heavily depend on model splits and interactions with predicate pushdowns. --- @@ -94,7 +94,7 @@ def predict_gbdt( # ... other features ... ) -> dt.float32: array_list = [carat, depth, ...] - return model.predict_arrays(array_list + return model.predict_arrays(array_list) ``` In its default form, `predict_gbdt` is a black box. Now we need Ibis to “understand” it enough to let us swap it out for a pruned version under the right conditions. @@ -107,7 +107,7 @@ Here’s the general process: 1. **Collect predicates** from the user’s filter (e.g. `x < 0.3`). 2. **Prune** the model based on those predicates (removing unreachable tree branches). -3. **Inject** a new UDF that references the pruned model, preserving the rest of the query plan. +3. **Rewrite** a new UDF that references the pruned model, preserving the rest of the query plan. ### 1. Collecting Predicates @@ -180,6 +180,7 @@ Now we use an Ibis rewrite rule (or a custom function) to **detect filters** on ```python from ibis.expr.operations import Project +@replace(p.Filter) def prune_gbdt_model(filter_op, original_udf, model): """Rewrite rule to prune GBDT model based on filter predicates.""" @@ -187,7 +188,8 @@ def prune_gbdt_model(filter_op, original_udf, model): if not predicates: # Nothing to prune if no relevant predicates return filter_op - + # in a real implementation you'd want to match on a ScalarUDF and ensure that the instance of the model type is + # the one implemented with quickgrove pruned_udf, required_features = create_pruned_udf(original_udf, model, predicates) parent_op = filter_op.parent @@ -216,27 +218,38 @@ def prune_gbdt_model(filter_op, original_udf, model): return Filter(parent=new_project, predicates=new_predicates) ``` -## Diff +### Diff + +For a query like the following: -The following columns were removes from the function signature since they are no longer required in the pruned version of the mode. +```python +expr = ( + t.mutate(prediction=predict_gbdt(t.carat, t.depth, ...)) + .filter( + (t["clarity_vvs2"] < 1), + (t["color_i"] < 1), + (t["color_j"] < 1) + ) + .select("prediction") +) +``` +See the diff below: Notice that with pruning we are also able to drop some of the projections in the UDF i.e. `color_i`, `color_j` and `clarity_vvs2`. The underlying engine .e.g. DataFusion may optimize this further when pulling data for UDFs. We cannot completely drop these from the query expression. -```python +```shell - predict_gbdt_3( + predict_gbdt_pruned( carat, depth, table, x, y, z, cut_good, cut_ideal, cut_premium, cut_very_good, - color_e, color_f, color_g, color_h, color_i, color_j, +- color_e, color_f, color_g, color_h, color_i, color_j, ++ color_e, color_f, color_g, color_h, clarity_if, clarity_si1, clarity_si2, clarity_vs1, - clarity_vs2, clarity_vvs1, clarity_vvs2 + clarity_vs2, clarity_vvs1 ) ``` -> Note: The above is a conceptual example. In a real implementation, you’ll wire this into a full Ibis rewrite pass so it automatically triggers whenever relevant filters are present in your query expression. -> - --- ## Putting It All Together @@ -311,7 +324,7 @@ optimized_t = rewrite_quickgrove_expression(t) result = ls.execute(optimized_t) ``` - +The complete example can be found [here](https://github.com/letsql/letsql/blob/main/examples/quickgrove_udf.py). With LetSQL, you get a **shorter, more declarative approach** to the same optimization logic we manually coded with Ibis. It abstracts away the gritty parts of rewriting your query plan. --- @@ -330,7 +343,7 @@ With LetSQL, you get a **shorter, more declarative approach** to the same optimi Combining **Ibis** with a prune-friendly framework like quickgrove lets you automatically optimize large-scale ML inference inside SQL queries. By **pushing filter predicates down into your decision trees**, you skip unnecessary computations and speed up queries significantly. -**And with LetSQL**, you can streamline this entire process—especially if you’re looking for an out-of-the-box solution that integrates with multiple engines and query languages. As next steps, consider experimenting with more complex models, exploring different tree pruning strategies, or even extending this pattern to other ML models beyond GBDTs. +**And with LetSQL**, you can streamline this entire process—especially if you’re looking for an out-of-the-box solution that integrates with multiple engines along with batteries included features like caching and aggregate/window UDFs. As next steps, consider experimenting with more complex models, exploring different tree pruning strategies, or even extending this pattern to other ML models beyond GBDTs. - **Try it out**: Explore the Ibis documentation to learn how to build custom UDFs. - **Dive deeper**: Check out [quickgrove](https://github.com/letsql/trusty) or read the Raven optimizer [paper](https://arxiv.org/pdf/2206.00136). @@ -342,5 +355,4 @@ Combining **Ibis** with a prune-friendly framework like quickgrove lets you auto - **Paper**: [End-to-end Optimization of Machine Learning Prediction Queries (Raven)](https://arxiv.org/pdf/2206.00136) - **Ibis + Torch**: [Ibis Project Blog Post](https://ibis-project.org/posts/torch/) -- **quickgrove**: [GitHub Repository](https://github.com/letsql/quickgrove) - **LetSQL**: [Documentation](https://docs.letsql.com)