diff --git a/connection_reactor.go b/connection_reactor.go index 25b4dec5..bbd85595 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -69,6 +69,10 @@ func (c *connection) onClose() error { // closeBuffer recycle input & output LinkBuffer. func (c *connection) closeBuffer() { + // if c is not closed by user, we shouldn't reuse buffer because user may still hold the buffer + if !c.isCloseBy(user) { + return + } var onConnect, _ = c.onConnectCallback.Load().(OnConnect) var onRequest, _ = c.onRequestCallback.Load().(OnRequest) // if client close the connection, we cannot ensure that the poller is not process the buffer, diff --git a/connection_test.go b/connection_test.go index 548d98a2..d36436d2 100644 --- a/connection_test.go +++ b/connection_test.go @@ -31,6 +31,7 @@ import ( "syscall" "testing" "time" + "unsafe" ) func TestConnectionWrite(t *testing.T) { @@ -182,6 +183,51 @@ func TestConnectionWaitReadHalfPacket(t *testing.T) { wg.Wait() } +func TestConnectionReadRepeatedlyAfterClosed(t *testing.T) { + buffers := map[uintptr][]byte{} + for i := 0; i < 1000; i++ { + msg1 := []byte(fmt.Sprintf("%5d", i)) + msg2 := []byte(fmt.Sprintf("%5d", i+1)) + + r1, w1 := GetSysFdPairs() + r2, w2 := GetSysFdPairs() + rconn1, rconn2 := new(connection), new(connection) + rconn1.init(&netFD{fd: r1}, nil) + rconn2.init(&netFD{fd: r2}, nil) + rconn1.SetOnConnect(func(ctx context.Context, connection Connection) context.Context { + return ctx + }) + rconn2.SetOnConnect(func(ctx context.Context, connection Connection) context.Context { + return ctx + }) + trigger := make(chan struct{}) + go func() { + syscall.Write(w1, msg1) + trigger <- struct{}{} + <-trigger // wait read msg1 + + syscall.Close(w1) + syscall.Write(w2, msg2) + trigger <- struct{}{} + //syscall.Close(w2) + }() + + <-trigger // wait write msg1 + buf1, _ := rconn1.Reader().Next(5) + Equal(t, string(buf1), string(msg1)) + trigger <- struct{}{} + + <-trigger // wait write msg2 + buf2, _ := rconn2.Reader().Next(5) + Equal(t, string(buf2), string(msg2)) + Equal(t, string(buf1), string(msg1)) + Assert(t, buffers[uintptr(unsafe.Pointer(&buf1[0]))] == nil) + Assert(t, buffers[uintptr(unsafe.Pointer(&buf2[0]))] == nil) + buffers[uintptr(unsafe.Pointer(&buf1[0]))] = buf1 + buffers[uintptr(unsafe.Pointer(&buf2[0]))] = buf2 + } +} + func TestReadTimer(t *testing.T) { read := time.NewTimer(time.Second) MustTrue(t, read.Stop())