Skip to content

Commit

Permalink
Fix!: mysql/tsql datetime precision, formatting, exp.AtTimeZone (#3951)
Browse files Browse the repository at this point in the history
* Fix!: mysql/tsql datetime precision, formatting, exp.AtTimeZone
* Remove exp.AtTimeZone -> CONVERT_TZ from the MySQL dialect and make it raise an unsupported warning instead
  • Loading branch information
erindru authored Aug 26, 2024
1 parent 22bb9a0 commit 8aec682
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 11 deletions.
17 changes: 14 additions & 3 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses
from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.time import TIMEZONES, format_time, subsecond_precision
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import new_trie

Expand Down Expand Up @@ -1243,13 +1243,24 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:
)


def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
datatype = (
def timestrtotime_sql(
self: Generator,
expression: exp.TimeStrToTime,
include_precision: bool = False,
) -> str:
datatype = exp.DataType.build(
exp.DataType.Type.TIMESTAMPTZ
if expression.args.get("zone")
else exp.DataType.Type.TIMESTAMP
)

if isinstance(expression.this, exp.Literal) and include_precision:
precision = subsecond_precision(expression.this.name)
if precision > 0:
datatype = exp.DataType.build(
datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))]
)

return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))


Expand Down
11 changes: 9 additions & 2 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
strposition_to_locate_sql,
unit_to_var,
trim_sql,
timestrtotime_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -728,8 +729,10 @@ class Generator(generator.Generator):
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(
exp.cast(e.this, exp.DataType.Type.DATETIME, copy=True)
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(
self,
e,
include_precision=not e.args.get("zone"),
),
exp.TimeToStr: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
Expand Down Expand Up @@ -1210,3 +1213,7 @@ def converttimezone_sql(self, expression: exp.ConvertTimezone) -> str:
dt = expression.args.get("timestamp")

return self.func("CONVERT_TZ", dt, from_tz, to_tz)

def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
self.unsupported("AT TIME ZONE is not supported by MySQL")
return self.sql(expression.this)
3 changes: 2 additions & 1 deletion sqlglot/dialects/trino.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from sqlglot import exp
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql
from sqlglot.dialects.dialect import merge_without_target_sql, trim_sql, timestrtotime_sql
from sqlglot.dialects.presto import Presto


Expand All @@ -21,6 +21,7 @@ class Generator(Presto.Generator):
exp.ArraySum: lambda self,
e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.Merge: merge_without_target_sql,
exp.TimeStrToTime: lambda self, e: timestrtotime_sql(self, e, include_precision=True),
exp.Trim: trim_sql,
}

Expand Down
6 changes: 4 additions & 2 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ class TSQL(Dialect):
"HH": "%H",
"H": "%-H",
"h": "%-I",
"S": "%f",
"ffffff": "%f",
"yyyy": "%Y",
"yy": "%y",
}
Expand Down Expand Up @@ -984,7 +984,9 @@ def setitem_sql(self, expression: exp.SetItem) -> str:
return super().setitem_sql(expression)

def boolean_sql(self, expression: exp.Boolean) -> str:
if type(expression.parent) in BIT_TYPES:
if type(expression.parent) in BIT_TYPES or isinstance(
expression.find_ancestor(exp.Values, exp.Select), exp.Values
):
return "1" if expression.this else "0"

return "(1 = 1)" if expression.this else "(1 = 0)"
Expand Down
24 changes: 24 additions & 0 deletions sqlglot/time.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing as t
import datetime

# The generic time format is based on python time.strftime.
# https://docs.python.org/3/library/time.html#time.strftime
Expand Down Expand Up @@ -661,3 +662,26 @@ def format_time(
"Zulu",
)
}


