diff --git a/ascii_over_tcp_client.go b/ascii_over_tcp_client.go index cc263e4..2e9007f 100644 --- a/ascii_over_tcp_client.go +++ b/ascii_over_tcp_client.go @@ -10,16 +10,16 @@ import ( // ASCIIOverTCPClientHandler implements Packager and Transporter interface. type ASCIIOverTCPClientHandler struct { - asciiPackager - asciiTCPTransporter + ASCIIPackager + ASCIITCPTransporter } // NewASCIIOverTCPClientHandler allocates and initializes a ASCIIOverTCPClientHandler. func NewASCIIOverTCPClientHandler(address string) *ASCIIOverTCPClientHandler { - handler := &ASCIIOverTCPClientHandler{} - handler.Address = address - handler.Timeout = tcpTimeout - handler.IdleTimeout = tcpIdleTimeout + handler := &ASCIIOverTCPClientHandler{ + ASCIIPackager: ASCIIPackager{}, + ASCIITCPTransporter: NewASCIITCPTransporter(address), + } return handler } @@ -29,12 +29,22 @@ func ASCIIOverTCPClient(address string) Client { return NewClient(handler) } -// asciiTCPTransporter implements Transporter interface. -type asciiTCPTransporter struct { - tcpTransporter +var _ Transporter = (*ASCIITCPTransporter)(nil) + +// ASCIITCPTransporter implements Transporter interface. +type ASCIITCPTransporter struct { + TCPTransporter +} + +// NewASCIITCPTransporter creates ASCIITCPTransporter with default values +func NewASCIITCPTransporter(address string) ASCIITCPTransporter { + return ASCIITCPTransporter{ + TCPTransporter: NewTCPTransporter(address), + } } -func (mb *asciiTCPTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { +// Send sends data to server and ensures response has required length. +func (mb *ASCIITCPTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { mb.mu.Lock() defer mb.mu.Unlock() diff --git a/asciiclient.go b/asciiclient.go index 94a85f0..85d0655 100644 --- a/asciiclient.go +++ b/asciiclient.go @@ -24,17 +24,16 @@ var asciiStart = []string{":", ">"} // ASCIIClientHandler implements Packager and Transporter interface. type ASCIIClientHandler struct { - asciiPackager - asciiSerialTransporter + ASCIIPackager + ASCIISerialTransporter } // NewASCIIClientHandler allocates and initializes a ASCIIClientHandler. func NewASCIIClientHandler(address string) *ASCIIClientHandler { - handler := &ASCIIClientHandler{} - handler.Address = address - handler.Timeout = serialTimeout - handler.IdleTimeout = serialIdleTimeout - handler.serialPort.Logger = handler // expose the logger + handler := &ASCIIClientHandler{ + ASCIIPackager: ASCIIPackager{}, + ASCIISerialTransporter: NewASCIISerialTransporter(address), + } return handler } @@ -44,24 +43,25 @@ func ASCIIClient(address string) Client { return NewClient(handler) } -// asciiPackager implements Packager interface. -type asciiPackager struct { +// ASCIIPackager implements Packager interface. +type ASCIIPackager struct { SlaveID byte } // SetSlave sets modbus slave id for the next client operations -func (mb *asciiPackager) SetSlave(slaveID byte) { +func (mb *ASCIIPackager) SetSlave(slaveID byte) { mb.SlaveID = slaveID } // Encode encodes PDU in a ASCII frame: -// Start : 1 char -// Address : 2 chars -// Function : 2 chars -// Data : 0 up to 2x252 chars -// LRC : 2 chars -// End : 2 chars -func (mb *asciiPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { +// +// Start : 1 char +// Address : 2 chars +// Function : 2 chars +// Data : 0 up to 2x252 chars +// LRC : 2 chars +// End : 2 chars +func (mb *ASCIIPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { var buf bytes.Buffer if _, err = buf.WriteString(asciiStart[0]); err != nil { @@ -88,7 +88,7 @@ func (mb *asciiPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { } // Verify verifies response length, frame boundary and slave id. -func (mb *asciiPackager) Verify(aduRequest []byte, aduResponse []byte) (err error) { +func (mb *ASCIIPackager) Verify(aduRequest []byte, aduResponse []byte) (err error) { length := len(aduResponse) // Minimum size (including address, function and LRC) if length < asciiMinSize+6 { @@ -129,7 +129,7 @@ func (mb *asciiPackager) Verify(aduRequest []byte, aduResponse []byte) (err erro } // Decode extracts PDU from ASCII frame and verify LRC. -func (mb *asciiPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { +func (mb *ASCIIPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { pdu = &ProtocolDataUnit{} // Slave address address, err := readHex(adu[1:]) @@ -163,32 +163,42 @@ func (mb *asciiPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { return } -// asciiSerialTransporter implements Transporter interface. -type asciiSerialTransporter struct { - serialPort - Logger logger +// ASCIISerialTransporter implements Transporter interface. +type ASCIISerialTransporter struct { + SerialPort +} + +// NewASCIISerialTransporter creates ASCIISerialTransporter with default values +func NewASCIISerialTransporter(address string) ASCIISerialTransporter { + t := ASCIISerialTransporter{ + SerialPort: *NewSerialPort(address), + } + t.SerialPort.Logger = &t + return t } -func (mb *asciiSerialTransporter) Printf(format string, v ...interface{}) { +// Printf implements the Logger interface +func (mb *ASCIISerialTransporter) Printf(format string, v ...interface{}) { if mb.Logger != nil { mb.Logger.Printf(format, v...) } } -func (mb *asciiSerialTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { - mb.serialPort.mu.Lock() - defer mb.serialPort.mu.Unlock() +// Send sends data to serial device and ensures adequate response for request type +func (mb *ASCIISerialTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { + mb.SerialPort.mu.Lock() + defer mb.SerialPort.mu.Unlock() // Make sure port is connected - if err = mb.serialPort.connect(); err != nil { + if err = mb.SerialPort.connect(); err != nil { return } // Start the timer to close when idle - mb.serialPort.lastActivity = time.Now() - mb.serialPort.startCloseTimer() + mb.SerialPort.lastActivity = time.Now() + mb.SerialPort.startCloseTimer() // Send the request - mb.serialPort.logf("modbus: send % x\n", aduRequest) + mb.SerialPort.logf("modbus: send % x\n", aduRequest) if _, err = mb.port.Write(aduRequest); err != nil { return } @@ -211,7 +221,7 @@ func (mb *asciiSerialTransporter) Send(aduRequest []byte) (aduResponse []byte, e } } aduResponse = data[:length] - mb.serialPort.logf("modbus: recv % x\n", aduResponse) + mb.SerialPort.logf("modbus: recv % x\n", aduResponse) return } diff --git a/asciiclient_test.go b/asciiclient_test.go index 4679f86..0b443ec 100644 --- a/asciiclient_test.go +++ b/asciiclient_test.go @@ -10,7 +10,7 @@ import ( ) func TestASCIIEncoding(t *testing.T) { - encoder := asciiPackager{} + encoder := ASCIIPackager{} encoder.SlaveID = 17 pdu := ProtocolDataUnit{} @@ -28,7 +28,7 @@ func TestASCIIEncoding(t *testing.T) { } func TestASCIIDecoding(t *testing.T) { - decoder := asciiPackager{} + decoder := ASCIIPackager{} decoder.SlaveID = 247 adu := []byte(":F7031389000A60\r\n") @@ -47,7 +47,7 @@ func TestASCIIDecoding(t *testing.T) { } func TestASCIIDecodeStartCharacter(t *testing.T) { - decoder := asciiPackager{} + decoder := ASCIIPackager{} aduReq := []byte(":010300010002F9\r\n") aduRespGreaterThan := []byte(">010304010F1509CA\r\n") aduRespColon := []byte(":010304010F1509CA\r\n") @@ -69,7 +69,7 @@ func TestASCIIDecodeStartCharacter(t *testing.T) { } func BenchmarkASCIIEncoder(b *testing.B) { - encoder := asciiPackager{ + encoder := ASCIIPackager{ SlaveID: 10, } pdu := ProtocolDataUnit{ @@ -85,7 +85,7 @@ func BenchmarkASCIIEncoder(b *testing.B) { } func BenchmarkASCIIDecoder(b *testing.B) { - decoder := asciiPackager{ + decoder := ASCIIPackager{ SlaveID: 10, } adu := []byte(":F7031389000A60\r\n") diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..e31a8b7 --- /dev/null +++ b/client_test.go @@ -0,0 +1,35 @@ +package modbus + +import "testing" + +const localhost = ":502" + +func TestTcp(t *testing.T) { + pack := TCPPackager{SlaveID: 1} + trans := NewTCPTransporter(localhost) + _ = NewClient2(&pack, &trans) +} + +func TestRtuOverTcp(t *testing.T) { + pack := RtuPackager{SlaveID: 1} + trans := NewRTUTCPTransporter(localhost) + _ = NewClient2(&pack, &trans) +} + +func TestAsciiOverTcp(t *testing.T) { + pack := ASCIIPackager{SlaveID: 1} + trans := NewASCIITCPTransporter(localhost) + _ = NewClient2(&pack, &trans) +} + +func TestRtu(t *testing.T) { + pack := RtuPackager{SlaveID: 1} + trans := NewRtuSerialTransporter(localhost) + _ = NewClient2(&pack, &trans) +} + +func TestAscii(t *testing.T) { + pack := ASCIIPackager{SlaveID: 1} + trans := NewASCIISerialTransporter(localhost) + _ = NewClient2(&pack, &trans) +} diff --git a/rtu_over_tcp_client.go b/rtu_over_tcp_client.go index d9ec19d..d72a45a 100644 --- a/rtu_over_tcp_client.go +++ b/rtu_over_tcp_client.go @@ -11,16 +11,15 @@ import ( // RTUOverTCPClientHandler implements Packager and Transporter interface. type RTUOverTCPClientHandler struct { - rtuPackager - rtuTCPTransporter + RtuPackager + RtuTCPTransporter } // NewRTUOverTCPClientHandler allocates and initializes a RTUOverTCPClientHandler. func NewRTUOverTCPClientHandler(address string) *RTUOverTCPClientHandler { - handler := &RTUOverTCPClientHandler{} - handler.Address = address - handler.Timeout = tcpTimeout - handler.IdleTimeout = tcpIdleTimeout + handler := &RTUOverTCPClientHandler{ + RtuTCPTransporter: NewRTUTCPTransporter(address), + } return handler } @@ -30,13 +29,22 @@ func RTUOverTCPClient(address string) Client { return NewClient(handler) } -// rtuTCPTransporter implements Transporter interface. -type rtuTCPTransporter struct { - tcpTransporter +var _ Transporter = (*RtuTCPTransporter)(nil) + +// RtuTCPTransporter implements Transporter interface. +type RtuTCPTransporter struct { + TCPTransporter +} + +// NewRTUTCPTransporter creates RtuTCPTransporter with default values +func NewRTUTCPTransporter(address string) RtuTCPTransporter { + return RtuTCPTransporter{ + TCPTransporter: NewTCPTransporter(address), + } } // Send sends data to server and ensures adequate response for request type -func (mb *rtuTCPTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { +func (mb *RtuTCPTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { mb.mu.Lock() defer mb.mu.Unlock() diff --git a/rtuclient.go b/rtuclient.go index 615d95c..5af8b7a 100644 --- a/rtuclient.go +++ b/rtuclient.go @@ -44,17 +44,16 @@ const ( // RTUClientHandler implements Packager and Transporter interface. type RTUClientHandler struct { - rtuPackager - rtuSerialTransporter + RtuPackager + RtuSerialTransporter } // NewRTUClientHandler allocates and initializes a RTUClientHandler. func NewRTUClientHandler(address string) *RTUClientHandler { - handler := &RTUClientHandler{} - handler.Address = address - handler.Timeout = serialTimeout - handler.IdleTimeout = serialIdleTimeout - handler.serialPort.Logger = handler // expose the logger + handler := &RTUClientHandler{ + RtuPackager: RtuPackager{}, + RtuSerialTransporter: NewRtuSerialTransporter(address), + } return handler } @@ -64,22 +63,23 @@ func RTUClient(address string) Client { return NewClient(handler) } -// rtuPackager implements Packager interface. -type rtuPackager struct { +// RtuPackager implements Packager interface. +type RtuPackager struct { SlaveID byte } // SetSlave sets modbus slave id for the next client operations -func (mb *rtuPackager) SetSlave(slaveID byte) { +func (mb *RtuPackager) SetSlave(slaveID byte) { mb.SlaveID = slaveID } // Encode encodes PDU in a RTU frame: -// Slave Address : 1 byte -// Function : 1 byte -// Data : 0 up to 252 bytes -// CRC : 2 byte -func (mb *rtuPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { +// +// Slave Address : 1 byte +// Function : 1 byte +// Data : 0 up to 252 bytes +// CRC : 2 byte +func (mb *RtuPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { length := len(pdu.Data) + 4 if length > rtuMaxSize { err = fmt.Errorf("modbus: length of data '%v' must not be bigger than '%v'", length, rtuMaxSize) @@ -102,7 +102,7 @@ func (mb *rtuPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { } // Verify verifies response length and slave id. -func (mb *rtuPackager) Verify(aduRequest []byte, aduResponse []byte) (err error) { +func (mb *RtuPackager) Verify(aduRequest []byte, aduResponse []byte) (err error) { length := len(aduResponse) // Minimum size (including address, function and CRC) if length < rtuMinSize { @@ -118,7 +118,7 @@ func (mb *rtuPackager) Verify(aduRequest []byte, aduResponse []byte) (err error) } // Decode extracts PDU from RTU frame and verify CRC. -func (mb *rtuPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { +func (mb *RtuPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { length := len(adu) // Calculate checksum var crc crc @@ -135,13 +135,24 @@ func (mb *rtuPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { return } -// rtuSerialTransporter implements Transporter interface. -type rtuSerialTransporter struct { - serialPort - Logger logger +// RtuSerialTransporter implements Transporter interface. +type RtuSerialTransporter struct { + SerialPort } -func (mb *rtuSerialTransporter) Printf(format string, v ...interface{}) { +var _ Transporter = (*RtuSerialTransporter)(nil) + +// NewRtuSerialTransporter creates RtuSerialTransporter with default values +func NewRtuSerialTransporter(address string) RtuSerialTransporter { + t := RtuSerialTransporter{ + SerialPort: *NewSerialPort(address), + } + t.SerialPort.Logger = &t + return t +} + +// Printf implements the Logger interface +func (mb *RtuSerialTransporter) Printf(format string, v ...interface{}) { if mb.Logger != nil { mb.Logger.Printf(format, v...) } @@ -257,20 +268,21 @@ func readIncrementally(slaveID, functionCode byte, r io.Reader, deadline time.Ti } -func (mb *rtuSerialTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { +// Send sends data to serial device and ensures adequate response for request type +func (mb *RtuSerialTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { mb.mu.Lock() defer mb.mu.Unlock() // Make sure port is connected - if err = mb.serialPort.connect(); err != nil { + if err = mb.SerialPort.connect(); err != nil { return } // Start the timer to close when idle - mb.serialPort.lastActivity = time.Now() - mb.serialPort.startCloseTimer() + mb.SerialPort.lastActivity = time.Now() + mb.SerialPort.startCloseTimer() // Send the request - mb.serialPort.logf("modbus: send % x\n", aduRequest) + mb.SerialPort.logf("modbus: send % x\n", aduRequest) if _, err = mb.port.Write(aduRequest); err != nil { return } @@ -279,15 +291,15 @@ func (mb *rtuSerialTransporter) Send(aduRequest []byte) (aduResponse []byte, err bytesToRead := calculateResponseLength(aduRequest) time.Sleep(mb.calculateDelay(len(aduRequest) + bytesToRead)) - data, err := readIncrementally(aduRequest[0], aduRequest[1], mb.port, time.Now().Add(mb.serialPort.Config.Timeout)) - mb.serialPort.logf("modbus: recv % x\n", data[:]) - aduResponse = data - return + data, err := readIncrementally(aduRequest[0], aduRequest[1], mb.port, time.Now().Add(mb.SerialPort.Config.Timeout)) + mb.SerialPort.logf("modbus: recv % x\n", data[:]) + + return data, err } // calculateDelay roughly calculates time needed for the next frame. // See MODBUS over Serial Line - Specification and Implementation Guide (page 13). -func (mb *rtuSerialTransporter) calculateDelay(chars int) time.Duration { +func (mb *RtuSerialTransporter) calculateDelay(chars int) time.Duration { var characterDelay, frameDelay int // us if mb.BaudRate <= 0 || mb.BaudRate > 19200 { diff --git a/rtuclient_prop_test.go b/rtuclient_prop_test.go index 4da20a6..d6490bb 100644 --- a/rtuclient_prop_test.go +++ b/rtuclient_prop_test.go @@ -9,7 +9,7 @@ import ( func TestRTUEncodeDecode(t *testing.T) { rapid.Check(t, func(t *rapid.T) { - packager := &rtuPackager{ + packager := &RtuPackager{ SlaveID: rapid.Byte().Draw(t, "SlaveID").(byte), } diff --git a/rtuclient_test.go b/rtuclient_test.go index 815b613..38a0559 100644 --- a/rtuclient_test.go +++ b/rtuclient_test.go @@ -12,7 +12,7 @@ import ( ) func TestRTUEncoding(t *testing.T) { - encoder := rtuPackager{} + encoder := RtuPackager{} encoder.SlaveID = 0x01 pdu := ProtocolDataUnit{} @@ -30,7 +30,7 @@ func TestRTUEncoding(t *testing.T) { } func TestRTUDecoding(t *testing.T) { - decoder := rtuPackager{} + decoder := RtuPackager{} adu := []byte{0x01, 0x10, 0x8A, 0x00, 0x00, 0x03, 0xAA, 0x10} pdu, err := decoder.Decode(adu) @@ -71,7 +71,7 @@ func TestCalculateResponseLength(t *testing.T) { } func BenchmarkRTUEncoder(b *testing.B) { - encoder := rtuPackager{ + encoder := RtuPackager{ SlaveID: 10, } pdu := ProtocolDataUnit{ @@ -87,7 +87,7 @@ func BenchmarkRTUEncoder(b *testing.B) { } func BenchmarkRTUDecoder(b *testing.B) { - decoder := rtuPackager{ + decoder := RtuPackager{ SlaveID: 10, } adu := []byte{0x01, 0x10, 0x8A, 0x00, 0x00, 0x03, 0xAA, 0x10} diff --git a/serial.go b/serial.go index a910b95..0a9e22f 100644 --- a/serial.go +++ b/serial.go @@ -19,8 +19,8 @@ const ( serialIdleTimeout = 60 * time.Second ) -// serialPort has configuration and I/O controller. -type serialPort struct { +// SerialPort has configuration and I/O controller. +type SerialPort struct { // Serial port configuration. serial.Config @@ -34,7 +34,19 @@ type serialPort struct { closeTimer *time.Timer } -func (mb *serialPort) Connect() (err error) { +// NewSerialPort creates a serial port with default configuration. +func NewSerialPort(address string) *SerialPort { + return &SerialPort{ + Config: serial.Config{ + Address: address, + Timeout: serialTimeout, + }, + IdleTimeout: serialIdleTimeout, + } +} + +// Connect opens the port. +func (mb *SerialPort) Connect() (err error) { mb.mu.Lock() defer mb.mu.Unlock() @@ -42,7 +54,7 @@ func (mb *serialPort) Connect() (err error) { } // connect connects to the serial port if it is not connected. Caller must hold the mutex. -func (mb *serialPort) connect() error { +func (mb *SerialPort) connect() error { if mb.port == nil { port, err := serial.Open(&mb.Config) if err != nil { @@ -53,7 +65,8 @@ func (mb *serialPort) connect() error { return nil } -func (mb *serialPort) Close() (err error) { +// Close closes the port. +func (mb *SerialPort) Close() (err error) { mb.mu.Lock() defer mb.mu.Unlock() @@ -61,7 +74,7 @@ func (mb *serialPort) Close() (err error) { } // close closes the serial port if it is connected. Caller must hold the mutex. -func (mb *serialPort) close() (err error) { +func (mb *SerialPort) close() (err error) { if mb.port != nil { err = mb.port.Close() mb.port = nil @@ -69,13 +82,13 @@ func (mb *serialPort) close() (err error) { return } -func (mb *serialPort) logf(format string, v ...interface{}) { +func (mb *SerialPort) logf(format string, v ...interface{}) { if mb.Logger != nil { mb.Logger.Printf(format, v...) } } -func (mb *serialPort) startCloseTimer() { +func (mb *SerialPort) startCloseTimer() { if mb.IdleTimeout <= 0 { return } @@ -87,7 +100,7 @@ func (mb *serialPort) startCloseTimer() { } // closeIdle closes the connection if last activity is passed behind IdleTimeout. -func (mb *serialPort) closeIdle() { +func (mb *SerialPort) closeIdle() { mb.mu.Lock() defer mb.mu.Unlock() diff --git a/serial_test.go b/serial_test.go index 78fce0a..5e9b4fd 100644 --- a/serial_test.go +++ b/serial_test.go @@ -22,7 +22,7 @@ func TestSerialCloseIdle(t *testing.T) { port := &nopCloser{ ReadWriter: &bytes.Buffer{}, } - s := serialPort{ + s := SerialPort{ port: port, IdleTimeout: 100 * time.Millisecond, } diff --git a/tcpclient.go b/tcpclient.go index 767d6ce..4d09a0c 100644 --- a/tcpclient.go +++ b/tcpclient.go @@ -35,16 +35,16 @@ func (length ErrTCPHeaderLength) Error() string { // TCPClientHandler implements Packager and Transporter interface. type TCPClientHandler struct { - tcpPackager - tcpTransporter + TCPPackager + TCPTransporter } // NewTCPClientHandler allocates a new TCPClientHandler. func NewTCPClientHandler(address string) *TCPClientHandler { - h := &TCPClientHandler{} - h.Address = address - h.Timeout = tcpTimeout - h.IdleTimeout = tcpIdleTimeout + h := &TCPClientHandler{ + TCPPackager: TCPPackager{}, + TCPTransporter: NewTCPTransporter(address), + } return h } @@ -54,8 +54,8 @@ func TCPClient(address string) Client { return NewClient(handler) } -// tcpPackager implements Packager interface. -type tcpPackager struct { +// TCPPackager implements Packager interface. +type TCPPackager struct { // For synchronization between messages of server & client transactionID uint32 // Broadcast address is 0 @@ -63,18 +63,19 @@ type tcpPackager struct { } // SetSlave sets modbus slave id for the next client operations -func (mb *tcpPackager) SetSlave(slaveID byte) { +func (mb *TCPPackager) SetSlave(slaveID byte) { mb.SlaveID = slaveID } // Encode adds modbus application protocol header: -// Transaction identifier: 2 bytes -// Protocol identifier: 2 bytes -// Length: 2 bytes -// Unit identifier: 1 byte -// Function code: 1 byte -// Data: n bytes -func (mb *tcpPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { +// +// Transaction identifier: 2 bytes +// Protocol identifier: 2 bytes +// Length: 2 bytes +// Unit identifier: 1 byte +// Function code: 1 byte +// Data: n bytes +func (mb *TCPPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { adu = make([]byte, tcpHeaderSize+1+len(pdu.Data)) // Transaction identifier @@ -95,16 +96,17 @@ func (mb *tcpPackager) Encode(pdu *ProtocolDataUnit) (adu []byte, err error) { } // Verify confirms transaction, protocol and unit id. -func (mb *tcpPackager) Verify(aduRequest []byte, aduResponse []byte) error { +func (mb *TCPPackager) Verify(aduRequest []byte, aduResponse []byte) error { return verify(aduRequest, aduResponse) } // Decode extracts PDU from TCP frame: -// Transaction identifier: 2 bytes -// Protocol identifier: 2 bytes -// Length: 2 bytes -// Unit identifier: 1 byte -func (mb *tcpPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { +// +// Transaction identifier: 2 bytes +// Protocol identifier: 2 bytes +// Length: 2 bytes +// Unit identifier: 1 byte +func (mb *TCPPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { // Read length value in the header length := binary.BigEndian.Uint16(adu[4:]) pduLength := len(adu) - tcpHeaderSize @@ -119,8 +121,10 @@ func (mb *tcpPackager) Decode(adu []byte) (pdu *ProtocolDataUnit, err error) { return } -// tcpTransporter implements Transporter interface. -type tcpTransporter struct { +var _ Transporter = (*TCPTransporter)(nil) + +// TCPTransporter implements Transporter interface. +type TCPTransporter struct { // Connect string Address string // Connect & Read timeout @@ -155,8 +159,17 @@ const ( readResultCloseRetry ) +// NewTCPTransporter creates TCPTransporter with default settings +func NewTCPTransporter(address string) TCPTransporter { + return TCPTransporter{ + Address: address, + Timeout: tcpTimeout, + IdleTimeout: tcpIdleTimeout, + } +} + // Send sends data to server and ensures response length is greater than header length. -func (mb *tcpTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { +func (mb *TCPTransporter) Send(aduRequest []byte) (aduResponse []byte, err error) { mb.mu.Lock() defer mb.mu.Unlock() @@ -206,7 +219,7 @@ func (mb *tcpTransporter) Send(aduRequest []byte) (aduResponse []byte, err error } } -func (mb *tcpTransporter) readResponse(aduRequest []byte, data []byte, recoveryDeadline time.Time) (aduResponse []byte, res readResult, err error) { +func (mb *TCPTransporter) readResponse(aduRequest []byte, data []byte, recoveryDeadline time.Time) (aduResponse []byte, res readResult, err error) { // res is readResultDone by default, which either means we succeeded or err contains the fatal error for { if _, err = io.ReadFull(mb.conn, data[:tcpHeaderSize]); err == nil { @@ -262,7 +275,7 @@ func (mb *tcpTransporter) readResponse(aduRequest []byte, data []byte, recoveryD } } -func (mb *tcpTransporter) processResponse(data []byte) (aduResponse []byte, err error) { +func (mb *TCPTransporter) processResponse(data []byte) (aduResponse []byte, err error) { // Read length, ignore transaction & protocol id (4 bytes) length := int(binary.BigEndian.Uint16(data[4:])) if length <= 0 { @@ -317,14 +330,14 @@ func verify(aduRequest []byte, aduResponse []byte) (err error) { // Connect establishes a new connection to the address in Address. // Connect and Close are exported so that multiple requests can be done with one session -func (mb *tcpTransporter) Connect() error { +func (mb *TCPTransporter) Connect() error { mb.mu.Lock() defer mb.mu.Unlock() return mb.connect() } -func (mb *tcpTransporter) connect() error { +func (mb *TCPTransporter) connect() error { if mb.conn == nil { dialer := net.Dialer{Timeout: mb.Timeout} conn, err := dialer.Dial("tcp", mb.Address) @@ -339,7 +352,7 @@ func (mb *tcpTransporter) connect() error { return nil } -func (mb *tcpTransporter) startCloseTimer() { +func (mb *TCPTransporter) startCloseTimer() { if mb.IdleTimeout <= 0 { return } @@ -351,7 +364,7 @@ func (mb *tcpTransporter) startCloseTimer() { } // Close closes current connection. -func (mb *tcpTransporter) Close() error { +func (mb *TCPTransporter) Close() error { mb.mu.Lock() defer mb.mu.Unlock() @@ -360,7 +373,7 @@ func (mb *tcpTransporter) Close() error { // flush flushes pending data in the connection, // returns io.EOF if connection is closed. -func (mb *tcpTransporter) flush(b []byte) (err error) { +func (mb *TCPTransporter) flush(b []byte) (err error) { if err = mb.conn.SetReadDeadline(time.Now()); err != nil { return } @@ -374,14 +387,14 @@ func (mb *tcpTransporter) flush(b []byte) (err error) { return } -func (mb *tcpTransporter) logf(format string, v ...interface{}) { +func (mb *TCPTransporter) logf(format string, v ...interface{}) { if mb.Logger != nil { mb.Logger.Printf(format, v...) } } // closeLocked closes current connection. Caller must hold the mutex before calling this method. -func (mb *tcpTransporter) close() (err error) { +func (mb *TCPTransporter) close() (err error) { if mb.conn != nil { err = mb.conn.Close() mb.conn = nil @@ -390,7 +403,7 @@ func (mb *tcpTransporter) close() (err error) { } // closeIdle closes the connection if last activity is passed behind IdleTimeout. -func (mb *tcpTransporter) closeIdle() { +func (mb *TCPTransporter) closeIdle() { mb.mu.Lock() defer mb.mu.Unlock() diff --git a/tcpclient_prop_test.go b/tcpclient_prop_test.go index 87f18e2..15d09b6 100644 --- a/tcpclient_prop_test.go +++ b/tcpclient_prop_test.go @@ -9,7 +9,7 @@ import ( func TestTCPEncodeDecode(t *testing.T) { rapid.Check(t, func(t *rapid.T) { - packager := &tcpPackager{ + packager := &TCPPackager{ transactionID: rapid.Uint32().Draw(t, "transactionID").(uint32), SlaveID: rapid.Byte().Draw(t, "SlaveID").(byte), } diff --git a/tcpclient_test.go b/tcpclient_test.go index 0871083..7cfbfbf 100644 --- a/tcpclient_test.go +++ b/tcpclient_test.go @@ -13,7 +13,7 @@ import ( ) func TestTCPEncoding(t *testing.T) { - packager := tcpPackager{} + packager := TCPPackager{} pdu := ProtocolDataUnit{} pdu.FunctionCode = 3 pdu.Data = []byte{0, 4, 0, 3} @@ -30,7 +30,7 @@ func TestTCPEncoding(t *testing.T) { } func TestTCPDecoding(t *testing.T) { - packager := tcpPackager{} + packager := TCPPackager{} packager.transactionID = 1 packager.SlaveID = 17 adu := []byte{0, 1, 0, 0, 0, 6, 17, 3, 0, 120, 0, 3} @@ -69,7 +69,7 @@ func TestTCPTransporter(t *testing.T) { return } }() - client := &tcpTransporter{ + client := &TCPTransporter{ Address: ln.Addr().String(), Timeout: 1 * time.Second, IdleTimeout: 100 * time.Millisecond, @@ -112,7 +112,7 @@ func TestTCPTransactionMismatchRetry(t *testing.T) { defer conn.Close() // ensure that answer is only written after second read attempt failed time.Sleep(2500 * time.Millisecond) - packager := &tcpPackager{SlaveID: 0} + packager := &TCPPackager{SlaveID: 0} pdu := &ProtocolDataUnit{ FunctionCode: FuncCodeReadInputRegisters, Data: append([]byte{0x02}, data...), @@ -173,7 +173,7 @@ func TestTCPTransactionMismatchRetry(t *testing.T) { } func BenchmarkTCPEncoder(b *testing.B) { - encoder := tcpPackager{ + encoder := TCPPackager{ SlaveID: 10, } pdu := ProtocolDataUnit{ @@ -189,7 +189,7 @@ func BenchmarkTCPEncoder(b *testing.B) { } func BenchmarkTCPDecoder(b *testing.B) { - decoder := tcpPackager{ + decoder := TCPPackager{ SlaveID: 10, } adu := []byte{0, 1, 0, 0, 0, 6, 17, 3, 0, 120, 0, 3}