diff --git a/realtime/events.go b/realtime/events.go index 5ef5018..3e7e579 100644 --- a/realtime/events.go +++ b/realtime/events.go @@ -1,20 +1,122 @@ -package realtime; +package realtime -// Channel Events -const JOIN_EVENT = "phx_join" -const REPLY_EVENT = "phx_reply" +import ( + "fmt" + "reflect" + "strings" +) -// DB Subscription Events -const POSTGRES_CHANGE_EVENT = "postgres_changes" +// Events that are used to communicate with the server +const ( + joinEvent string = "phx_join" + replyEvent string = "phx_reply" + leaveEvent string = "phx_leave" + closeEvent string = "phx_close" -// Broadcast Events -const BROADCAST_EVENT = "broadcast" + // DB Subscription Events + postgresChangesEvent string = "postgres_changes" -// Presence Events -const PRESENCE_STATE_EVENT = "presence_state" -const PRESENCE_DIFF_EVENT ="presence_diff" + // Broadcast Events + broadcastEvent string = "broadcast" -// Other Events -const SYS_EVENT = "system" -const HEARTBEAT_EVENT = "heartbeat" -const ACCESS_TOKEN_EVENT = "access_token" + // Presence Events + presenceStateEvent string = "presence_state" + presenceDiffEvent string = "presence_diff" + + // Other Events + systemEvent string = "system" + heartbeatEvent string = "heartbeat" + accessTokennEvent string = "access_token" +) + +// Event "type" that the user can specify for channel to listen to +const ( + presenceEventType string = "presence" + broadcastEventType string = "broadcast" + postgresChangesEventType string = "postgres_changes" +) + +// type eventFilter struct {} +type eventFilter interface {} + +type postgresFilter struct { + Event string `supabase:"required" json:"event"` + Schema string `supabase:"required" json:"schema"` + Table string `supabase:"optional" json:"table,omitempty"` + Filter string `supabase:"optional" json:"filter,omitempty"` +} + +type broadcastFilter struct { + Event string `supabase:"required"` +} + +type presenceFilter struct { + Event string `supabase:"required"` +} + +// Verify if the given event type is supported +func verifyEventType(eventType string) bool { + switch eventType { + case presenceEventType: + case broadcastEventType: + case postgresChangesEventType: + return true + } + + return false +} + +// Enforce client's filter object to follow a specific message +// structure of certain events. Check messages.go for more +// information on the struct of each event. +// Only the following events are currently supported: +// - postgres_changes, broadcast, presence +func createEventFilter(eventType string, filter map[string]string) (eventFilter, error) { + var filterType reflect.Type // Type for filter + var filterConValue reflect.Value // Concrete value + var filterPtrValue reflect.Value // Pointer value to the concrete value + var missingFields []string + + switch eventType { + case postgresChangesEvent: + filterPtrValue = reflect.ValueOf(&postgresFilter{}) + break + case broadcastEvent: + filterPtrValue = reflect.ValueOf(&broadcastFilter{}) + break + case presenceEventType: + filterPtrValue = reflect.ValueOf(&presenceFilter{}) + default: + return nil, fmt.Errorf("Unsupported event type: %s", eventType) + } + + // Get the underlying filter type to identify missing fields + filterConValue = filterPtrValue.Elem() + filterType = filterConValue.Type() + missingFields = make([]string, 0, filterType.NumField()) + + for i := 0; i < filterType.NumField(); i++ { + currField := filterType.Field(i) + currFieldName := strings.ToLower(currField.Name) + isRequired := currField.Tag.Get("supabase") == "required" + + val, ok := filter[currFieldName] + if !ok && isRequired { + missingFields = append(missingFields, currFieldName) + } + + // Set field to empty string when value for currFieldName is missing + filterConValue.Field(i).SetString(val) + } + + if len(missingFields) != 0 { + return nil, fmt.Errorf("Criteria for %s is missing: %+v", eventType, missingFields) + } + + filterFinal, ok := filterConValue.Interface().(eventFilter) + if !ok { + return nil, fmt.Errorf("Unexpected Error: cannot create event filter") + } + + return filterFinal, nil +} diff --git a/realtime/messages.go b/realtime/messages.go index 21b9872..3bb2c87 100644 --- a/realtime/messages.go +++ b/realtime/messages.go @@ -1,30 +1,152 @@ -package realtime; +package realtime -type TemplateMsg struct { - Event string `json:"event"` - Topic string `json:"topic"` - Ref string `json:"ref"` +import ( + "encoding/json" + "strconv" + "time" +) + +// This is a general message strucutre. It follows the message protocol +// of the phoenix server: +/* + { + event: string, + topic: string, + payload: [key: string]: boolean | int | string | any, + ref: string + } +*/ +type Msg struct { + Metadata + Payload any `json:"payload"` +} + +// Generic message that contains raw payload. It can be used +// as a tagged union, where the event field can be used to +// determine the structure of the payload. +type RawMsg struct { + Metadata + Payload json.RawMessage `json:"payload"` +} + +// The other fields besides the payload that make up a message. +// It describes other information about a message such as type of event, +// the topic the message belongs to, and its reference. +type Metadata struct { + Event string `json:"event"` + Topic string `json:"topic"` + Ref string `json:"ref"` } -type ConnectionMsg struct { - TemplateMsg +// Payload for the conection message for when client first joins the channel. +// More info: https://supabase.com/docs/guides/realtime/protocol#connection +type ConnectionPayload struct { + Config struct { + Broadcast struct { + Self bool `json:"self"` + } `json:"broadcast,omitempty"` + + Presence struct { + Key string `json:"key"` + } `json:"presence,omitempty"` + + PostgresChanges []postgresFilter `json:"postgres_changes,omitempty"` + } `json:"config"` +} + +// Payload of the server's first response of three upon joining channel. +// It contains details about subscribed postgres events. +// More info: https://supabase.com/docs/guides/realtime/protocol#connection +type ReplyPayload struct { + Response struct { + PostgresChanges []struct{ + ID int `json:"id"` + postgresFilter + } `json:"postgres_changes"` + } `json:"response"` + Status string `json:"status"` +} - Payload struct { - Data struct { - Schema string `json:"schema"` - Table string `json:"table"` - CommitTime string `json:"commit_timestamp"` - EventType string `json:"eventType"` - New map[string]string `json:"new"` - Old map[string]string `json:"old"` - Errors string `json:"errors"` - } `json:"data"` - } `json:"payload"` +// Payload of the server's second response of three upon joining channel. +// It contains details about the status of subscribing to PostgresSQL. +// More info: https://supabase.com/docs/guides/realtime/protocol#system-messages +type SystemPayload struct { + Channel string `json:"channel"` + Extension string `json:"extension"` + Message string `json:"message"` + Status string `json:"status"` } -type HearbeatMsg struct { - TemplateMsg +// Payload of the server's third response of three upon joining channel. +// It contains details about the Presence feature of Supabase. +// More info: https://supabase.com/docs/guides/realtime/protocol#state-update +type PresenceStatePayload map[string]struct{ + Metas []struct{ + Ref string `json:"phx_ref"` + Name string `json:"name"` + T float64 `json:"t"` + } `json:"metas,omitempty"` +} + +// Payload of the server's response when there is a postgres_changes event. +// More info: https://supabase.com/docs/guides/realtime/protocol#system-messages +type PostgresCDCPayload struct { + Data struct { + Schema string `json:"schema"` + Table string `json:"table"` + CommitTime string `json:"commit_timestamp"` + Record map[string]any `json:"record"` + Columns []struct{ + Name string `json:"name"` + Type string `json:"type"` + } `json:"columns"` + ActionType string `json:"type"` + Old map[string]any `json:"old_record"` + Errors string `json:"errors"` + } `json:"data"` + IDs []int `json:"ids"` +} + +// create a template message +func createMsgMetadata(event string, topic string) *Metadata { + return &Metadata{ + Event: event, + Topic: topic, + Ref: "", + } +} + +// create a connection message depending on event type +func createConnectionMessage(topic string, bindings []*binding) *Msg { + msg := &Msg{} + + // Fill out the message template + msg.Metadata = *createMsgMetadata(joinEvent, topic) + msg.Metadata.Ref = strconv.FormatInt(time.Now().Unix(), 10) + + // Fill out the payload + payload := &ConnectionPayload{} + for _, bind := range bindings { + filter := bind.filter + switch filter.(type) { + case postgresFilter: + if payload.Config.PostgresChanges == nil { + payload.Config.PostgresChanges = make([]postgresFilter, 0, 1) + } + payload.Config.PostgresChanges = append(payload.Config.PostgresChanges, filter.(postgresFilter)) + break + case broadcastFilter: + payload.Config.Broadcast.Self = true + break + case presenceFilter: + payload.Config.Presence.Key = "" + break + default: + panic("TYPE ASSERTION FAILED: expecting one of postgresFilter, broadcastFilter, or presenceFilter") + } + } + + msg.Payload = payload - Payload struct { - } `json:"payload"` + return msg } diff --git a/realtime/realtime_channel.go b/realtime/realtime_channel.go new file mode 100644 index 0000000..f059381 --- /dev/null +++ b/realtime/realtime_channel.go @@ -0,0 +1,163 @@ +package realtime + +import ( + "context" + "fmt" + "strings" + "sync" +) + +type RealtimeChannel struct { + topic string + client *RealtimeClient + hasSubscribed bool + + rwMu sync.RWMutex + numBindings int + bindings map[string][]*binding + postgresBindingRoute map[int]*binding +} + +// Bind an event with the user's callback function +type binding struct { + eventType string + filter eventFilter + callback func(any) +} + +// Initialize a new channel +func CreateRealtimeChannel(client *RealtimeClient, topic string) *RealtimeChannel { + return &RealtimeChannel{ + client: client, + topic: topic, + numBindings: 0, + bindings: make(map[string][]*binding), + postgresBindingRoute: make(map[int]*binding), + hasSubscribed: false, + } +} + +// Perform callbacks on specific events. Successive calls to On() +// will result in multiple callbacks acting at the event +func (channel *RealtimeChannel) On(eventType string, filter map[string]string, callback func(any)) error { + eventType = strings.ToLower(eventType) + if !verifyEventType(eventType) { + return fmt.Errorf("invalid event type: %s", eventType) + } + + eventFilter, err := createEventFilter(eventType, filter) + if err != nil { + return fmt.Errorf("Invalid filter criteria for %s event type: %w", eventType, err) + } + + newBinding := &binding{ + eventType: eventType, + filter: eventFilter, + callback: callback, + } + + channel.numBindings += 1 + channel.bindings[eventType] = append(channel.bindings[eventType], newBinding) + + return nil +} + +// Subscribe to the channel and start listening to events +func (channel *RealtimeChannel) Subscribe(ctx context.Context) error { + if channel.hasSubscribed { + return fmt.Errorf("Error: Channel %s can only be subscribed once", channel.topic) + } + + // Do nothing if there are no bindings + if channel.numBindings == 0 { + return nil + } + + // Flatten all type of bindings into one slice + allBindings := make([]*binding, channel.numBindings) + startIdx := 0 + for _, eventType := range []string{postgresChangesEventType, broadcastEventType, presenceEventType} { + if startIdx >= channel.numBindings { + break + } + + copy(allBindings[startIdx:], channel.bindings[eventType]) + startIdx += len(channel.bindings[eventType]) + } + + respPayload, err := channel.client.subscribe(channel.topic, allBindings, ctx) + if err != nil { + return fmt.Errorf("Channel %s failed to subscribe: %v", channel.topic, err) + } + + // Verify and map postgres events. If there are any mismatch, channel will + // rollback, and unsubscribe to the events. + changes := respPayload.Response.PostgresChanges + postgresBindings := channel.bindings[postgresChangesEventType] + if len(postgresBindings) != len(changes) { + channel.Unsubscribe(ctx) + return fmt.Errorf("Server returns the wrong number of subscribed events: %v events", len(changes)) + } + + for i, change := range changes { + bindingFilter, ok := postgresBindings[i].filter.(postgresFilter) + if !ok { + panic("TYPE ASSERTION FAILED: expecting type postgresFilter") + } + if strings.ToLower(change.Schema) != strings.ToLower(bindingFilter.Schema) || + strings.ToUpper(change.Event) != strings.ToUpper(bindingFilter.Event) || + strings.ToLower(change.Table) != strings.ToLower(bindingFilter.Table) || + strings.ToLower(change.Filter) != strings.ToLower(bindingFilter.Filter) { + channel.Unsubscribe(ctx) + return fmt.Errorf("Configuration mismatch between server's event and channel's event") + } + channel.postgresBindingRoute[change.ID] = postgresBindings[i] + } + + channel.hasSubscribed = true + + return nil +} + +// Unsubscribe from the channel and stop listening to events +func (channel *RealtimeChannel) Unsubscribe(ctx context.Context) { + if !channel.hasSubscribed { + return + } + + // Refresh all the binding routes + channel.rwMu.Lock() + clear(channel.postgresBindingRoute) + channel.rwMu.Unlock() + + channel.client.unsubscribe(channel.topic, ctx) + channel.hasSubscribed = false +} + +// Route the id of triggered event to appropriate callback +func (channel *RealtimeChannel) routePostgresEvent(id int, payload *PostgresCDCPayload) { + channel.rwMu.RLock() + binding, ok := channel.postgresBindingRoute[id] + channel.rwMu.RUnlock() + + if !ok { + channel.client.logger.Printf("Error: Unrecognized id %v", id) + return + } + + bindFilter, ok := binding.filter.(postgresFilter) + if !ok { + panic("TYPE ASSERTION FAILED: expecting type postgresFilter") + } + + // Match * | INSERT | UPDATE | DELETE + switch strings.ToUpper(bindFilter.Event) { + case "*": + fallthrough + case payload.Data.ActionType: + go binding.callback(payload) + break + default: + return + } +} diff --git a/realtime/realtime_client.go b/realtime/realtime_client.go index 8a3603e..1e2c5e9 100644 --- a/realtime/realtime_client.go +++ b/realtime/realtime_client.go @@ -2,6 +2,7 @@ package realtime import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -25,6 +26,9 @@ type RealtimeClient struct { reconnectInterval time.Duration heartbeatDuration time.Duration heartbeatInterval time.Duration + + replyChan chan *ReplyPayload + currentTopics map[string]*RealtimeChannel } // Create a new RealtimeClient with user's speicfications @@ -44,6 +48,9 @@ func CreateRealtimeClient(projectRef string, apiKey string) *RealtimeClient { heartbeatDuration: 5 * time.Second, heartbeatInterval: 20 * time.Second, reconnectInterval: 500 * time.Millisecond, + + currentTopics: make(map[string]*RealtimeChannel), + replyChan: make(chan *ReplyPayload), } } @@ -53,18 +60,18 @@ func (client *RealtimeClient) Connect() error { return nil } + // Change status of client to alive + client.closed = make(chan struct{}) + // Attempt to dial the server err := client.dialServer() if err != nil { + close(client.closed) return fmt.Errorf("Cannot connect to the server: %w", err) } - // client is only alive after the connection has been made - client.mu.Lock() - client.closed = make(chan struct{}) - client.mu.Unlock() - go client.startHeartbeats() + go client.startListening() return nil } @@ -93,6 +100,57 @@ func (client *RealtimeClient) Disconnect() error { return nil } +// Begins subscribing to events +func (client *RealtimeClient) subscribe(topic string, bindings []*binding, ctx context.Context) (*ReplyPayload, error) { + if !client.isClientAlive() { + client.Connect() + } + + msg := createConnectionMessage(topic, bindings) + err := wsjson.Write(context.Background(), client.conn, msg) + if err != nil { + return nil, fmt.Errorf("Unable to send the connection message: %v", err) + } + select { + case rep := <- client.replyChan: + if rep == nil { + return nil, fmt.Errorf("Error: Unable to subscribe to the channel %v succesfully", msg.Topic) + } + return rep, nil + case <- ctx.Done(): + return nil, fmt.Errorf("Error: Subscribing to to the channel %v has been canceled", msg.Topic) + } +} + +// Unsubscribe from events +func (client *RealtimeClient) unsubscribe(topic string, ctx context.Context) { + // There's no connection, so no need to unsubscribe from anything + if !client.isClientAlive() { + return + } + + leaveMsg := &Msg{ + Metadata: *createMsgMetadata(leaveEvent, topic), + Payload: struct{}{}, + } + + err := wsjson.Write(ctx, client.conn, leaveMsg) + if err != nil { + fmt.Printf("Unexpected error: %v", err) + } +} + +// Create a new channel with given topic string +func (client *RealtimeClient) Channel(newTopic string) (*RealtimeChannel, error) { + if _, ok := client.currentTopics[newTopic]; ok { + return nil, fmt.Errorf("Error: channel with %v topic already created", newTopic) + } + newChannel := CreateRealtimeChannel(client, "realtime:" + newTopic) + client.currentTopics["realtime:" + newTopic] = newChannel + + return newChannel, nil +} + // Start sending heartbeats to the server to maintain connection func (client *RealtimeClient) startHeartbeats() { for client.isClientAlive() { @@ -122,14 +180,11 @@ func (client *RealtimeClient) startHeartbeats() { // Send the heartbeat to the realtime server func (client *RealtimeClient) sendHeartbeat() error { - msg := HearbeatMsg{ - TemplateMsg: TemplateMsg{ - Event: HEARTBEAT_EVENT, - Topic: "phoenix", - Ref: "", - }, + msg := &Msg{ + Metadata: *createMsgMetadata(heartbeatEvent, "phoenix"), Payload: struct{}{}, } + msg.Metadata.Ref = heartbeatEvent ctx, cancel := context.WithTimeout(context.Background(), client.heartbeatDuration) defer cancel() @@ -144,12 +199,110 @@ func (client *RealtimeClient) sendHeartbeat() error { return nil } +// Keep reading from the connection from the connection +func (client *RealtimeClient) startListening() { + ctx := context.Background() + + for client.isClientAlive() { + var msg RawMsg + + // Read from the connection + err := wsjson.Read(ctx, client.conn, &msg) + + // Check if there's a way to partially marshal bytes into an object + // Or check if polymorphism in go (from TemplateMsg to another type of messg) + if err != nil { + if client.isConnectionAlive(err) { + client.logger.Printf("Unexpected error while listening: %v", err) + } else { + // Quick sleep to prevent taking up CPU cycles. + // Client should be able to reconnect automatically if it's still alive + time.Sleep(client.reconnectInterval) + } + } else { + // Spawn a new thread to process the server's respond + go client.processMessage(msg) + } + } +} + +// Process the given message according certain events +func (client *RealtimeClient) processMessage(msg RawMsg) { + genericPayload, err := client.unmarshalPayload(msg) + if err != nil { + client.logger.Printf("Unable to process received message: %v", err) + client.logger.Printf("%v", genericPayload) + return + } + + switch payload := genericPayload.(type) { + case *ReplyPayload: + status := payload.Status + + if msg.Ref == heartbeatEvent && status != "ok" { + client.logger.Printf("Heartbeat failure from server: %v", payload) + } else if msg.Ref == heartbeatEvent && status == "ok" { + client.logger.Printf("Heartbeat success from server: %v", payload) + } else if msg.Ref != heartbeatEvent && status != "ok" { + client.replyChan <- nil + } else if msg.Ref != heartbeatEvent && status == "ok" { + client.replyChan <- payload + } + break + case *PostgresCDCPayload: + if len(payload.IDs) == 0 { + client.logger.Print("Unexpected error: CDC message doesn't have any ids") + } + for _, id := range payload.IDs { + targetedChannel, ok := client.currentTopics[msg.Topic] + if !ok { + client.logger.Printf("Error: Unrecognized topic %v", msg.Topic) + continue + } + + targetedChannel.routePostgresEvent(id, payload) + } + break + } +} + +func (client *RealtimeClient) unmarshalPayload(msg RawMsg) (any, error) { + var payload any + var err error + + // Parse the payload depending on the event type + switch msg.Event { + case closeEvent: + fallthrough + case replyEvent: + payload = new(ReplyPayload) + break + case postgresChangesEvent: + payload = new(PostgresCDCPayload) + break + case systemEvent: + payload = new(SystemPayload) + break + case presenceStateEvent: + payload = new(PresenceStatePayload) + break + default: + return struct{}{}, fmt.Errorf("Error: Unsupported event %v", msg.Event) + } + + err = json.Unmarshal(msg.Payload, payload) + if err != nil { + return struct{}{}, fmt.Errorf("Error: Unable to unmarshal payload: %v", err) + } + return payload, nil +} + // Dial the server with a certain timeout in seconds func (client *RealtimeClient) dialServer() error { client.mu.Lock() defer client.mu.Unlock() - if client.isClientAlive() { + if !client.isClientAlive() { return nil }