Skip to content

Commit 63d4781

Browse files
Expose server error codes via mongo package
1 parent 3c97a45 commit 63d4781

File tree

4 files changed

+85
-6
lines changed

4 files changed

+85
-6
lines changed

internal/integration/clam_prose_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ func clamMultiByteTruncLogs(mt *mtest.T) []truncValidator {
198198

199199
// Insert started.
200200
validators[0] = newTruncValidator(mt, cmd, func(cmd string) error {
201-
202201
// Remove the suffix from the command string.
203202
cmd = cmd[:len(cmd)-len(logger.TruncationSuffix)]
204203

mongo/errors.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,12 @@ func hasErrorCode(srvErr ServerError, code int) bool {
364364
return false
365365
}
366366

367-
var _ ServerError = CommandError{}
368-
var _ ServerError = WriteError{}
369-
var _ ServerError = WriteException{}
370-
var _ ServerError = BulkWriteException{}
367+
var (
368+
_ ServerError = CommandError{}
369+
_ ServerError = WriteError{}
370+
_ ServerError = WriteException{}
371+
_ ServerError = BulkWriteException{}
372+
)
371373

372374
var _ error = ClientBulkWriteException{}
373375

@@ -901,3 +903,17 @@ func joinBatchErrors(errs []error) string {
901903

902904
return buf.String()
903905
}
906+
907+
// ErrorCodesFrom returns the list of server error codes contained in err.
908+
func ErrorCodesFrom(err error) []int {
909+
if err == nil {
910+
return nil
911+
}
912+
913+
var ec interface{ ErrorCodes() []int }
914+
if errors.As(wrapErrors(err), &ec) {
915+
return ec.ErrorCodes()
916+
}
917+
918+
return []int{}
919+
}

mongo/errors_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,3 +760,68 @@ func (n netErr) Temporary() bool {
760760
}
761761

762762
var _ net.Error = (*netErr)(nil)
763+
764+
func TestErrorCodesFrom(t *testing.T) {
765+
tests := []struct {
766+
name string
767+
input error
768+
want []int
769+
}{
770+
{
771+
name: "nil error",
772+
input: nil,
773+
want: nil,
774+
},
775+
{
776+
name: "non-server error",
777+
input: errors.New("boom"),
778+
want: []int{},
779+
},
780+
{
781+
name: "CommandError single code",
782+
input: CommandError{Code: 123},
783+
want: []int{123},
784+
},
785+
{
786+
name: "WriteError single code",
787+
input: WriteError{Code: 45},
788+
want: []int{45},
789+
},
790+
{
791+
name: "WriteException write errors only",
792+
input: WriteException{WriteErrors: WriteErrors{{Code: 1}, {Code: 2}}},
793+
want: []int{1, 2},
794+
},
795+
{
796+
name: "WriteException with write concern error",
797+
input: WriteException{WriteErrors: WriteErrors{{Code: 1}}, WriteConcernError: &WriteConcernError{Code: 64}},
798+
want: []int{1, 64},
799+
},
800+
{
801+
name: "BulkWriteException write errors only",
802+
input: BulkWriteException{WriteErrors: []BulkWriteError{{WriteError: WriteError{Code: 10}}, {WriteError: WriteError{Code: 11}}}},
803+
want: []int{10, 11},
804+
},
805+
{
806+
name: "BulkWriteException with write concern error",
807+
input: BulkWriteException{WriteErrors: []BulkWriteError{{WriteError: WriteError{Code: 10}}}, WriteConcernError: &WriteConcernError{Code: 79}},
808+
want: []int{10, 79},
809+
},
810+
{
811+
name: "driver.Error wraps to CommandError",
812+
input: driver.Error{Code: 91, Message: "shutdown in progress"},
813+
want: []int{91},
814+
},
815+
{
816+
name: "wrapped driver.Error",
817+
input: fmt.Errorf("context: %w", driver.Error{Code: 262, Message: "ExceededTimeLimit"}),
818+
want: []int{262},
819+
},
820+
}
821+
822+
for _, tt := range tests {
823+
t.Run(tt.name, func(t *testing.T) {
824+
require.Equal(t, tt.want, ErrorCodesFrom(tt.input))
825+
})
826+
}
827+
}

x/mongo/driver/operation.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,6 @@ func (op Operation) Execute(ctx context.Context) error {
747747
var moreToCome bool
748748
var startedInfo startedInformation
749749
*wm, moreToCome, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)
750-
751750
if err != nil {
752751
return err
753752
}

0 commit comments

Comments
 (0)