Skip to content

Commit 6eac225

Browse files
author
lin.zhang
committed
fix(starrocks): exp.Unnest transpilation
use `arrays_zip` to merge multiple `Lateral view explode`
1 parent 77b87d5 commit 6eac225

File tree

6 files changed

+53
-14
lines changed

6 files changed

+53
-14
lines changed

sqlglot/dialects/hive.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from functools import partial
45

56
from sqlglot import exp, generator, parser, tokens, transforms
67
from sqlglot.dialects.dialect import (
@@ -485,7 +486,7 @@ class Generator(generator.Generator):
485486
[
486487
transforms.eliminate_qualify,
487488
transforms.eliminate_distinct_on,
488-
transforms.unnest_to_explode,
489+
partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False),
489490
]
490491
),
491492
exp.Property: _property_sql,

sqlglot/dialects/spark2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ class Generator(Hive.Generator):
190190

191191
TRANSFORMS = {
192192
**Hive.Generator.TRANSFORMS,
193+
exp.Select: transforms.preprocess(
194+
[
195+
transforms.eliminate_qualify,
196+
transforms.eliminate_distinct_on,
197+
transforms.unnest_to_explode,
198+
]
199+
),
193200
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
194201
exp.ArraySum: lambda self,
195202
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",

sqlglot/expressions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5331,6 +5331,12 @@ class Explode(Func):
53315331
is_var_len_args = True
53325332

53335333

5334+
# https://spark.apache.org/docs/latest/api/sql/#inline
5335+
class Inline(Func):
5336+
arg_types = {"this": True, "expressions": False}
5337+
is_var_len_args = True
5338+
5339+
53345340
class ExplodeOuter(Explode):
53355341
pass
53365342

sqlglot/transforms.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import typing as t
44

55
from sqlglot import expressions as exp
6+
from sqlglot.expressions import Func
67
from sqlglot.helper import find_new_name, name_sequence
78

89
if t.TYPE_CHECKING:
@@ -296,12 +297,16 @@ def unqualify_unnest(expression: exp.Expression) -> exp.Expression:
296297
return expression
297298

298299

299-
def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
300+
def unnest_to_explode(
301+
expression: exp.Expression,
302+
unnest_using_arrays_zip: bool = True,
303+
support_unnest_func: bool = False,
304+
) -> exp.Expression:
300305
"""Convert cross join unnest into lateral view explode."""
301306
if isinstance(expression, exp.Select):
302307
from_ = expression.args.get("from")
303308

304-
if from_ and isinstance(from_.this, exp.Unnest):
309+
if not support_unnest_func and from_ and isinstance(from_.this, exp.Unnest):
305310
unnest = from_.this
306311
alias = unnest.args.get("alias")
307312
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
@@ -325,17 +330,30 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
325330

326331
if isinstance(unnest, exp.Unnest):
327332
alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
328-
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
333+
has_multi_expr = len(unnest.expressions) > 1
334+
_udtf: type[Func] = exp.Posexplode if unnest.args.get("offset") else exp.Explode
329335

330336
expression.args["joins"].remove(join)
331337

338+
if unnest_using_arrays_zip and has_multi_expr:
339+
# Modify the logic to use arrays_zip if there are multiple expressions
340+
# Build arrays_zip with nested expressions correctly
341+
unnest.set(
342+
"expressions",
343+
[exp.Anonymous(this="arrays_zip", expressions=unnest.expressions)],
344+
)
345+
_udtf = exp.Inline
346+
332347
for e, column in zip(unnest.expressions, alias.columns if alias else []):
333348
expression.append(
334349
"laterals",
335350
exp.Lateral(
336-
this=udtf(this=e),
351+
this=_udtf(this=e),
337352
view=True,
338-
alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
353+
alias=exp.TableAlias(
354+
this=alias.this, # type: ignore
355+
columns=alias.columns if unnest_using_arrays_zip else [column], # type: ignore
356+
),
339357
),
340358
)
341359

tests/dialects/test_dialect.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1399,7 +1399,8 @@ def test_cross_join(self):
13991399
write={
14001400
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
14011401
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
1402-
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
1402+
"spark": "SELECT a, b FROM x LATERAL VIEW INLINE(ARRAYS_ZIP(y, z)) t AS a, b",
1403+
"hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
14031404
},
14041405
)
14051406
self.validate_all(

tests/dialects/test_starrocks.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,32 @@ def test_unnest(self):
3636
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)",
3737
)
3838
self.validate_identity(
39-
"SELECT student, score, unnest.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)",
40-
"SELECT student, score, unnest.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)",
39+
"SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)",
40+
"SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)",
4141
)
4242
self.validate_all(
4343
r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""",
4444
write={
4545
"postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)",
46-
"spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS "
47-
"t(name, age)",
46+
"spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS t(name, age)",
4847
},
4948
)
5049
# Use UNNEST to convert into multiple columns
5150
# see: https://docs.starrocks.io/docs/sql-reference/sql-functions/array-functions/unnest/
5251
self.validate_all(
5352
r"""SELECT id, t.type, t.scores FROM example_table, unnest(split(type, ";"), scores) AS t(type,scores)""",
5453
write={
55-
"postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS "
56-
"t(type, scores)",
57-
"spark": "SELECT id, t.type, t.scores FROM example_table LATERAL VIEW EXPLODE(SPLIT(type, CONCAT"
54+
"postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)",
55+
"spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
56+
"hive": "SELECT id, t.type, t.scores FROM example_table LATERAL VIEW EXPLODE(SPLIT(type, CONCAT"
5857
r"""('\\Q', ';'))) t AS type LATERAL VIEW EXPLODE(scores) t AS scores""",
58+
"databricks": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
59+
},
60+
)
61+
self.validate_all(
62+
r"""SELECT id, t.type, t.scores FROM example_table CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""",
63+
write={
64+
"spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
5965
},
6066
)
6167

0 commit comments

Comments
 (0)