Skip to content

Commit 44ba2ee

Browse files
authored
feat: reduce the type assertion of CheckConn (#3066)
* feat: reduce the type assertion of CheckConn Signed-off-by: monkey92t <[email protected]> * fix: correct the function names Signed-off-by: monkey92t <[email protected]> --------- Signed-off-by: monkey92t <[email protected]>
1 parent 9cfeb30 commit 44ba2ee

File tree

5 files changed

+49
-26
lines changed

5 files changed

+49
-26
lines changed

internal/pool/conn.go

+22
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ package pool
33
import (
44
"bufio"
55
"context"
6+
"crypto/tls"
67
"net"
78
"sync/atomic"
9+
"syscall"
810
"time"
911

1012
"github.com/redis/go-redis/v9/internal/proto"
@@ -16,6 +18,9 @@ type Conn struct {
1618
usedAt int64 // atomic
1719
netConn net.Conn
1820

21+
// for checking the health status of the connection, it may be nil.
22+
sysConn syscall.Conn
23+
1924
rd *proto.Reader
2025
bw *bufio.Writer
2126
wr *proto.Writer
@@ -34,6 +39,7 @@ func NewConn(netConn net.Conn) *Conn {
3439
cn.bw = bufio.NewWriter(netConn)
3540
cn.wr = proto.NewWriter(cn.bw)
3641
cn.SetUsedAt(time.Now())
42+
cn.setSysConn()
3743
return cn
3844
}
3945

@@ -50,6 +56,22 @@ func (cn *Conn) SetNetConn(netConn net.Conn) {
5056
cn.netConn = netConn
5157
cn.rd.Reset(netConn)
5258
cn.bw.Reset(netConn)
59+
cn.setSysConn()
60+
}
61+
62+
func (cn *Conn) setSysConn() {
63+
cn.sysConn = nil
64+
conn := cn.netConn
65+
if conn == nil {
66+
return
67+
}
68+
if tlsConn, ok := conn.(*tls.Conn); ok {
69+
conn = tlsConn.NetConn()
70+
}
71+
72+
if sysConn, ok := conn.(syscall.Conn); ok {
73+
cn.sysConn = sysConn
74+
}
5375
}
5476

5577
func (cn *Conn) Write(b []byte) (int, error) {

internal/pool/conn_check.go

+1-15
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,14 @@
33
package pool
44

55
import (
6-
"crypto/tls"
76
"errors"
87
"io"
9-
"net"
108
"syscall"
11-
"time"
129
)
1310

1411
var errUnexpectedRead = errors.New("unexpected read from socket")
1512

16-
func connCheck(conn net.Conn) error {
17-
// Reset previous timeout.
18-
_ = conn.SetDeadline(time.Time{})
19-
20-
// Check if tls.Conn.
21-
if c, ok := conn.(*tls.Conn); ok {
22-
conn = c.NetConn()
23-
}
24-
sysConn, ok := conn.(syscall.Conn)
25-
if !ok {
26-
return nil
27-
}
13+
func connCheck(sysConn syscall.Conn) error {
2814
rawConn, err := sysConn.SyscallConn()
2915
if err != nil {
3016
return err

internal/pool/conn_check_dummy.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
package pool
44

5-
import "net"
5+
import "syscall"
66

7-
func connCheck(conn net.Conn) error {
7+
func connCheck(_ syscall.Conn) error {
88
return nil
99
}

internal/pool/conn_check_test.go

+16-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"crypto/tls"
77
"net"
88
"net/http/httptest"
9+
"syscall"
910
"time"
1011

1112
. "github.com/bsm/ginkgo/v2"
@@ -16,50 +17,58 @@ var _ = Describe("tests conn_check with real conns", func() {
1617
var ts *httptest.Server
1718
var conn net.Conn
1819
var tlsConn *tls.Conn
20+
var sysConn syscall.Conn
21+
var tlsSysConn syscall.Conn
1922
var err error
2023

2124
BeforeEach(func() {
2225
ts = httptest.NewServer(nil)
2326
conn, err = net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second)
2427
Expect(err).NotTo(HaveOccurred())
28+
sysConn = conn.(syscall.Conn)
2529
tlsTestServer := httptest.NewUnstartedServer(nil)
2630
tlsTestServer.StartTLS()
2731
tlsConn, err = tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, tlsTestServer.Listener.Addr().Network(), tlsTestServer.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true})
2832
Expect(err).NotTo(HaveOccurred())
33+
tlsSysConn = tlsConn.NetConn().(syscall.Conn)
2934
})
3035

3136
AfterEach(func() {
3237
ts.Close()
3338
})
3439

3540
It("good conn check", func() {
36-
Expect(connCheck(conn)).NotTo(HaveOccurred())
41+
Expect(connCheck(sysConn)).NotTo(HaveOccurred())
3742

3843
Expect(conn.Close()).NotTo(HaveOccurred())
39-
Expect(connCheck(conn)).To(HaveOccurred())
44+
Expect(connCheck(sysConn)).To(HaveOccurred())
4045
})
4146

4247
It("good tls conn check", func() {
43-
Expect(connCheck(tlsConn)).NotTo(HaveOccurred())
48+
Expect(connCheck(tlsSysConn)).NotTo(HaveOccurred())
4449

4550
Expect(tlsConn.Close()).NotTo(HaveOccurred())
46-
Expect(connCheck(tlsConn)).To(HaveOccurred())
51+
Expect(connCheck(tlsSysConn)).To(HaveOccurred())
4752
})
4853

4954
It("bad conn check", func() {
5055
Expect(conn.Close()).NotTo(HaveOccurred())
51-
Expect(connCheck(conn)).To(HaveOccurred())
56+
Expect(connCheck(sysConn)).To(HaveOccurred())
5257
})
5358

5459
It("bad tls conn check", func() {
5560
Expect(tlsConn.Close()).NotTo(HaveOccurred())
56-
Expect(connCheck(tlsConn)).To(HaveOccurred())
61+
Expect(connCheck(tlsSysConn)).To(HaveOccurred())
5762
})
5863

5964
It("check conn deadline", func() {
6065
Expect(conn.SetDeadline(time.Now())).NotTo(HaveOccurred())
6166
time.Sleep(time.Millisecond * 10)
62-
Expect(connCheck(conn)).NotTo(HaveOccurred())
67+
Expect(connCheck(sysConn)).To(HaveOccurred())
68+
69+
Expect(conn.SetDeadline(time.Now().Add(time.Minute))).NotTo(HaveOccurred())
70+
time.Sleep(time.Millisecond * 10)
71+
Expect(connCheck(sysConn)).NotTo(HaveOccurred())
6372
Expect(conn.Close()).NotTo(HaveOccurred())
6473
})
6574
})

internal/pool/pool.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ func (p *ConnPool) Close() error {
499499
return firstErr
500500
}
501501

502+
var zeroTime = time.Time{}
503+
502504
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
503505
now := time.Now()
504506

@@ -509,8 +511,12 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
509511
return false
510512
}
511513

512-
if connCheck(cn.netConn) != nil {
513-
return false
514+
if cn.sysConn != nil {
515+
// reset previous timeout.
516+
_ = cn.netConn.SetDeadline(zeroTime)
517+
if connCheck(cn.sysConn) != nil {
518+
return false
519+
}
514520
}
515521

516522
cn.SetUsedAt(now)

0 commit comments

Comments
 (0)