Skip to content

Commit 3ad2119

Browse files
committed
support SCARD and SMEMBERS with tests
1 parent 4243d53 commit 3ad2119

File tree

3 files changed

+345
-181
lines changed

3 files changed

+345
-181
lines changed

internal/server/redis/command.go

Lines changed: 103 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"math/rand/v2"
87
"strconv"
98
"strings"
109

@@ -598,140 +597,6 @@ func (s *session) handleIncrByFloat(ctx context.Context, args []resp.Value) (str
598597
return resp.FormatBulkString(string(res)), nil
599598
}
600599

601-
// Interns, stores, and returns a new UID. The upper 32 bits are a random sequence number, and the lower
602-
// 32 bits are a sequential number within that space. This allows for up to 4 billion unique IDs per
603-
// sequence and very low contention when allocating new IDs
604-
func (s *session) allocateNewUID(ctx context.Context, tx fdb.Transaction) (uint64, error) {
605-
ctx, span := s.tracer.Start(ctx, "allocateNewUID") // nolint
606-
defer span.End()
607-
608-
// sequenceNum is the random uint32 sequence we are using for this allocation
609-
var sequenceNum uint32
610-
var sequenceKey fdb.Key
611-
612-
// assignedUID is the uint32 within the sequence we will assign
613-
var assignedUID uint32
614-
615-
for range 20 {
616-
// pick a random uint32 as the sequence we will be using for this member
617-
sequenceNum = rand.Uint32()
618-
619-
var err error
620-
sequenceKey, err = s.uidKey(strconv.FormatUint(uint64(sequenceNum), 10))
621-
if err != nil {
622-
span.RecordError(err)
623-
return 0, fmt.Errorf("failed to get uid key: %w", err)
624-
}
625-
626-
val, err := tx.Get(sequenceKey).Get()
627-
if err != nil {
628-
return 0, recordErr(span, fmt.Errorf("failed to get last uid: %w", err))
629-
}
630-
631-
if len(val) == 0 {
632-
assignedUID = 1
633-
} else {
634-
lastUID, err := strconv.ParseUint(string(val), 10, 32)
635-
if err != nil {
636-
span.RecordError(err)
637-
return 0, fmt.Errorf("failed to parse last uid: %w", err)
638-
}
639-
640-
// if we have exhausted this sequence, pick a new random sequence
641-
if lastUID >= 0xFFFFFFFF {
642-
continue
643-
}
644-
645-
assignedUID = uint32(lastUID) + 1
646-
}
647-
}
648-
649-
// if we failed to find a sequence after several tries, return an error
650-
if assignedUID == 0 {
651-
err := fmt.Errorf("failed to allocate new uid after multiple attempts")
652-
span.RecordError(err)
653-
return 0, err
654-
}
655-
656-
// assemble the 64-bit UID from the sequence and assigned UID
657-
newUID := (uint64(sequenceNum) << 32) | uint64(assignedUID)
658-
659-
// store the assigned UID back to the sequence key for the next allocation
660-
tx.Set(sequenceKey, []byte(strconv.FormatUint(uint64(assignedUID), 10)))
661-
662-
// return the full 64-bit UID
663-
return newUID, nil
664-
}
665-
666-
// Returns the UID for the given member string, creating a new one if it does not exist.
667-
// If peek is true, it will only look up the UID without creating a new one. If peeking
668-
// and the member does not exist, it returns 0 without an error. UIDs start at 1, so 0
669-
// is never a valid UID.
670-
func (s *session) getUID(ctx context.Context, tx fdb.Transaction, member string, peek bool) (uint64, error) {
671-
ctx, span := s.tracer.Start(ctx, "getUID") // nolint
672-
defer span.End()
673-
674-
memberToUIDKey, err := s.reverseUIDKey(member)
675-
if err != nil {
676-
span.RecordError(err)
677-
return 0, fmt.Errorf("failed to get uid key: %w", err)
678-
}
679-
680-
val, err := tx.Get(memberToUIDKey).Get()
681-
if err != nil {
682-
span.RecordError(err)
683-
return 0, fmt.Errorf("failed to get member to uid mapping: %w", err)
684-
}
685-
if peek {
686-
// just look up the UID without creating a new one
687-
if len(val) == 0 {
688-
val = []byte("0")
689-
}
690-
} else {
691-
// check if we've already assigned a UID to this member
692-
val, err = tx.Get(memberToUIDKey).Get()
693-
if err != nil {
694-
span.RecordError(err)
695-
return 0, fmt.Errorf("failed to get member to uid mapping: %w", err)
696-
}
697-
698-
if len(val) == 0 {
699-
// allocate a new UID for this member string
700-
uid, err := s.allocateNewUID(ctx, tx)
701-
if err != nil {
702-
span.RecordError(err)
703-
return 0, fmt.Errorf("failed to allocate new uid: %w", err)
704-
}
705-
706-
uidStr := strconv.FormatUint(uid, 10)
707-
uidToMemberKey, err := s.uidKey(uidStr)
708-
if err != nil {
709-
span.RecordError(err)
710-
return 0, fmt.Errorf("failed to get uid key: %w", err)
711-
}
712-
713-
// store the bi-directional mapping
714-
tx.Set(memberToUIDKey, []byte(uidStr))
715-
tx.Set(uidToMemberKey, []byte(member))
716-
717-
val = []byte(uidStr)
718-
}
719-
}
720-
if err != nil {
721-
span.RecordError(err)
722-
return 0, fmt.Errorf("failed to get or create uid: %w", err)
723-
}
724-
725-
uid, err := strconv.ParseUint(string(val), 10, 64)
726-
if err != nil {
727-
span.RecordError(err)
728-
return 0, fmt.Errorf("failed to parse uid: %w", err)
729-
}
730-
731-
span.SetStatus(codes.Ok, "getUID ok")
732-
return uid, nil
733-
}
734-
735600
func (s *session) handleSetAdd(ctx context.Context, args []resp.Value) (string, error) {
736601
ctx, span := s.tracer.Start(ctx, "handleSetAdd")
737602
defer span.End()
@@ -764,7 +629,7 @@ func (s *session) handleSetAdd(ctx context.Context, args []resp.Value) (string,
764629
}
765630
}()
766631

