Skip to content

Commit 7ca265e

Browse files
committed
Allow explicit schema injection to rivertest.Require* test functions (#926)
Here, resolve #907 by letting an explicit schema be injected into `rivertest.Require*` assertions in a similar way that one can be used in a client. This approach adds a schema in `RequireInsertedOpts`. This comment does a good job of highlight all the potential approaches for adding a schema [1], and unfortunately none of them are all that great. I implemented one other version of this (a variant of option 2 in that list), which as some advantages, but in the end it just ended up ballooning the API out to an uncomfortable degree. The worst part about adding schema to `RequireInsertedOpts` is its interact with the `RequireMany*` functions, where each expectation can set its own schema, and it's not clear what would happen if different expectations set different schemas. I resolved this ambiguity by making it an error to mix and match schemas. Assertions are allowed to send a schema in only the first position like: jobs := requireManyInserted(ctx, bundle.mockT, bundle.driver, []ExpectedJob{ {Args: &Job1Args{String: "foo"}, Opts: bundle.schemaOpts}, {Args: &Job1Args{String: "bar"}}, }) Or send the same schema in all positions: jobs := requireManyInserted(ctx, bundle.mockT, bundle.driver, []ExpectedJob{ {Args: &Job1Args{String: "foo"}, Opts: bundle.schemaOpts}, {Args: &Job1Args{String: "bar"}, Opts: bundle.schemaOpts}, }) But they aren't allowed to set a schema only in position other than the first, or mix and match schemas between expectations. Fixes #907. [1] #907 (comment)
1 parent 1c1c851 commit 7ca265e

11 files changed

Lines changed: 535 additions & 219 deletions

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1414
- Preliminary River driver for SQLite (`riverdriver/riversqlite`). This driver seems to produce good results as judged by the test suite, but so far has minimal real world vetting. Try it and let us know how it works out. [PR #870](https://github.com/riverqueue/river/pull/870).
1515
- CLI `river migrate-get` now takes a `--schema` option to inject a custom schema into dumped migrations and schema comments are hidden if `--schema` option isn't provided. [PR #903](https://github.com/riverqueue/river/pull/903).
1616
- Added `riverlog.NewMiddlewareCustomContext` that makes the use of `riverlog` job-persisted logging possible with non-slog loggers. [PR #919](https://github.com/riverqueue/river/pull/919).
17+
- Added `RequireInsertedOpts.Schema`, allowing an explicit schema to be set when asserting on job inserts with `rivertest`. [PR #926](https://github.com/riverqueue/river/pull/926).
18+
- Added `JobListParams.Where`, which provides an escape hatch for job listing that runs arbitrary SQL with named parameters. [PR #933](https://github.com/riverqueue/river/pull/933).
1719

1820
### Changed
1921

client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,7 @@ func (c *Client[TTx]) JobList(ctx context.Context, params *JobListParams) (*JobL
20172017
}
20182018
params.schema = c.config.Schema
20192019

2020-
if c.driver.DatabaseName() == databaseNameSQLite && params.metadataFragment != "" {
2020+
if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled {
20212021
return nil, errJobListParamsMetadataNotSupportedSQLite
20222022
}
20232023

@@ -2052,7 +2052,7 @@ func (c *Client[TTx]) JobListTx(ctx context.Context, tx TTx, params *JobListPara
20522052
}
20532053
params.schema = c.config.Schema
20542054

2055-
if c.driver.DatabaseName() == databaseNameSQLite && params.metadataFragment != "" {
2055+
if c.driver.DatabaseName() == databaseNameSQLite && params.metadataCalled {
20562056
return nil, errJobListParamsMetadataNotSupportedSQLite
20572057
}
20582058

client_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3776,6 +3776,86 @@ func Test_Client_JobList(t *testing.T) {
37763776
require.Equal(t, []int64{job3.ID, job2.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
37773777
})
37783778

3779+
t.Run("ArbitraryWhereRawSQL", func(t *testing.T) {
3780+
t.Parallel()
3781+
3782+
client, bundle := setup(t)
3783+
3784+
var (
3785+
job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"foo": "bar"}`), Schema: bundle.schema})
3786+
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
3787+
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
3788+
)
3789+
3790+
listRes, err := client.JobList(ctx, NewJobListParams().Where(`jsonb_path_query_first(metadata, '$.foo') = '"bar"'::jsonb`))
3791+
require.NoError(t, err)
3792+
require.Equal(t, []int64{job1.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
3793+
})
3794+
3795+
t.Run("ArbitraryWhereNamedParams", func(t *testing.T) {
3796+
t.Parallel()
3797+
3798+
client, bundle := setup(t)
3799+
3800+
var (
3801+
job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"foo": "bar"}`), Schema: bundle.schema})
3802+
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
3803+
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Metadata: []byte(`{"baz": "value"}`), Schema: bundle.schema})
3804+
)
3805+
3806+
listRes, err := client.JobList(ctx, NewJobListParams().Where("jsonb_path_query_first(metadata, @json_query) = @json_val", NamedArgs{
3807+
"json_query": "$.foo",
3808+
"json_val": `"bar"`,
3809+
}))
3810+
require.NoError(t, err)
3811+
require.Equal(t, []int64{job1.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
3812+
})
3813+
3814+
t.Run("ArbitraryWhereMultipleNamedParams", func(t *testing.T) {
3815+
t.Parallel()
3816+
3817+
client, bundle := setup(t)
3818+
3819+
var (
3820+
job1 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
3821+
job2 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
3822+
job3 = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
3823+
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
3824+
)
3825+
3826+
listRes, err := client.JobList(ctx, NewJobListParams().Where("id IN (@id1, @id2, @id3)",
3827+
NamedArgs{"id1": job1.ID},
3828+
NamedArgs{"id2": job2.ID},
3829+
NamedArgs{"id3": job3.ID},
3830+
))
3831+
require.NoError(t, err)
3832+
require.Equal(t, []int64{job1.ID, job2.ID, job3.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
3833+
})
3834+
3835+
t.Run("ArbitraryWhereMultipleClauses", func(t *testing.T) {
3836+
t.Parallel()
3837+
3838+
client, bundle := setup(t)
3839+
3840+
var (
3841+
job = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{
3842+
MaxAttempts: ptrutil.Ptr(27),
3843+
Queue: ptrutil.Ptr("custom_queue"),
3844+
Schema: bundle.schema,
3845+
State: ptrutil.Ptr(rivertype.JobStateDiscarded),
3846+
})
3847+
_ = testfactory.Job(ctx, t, bundle.exec, &testfactory.JobOpts{Schema: bundle.schema})
3848+
)
3849+
3850+
listRes, err := client.JobList(ctx, NewJobListParams().
3851+
Where("kind = @kind", NamedArgs{"kind": job.Kind}).
3852+
Where("max_attempts = @max_attempts", NamedArgs{"max_attempts": job.MaxAttempts}).
3853+
Where("queue = @queue", NamedArgs{"queue": job.Queue}),
3854+
)
3855+
require.NoError(t, err)
3856+
require.Equal(t, []int64{job.ID}, sliceutil.Map(listRes.Jobs, func(job *rivertype.JobRow) int64 { return job.ID }))
3857+
})
3858+
37793859
t.Run("WithCancelledContext", func(t *testing.T) {
37803860
t.Parallel()
37813861

driver_client_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,33 @@ func ExerciseClient[TTx any](ctx context.Context, t *testing.T,
332332
require.Equal(t, job.ID, listRes.Jobs[0].ID)
333333
})
334334

335+
t.Run("JobListTxWhere", func(t *testing.T) {
336+
t.Parallel()
337+
338+
client, bundle := setup(t)
339+
340+
tx, execTx := beginTx(ctx, t, bundle)
341+
342+
job := testfactory.Job(ctx, t, execTx, &testfactory.JobOpts{
343+
Metadata: []byte(`{"foo":"bar","bar":"baz"}`),
344+
Schema: bundle.schema,
345+
})
346+
347+
listParams := NewJobListParams()
348+
349+
if client.driver.DatabaseName() == databaseNameSQLite {
350+
listParams = listParams.Where("metadata ->> @json_path = @json_val", NamedArgs{"json_path": "$.foo", "json_val": "bar"})
351+
} else {
352+
// "bar" is quoted in this branch because `jsonb_path_query_first` needs to be compared to a JSON value
353+
listParams = listParams.Where("jsonb_path_query_first(metadata, @json_path) = @json_val", NamedArgs{"json_path": "$.foo", "json_val": `"bar"`})
354+
}
355+
356+
listRes, err := client.JobListTx(ctx, tx, listParams)
357+
require.NoError(t, err)
358+
require.Len(t, listRes.Jobs, 1)
359+
require.Equal(t, job.ID, listRes.Jobs[0].ID)
360+
})
361+
335362
t.Run("QueueGet", func(t *testing.T) {
336363
t.Parallel()
337364

internal/dblist/db_list.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,27 @@ type JobListOrderBy struct {
2525
}
2626

2727
type JobListParams struct {
28-
Conditions string
2928
IDs []int64
3029
Kinds []string
3130
LimitCount int32
32-
NamedArgs map[string]any
3331
OrderBy []JobListOrderBy
3432
Priorities []int16
3533
Queues []string
3634
Schema string
3735
States []rivertype.JobState
36+
Where []WherePredicate
37+
}
38+
39+
type WherePredicate struct {
40+
NamedArgs map[string]any
41+
SQL string
3842
}
3943

4044
func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListParams, sqlFragmentColumnIn func(column string, values any) (string, any, error)) ([]*rivertype.JobRow, error) {
41-
var whereBuilder strings.Builder
45+
var (
46+
namedArgs = make(map[string]any)
47+
whereBuilder strings.Builder
48+
)
4249

4350
orderBy := make([]JobListOrderBy, len(params.OrderBy))
4451
for i, o := range params.OrderBy {
@@ -48,11 +55,6 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara
4855
}
4956
}
5057

51-
namedArgs := params.NamedArgs
52-
if namedArgs == nil {
53-
namedArgs = make(map[string]any)
54-
}
55-
5658
// Writes an `AND` to connect SQL predicates as long as this isn't the first
5759
// predicate.
5860
writeAndAfterFirst := func() {
@@ -122,9 +124,22 @@ func JobList(ctx context.Context, exec riverdriver.Executor, params *JobListPara
122124
namedArgs[column] = arg
123125
}
124126

125-
if params.Conditions != "" {
127+
for _, where := range params.Where {
126128
writeAndAfterFirst()
127-
whereBuilder.WriteString(params.Conditions)
129+
130+
whereBuilder.WriteString(where.SQL)
131+
for name, val := range where.NamedArgs {
132+
expectedSymbol := "@" + name
133+
if !strings.Contains(where.SQL, expectedSymbol) {
134+
return nil, fmt.Errorf("expected %q to contain named arg symbol %s", where.SQL, expectedSymbol)
135+
}
136+
137+
if _, ok := namedArgs[name]; ok {
138+
return nil, fmt.Errorf("named argument %s already registered", expectedSymbol)
139+
}
140+
141+
namedArgs[name] = val
142+
}
128143
}
129144

130145
// A condition of some kind is needed, so given no others write one that'll

internal/dblist/db_list_test.go

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ func TestJobListNoJobs(t *testing.T) {
5555
bundle := setup()
5656

5757
_, err := JobList(ctx, bundle.exec, &JobListParams{
58-
Conditions: "queue = 'test' AND priority = 1 AND args->>'foo' = @foo",
59-
NamedArgs: pgx.NamedArgs{"foo": "bar"},
6058
States: []rivertype.JobState{rivertype.JobStateCompleted},
6159
LimitCount: 1,
6260
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderAsc}},
61+
Where: []WherePredicate{
62+
{NamedArgs: map[string]any{"foo": "bar"}, SQL: "queue = 'test' AND priority = 1 AND args->>'foo' = @foo"},
63+
},
6364
}, bundle.driver.SQLFragmentColumnIn)
6465
require.NoError(t, err)
6566
})
@@ -148,11 +149,12 @@ func TestJobListWithJobs(t *testing.T) {
148149
bundle := setup(t)
149150

150151
params := &JobListParams{
151-
Conditions: "jsonb_extract_path(args, VARIADIC @paths1::text[]) = @value1::jsonb",
152152
LimitCount: 2,
153-
NamedArgs: map[string]any{"paths1": []string{"job_num"}, "value1": 2},
154153
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
155154
States: []rivertype.JobState{rivertype.JobStateAvailable},
155+
Where: []WherePredicate{
156+
{NamedArgs: map[string]any{"paths1": []string{"job_num"}, "value1": 2}, SQL: "jsonb_extract_path(args, VARIADIC @paths1::text[]) = @value1::jsonb"},
157+
},
156158
}
157159

158160
execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
@@ -164,7 +166,7 @@ func TestJobListWithJobs(t *testing.T) {
164166
})
165167
})
166168

167-
t.Run("ConditionsWithIDs", func(t *testing.T) {
169+
t.Run("WhereWithIDs", func(t *testing.T) {
168170
t.Parallel()
169171
bundle := setup(t)
170172
job1, job2, job3 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2]
@@ -188,7 +190,7 @@ func TestJobListWithJobs(t *testing.T) {
188190
})
189191
})
190192

191-
t.Run("ConditionsWithIDsAndPriorities", func(t *testing.T) {
193+
t.Run("WhereWithIDsAndPriorities", func(t *testing.T) {
192194
t.Parallel()
193195
bundle := setup(t)
194196
job1, job2, job3 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2]
@@ -207,17 +209,19 @@ func TestJobListWithJobs(t *testing.T) {
207209
})
208210
})
209211

210-
t.Run("ConditionsWithKinds", func(t *testing.T) {
212+
t.Run("WhereWithKinds", func(t *testing.T) {
211213
t.Parallel()
212214

213215
bundle := setup(t)
214216

215217
params := &JobListParams{
216-
Conditions: "finalized_at IS NULL",
217218
LimitCount: 2,
218219
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
219220
Kinds: []string{"alternate_kind"},
220221
States: []rivertype.JobState{rivertype.JobStateAvailable},
222+
Where: []WherePredicate{
223+
{SQL: "finalized_at IS NULL"},
224+
},
221225
}
222226

223227
execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
@@ -229,7 +233,7 @@ func TestJobListWithJobs(t *testing.T) {
229233
})
230234
})
231235

232-
t.Run("ConditionsWithPriorities", func(t *testing.T) {
236+
t.Run("WhereWithPriorities", func(t *testing.T) {
233237
t.Parallel()
234238
bundle := setup(t)
235239
_, job2, job3, _, job5 := bundle.jobs[0], bundle.jobs[1], bundle.jobs[2], bundle.jobs[3], bundle.jobs[4]
@@ -246,17 +250,19 @@ func TestJobListWithJobs(t *testing.T) {
246250
})
247251
})
248252

249-
t.Run("ConditionsWithQueues", func(t *testing.T) {
253+
t.Run("WhereWithQueues", func(t *testing.T) {
250254
t.Parallel()
251255

252256
bundle := setup(t)
253257

254258
params := &JobListParams{
255-
Conditions: "finalized_at IS NULL",
256259
LimitCount: 2,
257260
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
258261
Queues: []string{"priority"},
259262
States: []rivertype.JobState{rivertype.JobStateAvailable},
263+
Where: []WherePredicate{
264+
{SQL: "finalized_at IS NULL"},
265+
},
260266
}
261267

262268
execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
@@ -274,10 +280,11 @@ func TestJobListWithJobs(t *testing.T) {
274280
bundle := setup(t)
275281

276282
params := &JobListParams{
277-
Conditions: "metadata @> @metadata_filter::jsonb",
278283
LimitCount: 2,
279-
NamedArgs: map[string]any{"metadata_filter": `{"some_key": "some_value"}`},
280284
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
285+
Where: []WherePredicate{
286+
{NamedArgs: map[string]any{"metadata_filter": `{"some_key": "some_value"}`}, SQL: "metadata @> @metadata_filter::jsonb"},
287+
},
281288
}
282289

283290
execTest(ctx, t, bundle, params, func(jobs []*rivertype.JobRow, err error) {
@@ -288,4 +295,39 @@ func TestJobListWithJobs(t *testing.T) {
288295
require.Equal(t, []int64{job3.ID}, returnedIDs)
289296
})
290297
})
298+
299+
t.Run("NamedArgNotPresentInQueryError", func(t *testing.T) {
300+
t.Parallel()
301+
302+
bundle := setup(t)
303+
304+
params := &JobListParams{
305+
LimitCount: 2,
306+
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
307+
Where: []WherePredicate{
308+
{NamedArgs: map[string]any{"not_present": "foo"}, SQL: "1"},
309+
},
310+
}
311+
312+
_, err := JobList(ctx, bundle.exec, params, bundle.driver.SQLFragmentColumnIn)
313+
require.EqualError(t, err, `expected "1" to contain named arg symbol @not_present`)
314+
})
315+
316+
t.Run("DuplicateNamedArgError", func(t *testing.T) {
317+
t.Parallel()
318+
319+
bundle := setup(t)
320+
321+
params := &JobListParams{
322+
LimitCount: 2,
323+
OrderBy: []JobListOrderBy{{Expr: "id", Order: SortOrderDesc}},
324+
Where: []WherePredicate{
325+
{NamedArgs: map[string]any{"duplicate": "foo"}, SQL: "duplicate = @duplicate"},
326+
{NamedArgs: map[string]any{"duplicate": "foo"}, SQL: "duplicate = @duplicate"},
327+
},
328+
}
329+
330+
_, err := JobList(ctx, bundle.exec, params, bundle.driver.SQLFragmentColumnIn)
331+
require.EqualError(t, err, "named argument @duplicate already registered")
332+
})
291333
}

0 commit comments

Comments
 (0)