@@ -493,24 +493,32 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
493
493
):
494
494
raise error
495
495
496
- # COMMAND EXECUTION AND PROTOCOL PARSING
497
- async def execute_command (self , * args , ** options ):
498
- """Execute a command and return a parsed response"""
499
- await self .initialize ()
500
- pool = self .connection_pool
501
- command_name = args [0 ]
502
- conn = self .connection or await pool .get_connection (command_name , ** options )
503
-
496
+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
504
497
try :
505
498
return await conn .retry .call_with_retry (
506
499
lambda : self ._send_command_parse_response (
507
- conn , command_name , * args , ** options
500
+ conn , args [ 0 ] , * args , ** options
508
501
),
509
502
lambda error : self ._disconnect_raise (conn , error ),
510
503
)
504
+ except asyncio .CancelledError :
505
+ await conn .disconnect (nowait = True )
506
+ raise
511
507
finally :
512
508
if not self .connection :
513
- await pool .release (conn )
509
+ await self .connection_pool .release (conn )
510
+
511
+ # COMMAND EXECUTION AND PROTOCOL PARSING
512
+ async def execute_command (self , * args , ** options ):
513
+ """Execute a command and return a parsed response"""
514
+ await self .initialize ()
515
+ pool = self .connection_pool
516
+ command_name = args [0 ]
517
+ conn = self .connection or await pool .get_connection (command_name , ** options )
518
+
519
+ return await asyncio .shield (
520
+ self ._try_send_command_parse_response (conn , * args , ** options )
521
+ )
514
522
515
523
async def parse_response (
516
524
self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -749,10 +757,18 @@ async def _disconnect_raise_connect(self, conn, error):
749
757
is not a TimeoutError. Otherwise, try to reconnect
750
758
"""
751
759
await conn .disconnect ()
760
+
752
761
if not (conn .retry_on_timeout and isinstance (error , TimeoutError )):
753
762
raise error
754
763
await conn .connect ()
755
764
765
+ async def _try_execute (self , conn , command , * arg , ** kwargs ):
766
+ try :
767
+ return await command (* arg , ** kwargs )
768
+ except asyncio .CancelledError :
769
+ await conn .disconnect ()
770
+ raise
771
+
756
772
async def _execute (self , conn , command , * args , ** kwargs ):
757
773
"""
758
774
Connect manually upon disconnection. If the Redis server is down,
@@ -761,9 +777,11 @@ async def _execute(self, conn, command, *args, **kwargs):
761
777
called by the # connection to resubscribe us to any channels and
762
778
patterns we were previously listening to
763
779
"""
764
- return await conn .retry .call_with_retry (
765
- lambda : command (* args , ** kwargs ),
766
- lambda error : self ._disconnect_raise_connect (conn , error ),
780
+ return await asyncio .shield (
781
+ conn .retry .call_with_retry (
782
+ lambda : self ._try_execute (conn , command , * args , ** kwargs ),
783
+ lambda error : self ._disconnect_raise_connect (conn , error ),
784
+ )
767
785
)
768
786
769
787
async def parse_response (self , block : bool = True , timeout : float = 0 ):
@@ -1165,6 +1183,18 @@ async def _disconnect_reset_raise(self, conn, error):
1165
1183
await self .reset ()
1166
1184
raise
1167
1185
1186
+ async def _try_send_command_parse_response (self , conn , * args , ** options ):
1187
+ try :
1188
+ return await conn .retry .call_with_retry (
1189
+ lambda : self ._send_command_parse_response (
1190
+ conn , args [0 ], * args , ** options
1191
+ ),
1192
+ lambda error : self ._disconnect_reset_raise (conn , error ),
1193
+ )
1194
+ except asyncio .CancelledError :
1195
+ await conn .disconnect ()
1196
+ raise
1197
+
1168
1198
async def immediate_execute_command (self , * args , ** options ):
1169
1199
"""
1170
1200
Execute a command immediately, but don't auto-retry on a
@@ -1180,12 +1210,8 @@ async def immediate_execute_command(self, *args, **options):
1180
1210
command_name , self .shard_hint
1181
1211
)
1182
1212
self .connection = conn
1183
-
1184
- return await conn .retry .call_with_retry (
1185
- lambda : self ._send_command_parse_response (
1186
- conn , command_name , * args , ** options
1187
- ),
1188
- lambda error : self ._disconnect_reset_raise (conn , error ),
1213
+ return await asyncio .shield (
1214
+ self ._try_send_command_parse_response (conn , * args , ** options )
1189
1215
)
1190
1216
1191
1217
def pipeline_execute_command (self , * args , ** options ):
@@ -1353,6 +1379,19 @@ async def _disconnect_raise_reset(self, conn: Connection, error: Exception):
1353
1379
await self .reset ()
1354
1380
raise
1355
1381
1382
+ async def _try_execute (self , conn , execute , stack , raise_on_error ):
1383
+ try :
1384
+ return await conn .retry .call_with_retry (
1385
+ lambda : execute (conn , stack , raise_on_error ),
1386
+ lambda error : self ._disconnect_raise_reset (conn , error ),
1387
+ )
1388
+ except asyncio .CancelledError :
1389
+ # not supposed to be possible, yet here we are
1390
+ await conn .disconnect (nowait = True )
1391
+ raise
1392
+ finally :
1393
+ await self .reset ()
1394
+
1356
1395
async def execute (self , raise_on_error : bool = True ):
1357
1396
"""Execute all the commands in the current pipeline"""
1358
1397
stack = self .command_stack
@@ -1375,15 +1414,10 @@ async def execute(self, raise_on_error: bool = True):
1375
1414
1376
1415
try :
1377
1416
return await asyncio .shield (
1378
- conn .retry .call_with_retry (
1379
- lambda : execute (conn , stack , raise_on_error ),
1380
- lambda error : self ._disconnect_raise_reset (conn , error ),
1381
- )
1417
+ self ._try_execute (conn , execute , stack , raise_on_error )
1382
1418
)
1383
- except asyncio .CancelledError :
1384
- # not supposed to be possible, yet here we are
1385
- await conn .disconnect (nowait = True )
1386
- raise
1419
+ except RuntimeError :
1420
+ await self .reset ()
1387
1421
finally :
1388
1422
await self .reset ()
1389
1423
0 commit comments