Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 167 additions & 28 deletions internal/integration/unified/client_entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package unified
import (
"context"
"fmt"
"slices"
"strings"
"sync"
"sync/atomic"
Expand All @@ -24,6 +25,7 @@ import (
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
)

// There are no automated tests for truncation. Given that, setting the
Expand All @@ -38,6 +40,40 @@ var securitySensitiveCommands = []string{
"createUser", "updateUser", "copydbgetnonce", "copydbsaslstart", "copydb",
}

// eventSequencer allows for sequence-based event filtering for
// awaitMinPoolSizeMS support.
//
// Per the unified test format spec, when awaitMinPoolSizeMS is specified, any
// CMAP and SDAM events that occur during connection pool initialization
// (before minPoolSize is reached) must be ignored. We track this by
// assigning a monotonically increasing sequence number to each event as it's
// recorded. After pool initialization completes, we set eventCutoffSeq to the
// current sequence number. Event accessors for CMAP and SDAM types then
// filter out any events with sequence <= eventCutoffSeq.
//
// Sequencing is thread-safe to support concurrent operations that may generate
// events (e.g., connection checkouts generating CMAP events).
type eventSequencer struct {
counter atomic.Int64
cutoff atomic.Int64
mu sync.RWMutex
seqByEventType map[monitoringEventType][]int64
}

// setCutoff marks the current sequence as the filtering cutoff point.
func (es *eventSequencer) setCutoff() {
es.cutoff.Store(es.counter.Load())
}

// recordEvent stores the sequence number for a given event type.
func (es *eventSequencer) recordEvent(eventType monitoringEventType) {
next := es.counter.Add(1)

es.mu.Lock()
es.seqByEventType[eventType] = append(es.seqByEventType[eventType], next)
es.mu.Unlock()
}

// clientEntity is a wrapper for a mongo.Client object that also holds additional information required during test
// execution.
type clientEntity struct {
Expand Down Expand Up @@ -72,30 +108,8 @@ type clientEntity struct {

entityMap *EntityMap

logQueue chan orderedLogMessage
}

// awaitMinimumPoolSize waits for the client's connection pool to reach the
// specified minimum size. This is a best effort operation that times out after
// some predefined amount of time to avoid blocking tests indefinitely.
func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize uint64) error {
// Don't spend longer than 500ms awaiting minPoolSize.
awaitCtx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer cancel()

ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-awaitCtx.Done():
return fmt.Errorf("timed out waiting for client to reach minPoolSize")
case <-ticker.C:
if uint64(entity.eventsCount[connectionReadyEvent]) >= minPoolSize {
return nil
}
}
}
logQueue chan orderedLogMessage
eventSequencer eventSequencer
}

func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOptions) (*clientEntity, error) {
Expand All @@ -118,6 +132,9 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp
serverDescriptionChangedEventsCount: make(map[serverDescriptionChangedEventInfo]int32),
entityMap: em,
observeSensitiveCommands: entityOptions.ObserveSensitiveCommands,
eventSequencer: eventSequencer{
seqByEventType: make(map[monitoringEventType][]int64),
},
}
entity.setRecordEvents(true)

Expand Down Expand Up @@ -226,8 +243,9 @@ func newClientEntity(ctx context.Context, em *EntityMap, entityOptions *entityOp
return nil, fmt.Errorf("error creating mongo.Client: %w", err)
}

