-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pg_auto_parameterize_in_array extension, for converting IN/NOT IN…
… to = ANY or != ALL for more types When I originally developed the pg_auto_parameterize, I only handled integer arrays in order to avoid having that extension depend on the pg_array extension. This extension depends on both and adds support for the additional types.
- Loading branch information
1 parent
c33ec61
commit ba28830
Showing
6 changed files
with
351 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# frozen-string-literal: true | ||
# | ||
# The pg_auto_parameterize_in_array extension builds on the pg_auto_parameterize | ||
# extension, adding support for handling additional types when converting from | ||
# IN to = ANY and NOT IN to != ALL: | ||
# | ||
# DB[:table].where(column: [1.0, 2.0, ...]) | ||
# # Without extension: column IN ($1::numeric, $2:numeric, ...) # bound variables: 1.0, 2.0, ... | ||
# # With extension: column = ANY($1::numeric[]) # bound variables: [1.0, 2.0, ...] | ||
# | ||
# This prevents the use of an unbounded number of bound variables based on the | ||
# size of the array, as well as using different SQL for different array sizes. | ||
# | ||
# The following types are supported when doing the conversions, with the database | ||
# type used: | ||
# | ||
# Float :: if any are infinite or NaN, double precision, otherwise numeric | ||
# BigDecimal :: numeric | ||
# Date :: date | ||
# Time :: timestamp (or timestamptz if pg_timestamptz extension is used) | ||
# DateTime :: timestamp (or timestamptz if pg_timestamptz extension is used) | ||
# Sequel::SQLTime :: time | ||
# Sequel::SQL::Blob :: bytea | ||
# | ||
# String values are also supported using the +text+ type, but only if the | ||
# +:treat_string_list_as_text_array+ Database option is used. This is because | ||
# treating strings as text can break programs, since the type for | ||
# literal strings in PostgreSQL is +unknown+, not +text+. | ||
# | ||
# The conversion is only done for single dimensional arrays that have more | ||
# than two elements, where all elements are of the same class (other than | ||
# nil values). | ||
# | ||
# Related module: Sequel::Postgres::AutoParameterizeInArray | ||
|
||
module Sequel | ||
module Postgres | ||
# Enable automatically parameterizing queries. | ||
module AutoParameterizeInArray | ||
# Transform column IN (...) expressions into column = ANY($) | ||
# and column NOT IN (...) expressions into column != ALL($) | ||
# using an array bound variable for the ANY/ALL argument, | ||
# if all values inside the predicate are of the same type and | ||
# the type is handled by the extension. | ||
# This is the same optimization PostgreSQL performs internally, | ||
# but this reduces the number of bound variables. | ||
def complex_expression_sql_append(sql, op, args) | ||
case op | ||
when :IN, :"NOT IN" | ||
l, r = args | ||
if auto_param?(sql) && (type = _bound_variable_type_for_array(r)) | ||
if op == :IN | ||
op = :"=" | ||
func = :ANY | ||
else | ||
op = :!= | ||
func = :ALL | ||
end | ||
args = [l, Sequel.function(func, Sequel.pg_array(r, type))] | ||
end | ||
end | ||
|
||
super | ||
end | ||
|
||
private | ||
|
||
# The bound variable type string to use for the bound variable array. | ||
# Returns nil if a bound variable should not be used for the array. | ||
def _bound_variable_type_for_array(r) | ||
return unless Array === r && r.size > 1 | ||
classes = r.map(&:class) | ||
classes.uniq! | ||
classes.delete(NilClass) | ||
return unless classes.size == 1 | ||
|
||
klass = classes[0] | ||
if klass == Integer | ||
# This branch is not taken on Ruby <2.4, because of the Fixnum/Bignum split. | ||
# However, that causes no problems as pg_auto_parameterize handles integer | ||
# arrays natively (though the SQL used is different) | ||
"int8" | ||
elsif klass == String | ||
"text" if db.typecast_value(:boolean, db.opts[:treat_string_list_as_text_array]) | ||
elsif klass == BigDecimal | ||
"numeric" | ||
elsif klass == Date | ||
"date" | ||
elsif klass == Time | ||
@db.cast_type_literal(Time) | ||
elsif klass == Float | ||
# PostgreSQL treats literal floats as numeric, not double precision | ||
# But older versions of PostgreSQL don't handle Infinity/NaN in numeric | ||
r.all?{|v| v.nil? || v.finite?} ? "numeric" : "double precision" | ||
elsif klass == Sequel::SQLTime | ||
"time" | ||
elsif klass == DateTime | ||
@db.cast_type_literal(DateTime) | ||
elsif klass == Sequel::SQL::Blob | ||
"bytea" | ||
end | ||
end | ||
end | ||
end | ||
|
||
Database.register_extension(:pg_auto_parameterize_in_array) do |db| | ||
db.extension(:pg_array, :pg_auto_parameterize) | ||
db.extend_datasets(Postgres::AutoParameterizeInArray) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
require File.join(File.dirname(File.expand_path(__FILE__)), "spec_helper") | ||
|
||
describe "pg_auto_parameterize_in_array extension" do | ||
before do | ||
@db = Sequel.connect('mock://postgres') | ||
@db.synchronize{|c| def c.escape_bytea(v) v*2 end} | ||
@db.opts[:treat_string_list_as_text_array] = 't' | ||
@db.extension :pg_auto_parameterize_in_array | ||
end | ||
|
||
types = [ | ||
["strings if treat_string_list_as_text_array Database option is true", proc{|x| x.to_s}, "text"], | ||
["BigDecimals", proc{|x| BigDecimal(x)}, "numeric"], | ||
["dates", proc{|x| Date.new(2021, x)}, "date"], | ||
["times", proc{|x| Time.local(2021, x)}, "timestamp"], | ||
["SQLTimes", proc{|x| Sequel::SQLTime.create(x, 0, 0)}, "time"], | ||
["datetimes", proc{|x| DateTime.new(2021, x)}, "timestamp"], | ||
["floats", proc{|x| Float(x)}, "numeric"], | ||
["blobs", proc{|x| Sequel.blob(x.to_s)}, "bytea"], | ||
] | ||
|
||
if RUBY_VERSION >= '2.4' | ||
types << ["integers", proc{|x| x}, "int8"] | ||
else | ||
it "should fallback to pg_auto_parameterize extension behavior when switching column IN/NOT IN to = ANY/!= ALL for integers" do | ||
v = [1, 2, 3] | ||
nv = [1, nil, 3] | ||
type = "int8" | ||
|
||
sql = @db[:table].where(:a=>v).sql | ||
sql.must_equal %'SELECT * FROM \"table\" WHERE ("a" = ANY(CAST($1 AS #{type}[])))' | ||
sql.args.must_equal ['{1,2,3}'] | ||
|
||
sql = @db[:table].where(:a=>nv).sql | ||
sql.must_equal %'SELECT * FROM \"table\" WHERE ("a" = ANY(CAST($1 AS #{type}[])))' | ||
sql.args.must_equal ['{1,NULL,3}'] | ||
|
||
sql = @db[:table].exclude(:a=>v).sql | ||
sql.must_equal %'SELECT * FROM \"table\" WHERE ("a" != ALL(CAST($1 AS #{type}[])))' | ||
sql.args.must_equal ['{1,2,3}'] | ||
|
||
sql = @db[:table].exclude(:a=>nv).sql | ||
sql.must_equal %'SELECT * FROM \"table\" WHERE ("a" != ALL(CAST($1 AS #{type}[])))' | ||
sql.args.must_equal ['{1,NULL,3}'] | ||
end | ||
end | ||
|
||
types.each do |desc, conv, type| | ||
it "should automatically switch column IN/NOT IN to = ANY/!= ALL for #{desc}" do | ||
v = [1,2,3].map(&conv) | ||
nv = (v + [nil]).freeze | ||
|
||
sql = @db[:table].where(:a=>v).sql | ||
sql.must_equal %'SELECT * FROM \"table\" WHERE ("a" = ANY($1::#{type}[]))' | ||
sql.args.must_equal [v] | ||
|
||
sql = @db[:table].where(:a=>nv).sql | ||
sql.must_equal %'SELECT * FROM "table" WHERE ("a" = ANY($1::#{type}[]))' | ||
sql.args.must_equal [nv] | ||
|
||
sql = @db[:table].exclude(:a=>v).sql | ||
sql.must_equal %'SELECT * FROM "table" WHERE ("a" != ALL($1::#{type}[]))' | ||
sql.args.must_equal [v] | ||
|
||
sql = @db[:table].exclude(:a=>nv).sql | ||
sql.must_equal %'SELECT * FROM "table" WHERE ("a" != ALL($1::#{type}[]))' | ||
sql.args.must_equal [nv] | ||
end | ||
end | ||
|
||
it "should automatically switch column IN/NOT IN to = ANY/!= ALL for infinite/NaN floats" do | ||
v = [1.0, 1.0/0.0, -1.0/0.0, 0.0/0.0] | ||
nv = (v + [nil]).freeze | ||
type = "double precision" | ||
|
||
sql = @db[:table].where(:a=>v).sql | ||
sql.must_equal %'SELECT * FROM \"table\" WHERE ("a" = ANY($1::#{type}[]))' | ||
sql.args.must_equal [v] | ||
|
||
sql = @db[:table].where(:a=>nv).sql | ||
sql.must_equal %'SELECT * FROM "table" WHERE ("a" = ANY($1::#{type}[]))' | ||
sql.args.must_equal [nv] | ||
|
||
sql = @db[:table].exclude(:a=>v).sql | ||
sql.must_equal %'SELECT * FROM "table" WHERE ("a" != ALL($1::#{type}[]))' | ||
sql.args.must_equal [v] | ||
|
||
sql = @db[:table].exclude(:a=>nv).sql | ||
sql.must_equal %'SELECT * FROM "table" WHERE ("a" != ALL($1::#{type}[]))' | ||
sql.args.must_equal [nv] | ||
end | ||
|
||
it "should not automatically switch column IN/NOT IN to = ANY/!= ALL for strings by default" do | ||
@db.opts.delete(:treat_string_list_as_text_array) | ||
v = %w'1 2' | ||
sql = @db[:table].where([:a, :b]=>v).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE (("a", "b") IN ($1, $2))' | ||
sql.args.must_equal v | ||
|
||
sql = @db[:table].exclude([:a, :b]=>v).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE (("a", "b") NOT IN ($1, $2))' | ||
sql.args.must_equal v | ||
end | ||
|
||
it "should not convert IN/NOT IN expressions that use unsupported types" do | ||
v = [Sequel.lit('1'), Sequel.lit('2')].freeze | ||
sql = @db[:table].where([:a, :b]=>v).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE (("a", "b") IN (1, 2))' | ||
sql.args.must_be_nil | ||
|
||
sql = @db[:table].exclude([:a, :b]=>v).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE (("a", "b") NOT IN (1, 2))' | ||
sql.args.must_be_nil | ||
end | ||
|
||
it "should not convert multiple column IN expressions" do | ||
sql = @db[:table].where([:a, :b]=>[[1.0, 2.0]]).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE (("a", "b") IN (($1::numeric, $2::numeric)))' | ||
sql.args.must_equal [1, 2] | ||
|
||
sql = @db[:table].exclude([:a, :b]=>[[1.0, 2.0]]).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE (("a", "b") NOT IN (($1::numeric, $2::numeric)))' | ||
sql.args.must_equal [1, 2] | ||
end | ||
|
||
it "should not convert single value expressions" do | ||
sql = @db[:table].where(:a=>[1.0]).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE ("a" IN ($1::numeric))' | ||
sql.args.must_equal [1] | ||
|
||
sql = @db[:table].where(:a=>[1.0]).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE ("a" IN ($1::numeric))' | ||
sql.args.must_equal [1] | ||
end | ||
|
||
it "should not convert expressions with mixed types" do | ||
sql = @db[:table].where(:a=>[1, 2.0]).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE ("a" IN ($1::int4, $2::numeric))' | ||
sql.args.must_equal [1, 2.0] | ||
|
||
sql = @db[:table].where(:a=>[1, 2.0]).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE ("a" IN ($1::int4, $2::numeric))' | ||
sql.args.must_equal [1, 2.0] | ||
end | ||
|
||
it "should not convert other expressions" do | ||
sql = @db[:table].where(:a=>1).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE ("a" = $1::int4)' | ||
sql.args.must_equal [1] | ||
|
||
sql = @db[:table].where(:a=>@db[:table]).sql | ||
sql.must_equal 'SELECT * FROM "table" WHERE ("a" IN (SELECT * FROM "table"))' | ||
sql.args.must_be_nil | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters