Skip to content

Commit

Permalink
Add support for null types in sys tables
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunlol committed Jan 3, 2025
1 parent e2b6869 commit 2514060
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"
)

const VERSION = "0.29.1"
const VERSION = "0.29.2"

func main() {
config := LoadConfig()
Expand Down
14 changes: 7 additions & 7 deletions src/query_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestHandleQuery(t *testing.T) {
"SELECT * FROM pg_catalog.pg_user": {
"description": {"usename", "usesysid", "usecreatedb", "usesuper", "userepl", "usebypassrls", "passwd", "valuntil", "useconfig"},
"types": {Uint32ToString(pgtype.TextOID)},
"values": {"bemidb", "10", "t", "t", "t", "t", "", "NULL", "NULL"},
"values": {"bemidb", "10", "t", "t", "t", "t", "", "", ""},
},
"SELECT datid FROM pg_catalog.pg_stat_activity": {
"description": {"datid"},
Expand Down Expand Up @@ -177,12 +177,12 @@ func TestHandleQuery(t *testing.T) {
Uint32ToString(pgtype.BoolOID),
Uint32ToString(pgtype.BoolOID),
Uint32ToString(pgtype.Int8OID),
Uint32ToString(pgtype.TextOID),
Uint32ToString(pgtype.TextOID),
Uint32ToString(pgtype.Int4OID),
Uint32ToString(pgtype.Int4OID),
Uint32ToString(pgtype.BoolOID),
Uint32ToString(pgtype.TextOID),
Uint32ToString(pgtype.Int4OID),
},
"values": {"10", "bemidb", "true", "true", "true", "true", "true", "false", "-1", "NULL", "NULL", "false", "NULL"},
"values": {"10", "bemidb", "true", "true", "true", "true", "true", "false", "-1", "", "", "false", ""},
},
"SELECT * FROM pg_catalog.pg_inherits": {
"description": {"inhrelid", "inhparent", "inhseqno", "inhdetachpending"},
Expand Down Expand Up @@ -661,7 +661,7 @@ func TestHandleQuery(t *testing.T) {
"SELECT s.usename, r.rolconfig FROM pg_catalog.pg_shadow s LEFT JOIN pg_catalog.pg_roles r ON s.usename = r.rolname": {
"description": {"usename", "rolconfig"},
"types": {Uint32ToString(pgtype.TextOID)},
"values": {"bemidb", "NULL"},
"values": {"bemidb", ""},
},
"SELECT a.oid, pd.description FROM pg_catalog.pg_roles a LEFT JOIN pg_catalog.pg_shdescription pd ON a.oid = pd.objoid": {
"description": {"oid", "description"},
Expand Down Expand Up @@ -866,7 +866,7 @@ func TestHandleParseQuery(t *testing.T) {
&pgproto3.ParseComplete{},
})

remappedQuery := "SELECT usename, passwd FROM (VALUES ('bemidb', '10'::int8, 'FALSE'::bool, 'FALSE'::bool, 'TRUE'::bool, 'FALSE'::bool, 'bemidb-encrypted', 'NULL', 'NULL')) pg_shadow(usename, usesysid, usecreatedb, usesuper, userepl, usebypassrls, passwd, valuntil, useconfig) WHERE usename = $1"
remappedQuery := "SELECT usename, passwd FROM (VALUES ('bemidb', '10'::int8, 'FALSE'::bool, 'FALSE'::bool, 'TRUE'::bool, 'FALSE'::bool, 'bemidb-encrypted', NULL, NULL)) pg_shadow(usename, usesysid, usecreatedb, usesuper, userepl, usebypassrls, passwd, valuntil, useconfig) WHERE usename = $1"
if preparedStatement.Query != remappedQuery {
t.Errorf("Expected the prepared statement query to be %v, got %v", remappedQuery, preparedStatement.Query)
}
Expand Down
25 changes: 18 additions & 7 deletions src/query_parser_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,27 @@ func (utils *QueryParserUtils) MakeSubselectWithRowsNode(tableName string, colum
for _, row := range rowsValues {
var rowList []*pgQuery.Node
for _, val := range row {
constNode := pgQuery.MakeAConstStrNode(val, 0)
if _, err := strconv.ParseInt(val, 10, 64); err == nil {
constNode = parserType.MakeCaseTypeCastNode(constNode, "int8")
if val == "NULL" {
constNode := &pgQuery.Node{
Node: &pgQuery.Node_AConst{
AConst: &pgQuery.A_Const{
Isnull: true,
},
},
}
rowList = append(rowList, constNode)
} else {
valLower := strings.ToLower(val)
if valLower == "true" || valLower == "false" {
constNode = parserType.MakeCaseTypeCastNode(constNode, "bool")
constNode := pgQuery.MakeAConstStrNode(val, 0)
if _, err := strconv.ParseInt(val, 10, 64); err == nil {
constNode = parserType.MakeCaseTypeCastNode(constNode, "int8")
} else {
valLower := strings.ToLower(val)
if valLower == "true" || valLower == "false" {
constNode = parserType.MakeCaseTypeCastNode(constNode, "bool")
}
}
rowList = append(rowList, constNode)
}
rowList = append(rowList, constNode)
}
selectStmt.ValuesLists = append(selectStmt.ValuesLists,
&pgQuery.Node{Node: &pgQuery.Node_List{List: &pgQuery.List{Items: rowList}}})
Expand Down

0 comments on commit 2514060

Please sign in to comment.