New Upstream Release - python-channels-redis

Ready changes

Summary

Merged new upstream version: 4.1.0 (was: 4.0.0).

Diff

diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index dd80974..fee857d 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -2,27 +2,30 @@ name: Tests
 
 on:
   push:
-   branches:
-   - main
+    branches:
+      - main
   pull_request:
 
 jobs:
   tests:
     name: Python ${{ matrix.python-version }}
     runs-on: ubuntu-latest
+    timeout-minutes: 10
     strategy:
       fail-fast: false
       matrix:
         python-version:
-        - 3.7
-        - 3.8
-        - 3.9
-        - "3.10"
+          - 3.7
+          - 3.8
+          - 3.9
+          - "3.10"
+          # Refs #348 - skip Python 3.11 until redis-py >= 4.5.4 (or Python 3.11.3)
+          # - "3.11"
     services:
       redis:
         image: redis
         ports:
-         - 6379:6379
+          - 6379:6379
         options: >-
           --health-cmd "redis-cli ping"
           --health-interval 10s
@@ -44,9 +47,9 @@ jobs:
           REDIS_SENTINEL_PASSWORD: channels_redis
 
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python ${{ matrix.python-version }}
-        uses: actions/setup-python@v2
+        uses: actions/setup-python@v4
         with:
           python-version: ${{ matrix.python-version }}
       - name: Install dependencies
@@ -61,11 +64,11 @@ jobs:
     name: Lint
     runs-on: ubuntu-latest
     steps:
-      - uses: actions/checkout@v2
+      - uses: actions/checkout@v3
       - name: Set up Python
-        uses: actions/setup-python@v2
+        uses: actions/setup-python@v4
         with:
-          python-version: 3.9
+          python-version: "3.11"
       - name: Install dependencies
         run: |
           python -m pip install --upgrade pip tox
diff --git a/CHANGELOG.txt b/CHANGELOG.txt
index 65d36c3..d2e282f 100644
--- a/CHANGELOG.txt
+++ b/CHANGELOG.txt
@@ -1,13 +1,35 @@
+4.1.0 (2023-03-28)
+------------------
+
+* Adjusted the way Redis connections are handled:
+
+  * Connection handling is now shared between the two, core and pub-sub, layers.
+
+  * Both layers now ensure that connections are closed when an event loop shuts down.
+
+    In particular, redis-py 4.x requires that connections are manually closed.
+    In 4.0 that wasn't done by the core layer, which led to warnings for people
+    using `async_to_sync()`, without closing connections when updating from
+    3.x.
+
+* Updated the minimum redis-py version to 4.5.3 because of a security release there.
+  Note that this is not a security issue in channels-redis: installing an
+  earlier version will still use the latest redis-py, but by bumping the
+  dependency we make sure you'll get redis-py too, when you install the update
+  here.
+
 4.0.0 (2022-10-07)
 ------------------
 
 Version 4.0.0 migrates the underlying Redis library from ``aioredis`` to ``redis-py``.
 (``aioredis`` was retired and moved into ``redis-py``, which will host the ongoing development.)
 
-The API is unchanged. Version 4.0.0 should be compatible with existing Channels 3 projects, as well as Channels 4
+Version 4.0.0 should be compatible with existing Channels 3 projects, as well as Channels 4
 projects.
 
