diff --git a/pkg/tlsx/ztls/ztls.go b/pkg/tlsx/ztls/ztls.go index a03b7267..0eeea34e 100644 --- a/pkg/tlsx/ztls/ztls.go +++ b/pkg/tlsx/ztls/ztls.go @@ -140,7 +140,7 @@ func (c *Client) ConnectWithOptions(hostname, ip, port string, options clients.C // new tls connection tlsConn := tls.Client(conn, config) - err = c.tlsHandshakeWithTimeout(tlsConn, ctx) + err = c.tlsHandshakeWithTimeout(ctx, tlsConn) if err != nil { if clients.IsClientCertRequiredError(err) { clientCertRequired = true @@ -257,7 +257,7 @@ func (c *Client) EnumerateCiphers(hostname, ip, port string, options clients.Con conn := tls.Client(baseConn, baseCfg) baseCfg.CipherSuites = []uint16{ztlsCiphers[v]} - if err := c.tlsHandshakeWithTimeout(conn, context.TODO()); err == nil { + if err := c.tlsHandshakeWithTimeout(context.TODO(), conn); err == nil { h1 := conn.GetHandshakeLog() enumeratedCiphers = append(enumeratedCiphers, h1.ServerHello.CipherSuite.String()) } @@ -320,20 +320,22 @@ func (c *Client) getConfig(hostname, ip, port string, options clients.ConnectOpt return config, nil } -// tlsHandshakeWithCtx attempts tls handshake with given timeout -func (c *Client) tlsHandshakeWithTimeout(tlsConn *tls.Conn, ctx context.Context) error { +// tlsHandshakeWithTimeout attempts tls handshake with given timeout +func (c *Client) tlsHandshakeWithTimeout(ctx context.Context, tlsConn *tls.Conn) error { errChan := make(chan error, 1) - defer close(errChan) + + go func() { + errChan <- tlsConn.Handshake() + }() select { case <-ctx.Done(): - return errorutil.NewWithTag("ztls", "timeout while attempting handshake") //nolint - case errChan <- tlsConn.Handshake(): - } - - err := <-errChan - if err == tls.ErrCertsOnly { - err = nil + return errorutil.NewWithTag("ztls", "timeout while attempting handshake").Wrap(ctx.Err()) //nolint + case err := <-errChan: + if err == tls.ErrCertsOnly { + err = nil + } + return err } - return err } +