Skip to content

fix(starrocks): exp.Unnest transpilation #3966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing as t
from functools import partial

from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Expand Down Expand Up @@ -190,6 +191,16 @@ def _parse(args: t.List[exp.Expression]) -> exp.Expression:
return _parse


def _select_sql(self: Hive.Generator, expression: exp.Expression) -> str:
return transforms.preprocess(
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False, generator=self),
]
)(self, expression)


class Hive(Dialect):
ALIAS_POST_TABLESAMPLE = True
IDENTIFIERS_CAN_START_WITH_DIGIT = True
Expand Down Expand Up @@ -481,13 +492,7 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.Group: transforms.preprocess([transforms.unalias_group]),
exp.Select: transforms.preprocess(
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.Select: _select_sql,
exp.Property: _property_sql,
exp.AnyValue: rename_func("FIRST"),
exp.ApproxDistinct: approx_count_distinct_sql,
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ class Generator(Hive.Generator):
e.args["replacement"],
e.args.get("position"),
),
exp.Select: transforms.preprocess(
[
transforms.eliminate_qualify,
transforms.eliminate_distinct_on,
transforms.unnest_to_explode,
]
),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this),
Expand Down
8 changes: 7 additions & 1 deletion sqlglot/dialects/starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
if unnest:
alias = unnest.args.get("alias")

if alias and not alias.args.get("columns"):
if not alias:
# Starrocks defaults to naming the table alias as "unnest"
alias = exp.TableAlias(
this=exp.to_identifier("unnest"), columns=[exp.to_identifier("unnest")]
)
unnest.set("alias", alias)
elif not alias.args.get("columns"):
# Starrocks defaults to naming the UNNEST column as "unnest"
# if it's not otherwise specified
alias.set("columns", [exp.to_identifier("unnest")])
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5331,6 +5331,11 @@ class Explode(Func):
is_var_len_args = True


# https://spark.apache.org/docs/latest/api/sql/#inline
class Inline(Func):
pass


class ExplodeOuter(Explode):
pass

Expand Down
61 changes: 51 additions & 10 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing as t

from sqlglot import expressions as exp
from sqlglot import expressions as exp, generator
from sqlglot.helper import find_new_name, name_sequence

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -296,19 +296,50 @@ def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
return expression


def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
def unnest_to_explode(
expression: exp.Expression,
unnest_using_arrays_zip: bool = True,
generator: t.Optional[generator.Generator] = None,
) -> exp.Expression:
"""Convert cross join unnest into lateral view explode."""

def _unnest_zip_exprs(
u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool
) -> t.List[exp.Expression]:
if has_multi_expr:
if not unnest_using_arrays_zip:
if generator:
generator.unsupported(
f"Multiple expressions in UNNEST are not supported in "
f"{generator.dialect.__module__.split('.')[-1].upper()}"
)
else:
# Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
zip_exprs: t.List[exp.Expression] = [
exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)
]
u.set("expressions", zip_exprs)
return zip_exprs
return unnest_exprs

def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]:
if u.args.get("offset"):
return exp.Posexplode
return exp.Inline if has_multi_expr else exp.Explode

if isinstance(expression, exp.Select):
from_ = expression.args.get("from")

if from_ and isinstance(from_.this, exp.Unnest):
unnest = from_.this
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
this, *expressions = unnest.expressions
exprs = unnest.expressions
has_multi_expr = len(exprs) > 1
this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr)

unnest.replace(
exp.Table(
this=udtf(
this=_udtf_type(unnest, has_multi_expr)(
this=this,
expressions=expressions,
),
Expand All @@ -324,18 +355,28 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
unnest = join_expr.this if is_lateral else join_expr

if isinstance(unnest, exp.Unnest):
alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
if is_lateral:
alias = join_expr.args.get("alias")
else:
alias = unnest.args.get("alias")
exprs = unnest.expressions
# The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
has_multi_expr = len(exprs) > 1
exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)

expression.args["joins"].remove(join)

for e, column in zip(unnest.expressions, alias.columns if alias else []):
alias_cols = alias.columns if alias else []
for e, column in zip(exprs, alias_cols):
expression.append(
"laterals",
exp.Lateral(
this=udtf(this=e),
this=_udtf_type(unnest, has_multi_expr)(this=e),
view=True,
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
alias=exp.TableAlias(
this=alias.this, # type: ignore
columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore
),
),
)

Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,7 +1399,7 @@ def test_cross_join(self):
write={
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
"spark": "SELECT a, b FROM x LATERAL VIEW INLINE(ARRAYS_ZIP(y, z)) t AS a, b",
},
)
self.validate_all(
Expand Down
33 changes: 33 additions & 0 deletions tests/dialects/test_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,39 @@ def test_unnest(self):
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t",
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)",
)
self.validate_all(
"SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)",
write={
"spark": "SELECT student, score, unnest FROM tests LATERAL VIEW EXPLODE(scores) unnest AS unnest",
"starrocks": "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)",
},
)
self.validate_all(
r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""",
write={
"postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)",
"spark": "SELECT * FROM INLINE(ARRAYS_ZIP(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27))) AS t(name, age)",
"starrocks": "SELECT * FROM UNNEST(['John', 'Jane', 'Jim', 'Jamie'], [24, 25, 26, 27]) AS t(name, age)",
},
)
# Use UNNEST to convert into multiple columns
# see: https://docs.starrocks.io/docs/sql-reference/sql-functions/array-functions/unnest/
self.validate_all(
r"""SELECT id, t.type, t.scores FROM example_table, unnest(split(type, ";"), scores) AS t(type,scores)""",
write={
"postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)",
"spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
"databricks": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
"starrocks": r"""SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""",
},
)
self.validate_all(
r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""",
write={
"spark": r"""SELECT id, t.type, t.scores FROM example_table_2 LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
"starrocks": r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""",
},
)

lateral_explode_sqls = [
"SELECT id, t.col FROM tbl, UNNEST(scores) AS t(col)",
Expand Down
Loading