Skip to content

Commit

Permalink
fix(starrocks): exp.Unnest transpilation
Browse files Browse the repository at this point in the history
use `arrays_zip` to merge multiple `Lateral view explode`
  • Loading branch information
lin.zhang committed Aug 26, 2024
1 parent 77b87d5 commit 6eac225
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 14 deletions.
3 changes: 2 additions & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing as t
from functools import partial

from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Expand Down Expand Up @@ -485,7 +486,7 @@ class Generator(generator.Generator):
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False),
]
),
exp.Property: _property_sql,
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ class Generator(Hive.Generator):

TRANSFORMS = {
**Hive.Generator.TRANSFORMS,
exp.Select: transforms.preprocess(
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
Expand Down
6 changes: 6 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5331,6 +5331,12 @@ class Explode(Func):
is_var_len_args = True


# https://spark.apache.org/docs/latest/api/sql/#inline
class Inline(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True


class ExplodeOuter(Explode):
pass

Expand Down
28 changes: 23 additions & 5 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing as t

from sqlglot import expressions as exp
from sqlglot.expressions import Func
from sqlglot.helper import find_new_name, name_sequence

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -296,12 +297,16 @@ def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
return expression


def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
def unnest_to_explode(
expression: exp.Expression,
unnest_using_arrays_zip: bool = True,
support_unnest_func: bool = False,
) -> exp.Expression:
"""Convert cross join unnest into lateral view explode."""
if isinstance(expression, exp.Select):
from_ = expression.args.get("from")

if from_ and isinstance(from_.this, exp.Unnest):
if not support_unnest_func and from_ and isinstance(from_.this, exp.Unnest):
unnest = from_.this
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
Expand All @@ -325,17 +330,30 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:

if isinstance(unnest, exp.Unnest):
alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
has_multi_expr = len(unnest.expressions) > 1
_udtf: type[Func] = exp.Posexplode if unnest.args.get("offset") else exp.Explode

expression.args["joins"].remove(join)

if unnest_using_arrays_zip and has_multi_expr:
# Modify the logic to use arrays_zip if there are multiple expressions
# Build arrays_zip with nested expressions correctly
unnest.set(
"expressions",
[exp.Anonymous(this="arrays_zip", expressions=unnest.expressions)],
)
_udtf = exp.Inline

for e, column in zip(unnest.expressions, alias.columns if alias else []):
expression.append(
"laterals",
exp.Lateral(
this=udtf(this=e),
this=_udtf(this=e),
view=True,
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
alias=exp.TableAlias(
this=alias.this, # type: ignore
columns=alias.columns if unnest_using_arrays_zip else [column], # type: ignore
),
),
)

Expand Down
3 changes: 2 additions & 1 deletion tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,8 @@ def test_cross_join(self):
write={
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
"spark": "SELECT a, b FROM x LATERAL VIEW INLINE(ARRAYS_ZIP(y, z)) t AS a, b",
"hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
},
)
self.validate_all(
Expand Down
20 changes: 13 additions & 7 deletions tests/dialects/test_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,26 +36,32 @@ def test_unnest(self):
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)",
)
self.validate_identity(
"SELECT student, score, unnest.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)",
"SELECT student, score, unnest.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)",
"SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)",
"SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)",
)
self.validate_all(
r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""",
write={
"postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)",
"spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS "
"t(name, age)",
"spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS t(name, age)",
},
)
# Use UNNEST to convert into multiple columns
# see: https://docs.starrocks.io/docs/sql-reference/sql-functions/array-functions/unnest/
self.validate_all(
r"""SELECT id, t.type, t.scores FROM example_table, unnest(split(type, ";"), scores) AS t(type,scores)""",
write={
"postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS "
"t(type, scores)",
"spark": "SELECT id, t.type, t.scores FROM example_table LATERAL VIEW EXPLODE(SPLIT(type, CONCAT"
"postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)",
"spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
"hive": "SELECT id, t.type, t.scores FROM example_table LATERAL VIEW EXPLODE(SPLIT(type, CONCAT"
r"""('\\Q', ';'))) t AS type LATERAL VIEW EXPLODE(scores) t AS scores""",
"databricks": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
},
)
self.validate_all(
r"""SELECT id, t.type, t.scores FROM example_table CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""",
write={
"spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
},
)

Expand Down

0 comments on commit 6eac225

Please sign in to comment.