Skip to content
Draft
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from Standard.Base import all
import Standard.Base.Errors.Illegal_Argument.Illegal_Argument
import Standard.Base.Errors.Illegal_State.Illegal_State

from Standard.Table import Aggregate_Column, Value_Type

Expand Down Expand Up @@ -30,7 +31,7 @@ from Standard.Database.Internal.JDBC_Connection import JDBC_Connection
The dialect for Redshift connections.
redshift : Redshift_Dialect
redshift =
Redshift_Dialect.Value Postgres_Connection.get_dialect.get_dialect_operations
Redshift_Dialect.Value make_dialect_operations

## ---
private: true
Expand Down Expand Up @@ -108,7 +109,31 @@ type Redshift_Dialect
Generates a SQL expression for a table literal.
make_table_literal : Vector (Vector Text) -> Vector Text -> Text -> SQL_Builder
make_table_literal self vecs column_names as_name =
Base_Generator.default_make_table_literal self.wrap_identifier vecs column_names as_name
# Amazon Redshift does not support using VALUES(...) as a table expression
# in the FROM clause. To represent inline literal tables, we therefore
# generate a SELECT ... UNION ALL SELECT ... construct instead, which
# Redshift accepts and which preserves the same semantics.
rows = vecs.transpose
if rows.is_empty then
Error.throw (Illegal_Argument.Error "Cannot build a table literal with zero rows.")
first_row = rows.at 0

wrapped_name = self.wrap_identifier as_name
column_list = SQL_Builder.join ", " (column_names.map self.wrap_identifier) . paren

cols = (first_row.zip column_names).map pair->
SQL_Builder.interpolation pair.first ++ SQL_Builder.code " AS " ++ self.wrap_identifier pair.second

# SELECT <v1> AS col1, <v2> AS col2, ...
first_select = SQL_Builder.code "SELECT " ++ (SQL_Builder.join ", " cols)

# UNION ALL SELECT <v1>, <v2>, ...
rest_selects = rows.drop 1 . map row->
SQL_Builder.code "SELECT " ++ (SQL_Builder.join ", " (row.map SQL_Builder.interpolation))

union_body = SQL_Builder.join " UNION ALL " ([first_select] + rest_selects)

SQL_Builder.code "(" ++ union_body ++ ") AS " ++ wrapped_name ++ column_list

## ---
private: true
Expand Down Expand Up @@ -180,7 +205,7 @@ type Redshift_Dialect
Add an extra cast to adjust the output type of aggregate operations. Some
DBs do CAST(SUM(x) AS FLOAT) others do SUM(CAST(x AS FLOAT)).
cast_aggregate_columns self op_kind:Text columns:(Vector Internal_Column) =
self.cast_op_type op_kind columns (SQL_IR_Expression.Operation op_kind (columns.map c->c.expression))
self.cast_op_type op_kind columns (SQL_IR_Expression.Operation op_kind (columns.map c->(c:Internal_Column).expression))

## ---
private: true
Expand All @@ -199,8 +224,10 @@ type Redshift_Dialect
---
check_aggregate_support : Aggregate_Column -> Boolean ! Unsupported_Database_Operation
check_aggregate_support self aggregate =
_ = aggregate
True
_ = self
case aggregate of
Aggregate_Column.Mode _ _ -> Error.throw (Unsupported_Database_Operation.Error "Mode")
_ -> True

## ---
private: true
Expand Down Expand Up @@ -232,9 +259,8 @@ type Redshift_Dialect
Dialect_Flag.Primary_Key_Allows_Nulls -> False
## TODO: Check if Redshift supports NaN
Dialect_Flag.Supports_Separate_NaN -> False
## TODO: Check if Redshift supports WITH clauses in nested queries
Dialect_Flag.Supports_Nested_With_Clause -> True
Dialect_Flag.Supports_Case_Sensitive_Columns -> True
Dialect_Flag.Supports_Nested_With_Clause -> False
Dialect_Flag.Supports_Case_Sensitive_Columns -> False
Dialect_Flag.Supports_Infinity -> True
Dialect_Flag.Case_Sensitive_Text_Comparison -> True
Dialect_Flag.Supports_Sort_Digits_As_Numbers -> False
Expand Down Expand Up @@ -301,3 +327,71 @@ type Redshift_Dialect
ensure_query_has_no_holes : JDBC_Connection -> Text -> Nothing ! Illegal_Argument
ensure_query_has_no_holes self jdbc:JDBC_Connection raw_sql:Text =
jdbc.ensure_query_has_no_holes raw_sql

