diff --git a/ext/pg_type_map.c b/ext/pg_type_map.c index 8f06cb12f..d5bb26d12 100644 --- a/ext/pg_type_map.c +++ b/ext/pg_type_map.c @@ -115,6 +115,45 @@ pg_typemap_s_allocate( VALUE klass ) return self; } +/* + * call-seq: + * res.query_param_encoders(params) + * + * Retrieve the encoders that are used to encode the given values to be submitted to the database server. + * The selection of the encoders is defined in the derived type map class. + * + * +params+ must be an Array of values to be encoded. + * It's like +params+ given to exec_params . + * + * Returns an Array with the same length as +params+. + * + */ +static VALUE +pg_typemap_query_param_encoders( VALUE self, VALUE params ) +{ + t_typemap *this = RTYPEDDATA_DATA( self ); + int nParams; + int i=0; + VALUE res; + + Check_Type(params, T_ARRAY); + + this->funcs.fit_to_query( self, params ); + + nParams = RARRAY_LENINT(params); + res = rb_ary_new2(nParams); + + for ( i = 0; i < nParams; i++ ) { + t_pg_coder *conv; + VALUE param_value = rb_ary_entry(params, i); + + /* Let the given typemap select a coder for this param */ + conv = this->funcs.typecast_query_param(this, param_value, i); + rb_ary_push(res, conv ? conv->coder_obj : Qnil); + } + return res; +} + /* * call-seq: * res.default_type_map = typemap @@ -194,6 +233,7 @@ init_pg_type_map(void) */ rb_cTypeMap = rb_define_class_under( rb_mPG, "TypeMap", rb_cObject ); rb_define_alloc_func( rb_cTypeMap, pg_typemap_s_allocate ); + rb_define_method( rb_cTypeMap, "query_param_encoders", pg_typemap_query_param_encoders, 1 ); rb_mDefaultTypeMappable = rb_define_module_under( rb_cTypeMap, "DefaultTypeMappable"); rb_define_method( rb_mDefaultTypeMappable, "default_type_map=", pg_typemap_default_type_map_set, 1 ); diff --git a/lib/pg/basic_type_map_based_on_result.rb b/lib/pg/basic_type_map_based_on_result.rb index 301bcbf5a..36d8d0dac 100644 --- a/lib/pg/basic_type_map_based_on_result.rb +++ b/lib/pg/basic_type_map_based_on_result.rb @@ -64,4 +64,9 @@ def initialize(connection_or_coder_maps, registry: nil) add_coder(coder) end end + + # Returns the PG::BasicTypeRegistry::CoderMapsBundle used to translate result OIDs to encoders. + def coder_maps_bundle + @coder_maps + end end diff --git a/lib/pg/basic_type_map_for_queries.rb b/lib/pg/basic_type_map_for_queries.rb index 3b0d492d9..926beb2e1 100644 --- a/lib/pg/basic_type_map_for_queries.rb +++ b/lib/pg/basic_type_map_for_queries.rb @@ -57,6 +57,11 @@ def initialize(connection_or_coder_maps, registry: nil, if_undefined: nil) init_encoders end + # Returns the PG::BasicTypeRegistry::CoderMapsBundle used to translate encoders to OIDs. + def coder_maps_bundle + @coder_maps + end + class UndefinedDefault def self.call(oid_name, format) raise UndefinedEncoder, "no encoder defined for type #{oid_name.inspect} format #{format}" @@ -177,8 +182,8 @@ def get_array_type(value) end DEFAULT_TYPE_MAP = PG.make_shareable({ - TrueClass => [1, 'bool', 'bool'], - FalseClass => [1, 'bool', 'bool'], + TrueClass => [0, 'bool', 'bool'], + FalseClass => [0, 'bool', 'bool'], # We use text format and no type OID for numbers, because setting the OID can lead # to unnecessary type conversions on server side. Integer => [0, 'int8'], diff --git a/lib/pg/basic_type_map_for_results.rb b/lib/pg/basic_type_map_for_results.rb index 929320673..015fa8c65 100644 --- a/lib/pg/basic_type_map_for_results.rb +++ b/lib/pg/basic_type_map_for_results.rb @@ -101,4 +101,9 @@ def initialize(connection_or_coder_maps, registry: nil) typenames = @coder_maps.typenames_by_oid self.default_type_map = WarningTypeMap.new(typenames) end + + # Returns the PG::BasicTypeRegistry::CoderMapsBundle used to translate result OIDs to decoders. + def coder_maps_bundle + @coder_maps + end end diff --git a/lib/pg/connection.rb b/lib/pg/connection.rb index 09e23b50c..5c4c97327 100644 --- a/lib/pg/connection.rb +++ b/lib/pg/connection.rb @@ -669,6 +669,89 @@ def cancel end alias async_cancel cancel + PLACEHOLDER_RE = / + '(?:''|[^'])*' | # string literal + "(?:""|[^"])*" | # quoted identifier + --[^\n]* | # line comment + \/\*.*?\*\/ | # block comment + \$\$.*?\$\$ | # dollar-quoted string. E.g. $$ $1 $$ + \$(?<__dq_tag>[A-Za-z_][A-Za-z_0-9]*)\$.*?\$\k<__dq_tag>\$ | # named dollar-quoted string. E.g. $foo$ $1 $foo$ + (?\$(?:[1-9]\d*)) # placeholder we are interested in + /mx + private_constant :PLACEHOLDER_RE + + # Compiles your prepared SQL statement and the given positional arguments into plain SQL string. + # + # The resulting SQL string can be used with +conn.exec+ like the prepared SQL statement and parameters with +conn.exec_params+. + # +conn.exec_params+ is usually preferred because it's faster and safer. + # +embed_params+ is intended for debugging messages with positional parameters. + # It avoids manual insertion for later inspection in +psql+ or so. + # + # Example: + # res = conn.embed_params('SELECT $1 AS a, $2 AS b, $3 AS c', [1, 2, nil]) + # # => "SELECT '1' AS a, '2' AS b, NULL AS c" + def embed_params(sql, params, type_map: type_map_for_queries, coder_maps_bundle: nil) + return sql if params.empty? + + oid_to_typecast = proc do |oid| + if oid && oid > 0 + by_oid = if coder_maps_bundle + # Try to retrieve types from the method argument + coder_maps_bundle.typenames_by_oid + elsif type_map.respond_to?(:coder_maps_bundle) + # Try to retrieve types from the current type map + type_map.coder_maps_bundle.typenames_by_oid + elsif @typenames_by_oid + # Try to use cached types + @typenames_by_oid + else + # Load and cache types from the database server + @typenames_by_oid = PG::BasicTypeRegistry::CoderMapsBundle.new(self).typenames_by_oid + end + typename = by_oid[oid] || raise(ArgumentError, "cannot determine database type name of OID #{oid}") + "::#{ typename }" + end + end + + encoders = type_map.query_param_encoders(params) + params = encoders.map.with_index do |enc, i| + value = params[i] + case value + when NilClass + 'NULL' + when PG::BasicTypeMapForQueries::BinaryData + "'#{ escape_bytea(value) }'" + else + if enc + raise ArgumentError, "binary encoded data from #{enc} cannot be inserted into SQL text" if enc.format != 0 + "'#{escape(enc.encode(value))}'#{oid_to_typecast[enc.oid]}" + elsif Hash === value + next case value[:value] + when NilClass + 'NULL' + else + if value[:format] == 1 + raise ArgumentError, "binary encoded data with OID #{value[:type]} cannot be inserted into SQL text" if value[:type] && value[:type] != 17 + "'#{ escape_bytea(value[:value].to_s) }'#{oid_to_typecast[value[:type]]}" + else + "'#{escape(value[:value].to_s)}'#{oid_to_typecast[value[:type]]}" + end + end + else + "'#{escape(value.to_s)}'" + end + end + end + + sql.gsub(PLACEHOLDER_RE).each do |matched| + placeholder = Regexp.last_match[:placeholder] + # Do not replace non-positional args string and pass it as is + next matched unless placeholder + + params[placeholder[1..].to_i - 1] + end + end + module Pollable # Track the progress of the connection, waiting for the socket to become readable/writable before polling it. # diff --git a/spec/pg/basic_type_map_based_on_result_spec.rb b/spec/pg/basic_type_map_based_on_result_spec.rb index 0d0666f24..a3cfdb289 100644 --- a/spec/pg/basic_type_map_based_on_result_spec.rb +++ b/spec/pg/basic_type_map_based_on_result_spec.rb @@ -13,6 +13,7 @@ maps = PG::BasicTypeRegistry::CoderMapsBundle.new(@conn).freeze tm = PG::BasicTypeMapBasedOnResult.new(maps) expect( tm.rm_coder(0, 16) ).to be_kind_of(PG::TextEncoder::Boolean) + expect( tm.coder_maps_bundle ).to eq(maps) end it "can be initialized with a custom type registry" do diff --git a/spec/pg/basic_type_map_for_queries_spec.rb b/spec/pg/basic_type_map_for_queries_spec.rb index 143afbf47..7f5c0bb8e 100644 --- a/spec/pg/basic_type_map_for_queries_spec.rb +++ b/spec/pg/basic_type_map_for_queries_spec.rb @@ -39,6 +39,7 @@ maps = PG::BasicTypeRegistry::CoderMapsBundle.new(@conn).freeze tm = PG::BasicTypeMapForQueries.new(maps) expect( tm[Integer] ).to be_kind_of(PG::TextEncoder::Integer) + expect( tm.coder_maps_bundle ).to eq(maps) end it "can be initialized with a custom type registry" do @@ -54,7 +55,7 @@ args = [] pr = proc { |*a| args << a } PG::BasicTypeMapForQueries.new(@conn, registry: regi, if_undefined: pr) - expect( args.first ).to eq( ["bool", 1] ) + expect( args.first ).to eq( ["bool", 0] ) end it "raises UndefinedEncoder for undefined types" do diff --git a/spec/pg/basic_type_map_for_results_spec.rb b/spec/pg/basic_type_map_for_results_spec.rb index d3176c43c..cee9b51a1 100644 --- a/spec/pg/basic_type_map_for_results_spec.rb +++ b/spec/pg/basic_type_map_for_results_spec.rb @@ -15,6 +15,7 @@ maps = PG::BasicTypeRegistry::CoderMapsBundle.new(@conn).freeze tm = PG::BasicTypeMapForResults.new(maps) expect( tm.rm_coder(0, 16) ).to be_kind_of(PG::TextDecoder::Boolean) + expect( tm.coder_maps_bundle ).to eq(maps) end it "can be initialized with a custom type registry" do diff --git a/spec/pg/connection_spec.rb b/spec/pg/connection_spec.rb index 1546344f7..e6cfb7512 100644 --- a/spec/pg/connection_spec.rb +++ b/spec/pg/connection_spec.rb @@ -3013,6 +3013,121 @@ def wait_check_socket(conn) end end + describe :embed_params do + + def embed_params_and_check(sql, params, conn: @conn) + compiled = conn.embed_params(sql, params) + + res = conn.exec(compiled) + res2 = conn.exec_params(sql, params) + expect( res.to_a ).to eq( res2.to_a ), compiled + compiled + end + + def with_std_conf_strings(conn, onoff) + conn.exec("SET standard_conforming_strings = #{onoff}") + conn.exec("SET escape_string_warning = #{onoff}") + yield + ensure + conn.exec("SET standard_conforming_strings = on") + conn.exec("SET escape_string_warning = on") + end + + describe "default type map" do + it "compiles prepared sql into plain sql" do + compiled = embed_params_and_check(<<~SQL, [1, "2", true, false, nil]) + -- this is one: $1 + /* this is another one: $1 */ + select $1::int as a, $2 as b, $3 as c, $4 as d, $5 as e, '$5' as f, $$ $6 $$ as g, -- this is two: $2 + $body$ $1 $body$ as h, t."$1", t."$2" + from (select 10 as "$1", 20 as "$2") as t + SQL + + aggregate_failures do + expect(compiled).to include("-- this is one: $1") + expect(compiled).to include("/* this is another one: $1 */") + expect(compiled).to include("-- this is two: $2") + end + end + + it "escapes strings properly" do + embed_params_and_check(<<~SQL, ["', '1"]) + select $1 as one + SQL + end + + context "with params as Hash" do + + ['on', 'off'].each do |stdconf| + it "encodes values properly with std conforming strings=#{stdconf}" do + with_std_conf_strings(@conn, stdconf) do + params = [ + {value: "'\x1F\\".b, format: 1}, + {value: "'\0\xff\r\n\t1'".b, format: 1, type: 17}, + {value: "abc"}, + {value: 4}, + {value: 5, type: 23}, + {value: "{ 6, 7}", type: 1007}, + {value: false}, + {value: "\\x000102ff", type: 17}, + {value: nil} + ] + embed_params_and_check <<~SQL, params + select $1::bytea as a, $2 as b, $3 as c, $4 as d, $5 as e, $6 as f, $7 as g, $8 as h, $9 as i + SQL + end + end + end + end + end + + describe "PG::TypeMapByClass type map" do + before do + @conn2 = PG.connect(@conninfo) + @conn2.type_map_for_queries = PG::BasicTypeMapForQueries.new(@conn2) + @conn2.type_map_for_results = PG::BasicTypeMapForResults.new(@conn2) + end + + after do + @conn2.close + end + + it "compiles prepared sql into plain sql" do + compiled = embed_params_and_check(<<~SQL, [1, "2", { foo: :bar }, [1], true, false, nil], conn: @conn2) + -- this is one: $1 + /* this is another one: $1 */ + select $1::int as a, $2 as b, $3::json as c, $4::int[] as d, '$5' as e, $$ $6 $$ as f, + $body$ $1 $body$ as g, -- this is two: $2 + t."$1", t."$2", $5 as h, $6 as i, $7 j + from (select 10 as "$1", 20 as "$2") as t + SQL + + aggregate_failures do + expect(compiled).to include("-- this is one: $1") + expect(compiled).to include("/* this is another one: $1 */") + expect(compiled).to include("-- this is two: $2") + end + end + + it "escapes strings properly" do + embed_params_and_check(<<~SQL, ["', '1"], conn: @conn2) + select $1 as one + SQL + end + + ['on', 'off'].each do |stdconf| + it "encodes binary strings properly with std conforming strings=#{stdconf}" do + with_std_conf_strings(@conn, stdconf) do + binary = PG::BasicTypeMapForQueries::BinaryData.new("''\0\xff\r\n\t'".b) + embed_params_and_check(<<~SQL, [binary], conn: @conn2) + select $1::bytea as one + SQL + end + end + end + end + end + describe "deprecated forms of methods" do if PG::VERSION < "2" it "should forward exec to exec_params" do