767-
return s.getUID(ctx, tx, member, false)
632+
return s.getOrAllocateUID(ctx, tx, member)
768633
})
769634
if err != nil {
770635
return nil, recordErr(span, fmt.Errorf("failed to get uids for members: %w", err))
@@ -853,7 +718,7 @@ func (s *session) handleSetRemove(ctx context.Context, args []resp.Value) (strin
853718
}
854719
}()
855720

856-
return s.getUID(ctx, tx, member, true)
721+
return s.peekUID(ctx, tx, member)
857722
})
858723
if err != nil {
859724
return nil, recordErr(span, fmt.Errorf("failed to get uids for members: %w", err))
@@ -935,9 +800,9 @@ func (s *session) handleSetIsMember(ctx context.Context, args []resp.Value) (str
935800
return "", recordErr(span, fmt.Errorf("failed to parse set member argument: %w", err))
936801
}
937802

938-
containsAny, err := s.fdb.Transact(func(tx fdb.Transaction) (any, error) {
803+
containsAny, err := s.fdb.ReadTransact(func(tx fdb.ReadTransaction) (any, error) {
939804
// lookup the UID for the member
940-
memberUID, err := s.getUID(ctx, tx, member, true)
805+
memberUID, err := s.peekUID(ctx, tx, member)
941806
if err != nil {
942807
return false, recordErr(span, fmt.Errorf("failed to get uid for member: %w", err))
943808
}
@@ -969,6 +834,104 @@ func (s *session) handleSetIsMember(ctx context.Context, args []resp.Value) (str
969834
return "", recordErr(span, fmt.Errorf("invalid result type: %w", err))
970835
}
971836

972-
span.SetStatus(codes.Ok, "srem handled")
837+
span.SetStatus(codes.Ok, "sismember handled")
973838
return resp.FormatBoolAsInt(contains), nil
974839
}
840+
841+
func (s *session) handleSetCard(ctx context.Context, args []resp.Value) (string, error) {
842+
ctx, span := s.tracer.Start(ctx, "handleSetCard") // nolint
843+
defer span.End()
844+
845+
if err := validateNumArgs(args, 1); err != nil {
846+
return "", recordErr(span, err)
847+
}
848+
849+
key, err := extractStringArg(args[0])
850+
if err != nil {
851+
return "", recordErr(span, fmt.Errorf("failed to parse set key argument: %w", err))
852+
}
853+
854+
membersAny, err := s.fdb.ReadTransact(func(tx fdb.ReadTransaction) (any, error) {
855+
// get the bitmap if it exists
856+
_, blob, err := s.getObject(tx, key)
857+
if err != nil {
858+
return uint64(0), fmt.Errorf("failed to get existing set: %w", err)
859+
}
860+
if len(blob) == 0 {
861+
return uint64(0), nil
862+
}
863+
864+
bitmap := roaring.New()
865+
if err := bitmap.UnmarshalBinary(blob); err != nil {
866+
return uint64(0), fmt.Errorf("failed to unmarshal existing bitmap: %w", err)
867+
}
868+
869+
return bitmap.GetCardinality(), nil
870+
})
871+
if err != nil {
872+
return "", recordErr(span, fmt.Errorf("failed to check if member is in set: %w", err))
873+
}
874+
875+
members, err := cast[uint64](membersAny)
876+
if err != nil {
877+
return "", recordErr(span, fmt.Errorf("invalid result type: %w", err))
878+
}
879+
880+
span.SetStatus(codes.Ok, "scard handled")
881+
return resp.FormatInt(int64(members)), nil
882+
}
883+
884+
func (s *session) handleSetMembers(ctx context.Context, args []resp.Value) (string, error) {
885+
ctx, span := s.tracer.Start(ctx, "handleSetMembers") // nolint
886+
defer span.End()
887+
888+
if err := validateNumArgs(args, 1); err != nil {
889+
return "", recordErr(span, err)
890+
}
891+
892+
key, err := extractStringArg(args[0])
893+
if err != nil {
894+
return "", recordErr(span, fmt.Errorf("failed to parse set key argument: %w", err))
895+
}
896+
897+
membersAny, err := s.fdb.ReadTransact(func(tx fdb.ReadTransaction) (any, error) {
898+
// get the bitmap if it exists
899+
_, blob, err := s.getObject(tx, key)
900+
if err != nil {
901+
return []string{}, fmt.Errorf("failed to get existing set: %w", err)
902+
}
903+
904+
// if the bitmap doesn't exist, the member does not exist in the set
905+
if len(blob) == 0 {
906+
return []string{}, nil
907+
}
908+
909+
bitmap := roaring.New()
910+
if err := bitmap.UnmarshalBinary(blob); err != nil {
911+
return []string{}, fmt.Errorf("failed to unmarshal existing bitmap: %w", err)
912+
}
913+
914+
uids := bitmap.ToArray()
915+
916+
r := concurrent.New[uint64, string]()
917+
members, err := r.Do(ctx, uids, func(uid uint64) (string, error) {
918+
return s.memberFromUID(ctx, uid)
919+
})
920+
if err != nil {
921+
return []string{}, fmt.Errorf("failed to get member from uid: %w", err)
922+
}
923+
924+
return members, nil
925+
})
926+
if err != nil {
927+
return "", recordErr(span, fmt.Errorf("failed to check if member is in set: %w", err))
928+
}
929+
930+
members, err := cast[[]string](membersAny)
931+
if err != nil {
932+
return "", recordErr(span, fmt.Errorf("invalid result type: %w", err))
933+
}
934+
935+
span.SetStatus(codes.Ok, "smembers handled")
936+
return resp.FormatArrayOfBulkStrings(members), nil
937+
}

0 commit comments

Comments
 (0)