## ---
private: true
---
make_dialect_operations =
postgres_ops = Postgres_Connection.get_dialect.get_dialect_operations
overrides = [["COUNT_DISTINCT", agg_count_distinct], ["COUNT_DISTINCT_INCLUDE_NULL", agg_count_distinct_include_null], ["GROUP_NUMBER_STANDARD_DEVIATION_POPULATION", make_group_number_standard_deviation "STDDEV_POP"], ["GROUP_NUMBER_STANDARD_DEVIATION_SAMPLE", make_group_number_standard_deviation "STDDEV_SAMP"]]
stats = [agg_median, agg_percentile]
my_mappings = overrides + stats
postgres_ops.extend_with my_mappings

## ---
private: true
---
agg_count_distinct args =
if args.is_empty then (Error.throw (Illegal_Argument.Error "COUNT_DISTINCT requires at least one argument.")) else
count = SQL_Builder.code "COUNT(DISTINCT " ++ SQL_Builder.join ", " args ++ ")"
are_nulls = args.map arg-> arg.paren ++ " IS NULL"
all_nulls_filter = SQL_Builder.code " FILTER (WHERE NOT (" ++ SQL_Builder.join " AND " are_nulls ++ "))"
(count ++ all_nulls_filter).paren

## ---
private: true
---
agg_count_distinct_include_null args =
if args.is_empty then (Error.throw (Illegal_Argument.Error "COUNT_DISTINCT_INCLUDE_NULL requires at least one argument.")) else
SQL_Builder.code "COUNT(DISTINCT " ++ SQL_Builder.join ", " args ++ ")"

## ---
private: true
---
make_group_number_standard_deviation : Text -> Vector SQL_Builder -> SQL_Builder
make_group_number_standard_deviation (stddev_variant : Text) (arguments : Vector) =
if arguments.length < 1 then
Error.throw (Illegal_State.Error "Wrong amount of parameters in GROUP_NUMBER_STANDARD_DEVIATION IR. This is a bug in the Database library.")

stddev_column = arguments.at 0

stddev = SQL_Builder.code stddev_variant ++ "(" ++ stddev_column ++ ") over ()"
## If standard deviation is 0, using 1 instead to avoid division by zero,
resulting in all z-scores being zero and thus "Average" group.
If stanard deviation is NULL, using 1 instead for the same reason.
(STDDEV_SAMP will return NULL for a single value input.)
stddev_safe = SQL_Builder.code "(case when " ++ stddev.paren ++ " = 0 or " ++ stddev.paren ++ " is null then 1 else " ++ stddev.paren ++ " end)"
## Redshift AVG on integer inputs returns BIGINT, so force floating point.
mean = SQL_Builder.code "(avg(" ++ stddev_column.paren ++ " * 1.0) over ())"
zscore = SQL_Builder.code "((" ++ stddev_column.paren ++ " - " ++ mean.paren ++ ") / " ++ stddev_safe.paren ++ ")"
group_index = SQL_Builder.code "(floor(" ++ zscore.paren ++ " + 0.5))"
whener gi op name = SQL_Builder.code "when " ++ group_index.paren ++ " " ++ op ++ gi.to_text ++ " then '" ++ name ++ "'"
whens = Vector.build builder->
builder.append (whener -2 "<=" "Very Low")
builder.append (whener -1 "=" "Low")
builder.append (whener 0 "=" "Average")
builder.append (whener 1 "=" "High")
builder.append (whener 2 ">=" "Very High")
SQL_Builder.code "(case " ++ SQL_Builder.join " " whens ++ " end)"

## ---
private: true
---
agg_median = Base_Generator.lift_unary_op "MEDIAN" arg->
SQL_Builder.code "percentile_cont(0.5) WITHIN GROUP (ORDER BY " ++ arg ++ ")"

## ---
private: true
---
agg_percentile = Base_Generator.lift_binary_op "PERCENTILE" p-> expr->
SQL_Builder.code "percentile_cont(" ++ p ++ ") WITHIN GROUP (ORDER BY " ++ expr ++ ")"
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ type Redshift_Error_Mapper
is_null_primary_key_violation error =
error.java_exception.getMessage.contains "violates not-null constraint"

## ---
private: true
---
is_table_already_exists_error : SQL_Error -> Boolean
is_table_already_exists_error error =
error.java_exception.getMessage.match "ERROR: Relation .* already exists"