-* Migrated from ``aioredis`` to ``redis-py``.
+* Migrated from ``aioredis`` to ``redis-py``. Specifying hosts as tuples is no longer supported.
+  If hosts are specified as dicts, only the ``address`` key will be taken into account, i.e.
+  a `password`` must be specified inline in the address.
 
 * Added support for passing kwargs to sentinel connections.
 
diff --git a/README.rst b/README.rst
index 0e98ea1..b23a1c1 100644
--- a/README.rst
+++ b/README.rst
@@ -90,7 +90,7 @@ This should be slightly faster than a loopback TCP connection.
 ``prefix``
 ~~~~~~~~~~
 
-Prefix to add to all Redis keys. Defaults to ``asgi:``. If you're running
+Prefix to add to all Redis keys. Defaults to ``asgi``. If you're running
 two or more entirely separate channel layers through the same Redis instance,
 make sure they have different prefixes. All servers talking to the same layer
 should have the same prefix, though.
@@ -222,7 +222,7 @@ And then in your channels consumer, you can implement the handler:
 Dependencies
 ------------
 
-Redis >= 5.0 is required for `channels_redis`. Python 3.7 or higher is required.
+Redis server >= 5.0 is required for `channels_redis`. Python 3.7 or higher is required.
 
 
 Used commands
diff --git a/channels_redis/__init__.py b/channels_redis/__init__.py
index ce1305b..7039708 100644
--- a/channels_redis/__init__.py
+++ b/channels_redis/__init__.py
@@ -1 +1 @@
-__version__ = "4.0.0"
+__version__ = "4.1.0"
diff --git a/channels_redis/core.py b/channels_redis/core.py
index 7c04ecd..1111fc2 100644
--- a/channels_redis/core.py
+++ b/channels_redis/core.py
@@ -15,7 +15,7 @@ from redis import asyncio as aioredis
 from channels.exceptions import ChannelFull
 from channels.layers import BaseChannelLayer
 
-from .utils import _consistent_hash
+from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts
 
 logger = logging.getLogger(__name__)
 
@@ -69,6 +69,26 @@ class BoundedQueue(asyncio.Queue):
         return super(BoundedQueue, self).put_nowait(item)
 
 
+class RedisLoopLayer:
+    def __init__(self, channel_layer):
+        self._lock = asyncio.Lock()
+        self.channel_layer = channel_layer
+        self._connections = {}
+
+    def get_connection(self, index):
+        if index not in self._connections:
+            pool = self.channel_layer.create_pool(index)
+            self._connections[index] = aioredis.Redis(connection_pool=pool)
+
+        return self._connections[index]
+
+    async def flush(self):
+        async with self._lock:
+            for index in list(self._connections):
+                connection = self._connections.pop(index)
+                await connection.close(close_connection_pool=True)
+
+
 class RedisChannelLayer(BaseChannelLayer):
     """
     Redis channel layer.
@@ -98,11 +118,10 @@ class RedisChannelLayer(BaseChannelLayer):
         self.prefix = prefix
         assert isinstance(self.prefix, str), "Prefix must be unicode"
         # Configure the host objects
-        self.hosts = self.decode_hosts(hosts)
+        self.hosts = decode_hosts(hosts)
         self.ring_size = len(self.hosts)
         # Cached redis connection pools and the event loop they are from
-        self.pools = {}
-        self.pools_loop = None
+        self._layers = {}
         # Normal channels choose a host index by cycling through the available hosts
         self._receive_index_generator = itertools.cycle(range(len(self.hosts)))
         self._send_index_generator = itertools.cycle(range(len(self.hosts)))
@@ -127,46 +146,7 @@ class RedisChannelLayer(BaseChannelLayer):
         self.receive_clean_locks = ChannelLock()
 
     def create_pool(self, index):
-        host = self.hosts[index]
-
-        if "address" in host:
-            return aioredis.ConnectionPool.from_url(host["address"])
-        elif "master_name" in host:
-            sentinels = host.pop("sentinels")
-            master_name = host.pop("master_name")
-            sentinel_kwargs = host.pop("sentinel_kwargs", None)
-            return aioredis.sentinel.SentinelConnectionPool(
-                master_name,
-                aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
-                **host
-            )
-        else:
-            return aioredis.ConnectionPool(**host)
-
-    def decode_hosts(self, hosts):
-        """
-        Takes the value of the "hosts" argument passed to the class and returns
-        a list of kwargs to use for the Redis connection constructor.
-        """
-        # If no hosts were provided, return a default value
-        if not hosts:
-            return [{"address": "redis://localhost:6379"}]
-        # If they provided just a string, scold them.
-        if isinstance(hosts, (str, bytes)):
-            raise ValueError(
-                "You must pass a list of Redis hosts, even if there is only one."
-            )
-
-        # Decode each hosts entry into a kwargs dict
-        result = []
-        for entry in hosts:
-            if isinstance(entry, dict):
-                result.append(entry)
-            elif isinstance(entry, tuple):
-                result.append({"host": entry[0], "port": entry[1]})
-            else:
-                result.append({"address": entry})
-        return result
+        return create_pool(self.hosts[index])
 
     def _setup_encryption(self, symmetric_encryption_keys):
         # See if we can do encryption if they asked
@@ -331,7 +311,7 @@ class RedisChannelLayer(BaseChannelLayer):
 
                         raise
 
-                    message, token, exception = None, None, None
+                    message = token = exception = None
                     for task in done:
                         try:
                             result = task.result()
@@ -367,7 +347,7 @@ class RedisChannelLayer(BaseChannelLayer):
                             message_channel, message = await self.receive_single(
                                 real_channel
                             )
-                            if type(message_channel) is list:
+                            if isinstance(message_channel, list):
                                 for chan in message_channel:
                                     self.receive_buffer[chan].put_nowait(message)
                             else:
@@ -459,11 +439,7 @@ class RedisChannelLayer(BaseChannelLayer):
         Returns a new channel name that can be used by something in our
         process as a specific channel.
         """
-        return "%s.%s!%s" % (
-            prefix,
-            self.client_prefix,
-            uuid.uuid4().hex,
-        )
+        return f"{prefix}.{self.client_prefix}!{uuid.uuid4().hex}"
 
     ### Flush extension ###
 
@@ -496,9 +472,8 @@ class RedisChannelLayer(BaseChannelLayer):
         # Flush all cleaners, in case somebody just wanted to close the
         # pools without flushing first.
         await self.wait_received()
-
-        for index in self.pools:
-            await self.pools[index].disconnect()
+        for layer in self._layers.values():
+            await layer.flush()
 
     async def wait_received(self):
         """
@@ -667,7 +642,7 @@ class RedisChannelLayer(BaseChannelLayer):
         """
         Common function to make the storage key for the group.
         """
-        return ("%s:group:%s" % (self.prefix, group)).encode("utf8")
+        return f"{self.prefix}:group:{group}".encode("utf8")
 
     ### Serialization ###
 
@@ -711,7 +686,7 @@ class RedisChannelLayer(BaseChannelLayer):
         return Fernet(formatted_key)
 
     def __str__(self):
-        return "%s(hosts=%s)" % (self.__class__.__name__, self.hosts)
+        return f"{self.__class__.__name__}(hosts={self.hosts})"
 
     ### Connection handling ###
 
@@ -723,18 +698,14 @@ class RedisChannelLayer(BaseChannelLayer):
         # Catch bad indexes
         if not 0 <= index < self.ring_size:
             raise ValueError(
-                "There are only %s hosts - you asked for %s!" % (self.ring_size, index)
+                f"There are only {self.ring_size} hosts - you asked for {index}!"
             )
 
+        loop = asyncio.get_running_loop()
         try:
-            loop = asyncio.get_running_loop()
-            if self.pools_loop != loop:
-                self.pools = {}
-                self.pools_loop = loop
-        except RuntimeError:
-            pass
-
-        if index not in self.pools:
-            self.pools[index] = self.create_pool(index)
+            layer = self._layers[loop]
+        except KeyError:
+            _wrap_close(self, loop)
+            layer = self._layers[loop] = RedisLoopLayer(self)
 
-        return aioredis.Redis(connection_pool=self.pools[index])
+        return layer.get_connection(index)
diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py
index 3c10378..78db68e 100644
--- a/channels_redis/pubsub.py
+++ b/channels_redis/pubsub.py
@@ -1,32 +1,16 @@
 import asyncio
 import functools
 import logging
-import types
 import uuid
 
 import msgpack
 from redis import asyncio as aioredis
 
-from .utils import _consistent_hash
+from .utils import _consistent_hash, _wrap_close, create_pool, decode_hosts
 
 logger = logging.getLogger(__name__)
 
 
-def _wrap_close(proxy, loop):
-    original_impl = loop.close
-
-    def _wrapper(self, *args, **kwargs):
-        if loop in proxy._layers:
-            layer = proxy._layers[loop]
-            del proxy._layers[loop]
-            loop.run_until_complete(layer.flush())
-
-        self.close = original_impl
-        return self.close(*args, **kwargs)
-
-    loop.close = types.MethodType(_wrapper, loop)
-
-
 async def _async_proxy(obj, name, *args, **kwargs):
     # Must be defined as a function and not a method due to
     # https://bugs.python.org/issue38364
@@ -97,12 +81,6 @@ class RedisPubSubLoopLayer:
         channel_layer=None,
         **kwargs,
     ):
-        if hosts is None:
-            hosts = ["redis://localhost:6379"]
-        assert (
-            isinstance(hosts, list) and len(hosts) > 0
-        ), "`hosts` must be a list with at least one Redis server"
-
         self.prefix = prefix
 
         self.on_disconnect = on_disconnect
@@ -118,7 +96,9 @@ class RedisPubSubLoopLayer:
         self.groups = {}
 
         # For each host, we create a `RedisSingleShardConnection` to manage the connection to that host.
-        self._shards = [RedisSingleShardConnection(host, self) for host in hosts]
+        self._shards = [
+            RedisSingleShardConnection(host, self) for host in decode_hosts(hosts)
+        ]
 
     def _get_shard(self, channel_or_group_name):
         """
@@ -223,12 +203,14 @@ class RedisPubSubLoopLayer:
 
     async def group_discard(self, group, channel):
         """
-        Removes the channel from a group.
+        Removes the channel from a group if it is in the group;
+        does nothing otherwise (does not error)
         """
         group_channel = self._get_group_channel_name(group)
-        assert group_channel in self.groups
-        group_channels = self.groups[group_channel]
-        assert channel in group_channels
+        group_channels = self.groups.get(group_channel, set())
+        if channel not in group_channels:
+            return
+
         group_channels.remove(channel)
         if len(group_channels) == 0:
             del self.groups[group_channel]
@@ -261,9 +243,7 @@ class RedisPubSubLoopLayer:
 
 class RedisSingleShardConnection:
     def __init__(self, host, channel_layer):
-        self.host = host.copy() if type(host) is dict else {"address": host}
-        self.master_name = self.host.pop("master_name", None)
-        self.sentinel_kwargs = self.host.pop("sentinel_kwargs", None)
+        self.host = host
         self.channel_layer = channel_layer
         self._subscribed_to = set()
         self._lock = asyncio.Lock()
@@ -345,18 +325,7 @@ class RedisSingleShardConnection:
 
     def _ensure_redis(self):
         if self._redis is None:
-            if self.master_name is None:
-                pool = aioredis.ConnectionPool.from_url(self.host["address"])
-            else:
-                # aioredis default timeout is way too low
-                pool = aioredis.sentinel.SentinelConnectionPool(
-                    self.master_name,
-                    aioredis.sentinel.Sentinel(
-                        self.host["sentinels"],
-                        socket_timeout=2,
-                        sentinel_kwargs=self.sentinel_kwargs,
-                    ),
-                )
+            pool = create_pool(self.host)
             self._redis = aioredis.Redis(connection_pool=pool)
             self._pubsub = self._redis.pubsub()
 
diff --git a/channels_redis/utils.py b/channels_redis/utils.py
index 7b30fdc..98e06ca 100644
--- a/channels_redis/utils.py
+++ b/channels_redis/utils.py
@@ -1,4 +1,7 @@
 import binascii
+import types
+
+from redis import asyncio as aioredis
 
 
 def _consistent_hash(value, ring_size):
@@ -15,3 +18,68 @@ def _consistent_hash(value, ring_size):
     bigval = binascii.crc32(value) & 0xFFF
     ring_divisor = 4096 / float(ring_size)
     return int(bigval / ring_divisor)
+
+
+def _wrap_close(proxy, loop):
+    original_impl = loop.close
+
+    def _wrapper(self, *args, **kwargs):
+        if loop in proxy._layers:
+            layer = proxy._layers[loop]
+            del proxy._layers[loop]
+            loop.run_until_complete(layer.flush())
+
+        self.close = original_impl
+        return self.close(*args, **kwargs)
+
+    loop.close = types.MethodType(_wrapper, loop)
+
+
+def decode_hosts(hosts):
+    """
+    Takes the value of the "hosts" argument and returns
+    a list of kwargs to use for the Redis connection constructor.
+    """
+    # If no hosts were provided, return a default value
+    if not hosts:
+        return [{"address": "redis://localhost:6379"}]
+    # If they provided just a string, scold them.
+    if isinstance(hosts, (str, bytes)):
+        raise ValueError(
+            "You must pass a list of Redis hosts, even if there is only one."
+        )
+
+    # Decode each hosts entry into a kwargs dict
+    result = []
+    for entry in hosts:
+        if isinstance(entry, dict):
+            result.append(entry)
+        elif isinstance(entry, (tuple, list)):
+            result.append({"host": entry[0], "port": entry[1]})
+        else:
+            result.append({"address": entry})
+    return result
+
+
+def create_pool(host):
+    """
+    Takes the value of the "host" argument and returns a suited connection pool to
+    the corresponding redis instance.
+    """
+    # avoid side-effects from modifying host
+    host = host.copy()
+    if "address" in host:
+        address = host.pop("address")
+        return aioredis.ConnectionPool.from_url(address, **host)
+
+    master_name = host.pop("master_name", None)
+    if master_name is not None:
+        sentinels = host.pop("sentinels")
+        sentinel_kwargs = host.pop("sentinel_kwargs", None)
+        return aioredis.sentinel.SentinelConnectionPool(
+            master_name,
+            aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs),
+            **host
+        )
+
+    return aioredis.ConnectionPool(**host)
diff --git a/debian/changelog b/debian/changelog
index b367dfb..e895171 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+python-channels-redis (4.1.0-1) UNRELEASED; urgency=low
+
+  * New upstream release.
+
+ -- Debian Janitor <janitor@jelmer.uk>  Thu, 18 May 2023 17:43:37 -0000
+
 python-channels-redis (4.0.0-1) unstable; urgency=medium
 
   * New upstream release.
diff --git a/setup.cfg b/setup.cfg
index fd799b4..3888deb 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -2,6 +2,7 @@
 addopts = -p no:django
 testpaths = tests
 asyncio_mode = auto
+timeout = 10
 
 [flake8]
 exclude = venv/*,tox/*,specs/*,build/*
diff --git a/setup.py b/setup.py
index 3886521..566a622 100644
--- a/setup.py
+++ b/setup.py
@@ -13,6 +13,7 @@ test_requires = crypto_requires + [
     "pytest",
     "pytest-asyncio",
     "async-timeout",
+    "pytest-timeout",
 ]
 
 
@@ -30,7 +31,7 @@ setup(
     include_package_data=True,
     python_requires=">=3.7",
     install_requires=[
-        "redis>=4.2.0",
+        "redis>=4.5.3",
         "msgpack~=1.0",
         "asgiref>=3.2.10,<4",
         "channels",
diff --git a/tests/test_core.py b/tests/test_core.py
index 6d2ff2a..2752040 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -419,6 +419,12 @@ def test_repeated_group_send_with_async_to_sync(channel_layer):
         pytest.fail(f"repeated async_to_sync wrapped group_send calls raised {exc}")
 
 
+@pytest.mark.xfail(
+    reason="""
+Fails with error in redis-py: int() argument must be a string, a bytes-like
+object or a real number, not 'NoneType'. Refs: #348
+"""
+)
 @pytest.mark.asyncio
 async def test_receive_cancel(channel_layer):
     """
@@ -551,6 +557,7 @@ async def test_message_expiry__group_send(channel_layer):
             await channel_layer.receive(channel_name)
 
 
+@pytest.mark.xfail(reason="Fails with timeout. Refs: #348")
 @pytest.mark.asyncio
 async def test_message_expiry__group_send__one_channel_expires_message(channel_layer):
     expiry = 3
diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py
index d9a1082..78ad080 100644
--- a/tests/test_pubsub.py
+++ b/tests/test_pubsub.py
@@ -19,7 +19,8 @@ async def channel_layer():
     """
     channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS)
     yield channel_layer
-    await channel_layer.flush()
+    async with async_timeout.timeout(1):
+        await channel_layer.flush()
 
 
 @pytest.fixture()
@@ -252,3 +253,10 @@ async def test_auto_reconnect(channel_layer):
     with pytest.raises(asyncio.TimeoutError):
         async with async_timeout.timeout(1):
             await channel_layer.receive(channel_name2)
+
+
+@pytest.mark.asyncio
+async def test_discard_before_add(channel_layer):
+    channel_name = await channel_layer.new_channel(prefix="test-channel")
+    # Make sure that we can remove a group before it was ever added without crashing.
+    await channel_layer.group_discard("test-group", channel_name)
diff --git a/tests/test_pubsub_sentinel.py b/tests/test_pubsub_sentinel.py
index 300b106..049e39b 100644
--- a/tests/test_pubsub_sentinel.py
+++ b/tests/test_pubsub_sentinel.py
@@ -25,7 +25,8 @@ async def channel_layer():
     """
     channel_layer = RedisPubSubChannelLayer(hosts=TEST_HOSTS)
     yield channel_layer
-    await channel_layer.flush()
+    async with async_timeout.timeout(1):
+        await channel_layer.flush()
 
 
 @pytest.mark.asyncio
diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py
index 0a0fc2f..4fb7de5 100644
--- a/tests/test_sentinel.py
+++ b/tests/test_sentinel.py
@@ -461,6 +461,12 @@ async def test_group_send_capacity_multiple_channels(channel_layer, caplog):
         )
 
 
+@pytest.mark.xfail(
+    reason="""
+Fails with error in redis-py: int() argument must be a string, a bytes-like
+object or a real number, not 'NoneType'. Refs: #348
+"""
+)
 @pytest.mark.asyncio
 async def test_receive_cancel(channel_layer):
     """
@@ -593,6 +599,7 @@ async def test_message_expiry__group_send(channel_layer):
             await channel_layer.receive(channel_name)
 
 
+@pytest.mark.xfail(reason="Fails with timeout. Refs: #348")
 @pytest.mark.asyncio
 async def test_message_expiry__group_send__one_channel_expires_message(channel_layer):
     expiry = 3
diff --git a/tox.ini b/tox.ini
index 1aaa9a2..3ada8dc 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,6 +1,6 @@
 [tox]
 envlist =
-    py{37,38,39,310}
+    py{37,38,39,310,311}
     qa
 
 [testenv]

More details

Full run details

Historical runs