diff --git a/pkg/tlsx/tls/tls.go b/pkg/tlsx/tls/tls.go index c07a5ed2..49e4a698 100644 --- a/pkg/tlsx/tls/tls.go +++ b/pkg/tlsx/tls/tls.go @@ -236,10 +236,18 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) - if err := conn.Handshake(); err == nil { + handshakeCtx := context.Background() + var cancel context.CancelFunc + if c.options.Timeout != 0 { + handshakeCtx, cancel = context.WithTimeout(handshakeCtx, time.Duration(c.options.Timeout)*time.Second) + } + if err := conn.HandshakeContext(handshakeCtx); err == nil { ciphersuite := conn.ConnectionState().CipherSuite enumeratedCiphers = append(enumeratedCiphers, tls.CipherSuiteName(ciphersuite)) } + if cancel != nil { + cancel() + } _ = conn.Close() // close baseConn internally } return enumeratedCiphers, nil diff --git a/pkg/tlsx/ztls/timeout_test.go b/pkg/tlsx/ztls/timeout_test.go new file mode 100644 index 00000000..58edf959 --- /dev/null +++ b/pkg/tlsx/ztls/timeout_test.go @@ -0,0 +1,68 @@ +package ztls + +import ( + "context" + "net" + "testing" + "time" + + "github.com/zmap/zcrypto/tls" +) + +// TestHandshakeTimeout verifies that tlsHandshakeWithTimeout returns promptly +// when the context deadline is reached rather than hanging indefinitely. +// This is the regression test for issue #819. +func TestHandshakeTimeout(t *testing.T) { + // Start a TCP listener that accepts connections but never sends data, + // simulating the hosts that cause tlsx to hang indefinitely. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to start listener: %v", err) + } + t.Cleanup(func() { ln.Close() }) + + // Accept connections in the background so the dial succeeds, + // but never do anything with them (simulating a hanging server). + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + // hold the connection open, never respond + t.Cleanup(func() { conn.Close() }) + } + }() + + // Dial the hanging server. + rawConn, err := net.DialTimeout("tcp", ln.Addr().String(), 2*time.Second) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + t.Cleanup(func() { rawConn.Close() }) + + // Wrap in a ztls TLS connection. + tlsConn := tls.Client(rawConn, &tls.Config{InsecureSkipVerify: true}) + + // Use a short context deadline. + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + // Create a minimal Client (receiver is not used by tlsHandshakeWithTimeout). + c := &Client{} + + start := time.Now() + err = c.tlsHandshakeWithTimeout(tlsConn, rawConn, ctx) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected timeout error, got nil") + } + + // Must complete within 2 seconds (generous margin above the 500ms deadline). + if elapsed > 2*time.Second { + t.Fatalf("handshake took %v, expected to timeout around 500ms", elapsed) + } + + t.Logf("handshake timed out correctly in %v", elapsed) +} diff --git a/pkg/tlsx/ztls/ztls.go b/pkg/tlsx/ztls/ztls.go index a03b7267..35a8105c 100644 --- a/pkg/tlsx/ztls/ztls.go +++ b/pkg/tlsx/ztls/ztls.go @@ -140,7 +140,7 @@ func (c *Client) ConnectWithOptions(hostname, ip, port string, options clients.C // new tls connection tlsConn := tls.Client(conn, config) - err = c.tlsHandshakeWithTimeout(tlsConn, ctx) + err = c.tlsHandshakeWithTimeout(tlsConn, conn, ctx) if err != nil { if clients.IsClientCertRequiredError(err) { clientCertRequired = true @@ -257,10 +257,18 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) baseCfg.CipherSuites = []uint16{ztlsCiphers[v]} - if err := c.tlsHandshakeWithTimeout(conn, context.TODO()); err == nil { + handshakeCtx := context.Background() + var cancel context.CancelFunc + if c.options.Timeout != 0 { + handshakeCtx, cancel = context.WithTimeout(handshakeCtx, time.Duration(c.options.Timeout)*time.Second) + } + if err := c.tlsHandshakeWithTimeout(conn, baseConn, handshakeCtx); err == nil { h1 := conn.GetHandshakeLog() enumeratedCiphers = append(enumeratedCiphers, h1.ServerHello.CipherSuite.String()) } + if cancel != nil { + cancel() + } _ = conn.Close() // also closes baseConn internally } return enumeratedCiphers, nil @@ -320,20 +328,33 @@ func (c *Client) getConfig(hostname, ip, port string, options clients.ConnectOpt return config, nil } -// tlsHandshakeWithCtx attempts tls handshake with given timeout -func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, ctx context.Context) error { +// tlsHandshakeWithTimeout attempts tls handshake with given timeout. +// rawConn is the underlying TCP connection; on timeout we close it +// instead of tlsConn because zcrypto's Close() acquires the handshake +// mutex, causing deadlock when the handshake is blocked on I/O. +func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, rawConn net.Conn, ctx context.Context) error { errChan := make(chan error, 1) - defer close(errChan) + go func() { + errChan <- tlsConn.Handshake() + }() select { + case err := <-errChan: + if err == tls.ErrCertsOnly { + return nil + } + return err case <-ctx.Done(): + // Prefer a completed handshake that arrived simultaneously with the deadline + select { + case err := <-errChan: + if err == tls.ErrCertsOnly { + return nil + } + return err + default: + } + _ = rawConn.Close() return errorutil.NewWithTag("ztls", "timeout while attempting handshake") //nolint - case errChan <- tlsConn.Handshake(): - } - - err := <-errChan - if err == tls.ErrCertsOnly { - err = nil } - return err }