Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions acceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package quickfix
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"io"
"net"
Expand Down Expand Up @@ -319,7 +320,8 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {
}
}

sessID := SessionID{BeginString: string(beginString),
sessID := SessionID{
BeginString: string(beginString),
SenderCompID: string(targetCompID), SenderSubID: string(targetSubID), SenderLocationID: string(targetLocationID),
TargetCompID: string(senderCompID), TargetSubID: string(senderSubID), TargetLocationID: string(senderLocationID),
}
Expand Down Expand Up @@ -361,6 +363,7 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {
a.sessionAddr.Store(sessID, netConn.RemoteAddr())
msgIn := make(chan fixIn)
msgOut := make(chan []byte)
ctx := context.Background()

if err := session.connect(msgIn, msgOut); err != nil {
a.globalLog.OnEventf("Unable to accept session %v connection: %v", sessID, err.Error())
Expand All @@ -369,16 +372,16 @@ func (a *Acceptor) handleConnection(netConn net.Conn) {

go func() {
msgIn <- fixIn{msgBytes, parser.lastRead}
readLoop(parser, msgIn, a.globalLog)
readLoop(ctx, parser, msgIn, a.globalLog)
}()

writeLoop(netConn, msgOut, a.globalLog)
writeLoop(ctx, netConn, msgOut, a.globalLog)
}

func (a *Acceptor) dynamicSessionsLoop() {
var id int
var sessions = map[int]*session{}
var complete = make(chan int)
sessions := map[int]*session{}
complete := make(chan int)
defer close(complete)
LOOP:
for {
Expand Down
29 changes: 21 additions & 8 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,38 @@

package quickfix

import "io"
import (
"context"
"io"
)

func writeLoop(connection io.Writer, messageOut chan []byte, log Log) {
func writeLoop(ctx context.Context, connection io.Writer, messageOut chan []byte, log Log) {
for {
msg, ok := <-messageOut
if !ok {
select {
case <-ctx.Done():
return
}
case msg, ok := <-messageOut:
if !ok {
return
}

if _, err := connection.Write(msg); err != nil {
log.OnEvent(err.Error())
if _, err := connection.Write(msg); err != nil {
log.OnEvent(err.Error())
}
}
}
}

func readLoop(parser *parser, msgIn chan fixIn, log Log) {
func readLoop(ctx context.Context, parser *parser, msgIn chan fixIn, log Log) {
defer close(msgIn)

for {
select {
case <-ctx.Done():
return
default:
}

msg, err := parser.ReadMessage()
if err != nil {
log.OnEvent(err.Error())
Expand Down
41 changes: 39 additions & 2 deletions connection_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ package quickfix

import (
"bytes"
"context"
"strings"
"testing"
)

func TestWriteLoop(t *testing.T) {
ctx := context.Background()
writer := bytes.NewBufferString("")
msgOut := make(chan []byte)

Expand All @@ -31,7 +33,7 @@ func TestWriteLoop(t *testing.T) {
msgOut <- []byte("test msg 3")
close(msgOut)
}()
writeLoop(writer, msgOut, nullLog{})
writeLoop(ctx, writer, msgOut, nullLog{})

expected := "test msg 1 test msg 2 test msg 3"

Expand All @@ -40,12 +42,32 @@ func TestWriteLoop(t *testing.T) {
}
}

func TestWriteLoopCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
writer := bytes.NewBufferString("")
msgOut := make(chan []byte)

go func() {
msgOut <- []byte("test msg 1")
cancel()
}()
writeLoop(ctx, writer, msgOut, nullLog{})

expected := "test msg 1"

if writer.String() != expected {
t.Errorf("expected %v got %v", expected, writer.String())
}
}

func TestReadLoop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
msgIn := make(chan fixIn)
stream := "hello8=FIX.4.09=5blah10=103garbage8=FIX.4.09=4foo10=103"

parser := newParser(strings.NewReader(stream))
go readLoop(parser, msgIn, nullLog{})
go readLoop(ctx, parser, msgIn, nullLog{})

var tests = []struct {
expectedMsg string
Expand All @@ -71,3 +93,18 @@ func TestReadLoop(t *testing.T) {
}
}
}

func TestReadLoopCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
msgIn := make(chan fixIn)
stream := "hello8=FIX.4.09=5blah10=103garbage8=FIX.4.09=4foo10=103"

parser := newParser(strings.NewReader(stream))

cancel()
go readLoop(ctx, parser, msgIn, nullLog{})
_, ok := <-msgIn
if ok {
t.Error("Channel should be closed on context cancel")
}
}
36 changes: 24 additions & 12 deletions initiator.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,27 +151,37 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
wg.Done()
}()

// This is used to cancel read/write loops on reconnect/stop.
// During reconnect attempts, we are creating new context for read/write loops and canceling the old one immediately.
// During stop, we cancel the last created context here only after the session is stopped and we are safe to stop the read/write loop following the logout and disconnect.
readWriteCancel := func() {}

defer func() {
session.stop()
wg.Wait()
readWriteCancel()
}()

connectionAttempt := 0

ctx := context.Background()

for {
if !i.waitForInSessionTime(session) {
return
}

ctx, cancel := context.WithCancel(context.Background())
dialCtx, dialCancel := context.WithCancel(ctx)
readWriteCtx, rwCancel := context.WithCancel(ctx)
readWriteCancel = rwCancel

// We start a goroutine in order to be able to cancel the dialer mid-connection
// on receiving a stop signal to stop the initiator.
go func() {
select {
case <-i.stopChan:
cancel()
case <-ctx.Done():
dialCancel()
case <-dialCtx.Done():
return
}
}()
Expand All @@ -183,7 +193,7 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
address := session.SocketConnectAddress[connectionAttempt%len(session.SocketConnectAddress)]
session.log.OnEventf("Connecting to: %v", address)

netConn, err := dialer.DialContext(ctx, "tcp", address)
netConn, err := dialer.DialContext(dialCtx, "tcp", address)
if err != nil {
session.log.OnEventf("Failed to connect: %v", err)
goto reconnect
Expand All @@ -207,24 +217,25 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di

msgIn = make(chan fixIn)
msgOut = make(chan []byte)
if err := session.connect(msgIn, msgOut); err != nil {
session.log.OnEventf("Failed to initiate: %v", err)
goto reconnect
}

go readLoop(newParser(bufio.NewReader(netConn)), msgIn, session.log)
go readLoop(readWriteCtx, newParser(bufio.NewReader(netConn)), msgIn, session.log)
disconnected = make(chan interface{})
go func() {
writeLoop(netConn, msgOut, session.log)
writeLoop(readWriteCtx, netConn, msgOut, session.log)
if err := netConn.Close(); err != nil {
session.log.OnEvent(err.Error())
}
close(disconnected)
}()

if err := session.connect(msgIn, msgOut); err != nil {
session.log.OnEventf("Failed to initiate: %v", err)
goto reconnect
}

// This ensures we properly cleanup the goroutine and context used for
// dial cancelation after successful connection.
cancel()
dialCancel()

select {
case <-disconnected:
Expand All @@ -233,7 +244,8 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di
}

reconnect:
cancel()
dialCancel()
readWriteCancel()

connectionAttempt++
session.log.OnEventf("Reconnecting in %v", session.ReconnectInterval)
Expand Down
Loading
Loading