Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/adbc/driver/snowflake): add quoted identifier ignore case option #1800

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 39 additions & 2 deletions go/adbc/driver/snowflake/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ type connectionImpl struct {
ctor gosnowflake.Connector
sqldb *sql.DB

activeTransaction bool
useHighPrecision bool
activeTransaction bool
useHighPrecision bool
quotedIdentIgnoreCase bool
}

// Uniquely identify a constraint based on the dbName, schema, and tblName
Expand Down Expand Up @@ -1337,10 +1338,46 @@ func (c *connectionImpl) SetOption(key, value string) error {
}
}
return nil
case OptionQuotedIdentifiersIgnoreCase:
switch value {
case adbc.OptionValueEnabled, adbc.OptionValueDisabled:
c.quotedIdentIgnoreCase = value == adbc.OptionValueEnabled
q := "ALTER SESSION SET QUOTED_IDENTIFIERS_IGNORE_CASE = " + value
if _, err := c.cn.ExecContext(context.Background(), q, nil); err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
}

if _, err := c.sqldb.ExecContext(context.Background(), q); err != nil {
return errToAdbcErr(adbc.StatusInternal, err)
}
default:
return adbc.Error{
Msg: "[Snowflake] invalid value for option " + key + ": " + value,
Code: adbc.StatusInvalidArgument,
}
}
return nil
default:
return adbc.Error{
Msg: "[Snowflake] unknown connection option " + key + ": " + value,
Code: adbc.StatusInvalidArgument,
}
}
}

func (c *connectionImpl) GetOption(key string) (string, error) {
switch key {
case OptionUseHighPrecision:
if c.useHighPrecision {
return adbc.OptionValueEnabled, nil
}
return adbc.OptionValueDisabled, nil
case OptionQuotedIdentifiersIgnoreCase:
if c.quotedIdentIgnoreCase {
return adbc.OptionValueEnabled, nil
}
return adbc.OptionValueDisabled, nil
default:
return c.db.GetOption(key)
}
}
14 changes: 11 additions & 3 deletions go/adbc/driver/snowflake/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,17 @@ const (
// scale will return a Float64 column.
OptionUseHighPrecision = "adbc.snowflake.sql.client_option.use_high_precision"

OptionApplicationName = "adbc.snowflake.sql.client_option.app_name"
OptionSSLSkipVerify = "adbc.snowflake.sql.client_option.tls_skip_verify"
OptionOCSPFailOpenMode = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode"
// OptionQuotedIdentifiersIgnoreCase refers to the corresponding snowflake session
// parameter (https://docs.snowflake.com/en/sql-reference/parameters#label-quoted-identifiers-ignore-case)
// which controls whether or not the case of quoted identifiers will be preserved (default)
// or will be ignored (storing and resolving as uppercase).
// Because functionality such as bulk ingest and other options will automatically add quotes
// to identifiers by default, this option can be set to TRUE to ensure that the casing will
// be ignored for that functionality despite the fact that we wrap it in quotes.
OptionQuotedIdentifiersIgnoreCase = "adbc.snowflake.sql.quoted_identifiers_ignore_case"
OptionApplicationName = "adbc.snowflake.sql.client_option.app_name"
OptionSSLSkipVerify = "adbc.snowflake.sql.client_option.tls_skip_verify"
OptionOCSPFailOpenMode = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode"
// specify the token to use for OAuth or other forms of authentication
OptionAuthToken = "adbc.snowflake.sql.client_option.auth_token"
// specify the OKTAUrl to use for OKTA Authentication
Expand Down
120 changes: 120 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/google/uuid"
"github.com/snowflakedb/gosnowflake"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
Expand Down Expand Up @@ -2031,3 +2032,122 @@ func (suite *SnowflakeTests) TestMetadataOnlyQuery() {
// all the rows from each record in the stream.
suite.Equal(n, recv)
}

func TestSnowflakeQuotedIdentIgnoreCase(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

sc := arrow.NewSchema([]arrow.Field{
{
Name: "col_int64", Type: arrow.PrimitiveTypes.Int64,
Nullable: true,
},
{
Name: "col_list", Type: arrow.ListOf(arrow.BinaryTypes.String),
Nullable: true,
},
}, nil)

bldr := array.NewRecordBuilder(mem, sc)
defer bldr.Release()

bldr.Field(0).(*array.Int64Builder).AppendValues([]int64{1, 2, 3}, nil)

listbldr := bldr.Field(1).(*array.ListBuilder)
listvalbldr := listbldr.ValueBuilder().(*array.StringBuilder)
listbldr.Append(true)
listvalbldr.Append("one")
listbldr.Append(true)
listvalbldr.Append("two")
listbldr.Append(true)
listvalbldr.Append("three")

rec := bldr.NewRecord()
defer rec.Release()

expectedSchema := arrow.NewSchema([]arrow.Field{
{
Name: "col_int64", Type: arrow.PrimitiveTypes.Int64,
Nullable: true,
},
{
Name: "col_list", Type: arrow.BinaryTypes.String,
Nullable: true,
},
}, nil)

expectedRecord, _, err := array.RecordFromJSON(mem, expectedSchema, bytes.NewReader([]byte(`
[
{
"col_int64": 1,
"col_list": "[\n \"one\"\n]"
},
{
"col_int64": 2,
"col_list": "[\n \"two\"\n]"
},
{
"col_int64": 3,
"col_list": "[\n \"three\"\n]"
}
]
`)))
require.NoError(t, err)
defer expectedRecord.Release()

withQuirks(t, func(q *SnowflakeQuirks) {
drv := q.SetupDriver(t)
opts := q.DatabaseOptions()
// initialize connection with this session parameter set so that ingest
// and DropTable will both ignore the casing for the quoted identifiers
opts[driver.OptionQuotedIdentifiersIgnoreCase] = adbc.OptionValueEnabled

db, err := drv.NewDatabase(opts)
require.NoError(t, err)
defer db.Close()

ctx := context.Background()
cnxn, err := db.Open(ctx)
require.NoError(t, err)
defer cnxn.Close()

require.NoError(t, q.DropTable(cnxn, "bulk_ingest_list"))

stmt, err := cnxn.NewStatement()
require.NoError(t, err)
defer stmt.Close()

require.NoError(t, stmt.Bind(ctx, rec))
require.NoError(t, stmt.SetOption(adbc.OptionKeyIngestTargetTable, "bulk_ingest_list"))
n, err := stmt.ExecuteUpdate(ctx)
require.NoError(t, err)
assert.EqualValues(t, 3, n)

// disable the quoted identifiers option to get the default behavior back
require.NoError(t, cnxn.(adbc.GetSetOptions).
SetOption(driver.OptionQuotedIdentifiersIgnoreCase, adbc.OptionValueDisabled))
// with the option disabled this query should error because wrapping with quotes
// would preserve the case and the table wouldn't exist
require.NoError(t, stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_list" order by col_int64 ASC`))
_, _, err = stmt.ExecuteQuery(ctx)
assert.Error(t, err)

// confirm that our ingested table is using uppercase because of our option usage
require.NoError(t, stmt.SetSqlQuery("SELECT * FROM BULK_INGEST_LIST order by col_int64 ASC"))

rdr, n, err := stmt.ExecuteQuery(ctx)
require.NoError(t, err)
defer rdr.Release()

assert.EqualValues(t, 3, n)
assert.True(t, rdr.Next())
result := rdr.Record()
assert.Truef(t, array.RecordEqual(expectedRecord, result), "expected: %s\ngot: %s", expectedRecord, result)
logicalTypeList, ok := result.Schema().Field(1).Metadata.GetValue("logicalType")
assert.True(t, ok)
assert.Equal(t, "ARRAY", logicalTypeList)

assert.False(t, rdr.Next())
require.NoError(t, rdr.Err())
})
}
31 changes: 29 additions & 2 deletions go/adbc/driver/snowflake/snowflake_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ var (
}
)

const quotedIdentifiersIgnoreCase = "QUOTED_IDENTIFIERS_IGNORE_CASE"

type databaseImpl struct {
driverbase.DatabaseImplBase
cfg *gosnowflake.Config
Expand Down Expand Up @@ -130,6 +132,12 @@ func (d *databaseImpl) GetOption(key string) (string, error) {
return adbc.OptionValueEnabled, nil
}
return adbc.OptionValueDisabled, nil
case OptionQuotedIdentifiersIgnoreCase:
v, exists := d.cfg.Params[quotedIdentifiersIgnoreCase]
if !exists {
return adbc.OptionValueDisabled, nil
}
return *v, nil
default:
val, ok := d.cfg.Params[key]
if ok {
Expand Down Expand Up @@ -427,6 +435,17 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
Code: adbc.StatusInvalidArgument,
}
}
case OptionQuotedIdentifiersIgnoreCase:
switch v {
case adbc.OptionValueEnabled, adbc.OptionValueDisabled:
d.cfg.Params[quotedIdentifiersIgnoreCase] = &v
default:
return adbc.Error{
Msg: fmt.Sprintf("Invalid value for database option '%s': '%s'",
OptionQuotedIdentifiersIgnoreCase, v),
Code: adbc.StatusInvalidArgument,
}
}
default:
d.cfg.Params[k] = &v
}
Expand All @@ -445,15 +464,23 @@ func (d *databaseImpl) Open(ctx context.Context) (adbc.Connection, error) {
return nil, errToAdbcErr(adbc.StatusIO, err)
}

