Skip to content

Commit 5efd7bd

Browse files
authored
server: prohibit more than MaxConcurrentStreams handlers from running at once (#6703) (#6708)
1 parent bd1f038 commit 5efd7bd

File tree

4 files changed

+180
-45
lines changed

4 files changed

+180
-45
lines changed

internal/transport/http2_server.go

+3-8
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,10 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
171171
ID: http2.SettingMaxFrameSize,
172172
Val: http2MaxFrameLen,
173173
}}
174-
// TODO(zhaoq): Have a better way to signal "no limit" because 0 is
175-
// permitted in the HTTP2 spec.
176-
maxStreams := config.MaxStreams
177-
if maxStreams == 0 {
178-
maxStreams = math.MaxUint32
179-
} else {
174+
if config.MaxStreams != math.MaxUint32 {
180175
isettings = append(isettings, http2.Setting{
181176
ID: http2.SettingMaxConcurrentStreams,
182-
Val: maxStreams,
177+
Val: config.MaxStreams,
183178
})
184179
}
185180
dynamicWindow := true
@@ -258,7 +253,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
258253
framer: framer,
259254
readerDone: make(chan struct{}),
260255
writerDone: make(chan struct{}),
261-
maxStreams: maxStreams,
256+
maxStreams: config.MaxStreams,
262257
inTapHandle: config.InTapHandle,
263258
fc: &trInFlow{limit: uint32(icwz)},
264259
state: reachable,

internal/transport/transport_test.go

+19-16
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT
336336
if err != nil {
337337
return
338338
}
339+
if serverConfig.MaxStreams == 0 {
340+
serverConfig.MaxStreams = math.MaxUint32
341+
}
339342
transport, err := NewServerTransport(conn, serverConfig)
340343
if err != nil {
341344
return
@@ -442,8 +445,8 @@ func setUpServerOnly(t *testing.T, port int, sc *ServerConfig, ht hType) *server
442445
return server
443446
}
444447

445-
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, *http2Client, func()) {
446-
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
448+
func setUp(t *testing.T, port int, ht hType) (*server, *http2Client, func()) {
449+
return setUpWithOptions(t, port, &ServerConfig{}, ht, ConnectOptions{})
447450
}
448451

449452
func setUpWithOptions(t *testing.T, port int, sc *ServerConfig, ht hType, copts ConnectOptions) (*server, *http2Client, func()) {
@@ -538,7 +541,7 @@ func (s) TestInflightStreamClosing(t *testing.T) {
538541

539542
// Tests that when streamID > MaxStreamId, the current client transport drains.
540543
func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
541-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
544+
server, ct, cancel := setUp(t, 0, normal)
542545
defer cancel()
543546
defer server.stop()
544547
callHdr := &CallHdr{
@@ -583,7 +586,7 @@ func (s) TestClientTransportDrainsAfterStreamIDExhausted(t *testing.T) {
583586
}
584587

585588
func (s) TestClientSendAndReceive(t *testing.T) {
586-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
589+
server, ct, cancel := setUp(t, 0, normal)
587590
defer cancel()
588591
callHdr := &CallHdr{
589592
Host: "localhost",
@@ -623,7 +626,7 @@ func (s) TestClientSendAndReceive(t *testing.T) {
623626
}
624627

625628
func (s) TestClientErrorNotify(t *testing.T) {
626-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
629+
server, ct, cancel := setUp(t, 0, normal)
627630
defer cancel()
628631
go server.stop()
629632
// ct.reader should detect the error and activate ct.Error().
@@ -657,7 +660,7 @@ func performOneRPC(ct ClientTransport) {
657660
}
658661

659662
func (s) TestClientMix(t *testing.T) {
660-
s, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
663+
s, ct, cancel := setUp(t, 0, normal)
661664
defer cancel()
662665
time.AfterFunc(time.Second, s.stop)
663666
go func(ct ClientTransport) {
@@ -671,7 +674,7 @@ func (s) TestClientMix(t *testing.T) {
671674
}
672675

673676
func (s) TestLargeMessage(t *testing.T) {
674-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
677+
server, ct, cancel := setUp(t, 0, normal)
675678
defer cancel()
676679
callHdr := &CallHdr{
677680
Host: "localhost",
@@ -806,7 +809,7 @@ func (s) TestLargeMessageWithDelayRead(t *testing.T) {
806809
// proceed until they complete naturally, while not allowing creation of new
807810
// streams during this window.
808811
func (s) TestGracefulClose(t *testing.T) {
809-
server, ct, cancel := setUp(t, 0, math.MaxUint32, pingpong)
812+
server, ct, cancel := setUp(t, 0, pingpong)
810813
defer cancel()
811814
defer func() {
812815
// Stop the server's listener to make the server's goroutines terminate
@@ -872,7 +875,7 @@ func (s) TestGracefulClose(t *testing.T) {
872875
}
873876

874877
func (s) TestLargeMessageSuspension(t *testing.T) {
875-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
878+
server, ct, cancel := setUp(t, 0, suspended)
876879
defer cancel()
877880
callHdr := &CallHdr{
878881
Host: "localhost",
@@ -980,7 +983,7 @@ func (s) TestMaxStreams(t *testing.T) {
980983
}
981984

982985
func (s) TestServerContextCanceledOnClosedConnection(t *testing.T) {
983-
server, ct, cancel := setUp(t, 0, math.MaxUint32, suspended)
986+
server, ct, cancel := setUp(t, 0, suspended)
984987
defer cancel()
985988
callHdr := &CallHdr{
986989
Host: "localhost",
@@ -1452,7 +1455,7 @@ func (s) TestClientWithMisbehavedServer(t *testing.T) {
14521455
var encodingTestStatus = status.New(codes.Internal, "\n")
14531456

14541457
func (s) TestEncodingRequiredStatus(t *testing.T) {
1455-
server, ct, cancel := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
1458+
server, ct, cancel := setUp(t, 0, encodingRequiredStatus)
14561459
defer cancel()
14571460
callHdr := &CallHdr{
14581461
Host: "localhost",
@@ -1480,7 +1483,7 @@ func (s) TestEncodingRequiredStatus(t *testing.T) {
14801483
}
14811484

14821485
func (s) TestInvalidHeaderField(t *testing.T) {
1483-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1486+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
14841487
defer cancel()
14851488
callHdr := &CallHdr{
14861489
Host: "localhost",
@@ -1502,7 +1505,7 @@ func (s) TestInvalidHeaderField(t *testing.T) {
15021505
}
15031506

15041507
func (s) TestHeaderChanClosedAfterReceivingAnInvalidHeader(t *testing.T) {
1505-
server, ct, cancel := setUp(t, 0, math.MaxUint32, invalidHeaderField)
1508+
server, ct, cancel := setUp(t, 0, invalidHeaderField)
15061509
defer cancel()
15071510
defer server.stop()
15081511
defer ct.Close(fmt.Errorf("closed manually by test"))
@@ -2170,7 +2173,7 @@ func (s) TestPingPong1MB(t *testing.T) {
21702173

21712174
// This is a stress-test of flow control logic.
21722175
func runPingPongTest(t *testing.T, msgSize int) {
2173-
server, client, cancel := setUp(t, 0, 0, pingpong)
2176+
server, client, cancel := setUp(t, 0, pingpong)
21742177
defer cancel()
21752178
defer server.stop()
21762179
defer client.Close(fmt.Errorf("closed manually by test"))
@@ -2252,7 +2255,7 @@ func (s) TestHeaderTblSize(t *testing.T) {
22522255
}
22532256
}()
22542257

2255-
server, ct, cancel := setUp(t, 0, math.MaxUint32, normal)
2258+
server, ct, cancel := setUp(t, 0, normal)
22562259
defer cancel()
22572260
defer ct.Close(fmt.Errorf("closed manually by test"))
22582261
defer server.stop()
@@ -2611,7 +2614,7 @@ func TestConnectionError_Unwrap(t *testing.T) {
26112614

26122615
func (s) TestPeerSetInServerContext(t *testing.T) {
26132616
// create client and server transports.
2614-
server, client, cancel := setUp(t, 0, math.MaxUint32, normal)
2617+
server, client, cancel := setUp(t, 0, normal)
26152618
defer cancel()
26162619
defer server.stop()
26172620
defer client.Close(fmt.Errorf("closed manually by test"))

server.go

+48-21
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,6 @@ type serviceInfo struct {
115115
mdata interface{}
116116
}
117117

118-
type serverWorkerData struct {
119-
st transport.ServerTransport
120-
wg *sync.WaitGroup
121-
stream *transport.Stream
122-
}
123-
124118
// Server is a gRPC server to serve RPC requests.
125119
type Server struct {
126120
opts serverOptions
@@ -145,7 +139,7 @@ type Server struct {
145139
channelzID *channelz.Identifier
146140
czData *channelzData
147141

148-
serverWorkerChannel chan *serverWorkerData
142+
serverWorkerChannel chan func()
149143
}
150144

151145
type serverOptions struct {
@@ -177,6 +171,7 @@ type serverOptions struct {
177171
}
178172

179173
var defaultServerOptions = serverOptions{
174+
maxConcurrentStreams: math.MaxUint32,
180175
maxReceiveMessageSize: defaultServerMaxReceiveMessageSize,
181176
maxSendMessageSize: defaultServerMaxSendMessageSize,
182177
connectionTimeout: 120 * time.Second,
@@ -387,6 +382,9 @@ func MaxSendMsgSize(m int) ServerOption {
387382
// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
388383
// of concurrent streams to each ServerTransport.
389384
func MaxConcurrentStreams(n uint32) ServerOption {
385+
if n == 0 {
386+
n = math.MaxUint32
387+
}
390388
return newFuncServerOption(func(o *serverOptions) {
391389
o.maxConcurrentStreams = n
392390
})
@@ -567,24 +565,19 @@ const serverWorkerResetThreshold = 1 << 16
567565
// [1] https://github.com/golang/go/issues/18138
568566
func (s *Server) serverWorker() {
569567
for completed := 0; completed < serverWorkerResetThreshold; completed++ {
570-
data, ok := <-s.serverWorkerChannel
568+
f, ok := <-s.serverWorkerChannel
571569
if !ok {
572570
return
573571
}
574-
s.handleSingleStream(data)
572+
f()
575573
}
576574
go s.serverWorker()
577575
}
578576

579-
func (s *Server) handleSingleStream(data *serverWorkerData) {
580-
defer data.wg.Done()
581-
s.handleStream(data.st, data.stream, s.traceInfo(data.st, data.stream))
582-
}
583-
584577
// initServerWorkers creates worker goroutines and a channel to process incoming
585578
// connections to reduce the time spent overall on runtime.morestack.
586579
func (s *Server) initServerWorkers() {
587-
s.serverWorkerChannel = make(chan *serverWorkerData)
580+
s.serverWorkerChannel = make(chan func())
588581
for i := uint32(0); i < s.opts.numServerWorkers; i++ {
589582
go s.serverWorker()
590583
}
@@ -943,21 +936,26 @@ func (s *Server) serveStreams(st transport.ServerTransport) {
943936
defer st.Close(errors.New("finished serving streams for the server transport"))
944937
var wg sync.WaitGroup
945938

939+
streamQuota := newHandlerQuota(s.opts.maxConcurrentStreams)
946940
st.HandleStreams(func(stream *transport.Stream) {
947941
wg.Add(1)
942+
943+
streamQuota.acquire()
944+
f := func() {
945+
defer streamQuota.release()
946+
defer wg.Done()
947+
s.handleStream(st, stream, s.traceInfo(st, stream))
948+
}
949+
948950
if s.opts.numServerWorkers > 0 {
949-
data := &serverWorkerData{st: st, wg: &wg, stream: stream}
950951
select {
951-
case s.serverWorkerChannel <- data:
952+
case s.serverWorkerChannel <- f:
952953
return
953954
default:
954955
// If all stream workers are busy, fallback to the default code path.
955956
}
956957
}
957-
go func() {
958-
defer wg.Done()
959-
s.handleStream(st, stream, s.traceInfo(st, stream))
960-
}()
958+
go f()
961959
}, func(ctx context.Context, method string) context.Context {
962960
if !EnableTracing {
963961
return ctx
@@ -2052,3 +2050,32 @@ func validateSendCompressor(name, clientCompressors string) error {
20522050
}
20532051
return fmt.Errorf("client does not support compressor %q", name)
20542052
}
2053+
2054+
// atomicSemaphore implements a blocking, counting semaphore. acquire should be
2055+
// called synchronously; release may be called asynchronously.
2056+
type atomicSemaphore struct {
2057+
n int64
2058+
wait chan struct{}
2059+
}
2060+
2061+
func (q *atomicSemaphore) acquire() {
2062+
if atomic.AddInt64(&q.n, -1) < 0 {
2063+
// We ran out of quota. Block until a release happens.
2064+
<-q.wait
2065+
}
2066+
}
2067+
2068+
func (q *atomicSemaphore) release() {
2069+
// N.B. the "<= 0" check below should allow for this to work with multiple
2070+
// concurrent calls to acquire, but also note that with synchronous calls to
2071+
// acquire, as our system does, n will never be less than -1. There are
2072+
// fairness issues (queuing) to consider if this was to be generalized.
2073+
if atomic.AddInt64(&q.n, 1) <= 0 {
2074+
// An acquire was waiting on us. Unblock it.
2075+
q.wait <- struct{}{}
2076+
}
2077+
}
2078+
2079+
func newHandlerQuota(n uint32) *atomicSemaphore {
2080+
return &atomicSemaphore{n: int64(n), wait: make(chan struct{}, 1)}
2081+
}

0 commit comments

Comments
 (0)