diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 990991357e..0a80a9a465 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from functools import partial from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( @@ -485,7 +486,7 @@ class Generator(generator.Generator): [ transforms.eliminate_qualify, transforms.eliminate_distinct_on, - transforms.unnest_to_explode, + partial(transforms.unnest_to_explode, unnest_using_arrays_zip=False), ] ), exp.Property: _property_sql, diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py index 017a6f6d73..c254e94143 100644 --- a/sqlglot/dialects/spark2.py +++ b/sqlglot/dialects/spark2.py @@ -190,6 +190,13 @@ class Generator(Hive.Generator): TRANSFORMS = { **Hive.Generator.TRANSFORMS, + exp.Select: transforms.preprocess( + [ + transforms.eliminate_qualify, + transforms.eliminate_distinct_on, + transforms.unnest_to_explode, + ] + ), exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index d43c2e22f6..31b026c794 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5331,6 +5331,12 @@ class Explode(Func): is_var_len_args = True +# https://spark.apache.org/docs/latest/api/sql/#inline +class Inline(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + class ExplodeOuter(Explode): pass diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 1a67a2ddfc..ac3f78d6da 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -3,6 +3,7 @@ import typing as t from sqlglot import expressions as exp +from sqlglot.expressions import Func from sqlglot.helper import find_new_name, name_sequence if t.TYPE_CHECKING: @@ -296,12 +297,16 @@ def unqualify_unnest(expression: exp.Expression) -> exp.Expression: return expression -def unnest_to_explode(expression: exp.Expression) -> exp.Expression: +def unnest_to_explode( + expression: exp.Expression, + unnest_using_arrays_zip: bool = True, + support_unnest_func: bool = False, +) -> exp.Expression: """Convert cross join unnest into lateral view explode.""" if isinstance(expression, exp.Select): from_ = expression.args.get("from") - if from_ and isinstance(from_.this, exp.Unnest): + if not support_unnest_func and from_ and isinstance(from_.this, exp.Unnest): unnest = from_.this alias = unnest.args.get("alias") udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode @@ -325,17 +330,30 @@ def unnest_to_explode(expression: exp.Expression) -> exp.Expression: if isinstance(unnest, exp.Unnest): alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode + has_multi_expr = len(unnest.expressions) > 1 + _udtf: type[Func] = exp.Posexplode if unnest.args.get("offset") else exp.Explode expression.args["joins"].remove(join) + if unnest_using_arrays_zip and has_multi_expr: + # Modify the logic to use arrays_zip if there are multiple expressions + # Build arrays_zip with nested expressions correctly + unnest.set( + "expressions", + [exp.Anonymous(this="arrays_zip", expressions=unnest.expressions)], + ) + _udtf = exp.Inline + for e, column in zip(unnest.expressions, alias.columns if alias else []): expression.append( "laterals", exp.Lateral( - this=udtf(this=e), + this=_udtf(this=e), view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore + alias=exp.TableAlias( + this=alias.this, # type: ignore + columns=alias.columns if unnest_using_arrays_zip else [column], # type: ignore + ), ), ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 8f3c183343..7ddc8e5a2f 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1399,7 +1399,8 @@ def test_cross_join(self): write={ "drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", - "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", + "spark": "SELECT a, b FROM x LATERAL VIEW INLINE(ARRAYS_ZIP(y, z)) t AS a, b", + "hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", }, ) self.validate_all( diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index a0388cf042..5cc97ac941 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -36,15 +36,14 @@ def test_unnest(self): "SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)", ) self.validate_identity( - "SELECT student, score, unnest.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)", - "SELECT student, score, unnest.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)", + "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)", + "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)", ) self.validate_all( r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""", write={ "postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)", - "spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS " - "t(name, age)", + "spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS t(name, age)", }, ) # Use UNNEST to convert into multiple columns @@ -52,10 +51,17 @@ def test_unnest(self): self.validate_all( r"""SELECT id, t.type, t.scores FROM example_table, unnest(split(type, ";"), scores) AS t(type,scores)""", write={ - "postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS " - "t(type, scores)", - "spark": "SELECT id, t.type, t.scores FROM example_table LATERAL VIEW EXPLODE(SPLIT(type, CONCAT" + "postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)", + "spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "hive": "SELECT id, t.type, t.scores FROM example_table LATERAL VIEW EXPLODE(SPLIT(type, CONCAT" r"""('\\Q', ';'))) t AS type LATERAL VIEW EXPLODE(scores) t AS scores""", + "databricks": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + }, + ) + self.validate_all( + r"""SELECT id, t.type, t.scores FROM example_table CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""", + write={ + "spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", }, )