diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 990991357e..ad9d54337c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from functools import partial from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( @@ -190,6 +191,16 @@ def _parse(args: t.List[exp.Expression]) -> exp.Expression: return _parse +def _select_sql(self: Hive.Generator, expression: exp.Expression) -> str: + return transforms.preprocess( + [ + transforms.eliminate_qualify, + transforms.eliminate_distinct_on, + partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False, generator=self), + ] + )(self, expression) + + class Hive(Dialect): ALIAS_POST_TABLESAMPLE = True IDENTIFIERS_CAN_START_WITH_DIGIT = True @@ -481,13 +492,7 @@ class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.Group: transforms.preprocess([transforms.unalias_group]), - exp.Select: transforms.preprocess( - [ - transforms.eliminate_qualify, - transforms.eliminate_distinct_on, - transforms.unnest_to_explode, - ] - ), + exp.Select: _select_sql, exp.Property: _property_sql, exp.AnyValue: rename_func("FIRST"), exp.ApproxDistinct: approx_count_distinct_sql, diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 017a6f6d73..4e64fbc1c8 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -230,6 +230,13 @@ class Generator(Hive.Generator): e.args["replacement"], e.args.get("position"), ), + exp.Select: transforms.preprocess( + [ + transforms.eliminate_qualify, + transforms.eliminate_distinct_on, + transforms.unnest_to_explode, + ] + ), exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)), exp.TimestampTrunc: lambda self, e: self.func("DATE_TRUNC", unit_to_str(e), e.this), diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index c1b1e4cacb..49f31ca1e9 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -35,7 +35,13 @@ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: if unnest: alias = unnest.args.get("alias") - if alias and not alias.args.get("columns"): + if not alias: + # Starrocks defaults to naming the table alias as "unnest" + alias = exp.TableAlias( + this=exp.to_identifier("unnest"), columns=[exp.to_identifier("unnest")] + ) + unnest.set("alias", alias) + elif not alias.args.get("columns"): # Starrocks defaults to naming the UNNEST column as "unnest" # if it's not otherwise specified alias.set("columns", [exp.to_identifier("unnest")]) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index d43c2e22f6..bc45c61152 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5331,6 +5331,11 @@ class Explode(Func): is_var_len_args = True +# https://spark.apache.org/docs/latest/api/sql/#inline +class Inline(Func): + pass + + class ExplodeOuter(Explode): pass diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 1a67a2ddfc..229562ef76 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -2,7 +2,7 @@ import typing as t -from sqlglot import expressions as exp +from sqlglot import expressions as exp, generator from sqlglot.helper import find_new_name, name_sequence if t.TYPE_CHECKING: @@ -296,19 +296,50 @@ def unqualify_unnest(expression: exp.Expression) -> exp.Expression: return expression -def unnest_to_explode(expression: exp.Expression) -> exp.Expression: +def unnest_to_explode( + expression: exp.Expression, + unnest_using_arrays_zip: bool = True, + generator: t.Optional[generator.Generator] = None, +) -> exp.Expression: """Convert cross join unnest into lateral view explode.""" + + def _unnest_zip_exprs( + u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool + ) -> t.List[exp.Expression]: + if has_multi_expr: + if not unnest_using_arrays_zip: + if generator: + generator.unsupported( + f"Multiple expressions in UNNEST are not supported in " + f"{generator.dialect.__module__.split('.')[-1].upper()}" + ) + else: + # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions + zip_exprs: t.List[exp.Expression] = [ + exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) + ] + u.set("expressions", zip_exprs) + return zip_exprs + return unnest_exprs + + def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: + if u.args.get("offset"): + return exp.Posexplode + return exp.Inline if has_multi_expr else exp.Explode + if isinstance(expression, exp.Select): from_ = expression.args.get("from") if from_ and isinstance(from_.this, exp.Unnest): unnest = from_.this alias = unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode - this, *expressions = unnest.expressions + exprs = unnest.expressions + has_multi_expr = len(exprs) > 1 + this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) + unnest.replace( exp.Table( - this=udtf( + this=_udtf_type(unnest, has_multi_expr)( this=this, expressions=expressions, ), @@ -324,18 +355,28 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: unnest = join_expr.this if is_lateral else join_expr if isinstance(unnest, exp.Unnest): - alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode + if is_lateral: + alias = join_expr.args.get("alias") + else: + alias = unnest.args.get("alias") + exprs = unnest.expressions + # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here + has_multi_expr = len(exprs) > 1 + exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) expression.args["joins"].remove(join) - for e, column in zip(unnest.expressions, alias.columns if alias else []): + alias_cols = alias.columns if alias else [] + for e, column in zip(exprs, alias_cols): expression.append( "laterals", exp.Lateral( - this=udtf(this=e), + this=_udtf_type(unnest, has_multi_expr)(this=e), view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore + alias=exp.TableAlias( + this=alias.this, # type: ignore + columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore + ), ), ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 8f3c183343..d34c319dde 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1399,7 +1399,7 @@ def test_cross_join(self): write={ "drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", - "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", + "spark": "SELECT a, b FROM x LATERAL VIEW INLINE(ARRAYS_ZIP(y, z)) t AS a, b", }, ) self.validate_all( diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index fa9a2ccc4d..b5144f8b6e 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -35,6 +35,39 @@ def test_unnest(self): "SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t", "SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)", ) + self.validate_all( + "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)", + write={ + "spark": "SELECT student, score, unnest FROM tests LATERAL VIEW EXPLODE(scores) unnest AS unnest", + "starrocks": "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)", + }, + ) + self.validate_all( + r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""", + write={ + "postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)", + "spark": "SELECT * FROM INLINE(ARRAYS_ZIP(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27))) AS t(name, age)", + "starrocks": "SELECT * FROM UNNEST(['John', 'Jane', 'Jim', 'Jamie'], [24, 25, 26, 27]) AS t(name, age)", + }, + ) + # Use UNNEST to convert into multiple columns + # see: https://docs.starrocks.io/docs/sql-reference/sql-functions/array-functions/unnest/ + self.validate_all( + r"""SELECT id, t.type, t.scores FROM example_table, unnest(split(type, ";"), scores) AS t(type,scores)""", + write={ + "postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)", + "spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "databricks": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "starrocks": r"""SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""", + }, + ) + self.validate_all( + r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""", + write={ + "spark": r"""SELECT id, t.type, t.scores FROM example_table_2 LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "starrocks": r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""", + }, + ) lateral_explode_sqls = [ "SELECT id, t.col FROM tbl, UNNEST(scores) AS t(col)",