def subsecond_precision(timestamp_literal: str) -> int:
"""
Given an ISO-8601 timestamp literal, eg '2023-01-01 12:13:14.123456+00:00'
figure out its subsecond precision so we can construct types like DATETIME(6)
Note that in practice, this is either 3 or 6 digits (3 = millisecond precision, 6 = microsecond precision)
- 6 is the maximum because strftime's '%f' formats to microseconds and almost every database supports microsecond precision in timestamps
- Except Presto/Trino which in most cases only supports millisecond precision but will still honour '%f' and format to microseconds (replacing the remaining 3 digits with 0's)
- Python prior to 3.11 only supports 0, 3 or 6 digits in a timestamp literal. Any other amounts will throw a 'ValueError: Invalid isoformat string:' error
"""
try:
parsed = datetime.datetime.fromisoformat(timestamp_literal)
subsecond_digit_count = len(str(parsed.microsecond).rstrip("0"))
precision = 0
if subsecond_digit_count > 3:
precision = 6
elif subsecond_digit_count > 0:
precision = 3
return precision
except ValueError:
return 0
27 changes: 25 additions & 2 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,14 +655,30 @@ def test_time(self):
"doris": "CAST('2020-01-01' AS DATETIME)",
},
)
self.validate_all(
"TIME_STR_TO_TIME('2020-01-01 12:13:14.123456+00:00')",
write={
"mysql": "CAST('2020-01-01 12:13:14.123456+00:00' AS DATETIME(6))",
"trino": "CAST('2020-01-01 12:13:14.123456+00:00' AS TIMESTAMP(6))",
"presto": "CAST('2020-01-01 12:13:14.123456+00:00' AS TIMESTAMP)",
},
)
self.validate_all(
"TIME_STR_TO_TIME('2020-01-01 12:13:14.123-08:00', 'America/Los_Angeles')",
write={
"mysql": "TIMESTAMP('2020-01-01 12:13:14.123-08:00')",
"trino": "CAST('2020-01-01 12:13:14.123-08:00' AS TIMESTAMP(3) WITH TIME ZONE)",
"presto": "CAST('2020-01-01 12:13:14.123-08:00' AS TIMESTAMP WITH TIME ZONE)",
},
)
self.validate_all(
"TIME_STR_TO_TIME('2020-01-01 12:13:14-08:00', 'America/Los_Angeles')",
write={
"bigquery": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)",
"databricks": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)",
"duckdb": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMPTZ)",
"tsql": "CAST('2020-01-01 12:13:14-08:00' AS DATETIMEOFFSET) AT TIME ZONE 'UTC'",
"mysql": "CAST('2020-01-01 12:13:14-08:00' AS DATETIME)",
"mysql": "TIMESTAMP('2020-01-01 12:13:14-08:00')",
"postgres": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMPTZ)",
"redshift": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP WITH TIME ZONE)",
"snowflake": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMPTZ)",
Expand All @@ -683,7 +699,7 @@ def test_time(self):
"databricks": "CAST(col AS TIMESTAMP)",
"duckdb": "CAST(col AS TIMESTAMPTZ)",
"tsql": "CAST(col AS DATETIMEOFFSET) AT TIME ZONE 'UTC'",
"mysql": "CAST(col AS DATETIME)",
"mysql": "TIMESTAMP(col)",
"postgres": "CAST(col AS TIMESTAMPTZ)",
"redshift": "CAST(col AS TIMESTAMP WITH TIME ZONE)",
"snowflake": "CAST(col AS TIMESTAMPTZ)",
Expand Down Expand Up @@ -722,6 +738,13 @@ def test_time(self):
"doris": "DATE_FORMAT(x, '%Y-%m-%d')",
},
)
self.validate_all(
"TIME_TO_STR(a, '%Y-%m-%d %H:%M:%S.%f')",
write={
"redshift": "TO_CHAR(a, 'YYYY-MM-DD HH24:MI:SS.US')",
"tsql": "FORMAT(a, 'yyyy-MM-dd HH:mm:ss.ffffff')",
},
)
self.validate_all(
"TIME_TO_TIME_STR(x)",
write={
Expand Down
59 changes: 59 additions & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import unittest
import sys

from sqlglot import expressions as exp
from sqlglot.dialects.mysql import MySQL
from tests.dialects.test_dialect import Validator
Expand Down Expand Up @@ -637,6 +640,53 @@ def test_mysql_time(self):
},
)

# No timezone, make sure DATETIME captures the correct precision
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15.123456+00:00')",
write_sql="SELECT CAST('2023-01-01 13:14:15.123456+00:00' AS DATETIME(6))",
)
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15.123+00:00')",
write_sql="SELECT CAST('2023-01-01 13:14:15.123+00:00' AS DATETIME(3))",
)
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15+00:00')",
write_sql="SELECT CAST('2023-01-01 13:14:15+00:00' AS DATETIME)",
)

