diff --git a/platform/view/services/comm/master.go b/platform/view/services/comm/master.go index 5783f3117..bab4d9e5b 100644 --- a/platform/view/services/comm/master.go +++ b/platform/view/services/comm/master.go @@ -31,7 +31,6 @@ func (p *P2PNode) getOrCreateSession(sessionID, endpointAddress, contextID, call return session, nil } - ctx, cancel := context.WithCancel(context.Background()) s := &NetworkStreamSession{ node: p, endpointID: endpointID, @@ -40,10 +39,11 @@ func (p *P2PNode) getOrCreateSession(sessionID, endpointAddress, contextID, call sessionID: sessionID, caller: caller, callerViewID: callerViewID, - incoming: make(chan *view.Message, 1), + incoming: make(chan *view.Message), streams: make(map[*streamHandler]struct{}), - ctx: ctx, - cancel: cancel, + middleCh: make(chan *view.Message), + closing: make(chan struct{}), + closed: make(chan struct{}), } if msg != nil { diff --git a/platform/view/services/comm/session.go b/platform/view/services/comm/session.go index 812af6986..e9a71e651 100644 --- a/platform/view/services/comm/session.go +++ b/platform/view/services/comm/session.go @@ -9,7 +9,6 @@ package comm import ( "context" "sync" - "sync/atomic" "github.com/hyperledger-labs/fabric-smart-client/pkg/utils/errors" "github.com/hyperledger-labs/fabric-smart-client/pkg/utils/proto" @@ -38,11 +37,39 @@ type NetworkStreamSession struct { streams map[*streamHandler]struct{} mutex sync.RWMutex - closed atomic.Bool - closeOnce sync.Once - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + startOnce sync.Once + middleCh chan *view.Message + closing chan struct{} + closed chan struct{} +} + +func (n *NetworkStreamSession) tryStart() { + n.startOnce.Do(func() { + go func() { + exit := func(v *view.Message, needSend bool) { + close(n.closed) + if needSend { + n.incoming <- v + } + close(n.incoming) + } + + for { + select { + case <-n.closing: + exit(nil, false) + return + case v := <-n.middleCh: + select { + case <-n.closing: + exit(v, true) + return + case n.incoming <- v: + } + } + } + }() + }) } // Info returns a view.SessionInfo. @@ -92,18 +119,19 @@ func (n *NetworkStreamSession) enqueue(msg *view.Message) bool { return false } - if n.isClosed() { + // let's try to start the session + n.tryStart() + + select { + case <-n.closed: return false + default: } - n.wg.Add(1) - defer n.wg.Done() - select { - case <-n.ctx.Done(): - logger.Warnf("dropping message from %s for closed session [%s]", msg.Caller, msg.SessionID) + case <-n.closed: return false - case n.incoming <- msg: + case n.middleCh <- msg: return true } } @@ -114,7 +142,7 @@ func (n *NetworkStreamSession) Close() { } func (n *NetworkStreamSession) closeInternal() { - n.closeOnce.Do(func() { + closeStreams := func() { n.mutex.Lock() defer n.mutex.Unlock() @@ -135,20 +163,27 @@ func (n *NetworkStreamSession) closeInternal() { stream.close(context.TODO()) } logger.Debugf("closing session [%s]'s streams [%d] done", n.sessionID, len(toClose)) + clear(n.streams) + } - // next we are closing the incoming and the closing signal channel to drain the receivers; - n.closed.Store(true) - n.cancel() - n.wg.Wait() - close(n.incoming) - n.streams = make(map[*streamHandler]struct{}) - + select { + case n.closing <- struct{}{}: + <-n.closed + closeStreams() logger.Debugf("closing session [%s] done", n.sessionID) - }) + + case <-n.closed: + } } func (n *NetworkStreamSession) isClosed() bool { - return n.closed.Load() + select { + case <-n.closed: + return true + default: + } + + return false } func (n *NetworkStreamSession) sendWithStatus(ctx context.Context, payload []byte, status int32) error { diff --git a/platform/view/services/comm/session_test.go b/platform/view/services/comm/session_test.go index 8cc994dd0..051f15e52 100644 --- a/platform/view/services/comm/session_test.go +++ b/platform/view/services/comm/session_test.go @@ -9,6 +9,8 @@ package comm import ( "context" "fmt" + "math/rand" + "runtime" "sync" "testing" "time" @@ -25,6 +27,7 @@ import ( const ( timeout = 100 * time.Millisecond tick = 10 * time.Millisecond + maxVal = 1000 ) type mockSender struct { @@ -48,7 +51,6 @@ func setup() *NetworkStreamSession { net := &mockSender{} - ctx, cancel := context.WithCancel(context.Background()) return &NetworkStreamSession{ node: net, endpointID: endpointID, @@ -57,10 +59,11 @@ func setup() *NetworkStreamSession { sessionID: sessionID, caller: caller, callerViewID: callerViewID, - incoming: make(chan *view.Message, 1), + incoming: make(chan *view.Message), streams: make(map[*streamHandler]struct{}), - ctx: ctx, - cancel: cancel, + middleCh: make(chan *view.Message), + closing: make(chan struct{}), + closed: make(chan struct{}), } } @@ -79,10 +82,11 @@ func TestSessionLifecycle(t *testing.T) { Payload: []byte("some message"), } + require.False(t, s.isClosed()) + // enqueue a message require.Empty(t, s.incoming) - s.enqueue(msg) - require.Len(t, s.incoming, 1) + require.True(t, s.enqueue(msg)) // we should receive this message require.EventuallyWithT(t, func(c *assert.CollectT) { @@ -101,7 +105,7 @@ func TestSessionLifecycle(t *testing.T) { // enqueue on closed session should just drop the message require.Empty(t, s.incoming) - s.enqueue(msg) + require.False(t, s.enqueue(msg)) require.Empty(t, s.incoming) // on a closed session, a reader should return immediately @@ -166,3 +170,64 @@ func TestSessionLifecycleConcurrent(t *testing.T) { wg.Wait() } + +func TestSessionDeadlock(t *testing.T) { + // let check that at the end of this test all our go routines are stopped + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + s := setup() + + // hide the impl behind the session interface as a consumer + var sess view.Session = s + ch := sess.Receive() + + msg1 := []byte("msg") + + // we publish and then consume + assert.True(t, s.enqueue(&view.Message{Payload: msg1})) + require.EventuallyWithT(t, func(c *assert.CollectT) { + require.Equal(c, msg1, (<-ch).Payload) + }, timeout, 10*time.Millisecond) + + // next up, we publish msg1 and spawn another goroutine to publish msg2 + assert.True(t, s.enqueue(&view.Message{Payload: msg1})) + + var wg sync.WaitGroup + wg.Add(1) + // another producer + go func() { + defer wg.Done() + // as msg1 is not yet consumed, our produces is blocked + assert.False(t, s.enqueue(&view.Message{Payload: msg1})) + }() + + // let's give the producer a bit time + runtime.Gosched() + for { + value := rand.Intn(maxVal) + if value == 0 { + break + } + } + + // let's make sure that our produce is still waiting to complete publish + require.Never(t, func() bool { + // we expect to be blocked + wg.Wait() + return false + }, timeout, tick) + + // no we close the listener, which should unblock the producer + sess.Close() + + // wait for the producer to finish + wg.Wait() + + // we expect msg1 to be successfully published + require.EventuallyWithT(t, func(c *assert.CollectT) { + require.Equal(c, msg1, (<-ch).Payload) + }, timeout, tick) + + // msg2 should not be published + require.Empty(t, ch) +}