From c1bfdbaa9ea95a094f9dd6ec230928ebef2e13f1 Mon Sep 17 00:00:00 2001 From: "lin.zhang" Date: Mon, 26 Aug 2024 22:24:03 +0800 Subject: [PATCH] Fixes #3962 fix some comment --- sqlglot/dialects/starrocks.py | 7 ++---- sqlglot/expressions.py | 4 +-- sqlglot/transforms.py | 43 +++++++++++++++++++++----------- tests/dialects/test_starrocks.py | 12 +++++---- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 7fa8d69074..0bfc856602 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -36,12 +36,9 @@ def _parse_unnest(self, with_alias: bool = True) -> t.Optional[exp.Unnest]: alias = unnest.args.get("alias") if not alias: - # Starrocks defaults to naming the table alias as "unnest" - alias = exp.TableAlias( - this=exp.to_identifier("unnest"), columns=[exp.to_identifier("unnest")] - ) + alias = exp.TableAlias(columns=[exp.to_identifier("unnest")]) unnest.set("alias", alias) - if alias and not alias.args.get("columns"): + if not alias.args.get("columns"): # Starrocks defaults to naming the UNNEST column as "unnest" # if it's not otherwise specified alias.set("columns", [exp.to_identifier("unnest")]) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 31b026c794..ec051960c4 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5333,9 +5333,7 @@ class Explode(Func): # https://spark.apache.org/docs/latest/api/sql/#inline class Inline(Func): - arg_types = {"this": True, "expressions": False} - is_var_len_args = True - + pass class ExplodeOuter(Explode): pass diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 91fe61e322..4cc767e792 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -3,7 +3,6 @@ import typing as t from sqlglot import expressions as exp -from sqlglot.expressions import Func from sqlglot.helper import find_new_name, name_sequence if t.TYPE_CHECKING: @@ -303,17 +302,35 @@ def unnest_to_explode( support_unnest_func: bool = False, ) -> exp.Expression: """Convert cross join unnest into lateral view explode.""" + + def handle_unnest(u: exp.Unnest, _has_multi_expr: bool) -> None: + if unnest_using_arrays_zip and _has_multi_expr: + # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions + u.set( + "expressions", + [exp.Anonymous(this="ARRAYS_ZIP", expressions=u.expressions)], + ) + + def udtf_type(u: exp.Unnest, _has_multi_expr: bool): + return ( + exp.Posexplode + if u.args.get("offset") + else (exp.Inline if unnest_using_arrays_zip and _has_multi_expr else exp.Explode) + ) + if isinstance(expression, exp.Select): from_ = expression.args.get("from") if not support_unnest_func and from_ and isinstance(from_.this, exp.Unnest): unnest = from_.this alias = unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode + has_multi_expr = len(unnest.expressions) > 1 + handle_unnest(unnest, has_multi_expr) this, *expressions = unnest.expressions + unnest.replace( exp.Table( - this=udtf( + this=udtf_type(unnest, has_multi_expr)( this=this, expressions=expressions, ), @@ -329,26 +346,22 @@ def unnest_to_explode( unnest = join_expr.this if is_lateral else join_expr if isinstance(unnest, exp.Unnest): - alias = join_expr.args.get("alias") if is_lateral else unnest.args.get("alias") + alias = ( + join_expr.args.get("alias") + if is_lateral and join_expr.args.get("alias") + else unnest.args.get("alias") + ) + # The number of unnest.expressions will be changed by handle_unnest, we need to record it here has_multi_expr = len(unnest.expressions) > 1 - _udtf: type[Func] = exp.Posexplode if unnest.args.get("offset") else exp.Explode + handle_unnest(unnest, has_multi_expr) expression.args["joins"].remove(join) - if unnest_using_arrays_zip and has_multi_expr: - # Modify the logic to use arrays_zip if there are multiple expressions - # Build arrays_zip with nested expressions correctly - unnest.set( - "expressions", - [exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest.expressions)], - ) - _udtf = exp.Inline - for e, column in zip(unnest.expressions, alias.columns if alias else []): expression.append( "laterals", exp.Lateral( - this=_udtf(this=e), + this=udtf_type(unnest, has_multi_expr)(this=e), view=True, alias=exp.TableAlias( this=alias.this, # type: ignore diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index 5cc97ac941..96552163e7 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -35,15 +35,17 @@ def test_unnest(self): "SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t", "SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)", ) - self.validate_identity( + self.validate_all( "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)", - "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)", + write={ + "spark": "SELECT student, score, unnest FROM tests LATERAL VIEW EXPLODE(scores) AS unnest", + }, ) self.validate_all( r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""", write={ "postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)", - "spark": "SELECT * FROM EXPLODE(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27)) AS t(name, age)", + "spark": "SELECT * FROM INLINE(ARRAYS_ZIP(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27))) AS t(name, age)", }, ) # Use UNNEST to convert into multiple columns @@ -59,9 +61,9 @@ def test_unnest(self): }, ) self.validate_all( - r"""SELECT id, t.type, t.scores FROM example_table CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""", + r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""", write={ - "spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "spark": r"""SELECT id, t.type, t.scores FROM example_table_2 LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", }, )