diff --git a/internal/assert/assertbson/assertbson.go b/internal/assert/assertbson/assertbson.go new file mode 100644 index 0000000000..247d33ac37 --- /dev/null +++ b/internal/assert/assertbson/assertbson.go @@ -0,0 +1,55 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package assertbson + +import ( + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" +) + +type tHelper interface { + Helper() +} + +// EqualDocument asserts that the expected and actual BSON documents are equal. +// If the documents are not equal, it prints both the binary diff and Extended +// JSON representation of the BSON documents. +func EqualDocument(t assert.TestingT, expected, actual []byte) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + return assert.Equal(t, + expected, + actual, + `expected and actual BSON documents do not match +As Extended JSON: +Expected: %s +Actual : %s`, + bson.Raw(expected), + bson.Raw(actual)) +} + +// EqualValue asserts that the expected and actual BSON values are equal. If the +// values are not equal, it prints both the binary diff and Extended JSON +// representation of the BSON values. +func EqualValue[T bson.RawValue | bsoncore.Value](t assert.TestingT, expected, actual T) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + return assert.Equal(t, + expected, + actual, + `expected and actual BSON values do not match +As Extended JSON: +Expected: %s +Actual : %s`, + expected, + actual) +} diff --git a/internal/assert/assertbson/assertbson_test.go b/internal/assert/assertbson/assertbson_test.go new file mode 100644 index 0000000000..e553795750 --- /dev/null +++ b/internal/assert/assertbson/assertbson_test.go @@ -0,0 +1,249 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package assertbson + +import ( + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" +) + +func TestEqualDocument(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + expected []byte + actual []byte + want bool + }{ + { + name: "equal bson.Raw", + expected: bson.Raw{5, 0, 0, 0, 0}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: true, + }, + { + name: "different bson.Raw", + expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: false, + }, + { + name: "invalid bson.Raw", + expected: bson.Raw{99, 99, 99, 99}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: false, + }, + { + name: "nil bson.Raw", + expected: bson.Raw(nil), + actual: bson.Raw(nil), + want: true, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := EqualDocument(new(testing.T), tc.expected, tc.actual) + if got != tc.want { + t.Errorf("EqualDocument(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want) + } + }) + } +} + +func TestEqualValue(t *testing.T) { + t.Parallel() + + t.Run("bson.RawValue", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + expected bson.RawValue + actual bson.RawValue + want bool + }{ + { + name: "equal", + expected: bson.RawValue{ + Type: bson.TypeInt32, + Value: []byte{1, 0, 0, 0}, + }, + actual: bson.RawValue{ + Type: bson.TypeInt32, + Value: []byte{1, 0, 0, 0}, + }, + want: true, + }, + { + name: "same type, different value", + expected: bson.RawValue{ + Type: bson.TypeInt32, + Value: []byte{1, 0, 0, 0}, + }, + actual: bson.RawValue{ + Type: bson.TypeInt32, + Value: []byte{1, 1, 1, 1}, + }, + want: false, + }, + { + name: "same value, different type", + expected: bson.RawValue{ + Type: bson.TypeDouble, + Value: []byte{1, 0, 0, 0, 0, 0, 0, 0}, + }, + actual: bson.RawValue{ + Type: bson.TypeInt64, + Value: []byte{1, 0, 0, 0, 0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different value, different type", + expected: bson.RawValue{ + Type: bson.TypeInt32, + Value: []byte{1, 0, 0, 0}, + }, + actual: bson.RawValue{ + Type: bson.TypeString, + Value: []byte{1, 1, 1, 1}, + }, + want: false, + }, + { + name: "invalid", + expected: bson.RawValue{ + Type: bson.TypeInt64, + Value: []byte{1, 0, 0, 0}, + }, + actual: bson.RawValue{ + Type: bson.TypeInt32, + Value: []byte{1, 0, 0, 0}, + }, + want: false, + }, + { + name: "empty", + expected: bson.RawValue{}, + actual: bson.RawValue{}, + want: true, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := EqualValue(new(testing.T), tc.expected, tc.actual) + if got != tc.want { + t.Errorf("EqualValue(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want) + } + }) + } + }) + + t.Run("bsoncore.Value", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + expected bsoncore.Value + actual bsoncore.Value + want bool + }{ + { + name: "equal", + expected: bsoncore.Value{ + Type: bsoncore.TypeInt32, + Data: []byte{1, 0, 0, 0}, + }, + actual: bsoncore.Value{ + Type: bsoncore.TypeInt32, + Data: []byte{1, 0, 0, 0}, + }, + want: true, + }, + { + name: "same type, different value", + expected: bsoncore.Value{ + Type: bsoncore.TypeInt32, + Data: []byte{1, 0, 0, 0}, + }, + actual: bsoncore.Value{ + Type: bsoncore.TypeInt32, + Data: []byte{1, 1, 1, 1}, + }, + want: false, + }, + { + name: "same value, different type", + expected: bsoncore.Value{ + Type: bsoncore.TypeDouble, + Data: []byte{1, 0, 0, 0, 0, 0, 0, 0}, + }, + actual: bsoncore.Value{ + Type: bsoncore.TypeInt64, + Data: []byte{1, 0, 0, 0, 0, 0, 0, 0}, + }, + want: false, + }, + { + name: "different value, different type", + expected: bsoncore.Value{ + Type: bsoncore.TypeInt32, + Data: []byte{1, 0, 0, 0}, + }, + actual: bsoncore.Value{ + Type: bsoncore.TypeString, + Data: []byte{1, 1, 1, 1}, + }, + want: false, + }, + { + name: "invalid", + expected: bsoncore.Value{ + Type: bsoncore.TypeInt64, + Data: []byte{1, 0, 0, 0}, + }, + actual: bsoncore.Value{ + Type: bsoncore.TypeInt32, + Data: []byte{1, 0, 0, 0}, + }, + want: false, + }, + { + name: "empty", + expected: bsoncore.Value{}, + actual: bsoncore.Value{}, + want: true, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := EqualValue(new(testing.T), tc.expected, tc.actual) + if got != tc.want { + t.Errorf("EqualValue(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want) + } + }) + } + }) +} diff --git a/internal/assert/assertion_mongo.go b/internal/assert/assertion_mongo.go index e47fdf93e1..76e5c37e5e 100644 --- a/internal/assert/assertion_mongo.go +++ b/internal/assert/assertion_mongo.go @@ -11,7 +11,6 @@ package assert import ( "context" - "fmt" "reflect" "time" "unsafe" @@ -71,26 +70,6 @@ func DifferentAddressRanges(t TestingT, a, b []byte) (ok bool) { return false } -// EqualBSON asserts that the expected and actual BSON binary values are equal. -// If the values are not equal, it prints both the binary and Extended JSON diff -// of the BSON values. The provided BSON value types must implement the -// fmt.Stringer interface. -func EqualBSON(t TestingT, expected, actual interface{}) bool { - if h, ok := t.(tHelper); ok { - h.Helper() - } - - return Equal(t, - expected, - actual, - `expected and actual BSON values do not match -As Extended JSON: -Expected: %s -Actual : %s`, - expected.(fmt.Stringer).String(), - actual.(fmt.Stringer).String()) -} - // Soon runs the provided callback and fails the passed-in test if the callback // does not complete within timeout. The provided callback should respect the // passed-in context and cease execution when it has expired. diff --git a/internal/assert/assertion_mongo_test.go b/internal/assert/assertion_mongo_test.go index 9fe6f485d5..490adcef3d 100644 --- a/internal/assert/assertion_mongo_test.go +++ b/internal/assert/assertion_mongo_test.go @@ -8,8 +8,6 @@ package assert import ( "testing" - - "go.mongodb.org/mongo-driver/v2/bson" ) func TestDifferentAddressRanges(t *testing.T) { @@ -74,52 +72,3 @@ func TestDifferentAddressRanges(t *testing.T) { }) } } - -func TestEqualBSON(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - expected interface{} - actual interface{} - want bool - }{ - { - name: "equal bson.Raw", - expected: bson.Raw{5, 0, 0, 0, 0}, - actual: bson.Raw{5, 0, 0, 0, 0}, - want: true, - }, - { - name: "different bson.Raw", - expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0}, - actual: bson.Raw{5, 0, 0, 0, 0}, - want: false, - }, - { - name: "invalid bson.Raw", - expected: bson.Raw{99, 99, 99, 99}, - actual: bson.Raw{5, 0, 0, 0, 0}, - want: false, - }, - { - name: "nil bson.Raw", - expected: bson.Raw(nil), - actual: bson.Raw(nil), - want: true, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - got := EqualBSON(new(testing.T), tc.expected, tc.actual) - if got != tc.want { - t.Errorf("EqualBSON(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want) - } - }) - } -} diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 5cd4cec180..b51ea5bd10 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -12,6 +12,7 @@ import ( "net" "os" "reflect" + "runtime" "strings" "testing" "time" @@ -19,9 +20,9 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/assert/assertbson" "go.mongodb.org/mongo-driver/v2/internal/eventtest" "go.mongodb.org/mongo-driver/v2/internal/failpoint" - "go.mongodb.org/mongo-driver/v2/internal/handshake" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" "go.mongodb.org/mongo-driver/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/internal/require" @@ -29,6 +30,7 @@ import ( "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" + "go.mongodb.org/mongo-driver/v2/version" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" @@ -457,26 +459,27 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) assert.Nil(mt, err, "Ping error: %v", err) - msgPairs := mt.GetProxiedMessages() - assert.True(mt, len(msgPairs) >= 2, "expected at least 2 events sent, got %v", len(msgPairs)) + want := mustMarshalBSON(bson.D{ + {Key: "application", Value: bson.D{ + bson.E{Key: "name", Value: "foo"}, + }}, + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + }) - // First two messages should be connection handshakes: one for the heartbeat connection and the other for the - // application connection. - for idx, pair := range msgPairs[:2] { - helloCommand := handshake.LegacyHello - // Expect "hello" command name with API version. - if os.Getenv("REQUIRE_API_VERSION") == "true" { - helloCommand = "hello" - } - assert.Equal(mt, pair.CommandName, helloCommand, "expected command name %s at index %d, got %s", helloCommand, idx, - pair.CommandName) - - sent := pair.Sent - appNameVal, err := sent.Command.LookupErr("client", "application", "name") - assert.Nil(mt, err, "expected command %s at index %d to contain app name", sent.Command, idx) - appName := appNameVal.StringValue() - assert.Equal(mt, testAppName, appName, "expected app name %v at index %d, got %v", testAppName, idx, - appName) + for i := 0; i < 2; i++ { + message := mt.GetProxyCapture().TryNext() + require.NotNil(mt, message, "expected handshake message, got nil") + + clientMetadata := clientMetadataFromHandshake(mt, message.Sent.Command) + assertbson.EqualDocument(mt, want, clientMetadata) } }) @@ -605,24 +608,32 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) assert.Nil(mt, err, "Ping error: %v", err) - msgPairs := mt.GetProxiedMessages() - assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs)) + proxyCapture := mt.GetProxyCapture() // The first message should be a connection handshake. - pair := msgPairs[0] - assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s", - handshake.LegacyHello, pair.CommandName) - assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode, - "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String()) - - // Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire - // version is now known to be >= 6. + firstMessage := proxyCapture.TryNext() + require.NotNil(mt, firstMessage, "expected handshake message, got nil") + + assert.True(t, firstMessage.IsHandshake()) + + opCode := firstMessage.Sent.OpCode + assert.Equal(mt, wiremessage.OpQuery, opCode, + "expected 'OP_MSG' OpCode in wire message, got %q", opCode.String()) + + // Look for a saslContinue in the remaining proxied messages and assert that + // it uses the OP_MSG OpCode, as wire version is now known to be >= 6. var saslContinueFound bool - for _, pair := range msgPairs[1:] { - if pair.CommandName == "saslContinue" { + for { + message := proxyCapture.TryNext() + if message == nil { + break + } + + if message.CommandName == "saslContinue" { saslContinueFound = true - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %s", pair.Sent.OpCode.String()) + opCode := message.Sent.OpCode + assert.Equal(mt, wiremessage.OpMsg, opCode, + "expected 'OP_MSG' OpCode in wire message, got %q", opCode.String()) break } } @@ -635,18 +646,18 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) assert.Nil(mt, err, "Ping error: %v", err) - msgPairs := mt.GetProxiedMessages() - assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs)) - // First three messages should be connection handshakes: one for the heartbeat connection, another for the // application connection, and a final one for the RTT monitor connection. - for idx, pair := range msgPairs[:3] { - assert.Equal(mt, "hello", pair.CommandName, "expected command name 'hello' at index %d, got %s", idx, - pair.CommandName) + for idx := 0; idx < 3; idx++ { + message := mt.GetProxyCapture().TryNext() + require.NotNil(mt, message, "expected handshake message, got nil") + + assert.True(t, message.IsHandshake()) // Assert that appended OpCode is OP_MSG when API version is set. - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + opCode := message.Sent.OpCode + assert.Equal(mt, wiremessage.OpMsg, opCode, + "expected 'OP_MSG' OpCode in wire message, got %q", opCode.String()) } }) @@ -1172,7 +1183,7 @@ func TestClient_BSONOptions(t *testing.T) { got, err := sr.Raw() require.NoError(mt, err, "Raw error") - assert.EqualBSON(mt, tc.wantRaw, got) + assertbson.EqualDocument(mt, tc.wantRaw, got) } }) } diff --git a/internal/integration/handshake_test.go b/internal/integration/handshake_test.go index f4c449e30e..d2756faa43 100644 --- a/internal/integration/handshake_test.go +++ b/internal/integration/handshake_test.go @@ -9,12 +9,13 @@ package integration import ( "context" "os" - "reflect" "runtime" "testing" + "time" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/assert/assertbson" "go.mongodb.org/mongo-driver/v2/internal/handshake" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" "go.mongodb.org/mongo-driver/v2/internal/require" @@ -35,40 +36,6 @@ func TestHandshakeProse(t *testing.T) { CreateCollection(false). ClientType(mtest.Proxy) - clientMetadata := func(env bson.D, info *options.DriverInfo) bson.D { - var ( - driverName = "mongo-go-driver" - driverVersion = version.Driver - platform = runtime.Version() - ) - - if info != nil { - driverName = driverName + "|" + info.Name - driverVersion = driverVersion + "|" + info.Version - platform = platform + "|" + info.Platform - } - - elems := bson.D{ - {Key: "driver", Value: bson.D{ - {Key: "name", Value: driverName}, - {Key: "version", Value: driverVersion}, - }}, - {Key: "os", Value: bson.D{ - {Key: "type", Value: runtime.GOOS}, - {Key: "architecture", Value: runtime.GOARCH}, - }}, - } - - elems = append(elems, bson.E{Key: "platform", Value: platform}) - - // If env is empty, don't include it in the metadata. - if env != nil && !reflect.DeepEqual(env, bson.D{}) { - elems = append(elems, bson.E{Key: "env", Value: env}) - } - - return elems - } - driverInfo := &options.DriverInfo{ Name: "outer-library-name", Version: "outer-library-version", @@ -88,11 +55,11 @@ func TestHandshakeProse(t *testing.T) { t.Setenv("FUNCTION_REGION", "") t.Setenv("VERCEL_REGION", "") - for _, test := range []struct { + testCases := []struct { name string env map[string]string opts *options.ClientOptions - want bson.D + want []byte }{ { name: "1. valid AWS", @@ -102,11 +69,22 @@ func TestHandshakeProse(t *testing.T) { "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "aws.lambda"}, - {Key: "memory_mb", Value: 1024}, - {Key: "region", Value: "us-east-2"}, - }, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + {Key: "env", Value: bson.D{ + bson.E{Key: "name", Value: "aws.lambda"}, + bson.E{Key: "memory_mb", Value: 1024}, + bson.E{Key: "region", Value: "us-east-2"}, + }}, + }), }, { name: "2. valid Azure", @@ -114,9 +92,20 @@ func TestHandshakeProse(t *testing.T) { "FUNCTIONS_WORKER_RUNTIME": "node", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "azure.func"}, - }, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + {Key: "env", Value: bson.D{ + bson.E{Key: "name", Value: "azure.func"}, + }}, + }), }, { name: "3. valid GCP", @@ -127,12 +116,23 @@ func TestHandshakeProse(t *testing.T) { "FUNCTION_REGION": "us-central1", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "gcp.func"}, - {Key: "memory_mb", Value: 1024}, - {Key: "region", Value: "us-central1"}, - {Key: "timeout_sec", Value: 60}, - }, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + {Key: "env", Value: bson.D{ + bson.E{Key: "name", Value: "gcp.func"}, + bson.E{Key: "memory_mb", Value: 1024}, + bson.E{Key: "region", Value: "us-central1"}, + bson.E{Key: "timeout_sec", Value: int32(60)}, + }}, + }), }, { name: "4. valid Vercel", @@ -141,10 +141,21 @@ func TestHandshakeProse(t *testing.T) { "VERCEL_REGION": "cdg1", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "vercel"}, - {Key: "region", Value: "cdg1"}, - }, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + {Key: "env", Value: bson.D{ + bson.E{Key: "name", Value: "vercel"}, + bson.E{Key: "region", Value: "cdg1"}, + }}, + }), }, { name: "5. invalid multiple providers", @@ -153,7 +164,17 @@ func TestHandshakeProse(t *testing.T) { "FUNCTIONS_WORKER_RUNTIME": "node", }, opts: nil, - want: clientMetadata(nil, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + }), }, { name: "6. invalid long string", @@ -168,9 +189,20 @@ func TestHandshakeProse(t *testing.T) { }(), }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "aws.lambda"}, - }, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + {Key: "env", Value: bson.D{ + {Key: "name", Value: "aws.lambda"}, + }}, + }), }, { name: "7. invalid wrong types", @@ -179,9 +211,20 @@ func TestHandshakeProse(t *testing.T) { "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "aws.lambda"}, - }, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + {Key: "env", Value: bson.D{ + {Key: "name", Value: "aws.lambda"}, + }}, + }), }, { name: "8. Invalid - AWS_EXECUTION_ENV does not start with \"AWS_Lambda_\"", @@ -189,51 +232,58 @@ func TestHandshakeProse(t *testing.T) { "AWS_EXECUTION_ENV": "EC2", }, opts: nil, - want: clientMetadata(nil, nil), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + }), }, { name: "driver info included", opts: options.Client().SetDriverInfo(driverInfo), - want: clientMetadata(nil, driverInfo), + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|outer-library-name"}, + {Key: "version", Value: version.Driver + "|outer-library-version"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|outer-library-platform"}, + }), }, - } { - test := test + } - mt.RunOpts(test.name, opts, func(mt *mtest.T) { - for k, v := range test.env { + for _, tc := range testCases { + tc := tc // Avoid implicit memory aliasing in for loop. + + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + for k, v := range tc.env { mt.Setenv(k, v) } - if test.opts != nil { - mt.ResetClient(test.opts) + if tc.opts != nil { + mt.ResetClient(tc.opts) } // Ping the server to ensure the handshake has completed. err := mt.Client.Ping(context.Background(), nil) require.NoError(mt, err, "Ping error: %v", err) - messages := mt.GetProxiedMessages() - handshakeMessage := messages[:1][0] - - hello := handshake.LegacyHello - if os.Getenv("REQUIRE_API_VERSION") == "true" { - hello = "hello" - } - - assert.Equal(mt, hello, handshakeMessage.CommandName) - - // Lookup the "client" field in the command document. - clientVal, err := handshakeMessage.Sent.Command.LookupErr("client") - require.NoError(mt, err, "expected command %s to contain client field", handshakeMessage.Sent.Command) + firstMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, firstMessage, "expected to capture a proxied message") - got, ok := clientVal.DocumentOK() - require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type) + assert.True(mt, firstMessage.IsHandshake(), "expected first message to be a handshake") - wantBytes, err := bson.Marshal(test.want) - require.NoError(mt, err, "error marshaling want document: %v", err) - - want := bsoncore.Document(wantBytes) - assert.Equal(mt, want, got, "want: %v, got: %v", want, got) + clientMetadata := clientMetadataFromHandshake(mt, firstMessage.Sent.Command) + assertbson.EqualDocument(mt, tc.want, clientMetadata) }) } } @@ -249,13 +299,13 @@ func TestLoadBalancedConnectionHandshake(t *testing.T) { err := mt.Client.Ping(context.Background(), nil) require.NoError(mt, err, "Ping error: %v", err) - messages := mt.GetProxiedMessages() - handshakeMessage := messages[:1][0] + firstMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, firstMessage, "expected to capture a proxied message") // Per the specifications, if loadBalanced=true, drivers MUST use the hello // command for the initial handshake and use the OP_MSG protocol. - assert.Equal(mt, "hello", handshakeMessage.CommandName) - assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode) + assert.Equal(mt, "hello", firstMessage.CommandName) + assert.Equal(mt, wiremessage.OpMsg, firstMessage.Sent.OpCode) }) opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies( @@ -269,8 +319,8 @@ func TestLoadBalancedConnectionHandshake(t *testing.T) { err := mt.Client.Ping(context.Background(), nil) require.NoError(mt, err, "Ping error: %v", err) - messages := mt.GetProxiedMessages() - handshakeMessage := messages[:1][0] + firstMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, firstMessage, "expected to capture a proxied message") want := wiremessage.OpQuery @@ -283,7 +333,946 @@ func TestLoadBalancedConnectionHandshake(t *testing.T) { want = wiremessage.OpMsg } - assert.Equal(mt, hello, handshakeMessage.CommandName) - assert.Equal(mt, want, handshakeMessage.Sent.OpCode) + assert.Equal(mt, hello, firstMessage.CommandName, "expected first message to be a handshake") + assert.Equal(mt, want, firstMessage.Sent.OpCode) }) } + +// Test 1: Test that the driver updates metadata +// Test 2: Multiple Successive Metadata Updates +// Test 3: Multiple Successive Metadata Updates with Duplicate Data +func TestHandshakeProse_AppendMetadata_Test1_Test2_Test3(t *testing.T) { + mt := mtest.New(t) + + initialDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + testCases := []struct { + name string + driverInfo options.DriverInfo + want []byte + + // append initialDriverInfo using client.AppendDriverInfo instead of as a + // client-level constructor. + append bool + }{ + { + name: "test1.1: append new driver info", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Framework Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }), + append: false, + }, + { + name: "test1.2: append with no platform", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: false, + }, + { + name: "test1.3: append with no version", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "Framework Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }), + append: false, + }, + { + name: "test1.4: append with name only", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: false, + }, + { + name: "test2.1: append new driver info after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Framework Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }), + append: true, + }, + { + name: "test2.2: append with no platform after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: true, + }, + { + name: "test2.3: append with no version after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "Framework Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }), + append: true, + }, + { + name: "test2.4: append with name only after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: true, + }, + { + name: "test3.1: same driver info after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: true, + }, + { + name: "test3.2: same version and platform after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "1.2", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: true, + }, + { + name: "test3.3: same name and platform after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "2.0", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: true, + }, + { + name: "test3.4: same name and version after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Framework Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }), + append: true, + }, + { + name: "test3.5: same platform after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + append: true, + }, + { + name: "test3.6: same version after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "1.2", + Platform: "Framework Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }), + append: true, + }, + { + name: "test3.7: same name after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "2.0", + Platform: "Framework Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }), + append: true, + }, + } + + for _, tc := range testCases { + tc := tc // Avoid implicit memory aliasing in for loop. + + // Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + opts := mtest.NewOptions().CreateClient(false).ClientType(mtest.Proxy) + + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond) + + if !tc.append { + clientOpts = clientOpts.SetDriverInfo(&initialDriverInfo) + } + + mt.ResetClient(clientOpts) + + if tc.append { + mt.Client.AppendDriverInfo(initialDriverInfo) + } + + // Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // Save intercepted `client` document as `initialClientMetadata`. + initialClientMetadata := mt.GetProxyCapture().TryNext() + + require.NotNil(mt, initialClientMetadata, "expected to capture a proxied message") + assert.True(mt, initialClientMetadata.IsHandshake(), "expected first message to be a handshake") + + // Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + mt.Client.AppendDriverInfo(tc.driverInfo) + + // Drain the proxy + mt.GetProxyCapture().Drain() + + // Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // Capture the first message sent after appending driver info. + gotMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, gotMessage, "expected to capture a proxied message") + assert.True(mt, gotMessage.IsHandshake(), "expected first message to be a handshake") + + clientMetadata := clientMetadataFromHandshake(mt, gotMessage.Sent.Command) + assertbson.EqualDocument(mt, tc.want, clientMetadata) + }) + } +} + +// Test 4: Multiple Metadata Updates with Duplicate Data. +func TestHandshakeProse_AppendMetadata_MultipleUpdatesWithDuplicateFields(t *testing.T) { + opts := mtest.NewOptions().ClientType(mtest.Proxy) + mt := mtest.New(t, opts) + + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond) + + // 1. Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + mt.ResetClient(clientOpts) + + originalDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + // 2. Append initial driver info using client.AppendDriverInfo. + mt.Client.AppendDriverInfo(originalDriverInfo) + + // 3. Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 4. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 5. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Framework Platform", + }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 6. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 7. Save intercepted `client` document as `clientMetadata`. + // + // NOTE: The Go Driver statically defineds the expected client + // metadata value to make the tests more reliable and prevent + // false-positive assertion results. That deviates from the prose + // test. + want := mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2|2.0"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }) + + // 8. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 9. Append the original driver info again. + mt.Client.AppendDriverInfo(originalDriverInfo) + + // 10. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 11. Save intercepted `client` document as `clientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + clientMetadata := clientMetadataFromHandshake(mt, updatedClientMetadata.Sent.Command) + assertbson.EqualDocument(mt, want, clientMetadata) +} + +// Test 5: Metadata is not appended if identical to initial metadata +func TestHandshakeProse_AppendMetadata_NotAppendedIfIdentical(t *testing.T) { + opts := mtest.NewOptions().ClientType(mtest.Proxy) + mt := mtest.New(t, opts) + + originalDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond). + SetDriverInfo(&originalDriverInfo) + + // 1. Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + mt.ResetClient(clientOpts) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 2. Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // NOTE: The Go Driver statically defineds the expected client + // metadata value to make the tests more reliable and prevent + // false-positive assertion results. That deviates from the prose + // test. + want := mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }) + + // 3. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 5. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 6. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 7. Save intercepted `client` document as `updatedClientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + clientMetadata := clientMetadataFromHandshake(mt, updatedClientMetadata.Sent.Command) + assertbson.EqualDocument(mt, want, clientMetadata) + +} + +// Test 6: Metadata is not appended if identical to initial metadata (separated +// by non-identical metadata) +func TestHandshakeProse_AppendMetadata_NotAppendedIfIdentical_NonSequential(t *testing.T) { + opts := mtest.NewOptions().ClientType(mtest.Proxy) + mt := mtest.New(t, opts) + + originalDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond). + SetDriverInfo(&originalDriverInfo) + + // 1. Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + mt.ResetClient(clientOpts) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 2. Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 3. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 4. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "framework", + Version: "1.2", + Platform: "Framework Platform", + }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 5. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 6. Save intercepted `client` document as `clientMetadata`. + // + // NOTE: The Go Driver statically defineds the expected client + // metadata value to make the tests more reliable and prevent + // false-positive assertion results. That deviates from the prose + // test. + want := mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library|framework"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform|Framework Platform"}, + }) + + // 7. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 8. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 9. Send a `ping` command to the server and verify that the command + // succeeds. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 10. Save intercepted `client` document as `updatedClientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + clientMetadata := clientMetadataFromHandshake(mt, updatedClientMetadata.Sent.Command) + assertbson.EqualDocument(mt, want, clientMetadata) +} + +// Test 7: Empty strings are considered unset when appending duplicate metadata. +func TestHandshakeProse_AppendMetadata_EmptyStrings(t *testing.T) { + mt := mtest.New(t) + + testCases := []struct { + name string + initialDriverInfo options.DriverInfo + toAppendDriverInfo options.DriverInfo + want []byte + }{ + { + name: "name empty", + initialDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + }, + { + name: "version empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + }, + { + name: "platform empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + }), + }, + } + + for _, tc := range testCases { + tc := tc // Avoid implicit memory aliasing in for loop. + + // Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + opts := mtest.NewOptions().CreateClient(false).ClientType(mtest.Proxy) + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + // 1. Create a `MongoClient` instance. + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond) + + mt.ResetClient(clientOpts) + + // 2. Append the `DriverInfoOptions` from the selected test case from + // the initial metadata section. + mt.Client.AppendDriverInfo(tc.initialDriverInfo) + + mt.GetProxyCapture().Drain() + + // 3. Send a `ping` command to the server and verify that the command + // succeeds. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 4. Save intercepted `client` document as `initialClientMetadata`. + // + // NOTE: The Go Driver statically defineds the expected client + // metadata value to make the tests more reliable and prevent + // false-positive assertion results. That deviates from the prose + // test. + // + // See the "want" field in each test case for the expected client + // metadata value. + + // 5. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 6. Append the `DriverInfoOptions` from the selected test case from + // the appended metadata section. + mt.Client.AppendDriverInfo(tc.toAppendDriverInfo) + + // Drain the proxy + mt.GetProxyCapture().Drain() + + // 7. Send a `ping` command to the server and verify the command + // succeeds. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // Capture the first message sent after appending driver info. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + clientMetadata := clientMetadataFromHandshake(mt, updatedClientMetadata.Sent.Command) + assertbson.EqualDocument(mt, tc.want, clientMetadata) + }) + } +} + +// Test 8: Empty strings are considered unset when appending metadata identical +// to initial metadata +func TestHandshakeProse_AppendMetadata_EmptyStrings_InitializedClient(t *testing.T) { + mt := mtest.New(t) + + testCases := []struct { + name string + initialDriverInfo options.DriverInfo + toAppendDriverInfo options.DriverInfo + want []byte + }{ + { + name: "name empty", + initialDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + }, + { + name: "version empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version() + "|Library Platform"}, + }), + }, + { + name: "platform empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + want: mustMarshalBSON(bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: "mongo-go-driver|library"}, + {Key: "version", Value: version.Driver + "|1.2"}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + {Key: "platform", Value: runtime.Version()}, + }), + }, + } + + for _, tc := range testCases { + tc := tc // Avoid implicit memory aliasing in for loop. + + // Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + opts := mtest.NewOptions().CreateClient(false).ClientType(mtest.Proxy) + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + // 1. Create a `MongoClient` instance. + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond). + SetDriverInfo(&tc.initialDriverInfo) + + mt.ResetClient(clientOpts) + + // 2. Send a `ping` command to the server and verify that the command + // succeeds. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 3. Save intercepted `client` document as `initialClientMetadata`. + // + // NOTE: The Go Driver statically defineds the expected client + // metadata value to make the tests more reliable and prevent + // false-positive assertion results. That deviates from the prose + // test. + // + // See the "want" field in each test case for the expected client + // metadata value. + + // 4. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 5. Append the `DriverInfoOptions` from the selected test case from + // the appended metadata section. + mt.Client.AppendDriverInfo(tc.toAppendDriverInfo) + + // Drain the proxy + mt.GetProxyCapture().Drain() + + // 6. Send a `ping` command to the server and verify the command + // succeeds. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 7. Store the response as `updatedClientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + // 8. Assert that `initialClientMetadata` is identical to `updatedClientMetadata`. + clientMetadata := clientMetadataFromHandshake(mt, updatedClientMetadata.Sent.Command) + assertbson.EqualDocument(mt, tc.want, clientMetadata) + }) + } +} + +// mustMarshalBSON marshals a value to BSON. It panics if any error occurs. +func mustMarshalBSON(val interface{}) []byte { + bytes, err := bson.Marshal(val) + if err != nil { + panic(err) + } + + return bytes +} + +// clientMetadataFromHandshake returns the BSON document from the "client" field +// of the command document. +func clientMetadataFromHandshake(mt *mtest.T, cmd bsoncore.Document) []byte { + mt.Helper() + + client, err := cmd.LookupErr("client") + require.NoError(mt, err, "no client field in handshake command document") + + clientDoc, ok := client.DocumentOK() + require.True(mt, ok, "the client field is not a BSON document") + + return clientDoc +} diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index 07c67a4ce0..6d1b1f1255 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -337,13 +337,13 @@ func (t *T) FilterFailedEvents(filter func(*event.CommandFailedEvent) bool) { t.failed = newEvents } -// GetProxiedMessages returns the messages proxied to the server by the test. If the client type is not Proxy, this -// returns nil. -func (t *T) GetProxiedMessages() []*ProxyMessage { +// GetProxyCapture returns the ProxyCapture used by the test. If the client +// type is not Proxy, this returns nil. +func (t *T) GetProxyCapture() *ProxyCapture { if t.proxyDialer == nil { return nil } - return t.proxyDialer.Messages() + return t.proxyDialer.proxyCapture } // NumberConnectionsCheckedOut returns the number of connections checked out from the test Client. diff --git a/internal/integration/mtest/proxy_capture.go b/internal/integration/mtest/proxy_capture.go new file mode 100644 index 0000000000..5cd43cc984 --- /dev/null +++ b/internal/integration/mtest/proxy_capture.go @@ -0,0 +1,53 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "sync" +) + +// ProxyCapture provides a FIFO channel for handshake messages passed +// through the mtest proxyDialer. +type ProxyCapture struct { + messages chan *ProxyMessage + mu sync.Mutex +} + +func newProxyCapture(bufferSize int) *ProxyCapture { + return &ProxyCapture{ + messages: make(chan *ProxyMessage, bufferSize), + } +} + +func (hc *ProxyCapture) Capture(msg *ProxyMessage) { + hc.mu.Lock() + defer hc.mu.Unlock() + + hc.messages <- msg +} + +func (hc *ProxyCapture) TryNext() *ProxyMessage { + select { + case msg := <-hc.messages: + return msg + default: + return nil + } +} + +// Drain removes all messages from the channel and returns them as a slice. +func (hc *ProxyCapture) Drain() []*ProxyMessage { + messages := []*ProxyMessage{} + for { + select { + case msg := <-hc.messages: + messages = append(messages, msg) + default: + return messages + } + } +} diff --git a/internal/integration/mtest/proxy_dialer.go b/internal/integration/mtest/proxy_dialer.go index 7f17dbbdb1..0d980c406c 100644 --- a/internal/integration/mtest/proxy_dialer.go +++ b/internal/integration/mtest/proxy_dialer.go @@ -11,9 +11,11 @@ import ( "errors" "fmt" "net" + "os" "sync" "time" + "go.mongodb.org/mongo-driver/v2/internal/handshake" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -32,7 +34,6 @@ type proxyDialer struct { *net.Dialer sync.Mutex - messages []*ProxyMessage // sentMap temporarily stores the message sent to the server using the requestID so it can map requests to their // responses. sentMap sync.Map @@ -40,13 +41,16 @@ type proxyDialer struct { // differ. This can happen if a connection is dialed to a host name, in which case the reported remote address will // be the resolved IP address. addressTranslations sync.Map + + proxyCapture *ProxyCapture } var _ options.ContextDialer = (*proxyDialer)(nil) func newProxyDialer() *proxyDialer { return &proxyDialer{ - Dialer: &net.Dialer{Timeout: 30 * time.Second}, + Dialer: &net.Dialer{Timeout: 30 * time.Second}, + proxyCapture: newProxyCapture(100), } } @@ -121,21 +125,10 @@ func (p *proxyDialer) storeReceivedMessage(wm []byte, addr string) error { Sent: sent, Received: parsed, } - p.messages = append(p.messages, msgPair) + p.proxyCapture.Capture(msgPair) return nil } -// Messages returns a slice of proxied messages. This slice is a copy of the messages proxied so far and will not be -// updated for messages proxied after this call. -func (p *proxyDialer) Messages() []*ProxyMessage { - p.Lock() - defer p.Unlock() - - copiedMessages := make([]*ProxyMessage, len(p.messages)) - copy(copiedMessages, p.messages) - return copiedMessages -} - // proxyConn is a net.Conn that wraps a network connection. All messages sent/received through a proxyConn are stored // in the associated proxyDialer and are forwarded over the wrapped connection. Errors encountered when parsing and // storing wire messages are wrapped to add context, while errors returned from the underlying network connection are @@ -184,3 +177,12 @@ func (pc *proxyConn) Read(buffer []byte) (int, error) { return n, nil } + +func (msg *ProxyMessage) IsHandshake() bool { + hello := handshake.LegacyHello + if os.Getenv("REQUIRE_API_VERSION") == "true" { + hello = "hello" + } + + return hello == msg.CommandName +} diff --git a/internal/integration/sdam_prose_test.go b/internal/integration/sdam_prose_test.go index 274d6c0abb..ac9572ac02 100644 --- a/internal/integration/sdam_prose_test.go +++ b/internal/integration/sdam_prose_test.go @@ -69,7 +69,7 @@ func TestSDAMProse(t *testing.T) { } start := time.Now() time.Sleep(2 * time.Second) - messages := mt.GetProxiedMessages() + messages := mt.GetProxyCapture().Drain() duration := time.Since(start) hosts, err := mongoutil.HostsFromURI(mtest.ClusterURI()) diff --git a/internal/integration/unified/client_operation_execution.go b/internal/integration/unified/client_operation_execution.go index 86f161761d..1631d88ba7 100644 --- a/internal/integration/unified/client_operation_execution.go +++ b/internal/integration/unified/client_operation_execution.go @@ -307,6 +307,34 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati return newDocumentResult(rawBuilder.Build(), err), nil } +func executeAppendMetadata(ctx context.Context, op *operation) (*operationResult, error) { + client, err := entities(ctx).client(op.Object) + if err != nil { + return nil, fmt.Errorf("error getting client entity: %w", err) + } + + elems, err := op.Arguments.Elements() + if err != nil { + return nil, fmt.Errorf("error getting appendMetadata arguments: %w", err) + } + + driverInfo := options.DriverInfo{} + for _, elem := range elems { + key := elem.Key() + val := elem.Value() + + if key == "driverInfoOptions" { + if err = bson.Unmarshal(val.Value, &driverInfo); err != nil { + return nil, fmt.Errorf("error unmarshaling driverInfoOptions: %w", err) + } + } + } + + client.AppendDriverInfo(driverInfo) + + return newEmptyResult(), nil +} + func createClientInsertOneModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index 9baf785dcb..1b591d66af 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -128,7 +128,9 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat // executeWithTransaction internally verifies results/errors for each operation, so it doesn't return a result. return newEmptyResult(), executeWithTransaction(ctx, op, loopDone) - // Client operations + // Client operations + case "appendMetadata": + return executeAppendMetadata(ctx, op) case "createChangeStream": return executeCreateChangeStream(ctx, op) case "listDatabases": diff --git a/internal/integration/unified/unified_spec_test.go b/internal/integration/unified/unified_spec_test.go index 7ff374fd2e..7c8c4ee5dd 100644 --- a/internal/integration/unified/unified_spec_test.go +++ b/internal/integration/unified/unified_spec_test.go @@ -36,6 +36,7 @@ var ( "server-discovery-and-monitoring/tests/unified", "run-command/tests/unified", "index-management/tests", + "mongodb-handshake/tests/unified", } failDirectories = []string{ "unified-test-format/tests/valid-fail", diff --git a/mongo/client.go b/mongo/client.go index 4ed0348e17..10182276b2 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -11,6 +11,8 @@ import ( "errors" "fmt" "net/http" + "sync" + "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/bson" @@ -56,24 +58,26 @@ var ( // The Client type opens and closes connections automatically and maintains a pool of idle connections. For // connection pool configuration options, see documentation for the ClientOptions type in the mongo/options package. type Client struct { - id uuid.UUID - deployment driver.Deployment - localThreshold time.Duration - retryWrites bool - retryReads bool - clock *session.ClusterClock - readPreference *readpref.ReadPref - readConcern *readconcern.ReadConcern - writeConcern *writeconcern.WriteConcern - bsonOpts *options.BSONOptions - registry *bson.Registry - monitor *event.CommandMonitor - serverAPI *driver.ServerAPIOptions - serverMonitor *event.ServerMonitor - sessionPool *session.Pool - timeout *time.Duration - httpClient *http.Client - logger *logger.Logger + id uuid.UUID + deployment driver.Deployment + localThreshold time.Duration + retryWrites bool + retryReads bool + clock *session.ClusterClock + readPreference *readpref.ReadPref + readConcern *readconcern.ReadConcern + writeConcern *writeconcern.WriteConcern + bsonOpts *options.BSONOptions + registry *bson.Registry + monitor *event.CommandMonitor + serverAPI *driver.ServerAPIOptions + serverMonitor *event.ServerMonitor + sessionPool *session.Pool + timeout *time.Duration + httpClient *http.Client + logger *logger.Logger + currentDriverInfo *atomic.Pointer[options.DriverInfo] + seenDriverInfo sync.Map // in-use encryption fields isAutoEncryptionSet bool @@ -132,7 +136,11 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { if err != nil { return nil, err } - client := &Client{id: id} + + client := &Client{ + id: id, + currentDriverInfo: &atomic.Pointer[options.DriverInfo]{}, + } // ClusterClock client.clock = new(session.ClusterClock) @@ -217,7 +225,16 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { } } - cfg, err := topology.NewConfigFromOptionsWithAuthenticator(clientOpts, client.clock, client.authenticator) + if clientOpts.DriverInfo != nil { + client.AppendDriverInfo(*clientOpts.DriverInfo) + } + + cfg, err := topology.NewAuthenticatorConfig(client.authenticator, + topology.WithAuthConfigClock(client.clock), + topology.WithAuthConfigClientOptions(clientOpts), + topology.WithAuthConfigDriverInfo(client.currentDriverInfo), + ) + if err != nil { return nil, err } @@ -294,6 +311,45 @@ func (c *Client) connect() error { return nil } +// AppendDriverInfo appends the provided [options.DriverInfo] to the metadata +// (e.g. name, version, platform) that will be sent to the server in handshake +// requests when establishing new connections. +// +// Repeated calls to AppendDriverInfo with equivalent DriverInfo is a no-op. +// +// Metadata is limited to 512 bytes; any excess will be truncated. +func (c *Client) AppendDriverInfo(info options.DriverInfo) { + if _, loaded := c.seenDriverInfo.LoadOrStore(info, struct{}{}); loaded { + return + } + + if old := c.currentDriverInfo.Load(); old != nil { + if old.Name != "" && info.Name != "" && old.Name != info.Name { + info.Name = old.Name + "|" + info.Name + } else if old.Name != "" { + info.Name = old.Name + } + + if old.Version != "" && info.Version != "" && old.Version != info.Version { + info.Version = old.Version + "|" + info.Version + } else if old.Version != "" { + info.Version = old.Version + } + + if old.Platform != "" && info.Platform != "" && old.Platform != info.Platform { + info.Platform = old.Platform + "|" + info.Platform + } else if old.Platform != "" { + info.Platform = old.Platform + } + } + + // Copy-on-write so that the info stored in the client is immutable. + infoCopy := new(options.DriverInfo) + *infoCopy = info + + c.currentDriverInfo.Store(infoCopy) +} + // Disconnect closes sockets to the topology referenced by this Client. It will // shut down any monitoring goroutines, close the idle connection pool, and will // wait until all the in use connections have been returned to the connection diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index f0cb6125dc..933743e1a3 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -14,6 +14,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/assert/assertbson" "go.mongodb.org/mongo-driver/v2/internal/codecutil" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" @@ -603,7 +604,7 @@ func TestMarshalValue(t *testing.T) { t.Parallel() got, err := marshalValue(tc.value, tc.bsonOpts, tc.registry) - assert.EqualBSON(t, tc.want, got) + assertbson.EqualValue(t, tc.want, got) assert.Equal(t, tc.wantErr, err, "expected and actual error do not match") }) } diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index b665387404..a408aee561 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -807,9 +807,18 @@ func (s *Server) createConnection() *connection { opts := copyConnectionOpts(s.cfg.connectionOpts) opts = append(opts, WithHandshaker(func(Handshaker) Handshaker { - return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts). - ServerAPI(s.cfg.serverAPI).OuterLibraryName(s.cfg.outerLibraryName). - OuterLibraryVersion(s.cfg.outerLibraryVersion).OuterLibraryPlatform(s.cfg.outerLibraryPlatform) + handshaker := operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts). + ServerAPI(s.cfg.serverAPI) + + if s.cfg.driverInfo != nil { + driverInfo := s.cfg.driverInfo.Load() + if driverInfo != nil { + handshaker = handshaker.OuterLibraryName(driverInfo.Name).OuterLibraryVersion(driverInfo.Version). + OuterLibraryPlatform(driverInfo.Platform) + } + } + + return handshaker }), // Override any monitors specified in options with nil to avoid monitoring heartbeats. WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }), diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index 490834cbef..297cafc701 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -7,11 +7,13 @@ package topology import ( + "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/logger" + "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" @@ -32,6 +34,7 @@ type serverConfig struct { monitoringDisabled bool serverAPI *driver.ServerAPIOptions loadBalanced bool + driverInfo *atomic.Pointer[options.DriverInfo] // Connection pool options. maxConns uint64 @@ -41,11 +44,6 @@ type serverConfig struct { logger *logger.Logger poolMaxIdleTime time.Duration poolMaintainInterval time.Duration - - // Fields provided by a library that wraps the Go Driver. - outerLibraryName string - outerLibraryVersion string - outerLibraryPlatform string } func newServerConfig(connectTimeout time.Duration, opts ...ServerOption) *serverConfig { @@ -101,27 +99,12 @@ func WithServerAppName(fn func(string) string) ServerOption { } } -// WithOuterLibraryName configures the name for the outer library to include -// in the drivers section of the handshake metadata. -func WithOuterLibraryName(fn func(string) string) ServerOption { - return func(cfg *serverConfig) { - cfg.outerLibraryName = fn(cfg.outerLibraryName) - } -} - -// WithOuterLibraryVersion configures the version for the outer library to -// include in the drivers section of the handshake metadata. -func WithOuterLibraryVersion(fn func(string) string) ServerOption { - return func(cfg *serverConfig) { - cfg.outerLibraryVersion = fn(cfg.outerLibraryVersion) - } -} - -// WithOuterLibraryPlatform configures the platform for the outer library to -// include in the platform section of the handshake metadata. -func WithOuterLibraryPlatform(fn func(string) string) ServerOption { +// WithDriverInfo sets at atomic pointer to the server configuration, which will +// be used to create the "driver" section on handshake commands. An atomic +// pointer is used so that the driver info can be updated concurrently. +func WithDriverInfo(info *atomic.Pointer[options.DriverInfo]) ServerOption { return func(cfg *serverConfig) { - cfg.outerLibraryPlatform = fn(cfg.outerLibraryPlatform) + cfg.driverInfo = info } } diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index be01504fd1..2d53805d3b 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "fmt" "net/http" + "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/event" @@ -139,14 +140,56 @@ func NewConfig(opts *options.ClientOptions, clock *session.ClusterClock) (*Confi return nil, fmt.Errorf("error creating authenticator: %w", err) } } - return NewConfigFromOptionsWithAuthenticator(opts, clock, authenticator) + return NewAuthenticatorConfig(authenticator, + WithAuthConfigClock(clock), + WithAuthConfigClientOptions(opts), + ) +} + +type authConfigOptions struct { + clock *session.ClusterClock + opts *options.ClientOptions + driverInfo *atomic.Pointer[options.DriverInfo] } -// NewConfigFromOptionsWithAuthenticator will translate data from client options into a +// AuthConfigOption is a function that configures authConfigOptions. +type AuthConfigOption func(*authConfigOptions) + +// WithAuthConfigClock sets the cluster clock in authConfigOptions. +func WithAuthConfigClock(clock *session.ClusterClock) AuthConfigOption { + return func(co *authConfigOptions) { + co.clock = clock + } +} + +// WithAuthConfigClientOptions sets the client options in authConfigOptions. +func WithAuthConfigClientOptions(opts *options.ClientOptions) AuthConfigOption { + return func(co *authConfigOptions) { + co.opts = opts + } +} + +// WithAuthConfigDriverInfo sets the driver info in authConfigOptions. +func WithAuthConfigDriverInfo(driverInfo *atomic.Pointer[options.DriverInfo]) AuthConfigOption { + return func(co *authConfigOptions) { + co.driverInfo = driverInfo + } +} + +// NewAuthenticatorConfig will translate data from client options into a // topology config for building non-default deployments. Server and topology // options are not honored if a custom deployment is used. It uses a passed in // authenticator to authenticate the connection. -func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { +func NewAuthenticatorConfig(authenticator driver.Authenticator, clientOpts ...AuthConfigOption) (*Config, error) { + settings := authConfigOptions{} + for _, apply := range clientOpts { + apply(&settings) + } + + opts := settings.opts + clock := settings.clock + driverInfo := settings.driverInfo + var serverAPI *driver.ServerAPIOptions if err := opts.Validate(); err != nil { @@ -200,23 +243,8 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s })) } - var outerLibraryName, outerLibraryVersion, outerLibraryPlatform string - if opts.DriverInfo != nil { - outerLibraryName = opts.DriverInfo.Name - outerLibraryVersion = opts.DriverInfo.Version - outerLibraryPlatform = opts.DriverInfo.Platform - - serverOpts = append(serverOpts, WithOuterLibraryName(func(string) string { - return outerLibraryName - })) - - serverOpts = append(serverOpts, WithOuterLibraryVersion(func(string) string { - return outerLibraryVersion - })) - - serverOpts = append(serverOpts, WithOuterLibraryPlatform(func(string) string { - return outerLibraryPlatform - })) + if driverInfo != nil { + serverOpts = append(serverOpts, WithDriverInfo(driverInfo)) } // Compressors & ZlibLevel @@ -256,44 +284,57 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s // Handshaker var handshaker func(driver.Handshaker) driver.Handshaker if authenticator != nil { - handshakeOpts := &auth.HandshakeOptions{ - AppName: appName, - Authenticator: authenticator, - Compressors: comps, - ServerAPI: serverAPI, - LoadBalanced: loadBalanced, - ClusterClock: clock, - OuterLibraryName: outerLibraryName, - OuterLibraryVersion: outerLibraryVersion, - OuterLibraryPlatform: outerLibraryPlatform, - } + handshaker = func(driver.Handshaker) driver.Handshaker { + handshakeOpts := &auth.HandshakeOptions{ + AppName: appName, + Authenticator: authenticator, + Compressors: comps, + ServerAPI: serverAPI, + LoadBalanced: loadBalanced, + ClusterClock: clock, + } - if opts.Auth.AuthMechanism == "" { - // Required for SASL mechanism negotiation during handshake - handshakeOpts.DBUser = opts.Auth.AuthSource + "." + opts.Auth.Username - } - if auth, ok := optionsutil.Value(opts.Custom, "authenticateToAnything").(bool); ok && auth { - // Authenticate arbiters - handshakeOpts.PerformAuthentication = func(_ description.Server) bool { - return true + if opts.Auth.AuthMechanism == "" { + // Required for SASL mechanism negotiation during handshake + handshakeOpts.DBUser = opts.Auth.AuthSource + "." + opts.Auth.Username + } + + if auth, ok := optionsutil.Value(opts.Custom, "authenticateToAnything").(bool); ok && auth { + // Authenticate arbiters + handshakeOpts.PerformAuthentication = func(_ description.Server) bool { + return true + } + } + + if driverInfo != nil { + if di := driverInfo.Load(); di != nil { + handshakeOpts.OuterLibraryName = di.Name + handshakeOpts.OuterLibraryVersion = di.Version + handshakeOpts.OuterLibraryPlatform = di.Platform + } } - } - handshaker = func(driver.Handshaker) driver.Handshaker { return auth.Handshaker(nil, handshakeOpts) } } else { handshaker = func(driver.Handshaker) driver.Handshaker { - return operation.NewHello(). + op := operation.NewHello(). AppName(appName). Compressors(comps). ClusterClock(clock). ServerAPI(serverAPI). - LoadBalanced(loadBalanced). - OuterLibraryName(outerLibraryName). - OuterLibraryVersion(outerLibraryVersion). - OuterLibraryPlatform(outerLibraryPlatform) + LoadBalanced(loadBalanced) + + if driverInfo != nil { + if di := driverInfo.Load(); di != nil { + op = op.OuterLibraryName(di.Name). + OuterLibraryVersion(di.Version). + OuterLibraryPlatform(di.Platform) + } + } + + return op } } diff --git a/x/mongo/driver/topology/topology_options_test.go b/x/mongo/driver/topology/topology_options_test.go index 680aa638a7..319dabf9c2 100644 --- a/x/mongo/driver/topology/topology_options_test.go +++ b/x/mongo/driver/topology/topology_options_test.go @@ -149,7 +149,7 @@ func TestAuthenticateToAnything(t *testing.T) { opt := options.Client().SetAuth(options.Credential{Username: "foo", Password: "bar"}) err := tc.set(opt) require.NoError(t, err, "error setting authenticateToAnything: %v", err) - cfg, err := NewConfigFromOptionsWithAuthenticator(opt, nil, &testAuthenticator{}) + cfg, err := NewAuthenticatorConfig(&testAuthenticator{}, WithAuthConfigClientOptions(opt)) require.NoError(t, err, "error constructing topology config: %v", err) srvrCfg := newServerConfig(defaultConnectionTimeout, cfg.ServerOpts...)