diff --git a/await.go b/await.go index 693da577..6366fc28 100644 --- a/await.go +++ b/await.go @@ -56,6 +56,9 @@ func NewPublicationAwaiter(ctx context.Context, readCheckpoint func(ctx context. type PublicationAwaiter struct { c *sync.Cond + // Only used for testing coordination + preWaitSignaller chan struct{} + // size, checkpoint, and err keep track of the latest size and checkpoint // (or error) seen by the poller. size uint64 @@ -83,17 +86,20 @@ func (a *PublicationAwaiter) Await(ctx context.Context, future IndexFuture) (Ind span.AddEvent("Waiting for tree growth") a.c.L.Lock() defer a.c.L.Unlock() + if a.preWaitSignaller != nil { + a.preWaitSignaller <- struct{}{} + } for (a.size <= i.Index && a.err == nil) && ctx.Err() == nil { a.c.Wait() } - // Ensure we propogate context done error, if any. - if err := ctx.Err(); err != nil { - a.err = err - } else { - span.AddEvent("Tree covers index") + + // Make sure we report any errors that caused us to stop early + err = a.err + if err == nil { + err = ctx.Err() } - return i, a.checkpoint, a.err + return i, a.checkpoint, err }) } diff --git a/await_test.go b/await_test.go index 89b62669..c077a99d 100644 --- a/await_test.go +++ b/await_test.go @@ -174,7 +174,7 @@ func TestAwait_multiClient(t *testing.T) { testTimeout := 1 * time.Second // Await will time out via this context, causing tests to fail // if the integration condition is never reached. - ctx, cancel := context.WithTimeout(context.Background(), testTimeout) + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) defer cancel() size := uint64(0) @@ -204,8 +204,7 @@ func TestAwait_multiClient(t *testing.T) { <-time.After(15 * time.Millisecond) return Index{Index: index}, nil } - wg.Add(1) - go func() { + wg.Go(func() { i, cpRaw, err := awaiter.Await(ctx, future) if err != nil { t.Errorf("function for %d failed: %v", i.Index, err) @@ -221,9 +220,77 @@ func TestAwait_multiClient(t *testing.T) { t.Errorf("got cp size of %d for index %d", cp.Size, i.Index) } - wg.Done() - }() + }) + } + wg.Wait() +} + +func TestAwait_contextCancel(t *testing.T) { + s, err := note.NewSigner("PRIVATE+KEY+example.com/log/testdata+33d7b496+AeymY/SZAX0jZcJ8enZ5FY1Dz+wTML2yWSkK+9DSF3eg") + if err != nil { + t.Fatal(err) + } + + t.Parallel() + testTimeout := 2 * time.Second + // Await will time out via this context, causing tests to fail + // if the integration condition is never reached. + ctx, cancel := context.WithTimeout(t.Context(), testTimeout) + defer cancel() + + var size atomic.Uint64 + readCheckpoint := func(ctx context.Context) ([]byte, error) { + thisSize := size.Load() + hash := sha256.Sum256(fmt.Append(nil, thisSize)) + cpRaw := log.Checkpoint{ + Origin: "example.com/log/testdata", + Size: thisSize, + Hash: hash[:], + }.Marshal() + n, err := note.Sign(¬e.Note{Text: string(cpRaw)}, s) + if err != nil { + return nil, fmt.Errorf("note.Sign: %w", err) + } + return n, nil } + awaiter := NewPublicationAwaiter(ctx, readCheckpoint, 10*time.Millisecond) + awaiter.preWaitSignaller = make(chan struct{}, 2) + + wg := sync.WaitGroup{} + + // This one should succeed + wg.Go(func() { + future := func() (Index, error) { + return Index{Index: 50}, nil + } + _, _, err := awaiter.Await(ctx, future) + if err != nil { + t.Error(err) + } + }) + + timeoutSuccess := make(chan struct{}) + // This one is expected to hit deadline exceeded + wg.Go(func() { + cctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + defer close(timeoutSuccess) + future := func() (Index, error) { + return Index{Index: 100}, nil + } + _, _, err := awaiter.Await(cctx, future) + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("Expected context deadline exceeded, got %v", err) + } + }) + + // Block until both of the goroutines are waiting + <-awaiter.preWaitSignaller + <-awaiter.preWaitSignaller + // Then wait until the second one has timed out + <-timeoutSuccess + // And finally release the final one for success + size.Store(75) wg.Wait() }