Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions internal/cmd/diagnose.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"time"

"github.com/spf13/cobra"
Expand Down Expand Up @@ -54,8 +55,11 @@ Do not pass go test -trace; use diagnose --trace instead. With --shuffle-seed, a
iterSetup := hooks.BuildIterationHook(runnerOpts, shell, hooks.PhaseSetup)
iterTeardown := hooks.BuildIterationHook(runnerOpts, shell, hooks.PhaseTeardown)

runCtx, stopRun := runner.NewDiagnoseRunContext(context.WithoutCancel(cmd.Context()))
defer stopRun()

state, start, resultsDir, runErr := runner.RunIterations(
cmd.Context(),
runCtx,
conf,
out,
args,
Expand All @@ -66,12 +70,12 @@ Do not pass go test -trace; use diagnose --trace instead. With --shuffle-seed, a

finishResourceCleanup(cmd, &runErr, cleanup)
if runErr != nil {
if cmd.Context().Err() == nil {
if runCtx.Err() == nil {
return runErr
}
}

return runner.FinishDiagnoseAnalysis(cmd.Context(), conf, out, args, state, start, resultsDir)
return runner.FinishDiagnoseAnalysis(runCtx, conf, out, args, state, start, resultsDir)
}),
}

Expand Down
9 changes: 5 additions & 4 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,11 @@ func NewRootCommand(runnerOpts hooks.RunOptions) *cobra.Command {
return rootCmd
}

