Skip to content

Commit 4b0c5de

Browse files
committed
add custom logger support
This adds the ability to set a custom logger for the library to use. The logger is passed through client and server configuration objects and is propagated to inner objects (i.e. transports) as needed. If unspecified, the previous behavior of logging to stdout is preserved.
1 parent 144a48c commit 4b0c5de

10 files changed

+96
-33
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,11 @@ Byte encoding/endianness/word ordering:
180180
* Little and Big endian, with and without word swap for 32 and 64-bit
181181
integers and floating point numbers.
182182

183+
### Logging ###
184+
Both client and server objects will log to stdout by default.
185+
This behavior can be overriden by passing a log.Logger object
186+
through the Logger property of ClientConfiguration/ServerConfiguration.
187+
183188
### TODO (in no particular order)
184189
* Add RTU (serial) support to the server
185190
* Add more tests

client.go

+12-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto/tls"
55
"crypto/x509"
66
"fmt"
7+
"log"
78
"net"
89
"os"
910
"strings"
@@ -52,6 +53,9 @@ type ClientConfiguration struct {
5253
// the server (tcp+tls only). Leaf (i.e. server) certificates can also
5354
// be used in case of self-signed certs, or if cert pinning is required.
5455
TLSRootCAs *x509.CertPool
56+
// Logger provides a custom sink for log messages.
57+
// If nil, messages will be written to stdout.
58+
Logger *log.Logger
5559
}
5660

5761
// Modbus client object.
@@ -81,7 +85,8 @@ func NewClient(conf *ClientConfiguration) (mc *ModbusClient, err error) {
8185
mc.conf.URL = splitURL[1]
8286
}
8387

84-
mc.logger = newLogger(fmt.Sprintf("modbus-client(%s)", mc.conf.URL))
88+
mc.logger = newLogger(
89+
fmt.Sprintf("modbus-client(%s)", mc.conf.URL), conf.Logger)
8590

