diff --git a/src/main.go b/src/main.go index 7646e3f..6a51a76 100644 --- a/src/main.go +++ b/src/main.go @@ -6,7 +6,7 @@ import ( "time" ) -const VERSION = "0.29.1" +const VERSION = "0.29.2" func main() { config := LoadConfig() diff --git a/src/query_handler_test.go b/src/query_handler_test.go index 037225f..9fd2cfd 100644 --- a/src/query_handler_test.go +++ b/src/query_handler_test.go @@ -130,7 +130,7 @@ func TestHandleQuery(t *testing.T) { "SELECT * FROM pg_catalog.pg_user": { "description": {"usename", "usesysid", "usecreatedb", "usesuper", "userepl", "usebypassrls", "passwd", "valuntil", "useconfig"}, "types": {Uint32ToString(pgtype.TextOID)}, - "values": {"bemidb", "10", "t", "t", "t", "t", "", "NULL", "NULL"}, + "values": {"bemidb", "10", "t", "t", "t", "t", "", "", ""}, }, "SELECT datid FROM pg_catalog.pg_stat_activity": { "description": {"datid"}, @@ -177,12 +177,12 @@ func TestHandleQuery(t *testing.T) { Uint32ToString(pgtype.BoolOID), Uint32ToString(pgtype.BoolOID), Uint32ToString(pgtype.Int8OID), - Uint32ToString(pgtype.TextOID), - Uint32ToString(pgtype.TextOID), + Uint32ToString(pgtype.Int4OID), + Uint32ToString(pgtype.Int4OID), Uint32ToString(pgtype.BoolOID), - Uint32ToString(pgtype.TextOID), + Uint32ToString(pgtype.Int4OID), }, - "values": {"10", "bemidb", "true", "true", "true", "true", "true", "false", "-1", "NULL", "NULL", "false", "NULL"}, + "values": {"10", "bemidb", "true", "true", "true", "true", "true", "false", "-1", "", "", "false", ""}, }, "SELECT * FROM pg_catalog.pg_inherits": { "description": {"inhrelid", "inhparent", "inhseqno", "inhdetachpending"}, @@ -661,7 +661,7 @@ func TestHandleQuery(t *testing.T) { "SELECT s.usename, r.rolconfig FROM pg_catalog.pg_shadow s LEFT JOIN pg_catalog.pg_roles r ON s.usename = r.rolname": { "description": {"usename", "rolconfig"}, "types": {Uint32ToString(pgtype.TextOID)}, - "values": {"bemidb", "NULL"}, + "values": {"bemidb", ""}, }, "SELECT a.oid, pd.description FROM pg_catalog.pg_roles a LEFT JOIN pg_catalog.pg_shdescription pd ON a.oid = pd.objoid": { "description": {"oid", "description"}, @@ -866,7 +866,7 @@ func TestHandleParseQuery(t *testing.T) { &pgproto3.ParseComplete{}, }) - remappedQuery := "SELECT usename, passwd FROM (VALUES ('bemidb', '10'::int8, 'FALSE'::bool, 'FALSE'::bool, 'TRUE'::bool, 'FALSE'::bool, 'bemidb-encrypted', 'NULL', 'NULL')) pg_shadow(usename, usesysid, usecreatedb, usesuper, userepl, usebypassrls, passwd, valuntil, useconfig) WHERE usename = $1" + remappedQuery := "SELECT usename, passwd FROM (VALUES ('bemidb', '10'::int8, 'FALSE'::bool, 'FALSE'::bool, 'TRUE'::bool, 'FALSE'::bool, 'bemidb-encrypted', NULL, NULL)) pg_shadow(usename, usesysid, usecreatedb, usesuper, userepl, usebypassrls, passwd, valuntil, useconfig) WHERE usename = $1" if preparedStatement.Query != remappedQuery { t.Errorf("Expected the prepared statement query to be %v, got %v", remappedQuery, preparedStatement.Query) } diff --git a/src/query_parser_utils.go b/src/query_parser_utils.go index f523f16..b5ea926 100644 --- a/src/query_parser_utils.go +++ b/src/query_parser_utils.go @@ -28,16 +28,27 @@ func (utils *QueryParserUtils) MakeSubselectWithRowsNode(tableName string, colum for _, row := range rowsValues { var rowList []*pgQuery.Node for _, val := range row { - constNode := pgQuery.MakeAConstStrNode(val, 0) - if _, err := strconv.ParseInt(val, 10, 64); err == nil { - constNode = parserType.MakeCaseTypeCastNode(constNode, "int8") + if val == "NULL" { + constNode := &pgQuery.Node{ + Node: &pgQuery.Node_AConst{ + AConst: &pgQuery.A_Const{ + Isnull: true, + }, + }, + } + rowList = append(rowList, constNode) } else { - valLower := strings.ToLower(val) - if valLower == "true" || valLower == "false" { - constNode = parserType.MakeCaseTypeCastNode(constNode, "bool") + constNode := pgQuery.MakeAConstStrNode(val, 0) + if _, err := strconv.ParseInt(val, 10, 64); err == nil { + constNode = parserType.MakeCaseTypeCastNode(constNode, "int8") + } else { + valLower := strings.ToLower(val) + if valLower == "true" || valLower == "false" { + constNode = parserType.MakeCaseTypeCastNode(constNode, "bool") + } } + rowList = append(rowList, constNode) } - rowList = append(rowList, constNode) } selectStmt.ValuesLists = append(selectStmt.ValuesLists, &pgQuery.Node{Node: &pgQuery.Node_List{List: &pgQuery.List{Items: rowList}}})