Skip to content

Commit c5d8028

Browse files
committed
Add PostgreSQL :: style casts
1 parent e437b16 commit c5d8028

File tree

4 files changed

+59
-9
lines changed

4 files changed

+59
-9
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,8 @@ primaryExpression
592592
| CASE operand=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
593593
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
594594
| CAST '(' expression AS type ')' #cast
595+
// This is a postgres extension to ANSI SQL, which allows for the use of "::" to cast
596+
| primaryExpression DOUBLE_COLON type #cast
595597
| TRY_CAST '(' expression AS type ')' #cast
596598
| ARRAY '[' (expression (',' expression)*)? ']' #arrayConstructor
597599
| value=primaryExpression '[' index=valueExpression ']' #subscript
@@ -1327,6 +1329,7 @@ LT: '<';
13271329
LTE: '<=';
13281330
GT: '>';
13291331
GTE: '>=';
1332+
DOUBLE_COLON: '::';
13301333

13311334
PLUS: '+';
13321335
MINUS: '-';

core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@
368368
import static io.trino.sql.tree.TableFunctionDescriptorArgument.nullDescriptorArgument;
369369
import static java.util.Locale.ENGLISH;
370370
import static java.util.Objects.requireNonNull;
371+
import static java.util.Objects.requireNonNullElse;
371372
import static java.util.stream.Collectors.joining;
372373
import static java.util.stream.Collectors.toList;
373374

@@ -2387,7 +2388,8 @@ public Node visitArrayConstructor(SqlBaseParser.ArrayConstructorContext context)
23872388
public Node visitCast(SqlBaseParser.CastContext context)
23882389
{
23892390
boolean isTryCast = context.TRY_CAST() != null;
2390-
return new Cast(getLocation(context), (Expression) visit(context.expression()), (DataType) visit(context.type()), isTryCast);
2391+
Expression expression = (Expression) visit(requireNonNullElse(context.expression(), context.primaryExpression()));
2392+
return new Cast(getLocation(context), expression, (DataType) visit(context.type()), isTryCast);
23912393
}
23922394

23932395
@Override

core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParser.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,51 @@ public void testCase()
13451345
Optional.of(new LongLiteral(location(1, 38), "3"))));
13461346
}
13471347

1348+
@Test
1349+
public void testCast()
1350+
{
1351+
assertThat(expression("CAST(1 AS BIGINT)"))
1352+
.isEqualTo(new Cast(location(1, 1),
1353+
new LongLiteral(location(1, 6), "1"),
1354+
simpleType(location(1, 11), "BIGINT")));
1355+
1356+
assertThat(expression("1::BIGINT"))
1357+
.isEqualTo(new Cast(location(1, 1),
1358+
new LongLiteral(location(1, 1), "1"),
1359+
simpleType(location(1, 4), "BIGINT")));
1360+
1361+
assertThat(expression("-3::BIGINT"))
1362+
.isEqualTo(new Cast(location(1, 1),
1363+
new LongLiteral(location(1, 1), "-3"),
1364+
simpleType(location(1, 5), "BIGINT")));
1365+
1366+
assertThat(expression("3*'4'::BIGINT"))
1367+
.isEqualTo(new ArithmeticBinaryExpression(
1368+
location(1, 2),
1369+
ArithmeticBinaryExpression.Operator.MULTIPLY,
1370+
new LongLiteral(location(1, 1), "3"),
1371+
new Cast(
1372+
location(1, 3),
1373+
new StringLiteral(location(1, 3), "4"),
1374+
simpleType(location(1, 8), "BIGINT"))));
1375+
1376+
assertThat(expression("CAST(ROW(11, 12) AS ROW(COL0 INTEGER, COL1 INTEGER))"))
1377+
.isEqualTo(new Cast(location(1, 1),
1378+
new Row(location(1, 6), Lists.newArrayList(new LongLiteral(location(1, 10), "11"), new LongLiteral(location(1, 14), "12"))),
1379+
rowType(
1380+
location(1, 21),
1381+
field(location(1, 25), "COL0", simpleType(location(1, 30), "INTEGER")),
1382+
field(location(1, 39), "COL1", simpleType(location(1, 44), "INTEGER")))));
1383+
1384+
assertThat(expression("ROW(11, 12)::ROW(COL0 INTEGER, COL1 INTEGER)"))
1385+
.isEqualTo(new Cast(location(1, 1),
1386+
new Row(location(1, 1), Lists.newArrayList(new LongLiteral(location(1, 5), "11"), new LongLiteral(location(1, 9), "12"))),
1387+
rowType(
1388+
location(1, 14),
1389+
field(location(1, 18), "COL0", simpleType(location(1, 23), "INTEGER")),
1390+
field(location(1, 32), "COL1", simpleType(location(1, 37), "INTEGER")))));
1391+
}
1392+
13481393
@Test
13491394
public void testSearchedCase()
13501395
{

core/trino-parser/src/test/java/io/trino/sql/parser/TestSqlParserErrorHandling.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ private static Stream<Arguments> expressions()
3333
{
3434
return Stream.of(
3535
Arguments.of("", "line 1:1: mismatched input '<EOF>'. Expecting: <expression>"),
36-
Arguments.of("1 + 1 x", "line 1:7: mismatched input 'x'. Expecting: '%', '*', '+', '-', '.', '/', 'AND', 'AT', 'OR', '[', '||', <EOF>, <predicate>"));
36+
Arguments.of("1 + 1 x", "line 1:7: mismatched input 'x'. Expecting: '%', '*', '+', '-', '.', '/', '::', 'AND', 'AT', 'OR', '[', '||', <EOF>, <predicate>"));
3737
}
3838