## ---
private: true
---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ type Redshift_Details
types_record = JDBCDriverTypes.create "Redshift"
jdbc_connection = JDBC_Connection.create types_record self.jdbc_url properties

set_result = jdbc_connection.execute "SET enable_case_sensitive_identifier TO true"
set_result.if_not_error <|
stmt = set_result.at 1
stmt.close

# TODO [RW] can we inherit these from postgres?
encoding = parse_postgres_encoding (get_encoding_name jdbc_connection)
entity_naming_properties = Entity_Naming_Properties.from_jdbc_connection jdbc_connection encoding is_case_sensitive=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ type Connection
_ : Vector -> types
_ -> [types]
tm = self.type_mapping
result = get_tables_advanced self.jdbc_connection name_like database schema types_vector tm.column_fetcher_factory tm.sql_type_to_value_type all_fields
result = _get_tables_advanced self.jdbc_connection name_like database schema types_vector tm.column_fetcher_factory tm.sql_type_to_value_type all_fields
case include_hidden of
True -> result
False ->
Expand Down Expand Up @@ -672,11 +672,11 @@ private _check_statement_is_allowed connection stmt =
private: true
---
Get a metadata table for tables.
get_tables_advanced jdbc_connection name_like:Text|Nothing database:Text|Nothing schema:Text|Nothing types:Nothing|Vector column_fetcher_factory sql_type_to_value_type=(_->Nothing) all_fields:Boolean=False =
_get_tables_advanced jdbc_connection name_like:Text|Nothing database:Text|Nothing schema:Text|Nothing types:Nothing|Vector column_fetcher_factory sql_type_to_value_type=(_->Nothing) all_fields:Boolean=False =
name_dict = Dictionary.from_vector [["TABLE_CAT", "Database"], ["TABLE_SCHEM", "Schema"], ["TABLE_NAME", "Name"], ["TABLE_TYPE", "Type"], ["REMARKS", "Description"], ["TYPE_CAT", "Type Database"], ["TYPE_SCHEM", "Type Schema"], ["TYPE_NAME", "Type Name"]]
jdbc_connection.with_metadata metadata->
table = Managed_Resource.bracket (metadata.getTables database schema name_like types) .close result_set->
result_set_to_table result_set column_fetcher_factory sql_type_to_value_type
renamed = table.rename_columns name_dict error_on_missing_columns=False on_problems=..Ignore
renamed = table.rename_columns name_dict case_sensitivity=..Insensitive error_on_missing_columns=False on_problems=..Ignore
if all_fields then renamed else
renamed.select_columns ["Database", "Schema", "Name", "Type", "Description"]
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import Standard.Base.Errors.Illegal_Argument.Illegal_Argument
import Standard.Base.Errors.Illegal_State.Illegal_State

import project.Internal.JDBC_Connection.JDBC_Connection
from project.Errors import Unsupported_Database_Encoding
from project.Errors import SQL_Error, Unsupported_Database_Encoding

polyglot java import org.enso.database.fetchers.ColumnFetcherFactory

Expand All @@ -12,7 +12,7 @@ polyglot java import org.enso.database.fetchers.ColumnFetcherFactory
---
get_encoding_name : JDBC_Connection -> Text
get_encoding_name jdbc_connection =
server_encoding = get_pragma_value jdbc_connection "SHOW server_encoding"
server_encoding = get_pragma_value jdbc_connection "SHOW server_encoding" . catch SQL_Error _-> "SQL_ASCII"
if server_encoding != "SQL_ASCII" then server_encoding else
get_pragma_value jdbc_connection "SHOW client_encoding"

Expand Down
8 changes: 4 additions & 4 deletions test/AWS_Tests/src/Redshift_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Data
setup create_connection_fn = Data.Value <|
connection = create_connection_fn Nothing
tinfo = Name_Generator.random_name "Tinfo"
connection.execute_update 'CREATE TEMPORARY TABLE "'+tinfo+'" ("strs" VARCHAR, "ints" INTEGER, "bools" BOOLEAN, "reals" REAL)'
connection.execute_update 'CREATE TEMPORARY TABLE "'+tinfo+'" ("strs" VARCHAR, "ints" BIGINT, "bools" BOOLEAN, "reals" REAL)'
t = connection.query (SQL_Query.Table_Name tinfo)
row1 = ["a", Nothing, False, 1.2]
row2 = ["abc", Nothing, Nothing, 1.3]
Expand All @@ -48,7 +48,7 @@ add_redshift_specific_specs suite_builder create_connection_fn setup =
i = data.t.column_info
i.at "Column" . to_vector . should_equal ["strs", "ints", "bools", "reals"]
i.at "Items Count" . to_vector . should_equal [3, 1, 2, 3]
i.at "Value Type" . to_vector . should_equal [Value_Type.Char, Value_Type.Integer Bits.Bits_32, Value_Type.Boolean, Value_Type.Float Bits.Bits_32]
i.at "Value Type" . to_vector . should_equal [Value_Type.Char 256 True, Value_Type.Integer Bits.Bits_64, Value_Type.Boolean, Value_Type.Float Bits.Bits_32]