8691
switch clientType {
8792
case "rtu":
@@ -221,7 +226,7 @@ func (mc *ModbusClient) Open() (err error) {
221226

222227
// create the RTU transport
223228
mc.transport = newRTUTransport(
224-
spw, mc.conf.URL, mc.conf.Speed, mc.conf.Timeout)
229+
spw, mc.conf.URL, mc.conf.Speed, mc.conf.Timeout, mc.conf.Logger)
225230

226231
case modbusRTUOverTCP:
227232
// connect to the remote host
@@ -235,7 +240,7 @@ func (mc *ModbusClient) Open() (err error) {
235240

236241
// create the RTU transport
237242
mc.transport = newRTUTransport(
238-
sock, mc.conf.URL, mc.conf.Speed, mc.conf.Timeout)
243+
sock, mc.conf.URL, mc.conf.Speed, mc.conf.Timeout, mc.conf.Logger)
239244

240245
case modbusRTUOverUDP:
241246
// open a socket to the remote host (note: no actual connection is
@@ -250,7 +255,7 @@ func (mc *ModbusClient) Open() (err error) {
250255
// packets byte per byte
251256
mc.transport = newRTUTransport(
252257
newUDPSockWrapper(sock),
253-
mc.conf.URL, mc.conf.Speed, mc.conf.Timeout)
258+
mc.conf.URL, mc.conf.Speed, mc.conf.Timeout, mc.conf.Logger)
254259

255260
case modbusTCP:
256261
// connect to the remote host
@@ -260,7 +265,7 @@ func (mc *ModbusClient) Open() (err error) {
260265
}
261266

262267
// create the TCP transport
263-
mc.transport = newTCPTransport(sock, mc.conf.Timeout)
268+
mc.transport = newTCPTransport(sock, mc.conf.Timeout, mc.conf.Logger)
264269

265270
case modbusTCPOverTLS:
266271
// connect to the remote host with TLS
@@ -288,7 +293,7 @@ func (mc *ModbusClient) Open() (err error) {
288293
}
289294

290295
// create the TCP transport
291-
mc.transport = newTCPTransport(sock, mc.conf.Timeout)
296+
mc.transport = newTCPTransport(sock, mc.conf.Timeout, mc.conf.Logger)
292297

293298
case modbusTCPOverUDP:
294299
// open a socket to the remote host (note: no actual connection is
@@ -302,7 +307,7 @@ func (mc *ModbusClient) Open() (err error) {
302307
// an adapter to allow the transport to read the stream of
303308
// packets byte per byte
304309
mc.transport = newTCPTransport(
305-
newUDPSockWrapper(sock), mc.conf.Timeout)
310+
newUDPSockWrapper(sock), mc.conf.Timeout, mc.conf.Logger)
306311

307312
default:
308313
// should never happen

logger.go

+15-12
Original file line numberDiff line numberDiff line change
@@ -3,52 +3,55 @@ package modbus
33
import (
44
"fmt"
55
"os"
6+
"log"
67
)
78

89
type logger struct {
9-
prefix string
10+
prefix string
11+
customLogger *log.Logger
1012
}
1113

12-
func newLogger(prefix string) (l *logger) {
14+
func newLogger(prefix string, customLogger *log.Logger) (l *logger) {
1315
l = &logger{
1416
prefix: prefix,
17+
customLogger: customLogger,
1518
}
1619

1720
return
1821
}
1922

2023
func (l *logger) Info(msg string) {
21-
l.write(false, fmt.Sprintf("%s [info]: %s\n", l.prefix, msg))
24+
l.write(fmt.Sprintf("%s [info]: %s\n", l.prefix, msg))
2225

2326
return
2427
}
2528

2629
func (l *logger) Infof(format string, msg ...interface{}) {
27-
l.write(false, fmt.Sprintf("%s [info]: %s\n", l.prefix, fmt.Sprintf(format, msg...)))
30+
l.write(fmt.Sprintf("%s [info]: %s\n", l.prefix, fmt.Sprintf(format, msg...)))
2831

2932
return
3033
}
3134

3235
func (l *logger) Warning(msg string) {
33-
l.write(false, fmt.Sprintf("%s [warn]: %s\n", l.prefix, msg))
36+
l.write(fmt.Sprintf("%s [warn]: %s\n", l.prefix, msg))
3437

3538
return
3639
}
3740

3841
func (l *logger) Warningf(format string, msg ...interface{}) {
39-
l.write(false, fmt.Sprintf("%s [warn]: %s\n", l.prefix, fmt.Sprintf(format, msg...)))
42+
l.write(fmt.Sprintf("%s [warn]: %s\n", l.prefix, fmt.Sprintf(format, msg...)))
4043

4144
return
4245
}
4346

4447
func (l *logger) Error(msg string) {
45-
l.write(false, fmt.Sprintf("%s [error]: %s\n", l.prefix, msg))
48+
l.write(fmt.Sprintf("%s [error]: %s\n", l.prefix, msg))
4649

4750
return
4851
}
4952

5053
func (l *logger) Errorf(format string, msg ...interface{}) {
51-
l.write(false, fmt.Sprintf("%s [error]: %s\n", l.prefix, fmt.Sprintf(format, msg...)))
54+
l.write(fmt.Sprintf("%s [error]: %s\n", l.prefix, fmt.Sprintf(format, msg...)))
5255

5356
return
5457
}
@@ -67,11 +70,11 @@ func (l *logger) Fatalf(format string, msg ...interface{}) {
6770
return
6871
}
6972

70-
func (l *logger) write(stderr bool, msg string) {
71-
if stderr {
72-
os.Stderr.WriteString(msg)
73-
} else {
73+
func (l *logger) write(msg string) {
74+
if l.customLogger == nil {
7475
os.Stdout.WriteString(msg)
76+
} else {
77+
l.customLogger.Print(msg)
7578
}
7679

7780
return

logger_test.go

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package modbus
2+
3+
import (
4+
"bytes"
5+
"log"
6+
"testing"
7+
)
8+
9+
func TestClientCustomLogger(t *testing.T) {
10+
var buf bytes.Buffer
11+
var logger *log.Logger
12+
13+
logger = log.New(&buf, "external-prefix: ", 0)
14+
15+
_, _ = NewClient(&ClientConfiguration{
16+
Logger: logger,
17+
URL: "sometype://sometarget",
18+
})
19+
20+
if buf.String() != "external-prefix: modbus-client(sometarget) [error]: unsupported client type 'sometype'\n" {
21+
t.Errorf("unexpected logger output '%s'", buf.String())
22+
}
23+
24+
return
25+
}
26+
27+
func TestServerCustomLogger(t *testing.T) {
28+
var buf bytes.Buffer
29+
var logger *log.Logger
30+
31+
logger = log.New(&buf, "external-prefix: ", 0)
32+
33+
_, _ = NewServer(&ServerConfiguration{
34+
Logger: logger,
35+
URL: "tcp://",
36+
}, nil)
37+
38+
if buf.String() != "external-prefix: modbus-server() [error]: missing host part in URL 'tcp://'\n" {
39+
t.Errorf("unexpected logger output '%s'", buf.String())
40+
}
41+
42+
return
43+
}

rtu_transport.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package modbus
33
import (
44
"fmt"
55
"io"
6+
"log"
67
"time"
78
)
89

@@ -27,9 +28,9 @@ type rtuLink interface {
2728
}
2829

2930
// Returns a new RTU transport.
30-
func newRTUTransport(link rtuLink, addr string, speed uint, timeout time.Duration) (rt *rtuTransport) {
31+
func newRTUTransport(link rtuLink, addr string, speed uint, timeout time.Duration, customLogger *log.Logger) (rt *rtuTransport) {
3132
rt = &rtuTransport{
32-
logger: newLogger(fmt.Sprintf("rtu-transport(%s)", addr)),
33+
logger: newLogger(fmt.Sprintf("rtu-transport(%s)", addr), customLogger),
3334
link: link,
3435
timeout: timeout,
3536
t1: serialCharTime(speed),

rtu_transport_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestRTUTransportReadRTUFrame(t *testing.T) {
7070
go feedTestPipe(t, txchan, p1)
7171

7272

73-
rt = newRTUTransport(p2, "", 9600, 10 * time.Millisecond)
73+
rt = newRTUTransport(p2, "", 9600, 10 * time.Millisecond, nil)
7474

7575
// read a valid response (illegal data address)
7676
txchan <- []byte{

server.go

+10-5
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@ import (
66
"encoding/asn1"
77
"errors"
88
"fmt"
9-
"time"
9+
"log"
1010
"net"
1111
"strings"
1212
"sync"
13+
"time"
1314
)
1415

1516
// Modbus Role PEM OID (see R-21 of the MBAPS spec)
@@ -32,6 +33,9 @@ type ServerConfiguration struct {
3233
// client connections (tcp+tls only). Leaf (i.e. client) certificates can
3334
// also be used in case of self-signed certs, or if cert pinning is required.
3435
TLSClientCAs *x509.CertPool
36+
// Logger provides a custom sink for log messages.
37+
// If nil, messages will be written to stdout.
38+
Logger *log.Logger
3539
}
3640

3741
// Request object passed to the coil handler.
@@ -166,10 +170,11 @@ func NewServer(conf *ServerConfiguration, reqHandler RequestHandler) (
166170
ms.conf.URL = splitURL[1]
167171
}
168172

169-
ms.logger = newLogger(fmt.Sprintf("modbus-server(%s)", ms.conf.URL))
173+
ms.logger = newLogger(
174+
fmt.Sprintf("modbus-server(%s)", ms.conf.URL), ms.conf.Logger)
170175

171176
if ms.conf.URL == "" {
172-
ms.logger.Errorf("missing host part in URL '%s')", conf.URL)
177+
ms.logger.Errorf("missing host part in URL '%s'", conf.URL)
173178
err = ErrConfigurationError
174179
return
175180
}
@@ -333,7 +338,7 @@ func (ms *ModbusServer) handleTCPClient(sock net.Conn) {
333338
case modbusTCP:
334339
// serve modbus requests over the raw TCP connection
335340
ms.handleTransport(
336-
newTCPTransport(sock, ms.conf.Timeout),
341+
newTCPTransport(sock, ms.conf.Timeout, ms.conf.Logger),
337342
sock.RemoteAddr().String(), "")
338343

339344
case modbusTCPOverTLS:
@@ -345,7 +350,7 @@ func (ms *ModbusServer) handleTCPClient(sock net.Conn) {
345350
} else {
346351
// serve modbus requests over the TLS tunnel
347352
ms.handleTransport(
348-
newTCPTransport(tlsSock, ms.conf.Timeout),
353+
newTCPTransport(tlsSock, ms.conf.Timeout, ms.conf.Logger),
349354
sock.RemoteAddr().String(), clientRole)
350355
}
351356

server_tls_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ func TestServerExtractRole(t *testing.T) {
406406
var role string
407407

408408
ms = &ModbusServer{
409-
logger: newLogger("test-server-role-extraction"),
409+
logger: newLogger("test-server-role-extraction", nil),
410410
}
411411

412412
// load a client cert without role OID

tcp_transport.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package modbus
33
import (
44
"fmt"
55
"io"
6+
"log"
67
"net"
78
"time"
89
)
@@ -20,11 +21,11 @@ type tcpTransport struct {
2021
}
2122

2223
// Returns a new TCP transport.
23-
func newTCPTransport(socket net.Conn, timeout time.Duration) (tt *tcpTransport) {
24+
func newTCPTransport(socket net.Conn, timeout time.Duration, customLogger *log.Logger) (tt *tcpTransport) {
2425
tt = &tcpTransport{
2526
socket: socket,
2627
timeout: timeout,
27-
logger: newLogger(fmt.Sprintf("tcp-transport(%s)", socket.RemoteAddr())),
28+
logger: newLogger(fmt.Sprintf("tcp-transport(%s)", socket.RemoteAddr()), customLogger),
2829
}
2930

3031
return

tcp_transport_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestTCPTransportReadResponse(t *testing.T) {
7171
go feedTestPipe(t, txchan, p1)
7272

7373

74-
tt = newTCPTransport(p2, 10 * time.Millisecond)
74+
tt = newTCPTransport(p2, 10 * time.Millisecond, nil)
7575
tt.lastTxnId = 0x9218
7676

7777
// read a valid response
@@ -221,7 +221,7 @@ func TestTCPTransportReadRequest(t *testing.T) {
221221
go feedTestPipe(t, txchan, p1)
222222

223223

224-
tt = newTCPTransport(p2, 10 * time.Millisecond)
224+
tt = newTCPTransport(p2, 10 * time.Millisecond, nil)
225225
tt.lastTxnId = 0x0a00
226226

227227
// push three frames in a row:
@@ -348,7 +348,7 @@ func TestTCPTransportWriteResponse(t *testing.T) {
348348
}(t, p2, done)
349349

350350

351-
tt = newTCPTransport(p1, 10 * time.Millisecond)
351+
tt = newTCPTransport(p1, 10 * time.Millisecond, nil)
352352
tt.lastTxnId = 0xc01f
353353

354354
err = tt.WriteResponse(&pdu{

0 commit comments

Comments
 (0)