From 7342765302aa07e4f49ffd5cda050ff81445850e Mon Sep 17 00:00:00 2001 From: Brandur Date: Sat, 13 Jun 2026 08:31:02 -0500 Subject: [PATCH] Add `Config.HardStopTimeout` to perform a "hard stop" setting jobs errored Here, add a new `Config.HardStopTimeout` on top of the existing `SoftStopTimeout` whose job it is to recover badly behaving job as much as possible before coming to a full stop. Currently, if a client is stopping and is running jobs that don't respond to context cancellation, those jobs end up getting left in a `running` state, which means that they won't be recoverable again until they're rescued an hour later. `HardStopTimeout` engages after soft stop, and has each producer perform a "hard stop", which means to have it set any jobs still running to an error state. Because they're errored, they'll get to run immediately the next time a client starts up. Ideally, users don't need to depend on this functionality since the "correct" behavior would be to make sure that all jobs are able to respond to context cancellation, so we make this new feature optional. --- client.go | 125 ++++++++++++---- client_test.go | 136 ++++++++++++++++- ..._graceful_shutdown_stop_and_cancel_test.go | 8 +- example_graceful_shutdown_test.go | 6 +- internal/jobexecutor/job_executor.go | 14 ++ producer.go | 141 ++++++++++++++++-- producer_test.go | 81 ++++++++++ 7 files changed, 452 insertions(+), 59 deletions(-) diff --git a/client.go b/client.go index 3f29e545..6417d2dc 100644 --- a/client.go +++ b/client.go @@ -221,6 +221,18 @@ type Config struct { // Jobs may have their own specific hooks by implementing JobArgsWithHooks. Hooks []rivertype.Hook + // HardStopTimeout is the maximum amount of time that the client will wait + // after job contexts are cancelled during shutdown before forcing jobs still + // running to an errored state. This hard stop phase lets jobs be retried + // immediately on the next client start instead of waiting for rescue. + // + // The timer starts only after a soft stop has begun by cancelling job + // contexts, like after SoftStopTimeout elapses, StopAndCancel is called, or + // the Start context is cancelled without SoftStopTimeout configured. + // + // Defaults to no timeout (hard stop disabled). + HardStopTimeout time.Duration + // Logger is the structured logger to use for logging purposes. If none is // specified, logs will be emitted to STDOUT with messages at warn level // or higher. @@ -330,11 +342,9 @@ type Config struct { Schema string // SoftStopTimeout is the maximum amount of time that the client will wait - // for running jobs to finish during a stop before their contexts are - // cancelled. After the timeout elapses, the client escalates to a hard stop - // by cancelling the context of all running jobs. This applies regardless of - // how stop is initiated — whether by calling Stop, StopAndCancel, or by - // cancelling the context passed to Start. + // for running jobs to finish during a graceful stop before entering soft + // stop by cancelling job contexts. This applies when stop is initiated by + // calling Stop or by cancelling the context passed to Start. // // In combination with signal.NotifyContext on the context passed to Start, // this can simplify graceful stop to: @@ -345,12 +355,12 @@ type Config struct { // if err := client.Start(ctx); err != nil { ... } // <-client.Stopped() // - // The signal cancels the Start context, which initiates a soft stop. If + // The signal cancels the Start context, which initiates a graceful stop. If // running jobs haven't finished after SoftStopTimeout, their contexts are - // automatically cancelled to trigger a hard stop. + // cancelled. // - // StopAndCancel bypasses the timeout entirely and cancels job contexts - // immediately. + // StopAndCancel cancels job contexts immediately instead of waiting for + // SoftStopTimeout. // // Defaults to no timeout (wait indefinitely for jobs to finish). SoftStopTimeout time.Duration @@ -468,6 +478,7 @@ func (c *Config) WithDefaults() *Config { ErrorHandler: c.ErrorHandler, FetchCooldown: cmp.Or(c.FetchCooldown, FetchCooldownDefault), FetchPollInterval: cmp.Or(c.FetchPollInterval, FetchPollIntervalDefault), + HardStopTimeout: c.HardStopTimeout, ID: valutil.ValOrDefaultFunc(c.ID, func() string { return defaultClientID(time.Now().UTC()) }), Hooks: c.Hooks, JobInsertMiddleware: c.JobInsertMiddleware, @@ -515,6 +526,9 @@ func (c *Config) validate() error { if c.FetchPollInterval < c.FetchCooldown { return fmt.Errorf("FetchPollInterval cannot be shorter than FetchCooldown (%s)", c.FetchCooldown) } + if c.HardStopTimeout < 0 { + return errors.New("HardStopTimeout cannot be less than zero") + } if len(c.ID) > 100 { return errors.New("ID cannot be longer than 100 characters") } @@ -547,6 +561,9 @@ func (c *Config) validate() error { if c.Schema != "" && !postgresSchemaNameRE.MatchString(c.Schema) { return errors.New("Schema name can only contain letters, numbers, and underscores, and must start with a letter or underscore") } + if c.SoftStopTimeout < 0 { + return errors.New("SoftStopTimeout cannot be less than zero") + } for queue, queueConfig := range c.Queues { if err := queueConfig.validate(queue, c.FetchCooldown, c.FetchPollInterval); err != nil { @@ -1048,10 +1065,12 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client // A graceful shutdown stops fetching new jobs but allows any previously fetched // jobs to complete. This can be initiated with the Stop method. // -// A more abrupt shutdown can be achieved by either cancelling the provided -// context or by calling StopAndCancel. This will not only stop fetching new -// jobs, but will also cancel the context for any currently-running jobs. If -// using StopAndCancel, there's no need to also call Stop. +// A soft stop cancels job contexts after fetching has stopped. It can be +// initiated by calling StopAndCancel, by cancelling the provided context when +// SoftStopTimeout is not configured, or by waiting for SoftStopTimeout to elapse +// during graceful stop. If HardStopTimeout is configured, jobs still running +// after that timeout will be forced into an errored state. If using +// StopAndCancel, there's no need to also call Stop. func (c *Client[TTx]) Start(ctx context.Context) error { fetchCtx, shouldStart, started, stopped := c.baseStartStop.StartInit(ctx) if !shouldStart { @@ -1065,9 +1084,13 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // sure to take a channel reference before finishing stopped. c.stopped = c.baseStartStop.StoppedUnsafe() - producersAsServices := func() []startstop.Service { + producers := func() []*producer { + return maputil.Values(c.producersByQueueName) + } + + producersAsServices := func(producers []*producer) []startstop.Service { return sliceutil.Map( - maputil.Values(c.producersByQueueName), + producers, func(p *producer) startstop.Service { return p }, ) } @@ -1121,8 +1144,8 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // We use separate contexts for fetching and working to allow for a // graceful stop. When SoftStopTimeout is configured, the work context // is detached from the start context so that cancelling the start - // context initiates a soft stop (with timeout escalation) rather than - // an immediate hard stop. When SoftStopTimeout is not configured, the + // context initiates a graceful stop (with timeout escalation) rather + // than an immediate soft stop. When SoftStopTimeout is not configured, the // work context inherits from the start context to preserve the // existing behavior where cancelling the start context is equivalent // to StopAndCancel. @@ -1145,7 +1168,7 @@ func (c *Client[TTx]) Start(ctx context.Context) error { for _, producer := range c.producersByQueueName { if err := producer.StartWorkContext(fetchCtx, workCtx); err != nil { workCancel(err) - startstop.StopAllParallel(producersAsServices()...) + startstop.StopAllParallel(producersAsServices(producers())...) stopServicesOnError() return err } @@ -1167,7 +1190,7 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // Generate producer services while c.queues.startStopMu.Lock() is still // held. This is used for WaitAllStarted below, but don't use it elsewhere // because new producers may have been added while the client is running. - producerServices := producersAsServices() + producerServices := producersAsServices(producers()) go func() { // Wait for all subservices to start up before signaling our own start. @@ -1194,22 +1217,57 @@ func (c *Client[TTx]) Start(ctx context.Context) error { c.queues.startStopMu.Lock() defer c.queues.startStopMu.Unlock() + producerList := producers() + + hardStopTimerCtx, hardStopTimerCancel := context.WithCancel(context.WithoutCancel(ctx)) + defer hardStopTimerCancel() + + startHardStopTimer := sync.OnceFunc(func() { + if c.config.HardStopTimeout <= 0 { + return + } + + go func() { + timer := time.NewTimer(c.config.HardStopTimeout) + defer timer.Stop() + + select { + case <-timer.C: + c.baseService.Logger.WarnContext(ctx, c.baseService.Name+": Hard stop timeout; setting remaining jobs to errored", slog.Duration("hard_stop_timeout", c.config.HardStopTimeout)) + for _, producer := range producerList { + producer.hardStop() + } + case <-hardStopTimerCtx.Done(): + } + }() + }) + + workCtx := c.queues.workCtx + go func() { + select { + case <-workCtx.Done(): + startHardStopTimer() + case <-hardStopTimerCtx.Done(): + } + }() + // If SoftStopTimeout is configured, start a timer that will cancel - // the work context (escalating to a hard stop) if producers don't - // finish in time. StopAndCancel also calls workCancel, in which case - // this timer is a harmless no-op because the context is already done. + // the work context if producers don't finish in time. Once the work + // context is cancelled, the optional hard stop timer starts. if c.config.SoftStopTimeout > 0 { softStopTimer := time.AfterFunc(c.config.SoftStopTimeout, func() { c.baseService.Logger.WarnContext(ctx, c.baseService.Name+": Soft stop timeout; cancelling remaining job contexts", slog.Duration("soft_stop_timeout", c.config.SoftStopTimeout)) c.workCancel(rivercommon.ErrStop) + startHardStopTimer() }) defer softStopTimer.Stop() } // On stop, have the producers stop fetching first of all. c.baseService.Logger.DebugContext(ctx, c.baseService.Name+": Stopping producers") - startstop.StopAllParallel(producersAsServices()...) + startstop.StopAllParallel(producersAsServices(producerList)...) c.baseService.Logger.DebugContext(ctx, c.baseService.Name+": All producers stopped") + hardStopTimerCancel() c.workCancel(rivercommon.ErrStop) @@ -1238,12 +1296,14 @@ func (c *Client[TTx]) Start(ctx context.Context) error { // complete before exiting. If the provided context is done before shutdown has // completed, Stop will return immediately with the context's error. // -// If SoftStopTimeout is configured, running job contexts will be automatically -// cancelled after the timeout elapses, escalating to a hard stop. This also +// If SoftStopTimeout is configured, jobs still running after the timeout +// elapses have their contexts cancelled. If HardStopTimeout is also configured, +// jobs still running after that second timeout are forced into an errored state +// so they can be retried immediately on the next client start. This also // applies when stop is initiated by cancelling the context passed to Start. // -// There's no need to call this method if a hard stop has already been initiated -// by cancelling the context passed to Start or by calling StopAndCancel. +// There's no need to call this method if shutdown has already been initiated by +// cancelling the context passed to Start or by calling StopAndCancel. func (c *Client[TTx]) Stop(ctx context.Context) error { shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() if !shouldStop { @@ -1262,10 +1322,11 @@ func (c *Client[TTx]) Stop(ctx context.Context) error { // StopAndCancel shuts down the client and cancels all work in progress. It is a // more aggressive stop than Stop because the contexts for any in-progress jobs -// are cancelled. However, it still waits for jobs to complete before returning, -// even though their contexts are cancelled. If the provided context is done -// before shutdown has completed, StopAndCancel will return immediately with the -// context's error. +// are cancelled immediately. If HardStopTimeout is configured, jobs that still +// remain running after the timeout are hard-stopped; otherwise, StopAndCancel +// waits for jobs to complete even though their contexts are cancelled. If the +// provided context is done before shutdown has completed, StopAndCancel will +// return immediately with the context's error. // // This can also be initiated by cancelling the context passed to Start. There is // no need to call this method if the context passed to Start is cancelled @@ -1277,7 +1338,7 @@ func (c *Client[TTx]) Stop(ctx context.Context) error { // graceful stop semantics without requiring manual orchestration of Stop and // StopAndCancel. func (c *Client[TTx]) StopAndCancel(ctx context.Context) error { - c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Hard stop started; cancelling all work") + c.baseService.Logger.InfoContext(ctx, c.baseService.Name+": Soft stop started; cancelling all work") c.workCancel(rivercommon.ErrStop) shouldStop, stopped, finalizeStop := c.baseStartStop.StopInit() diff --git a/client_test.go b/client_test.go index 9352dcbb..a8a3bfa8 100644 --- a/client_test.go +++ b/client_test.go @@ -2538,6 +2538,98 @@ func Test_Client_StopAndCancel(t *testing.T) { }) } +func Test_Client_HardStopTimeout(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + type JobArgs struct { + testutil.JobArgsReflectKind[JobArgs] + } + + setup := func(t *testing.T, configFunc func(config *Config)) (*Client[pgx.Tx], *rivertype.JobRow, chan struct{}, chan struct{}, func()) { + t.Helper() + + config := newTestConfig(t, "") + configFunc(config) + + jobContextDoneChan := make(chan struct{}) + jobReleasedChan := make(chan struct{}) + jobStartedChan := make(chan struct{}) + releaseJobChan := make(chan struct{}) + releaseJob := sync.OnceFunc(func() { close(releaseJobChan) }) + + AddWorker(config.Workers, WorkFunc(func(ctx context.Context, job *Job[JobArgs]) error { + close(jobStartedChan) + <-ctx.Done() + close(jobContextDoneChan) + <-releaseJobChan + close(jobReleasedChan) + return nil + })) + + client := runNewTestClient(ctx, t, config) + t.Cleanup(releaseJob) + + insertRes, err := client.Insert(ctx, JobArgs{}, nil) + require.NoError(t, err) + + riversharedtest.WaitOrTimeout(t, jobStartedChan) + + return client, insertRes.Job, jobContextDoneChan, jobReleasedChan, releaseJob + } + + requireHardStoppedAvailable := func(t *testing.T, client *Client[pgx.Tx], jobID int64) { + t.Helper() + + jobAfter, err := client.JobGet(ctx, jobID) + require.NoError(t, err) + require.Equal(t, rivertype.JobStateAvailable, jobAfter.State) + require.Len(t, jobAfter.Errors, 1) + require.Equal(t, producerHardStopError, jobAfter.Errors[0].Error) + require.Equal(t, 1, jobAfter.Errors[0].Attempt) + require.Empty(t, jobAfter.Errors[0].Trace) + require.Nil(t, jobAfter.FinalizedAt) + } + + t.Run("AfterSoftStopTimeout", func(t *testing.T) { + t.Parallel() + + client, job, jobContextDoneChan, jobReleasedChan, releaseJob := setup(t, func(config *Config) { + config.HardStopTimeout = 100 * time.Millisecond + config.SoftStopTimeout = 100 * time.Millisecond + }) + + stopCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + require.NoError(t, client.Stop(stopCtx)) + + riversharedtest.WaitOrTimeout(t, jobContextDoneChan) + requireHardStoppedAvailable(t, client, job.ID) + + releaseJob() + riversharedtest.WaitOrTimeout(t, jobReleasedChan) + }) + + t.Run("AfterStopAndCancel", func(t *testing.T) { + t.Parallel() + + client, job, jobContextDoneChan, jobReleasedChan, releaseJob := setup(t, func(config *Config) { + config.HardStopTimeout = 100 * time.Millisecond + }) + + stopCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + require.NoError(t, client.StopAndCancel(stopCtx)) + + riversharedtest.WaitOrTimeout(t, jobContextDoneChan) + requireHardStoppedAvailable(t, client, job.ID) + + releaseJob() + riversharedtest.WaitOrTimeout(t, jobReleasedChan) + }) +} + func Test_Client_SoftStopTimeout(t *testing.T) { t.Parallel() @@ -2547,7 +2639,7 @@ func Test_Client_SoftStopTimeout(t *testing.T) { testutil.JobArgsReflectKind[JobArgs] } - t.Run("EscalatesToHardStopAfterTimeout", func(t *testing.T) { + t.Run("CancelsJobsAfterTimeout", func(t *testing.T) { t.Parallel() config := newTestConfig(t, "") @@ -2569,8 +2661,8 @@ func Test_Client_SoftStopTimeout(t *testing.T) { riversharedtest.WaitOrTimeout(t, jobStartedChan) - // Stop initiates a soft stop. The job won't finish on its own, but - // SoftStopTimeout should escalate to a hard stop after 100ms. + // Stop initiates a graceful stop. The job won't finish on its own, but + // SoftStopTimeout should cancel its context after 100ms. require.NoError(t, client.Stop(ctx)) // Verify the job's context was indeed cancelled. @@ -2605,7 +2697,7 @@ func Test_Client_SoftStopTimeout(t *testing.T) { require.NoError(t, client.Stop(ctx)) }) - t.Run("ContextCancellationEscalatesAfterTimeout", func(t *testing.T) { + t.Run("StartContextCancellationCancelsJobsAfterTimeout", func(t *testing.T) { t.Parallel() config := newTestConfig(t, "") @@ -2640,8 +2732,8 @@ func Test_Client_SoftStopTimeout(t *testing.T) { riversharedtest.WaitOrTimeout(t, jobStartedChan) - // Cancel the start context. This should initiate a soft stop, then - // escalate to hard stop after SoftStopTimeout. + // Cancel the start context. This should initiate a graceful stop, then + // cancel job contexts after SoftStopTimeout. startCtxCancel() riversharedtest.WaitOrTimeout(t, client.Stopped()) @@ -8291,6 +8383,22 @@ func Test_NewClient_Validations(t *testing.T) { }, wantErr: fmt.Errorf("FetchPollInterval cannot be shorter than FetchCooldown (%s)", 20*time.Millisecond), }, + { + name: "HardStopTimeout cannot be negative", + configFunc: func(config *Config) { + config.HardStopTimeout = -1 + }, + wantErr: errors.New("HardStopTimeout cannot be less than zero"), + }, + { + name: "HardStopTimeout may be overridden", + configFunc: func(config *Config) { + config.HardStopTimeout = 23 * time.Second + }, + validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper + require.Equal(t, 23*time.Second, client.config.HardStopTimeout) + }, + }, { name: "FetchPollInterval cannot be less than MinFetchPollInterval", configFunc: func(config *Config) { config.FetchPollInterval = time.Millisecond - 1 }, @@ -8486,6 +8594,22 @@ func Test_NewClient_Validations(t *testing.T) { }, wantErr: errors.New("Schema name can only contain letters, numbers, and underscores, and must start with a letter or underscore"), }, + { + name: "SoftStopTimeout cannot be negative", + configFunc: func(config *Config) { + config.SoftStopTimeout = -1 + }, + wantErr: errors.New("SoftStopTimeout cannot be less than zero"), + }, + { + name: "SoftStopTimeout may be overridden", + configFunc: func(config *Config) { + config.SoftStopTimeout = 23 * time.Second + }, + validateResult: func(t *testing.T, client *Client[pgx.Tx]) { //nolint:thelper + require.Equal(t, 23*time.Second, client.config.SoftStopTimeout) + }, + }, { name: "Queues can be nil when Workers is also nil", configFunc: func(config *Config) { diff --git a/example_graceful_shutdown_stop_and_cancel_test.go b/example_graceful_shutdown_stop_and_cancel_test.go index 30217c2d..86aea931 100644 --- a/example_graceful_shutdown_stop_and_cancel_test.go +++ b/example_graceful_shutdown_stop_and_cancel_test.go @@ -20,8 +20,8 @@ import ( // Example_gracefulShutdownStopCancel demonstrates graceful stop with explicit // fallback to StopAndCancel. When a SIGINT/SIGTERM arrives, Stop initiates a -// soft stop. If running jobs don't finish before the soft stop context expires, -// StopAndCancel cancels their contexts (hard stop). This example is intended to +// graceful stop. If running jobs don't finish before the graceful stop context +// expires, StopAndCancel cancels their contexts. This example is intended to // demonstrate advanced use of StopAndCancel. Generally, prefer the method shown // in Example_gracefulShutdown over the one here. func Example_gracefulShutdownStopAndCancel() { @@ -59,8 +59,8 @@ func Example_gracefulShutdownStopAndCancel() { } // Use signal.NotifyContext to detect SIGINT/SIGTERM, but don't pass the - // signal context to Start. Cancelling the Start context cancels running job - // contexts immediately, which is equivalent to StopAndCancel. + // signal context to Start. Cancelling the Start context would cancel running + // job contexts immediately, which is equivalent to StopAndCancel. signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer stop() diff --git a/example_graceful_shutdown_test.go b/example_graceful_shutdown_test.go index bf99b8ac..72cb4d60 100644 --- a/example_graceful_shutdown_test.go +++ b/example_graceful_shutdown_test.go @@ -43,8 +43,8 @@ func (w *WaitsForCancelOnlyWorker) Work(ctx context.Context, job *river.Job[Wait // Example_gracefulShutdown demonstrates graceful stop using SoftStopTimeout. // When a SIGINT/SIGTERM arrives, the start context is cancelled, which -// initiates a soft stop. If running jobs don't finish within the configured -// SoftStopTimeout, their contexts are automatically cancelled (hard stop). +// initiates a graceful stop. If running jobs don't finish within the configured +// SoftStopTimeout, their contexts are automatically cancelled. func Example_gracefulShutdown() { ctx := context.Background() @@ -77,7 +77,7 @@ func Example_gracefulShutdown() { } // Use signal.NotifyContext to cancel the start context on SIGINT/SIGTERM. - // When the signal fires, the client initiates a soft stop. If running jobs + // When the signal fires, the client initiates a graceful stop. If running jobs // don't finish within SoftStopTimeout, their contexts are cancelled. signalCtx, stop := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) defer stop() diff --git a/internal/jobexecutor/job_executor.go b/internal/jobexecutor/job_executor.go index dab44f39..8bf9e852 100644 --- a/internal/jobexecutor/job_executor.go +++ b/internal/jobexecutor/job_executor.go @@ -9,6 +9,7 @@ import ( "log/slog" "runtime" "strings" + "sync/atomic" "time" "github.com/tidwall/gjson" @@ -110,6 +111,7 @@ type JobExecutor struct { ClientRetryPolicy ClientRetryPolicy DefaultClientRetryPolicy ClientRetryPolicy ErrorHandler ErrorHandler + hardStopped atomic.Bool HookLookupByJob *hooklookup.JobHookLookup HookLookupGlobal hooklookup.HookLookupInterface JobRow *rivertype.JobRow @@ -148,6 +150,12 @@ func (e *JobExecutor) Execute(ctx context.Context) { res.Err = context.Cause(ctx) } + // Hard-stopped jobs have already been moved out of running by the producer. + if e.hardStopped.Load() { + e.ProducerCallbacks.JobDone(e.JobRow) + return + } + var multiJobErrors withJobsAndErrorsByID if res.Err != nil { multiJobErrors, _ = res.Err.(withJobsAndErrorsByID) @@ -167,6 +175,12 @@ func (e *JobExecutor) Execute(ctx context.Context) { e.ProducerCallbacks.JobDone(e.JobRow) } +// HardStop suppresses completion reporting for a job that's been forcibly +// errored by its producer during shutdown. +func (e *JobExecutor) HardStop() { + e.hardStopped.Store(true) +} + // Executes the job, handling a panic if necessary (and various other error // conditions). The named return value is so that we can still return a value in // case of a panic. diff --git a/producer.go b/producer.go index dbd40653..b365606f 100644 --- a/producer.go +++ b/producer.go @@ -34,6 +34,7 @@ import ( ) const ( + producerHardStopError = "job stopped because River client hard stopped" producerReportIntervalDefault = time.Minute queuePollIntervalDefault = 2 * time.Second queueReportIntervalDefault = 10 * time.Minute @@ -187,8 +188,10 @@ type producer struct { exec riverdriver.Executor errorHandler jobexecutor.ErrorHandler fetchLimiter *chanutil.DebouncedChan - state riverpilot.ProducerState + hardStopCh chan struct{} // signals that a "hard stop" has been initiated (set all running jobs to errored regardless of whether they stopped cleanly) + hardStopOnce *sync.Once // closes hardStopChan exactly once pilot riverpilot.Pilot + state riverpilot.ProducerState workers *Workers // Receives job IDs to cancel. Written by notifier goroutine, only read from @@ -202,7 +205,7 @@ type producer struct { // Receives completed jobs from workers. Written by completed workers, only // read from main goroutine. - jobResultCh chan *rivertype.JobRow + jobResultCh chan *producerJobResult jobTimeout time.Duration @@ -240,7 +243,9 @@ func newProducer(archetype *baseservice.Archetype, exec riverdriver.Executor, pi config: config.mustValidate(), exec: exec, errorHandler: errorHandler, - jobResultCh: make(chan *rivertype.JobRow, config.MaxWorkers), + hardStopCh: make(chan struct{}), + hardStopOnce: &sync.Once{}, + jobResultCh: make(chan *producerJobResult, config.MaxWorkers), jobTimeout: config.JobTimeout, pilot: pilot, queueControlCh: make(chan *controlEventPayload, 100), @@ -279,6 +284,9 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { return nil } + p.hardStopCh = make(chan struct{}) + p.hardStopOnce = &sync.Once{} + isExpectedShutdownError := func(err error) bool { return errors.Is(err, startstop.ErrStop) || strings.HasSuffix(err.Error(), "conn closed") || fetchCtx.Err() != nil } @@ -415,7 +423,7 @@ func (p *producer) StartWorkContext(fetchCtx, workCtx context.Context) error { p.fetchAndRunLoop(fetchCtx, workCtx) p.Logger.DebugContext(workCtx, p.Name+": Entering shutdown loop", slog.String("queue", p.config.Queue), slog.Int64("id", p.id.Load())) - p.executorShutdownLoop() + p.executorShutdownLoop(context.WithoutCancel(fetchCtx)) p.Logger.DebugContext(workCtx, p.Name+": Shutdown loop exited, awaiting subroutines", slog.String("queue", p.config.Queue), slog.Int64("id", p.id.Load())) cancelSubroutines(fmt.Errorf("producer stopped: %w", startstop.ErrStop)) @@ -470,6 +478,11 @@ type insertPayload struct { Queue string `json:"queue"` } +type producerJobResult struct { + executor *jobexecutor.JobExecutor + job *rivertype.JobRow +} + func (p *producer) handleControlNotification(workCtx context.Context) func(notifier.NotificationTopic, string) { return func(topic notifier.NotificationTopic, payload string) { var decoded controlEventPayload @@ -663,12 +676,29 @@ func (p *producer) innerFetchLoop(workCtx context.Context, fetchResultCh chan pr } } -func (p *producer) executorShutdownLoop() { +func (p *producer) executorShutdownLoop(ctx context.Context) { // No more jobs will be fetched or executed. However, we must wait for all // in-progress jobs to complete. for len(p.activeJobs) != 0 { - result := <-p.jobResultCh - p.removeActiveJob(result) + select { + case result := <-p.jobResultCh: + p.removeActiveJob(result) + case <-p.hardStopCh: + p.drainJobResults() + p.hardStopActiveJobs(ctx) + return + } + } +} + +func (p *producer) drainJobResults() { + for { + select { + case result := <-p.jobResultCh: + p.removeActiveJob(result) + default: + return + } } } @@ -728,11 +758,90 @@ func (p *producer) addActiveJob(id int64, executor *jobexecutor.JobExecutor) { p.activeJobs[id] = executor } -func (p *producer) removeActiveJob(job *rivertype.JobRow) { - delete(p.activeJobs, job.ID) +func (p *producer) hardStop() { + p.hardStopOnce.Do(func() { close(p.hardStopCh) }) +} + +func (p *producer) hardStopActiveJobs(ctx context.Context) { + if len(p.activeJobs) == 0 { + return + } + + now := p.Time.Now() + params := &riverdriver.JobSetStateIfRunningManyParams{ + Attempt: make([]*int, 0, len(p.activeJobs)), + ErrData: make([][]byte, 0, len(p.activeJobs)), + FinalizedAt: make([]*time.Time, 0, len(p.activeJobs)), + ID: make([]int64, 0, len(p.activeJobs)), + MetadataDoMerge: make([]bool, 0, len(p.activeJobs)), + MetadataUpdates: make([][]byte, 0, len(p.activeJobs)), + Now: &now, + ScheduledAt: make([]*time.Time, 0, len(p.activeJobs)), + Schema: p.config.Schema, + State: make([]rivertype.JobState, 0, len(p.activeJobs)), + } + + for _, executor := range p.activeJobs { + p.hardStop() + + job := executor.JobRow + errData, err := json.Marshal(rivertype.AttemptError{ + At: now, + Attempt: job.Attempt, + Error: producerHardStopError, + }) + if err != nil { + panic(fmt.Errorf("error serializing hard stop error: %w", err)) + } + + var setStateParams *riverdriver.JobSetStateIfRunningParams + if job.Attempt >= job.MaxAttempts { + setStateParams = riverdriver.JobSetStateDiscarded(job.ID, now, errData, nil) + } else { + setStateParams = riverdriver.JobSetStateErrorAvailable(job.ID, now, errData, nil) + } + + params.Attempt = append(params.Attempt, setStateParams.Attempt) + params.ErrData = append(params.ErrData, setStateParams.ErrData) + params.FinalizedAt = append(params.FinalizedAt, setStateParams.FinalizedAt) + params.ID = append(params.ID, setStateParams.ID) + params.MetadataDoMerge = append(params.MetadataDoMerge, setStateParams.MetadataDoMerge) + params.MetadataUpdates = append(params.MetadataUpdates, setStateParams.MetadataUpdates) + params.ScheduledAt = append(params.ScheduledAt, setStateParams.ScheduledAt) + params.State = append(params.State, setStateParams.State) + } + + timeoutCtx, cancel := context.WithTimeout(ctx, rivercommon.HotOperationTimeout) + defer cancel() + + if _, err := p.pilot.JobSetStateIfRunningMany(timeoutCtx, p.exec, params); err != nil { + p.Logger.ErrorContext(ctx, p.Name+": Error setting hard-stopped jobs to errored", slog.String("err", err.Error()), slog.Int("num_jobs", len(params.ID)), slog.String("queue", p.config.Queue)) + } else { + p.Logger.WarnContext(ctx, p.Name+": Hard-stopped running jobs", slog.Int("num_jobs", len(params.ID)), slog.String("queue", p.config.Queue)) + } + + numActiveJobs := len(p.activeJobs) + for _, executor := range p.activeJobs { + if p.state != nil { + p.state.JobFinish(executor.JobRow) + } + } + p.activeJobs = make(map[int64]*jobexecutor.JobExecutor) + p.numJobsActive.Add(-int32(numActiveJobs)) //nolint:gosec +} + +func (p *producer) removeActiveJob(result *producerJobResult) { + // Ignore stale results from executors hard-stopped out of active tracking. + if activeExecutor := p.activeJobs[result.job.ID]; activeExecutor != result.executor { + return + } + + delete(p.activeJobs, result.job.ID) p.numJobsActive.Add(-1) p.numJobsRan.Add(1) - p.state.JobFinish(job) + if p.state != nil { + p.state.JobFinish(result.job) + } } func (p *producer) maybeCancelJob(ctx context.Context, id int64) { @@ -822,7 +931,8 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. // jobCancel will always be called by the executor to prevent leaks. jobCtx, jobCancel := context.WithCancelCause(workCtx) - executor := baseservice.Init(&p.Archetype, &jobexecutor.JobExecutor{ + var executor *jobexecutor.JobExecutor + executor = baseservice.Init(&p.Archetype, &jobexecutor.JobExecutor{ CancelFunc: jobCancel, ClientJobTimeout: p.jobTimeout, ClientRetryPolicy: p.retryPolicy, @@ -838,7 +948,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. Stuck func() Unstuck func() }{ - JobDone: p.handleWorkerDone, + JobDone: func(jobRow *rivertype.JobRow) { p.handleWorkerDone(executor, jobRow) }, Stuck: func() { p.numJobsStuck.Add(1) }, Unstuck: func() { p.numJobsStuck.Add(-1) }, }, @@ -859,8 +969,11 @@ func (p *producer) maxJobsToFetch() int { return p.config.MaxWorkers - int(p.numJobsActive.Load()) } -func (p *producer) handleWorkerDone(job *rivertype.JobRow) { - p.jobResultCh <- job +func (p *producer) handleWorkerDone(executor *jobexecutor.JobExecutor, job *rivertype.JobRow) { + p.jobResultCh <- &producerJobResult{ + executor: executor, + job: job, + } } func (p *producer) pollForSettingChanges(ctx context.Context, wg *sync.WaitGroup, lastPaused bool, lastMetadata []byte) { diff --git a/producer_test.go b/producer_test.go index c03bd766..8b1464c6 100644 --- a/producer_test.go +++ b/producer_test.go @@ -12,6 +12,7 @@ import ( "github.com/riverqueue/river/internal/hooklookup" "github.com/riverqueue/river/internal/jobcompleter" + "github.com/riverqueue/river/internal/jobexecutor" "github.com/riverqueue/river/internal/maintenance" "github.com/riverqueue/river/internal/middlewarelookup" "github.com/riverqueue/river/internal/notifier" @@ -161,6 +162,86 @@ func Test_Producer_CanSafelyCompleteJobsWhileFetchingNewOnes(t *testing.T) { } } +func TestProducer_HardStopActiveJobs(t *testing.T) { + t.Parallel() + + ctx := context.Background() + require := require.New(t) + + var ( + archetype = riversharedtest.BaseServiceArchetype(t) + dbPool = riversharedtest.DBPool(ctx, t) + driver = riverpgxv5.New(dbPool) + exec = driver.GetExecutor() + pilot = &riverpilot.StandardPilot{} + schema = riverdbtest.TestSchema(ctx, t, driver, nil) + ) + + completer := jobcompleter.NewInlineCompleter(archetype, schema, exec, pilot, make(chan []jobcompleter.CompleterJobUpdated, 10)) + + producer := newProducer(archetype, exec, pilot, &producerConfig{ + ClientID: testClientID, + Completer: completer, + ErrorHandler: newTestErrorHandler(), + FetchCooldown: FetchCooldownDefault, + FetchPollInterval: FetchPollIntervalDefault, + HookLookupByJob: hooklookup.NewJobHookLookup(), + HookLookupGlobal: hooklookup.NewHookLookup(nil), + JobTimeout: JobTimeoutDefault, + MaxWorkers: 10, + MiddlewareLookupGlobal: middlewarelookup.NewMiddlewareLookup(nil), + Queue: rivercommon.QueueDefault, + QueuePollInterval: queuePollIntervalDefault, + QueueReportInterval: queueReportIntervalDefault, + RetryPolicy: &DefaultClientRetryPolicy{}, + SchedulerInterval: maintenance.JobSchedulerIntervalDefault, + Schema: schema, + StaleProducerRetentionPeriod: time.Minute, + Workers: NewWorkers(), + }) + + runningState := rivertype.JobStateRunning + retryableJob := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{ + Attempt: ptrutil.Ptr(1), + MaxAttempts: ptrutil.Ptr(3), + Schema: schema, + State: &runningState, + }) + discardedJob := testfactory.Job(ctx, t, exec, &testfactory.JobOpts{ + Attempt: ptrutil.Ptr(3), + MaxAttempts: ptrutil.Ptr(3), + Schema: schema, + State: &runningState, + }) + + producer.addActiveJob(retryableJob.ID, &jobexecutor.JobExecutor{JobRow: retryableJob}) + producer.addActiveJob(discardedJob.ID, &jobexecutor.JobExecutor{JobRow: discardedJob}) + + producer.hardStop() + producer.executorShutdownLoop(ctx) + + require.Empty(producer.activeJobs) + require.Zero(producer.numJobsActive.Load()) + + retryableJobAfter, err := exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: retryableJob.ID, Schema: schema}) + require.NoError(err) + require.Equal(rivertype.JobStateAvailable, retryableJobAfter.State) + require.Len(retryableJobAfter.Errors, 1) + require.Equal(producerHardStopError, retryableJobAfter.Errors[0].Error) + require.Equal(retryableJob.Attempt, retryableJobAfter.Errors[0].Attempt) + require.Empty(retryableJobAfter.Errors[0].Trace) + require.Nil(retryableJobAfter.FinalizedAt) + + discardedJobAfter, err := exec.JobGetByID(ctx, &riverdriver.JobGetByIDParams{ID: discardedJob.ID, Schema: schema}) + require.NoError(err) + require.Equal(rivertype.JobStateDiscarded, discardedJobAfter.State) + require.Len(discardedJobAfter.Errors, 1) + require.Equal(producerHardStopError, discardedJobAfter.Errors[0].Error) + require.Equal(discardedJob.Attempt, discardedJobAfter.Errors[0].Attempt) + require.Empty(discardedJobAfter.Errors[0].Trace) + require.NotNil(discardedJobAfter.FinalizedAt) +} + func TestProducer_PollOnly(t *testing.T) { t.Parallel()