Skip to content

Commit

Permalink
Fixes #3962
Browse files Browse the repository at this point in the history
fix some comment
  • Loading branch information
lin.zhang committed Aug 26, 2024
1 parent 8261ff5 commit c1bfdba
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 28 deletions.
7 changes: 2 additions & 5 deletions sqlglot/dialects/starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
alias = unnest.args.get("alias")

if not alias:
# Starrocks defaults to naming the table alias as "unnest"
alias = exp.TableAlias(
this=exp.to_identifier("unnest"), columns=[exp.to_identifier("unnest")]
)
alias = exp.TableAlias(columns=[exp.to_identifier("unnest")])
unnest.set("alias", alias)
if alias and not alias.args.get("columns"):
if not alias.args.get("columns"):
# Starrocks defaults to naming the UNNEST column as "unnest"
# if it's not otherwise specified
alias.set("columns", [exp.to_identifier("unnest")])
Expand Down
4 changes: 1 addition & 3 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5333,9 +5333,7 @@ class Explode(Func):

# https://spark.apache.org/docs/latest/api/sql/#inline
class Inline(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True

pass

class ExplodeOuter(Explode):
pass
Expand Down
43 changes: 28 additions & 15 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing as t

from sqlglot import expressions as exp
from sqlglot.expressions import Func
from sqlglot.helper import find_new_name, name_sequence

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -303,17 +302,35 @@ def unnest_to_explode(
support_unnest_func: bool = False,
) -> exp.Expression:
"""Convert cross join unnest into lateral view explode."""

def handle_unnest(u: exp.Unnest, _has_multi_expr: bool) -> None:
if unnest_using_arrays_zip and _has_multi_expr:
# Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
u.set(
"expressions",
[exp.Anonymous(this="ARRAYS_ZIP", expressions=u.expressions)],
)

def udtf_type(u: exp.Unnest, _has_multi_expr: bool):
return (
exp.Posexplode
if u.args.get("offset")
else (exp.Inline if unnest_using_arrays_zip and _has_multi_expr else exp.Explode)
)

if isinstance(expression, exp.Select):
from_ = expression.args.get("from")

if not support_unnest_func and from_ and isinstance(from_.this, exp.Unnest):
unnest = from_.this
alias = unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
has_multi_expr = len(unnest.expressions) > 1
handle_unnest(unnest, has_multi_expr)
this, *expressions = unnest.expressions

unnest.replace(
exp.Table(
this=udtf(
this=udtf_type(unnest, has_multi_expr)(
this=this,
expressions=expressions,
),
Expand All @@ -329,26 +346,22 @@ def unnest_to_explode(
unnest = join_expr.this if is_lateral else join_expr

if isinstance(unnest, exp.Unnest):
alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
alias = (
join_expr.args.get("alias")
if is_lateral and join_expr.args.get("alias")
else unnest.args.get("alias")
)
# The number of unnest.expressions will be changed by handle_unnest, we need to record it here
has_multi_expr = len(unnest.expressions) > 1
_udtf: type[Func] = exp.Posexplode if unnest.args.get("offset") else exp.Explode
handle_unnest(unnest, has_multi_expr)

expression.args["joins"].remove(join)

if unnest_using_arrays_zip and has_multi_expr:
# Modify the logic to use arrays_zip if there are multiple expressions
# Build arrays_zip with nested expressions correctly
unnest.set(
"expressions",
[exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest.expressions)],
)
_udtf = exp.Inline

for e, column in zip(unnest.expressions, alias.columns if alias else []):
expression.append(
"laterals",
exp.Lateral(
this=_udtf(this=e),
this=udtf_type(unnest, has_multi_expr)(this=e),
view=True,
alias=exp.TableAlias(
this=alias.this, # type: ignore
Expand Down
12 changes: 7 additions & 5 deletions tests/dialects/test_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def test_unnest(self):
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t",
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)",
)
self.validate_identity(
self.validate_all(
"SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)",
"SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)",
write={
"spark": "SELECT student, score, unnest FROM tests LATERAL VIEW EXPLODE(scores) AS unnest",
},
)
self.validate_all(
r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""",
write={
"postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)",
"spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS t(name, age)",
"spark": "SELECT * FROM INLINE(ARRAYS_ZIP(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27))) AS t(name, age)",
},
)
# Use UNNEST to convert into multiple columns
Expand All @@ -59,9 +61,9 @@ def test_unnest(self):
},
)
self.validate_all(
r"""SELECT id, t.type, t.scores FROM example_table CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""",
r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""",
write={
"spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
"spark": r"""SELECT id, t.type, t.scores FROM example_table_2 LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
},
)

Expand Down

0 comments on commit c1bfdba

Please sign in to comment.