Skip to content

Commit

Permalink
fix(starrocks): exp.Array generation, exp.Unnest alias (#3964)
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD committed Aug 24, 2024
1 parent 0a9ba05 commit c1ac987
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
17 changes: 17 additions & 0 deletions sqlglot/dialects/starrocks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import typing as t

from sqlglot import exp
from sqlglot.dialects.dialect import (
approx_count_distinct_sql,
arrow_json_extract_sql,
build_timestamp_trunc,
rename_func,
unit_to_str,
inline_array_sql,
)
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import seq_get
Expand All @@ -26,6 +29,19 @@ class Parser(MySQL.Parser):
"REGEXP": exp.RegexpLike.from_arg_list,
}

def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]:
unnest = super()._parse_unnest(with_alias=with_alias)

if unnest:
alias = unnest.args.get("alias")

if alias and not alias.args.get("columns"):
# Starrocks defaults to naming the UNNEST column as "unnest"
# if it's not otherwise specified
alias.set("columns", [exp.to_identifier("unnest")])

return unnest

class Generator(MySQL.Generator):
CAST_MAPPING = {}

Expand All @@ -38,6 +54,7 @@ class Generator(MySQL.Generator):

TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.DateDiff: lambda self, e: self.func(
"DATE_DIFF", unit_to_str(e), e.this, e.expression
Expand Down
8 changes: 6 additions & 2 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,14 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression:
)

for join in expression.args.get("joins") or []:
unnest = join.this
join_expr = join.this

is_lateral = isinstance(join_expr, exp.Lateral)

unnest = join_expr.this if is_lateral else join_expr

if isinstance(unnest, exp.Unnest):
alias = unnest.args.get("alias")
alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias")
udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode

expression.args["joins"].remove(join)
Expand Down
22 changes: 22 additions & 0 deletions tests/dialects/test_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class TestStarrocks(Validator):
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x")
self.validate_identity("SELECT [1, 2, 3]")

def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")
Expand All @@ -28,3 +29,24 @@ def test_regex(self):
"mysql": "SELECT REGEXP_LIKE(abc, '%foo%')",
},
)

def test_unnest(self):
self.validate_identity(
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t",
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)",
)

lateral_explode_sqls = [
"SELECT id, t.col FROM tbl, UNNEST(scores) AS t(col)",
"SELECT id, t.col FROM tbl CROSS JOIN LATERAL UNNEST(scores) AS t(col)",
]

for sql in lateral_explode_sqls:
with self.subTest(f"Testing Starrocks roundtrip & transpilation of: {sql}"):
self.validate_all(
sql,
write={
"starrocks": sql,
"spark": "SELECT id, t.col FROM tbl LATERAL VIEW EXPLODE(scores) t AS col",
},
)

0 comments on commit c1ac987

Please sign in to comment.