Skip to content

Commit b3c89ac

Browse files
authored
AsyncIO Race Condition Fix (#2640)
1 parent 8592cac commit b3c89ac

File tree

6 files changed

+61
-8
lines changed

6 files changed

+61
-8
lines changed

.github/workflows/integration.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ on:
1313
branches:
1414
- master
1515
- '[0-9].[0-9]'
16-
schedule:
17-
- cron: '0 1 * * *' # nightly build
16+
# schedule:
17+
# - cron: '0 1 * * *' # nightly build
1818

1919
permissions:
2020
contents: read # to fetch code (actions/checkout)

redis/asyncio/client.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1374,10 +1374,16 @@ async def execute(self, raise_on_error: bool = True):
13741374
conn = cast(Connection, conn)
13751375

13761376
try:
1377-
return await conn.retry.call_with_retry(
1378-
lambda: execute(conn, stack, raise_on_error),
1379-
lambda error: self._disconnect_raise_reset(conn, error),
1377+
return await asyncio.shield(
1378+
conn.retry.call_with_retry(
1379+
lambda: execute(conn, stack, raise_on_error),
1380+
lambda error: self._disconnect_raise_reset(conn, error),
1381+
)
13801382
)
1383+
except asyncio.CancelledError:
1384+
# not supposed to be possible, yet here we are
1385+
await conn.disconnect(nowait=True)
1386+
raise
13811387
finally:
13821388
await self.reset()
13831389

redis/asyncio/cluster.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1002,10 +1002,18 @@ async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
10021002
await connection.send_packed_command(connection.pack_command(*args), False)
10031003

10041004
# Read response
1005+
return await asyncio.shield(
1006+
self._parse_and_release(connection, args[0], **kwargs)
1007+
)
1008+
1009+
async def _parse_and_release(self, connection, *args, **kwargs):
10051010
try:
1006-
return await self.parse_response(connection, args[0], **kwargs)
1011+
return await self.parse_response(connection, *args, **kwargs)
1012+
except asyncio.CancelledError:
1013+
# should not be possible
1014+
await connection.disconnect(nowait=True)
1015+
raise
10071016
finally:
1008-
# Release connection
10091017
self._free.append(connection)
10101018

10111019
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
long_description_content_type="text/markdown",
99
keywords=["Redis", "key-value store", "database"],
1010
license="MIT",
11-
version="4.4.2",
11+
version="4.4.3",
1212
packages=find_packages(
1313
include=[
1414
"redis",

tests/test_asyncio/test_cluster.py

+17
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,23 @@ async def test_from_url(self, request: FixtureRequest) -> None:
340340
rc = RedisCluster.from_url("rediss://localhost:16379")
341341
assert rc.connection_kwargs["connection_class"] is SSLConnection
342342

343+
async def test_asynckills(self, r) -> None:
344+
345+
await r.set("foo", "foo")
346+
await r.set("bar", "bar")
347+
348+
t = asyncio.create_task(r.get("foo"))
349+
await asyncio.sleep(1)
350+
t.cancel()
351+
try:
352+
await t
353+
except asyncio.CancelledError:
354+
pytest.fail("connection is left open with unread response")
355+
356+
assert await r.get("bar") == b"bar"
357+
assert await r.ping()
358+
assert await r.get("foo") == b"foo"
359+
343360
async def test_max_connections(
344361
self, create_redis: Callable[..., RedisCluster]
345362
) -> None:

tests/test_asyncio/test_connection.py

+22
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@ async def test_invalid_response(create_redis):
4141
await r.connection.disconnect()
4242

4343

44+
@pytest.mark.onlynoncluster
45+
async def test_asynckills(create_redis):
46+
47+
for b in [True, False]:
48+
r = await create_redis(single_connection_client=b)
49+
50+
await r.set("foo", "foo")
51+
await r.set("bar", "bar")
52+
53+
t = asyncio.create_task(r.get("foo"))
54+
await asyncio.sleep(1)
55+
t.cancel()
56+
try:
57+
await t
58+
except asyncio.CancelledError:
59+
pytest.fail("connection left open with unread response")
60+
61+
assert await r.get("bar") == b"bar"
62+
assert await r.ping()
63+
assert await r.get("foo") == b"foo"
64+
65+
4466
@skip_if_server_version_lt("4.0.0")
4567
@pytest.mark.redismod
4668
@pytest.mark.onlynoncluster

0 commit comments

Comments
 (0)