group_builder.specify "should infer standard types correctly" <|
data.t.at "strs" . value_type . is_text . should_be_true
Expand All @@ -75,8 +75,8 @@ add_database_specs suite_builder create_connection_fn =
materialize = .read

common_selection = Common_Table_Operations.Main.Test_Selection.Config run_advanced_edge_case_tests_by_default=False
aggregate_selection = Common_Table_Operations.Aggregate_Spec.Test_Selection.Config first_last_row_order=False aggregation_problems=False date_support=False
agg_in_memory_table = (enso_project.data / "data.csv") . read
aggregate_selection = Common_Table_Operations.Aggregate_Spec.Test_Selection.Config advanced_stats=False first_last_row_order=False aggregation_problems=False date_support=False multi_distinct=False
agg_in_memory_table = ((Project_Description.new enso_dev.Table_Tests).data / "data.csv") . read
agg_table_fn = _->
agg_in_memory_table.select_into_database_table default_connection.get (Name_Generator.random_name "Agg1") primary_key=Nothing temporary=True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,21 +298,23 @@ add_group_number_specs suite_builder setup =
label_p2 = "Very High"

group_builder.specify "should add group number by standard deviation" <|
t = table_builder [['x', [3, 4, 5, 4, 3, 5, 10, 4, 3]], ['y', [1, 2, 2, 1, 2, 1, 1, 1, 1]]]
row_id = 0.up_to 9 . to_vector
t = table_builder [['x', [3, 4, 5, 4, 3, 5, 10, 4, 3]], ['y', [1, 2, 2, 1, 2, 1, 1, 1, 1]], ['row_id', row_id]]

g1 = t.add_group_number (..Standard_Deviation "x") "g"
g1 = t.add_group_number (..Standard_Deviation "x") "g" . sort 'row_id'
g1.at 'g' . to_vector . should_equal [label_n1, label_0, label_0, label_0, label_n1, label_0, label_p2, label_0, label_n1]

g3 = t.add_group_number (..Standard_Deviation "y") "g"
g3 = t.add_group_number (..Standard_Deviation "y") "g" . sort 'row_id'
g3.at 'g' . to_vector . should_equal [label_n1, label_p1, label_p1, label_n1, label_p1, label_n1, label_n1, label_n1, label_n1]

group_builder.specify "should respect standard deviation 'population' parameter" <|
t = table_builder [['x', [3, 4, 5, 4, 3, 5, 10, 4, 3, 5.67]]]
row_id = 0.up_to 10 . to_vector
t = table_builder [['x', [3, 4, 5, 4, 3, 5, 10, 4, 3, 5.67]], ['row_id', row_id]]

g0 = t.add_group_number (..Standard_Deviation "x" population=True) "g"
g0 = t.add_group_number (..Standard_Deviation "x" population=True) "g" . sort 'row_id'
g0.at 'g' . to_vector . should_equal [label_n1, label_0, label_0, label_0, label_n1, label_0, label_p2, label_0, label_n1, label_p1]

g1 = t.add_group_number (..Standard_Deviation "x" population=False) "g"
g1 = t.add_group_number (..Standard_Deviation "x" population=False) "g" . sort 'row_id'
g1.at 'g' . to_vector . should_equal [label_n1, label_0, label_0, label_0, label_n1, label_0, label_p2, label_0, label_n1, label_0]

if setup.is_database.not then
Expand Down
23 changes: 10 additions & 13 deletions test/Table_Tests/src/Database/Common/Common_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
group_builder.specify "should allow direct read by execute_query" <|
name = data.t1.name
tmp = data.connection.execute_query ('SELECT * FROM "' + name + '"')
tmp.read . should_equal data.t1.read
tmp.read . should_equal_ignoring_order data.t1.read