# With timezone, make sure the TIMESTAMP constructor is used
# also TIMESTAMP doesnt have the subsecond precision truncation issue that DATETIME does so we dont need to TIMESTAMP(6)
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15-08:00', 'America/Los_Angeles')",
write_sql="SELECT TIMESTAMP('2023-01-01 13:14:15-08:00')",
)
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15-08:00', 'America/Los_Angeles')",
write_sql="SELECT TIMESTAMP('2023-01-01 13:14:15-08:00')",
)

@unittest.skipUnless(
sys.version_info >= (3, 11),
"Python 3.11 relaxed datetime.fromisoformat() parsing with regards to microseconds",
)
def test_mysql_time_python311(self):
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15.12345+00:00')",
write_sql="SELECT CAST('2023-01-01 13:14:15.12345+00:00' AS DATETIME(6))",
)
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15.1234+00:00')",
write_sql="SELECT CAST('2023-01-01 13:14:15.1234+00:00' AS DATETIME(6))",
)
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15.12+00:00')",
write_sql="SELECT CAST('2023-01-01 13:14:15.12+00:00' AS DATETIME(3))",
)
self.validate_identity(
"SELECT TIME_STR_TO_TIME('2023-01-01 13:14:15.1+00:00')",
write_sql="SELECT CAST('2023-01-01 13:14:15.1+00:00' AS DATETIME(3))",
)

def test_mysql(self):
self.validate_all(
"SELECT CONCAT('11', '22')",
Expand Down Expand Up @@ -1192,3 +1242,12 @@ def test_timestamp_trunc(self):
"mysql": f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})",
},
)

def test_at_time_zone(self):
with self.assertLogs() as cm:
# Check AT TIME ZONE doesnt discard the column name and also raises a warning
self.validate_identity(
"SELECT foo AT TIME ZONE 'UTC'",
write_sql="SELECT foo",
)
assert "AT TIME ZONE is not supported" in cm.output[0]
6 changes: 6 additions & 0 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,12 @@ def test_tsql(self):
},
)

# Check that TRUE and FALSE dont get expanded to (1=1) or (1=0) when used in a VALUES expression
self.validate_identity(
"SELECT val FROM (VALUES ((TRUE), (FALSE), (NULL))) AS t(val)",
write_sql="SELECT val FROM (VALUES ((1), (0), (NULL))) AS t(val)",
)

def test_option(self):
possible_options = [
"HASH GROUP",
Expand Down
22 changes: 21 additions & 1 deletion tests/test_time.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import sys

from sqlglot.time import format_time
from sqlglot.time import format_time, subsecond_precision


class TestTime(unittest.TestCase):
Expand All @@ -12,3 +13,22 @@ def test_format_time(self):
self.assertEqual(format_time("aa", mapping), "c")
self.assertEqual(format_time("aaada", mapping), "cbdb")
self.assertEqual(format_time("da", mapping), "db")

def test_subsecond_precision(self):
self.assertEqual(6, subsecond_precision("2023-01-01 12:13:14.123456+00:00"))
self.assertEqual(3, subsecond_precision("2023-01-01 12:13:14.123+00:00"))
self.assertEqual(0, subsecond_precision("2023-01-01 12:13:14+00:00"))
self.assertEqual(0, subsecond_precision("2023-01-01 12:13:14"))
self.assertEqual(0, subsecond_precision("garbage"))

@unittest.skipUnless(
sys.version_info >= (3, 11),
"Python 3.11 relaxed datetime.fromisoformat() parsing with regards to microseconds",
)
def test_subsecond_precision_python311(self):
# ref: https://docs.python.org/3/whatsnew/3.11.html#datetime
self.assertEqual(6, subsecond_precision("2023-01-01 12:13:14.123456789+00:00"))
self.assertEqual(6, subsecond_precision("2023-01-01 12:13:14.12345+00:00"))
self.assertEqual(6, subsecond_precision("2023-01-01 12:13:14.1234+00:00"))
self.assertEqual(3, subsecond_precision("2023-01-01 12:13:14.12+00:00"))
self.assertEqual(3, subsecond_precision("2023-01-01 12:13:14.1+00:00"))

0 comments on commit 8aec682

Please sign in to comment.