if entityOptions.AwaitMinPoolSize && clientOpts.MinPoolSize != nil && *clientOpts.MinPoolSize > 0 {
if err := awaitMinimumPoolSize(ctx, entity, *clientOpts.MinPoolSize); err != nil {
if entityOptions.AwaitMinPoolSizeMS != nil && *entityOptions.AwaitMinPoolSizeMS > 0 &&
clientOpts.MinPoolSize != nil && *clientOpts.MinPoolSize > 0 {
if err := awaitMinimumPoolSize(ctx, entity, *clientOpts.MinPoolSize, *entityOptions.AwaitMinPoolSizeMS); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -326,8 +344,39 @@ func (c *clientEntity) failedEvents() []*event.CommandFailedEvent {
return events
}

func (c *clientEntity) poolEvents() []*event.PoolEvent {
return c.pooled
// filterEventsBySeq filters events by sequence number for the given eventType.
// See comments on eventSequencer for more details.
func filterEventsBySeq[T any](c *clientEntity, events []T, eventType monitoringEventType) []T {
cutoff := c.eventSequencer.cutoff.Load()
if cutoff == 0 {
return events
}

// Lock order: eventProcessMu -> eventSequencer.mu (matches writers)
c.eventProcessMu.RLock()
c.eventSequencer.mu.RLock()

// Snapshot to minimize time under locks and avoid races
eventsSnapshot := slices.Clone(events)
seqSnapshot := slices.Clone(c.eventSequencer.seqByEventType[eventType])

c.eventSequencer.mu.RUnlock()
c.eventProcessMu.RUnlock()
Comment on lines +356 to +364
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we need to hold the eventsProcessMu read lock here to prevent race conditions reading events. That requirement is difficult to trace and may lead to bugs if future implementers don't realize the purpose (it's unconventional to require a function to take a lock to read an input param). I recommend taking a read lock before calling filterEventsBySeq.

Copy link
Member Author

@prestonvasquez prestonvasquez Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unconventional to require a function to take a lock to read an input param. I recommend taking a read lock before calling filterEventsBySeq.

Callers would be no less idiomatic:

func verifyCMAPEvents(client *clientEntity, expectedEvents *expectedEvents) error {
	client.eventProcessMu.RLock()
	client.eventSequencer.mu.RLock()

	pooled := filterEventsBySeq(client, client.pooled, poolAnyEvent)

	client.eventSequencer.mu.RUnlock()
	client.eventProcessMu.RUnlock()

        // ...
}

IMO, calling at the "nearest method" would be even more confusing than calling it in filterEventsBySeq.

Additionally, keeping the locking encapsulated in the helper is safer and avoids spreading lock-order constraints across the codebase.

We might be able to make filterEventsBySeq a clientEntity method but it wont be straightforward given that you can't use generics in that case. What are your thoughts?


// guard against index out of range.
n := len(eventsSnapshot)
if len(seqSnapshot) < n {
n = len(seqSnapshot)
}

filtered := make([]T, 0, n)
for i := 0; i < n; i++ {
if seqSnapshot[i] > cutoff {
filtered = append(filtered, eventsSnapshot[i])
}
}

return filtered
}

func (c *clientEntity) numberConnectionsCheckedOut() int32 {
Expand Down Expand Up @@ -516,7 +565,10 @@ func (c *clientEntity) processPoolEvent(evt *event.PoolEvent) {

eventType := monitoringEventTypeFromPoolEvent(evt)
if _, ok := c.observedEvents[eventType]; ok {
c.eventProcessMu.Lock()
c.pooled = append(c.pooled, evt)
c.eventSequencer.recordEvent(poolAnyEvent)
c.eventProcessMu.Unlock()
}

c.addEventsCount(eventType)
Expand All @@ -539,6 +591,7 @@ func (c *clientEntity) processServerDescriptionChangedEvent(evt *event.ServerDes

if _, ok := c.observedEvents[serverDescriptionChangedEvent]; ok {
c.serverDescriptionChanged = append(c.serverDescriptionChanged, evt)
c.eventSequencer.recordEvent(serverDescriptionChangedEvent)
}

// Record object-specific unified spec test data on an event.
Expand All @@ -558,6 +611,7 @@ func (c *clientEntity) processServerHeartbeatFailedEvent(evt *event.ServerHeartb

if _, ok := c.observedEvents[serverHeartbeatFailedEvent]; ok {
c.serverHeartbeatFailedEvent = append(c.serverHeartbeatFailedEvent, evt)
c.eventSequencer.recordEvent(serverHeartbeatFailedEvent)
}

c.addEventsCount(serverHeartbeatFailedEvent)
Expand All @@ -573,6 +627,7 @@ func (c *clientEntity) processServerHeartbeatStartedEvent(evt *event.ServerHeart

if _, ok := c.observedEvents[serverHeartbeatStartedEvent]; ok {
c.serverHeartbeatStartedEvent = append(c.serverHeartbeatStartedEvent, evt)
c.eventSequencer.recordEvent(serverHeartbeatStartedEvent)
}

c.addEventsCount(serverHeartbeatStartedEvent)
Expand All @@ -588,6 +643,7 @@ func (c *clientEntity) processServerHeartbeatSucceededEvent(evt *event.ServerHea

if _, ok := c.observedEvents[serverHeartbeatSucceededEvent]; ok {
c.serverHeartbeatSucceeded = append(c.serverHeartbeatSucceeded, evt)
c.eventSequencer.recordEvent(serverHeartbeatSucceededEvent)
}

c.addEventsCount(serverHeartbeatSucceededEvent)
Expand All @@ -603,6 +659,7 @@ func (c *clientEntity) processTopologyDescriptionChangedEvent(evt *event.Topolog

if _, ok := c.observedEvents[topologyDescriptionChangedEvent]; ok {
c.topologyDescriptionChanged = append(c.topologyDescriptionChanged, evt)
c.eventSequencer.recordEvent(topologyDescriptionChangedEvent)
}

c.addEventsCount(topologyDescriptionChangedEvent)
Expand All @@ -618,6 +675,7 @@ func (c *clientEntity) processTopologyOpeningEvent(evt *event.TopologyOpeningEve

if _, ok := c.observedEvents[topologyOpeningEvent]; ok {
c.topologyOpening = append(c.topologyOpening, evt)
c.eventSequencer.recordEvent(topologyOpeningEvent)
}

c.addEventsCount(topologyOpeningEvent)
Expand All @@ -633,6 +691,7 @@ func (c *clientEntity) processTopologyClosedEvent(evt *event.TopologyClosedEvent

if _, ok := c.observedEvents[topologyClosedEvent]; ok {
c.topologyClosed = append(c.topologyClosed, evt)
c.eventSequencer.recordEvent(topologyClosedEvent)
}

c.addEventsCount(topologyClosedEvent)
Expand Down Expand Up @@ -724,3 +783,83 @@ func evaluateUseMultipleMongoses(clientOpts *options.ClientOptions, useMultipleM
}
return nil
}

// checkAllPoolsReady checks if all connection pools have reached the minimum
// pool size by counting ConnectionReady events per address and comparing
// against the data-bearing servers in the topology.
//
// This approach uses topology events to determine how many servers we expect,
// then verifies each server's pool has reached minPoolSize. This prevents false
// positives where we return early before all servers are discovered.
func checkAllPoolsReady(
pooledEvents []*event.PoolEvent,
topologyEvents []*event.TopologyDescriptionChangedEvent,
minPoolSize uint64,
) bool {
if len(topologyEvents) == 0 {
return false
}

// Use the most recent topology description
latestTopology := topologyEvents[len(topologyEvents)-1].NewDescription

// Get addresses of data-bearing servers from topology
expectedServers := make(map[string]bool)
for _, server := range latestTopology.Servers {
// Only track data-bearing servers
if server.Kind != "" &&
server.Kind != description.ServerKindRSArbiter.String() &&
server.Kind != description.ServerKindRSGhost.String() {
expectedServers[server.Addr.String()] = true
}
}

// If no data-bearing servers yet, not ready
if len(expectedServers) == 0 {
return false
}

// Count ConnectionReady events per address
readyByAddress := make(map[string]int)
for _, evt := range pooledEvents {
if evt.Type == event.ConnectionReady {
readyByAddress[evt.Address]++
}
}

// Check if all expected servers have reached minPoolSize
for address := range expectedServers {
if uint64(readyByAddress[address]) < minPoolSize {
return false
}
}

return true
}

// awaitMinimumPoolSize waits for each of the client's connection pools to reach
// the specified minimum size, then clears all CMAP and SDAM events that
// occurred during pool initialization.
func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize uint64, timeoutMS int) error {
awaitCtx, cancel := context.WithTimeout(ctx, time.Duration(timeoutMS)*time.Millisecond)
defer cancel()

ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-awaitCtx.Done():
return fmt.Errorf("timed out waiting for client to reach minPoolSize")
case <-ticker.C:
entity.eventProcessMu.RLock()
ready := checkAllPoolsReady(entity.pooled, entity.topologyDescriptionChanged, minPoolSize)
entity.eventProcessMu.RUnlock()

if ready {
entity.eventSequencer.setCutoff()
return nil
}
}
}
}
Loading
Loading