// Execute runs the root command. A SIGINT or SIGTERM cancels the context so
// long-running subcommands (notably `diagnose`) can stop cleanly and still write
// their post-run analysis. A second signal hits the default handler and
// force-exits.
// Execute runs the root command. A SIGINT or SIGTERM cancels cmd.Context() on
// the first signal so generic subcommands (go test, gotestsum) stop promptly.
// diagnose installs its own two-stage context: first signal finishes in-flight
// iterations; second signal hard-cancels. A second root signal hits the default
// handler and force-exits when diagnose is not running.
func Execute(opts ...hooks.Option) {
if err := runExecute(opts...); err != nil {
os.Exit(1)
Expand Down
130 changes: 130 additions & 0 deletions internal/runner/diagnose_cancel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package runner

import (
"context"
"os"
"os/signal"
"sync"
"sync/atomic"
"syscall"
)

type diagnoseCancelKey struct{}

type diagnoseCancelState struct {
softStop atomic.Bool
hardCancel context.CancelFunc
softStopCh chan struct{}
softOnce sync.Once
}

// NewDiagnoseRunContext returns a context for diagnose iteration runs with
// two-stage cancellation. The first SIGINT/SIGTERM requests a graceful stop
// (finish in-flight iterations, do not enqueue new ones). The second signal
// hard-cancels the context. Pass context.WithoutCancel(cmd.Context()) as parent
// so the root CLI signal handler does not cancel iteration execution on the
// first press.
func NewDiagnoseRunContext(parent context.Context) (context.Context, func()) {
return newDiagnoseRunContext(parent, true)
}

// NewDiagnoseRunContextForTest is like NewDiagnoseRunContext but without OS
// signal handling. Use in unit and synctest tests with RequestDiagnoseGracefulStop
// and RequestDiagnoseHardCancel.
func NewDiagnoseRunContextForTest(parent context.Context) (context.Context, func()) {
return newDiagnoseRunContext(parent, false)
}

func newDiagnoseRunContext(parent context.Context, listenSignals bool) (context.Context, func()) {
ctx, hardCancel := context.WithCancel(parent)
state := &diagnoseCancelState{
hardCancel: hardCancel,
softStopCh: make(chan struct{}),
}
ctx = context.WithValue(ctx, diagnoseCancelKey{}, state)

stop := func() {
hardCancel()
}

if !listenSignals {
return ctx, stop
}

sigCh := make(chan os.Signal, 2)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)

stop = func() {
signal.Stop(sigCh)
hardCancel()
}

go func() {
defer signal.Stop(sigCh)
for {
select {
case <-ctx.Done():
return
case <-sigCh:
state.requestGracefulStop()
select {
case <-sigCh:
hardCancel()
return
case <-ctx.Done():
return
}
}
}
}()

return ctx, stop
}

func (s *diagnoseCancelState) requestGracefulStop() {
s.softOnce.Do(func() {
s.softStop.Store(true)
close(s.softStopCh)
})
}

func diagnoseCancelFromContext(ctx context.Context) *diagnoseCancelState {
if ctx == nil {
return nil
}
state, _ := ctx.Value(diagnoseCancelKey{}).(*diagnoseCancelState)
return state
}

// DiagnoseGracefulStopChan returns a channel closed on the first graceful stop
// request, or nil when ctx is not from NewDiagnoseRunContext.
func DiagnoseGracefulStopChan(ctx context.Context) <-chan struct{} {
state := diagnoseCancelFromContext(ctx)
if state == nil {
return nil
}
return state.softStopCh
}

// DiagnoseGracefulStopRequested reports whether the user requested a graceful
// stop (first Ctrl+C) on a context from NewDiagnoseRunContext.
func DiagnoseGracefulStopRequested(ctx context.Context) bool {
state := diagnoseCancelFromContext(ctx)
return state != nil && state.softStop.Load()
}

// RequestDiagnoseGracefulStop simulates the first interrupt for tests.
func RequestDiagnoseGracefulStop(ctx context.Context) {
state := diagnoseCancelFromContext(ctx)
if state != nil {
state.requestGracefulStop()
}
}

// RequestDiagnoseHardCancel simulates the second interrupt for tests.
func RequestDiagnoseHardCancel(ctx context.Context) {
state := diagnoseCancelFromContext(ctx)
if state != nil {
state.hardCancel()
}
}
63 changes: 63 additions & 0 deletions internal/runner/diagnose_cancel_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package runner

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestDiagnoseGracefulStopRequested(t *testing.T) {
t.Parallel()

t.Run("background", func(t *testing.T) {
t.Parallel()
assert.False(t, DiagnoseGracefulStopRequested(context.Background()))
})

t.Run("diagnose_context_before_signal", func(t *testing.T) {
t.Parallel()
ctx, stop := NewDiagnoseRunContext(context.Background())
defer stop()
assert.False(t, DiagnoseGracefulStopRequested(ctx))
})

t.Run("diagnose_context_after_graceful_request", func(t *testing.T) {
t.Parallel()
ctx, stop := NewDiagnoseRunContext(context.Background())
defer stop()
RequestDiagnoseGracefulStop(ctx)
assert.True(t, DiagnoseGracefulStopRequested(ctx))
})
}

func TestDiagnoseHardCancel(t *testing.T) {
t.Parallel()
ctx, stop := NewDiagnoseRunContext(context.Background())
defer stop()

RequestDiagnoseHardCancel(ctx)
require.Error(t, ctx.Err())
assert.False(t, DiagnoseGracefulStopRequested(ctx))
}

func TestDiagnoseGracefulThenHardCancel(t *testing.T) {
t.Parallel()
ctx, stop := NewDiagnoseRunContext(context.Background())
defer stop()

RequestDiagnoseGracefulStop(ctx)
assert.True(t, DiagnoseGracefulStopRequested(ctx))
require.NoError(t, ctx.Err())

RequestDiagnoseHardCancel(ctx)
require.Error(t, ctx.Err())
}

func TestDiagnoseRunContextStopTearsDown(t *testing.T) {
t.Parallel()
ctx, stop := NewDiagnoseRunContext(context.Background())
stop()
require.Error(t, ctx.Err())
}
28 changes: 28 additions & 0 deletions internal/runner/diagnose_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"path/filepath"
"runtime"
"time"

"charm.land/lipgloss/v2"
Expand Down Expand Up @@ -154,3 +155,30 @@ func diagnoseCSVPath(resultsDir string, rep *Report) string {
}
return filepath.Join(resultsDir, "report.csv")
}

func diagnoseInterruptKeyHint() string {
if runtime.GOOS == "darwin" {
return "⌘C"
}
return "Ctrl+C"
}

func printDiagnoseGracefulStopNotice(out *output.Printer, completed, total int) {
if out == nil {
return
}
if out.AIOutput() {
out.Stderrf("stop_graceful completed=%d total=%d\n", completed, total)
return
}
out.ClearInline()
hint := diagnoseInterruptKeyHint()
out.HumanStderr(
termstyle.Accent.Render(
fmt.Sprintf("Stopping diagnose run after current iteration — %d/%d completed.", completed, total),
) + "\n" +
termstyle.Muted.Render(
fmt.Sprintf("Press %s again to cancel immediately.", hint),
),
)
}
Loading
Loading