diff --git a/packages/pg-protocol/src/buffer-writer.ts b/packages/pg-protocol/src/buffer-writer.ts index cebb0d9ed..bd6b6b429 100644 --- a/packages/pg-protocol/src/buffer-writer.ts +++ b/packages/pg-protocol/src/buffer-writer.ts @@ -82,4 +82,9 @@ export class Writer { this.buffer = Buffer.allocUnsafe(this.size) return result } + + public clear(): void { + this.offset = 5 + this.headerPosition = 0 + } } diff --git a/packages/pg-protocol/src/outbound-serializer.test.ts b/packages/pg-protocol/src/outbound-serializer.test.ts index 0d3e387e4..d19a8c9af 100644 --- a/packages/pg-protocol/src/outbound-serializer.test.ts +++ b/packages/pg-protocol/src/outbound-serializer.test.ts @@ -273,4 +273,62 @@ describe('serializer', () => { const expected = new BufferList().addInt16(1234).addInt16(5678).addInt32(3).addInt32(4).join(true) assert.deepEqual(actual, expected) }) + + describe('bind error recovery', () => { + const throwingMapper = () => { + throw new Error('valueMapper error') + } + + it('produces correct bind output after a valueMapper exception', () => { + assert.throws(() => { + serialize.bind({ + values: ['fail'], + valueMapper: throwingMapper, + }) + }, /valueMapper error/) + + const actual = serialize.bind({ + portal: 'bang', + statement: 'woo', + values: ['1', 'hi', null, 'zing'], + }) + const expectedBuffer = new BufferList() + .addCString('bang') + .addCString('woo') + .addInt16(4) + .addInt16(0) + .addInt16(0) + .addInt16(0) + .addInt16(0) + .addInt16(4) + .addInt32(1) + .add(Buffer.from('1')) + .addInt32(2) + .add(Buffer.from('hi')) + .addInt32(-1) + .addInt32(4) + .add(Buffer.from('zing')) + .addInt16(1) + .addInt16(0) + .join(true, 'B') + assert.deepEqual(actual, expectedBuffer) + }) + + it('produces correct output from other serializer methods after a failed bind', () => { + assert.throws(() => { + serialize.bind({ + values: ['fail'], + valueMapper: throwingMapper, + }) + }, /valueMapper error/) + + const parseActual = serialize.parse({ text: '!' }) + const parseExpected = new BufferList().addCString('').addCString('!').addInt16(0).join(true, 'P') + assert.deepEqual(parseActual, parseExpected) + + const queryActual = serialize.query('select 1') + const queryExpected = new BufferList().addCString('select 1').join(true, 'Q') + assert.deepEqual(queryActual, queryExpected) + }) + }) }) diff --git a/packages/pg-protocol/src/serializer.ts b/packages/pg-protocol/src/serializer.ts index bb0441f56..4ad42ba06 100644 --- a/packages/pg-protocol/src/serializer.ts +++ b/packages/pg-protocol/src/serializer.ts @@ -152,7 +152,13 @@ const bind = (config: BindOpts = {}): Buffer => { writer.addCString(portal).addCString(statement) writer.addInt16(len) - writeValues(values, config.valueMapper) + try { + writeValues(values, config.valueMapper) + } catch (err) { + writer.clear() + paramWriter.clear() + throw err + } writer.addInt16(len) writer.add(paramWriter.flush()) diff --git a/packages/pg/lib/query.js b/packages/pg/lib/query.js index 64aab5ff2..04e1c1d65 100644 --- a/packages/pg/lib/query.js +++ b/packages/pg/lib/query.js @@ -228,6 +228,10 @@ class Query extends EventEmitter { valueMapper: utils.prepareValue, }) } catch (err) { + // we should close parse to avoid leaking connections + connection.close({ type: 'S', name: this.name }) + connection.sync() + this.handleError(err, connection) return } diff --git a/packages/pg/test/unit/client/throw-in-bind-tests.js b/packages/pg/test/unit/client/throw-in-bind-tests.js new file mode 100644 index 000000000..8b460b9e4 --- /dev/null +++ b/packages/pg/test/unit/client/throw-in-bind-tests.js @@ -0,0 +1,86 @@ +'use strict' +const helper = require('./test-helper') +const Query = require('../../../lib/query') +const assert = require('assert') + +const suite = new helper.Suite() + +const bindError = new Error('TEST: Throw in bind') + +const setupClient = function () { + const client = helper.client() + const con = client.connection + const calls = { parse: 0, sync: 0, describe: 0, execute: 0, close: 0 } + + con.parse = function () { + calls.parse++ + } + con.bind = function () { + throw bindError + } + con.describe = function () { + calls.describe++ + assert.fail('describe should not be called when bind throws') + } + con.execute = function () { + calls.execute++ + assert.fail('execute should not be called when bind throws') + } + con.close = function () { + calls.close++ + } + con.sync = function () { + calls.sync++ + } + + return { client, con, calls } +} + +suite.test('calls callback with error when bind throws', function (done) { + const { client, con, calls } = setupClient() + con.emit('readyForQuery') + client.query( + new Query({ + text: 'select $1', + values: ['x'], + callback: function (err) { + assert.equal(err, bindError) + assert.equal(calls.sync, 1, 'sync should be called once') + assert.equal(calls.describe, 0, 'describe should not be called') + assert.equal(calls.execute, 0, 'execute should not be called') + done() + }, + }) + ) +}) + +suite.test('emits error event when bind throws (no callback)', function (done) { + const { client, con, calls } = setupClient() + con.emit('readyForQuery') + const query = new Query({ + text: 'select $1', + values: ['x'], + }) + query.on('error', function (err) { + assert.equal(err, bindError) + assert.equal(calls.sync, 1, 'sync should be called once') + done() + }) + client.query(query) +}) + +suite.test('send close when bind throws', function (done) { + const { client, con, calls } = setupClient() + con.emit('readyForQuery') + client.query( + new Query({ + text: 'select $1', + values: ['x'], + callback: function (err) { + assert.equal(err, bindError) + assert.equal(calls.close, 1, 'close should be called') + done() + }, + }) + ) +})