diff --git a/internal/integration/clam_prose_test.go b/internal/integration/clam_prose_test.go index 7dbb564280..e14b39a973 100644 --- a/internal/integration/clam_prose_test.go +++ b/internal/integration/clam_prose_test.go @@ -198,7 +198,6 @@ func clamMultiByteTruncLogs(mt *mtest.T) []truncValidator { // Insert started. validators[0] = newTruncValidator(mt, cmd, func(cmd string) error { - // Remove the suffix from the command string. cmd = cmd[:len(cmd)-len(logger.TruncationSuffix)] diff --git a/mongo/errors.go b/mongo/errors.go index 234445ab86..9a7105ebdf 100644 --- a/mongo/errors.go +++ b/mongo/errors.go @@ -364,10 +364,12 @@ func hasErrorCode(srvErr ServerError, code int) bool { return false } -var _ ServerError = CommandError{} -var _ ServerError = WriteError{} -var _ ServerError = WriteException{} -var _ ServerError = BulkWriteException{} +var ( + _ ServerError = CommandError{} + _ ServerError = WriteError{} + _ ServerError = WriteException{} + _ ServerError = BulkWriteException{} +) var _ error = ClientBulkWriteException{} @@ -901,3 +903,17 @@ func joinBatchErrors(errs []error) string { return buf.String() } + +// ErrorCodes returns the list of server error codes contained in err. +func ErrorCodes(err error) []int { + if err == nil { + return nil + } + + var ec interface{ ErrorCodes() []int } + if errors.As(wrapErrors(err), &ec) { + return ec.ErrorCodes() + } + + return []int{} +} diff --git a/mongo/errors_test.go b/mongo/errors_test.go index 2ff04c4dd2..5f62f43091 100644 --- a/mongo/errors_test.go +++ b/mongo/errors_test.go @@ -760,3 +760,79 @@ func (n netErr) Temporary() bool { } var _ net.Error = (*netErr)(nil) + +func TestErrorCodesFrom(t *testing.T) { + tests := []struct { + name string + input error + want []int + }{ + { + name: "nil error", + input: nil, + want: nil, + }, + { + name: "non-server error", + input: errors.New("boom"), + want: []int{}, + }, + { + name: "CommandError single code", + input: CommandError{Code: 123}, + want: []int{123}, + }, + { + name: "WriteError single code", + input: WriteError{Code: 45}, + want: []int{45}, + }, + { + name: "WriteException write errors only", + input: WriteException{WriteErrors: WriteErrors{{Code: 1}, {Code: 2}}}, + want: []int{1, 2}, + }, + { + name: "WriteException with write concern error", + input: WriteException{WriteErrors: WriteErrors{{Code: 1}}, WriteConcernError: &WriteConcernError{Code: 64}}, + want: []int{1, 64}, + }, + { + name: "BulkWriteException write errors only", + input: BulkWriteException{ + WriteErrors: []BulkWriteError{ + {WriteError: WriteError{Code: 10}}, + {WriteError: WriteError{Code: 11}}, + }, + }, + want: []int{10, 11}, + }, + { + name: "BulkWriteException with write concern error", + input: BulkWriteException{ + WriteErrors: []BulkWriteError{ + {WriteError: WriteError{Code: 10}}, + {WriteError: WriteError{Code: 11}}, + }, + WriteConcernError: &WriteConcernError{Code: 79}, + }, + want: []int{10, 11, 79}, + }, + { + name: "driver.Error wraps to CommandError", + input: driver.Error{Code: 91, Message: "shutdown in progress"}, + want: []int{91}, + }, + { + name: "wrapped driver.Error", + input: fmt.Errorf("context: %w", driver.Error{Code: 262, Message: "ExceededTimeLimit"}), + want: []int{262}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, ErrorCodes(tt.input)) + }) + } +} diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index ac2d5f69e1..9906563100 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -747,7 +747,6 @@ func (op Operation) Execute(ctx context.Context) error { var moreToCome bool var startedInfo startedInformation *wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) - if err != nil { return err }