var quoteIgnoreCase bool

v, exists := d.cfg.Params[quotedIdentifiersIgnoreCase]
if exists {
quoteIgnoreCase = *v == adbc.OptionValueEnabled
}

conn := &connectionImpl{
cn: cn.(snowflakeConn),
db: d, ctor: connector,
sqldb: sql.OpenDB(connector),
// default enable high precision
// SetOption(OptionUseHighPrecision, adbc.OptionValueDisabled) to
// get Int64/Float64 instead
useHighPrecision: d.useHighPrecision,
ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase),
useHighPrecision: d.useHighPrecision,
ConnectionImplBase: driverbase.NewConnectionImplBase(&d.DatabaseImplBase),
quotedIdentIgnoreCase: quoteIgnoreCase,
}

return driverbase.NewConnectionBuilder(conn).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,14 @@ class DatabaseOptions(enum.Enum):
OCSP_FAIL_OPEN_MODE = "adbc.snowflake.sql.client_option.ocsp_fail_open_mode"
PORT = "adbc.snowflake.sql.uri.port"
PROTOCOL = "adbc.snowflake.sql.uri.protocol"
#: Control the QUOTED_IDENTIFIERS_IGNORE_CASE snowflake session parameter as
#: described by snowflake parameter docs for #label-quoted-identifiers-ignore-case
#: This defaults to false as per the Snowflake documentation. This is
#: important for managing the table names created when using bulk_ingest
#: since we will wrap any identifiers in quotes by default. Behavior is not
#: defined when mixing this with manually running ALTER SESSION queries to
#: set the variable.
QUOTED_IDENTIFIERS_IGNORE_CASE = "adbc.snowflake.sql.quoted_identifiers_ignore_case"
REGION = "adbc.snowflake.sql.region"
#: request retry timeout EXCLUDING network roundtrip and reading http response
#: use format like http://pkg.go.dev/time#ParseDuration such as
Expand Down