Check failure on line 219 in test/Table_Tests/src/Database/Common/Common_Spec.enso

View workflow job for this annotation

GitHub Actions / ⚙️ Checks / Standard Library JVM Tests (GraalVM CE) (linux, amd64)

[DuckDB In-Memory] Connection.query: should allow direct read by execute_query

An unexpected panic was thrown: (Illegal_Argument.Error 'Expected a Vector, but got a Table & ~In_Memory_Table (at /runner/_work/enso/enso/test/Table_Tests/src/Database/Common/Common_Spec.enso:219:13-63).' Nothing)

Check failure on line 219 in test/Table_Tests/src/Database/Common/Common_Spec.enso

View workflow job for this annotation

GitHub Actions / ⚙️ Checks / Standard Library Native Tests (GraalVM CE) (linux, amd64)

[DuckDB In-Memory] Connection.query: should allow direct read by execute_query

An unexpected panic was thrown: (Illegal_Argument.Error 'Expected a Vector, but got a Table & ~In_Memory_Table (at /runner/_work/enso/enso/test/Table_Tests/src/Database/Common/Common_Spec.enso:219:13-63).' Nothing)

group_builder.specify "should allow to access a Table by an SQL query" <|
name = data.t1.name
Expand Down Expand Up @@ -268,7 +268,7 @@
(Table.new [["a", [100, 200]]]).select_into_database_table data.connection name temporary=True

# Reading a column that was kept will work OK
t1.at "a" . to_vector . should_equal [100, 200]
t1.at "a" . to_vector . should_equal_ignoring_order [100, 200]

# But reading the whole table will fail on the missing column:
m2 = t1.read
Expand Down Expand Up @@ -401,8 +401,6 @@
table.sort ([..Name order_column])

group_builder.specify "should allow counting group sizes and elements" <|
## Names set to lower case to avoid issue with Redshift where columns are
returned in lower case.
aggregates = [Aggregate_Column.Count "count", Aggregate_Column.Count_Not_Nothing "price" "count not nothing price", Aggregate_Column.Count_Nothing "price" "count nothing price"]

t1 = determinize_by "name" (data.t9.aggregate ["name"] aggregates . read)
Expand All @@ -417,21 +415,20 @@
t2.at "count nothing price" . to_vector . should_equal [5]

group_builder.specify "should allow simple arithmetic aggregations" <|
## Names set to lower case to avoid issue with Redshift where columns are
returned in lower case.
aggregates = [Aggregate_Column.Sum "price" "sum price", Aggregate_Column.Sum "quantity" "sum quantity", Aggregate_Column.Average "price" "avg price"]
eps=0.000001
aggregates = [Aggregate_Column.Sum "price" "Sum price", Aggregate_Column.Sum "quantity" "Sum quantity", Aggregate_Column.Average "price" "Avg price"]
## TODO can check the datatypes

t1 = determinize_by "name" (data.t9.aggregate ["name"] aggregates . read)
t1.at "name" . to_vector . should_equal ["bar", "baz", "foo", "quux", "zzzz"]
t1.at "sum price" . to_vector . should_equal [100.5, 6.7, 1, Nothing, 2]
t1.at "sum quantity" . to_vector . should_equal [80, 40, 120, 70, 2]
t1.at "avg price" . to_vector . should_equal [50.25, 6.7, (1/3), Nothing, (2/5)]
t1.at "Sum price" . to_vector . should_equal [100.5, 6.7, 1, Nothing, 2]
t1.at "Sum quantity" . to_vector . should_equal [80, 40, 120, 70, 2]
t1.at "Avg price" . to_vector . should_equal [50.25, 6.7, (1/3), Nothing, (2/5)]

t2 = data.t9.aggregate [] aggregates . read
t2.at "sum price" . to_vector . should_equal [110.2]
t2.at "sum quantity" . to_vector . should_equal [312]
t2.at "avg price" . to_vector . should_equal [(110.2 / 11)]
t2.at "Sum price" . to_vector . at 0 . should_equal 110.2 epsilon=eps
t2.at "Sum quantity" . to_vector . at 0 . should_equal 312 epsilon=eps
t2.at "Avg price" . to_vector . at 0 . should_equal (110.2 / 11) epsilon=eps

suite_builder.group prefix+"Table.filter" group_builder->
data = Basic_Data.setup default_connection.get
Expand Down
Loading
Loading