3939
private static Stream<Arguments> statements()
@@ -67,7 +67,7 @@ private static Stream<Arguments> statements()
6767
Arguments.of("select 1x from dual",
6868
"line 1:8: identifiers must not start with a digit; surround the identifier with double quotes"),
6969
Arguments.of("select fuu from dual order by fuu order by fuu",
70-
"line 1:35: mismatched input 'order'. Expecting: '%', '*', '+', ',', '-', '.', '/', 'AND', 'ASC', 'AT', 'DESC', 'FETCH', 'LIMIT', 'NULLS', 'OFFSET', 'OR', '[', '||', <EOF>, <predicate>"),
70+
"line 1:35: mismatched input 'order'. Expecting: '%', '*', '+', ',', '-', '.', '/', '::', 'AND', 'ASC', 'AT', 'DESC', 'FETCH', 'LIMIT', 'NULLS', 'OFFSET', 'OR', '[', '||', <EOF>, <predicate>"),
7171
Arguments.of("select fuu from dual limit 10 order by fuu",
7272
"line 1:31: mismatched input 'order'. Expecting: <EOF>"),
7373
Arguments.of("select CAST(12223222232535343423232435343 AS BIGINT)",
@@ -99,7 +99,7 @@ private static Stream<Arguments> statements()
9999
Arguments.of("SELECT x() over (ROWS select) FROM t",
100100
"line 1:23: mismatched input 'select'. Expecting: ')', 'BETWEEN', 'CURRENT', 'GROUPS', 'MEASURES', 'ORDER', 'PARTITION', 'RANGE', 'ROWS', 'UNBOUNDED', <expression>"),
101101
Arguments.of("SELECT X() OVER (ROWS UNBOUNDED) FROM T",
102-
"line 1:32: mismatched input ')'. Expecting: '%', '(', '*', '+', '-', '->', '.', '/', 'AND', 'AT', 'FOLLOWING', 'OR', 'OVER', 'PRECEDING', '[', '||', <predicate>, <string>"),
102+
"line 1:32: mismatched input ')'. Expecting: '%', '(', '*', '+', '-', '->', '.', '/', '::', 'AND', 'AT', 'FOLLOWING', 'OR', 'OVER', 'PRECEDING', '[', '||', <predicate>, <string>"),
103103
Arguments.of("SELECT a FROM x ORDER BY (SELECT b FROM t WHERE ",
104104
"line 1:49: mismatched input '<EOF>'. Expecting: <expression>"),
105105
Arguments.of("SELECT a FROM a AS x TABLESAMPLE x ",
@@ -134,7 +134,7 @@ private static Stream<Arguments> statements()
134134
Arguments.of("SELECT a FROM \"\".s.t",
135135
"line 1:15: Zero-length delimited identifier not allowed"),
136136
Arguments.of("WITH t AS (SELECT 1 SELECT t.* FROM t",
137-
"line 1:21: mismatched input 'SELECT'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', 'AND', 'AS', 'AT', 'EXCEPT', 'FETCH', 'FROM', " +
137+
"line 1:21: mismatched input 'SELECT'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', '::', 'AND', 'AS', 'AT', 'EXCEPT', 'FETCH', 'FROM', " +
138138
"'GROUP', 'HAVING', 'INTERSECT', 'LIMIT', 'OFFSET', 'OR', 'ORDER', 'UNION', 'WHERE', 'WINDOW', '[', '||', " +
139139
"<identifier>, <predicate>"),
140140
Arguments.of("SHOW CATALOGS LIKE '%$_%' ESCAPE",
@@ -160,9 +160,9 @@ private static Stream<Arguments> statements()
160160
Arguments.of("SELECT * FROM t FOR VERSION AS OF TIMESTAMP WHERE",
161161
"line 1:50: mismatched input '<EOF>'. Expecting: <expression>"),
162162
Arguments.of("SELECT ROW(DATE '2022-10-10', DOUBLE 12.0)",
163-
"line 1:38: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'ORDER', 'OVER', 'PRECISION', '[', '||', <predicate>, <string>"),
163+
"line 1:38: mismatched input '12.0'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', '::', 'AND', 'AT', 'OR', 'ORDER', 'OVER', 'PRECISION', '[', '||', <predicate>, <string>"),
164164
Arguments.of("VALUES(DATE 2)",
165-
"line 1:13: mismatched input '2'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', 'AND', 'AT', 'OR', 'OVER', '[', '||', <predicate>, <string>"),
165+
"line 1:13: mismatched input '2'. Expecting: '%', '(', ')', '*', '+', ',', '-', '->', '.', '/', '::', 'AND', 'AT', 'OR', 'OVER', '[', '||', <predicate>, <string>"),
166166
Arguments.of("SELECT count(DISTINCT *) FROM (VALUES 1)",
167167
"line 1:23: mismatched input '*'. Expecting: <expression>"));
168168
}
@@ -182,7 +182,7 @@ public void testPossibleExponentialBacktracking()
182182
"1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * " +
183183
"1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9 * " +
184184
"1 * 2 * 3 * 4 * 5 * 6 * 7 * 8 * 9",
185-
"line 1:375: mismatched input '<EOF>'. Expecting: '%', '*', '+', '-', '.', '/', 'AND', 'AT', 'OR', 'THEN', '[', '||', <predicate>");
185+
"line 1:375: mismatched input '<EOF>'. Expecting: '%', '*', '+', '-', '.', '/', '::', 'AND', 'AT', 'OR', 'THEN', '[', '||', <predicate>");
186186
}
187187

188188
@Test
@@ -212,7 +212,7 @@ public void testPossibleExponentialBacktracking2()
212212
"OR (f()\n" +
213213
"OR (f()\n" +
214214
"GROUP BY id",
215-
"line 24:1: mismatched input 'GROUP'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', 'AND', 'AT', 'FILTER', 'IGNORE', 'OR', 'OVER', 'RESPECT', '[', '||', <predicate>");
215+
"line 24:1: mismatched input 'GROUP'. Expecting: '%', ')', '*', '+', ',', '-', '.', '/', '::', 'AND', 'AT', 'FILTER', 'IGNORE', 'OR', 'OVER', 'RESPECT', '[', '||', <predicate>");
216216
}
217217

218218
@ParameterizedTest

0 commit comments

Comments
 (0)