Skip to content

Commit c1ac987

Browse files
authored
fix(starrocks): exp.Array generation, exp.Unnest alias (#3964)
1 parent 0a9ba05 commit c1ac987

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

sqlglot/dialects/starrocks.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from __future__ import annotations
22

3+
import typing as t
4+
35
from sqlglot import exp
46
from sqlglot.dialects.dialect import (
57
approx_count_distinct_sql,
68
arrow_json_extract_sql,
79
build_timestamp_trunc,
810
rename_func,
911
unit_to_str,
12+
inline_array_sql,
1013
)
1114
from sqlglot.dialects.mysql import MySQL
1215
from sqlglot.helper import seq_get
@@ -26,6 +29,19 @@ class Parser(MySQL.Parser):
2629
"REGEXP": exp.RegexpLike.from_arg_list,
2730
}
2831

32+
def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
33+
unnest = super()._parse_unnest(with_alias=with_alias)
34+
35+
if unnest:
36+
alias = unnest.args.get("alias")
37+
38+
if alias and not alias.args.get("columns"):
39+
# Starrocks defaults to naming the UNNEST column as "unnest"
40+
# if it's not otherwise specified
41+
alias.set("columns", [exp.to_identifier("unnest")])
42+
43+
return unnest
44+
2945
class Generator(MySQL.Generator):
3046
CAST_MAPPING = {}
3147

@@ -38,6 +54,7 @@ class Generator(MySQL.Generator):
3854

3955
TRANSFORMS = {
4056
**MySQL.Generator.TRANSFORMS,
57+
exp.Array: inline_array_sql,
4158
exp.ApproxDistinct: approx_count_distinct_sql,
4259
exp.DateDiff: lambda self, e: self.func(
4360
"DATE_DIFF", unit_to_str(e), e.this, e.expression

sqlglot/transforms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,14 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
317317
)
318318

319319
for join in expression.args.get("joins") or []:
320-
unnest = join.this
320+
join_expr = join.this
321+
322+
is_lateral = isinstance(join_expr, exp.Lateral)
323+
324+
unnest = join_expr.this if is_lateral else join_expr
321325

322326
if isinstance(unnest, exp.Unnest):
323-
alias = unnest.args.get("alias")
327+
alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
324328
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode
325329

326330
expression.args["joins"].remove(join)

tests/dialects/test_starrocks.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ class TestStarrocks(Validator):
77
def test_identity(self):
88
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
99
self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x")
10+
self.validate_identity("SELECT [1, 2, 3]")
1011

1112
def test_time(self):
1213
self.validate_identity("TIMESTAMP('2022-01-01')")
@@ -28,3 +29,24 @@ def test_regex(self):
2829
"mysql": "SELECT REGEXP_LIKE(abc, '%foo%')",
2930
},
3031
)
32+
33+
def test_unnest(self):
34+
self.validate_identity(
35+
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t",
36+
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)",
37+
)
38+
39+
lateral_explode_sqls = [
40+
"SELECT id, t.col FROM tbl, UNNEST(scores) AS t(col)",
41+
"SELECT id, t.col FROM tbl CROSS JOIN LATERAL UNNEST(scores) AS t(col)",
42+
]
43+
44+
for sql in lateral_explode_sqls:
45+
with self.subTest(f"Testing Starrocks roundtrip & transpilation of: {sql}"):
46+
self.validate_all(
47+
sql,
48+
write={
49+
"starrocks": sql,
50+
"spark": "SELECT id, t.col FROM tbl LATERAL VIEW EXPLODE(scores) t AS col",
51+
},
52+
)

0 commit comments

Comments
 (0)