New Upstream Release - zict
Ready changes
Summary
Merged new upstream version: 3.0.0 (was: 2.2.0).
Diff
diff --git a/PKG-INFO b/PKG-INFO
index 9d88586..bfa869c 100644
--- a/PKG-INFO
+++ b/PKG-INFO
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: zict
-Version: 2.2.0
+Version: 3.0.0
Summary: Mutable mapping tools
Home-page: http://zict.readthedocs.io/en/latest/
Maintainer: Matthew Rocklin
@@ -13,11 +13,11 @@ Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
-Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
-Requires-Python: >=3.7
+Classifier: Programming Language :: Python :: 3.11
+Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE.txt
diff --git a/debian/changelog b/debian/changelog
index 1eccf53..9960efb 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+zict (3.0.0-1) UNRELEASED; urgency=low
+
+ * New upstream release.
+
+ -- Debian Janitor <janitor@jelmer.uk> Wed, 03 May 2023 07:35:36 -0000
+
zict (2.2.0-1) unstable; urgency=medium
* Team upload.
diff --git a/requirements.txt b/requirements.txt
index c6d4535..e69de29 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +0,0 @@
-heapdict
diff --git a/setup.cfg b/setup.cfg
index 00d26fe..b69e684 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,6 @@
[metadata]
name = zict
-version = 2.2.0
+version = 3.0.0
maintainer = Matthew Rocklin
maintainer_email = mrocklin@coiled.io
license = BSD
@@ -19,18 +19,17 @@ classifiers =
Operating System :: OS Independent
Programming Language :: Python
Programming Language :: Python :: 3
- Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
[options]
packages = zict
zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html
include_package_data = True
-python_requires = >=3.7
+python_requires = >=3.8
install_requires =
- heapdict
[options.package_data]
zict =
@@ -46,13 +45,23 @@ universal = 1
extend-ignore = E203, E266, E501
exclude = __init__.py
ignore =
- E4, # Import formatting
- E731, # Assigning lambda expression
- W503, # line break before binary operator
+ E4
+ E731
+ W503
max-line-length = 88
[tool:pytest]
-addopts = -v --doctest-modules
+addopts =
+ -v
+ --doctest-modules
+ --durations=20
+ --strict-markers
+ --strict-config
+ -p no:legacypath
+timeout_method = thread
+timeout = 180
+markers =
+ stress: slow-running stress test with a random component. Pass --stress <n> to change number of reruns.
[isort]
sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
@@ -62,6 +71,22 @@ force_to_top = true
default_section = THIRDPARTY
known_first_party = zict
+[mypy]
+python_version = 3.9
+platform = linux
+allow_incomplete_defs = false
+allow_untyped_decorators = false
+allow_untyped_defs = false
+ignore_missing_imports = true
+no_implicit_optional = true
+show_error_codes = true
+warn_redundant_casts = true
+warn_unused_ignores = true
+warn_unreachable = true
+
+[mypy-zict.tests.*]
+allow_untyped_defs = true
+
[egg_info]
tag_build =
tag_date = 0
diff --git a/zict.egg-info/PKG-INFO b/zict.egg-info/PKG-INFO
index 9d88586..bfa869c 100644
--- a/zict.egg-info/PKG-INFO
+++ b/zict.egg-info/PKG-INFO
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: zict
-Version: 2.2.0
+Version: 3.0.0
Summary: Mutable mapping tools
Home-page: http://zict.readthedocs.io/en/latest/
Maintainer: Matthew Rocklin
@@ -13,11 +13,11 @@ Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
-Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
-Requires-Python: >=3.7
+Classifier: Programming Language :: Python :: 3.11
+Requires-Python: >=3.8
Description-Content-Type: text/x-rst
License-File: LICENSE.txt
diff --git a/zict.egg-info/SOURCES.txt b/zict.egg-info/SOURCES.txt
index 63d62d9..c37cd75 100644
--- a/zict.egg-info/SOURCES.txt
+++ b/zict.egg-info/SOURCES.txt
@@ -5,6 +5,7 @@ requirements.txt
setup.cfg
setup.py
zict/__init__.py
+zict/async_buffer.py
zict/buffer.py
zict/cache.py
zict/common.py
@@ -14,20 +15,24 @@ zict/lmdb.py
zict/lru.py
zict/py.typed
zict/sieve.py
+zict/utils.py
zict/zip.py
zict.egg-info/PKG-INFO
zict.egg-info/SOURCES.txt
zict.egg-info/dependency_links.txt
zict.egg-info/not-zip-safe
-zict.egg-info/requires.txt
zict.egg-info/top_level.txt
zict/tests/__init__.py
+zict/tests/conftest.py
+zict/tests/test_async_buffer.py
zict/tests/test_buffer.py
zict/tests/test_cache.py
+zict/tests/test_common.py
zict/tests/test_file.py
zict/tests/test_func.py
zict/tests/test_lmdb.py
zict/tests/test_lru.py
zict/tests/test_sieve.py
+zict/tests/test_utils.py
zict/tests/test_zip.py
zict/tests/utils_test.py
\ No newline at end of file
diff --git a/zict.egg-info/requires.txt b/zict.egg-info/requires.txt
deleted file mode 100644
index c6d4535..0000000
--- a/zict.egg-info/requires.txt
+++ /dev/null
@@ -1 +0,0 @@
-heapdict
diff --git a/zict/__init__.py b/zict/__init__.py
index e3d5303..fcdc44b 100644
--- a/zict/__init__.py
+++ b/zict/__init__.py
@@ -1,11 +1,14 @@
-from zict.buffer import Buffer
-from zict.cache import Cache, WeakValueMapping
-from zict.file import File
-from zict.func import Func
-from zict.lmdb import LMDB
-from zict.lru import LRU
-from zict.sieve import Sieve
-from zict.zip import Zip
+from zict.async_buffer import AsyncBuffer as AsyncBuffer
+from zict.buffer import Buffer as Buffer
+from zict.cache import Cache as Cache
+from zict.cache import WeakValueMapping as WeakValueMapping
+from zict.file import File as File
+from zict.func import Func as Func
+from zict.lmdb import LMDB as LMDB
+from zict.lru import LRU as LRU
+from zict.sieve import Sieve as Sieve
+from zict.utils import InsertionSortedSet as InsertionSortedSet
+from zict.zip import Zip as Zip
# Must be kept aligned with setup.cfg
-__version__ = "2.2.0"
+__version__ = "3.0.0"
diff --git a/zict/async_buffer.py b/zict/async_buffer.py
new file mode 100644
index 0000000..aa4281f
--- /dev/null
+++ b/zict/async_buffer.py
@@ -0,0 +1,176 @@
+from __future__ import annotations
+
+import asyncio
+import contextvars
+from collections.abc import Callable, Collection
+from concurrent.futures import Executor, ThreadPoolExecutor
+from functools import wraps
+from itertools import chain
+from typing import Any, Literal
+
+from zict.buffer import Buffer
+from zict.common import KT, VT, T, locked
+
+
+class AsyncBuffer(Buffer[KT, VT]):
+ """Extension of :class:`~zict.Buffer` that allows offloading all reads and writes
+ from/to slow to a separate worker thread.
+
+ This requires ``fast`` to be fully thread-safe (e.g. a plain dict).
+
+ ``slow.__setitem__`` and ``slow.__getitem__`` will be called from the offloaded
+ thread, while all of its other methods (including, notably for the purpose of
+ thread-safety consideration, ``__contains__`` and ``__delitem__``) will be called
+ from the main thread.
+
+ See Also
+ --------
+ Buffer
+
+ Parameters
+ ----------
+ Same as in Buffer, plus:
+
+ executor: concurrent.futures.Executor, optional
+ An Executor instance to use for offloading. It must not pickle/unpickle.
+ Defaults to an internal ThreadPoolExecutor.
+ nthreads: int, optional
+ Number of offloaded threads to run in parallel. Defaults to 1.
+ Mutually exclusive with executor parameter.
+ """
+
+ executor: Executor | None
+ nthreads: int | None
+ futures: set[asyncio.Future]
+ evicting: dict[asyncio.Future, float]
+
+ @wraps(Buffer.__init__)
+ def __init__(
+ self,
+ *args: Any,
+ executor: Executor | None = None,
+ nthreads: int = 1,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.executor = executor
+ self.nthreads = None if executor else nthreads
+ self._internal_executor = executor is None
+ self.futures = set()
+ self.evicting = {}
+
+ def close(self) -> None:
+ # Call LRU.close(), which stops LRU.evict_until_below_target() halfway through
+ super().close()
+ for future in self.futures:
+ future.cancel()
+ if self.executor is not None and self.nthreads is not None:
+ self.executor.shutdown(wait=True)
+ self.executor = None
+
+ def _offload(self, func: Callable[..., T], *args: Any) -> asyncio.Future[T]:
+ if self.executor is None:
+ assert self.nthreads
+ self.executor = ThreadPoolExecutor(
+ self.nthreads, thread_name_prefix="zict.AsyncBuffer offloader"
+ )
+
+ loop = asyncio.get_running_loop()
+ context = contextvars.copy_context()
+ future = loop.run_in_executor(self.executor, context.run, func, *args)
+ self.futures.add(future)
+ future.add_done_callback(self.futures.remove)
+ return future # type: ignore[return-value]
+
+ # Return an asyncio.Future, instead of just writing it as an async function, to make
+ # it easier for overriders to tell apart the use case when all keys were already
+ # in fast
+ @locked
+ def async_get(
+ self, keys: Collection[KT], missing: Literal["raise", "omit"] = "raise"
+ ) -> asyncio.Future[dict[KT, VT]]:
+ """Fetch one or more key/value pairs. If not all keys are available in fast,
+ offload to a worker thread moving keys from slow to fast, as well as possibly
+ moving older keys from fast to slow.
+
+ Parameters
+ ----------
+ keys:
+ collection of zero or more keys to get
+ missing: raise or omit, optional
+ raise (default)
+ If any key is missing, raise KeyError.
+ omit
+ If a key is missing, return a dict with less keys than those requested.
+
+ Notes
+ -----
+ All keys may be present when you call ``async_get``, but ``__delitem__`` may be
+ called on one of them before the actual data is fetched. ``__setitem__`` also
+ internally calls ``__delitem__`` in a non-atomic way, so you may get
+ ``KeyError`` when updating a value too.
+ """
+ # This block avoids spawning a thread if keys are missing from both fast and
+ # slow. It is otherwise just a performance optimization.
+ if missing == "omit":
+ keys = [key for key in keys if key in self]
+ elif missing == "raise":
+ for key in keys:
+ if key not in self:
+ raise KeyError(key)
+ else:
+ raise ValueError(f"missing: expected raise or omit; got {missing}")
+ # End performance optimization
+
+ try:
+ # Do not pull keys towards the top of the LRU unless they are all available.
+ # This matters when there is a very long queue of async_get futures.
+ d = self.fast.get_all_or_nothing(keys)
+ except KeyError:
+ pass
+ else:
+ f: asyncio.Future[dict[KT, VT]] = asyncio.Future()
+ f.set_result(d)
+ return f
+
+ def _async_get() -> dict[KT, VT]:
+ d = {}
+ for k in keys:
+ if self.fast.closed:
+ raise asyncio.CancelledError()
+ try:
+ # This can cause keys to be restored and older keys to be evicted
+ d[k] = self[k]
+ except KeyError:
+ # Race condition: key was there when async_get was called, but got
+ # deleted afterwards.
+ if missing == "raise":
+ raise
+ return d
+
+ return self._offload(_async_get)
+
+ def __setitem__(self, key: KT, value: VT) -> None:
+ """Immediately set a key in fast. If this causes the total weight to exceed n,
+ asynchronously start moving keys from fast to slow in a worker thread.
+ """
+ self.set_noevict(key, value)
+ self.async_evict_until_below_target()
+
+ @locked
+ def async_evict_until_below_target(self, n: float | None = None) -> None:
+ """If the total weight exceeds n, asynchronously start moving keys from fast to
+ slow in a worker thread.
+ """
+ if n is None:
+ n = self.n
+ n = max(0.0, n)
+ weight = min(chain([self.fast.total_weight], self.evicting.values()))
+ if weight <= n:
+ return
+
+ # Note: this can get cancelled by LRU.close(), which in turn is
+ # triggered by Buffer.close()
+ future = self._offload(self.evict_until_below_target, n)
+ self.evicting[future] = n
+ future.add_done_callback(self.evicting.__delitem__)
diff --git a/zict/buffer.py b/zict/buffer.py
index 1227e8a..720dd3f 100644
--- a/zict/buffer.py
+++ b/zict/buffer.py
@@ -2,8 +2,12 @@ from __future__ import annotations
from collections.abc import Callable, Iterator, MutableMapping
from itertools import chain
+from typing import ( # TODO import from collections.abc (needs Python >=3.9)
+ ItemsView,
+ ValuesView,
+)
-from zict.common import KT, VT, ZictBase, close, flush
+from zict.common import KT, VT, ZictBase, close, discard, flush, locked
from zict.lru import LRU
@@ -20,21 +24,27 @@ class Buffer(ZictBase[KT, VT]):
fast: MutableMapping
slow: MutableMapping
n: float
- Total size of fast that triggers evictions to slow
+ Number of elements to keep, or total weight if ``weight`` is used.
weight: f(k, v) -> float, optional
Weight of each key/value pair (default: 1)
fast_to_slow_callbacks: list of callables
These functions run every time data moves from the fast to the slow
- mapping. They take two arguments, a key and a value
+ mapping. They take two arguments, a key and a value.
If an exception occurs during a fast_to_slow_callbacks (e.g a callback tried
storing to disk and raised a disk full error) the key will remain in the LRU.
slow_to_fast_callbacks: list of callables
- These functions run every time data moves form the slow to the fast
- mapping.
+ These functions run every time data moves form the slow to the fast mapping.
+
+ Notes
+ -----
+ If you call methods of this class from multiple threads, access will be fast as long
+ as all methods of ``fast``, plus ``slow.__contains__`` and ``slow.__delitem__``, are
+ fast. ``slow.__getitem__``, ``slow.__setitem__`` and callbacks are not protected
+ by locks.
Examples
--------
- >>> fast = dict()
+ >>> fast = {}
>>> slow = Func(dumps, loads, File('storage/')) # doctest: +SKIP
>>> def weight(k, v):
... return sys.getsizeof(v)
@@ -47,10 +57,10 @@ class Buffer(ZictBase[KT, VT]):
fast: LRU[KT, VT]
slow: MutableMapping[KT, VT]
- n: float
weight: Callable[[KT, VT], float]
fast_to_slow_callbacks: list[Callable[[KT, VT], None]]
slow_to_fast_callbacks: list[Callable[[KT, VT], None]]
+ _cancel_restore: dict[KT, bool]
def __init__(
self,
@@ -65,17 +75,65 @@ class Buffer(ZictBase[KT, VT]):
| list[Callable[[KT, VT], None]]
| None = None,
):
- self.fast = LRU(n, fast, weight=weight, on_evict=[self.fast_to_slow])
+ super().__init__()
+ self.fast = LRU(
+ n,
+ fast,
+ weight=weight,
+ on_evict=[self.fast_to_slow],
+ on_cancel_evict=[self._cancel_evict],
+ )
self.slow = slow
- self.n = n
- # FIXME https://github.com/python/mypy/issues/708
- self.weight = weight # type: ignore
+ self.weight = weight
if callable(fast_to_slow_callbacks):
fast_to_slow_callbacks = [fast_to_slow_callbacks]
if callable(slow_to_fast_callbacks):
slow_to_fast_callbacks = [slow_to_fast_callbacks]
self.fast_to_slow_callbacks = fast_to_slow_callbacks or []
self.slow_to_fast_callbacks = slow_to_fast_callbacks or []
+ self._cancel_restore = {}
+
+ @property
+ def n(self) -> float:
+ """Maximum weight in the fast mapping before eviction happens.
+ Can be updated; this won't trigger eviction by itself; you should call
+ :meth:`evict_until_below_target` afterwards.
+
+ See also
+ --------
+ offset
+ evict_until_below_target
+ LRU.n
+ LRU.offset
+ """
+ return self.fast.n
+
+ @n.setter
+ def n(self, value: float) -> None:
+ self.fast.n = value
+
+ @property
+ def offset(self) -> float:
+ """Offset to add to the total weight in the fast buffer to determine when
+ eviction happens. Note that increasing offset is not the same as decreasing n,
+ as the latter also changes what keys qualify as "heavy" and should not be stored
+ in fast.
+
+ Always starts at zero and can be updated; this won't trigger eviction by itself;
+ you should call :meth:`evict_until_below_target` afterwards.
+
+ See also
+ --------
+ n
+ evict_until_below_target
+ LRU.n
+ LRU.offset
+ """
+ return self.fast.offset
+
+ @offset.setter
+ def offset(self, value: float) -> None:
+ self.fast.offset = value
def fast_to_slow(self, key: KT, value: VT) -> None:
self.slow[key] = value
@@ -89,55 +147,112 @@ class Buffer(ZictBase[KT, VT]):
raise
def slow_to_fast(self, key: KT) -> VT:
- value = self.slow[key]
+ self._cancel_restore[key] = False
+ try:
+ with self.unlock():
+ value = self.slow[key]
+ if self._cancel_restore[key]:
+ raise KeyError(key)
+ finally:
+ del self._cancel_restore[key]
+
# Avoid useless movement for heavy values
- w = self.weight(key, value) # type: ignore
+ w = self.weight(key, value)
if w <= self.n:
+ # Multithreaded edge case:
+ # - Thread 1 starts slow_to_fast(x) and puts it at the top of fast
+ # - This causes the eviction of older key(s)
+ # - While thread 1 is evicting older keys, thread 2 is loading fast with
+ # set_noevict()
+ # - By the time the eviction of the older key(s) is done, there is
+ # enough weight in fast that thread 1 will spill x
+ # - If the below code was just `self.fast[key] = value; del
+ # self.slow[key]` now the key would be in neither slow nor fast!
+ self.fast.set_noevict(key, value)
del self.slow[key]
- self.fast[key] = value
- for cb in self.slow_to_fast_callbacks:
- cb(key, value)
+
+ with self.unlock():
+ self.fast.evict_until_below_target()
+ for cb in self.slow_to_fast_callbacks:
+ cb(key, value)
+
return value
+ @locked
def __getitem__(self, key: KT) -> VT:
- if key in self.fast:
+ try:
return self.fast[key]
- elif key in self.slow:
+ except KeyError:
return self.slow_to_fast(key)
- else:
- raise KeyError(key)
def __setitem__(self, key: KT, value: VT) -> None:
- if key in self.slow:
- del self.slow[key]
- # This may trigger an eviction from fast to slow of older keys.
- # If the weight is individually greater than n, then key/value will be stored
- # into self.slow instead (see LRU.__setitem__).
+ with self.lock:
+ discard(self.slow, key)
+ if key in self._cancel_restore:
+ self._cancel_restore[key] = True
self.fast[key] = value
+ @locked
+ def set_noevict(self, key: KT, value: VT) -> None:
+ """Variant of ``__setitem__`` that does not move keys from fast to slow if the
+ total weight exceeds n
+ """
+ discard(self.slow, key)
+ if key in self._cancel_restore:
+ self._cancel_restore[key] = True
+ self.fast.set_noevict(key, value)
+
+ def evict_until_below_target(self, n: float | None = None) -> None:
+ """Wrapper around :meth:`zict.LRU.evict_until_below_target`.
+ Presented here to allow easier overriding.
+ """
+ self.fast.evict_until_below_target(n)
+
+ @locked
def __delitem__(self, key: KT) -> None:
- if key in self.fast:
+ if key in self._cancel_restore:
+ self._cancel_restore[key] = True
+ try:
del self.fast[key]
- elif key in self.slow:
+ except KeyError:
del self.slow[key]
- else:
- raise KeyError(key)
- # FIXME dictionary views https://github.com/dask/zict/issues/61
- def keys(self) -> Iterator[KT]: # type: ignore
- return chain(self.fast.keys(), self.slow.keys())
+ @locked
+ def _cancel_evict(self, key: KT, value: VT) -> None:
+ discard(self.slow, key)
- def values(self) -> Iterator[VT]: # type: ignore
- return chain(self.fast.values(), self.slow.values())
+ def values(self) -> ValuesView[VT]:
+ return BufferValuesView(self)
- def items(self) -> Iterator[tuple[KT, VT]]: # type: ignore
- return chain(self.fast.items(), self.slow.items())
+ def items(self) -> ItemsView[KT, VT]:
+ return BufferItemsView(self)
def __len__(self) -> int:
- return len(self.fast) + len(self.slow)
+ with self.lock, self.fast.lock:
+ return (
+ len(self.fast)
+ + len(self.slow)
+ - sum(
+ k in self.fast and k in self.slow
+ for k in chain(self._cancel_restore, self.fast._cancel_evict)
+ )
+ )
def __iter__(self) -> Iterator[KT]:
- return chain(iter(self.fast), iter(self.slow))
+ """Make sure that the iteration is not disrupted if you evict/restore a key in
+ the middle of it
+ """
+ seen = set()
+ while True:
+ try:
+ for d in (self.fast, self.slow):
+ for key in d:
+ if key not in seen:
+ seen.add(key)
+ yield key
+ return
+ except RuntimeError:
+ pass
def __contains__(self, key: object) -> bool:
return key in self.fast or key in self.slow
@@ -152,3 +267,25 @@ class Buffer(ZictBase[KT, VT]):
def close(self) -> None:
close(self.fast, self.slow)
+
+
+class BufferItemsView(ItemsView[KT, VT]):
+ _mapping: Buffer # FIXME CPython implementation detail
+ __slots__ = ()
+
+ def __iter__(self) -> Iterator[tuple[KT, VT]]:
+ # Avoid changing the LRU
+ return chain(self._mapping.fast.items(), self._mapping.slow.items())
+
+
+class BufferValuesView(ValuesView[VT]):
+ _mapping: Buffer # FIXME CPython implementation detail
+ __slots__ = ()
+
+ def __contains__(self, value: object) -> bool:
+ # Avoid changing the LRU
+ return any(value == v for v in self)
+
+ def __iter__(self) -> Iterator[VT]:
+ # Avoid changing the LRU
+ return chain(self._mapping.fast.values(), self._mapping.slow.values())
diff --git a/zict/cache.py b/zict/cache.py
index eb247c6..3030e51 100644
--- a/zict/cache.py
+++ b/zict/cache.py
@@ -1,10 +1,10 @@
from __future__ import annotations
import weakref
-from collections.abc import Iterator, KeysView, MutableMapping
+from collections.abc import Iterator, MutableMapping
from typing import TYPE_CHECKING
-from zict.common import KT, VT, ZictBase, close, flush
+from zict.common import KT, VT, ZictBase, close, discard, flush, locked
class Cache(ZictBase[KT, VT]):
@@ -22,10 +22,16 @@ class Cache(ZictBase[KT, VT]):
If True (default), the cache will be updated both when writing and reading.
If False, update the cache when reading, but just invalidate it when writing.
+ Notes
+ -----
+ If you call methods of this class from multiple threads, access will be fast as long
+ as all methods of ``cache``, plus ``data.__delitem__``, are fast. Other methods of
+ ``data`` are not protected by locks.
+
Examples
--------
Keep the latest 100 accessed values in memory
- >>> from zict import File, LRU
+ >>> from zict import Cache, File, LRU, WeakValueMapping
>>> d = Cache(File('myfile'), LRU(100, {})) # doctest: +SKIP
Read data from disk every time, unless it was previously accessed and it's still in
@@ -36,6 +42,8 @@ class Cache(ZictBase[KT, VT]):
data: MutableMapping[KT, VT]
cache: MutableMapping[KT, VT]
update_on_set: bool
+ _gen: int
+ _last_updated: dict[KT, int]
def __init__(
self,
@@ -43,32 +51,66 @@ class Cache(ZictBase[KT, VT]):
cache: MutableMapping[KT, VT],
update_on_set: bool = True,
):
+ super().__init__()
self.data = data
self.cache = cache
self.update_on_set = update_on_set
+ self._gen = 0
+ self._last_updated = {}
+ @locked
def __getitem__(self, key: KT) -> VT:
try:
return self.cache[key]
except KeyError:
pass
- value = self.data[key]
- self.cache[key] = value
- return value
+ gen = self._last_updated[key]
- def __setitem__(self, key: KT, value: VT) -> None:
- # If the item was already in cache and data.__setitem__ fails, e.g. because it's
- # a File and the disk is full, make sure that the cache is invalidated.
- # FIXME https://github.com/python/mypy/issues/10152
- self.cache.pop(key, None) # type: ignore
+ with self.unlock():
+ value = self.data[key]
- self.data[key] = value
- if self.update_on_set:
+ # Could another thread have called __setitem__ or __delitem__ on the
+ # same key in the meantime? If not, update the cache
+ if gen == self._last_updated.get(key):
self.cache[key] = value
+ self._last_updated[key] += 1
+ return value
+ @locked
+ def __setitem__(self, key: KT, value: VT) -> None:
+ # If the item was already in cache and data.__setitem__ fails, e.g. because
+ # it's a File and the disk is full, make sure that the cache is invalidated.
+ discard(self.cache, key)
+ gen = self._gen
+ gen += 1
+ self._last_updated[key] = self._gen = gen
+
+ with self.unlock():
+ self.data[key] = value
+
+ if key not in self._last_updated:
+ # Another thread called __delitem__ in the meantime
+ discard(self.data, key)
+ elif gen != self._last_updated[key]:
+ # Another thread called __setitem__ in the meantime. We have no idea which
+ # of the two ended up actually setting self.data.
+ # Case 1: the other thread did not enter this locked code block yet.
+ # Prevent it from setting the cache.
+ self._last_updated[key] += 1
+ # Case 2: the other thread already exited this locked code block and set the
+ # cache. Invalidate it.
+ discard(self.cache, key)
+ else:
+ # No race condition
+ self._last_updated[key] += 1
+ if self.update_on_set:
+ self.cache[key] = value
+
+ @locked
def __delitem__(self, key: KT) -> None:
- self.cache.pop(key, None) # type: ignore
del self.data[key]
+ del self._last_updated[key]
+ discard(self.cache, key)
def __len__(self) -> int:
return len(self.data)
@@ -80,11 +122,6 @@ class Cache(ZictBase[KT, VT]):
# Do not let MutableMapping call self.data[key]
return key in self.data
- def keys(self) -> KeysView[KT]:
- # Return a potentially optimized set-like, instead of letting MutableMapping
- # build it from __iter__ on the fly
- return self.data.keys()
-
def flush(self) -> None:
flush(self.cache, self.data)
@@ -93,7 +130,7 @@ class Cache(ZictBase[KT, VT]):
if TYPE_CHECKING:
- # TODO Python 3.9: remove this branch and just use [] in the implementation below
+ # TODO remove this branch and just use [] in the implementation below (needs Python >=3.9)
class WeakValueMapping(weakref.WeakValueDictionary[KT, VT]):
...
diff --git a/zict/common.py b/zict/common.py
index 42648f3..ebfb1e0 100644
--- a/zict/common.py
+++ b/zict/common.py
@@ -1,67 +1,95 @@
from __future__ import annotations
-from collections.abc import Iterable, Mapping
+import threading
+from collections.abc import Callable, Iterable, Iterator, Mapping
+from contextlib import contextmanager
+from enum import Enum
+from functools import wraps
from itertools import chain
from typing import MutableMapping # TODO move to collections.abc (needs Python >=3.9)
-from typing import Any, TypeVar, overload
+from typing import TYPE_CHECKING, Any, TypeVar, cast
T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")
+if TYPE_CHECKING:
+ # TODO import ParamSpec from typing (needs Python >=3.10)
+ # TODO import Self from typing (needs Python >=3.11)
+ from typing_extensions import ParamSpec, Self
+
+ P = ParamSpec("P")
+
+
+class NoDefault(Enum):
+ nodefault = None
+
+
+nodefault = NoDefault.nodefault
+
class ZictBase(MutableMapping[KT, VT]):
"""Base class for zict mappings"""
- # TODO use positional-only arguments to protect self (requires Python 3.8+)
- @overload
- def update(self, __m: Mapping[KT, VT], **kwargs: VT) -> None:
- ...
-
- @overload
- def update(self, __m: Iterable[tuple[KT, VT]], **kwargs: VT) -> None:
- ...
-
- @overload
- def update(self, **kwargs: VT) -> None:
- ...
-
- def update(*args, **kwds):
- # Boilerplate for implementing an update() method
- if not args:
- raise TypeError(
- "descriptor 'update' of MutableMapping object " "needs an argument"
- )
- self = args[0]
- args = args[1:]
- if len(args) > 1:
- raise TypeError("update expected at most 1 arguments, got %d" % len(args))
- items = []
- if args:
- other = args[0]
- if isinstance(other, Mapping) or hasattr(other, "items"):
- items = other.items()
- else:
- # Assuming (key, value) pairs
- items = other
- if kwds:
- items = chain(items, kwds.items())
- self._do_update(items)
+ lock: threading.RLock
+
+ def __init__(self) -> None:
+ self.lock = threading.RLock()
+
+ def __getstate__(self) -> dict[str, Any]:
+ state = self.__dict__.copy()
+ del state["lock"]
+ return state
+
+ def __setstate__(self, state: dict[str, Any]) -> None:
+ self.__dict__ = state
+ self.lock = threading.RLock()
+
+ def update( # type: ignore[override]
+ self,
+ other: Mapping[KT, VT] | Iterable[tuple[KT, VT]] = (),
+ /,
+ **kwargs: VT,
+ ) -> None:
+ if hasattr(other, "items"):
+ other = other.items()
+ other = chain(other, kwargs.items()) # type: ignore
+ self._do_update(other)
def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None:
# Default implementation, can be overriden for speed
for k, v in items:
self[k] = v
+ def discard(self, key: KT) -> None:
+ """Flush *key* if possible.
+ Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``.
+ """
+ discard(self, key)
+
def close(self) -> None:
"""Release any system resources held by this object"""
- def __enter__(self: T) -> T:
+ def __enter__(self) -> Self:
return self
- def __exit__(self, *args) -> None:
+ def __exit__(self, *args: Any) -> None:
+ self.close()
+
+ def __del__(self) -> None:
self.close()
+ @contextmanager
+ def unlock(self) -> Iterator[None]:
+ """To be used in a method decorated by ``@locked``.
+ Temporarily releases the mapping's RLock.
+ """
+ self.lock.release()
+ try:
+ yield
+ finally:
+ self.lock.acquire()
+
def close(*z: Any) -> None:
"""Close *z* if possible."""
@@ -75,3 +103,27 @@ def flush(*z: Any) -> None:
for zi in z:
if hasattr(zi, "flush"):
zi.flush()
+
+
+def discard(m: MutableMapping[KT, VT], key: KT) -> None:
+ """Flush *key* if possible.
+ Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``.
+ """
+ try:
+ del m[key]
+ except KeyError:
+ pass
+
+
+def locked(func: Callable[P, VT]) -> Callable[P, VT]:
+ """Decorator for a method of ZictBase, which wraps the whole method in a
+ instance-specific (but not key-specific) rlock.
+ """
+
+ @wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> VT:
+ self = cast(ZictBase, args[0])
+ with self.lock:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/zict/file.py b/zict/file.py
index ebae757..60564da 100644
--- a/zict/file.py
+++ b/zict/file.py
@@ -2,39 +2,38 @@ from __future__ import annotations
import mmap
import os
+import pathlib
from collections.abc import Iterator
from urllib.parse import quote, unquote
-from zict.common import ZictBase
-
-
-def _safe_key(key: str) -> str:
- """
- Escape key so as to be usable on all filesystems.
- """
- # Even directory separators are unsafe.
- return quote(key, safe="")
-
-
-def _unsafe_key(key: str) -> str:
- """
- Undo the escaping done by _safe_key().
- """
- return unquote(key)
+from zict.common import ZictBase, locked
class File(ZictBase[str, bytes]):
"""Mutable Mapping interface to a directory
- Keys must be strings, values must be bytes
+ Keys must be strings, values must be buffers
Note this shouldn't be used for interprocess persistence, as keys
are cached in memory.
Parameters
----------
- directory: string
- mode: string, ('r', 'w', 'a'), defaults to 'a'
+ directory: str
+ Directory to write to. If it already exists, existing files will be imported as
+ mapping elements. If it doesn't exists, it will be created.
+ memmap: bool (optional)
+ If True, use `mmap` for reading. Defaults to False.
+
+ Notes
+ -----
+ If you call methods of this class from multiple threads, access will be fast as long
+ as atomic disk access such as ``open``, ``os.fstat``, and ``os.remove`` is fast.
+ This is not always the case, e.g. in case of slow network mounts or spun-down
+ magnetic drives.
+ Bytes read/write in the files is not protected by locks; this could cause failures
+ on Windows, NFS, and in general whenever it's not OK to delete a file while there
+ are file descriptors open on it.
Examples
--------
@@ -55,69 +54,105 @@ class File(ZictBase[str, bytes]):
"""
directory: str
- mode: str
memmap: bool
- _keys: set[str]
+ filenames: dict[str, str]
+ _inc: int
- def __init__(self, directory: str, mode: str = "a", memmap: bool = False):
- self.directory = directory
- self.mode = mode
+ def __init__(self, directory: str | pathlib.Path, memmap: bool = False):
+ super().__init__()
+ self.directory = str(directory)
self.memmap = memmap
- self._keys = set()
+ self.filenames = {}
+ self._inc = 0
+
if not os.path.exists(self.directory):
os.makedirs(self.directory, exist_ok=True)
else:
- for n in os.listdir(self.directory):
- self._keys.add(_unsafe_key(n))
+ for fn in os.listdir(self.directory):
+ self.filenames[self._unsafe_key(fn)] = fn
+ self._inc += 1
+
+ def _safe_key(self, key: str) -> str:
+ """Escape key so that it is usable on all filesystems.
+
+ Append to the filenames a unique suffix that changes every time this method is
+ called. This prevents race conditions when another thread accesses the same
+ key, e.g. ``__setitem__`` on one thread and ``__getitem__`` on another.
+ """
+ # `#` is escaped by quote and is supported by most file systems
+ key = quote(key, safe="") + f"#{self._inc}"
+ self._inc += 1
+ return key
+
+ @staticmethod
+ def _unsafe_key(key: str) -> str:
+ """Undo the escaping done by _safe_key()"""
+ key = key.split("#")[0]
+ return unquote(key)
def __str__(self) -> str:
- return f'<File: {self.directory}, mode="{self.mode}", {len(self)} elements>'
+ return f"<File: {self.directory}, {len(self)} elements>"
__repr__ = __str__
- def __getitem__(self, key: str) -> bytes:
- if key not in self._keys:
- raise KeyError(key)
- fn = os.path.join(self.directory, _safe_key(key))
- with open(fn, "rb") as fh:
- if self.memmap:
- return memoryview(mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ))
- else:
- return fh.read()
+ @locked
+ def __getitem__(self, key: str) -> bytearray | memoryview:
+ fn = os.path.join(self.directory, self.filenames[key])
+
+ # distributed.protocol.numpy.deserialize_numpy_ndarray makes sure that, if the
+ # numpy array was writeable before serialization, remains writeable afterwards.
+ # If it receives a read-only buffer (e.g. from fh.read() or from a mmap to a
+ # read-only file descriptor), it performs an expensive memcpy.
+ # Note that this is a dask-specific feature; vanilla pickle.loads will instead
+ # return an array with flags.writeable=False.
+ if self.memmap:
+ with open(fn, "r+b") as fh:
+ return memoryview(mmap.mmap(fh.fileno(), 0))
+ else:
+ with open(fn, "rb") as fh:
+ size = os.fstat(fh.fileno()).st_size
+ buf = bytearray(size)
+ with self.unlock():
+ nread = fh.readinto(buf)
+ assert nread == size
+ return buf
+
+ @locked
def __setitem__(
self,
key: str,
value: bytes
| bytearray
- | list[bytes]
- | list[bytearray]
- | tuple[bytes]
- | tuple[bytearray],
+ | memoryview
+ | list[bytes | bytearray | memoryview]
+ | tuple[bytes | bytearray | memoryview, ...],
) -> None:
- fn = os.path.join(self.directory, _safe_key(key))
- with open(fn, "wb") as fh:
+ self.discard(key)
+ fn = self._safe_key(key)
+ with open(os.path.join(self.directory, fn), "wb") as fh, self.unlock():
if isinstance(value, (tuple, list)):
fh.writelines(value)
else:
fh.write(value)
- self._keys.add(key)
- def __contains__(self, key: object) -> bool:
- return key in self._keys
+ if key in self.filenames:
+ # Race condition: two calls to __setitem__ from different threads on the
+ # same key at the same time
+ os.remove(os.path.join(self.directory, fn))
+ else:
+ self.filenames[key] = fn
- # FIXME dictionary views https://github.com/dask/zict/issues/61
- def keys(self) -> set[str]: # type: ignore
- return self._keys
+ def __contains__(self, key: object) -> bool:
+ return key in self.filenames
def __iter__(self) -> Iterator[str]:
- return iter(self._keys)
+ return iter(self.filenames)
+ @locked
def __delitem__(self, key: str) -> None:
- if key not in self._keys:
- raise KeyError(key)
- os.remove(os.path.join(self.directory, _safe_key(key)))
- self._keys.remove(key)
+ fn = self.filenames.pop(key)
+ os.remove(os.path.join(self.directory, fn))
def __len__(self) -> int:
- return len(self._keys)
+ return len(self.filenames)
diff --git a/zict/func.py b/zict/func.py
index 00b66dc..b45cfa0 100644
--- a/zict/func.py
+++ b/zict/func.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from collections.abc import Callable, Iterable, Iterator, KeysView, MutableMapping
+from collections.abc import Callable, Iterable, Iterator, MutableMapping
from typing import Generic, TypeVar
from zict.common import KT, VT, ZictBase, close, flush
@@ -27,7 +27,7 @@ class Func(ZictBase[KT, VT], Generic[KT, VT, WT]):
>>> def halve(x):
... return x / 2
- >>> d = dict()
+ >>> d = {}
>>> f = Func(double, halve, d)
>>> f['x'] = 10
>>> d
@@ -46,16 +46,16 @@ class Func(ZictBase[KT, VT], Generic[KT, VT, WT]):
load: Callable[[WT], VT],
d: MutableMapping[KT, WT],
):
- # FIXME https://github.com/python/mypy/issues/708
- self.dump = dump # type: ignore
- self.load = load # type: ignore
+ super().__init__()
+ self.dump = dump
+ self.load = load
self.d = d
def __getitem__(self, key: KT) -> VT:
- return self.load(self.d[key]) # type: ignore
+ return self.load(self.d[key])
def __setitem__(self, key: KT, value: VT) -> None:
- self.d[key] = self.dump(value) # type: ignore
+ self.d[key] = self.dump(value)
def __contains__(self, key: object) -> bool:
return key in self.d
@@ -63,18 +63,8 @@ class Func(ZictBase[KT, VT], Generic[KT, VT, WT]):
def __delitem__(self, key: KT) -> None:
del self.d[key]
- def keys(self) -> KeysView[KT]:
- return self.d.keys()
-
- # FIXME dictionary views https://github.com/dask/zict/issues/61
- def values(self) -> Iterator[VT]: # type: ignore
- return (self.load(v) for v in self.d.values()) # type: ignore
-
- def items(self) -> Iterator[tuple[KT, VT]]: # type: ignore
- return ((k, self.load(v)) for k, v in self.d.items()) # type: ignore
-
def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None:
- it = ((k, self.dump(v)) for k, v in items) # type: ignore
+ it = ((k, self.dump(v)) for k, v in items)
self.d.update(it)
def __iter__(self) -> Iterator[KT]:
@@ -95,7 +85,7 @@ class Func(ZictBase[KT, VT], Generic[KT, VT, WT]):
close(self.d)
-def funcname(func) -> str:
+def funcname(func: Callable) -> str:
"""Get the name of a function."""
while hasattr(func, "func"):
func = func.func
diff --git a/zict/lmdb.py b/zict/lmdb.py
index f4ac2db..c7648e1 100644
--- a/zict/lmdb.py
+++ b/zict/lmdb.py
@@ -1,7 +1,12 @@
from __future__ import annotations
+import pathlib
import sys
from collections.abc import Iterable, Iterator
+from typing import ( # TODO import from collections.abc (needs Python >=3.9)
+ ItemsView,
+ ValuesView,
+)
from zict.common import ZictBase
@@ -21,7 +26,18 @@ class LMDB(ZictBase[str, bytes]):
Parameters
----------
- directory: string
+ directory: str
+ map_size: int
+ On Linux and MacOS, maximum size of the database file on disk.
+ Defaults to 1 TiB on 64 bit systems and 1 GiB on 32 bit ones.
+
+ On Windows, preallocated total size of the database file on disk. Defaults to
+ 10 MiB to encourage explicitly setting it.
+
+ Notes
+ -----
+ None of this class is thread-safe - not even normally trivial methods such as
+ ``__len__ `` or ``__contains__``.
Examples
--------
@@ -31,24 +47,27 @@ class LMDB(ZictBase[str, bytes]):
b'123'
"""
- def __init__(self, directory: str):
+ def __init__(self, directory: str | pathlib.Path, map_size: int | None = None):
import lmdb
- # map_size is the maximum database size but shouldn't fill up the
- # virtual address space
- map_size = 1 << 40 if sys.maxsize >= 2**32 else 1 << 28
- # writemap requires sparse file support otherwise the whole
- # `map_size` may be reserved up front on disk
- writemap = sys.platform.startswith("linux")
+ super().__init__()
+ if map_size is None:
+ if sys.platform != "win32":
+ map_size = min(2**40, sys.maxsize // 4)
+ else:
+ map_size = 10 * 2**20
+
self.db = lmdb.open(
- directory,
+ str(directory),
subdir=True,
map_size=map_size,
sync=False,
- writemap=writemap,
+ writemap=True,
)
def __getitem__(self, key: str) -> bytes:
+ if not isinstance(key, str):
+ raise KeyError(key)
with self.db.begin() as txn:
value = txn.get(_encode_key(key))
if value is None:
@@ -56,6 +75,10 @@ class LMDB(ZictBase[str, bytes]):
return value
def __setitem__(self, key: str, value: bytes) -> None:
+ if not isinstance(key, str):
+ raise TypeError(key)
+ if not isinstance(value, bytes):
+ raise TypeError(value)
with self.db.begin(write=True) as txn:
txn.put(_encode_key(key), value)
@@ -65,30 +88,33 @@ class LMDB(ZictBase[str, bytes]):
with self.db.begin() as txn:
return txn.cursor().set_key(_encode_key(key))
- # FIXME dictionary views https://github.com/dask/zict/issues/61
- def items(self) -> Iterator[tuple[str, bytes]]: # type: ignore
- cursor = self.db.begin().cursor()
- return ((_decode_key(k), v) for k, v in cursor.iternext(keys=True, values=True))
-
- def keys(self) -> Iterator[str]: # type: ignore
+ def __iter__(self) -> Iterator[str]:
cursor = self.db.begin().cursor()
return (_decode_key(k) for k in cursor.iternext(keys=True, values=False))
- def values(self) -> Iterator[bytes]: # type: ignore
- cursor = self.db.begin().cursor()
- return cursor.iternext(keys=False, values=True)
+ def items(self) -> ItemsView[str, bytes]:
+ return LMDBItemsView(self)
+
+ def values(self) -> ValuesView[bytes]:
+ return LMDBValuesView(self)
def _do_update(self, items: Iterable[tuple[str, bytes]]) -> None:
# Optimized version of update() using a single putmulti() call.
- items_enc = [(_encode_key(k), v) for k, v in items]
+ items_enc = []
+ for key, value in items:
+ if not isinstance(key, str):
+ raise TypeError(key)
+ if not isinstance(value, bytes):
+ raise TypeError(value)
+ items_enc.append((_encode_key(key), value))
+
with self.db.begin(write=True) as txn:
consumed, added = txn.cursor().putmulti(items_enc)
assert consumed == added == len(items_enc)
- def __iter__(self) -> Iterator[str]:
- return self.keys()
-
def __delitem__(self, key: str) -> None:
+ if not isinstance(key, str):
+ raise KeyError(key)
with self.db.begin(write=True) as txn:
if not txn.delete(_encode_key(key)):
raise KeyError(key)
@@ -98,3 +124,35 @@ class LMDB(ZictBase[str, bytes]):
def close(self) -> None:
self.db.close()
+
+
+class LMDBItemsView(ItemsView[str, bytes]):
+ _mapping: LMDB # FIXME CPython implementation detail
+ __slots__ = ()
+
+ def __contains__(self, item: object) -> bool:
+ key: str
+ value: object
+ key, value = item # type: ignore
+ try:
+ v = self._mapping[key]
+ except KeyError:
+ return False
+ else:
+ return v == value
+
+ def __iter__(self) -> Iterator[tuple[str, bytes]]:
+ cursor = self._mapping.db.begin().cursor()
+ return ((_decode_key(k), v) for k, v in cursor.iternext(keys=True, values=True))
+
+
+class LMDBValuesView(ValuesView[bytes]):
+ _mapping: LMDB # FIXME CPython implementation detail
+ __slots__ = ()
+
+ def __contains__(self, value: object) -> bool:
+ return any(value == v for v in self)
+
+ def __iter__(self) -> Iterator[bytes]:
+ cursor = self._mapping.db.begin().cursor()
+ return cursor.iternext(keys=False, values=True)
diff --git a/zict/lru.py b/zict/lru.py
index 345ed83..c3f8b42 100644
--- a/zict/lru.py
+++ b/zict/lru.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from collections.abc import (
Callable,
+ Collection,
ItemsView,
Iterator,
KeysView,
@@ -9,9 +10,8 @@ from collections.abc import (
ValuesView,
)
-from heapdict import heapdict
-
-from zict.common import KT, VT, ZictBase, close, flush
+from zict.common import KT, VT, NoDefault, ZictBase, close, flush, locked, nodefault
+from zict.utils import InsertionSortedSet
class LRU(ZictBase[KT, VT]):
@@ -20,17 +20,33 @@ class LRU(ZictBase[KT, VT]):
Parameters
----------
n: int or float
- Number of elements to keep, or total weight if weight= is used
+ Number of elements to keep, or total weight if ``weight`` is used.
+ Any individual key that is heavier than n will be automatically evicted as soon
+ as it is inserted.
+
+ It can be updated after initialization. See also: ``offset`` attribute.
d: MutableMapping
- Dict-like in which to hold elements
- on_evict: list of callables
- Function:: k, v -> action to call on key value pairs prior to eviction
+ Dict-like in which to hold elements. There are no expectations on its internal
+ ordering. Iteration on the LRU follows the order of the underlying mapping.
+ on_evict: callable or list of callables
+ Function:: k, v -> action to call on key/value pairs prior to eviction
If an exception occurs during an on_evict callback (e.g a callback tried
storing to disk and raised a disk full error) the key will remain in the LRU.
+ on_cancel_evict: callable or list of callables
+ Function:: k, v -> action to call on key/value pairs if they're deleted or
+ updated from a thread while the on_evict callables are being executed in
+ another.
+ If you're not accessing the LRU from multiple threads, ignore this parameter.
weight: callable
Function:: k, v -> number to determine the size of keeping the item in
the mapping. Defaults to ``(k, v) -> 1``
+ Notes
+ -----
+ If you call methods of this class from multiple threads, access will be fast as long
+ as all methods of ``d`` are fast. Callbacks are not protected by locks and can be
+ arbitrarily slow.
+
Examples
--------
>>> lru = LRU(2, {}, on_evict=lambda k, v: print("Lost", k, v))
@@ -41,101 +57,196 @@ class LRU(ZictBase[KT, VT]):
"""
d: MutableMapping[KT, VT]
- heap: heapdict[KT, VT]
+ order: InsertionSortedSet[KT]
+ heavy: InsertionSortedSet[KT]
on_evict: list[Callable[[KT, VT], None]]
+ on_cancel_evict: list[Callable[[KT, VT], None]]
weight: Callable[[KT, VT], float]
+ #: Maximum weight before eviction is triggered, as set during initialization.
+ #: Updating this attribute doesn't trigger eviction by itself; you should call
+ #: :meth:`evict_until_below_target` explicitly afterwards.
n: float
- i: int
- total_weight: float
+ #: Offset to add to ``total_weight`` to determine if key/value pairs should be
+ #: evicted. It always starts at zero and can be updated afterwards. Updating this
+ #: attribute doesn't trigger eviction by itself; you should call
+ #: :meth:`evict_until_below_target` explicitly afterwards.
+ #: Increasing ``offset`` is not the same as reducing ``n``, as the latter will also
+ #: reduce the threshold below which a value is considered "heavy" and qualifies for
+ #: immediate eviction.
+ offset: float
weights: dict[KT, float]
+ closed: bool
+ total_weight: float
+ _cancel_evict: dict[KT, bool]
def __init__(
self,
n: float,
d: MutableMapping[KT, VT],
+ *,
on_evict: Callable[[KT, VT], None]
| list[Callable[[KT, VT], None]]
| None = None,
+ on_cancel_evict: Callable[[KT, VT], None]
+ | list[Callable[[KT, VT], None]]
+ | None = None,
weight: Callable[[KT, VT], float] = lambda k, v: 1,
):
+ super().__init__()
self.d = d
self.n = n
- self.heap = heapdict()
- self.i = 0
+ self.offset = 0
+
if callable(on_evict):
on_evict = [on_evict]
self.on_evict = on_evict or []
- # FIXME https://github.com/python/mypy/issues/708
- self.weight = weight # type: ignore
- self.total_weight = 0
- self.weights = {}
+ if callable(on_cancel_evict):
+ on_cancel_evict = [on_cancel_evict]
+ self.on_cancel_evict = on_cancel_evict or []
+
+ self.weight = weight
+ self.weights = {k: weight(k, v) for k, v in d.items()}
+ self.total_weight = sum(self.weights.values())
+ self.order = InsertionSortedSet(d)
+ self.heavy = InsertionSortedSet(k for k, v in self.weights.items() if v >= n)
+ self.closed = False
+ self._cancel_evict = {}
+ @locked
def __getitem__(self, key: KT) -> VT:
result = self.d[key]
- self.i += 1
- self.heap[key] = self.i
+ self.order.remove(key)
+ self.order.add(key)
+ return result
+
+ @locked
+ def get_all_or_nothing(self, keys: Collection[KT]) -> dict[KT, VT]:
+ """If all keys exist in the LRU, update their FIFO priority and return their
+ values; this would be the same as ``{k: lru[k] for k in keys}``.
+ If any keys are missing, however, raise KeyError for the first one missing and
+ do not bring any of the available keys to the top of the LRU.
+ """
+ result = {key: self.d[key] for key in keys}
+ for key in keys:
+ self.order.remove(key)
+ self.order.add(key)
return result
def __setitem__(self, key: KT, value: VT) -> None:
- if key in self.d:
- del self[key]
-
- weight = self.weight(key, value) # type: ignore
-
- def set_():
- self.d[key] = value
- self.i += 1
- self.heap[key] = self.i
- self.weights[key] = weight
- self.total_weight += weight
- # Evicting the last key/value pair is guaranteed to fail, so don't try.
- # This is because it is always the last one inserted by virtue of this
- # being an LRU, which in turn means we reached this point because
- # weight > self.n and a callbacks raised exception (e.g. disk full).
- while self.total_weight > self.n and len(self.d) > 1:
- self.evict()
+ self.set_noevict(key, value)
+ try:
+ self.evict_until_below_target()
+ except Exception:
+ if self.weights.get(key, 0) > self.n and key not in self.heavy:
+ # weight(value) > n and evicting the key we just inserted failed.
+ # Evict the rest of the LRU instead.
+ try:
+ while len(self.d) > 1:
+ self.evict()
+ except Exception:
+ pass
+ raise
+
+ @locked
+ def set_noevict(self, key: KT, value: VT) -> None:
+ """Variant of ``__setitem__`` that does not evict if the total weight exceeds n.
+ Unlike ``__setitem__``, this method does not depend on the ``on_evict``
+ functions to be thread-safe for its own thread-safety. It also is not prone to
+ re-raising exceptions from the ``on_evict`` callbacks.
+ """
+ self.discard(key)
+ weight = self.weight(key, value)
+ if key in self._cancel_evict:
+ self._cancel_evict[key] = True
+ self.d[key] = value
+ self.order.add(key)
+ if weight > self.n:
+ self.heavy.add(key) # Mark this key to be evicted first
+ self.weights[key] = weight
+ self.total_weight += weight
- if weight <= self.n:
- set_()
- else:
+ def evict_until_below_target(self, n: float | None = None) -> None:
+ """Evict key/value pairs until the total weight falls below n
+
+ Parameters
+ ----------
+ n: float, optional
+ Total weight threshold to achieve. Defaults to self.n.
+ """
+ if n is None:
+ n = self.n
+ while self.total_weight + self.offset > n and not self.closed:
try:
- for cb in self.on_evict:
- cb(key, value)
- except Exception:
- # e.g. if a callback tried storing to disk and raised a disk full error
- set_()
- raise
+ self.evict()
+ except KeyError:
+ return # Multithreaded race condition
- def evict(self) -> tuple[KT, VT, float]:
- """Evict least recently used key
+ @locked
+ def evict(
+ self, key: KT | NoDefault = nodefault
+ ) -> tuple[KT, VT, float] | tuple[None, None, float]:
+ """Evict least recently used key, or least recently inserted key with individual
+ weight > n, if any. You may also evict a specific key.
This is typically called from internal use, but can be externally
triggered as well.
Returns
-------
- k: key
- v: value
- w: weight
+ Tuple of (key, value, weight)
+
+ Or (None, None, 0) if the key that was being evicted was updated or deleted from
+ another thread while the on_evict callbacks were being executed. This outcome is
+ only possible in multithreaded access.
"""
- k, priority = self.heap.popitem()
- v = self.d.pop(k)
+ if key is nodefault:
+ try:
+ key = next(iter(self.heavy or self.order))
+ except StopIteration:
+ raise KeyError("evict(): dictionary is empty")
+
+ if key in self._cancel_evict:
+ return None, None, 0
+
+ # For the purpose of multithreaded access, it's important that the value remains
+ # in self.d until all callbacks are successful.
+ # When this is used inside a Buffer, there must never be a moment when the key
+ # is neither in fast nor in slow.
+ value = self.d[key]
+
+ # If we are evicting a heavy key we just inserted and one of the callbacks
+ # fails, put it at the bottom of the LRU instead of the top. This way lighter
+ # keys will have a chance to be evicted first and make space.
+ self.heavy.discard(key)
+
+ self._cancel_evict[key] = False
try:
- for cb in self.on_evict:
- cb(k, v)
- except Exception:
- # e.g. if a callback tried storing to disk and raised a disk full error
- self.heap[k] = priority
- self.d[k] = v
- raise
+ with self.unlock():
+ # This may raise; e.g. if a callback tries storing to a full disk
+ for cb in self.on_evict:
+ cb(key, value)
+
+ if self._cancel_evict[key]:
+ for cb in self.on_cancel_evict:
+ cb(key, value)
+ return None, None, 0
+ finally:
+ del self._cancel_evict[key]
- weight = self.weights.pop(k)
+ del self.d[key]
+ self.order.remove(key)
+ weight = self.weights.pop(key)
self.total_weight -= weight
- return k, v, weight
+ return key, value, weight
+
+ @locked
def __delitem__(self, key: KT) -> None:
+ if key in self._cancel_evict:
+ self._cancel_evict[key] = True
del self.d[key]
- del self.heap[key]
+ self.order.remove(key)
+ self.heavy.discard(key)
self.total_weight -= self.weights.pop(key)
def keys(self) -> KeysView[KT]:
@@ -158,7 +269,7 @@ class LRU(ZictBase[KT, VT]):
def __str__(self) -> str:
sub = str(self.d) if not isinstance(self.d, dict) else "dict"
- return f"<LRU: {self.total_weight}/{self.n} on {sub}>"
+ return f"<LRU: {self.total_weight + self.offset}/{self.n} on {sub}>"
__repr__ = __str__
@@ -166,4 +277,5 @@ class LRU(ZictBase[KT, VT]):
flush(self.d)
def close(self) -> None:
+ self.closed = True
close(self.d)
diff --git a/zict/sieve.py b/zict/sieve.py
index 8640b8d..4ac4e3f 100644
--- a/zict/sieve.py
+++ b/zict/sieve.py
@@ -2,10 +2,9 @@ from __future__ import annotations
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Mapping, MutableMapping
-from itertools import chain
from typing import Generic, TypeVar
-from zict.common import KT, VT, ZictBase, close, flush
+from zict.common import KT, VT, ZictBase, close, discard, flush, locked
MKT = TypeVar("MKT")
@@ -23,85 +22,101 @@ class Sieve(ZictBase[KT, VT], Generic[KT, VT, MKT]):
mappings: dict of {mapping key: MutableMapping}
selector: callable (key, value) -> mapping key
+ Notes
+ -----
+ If you call methods of this class from multiple threads, access will be fast as long
+ as the ``__contains__`` and ``__delitem__`` methods of all underlying mappins are
+ fast. ``__getitem__`` and ``__setitem__`` methods of the underlying mappings are not
+ protected by locks.
+
Examples
--------
>>> small = {}
>>> large = DataBase() # doctest: +SKIP
>>> mappings = {True: small, False: large} # doctest: +SKIP
>>> def is_small(key, value): # doctest: +SKIP
- return sys.getsizeof(value) < 10000
+ ... return sys.getsizeof(value) < 10000 # doctest: +SKIP
>>> d = Sieve(mappings, is_small) # doctest: +SKIP
-
- See Also
- --------
- Buffer
"""
mappings: Mapping[MKT, MutableMapping[KT, VT]]
selector: Callable[[KT, VT], MKT]
key_to_mapping: dict[KT, MutableMapping[KT, VT]]
+ gen: int
def __init__(
self,
mappings: Mapping[MKT, MutableMapping[KT, VT]],
selector: Callable[[KT, VT], MKT],
):
+ super().__init__()
self.mappings = mappings
- # FIXME https://github.com/python/mypy/issues/708
- self.selector = selector # type: ignore
+ self.selector = selector
self.key_to_mapping = {}
+ self.gen = 0
def __getitem__(self, key: KT) -> VT:
+ # Note that this may raise KeyError if you call it in the middle of __setitem__
+ # or update for an already existing key
return self.key_to_mapping[key][key]
+ @locked
def __setitem__(self, key: KT, value: VT) -> None:
- old_mapping = self.key_to_mapping.get(key)
- mkey = self.selector(key, value) # type: ignore
+ discard(self, key)
+ mkey = self.selector(key, value)
mapping = self.mappings[mkey]
- if old_mapping is not None and old_mapping is not mapping:
- del old_mapping[key]
- mapping[key] = value
self.key_to_mapping[key] = mapping
+ self.gen += 1
+ gen = self.gen
+
+ with self.unlock():
+ mapping[key] = value
+
+ if gen != self.gen and self.key_to_mapping.get(key) is not mapping:
+ # Multithreaded race condition
+ discard(mapping, key)
+ @locked
def __delitem__(self, key: KT) -> None:
- del self.key_to_mapping.pop(key)[key]
+ mapping = self.key_to_mapping.pop(key)
+ self.gen += 1
+ discard(mapping, key)
+ @locked
def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None:
# Optimized update() implementation issuing a single update()
# call per underlying mapping.
updates = defaultdict(list)
- mapping_ids = {id(m): m for m in self.mappings.values()}
+ self.gen += 1
+ gen = self.gen
for key, value in items:
- old_mapping = self.key_to_mapping.get(key)
- mkey = self.selector(key, value) # type: ignore
+ old_mapping = self.key_to_mapping.pop(key, None)
+ if old_mapping is not None:
+ discard(old_mapping, key)
+ mkey = self.selector(key, value)
mapping = self.mappings[mkey]
- if old_mapping is not None and old_mapping is not mapping:
- del old_mapping[key]
- # Can't hash a mutable mapping, so use its id() instead
- updates[id(mapping)].append((key, value))
-
- for mid, mitems in updates.items():
- mapping = mapping_ids[mid]
- mapping.update(mitems)
- for key, _ in mitems:
- self.key_to_mapping[key] = mapping
-
- # FIXME dictionary views https://github.com/dask/zict/issues/61
- def keys(self) -> Iterator[KT]: # type: ignore
- return chain.from_iterable(self.mappings.values())
-
- def values(self) -> Iterator[VT]: # type: ignore
- return chain.from_iterable(m.values() for m in self.mappings.values())
-
- def items(self) -> Iterator[tuple[KT, VT]]: # type: ignore
- return chain.from_iterable(m.items() for m in self.mappings.values())
+ updates[mkey].append((key, value))
+ self.key_to_mapping[key] = mapping
+
+ with self.unlock():
+ for mkey, mitems in updates.items():
+ mapping = self.mappings[mkey]
+ mapping.update(mitems)
+
+ if gen != self.gen:
+ # Multithreaded race condition
+ for mkey, mitems in updates.items():
+ mapping = self.mappings[mkey]
+ for key, _ in mitems:
+ if self.key_to_mapping.get(key) is not mapping:
+ discard(mapping, key)
def __len__(self) -> int:
- return sum(map(len, self.mappings.values()))
+ return len(self.key_to_mapping)
def __iter__(self) -> Iterator[KT]:
- return self.keys()
+ return iter(self.key_to_mapping)
def __contains__(self, key: object) -> bool:
return key in self.key_to_mapping
diff --git a/zict/tests/conftest.py b/zict/tests/conftest.py
new file mode 100644
index 0000000..395e7f5
--- /dev/null
+++ b/zict/tests/conftest.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+
+import gc
+import sys
+import threading
+from concurrent.futures import ThreadPoolExecutor
+
+import pytest
+
+try:
+ import psutil
+except ImportError:
+ psutil = None # type: ignore
+
+
+@pytest.fixture
+def check_fd_leaks():
+ if sys.platform == "win32" or psutil is None:
+ yield
+ else:
+ proc = psutil.Process()
+ before = proc.num_fds()
+ yield
+ gc.collect()
+ assert proc.num_fds() == before
+
+
+@pytest.fixture
+def is_locked():
+ """Callable that returns True if the parameter zict mapping has its RLock engaged"""
+ with ThreadPoolExecutor(1) as ex:
+
+ def __is_locked(d):
+ out = d.lock.acquire(blocking=False)
+ if out:
+ d.lock.release()
+ return not out
+
+ def _is_locked(d):
+ return ex.submit(__is_locked, d).result()
+
+ yield _is_locked
+
+
+@pytest.fixture
+def check_thread_leaks():
+ active_threads_start = threading.enumerate()
+
+ yield
+
+ bad_threads = [
+ thread for thread in threading.enumerate() if thread not in active_threads_start
+ ]
+ if bad_threads:
+ raise RuntimeError(f"Leaked thread(s): {bad_threads}")
diff --git a/zict/tests/test_async_buffer.py b/zict/tests/test_async_buffer.py
new file mode 100644
index 0000000..5165faf
--- /dev/null
+++ b/zict/tests/test_async_buffer.py
@@ -0,0 +1,272 @@
+import asyncio
+import contextvars
+import threading
+import time
+from collections import UserDict
+from concurrent.futures import Executor, Future
+
+import pytest
+
+from zict import AsyncBuffer, Func
+from zict.tests import utils_test
+
+
+@pytest.mark.asyncio
+async def test_simple(check_thread_leaks):
+ with AsyncBuffer({}, utils_test.SlowDict(0.01), n=3) as buff:
+ buff["a"] = 1
+ buff["b"] = 2
+ buff["c"] = 3
+ assert set(buff.fast) == {"a", "b", "c"}
+ assert not buff.slow
+ assert not buff.futures
+
+ buff["d"] = 4
+ assert set(buff.fast) == {"a", "b", "c", "d"}
+ assert not buff.slow
+ assert buff.futures
+ await asyncio.wait(buff.futures)
+ assert set(buff.fast) == {"b", "c", "d"}
+ assert set(buff.slow) == {"a"}
+
+ buff.async_evict_until_below_target()
+ assert not buff.futures
+ buff.async_evict_until_below_target(10)
+ assert not buff.futures
+ buff.async_evict_until_below_target(2)
+ assert buff.futures
+ await asyncio.wait(buff.futures)
+ assert set(buff.fast) == {"c", "d"}
+ assert set(buff.slow) == {"a", "b"}
+
+ # Do not incur in threading sync cost if everything is in fast
+ assert list(buff.fast.order) == ["c", "d"]
+ future = buff.async_get(["c"])
+ assert future.done()
+ assert await future == {"c": 3}
+ assert list(buff.fast.order) == ["d", "c"]
+
+ # Do not disturb LRU order in case of missing keys
+ with pytest.raises(KeyError, match="m"):
+ _ = buff.async_get(["d", "m"], missing="raise")
+ assert list(buff.fast.order) == ["d", "c"]
+
+ future = buff.async_get(["d", "m"], missing="omit")
+ assert future.done()
+ assert await future == {"d": 4}
+ assert list(buff.fast.order) == ["c", "d"]
+
+ with pytest.raises(ValueError):
+ _ = buff.async_get(["a"], missing="misspell")
+
+ # Asynchronously retrieve from slow
+ future = buff.async_get(["a", "b"])
+ assert not future.done()
+ assert future in buff.futures
+ assert await future == {"a": 1, "b": 2}
+ assert not buff.futures
+ assert set(buff.fast) == {"d", "a", "b"}
+ assert set(buff.slow) == {"c"}
+
+
+@pytest.mark.asyncio
+async def test_double_evict(check_thread_leaks):
+ """User calls async_evict_until_below_target() while the same is already running"""
+ with AsyncBuffer({}, utils_test.SlowDict(0.01), n=3) as buff:
+ buff["x"] = 1
+ buff["y"] = 2
+ buff["z"] = 3
+ assert len(buff.fast) == 3
+ assert not buff.futures
+
+ buff.async_evict_until_below_target(2)
+ assert len(buff.futures) == 1
+ assert list(buff.evicting.values()) == [2]
+
+ # Evicting to the same n is a no-op
+ buff.async_evict_until_below_target(2)
+ assert len(buff.futures) == 1
+ assert list(buff.evicting.values()) == [2]
+
+ # Evicting to a lower n while a previous eviction is still running does not
+ # cancel the previous eviction
+ buff.async_evict_until_below_target(1)
+ assert len(buff.futures) == 2
+ assert list(buff.evicting.values()) == [2, 1]
+ await asyncio.wait(buff.futures, return_when=asyncio.FIRST_COMPLETED)
+ assert len(buff.futures) == 1
+ assert list(buff.evicting.values()) == [1]
+ await asyncio.wait(buff.futures)
+ assert not buff.futures
+ assert not buff.evicting
+
+ assert buff.fast == {"z": 3}
+ assert buff.slow.data == {"x": 1, "y": 2}
+
+ # Evicting to negative n while fast is empty does nothing
+ buff.evict_until_below_target(0)
+ buff.async_evict_until_below_target(-1)
+ assert not buff.futures
+ assert not buff.evicting
+
+
+@pytest.mark.asyncio
+async def test_close_during_evict(check_thread_leaks):
+ buff = AsyncBuffer({}, utils_test.SlowDict(0.01), n=100)
+ buff.update({i: i for i in range(100)})
+ assert not buff.futures
+ assert len(buff.fast) == 100
+
+ buff.async_evict_until_below_target(0)
+ while not buff.slow:
+ await asyncio.sleep(0.01)
+ assert buff.fast
+ assert buff.futures
+
+ buff.close()
+ await asyncio.wait(buff.futures)
+ assert not buff.futures
+ assert buff.fast
+ assert buff.slow
+
+
+@pytest.mark.asyncio
+async def test_close_during_get(check_thread_leaks):
+ buff = AsyncBuffer({}, utils_test.SlowDict(0.01), n=100)
+ buff.slow.data.update({i: i for i in range(100)})
+ assert len(buff) == 100
+ assert not buff.fast
+
+ future = buff.async_get(list(range(100)))
+ assert buff.futures
+ while not buff.fast:
+ await asyncio.sleep(0.01)
+
+ buff.close()
+ with pytest.raises(asyncio.CancelledError):
+ await future
+ await asyncio.wait(buff.futures)
+ assert not buff.futures
+
+ assert buff.fast
+ assert buff.slow
+
+
+@pytest.mark.asyncio
+async def test_contextvars(check_thread_leaks):
+ ctx = contextvars.ContextVar("v", default=0)
+ in_dump = threading.Event()
+ in_load = threading.Event()
+ block_dump = threading.Event()
+ block_load = threading.Event()
+
+ def dump(v):
+ in_dump.set()
+ assert block_dump.wait(timeout=5)
+ return v + ctx.get()
+
+ def load(v):
+ in_load.set()
+ assert block_load.wait(timeout=5)
+ return v + ctx.get()
+
+ with AsyncBuffer({}, Func(dump, load, {}), n=0.1) as buff:
+ ctx.set(20) # Picked up by dump
+ buff["x"] = 1
+ assert buff.futures
+ assert in_dump.wait(timeout=5)
+ ctx.set(300) # Changed while dump runs. Won't be picked up until load.
+ block_dump.set()
+ await asyncio.wait(buff.futures)
+ assert buff.slow.d == {"x": 21}
+ fut = buff.async_get(["x"])
+ assert in_load.wait(timeout=5)
+ ctx.set(4000) # Changed while load runs. Won't be picked up.
+ block_load.set()
+ assert await fut == {"x": 321} # 1 + 20 (added by dump) + 300 (added by load)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("missing", ["raise", "omit"])
+async def test_race_condition_get_async_delitem(check_thread_leaks, missing):
+ """All required keys exist in slow when you call get_async(); however some are
+ deleted by the time the offloaded thread retrieves their values.
+ """
+
+ class Slow(UserDict):
+ def __getitem__(self, key):
+ if key in self:
+ time.sleep(0.01)
+ return super().__getitem__(key)
+
+ with AsyncBuffer({}, Slow(), n=100) as buff:
+ buff.slow.update({i: i for i in range(100)})
+ assert len(buff) == 100
+
+ future = buff.async_get(list(range(100)), missing=missing)
+ while not buff.fast:
+ await asyncio.sleep(0.01)
+ assert buff.slow
+ # Don't use clear(); it uses __iter__ which would not return until restore is
+ # completed
+ for i in range(100):
+ del buff[i]
+ assert not buff.fast
+ assert not buff.slow
+ assert not future.done()
+
+ if missing == "raise":
+ with pytest.raises(KeyError):
+ await future
+ else:
+ out = await future
+ assert 0 < len(out) < 100
+
+
+@pytest.mark.asyncio
+async def test_multiple_offload_threads():
+ barrier = threading.Barrier(2)
+
+ class Slow(UserDict):
+ def __getitem__(self, key):
+ barrier.wait(timeout=5)
+ return super().__getitem__(key)
+
+ with AsyncBuffer({}, Slow(), n=100, nthreads=2) as buff:
+ buff["x"] = 1
+ buff["y"] = 2
+ buff.evict_until_below_target(0)
+ assert not buff.fast
+ assert set(buff.slow) == {"x", "y"}
+
+ out = await asyncio.gather(buff.async_get(["x"]), buff.async_get(["y"]))
+ assert out == [{"x": 1}, {"y": 2}]
+
+
+@pytest.mark.asyncio
+async def test_external_executor():
+ n_submit = 0
+
+ class MyExecutor(Executor):
+ def submit(self, fn, /, *args, **kwargs):
+ nonlocal n_submit
+ n_submit += 1
+ out = fn(*args, **kwargs)
+ f = Future()
+ f.set_result(out)
+ return f
+
+ def shutdown(self, *args, **kwargs):
+ raise AssertionError("AsyncBuffer.close() called executor.shutdown()")
+
+ ex = MyExecutor()
+ buff = AsyncBuffer({}, {}, n=1, executor=ex)
+ buff["x"] = 1
+ buff["y"] = 2 # Evict x
+ assert buff.fast.d == {"y": 2}
+ assert buff.slow == {"x": 1}
+ assert await buff.async_get(["x"]) == {"x": 1} # Restore x, evict y
+ assert buff.fast.d == {"x": 1}
+ assert buff.slow == {"y": 2}
+ assert n_submit == 2
+ buff.close()
diff --git a/zict/tests/test_buffer.py b/zict/tests/test_buffer.py
index 5211551..b5d23bd 100644
--- a/zict/tests/test_buffer.py
+++ b/zict/tests/test_buffer.py
@@ -1,13 +1,18 @@
+import random
+import threading
+from collections import UserDict
+from concurrent.futures import ThreadPoolExecutor
+
import pytest
-import zict
+from zict import Buffer
from zict.tests import utils_test
def test_simple():
- a = dict()
- b = dict()
- buff = zict.Buffer(a, b, n=10, weight=lambda k, v: v)
+ a = {}
+ b = {}
+ buff = Buffer(a, b, n=10, weight=lambda k, v: v)
buff["x"] = 1
buff["y"] = 2
@@ -64,10 +69,9 @@ def test_simple():
def test_setitem_avoid_fast_slow_duplicate():
-
- a = dict()
- b = dict()
- buff = zict.Buffer(a, b, n=10, weight=lambda k, v: v)
+ a = {}
+ b = {}
+ buff = Buffer(a, b, n=10, weight=lambda k, v: v)
for first, second in [(1, 12), (12, 1)]:
buff["a"] = first
assert buff["a"] == first
@@ -91,10 +95,19 @@ def test_mapping():
"""
a = {}
b = {}
- buff = zict.Buffer(a, b, n=2)
+ buff = Buffer(a, b, n=2)
utils_test.check_mapping(buff)
utils_test.check_closing(buff)
+ buff.clear()
+ assert not buff.slow
+ assert not buff._cancel_restore
+ assert not buff.fast
+ assert not buff.fast.d
+ assert not buff.fast.weights
+ assert not buff.fast.total_weight
+ assert not buff.fast._cancel_evict
+
def test_callbacks():
f2s = []
@@ -107,9 +120,9 @@ def test_callbacks():
def s2f_cb(k, v):
s2f.append(k)
- a = dict()
- b = dict()
- buff = zict.Buffer(
+ a = {}
+ b = {}
+ buff = Buffer(
a,
b,
n=10,
@@ -158,7 +171,7 @@ def test_callbacks_exception_catch():
a = {}
b = {}
- buff = zict.Buffer(
+ buff = Buffer(
a,
b,
n=10,
@@ -186,7 +199,7 @@ def test_callbacks_exception_catch():
assert b == {"x": 1}
# Add key > n, again total weight > n this will move everything to slow except w
- # that stays in fast due after callback raise
+ # that stays in fast due to callback raising
with pytest.raises(MyError):
buff["w"] = 11
@@ -194,3 +207,219 @@ def test_callbacks_exception_catch():
assert s2f == []
assert a == {"w": 11}
assert b == {"x": 1, "y": 2, "z": 8}
+
+
+def test_n_offset():
+ buff = Buffer({}, {}, n=5)
+ assert buff.n == 5
+ assert buff.fast.n == 5
+ buff.n = 3
+ assert buff.fast.n == 3
+ assert buff.offset == 0
+ assert buff.fast.offset == 0
+ buff.offset = 2
+ assert buff.offset == 2
+ assert buff.fast.offset == 2
+
+
+def test_set_noevict():
+ a = {}
+ b = {}
+ f2s = []
+ s2f = []
+ buff = Buffer(
+ a,
+ b,
+ n=5,
+ weight=lambda k, v: v,
+ fast_to_slow_callbacks=lambda k, v: f2s.append(k),
+ slow_to_fast_callbacks=lambda k, v: s2f.append(k),
+ )
+ buff.set_noevict("x", 3)
+ buff.set_noevict("y", 3) # Would cause x to move to slow
+ buff.set_noevict("z", 6) # >n; would be immediately evicted
+
+ assert a == {"x": 3, "y": 3, "z": 6}
+ assert b == {}
+ assert f2s == s2f == []
+
+ buff.evict_until_below_target()
+ assert a == {"y": 3}
+ assert b == {"z": 6, "x": 3}
+ assert f2s == ["z", "x"]
+ assert s2f == []
+
+ # set_noevict clears slow
+ f2s.clear()
+ buff.set_noevict("x", 1)
+ assert a == {"y": 3, "x": 1}
+ assert b == {"z": 6}
+ assert f2s == s2f == []
+
+ # Custom target; 0 != None
+ buff.evict_until_below_target(0)
+ assert a == {}
+ assert b == {"z": 6, "x": 1, "y": 3}
+ assert f2s == ["y", "x"]
+ assert s2f == []
+
+
+def test_evict_restore_during_iter():
+ """Test that __iter__ won't be disrupted if another thread evicts or restores a key"""
+ buff = Buffer({"x": 1, "y": 2}, {"z": 3}, n=5)
+ assert list(buff) == ["x", "y", "z"]
+ it = iter(buff)
+ assert next(it) == "x"
+ buff.fast.evict("x")
+ assert next(it) == "y"
+ assert buff["x"] == 1
+ assert next(it) == "z"
+ with pytest.raises(StopIteration):
+ next(it)
+
+
+@pytest.mark.parametrize("event", ("set", "set_noevict", "del"))
+@pytest.mark.parametrize("when", ("before", "after"))
+def test_cancel_evict(event, when):
+ """See also:
+
+ test_cancel_restore
+ test_lru.py::test_cancel_evict
+ """
+ ev1 = threading.Event()
+ ev2 = threading.Event()
+
+ class Slow(UserDict):
+ def __setitem__(self, k, v):
+ if when == "before":
+ ev1.set()
+ assert ev2.wait(timeout=5)
+ super().__setitem__(k, v)
+ else:
+ super().__setitem__(k, v)
+ ev1.set()
+ assert ev2.wait(timeout=5)
+
+ buff = Buffer({}, Slow(), n=100, weight=lambda k, v: v)
+ buff.set_noevict("x", 1)
+ with ThreadPoolExecutor(1) as ex:
+ fut = ex.submit(buff.fast.evict)
+ assert ev1.wait(timeout=5)
+ # cb is running
+
+ if event == "set":
+ buff["x"] = 2
+ elif event == "set_noevict":
+ buff.set_noevict("x", 2)
+ else:
+ assert event == "del"
+ del buff["x"]
+ ev2.set()
+ assert fut.result() == (None, None, 0)
+
+ if event in ("set", "set_noevict"):
+ assert buff.fast == {"x": 2}
+ assert not buff.slow
+ assert buff.fast.weights == {"x": 2}
+ assert list(buff.fast.order) == ["x"]
+ else:
+ assert not buff.fast
+ assert not buff.slow
+ assert not buff.fast.weights
+ assert not buff.fast.order
+
+ assert not buff.fast._cancel_evict
+
+
+@pytest.mark.parametrize("event", ("set", "set_noevict", "del"))
+@pytest.mark.parametrize("when", ("before", "after"))
+def test_cancel_restore(event, when):
+ """See also:
+
+ test_cancel_evict
+ test_lru.py::test_cancel_evict
+ """
+ ev1 = threading.Event()
+ ev2 = threading.Event()
+
+ class Slow(UserDict):
+ def __getitem__(self, k):
+ if when == "before":
+ ev1.set()
+ assert ev2.wait(timeout=5)
+ return super().__getitem__(k)
+ else:
+ out = super().__getitem__(k)
+ ev1.set()
+ assert ev2.wait(timeout=5)
+ return out
+
+ buff = Buffer({}, Slow(), n=100, weight=lambda k, v: v)
+ buff.set_noevict("x", 1)
+ buff.fast.evict()
+ assert not buff.fast
+ assert set(buff.slow) == {"x"}
+
+ with ThreadPoolExecutor(1) as ex:
+ fut = ex.submit(buff.__getitem__, "x")
+ assert ev1.wait(timeout=5)
+ # cb is running
+
+ if event == "set":
+ buff["x"] = 2
+ elif event == "set_noevict":
+ buff.set_noevict("x", 2)
+ else:
+ assert event == "del"
+ del buff["x"]
+ ev2.set()
+
+ with pytest.raises(KeyError, match="x"):
+ fut.result()
+
+ if event in ("set", "set_noevict"):
+ assert buff.fast == {"x": 2}
+ assert not buff.slow
+ assert buff.fast.weights == {"x": 2}
+ assert list(buff.fast.order) == ["x"]
+ else:
+ assert not buff.fast
+ assert not buff.slow
+ assert not buff.fast.weights
+ assert not buff.fast.order
+
+ assert not buff._cancel_restore
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+def test_stress_different_keys_threadsafe():
+ # Sometimes x and y can cohexist without triggering eviction
+ # Sometimes x and y are individually <n but when they're both in they cause eviction
+ # Sometimes x or y are heavy
+ buff = Buffer(
+ {},
+ utils_test.SlowDict(0.001),
+ n=1,
+ weight=lambda k, v: random.choice([0.4, 0.9, 1.1]),
+ )
+ utils_test.check_different_keys_threadsafe(buff)
+ assert not buff.fast
+ assert not buff.slow
+ utils_test.check_mapping(buff)
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+def test_stress_same_key_threadsafe():
+ # Sometimes x is heavy
+ buff = Buffer(
+ {},
+ utils_test.SlowDict(0.001),
+ n=1,
+ weight=lambda k, v: random.choice([0.9, 1.1]),
+ )
+ utils_test.check_same_key_threadsafe(buff)
+ assert not buff.fast
+ assert not buff.slow
+ utils_test.check_mapping(buff)
diff --git a/zict/tests/test_cache.py b/zict/tests/test_cache.py
index ebdbcf1..70c61e6 100644
--- a/zict/tests/test_cache.py
+++ b/zict/tests/test_cache.py
@@ -1,9 +1,12 @@
import gc
+import threading
from collections import UserDict
+from concurrent.futures import ThreadPoolExecutor
import pytest
-from zict.cache import Cache, WeakValueMapping
+from zict import Cache, WeakValueMapping
+from zict.tests import utils_test
def test_cache_get_set_del():
@@ -57,7 +60,7 @@ def test_do_not_read_from_data():
class D(UserDict):
def __getitem__(self, key):
- assert False
+ raise AssertionError()
d = Cache(D({1: 10, 2: 20}), {})
assert len(d) == 2
@@ -129,3 +132,280 @@ def test_weakvaluemapping():
b = "bbb"
d["b"] = b
assert "b" not in d
+
+
+def test_mapping():
+ """
+ Test mapping interface for Cache().
+ """
+ buff = Cache({}, {})
+ utils_test.check_mapping(buff)
+ utils_test.check_closing(buff)
+
+ buff.clear()
+ assert not buff.cache
+ assert not buff.data
+ assert not buff._last_updated
+
+
+@pytest.mark.parametrize("get_when", ("before", "after"))
+@pytest.mark.parametrize("set_when", ("before", "after"))
+@pytest.mark.parametrize(
+ "starts_first,seed,update_on_set",
+ [
+ ("get", True, False),
+ ("set", False, False),
+ ("set", False, True),
+ ("set", True, False),
+ ("set", True, True),
+ ],
+)
+@pytest.mark.parametrize("ends_first", ("get", "set"))
+def test_multithread_race_condition_set_get(
+ get_when, set_when, starts_first, seed, update_on_set, ends_first
+):
+ """Test race conditions between __setitem__ and __getitem__ on the same key"""
+ in_get = threading.Event()
+ block_get = threading.Event()
+ in_set = threading.Event()
+ block_set = threading.Event()
+
+ class Slow(UserDict):
+ def __getitem__(self, k):
+ if get_when == "before":
+ try:
+ v = self.data[k]
+ finally:
+ in_get.set()
+ assert block_get.wait(timeout=5)
+ return v
+ else:
+ in_get.set()
+ assert block_get.wait(timeout=5)
+ return self.data[k]
+
+ def __setitem__(self, k, v):
+ if set_when == "before":
+ self.data[k] = v
+ in_set.set()
+ assert block_set.wait(timeout=5)
+ else:
+ in_set.set()
+ assert block_set.wait(timeout=5)
+ self.data[k] = v
+
+ z = Cache(Slow(), {}, update_on_set=update_on_set)
+ if seed:
+ block_set.set()
+ z["x"] = 1
+ in_set.clear()
+ block_set.clear()
+ assert z.data.data == {"x": 1}
+ assert set(z._last_updated) == {"x"}
+ if update_on_set:
+ assert z.cache == {"x": 1}
+ else:
+ assert z.cache == {}
+
+ with ThreadPoolExecutor(2) as ex:
+ if starts_first == "get":
+ get_fut = ex.submit(z.__getitem__, "x")
+ assert in_get.wait(timeout=5)
+ set_fut = ex.submit(z.__setitem__, "x", 2)
+ assert in_set.wait(timeout=5)
+ else:
+ set_fut = ex.submit(z.__setitem__, "x", 2)
+ assert in_set.wait(timeout=5)
+ get_fut = ex.submit(z.__getitem__, "x")
+ assert in_get.wait(timeout=5)
+
+ if ends_first == "get":
+ block_get.set()
+ try:
+ assert get_fut.result() in (1, 2)
+ except KeyError:
+ pass
+ block_set.set()
+ set_fut.result()
+ else:
+ block_set.set()
+ set_fut.result()
+ block_get.set()
+ try:
+ assert get_fut.result() in (1, 2)
+ except KeyError:
+ pass
+
+ assert z.data.data == {"x": 2}
+ # The cache is either not populated or up to date
+ assert z.cache in ({}, {"x": 2})
+ assert set(z._last_updated) == {"x"}
+
+
+@pytest.mark.parametrize("get_when", ("before", "after"))
+def test_multithread_race_condition_del_get(get_when):
+ """Test race conditions between __delitem__ and __getitem__ on the same key"""
+ in_get = threading.Event()
+ block_get = threading.Event()
+
+ class Slow(UserDict):
+ def __getitem__(self, k):
+ if get_when == "before":
+ v = self.data[k]
+ in_get.set()
+ assert block_get.wait(timeout=5)
+ return v
+ else:
+ in_get.set()
+ assert block_get.wait(timeout=5)
+ return self.data[k]
+
+ z = Cache(Slow(), {}, update_on_set=False)
+ z["x"] = 1
+ assert z.data.data == {"x": 1}
+ assert set(z._last_updated) == {"x"}
+ assert z.cache == {}
+
+ with ThreadPoolExecutor(1) as ex:
+ get_fut = ex.submit(z.__getitem__, "x")
+ assert in_get.wait(timeout=5)
+ del z["x"]
+
+ block_get.set()
+ if get_when == "before":
+ assert get_fut.result() == 1
+ else:
+ with pytest.raises(KeyError):
+ get_fut.result()
+
+ assert not z.data.data
+ assert not z.cache
+ assert not z._last_updated
+
+
+@pytest.mark.parametrize("set_when", ("before", "after"))
+@pytest.mark.parametrize(
+ "seed_data,seed_cache", [(False, False), (True, False), (True, True)]
+)
+@pytest.mark.parametrize("update_on_set", [False, True])
+def test_multithread_race_condition_del_set(
+ set_when, seed_data, seed_cache, update_on_set
+):
+ """Test race conditions between __delitem__ and __setitem__ on the same key"""
+ in_set = threading.Event()
+ block_set = threading.Event()
+
+ class Slow(UserDict):
+ def __setitem__(self, k, v):
+ if set_when == "before":
+ self.data[k] = v
+ in_set.set()
+ assert block_set.wait(timeout=5)
+ else:
+ in_set.set()
+ assert block_set.wait(timeout=5)
+ self.data[k] = v
+
+ z = Cache(Slow(), {}, update_on_set=update_on_set)
+ if seed_data:
+ block_set.set()
+ z["x"] = 1
+ in_set.clear()
+ block_set.clear()
+ if seed_cache and not update_on_set:
+ _ = z["x"]
+
+ with ThreadPoolExecutor(1) as ex:
+ set_fut = ex.submit(z.__setitem__, "x", 2)
+ assert in_set.wait(timeout=5)
+ try:
+ del z["x"]
+ except KeyError:
+ pass
+ block_set.set()
+ set_fut.result()
+
+ assert z.data.data in ({"x": 2}, {})
+ assert z.cache in (z.data.data, {})
+ assert z._last_updated.keys() == z.data.data.keys()
+
+
+@pytest.mark.parametrize("set1_when", ("before", "after"))
+@pytest.mark.parametrize("set2_when", ("before", "after"))
+@pytest.mark.parametrize("starts_first", (1, 2))
+@pytest.mark.parametrize("ends_first", (1, 2))
+@pytest.mark.parametrize(
+ "seed_data,seed_cache", [(False, False), (True, False), (True, True)]
+)
+@pytest.mark.parametrize("update_on_set", [False, True])
+def test_multithread_race_condition_set_set(
+ set1_when, set2_when, starts_first, ends_first, seed_data, seed_cache, update_on_set
+):
+ """Test __setitem__ in race condition with itself"""
+ when = {1: set1_when, 2: set2_when}
+ in_set = {1: threading.Event(), 2: threading.Event()}
+ block_set = {1: threading.Event(), 2: threading.Event()}
+
+ class Slow(UserDict):
+ def __setitem__(self, k, v):
+ if v == 0:
+ # seed
+ self.data[k] = v
+ return
+
+ if when[v] == "before":
+ self.data[k] = v
+ in_set[v].set()
+ assert block_set[v].wait(timeout=5)
+ else:
+ in_set[v].set()
+ assert block_set[v].wait(timeout=5)
+ self.data[k] = v
+
+ z = Cache(Slow(), {}, update_on_set=update_on_set)
+ if seed_data:
+ z["x"] = 0
+ if seed_cache and not update_on_set:
+ _ = z["x"]
+
+ with ThreadPoolExecutor(2) as ex:
+ futures = {}
+ futures[starts_first] = ex.submit(z.__setitem__, "x", starts_first)
+ assert in_set[starts_first].wait(timeout=5)
+ starts_second = 2 if starts_first == 1 else 1
+ futures[starts_second] = ex.submit(z.__setitem__, "x", starts_second)
+ assert in_set[starts_second].wait(timeout=5)
+
+ block_set[ends_first].set()
+ futures[ends_first].result()
+ ends_second = 2 if ends_first == 1 else 1
+ block_set[ends_second].set()
+ futures[ends_second].result()
+
+ assert z.data.data in ({"x": 1}, {"x": 2})
+ assert z.cache in (z.data.data, {})
+ assert set(z._last_updated) == {"x"}
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+@pytest.mark.parametrize("update_on_set", [False, True])
+def test_stress_different_keys_threadsafe(update_on_set):
+ buff = Cache(utils_test.SlowDict(0.001), {}, update_on_set=update_on_set)
+ utils_test.check_different_keys_threadsafe(buff)
+ assert not buff.cache
+ assert not buff.data
+ assert not buff._last_updated
+ utils_test.check_mapping(buff)
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+@pytest.mark.parametrize("update_on_set", [False, True])
+def test_stress_same_key_threadsafe(update_on_set):
+ buff = Cache(utils_test.SlowDict(0.001), {}, update_on_set=update_on_set)
+ utils_test.check_same_key_threadsafe(buff)
+ assert not buff.cache
+ assert not buff.data
+ assert not buff._last_updated
+ utils_test.check_mapping(buff)
diff --git a/zict/tests/test_common.py b/zict/tests/test_common.py
new file mode 100644
index 0000000..4d7a2b8
--- /dev/null
+++ b/zict/tests/test_common.py
@@ -0,0 +1,108 @@
+import pickle
+
+import pytest
+
+from zict.common import locked
+from zict.tests.utils_test import SimpleDict
+
+
+def test_close_on_del():
+ closed = False
+
+ class D(SimpleDict):
+ def close(self):
+ nonlocal closed
+ closed = True
+
+ d = D()
+ del d
+ assert closed
+
+
+def test_context():
+ closed = False
+
+ class D(SimpleDict):
+ def close(self):
+ nonlocal closed
+ closed = True
+
+ d = D()
+ with d as d2:
+ assert d2 is d
+ assert closed
+
+
+def test_update():
+ items = []
+
+ class D(SimpleDict):
+ def _do_update(self, items_):
+ nonlocal items
+ items = items_
+
+ d = D()
+ d.update({"x": 1})
+ assert list(items) == [("x", 1)]
+ d.update(iter([("x", 2)]))
+ assert list(items) == [("x", 2)]
+ d.update({"x": 3}, y=4)
+ assert list(items) == [("x", 3), ("y", 4)]
+ d.update(x=5)
+ assert list(items) == [("x", 5)]
+
+ # Special kwargs can't overwrite positional-only parameters
+ d.update(self=1, other=2)
+ assert list(items) == [("self", 1), ("other", 2)]
+
+
+def test_discard():
+ class D(SimpleDict):
+ def __getitem__(self, key):
+ raise AssertionError()
+
+ d = D()
+ d["x"] = 1
+ d["z"] = 2
+ d.discard("x")
+ d.discard("y")
+ assert d.data == {"z": 2}
+
+
+def test_pickle():
+ d = SimpleDict()
+ d["x"] = 1
+ d2 = pickle.loads(pickle.dumps(d))
+ assert d2.data == {"x": 1}
+
+
+def test_lock(is_locked):
+ class CustomError(Exception):
+ pass
+
+ class D(SimpleDict):
+ @locked
+ def f(self, crash):
+ assert is_locked(self)
+ with self.unlock():
+ assert not is_locked(self)
+ assert is_locked(self)
+
+ # context manager re-acquires the lock on failure
+ with pytest.raises(CustomError):
+ with self.unlock():
+ raise CustomError()
+ assert is_locked(self)
+
+ if crash:
+ raise CustomError()
+
+ d = D()
+ assert not is_locked(d)
+ d.f(crash=False)
+ assert not is_locked(d)
+
+ # decorator releases the lock on failure
+ with pytest.raises(CustomError):
+ d.f(crash=True)
+ assert not is_locked(d)
diff --git a/zict/tests/test_file.py b/zict/tests/test_file.py
index a79c16d..6981f35 100644
--- a/zict/tests/test_file.py
+++ b/zict/tests/test_file.py
@@ -1,95 +1,96 @@
import os
-import shutil
+import pathlib
+import sys
import pytest
-from zict.file import File
+from zict import File
from zict.tests import utils_test
-@pytest.fixture
-def fn():
- filename = ".tmp"
- if os.path.exists(filename):
- shutil.rmtree(filename)
-
- yield filename
-
- if os.path.exists(filename):
- shutil.rmtree(filename)
-
-
-def test_mapping(fn):
+def test_mapping(tmp_path, check_fd_leaks):
"""
Test mapping interface for File().
"""
- z = File(fn)
+ z = File(tmp_path)
utils_test.check_mapping(z)
-def test_implementation(fn):
- z = File(fn)
+@pytest.mark.parametrize("dirtype", [str, pathlib.Path, lambda x: x])
+def test_implementation(tmp_path, check_fd_leaks, dirtype):
+ z = File(dirtype(tmp_path))
assert not z
z["x"] = b"123"
- assert os.listdir(fn) == ["x"]
- with open(os.path.join(fn, "x"), "rb") as f:
+ assert os.listdir(tmp_path) == ["x#0"]
+ with open(tmp_path / "x#0", "rb") as f:
assert f.read() == b"123"
assert "x" in z
+ out = z["x"]
+ assert isinstance(out, bytearray)
+ assert out == b"123"
-def test_memmap_implementation(fn):
- z = File(fn, memmap=True)
+def test_memmap_implementation(tmp_path, check_fd_leaks):
+ z = File(tmp_path, memmap=True)
assert not z
- z["x"] = b"123"
- assert os.listdir(fn) == ["x"]
- assert z["x"] == memoryview(b"123")
-
+ mv = memoryview(b"123")
+ assert "x" not in z
+ z["x"] = mv
+ assert os.listdir(tmp_path) == ["x#0"]
assert "x" in z
+ mv2 = z["x"]
+ assert mv2 == b"123"
+ # Buffer is writeable
+ mv2[0] = mv2[1]
+ assert mv2 == b"223"
-def test_str(fn):
- z = File(fn)
- assert fn in str(z)
- assert fn in repr(z)
- assert z.mode in str(z)
- assert z.mode in repr(z)
+def test_str(tmp_path, check_fd_leaks):
+ z = File(tmp_path)
+ assert str(z) == repr(z) == f"<File: {tmp_path}, 0 elements>"
-def test_setitem_typeerror(fn):
- z = File(fn)
+def test_setitem_typeerror(tmp_path, check_fd_leaks):
+ z = File(tmp_path)
with pytest.raises(TypeError):
z["x"] = 123
-def test_contextmanager(fn):
- with File(fn) as z:
+def test_contextmanager(tmp_path, check_fd_leaks):
+ with File(tmp_path) as z:
z["x"] = b"123"
- with open(os.path.join(fn, "x"), "rb") as f:
- assert f.read() == b"123"
+ with open(tmp_path / "x#0", "rb") as fh:
+ assert fh.read() == b"123"
-def test_delitem(fn):
- z = File(fn)
+def test_delitem(tmp_path, check_fd_leaks):
+ z = File(tmp_path)
z["x"] = b"123"
- assert os.path.exists(os.path.join(z.directory, "x"))
+ assert os.listdir(tmp_path) == ["x#0"]
del z["x"]
- assert not os.path.exists(os.path.join(z.directory, "x"))
+ assert os.listdir(tmp_path) == []
+ # File name is never repeated
+ z["x"] = b"123"
+ assert os.listdir(tmp_path) == ["x#1"]
+ # __setitem__ deletes the previous file
+ z["x"] = b"123"
+ assert os.listdir(tmp_path) == ["x#2"]
-def test_missing_key(fn):
- z = File(fn)
+def test_missing_key(tmp_path, check_fd_leaks):
+ z = File(tmp_path)
with pytest.raises(KeyError):
z["x"]
-def test_arbitrary_chars(fn):
- z = File(fn)
+def test_arbitrary_chars(tmp_path, check_fd_leaks):
+ z = File(tmp_path)
# Avoid hitting the Windows max filename length
chunk = 16
@@ -104,7 +105,7 @@ def test_arbitrary_chars(fn):
assert list(z.items()) == [(key, b"foo")]
assert list(z.values()) == [b"foo"]
- zz = File(fn)
+ zz = File(tmp_path)
assert zz[key] == b"foo"
assert list(zz) == [key]
assert list(zz.keys()) == [key]
@@ -117,8 +118,31 @@ def test_arbitrary_chars(fn):
z[key]
-def test_write_list_of_bytes(fn):
- z = File(fn)
+def test_write_list_of_bytes(tmp_path, check_fd_leaks):
+ z = File(tmp_path)
z["x"] = [b"123", b"4567"]
assert z["x"] == b"1234567"
+
+
+def test_bad_types(tmp_path, check_fd_leaks):
+ z = File(tmp_path)
+ utils_test.check_bad_key_types(z)
+ utils_test.check_bad_value_types(z)
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+def test_stress_different_keys_threadsafe(tmp_path):
+ z = File(tmp_path)
+ utils_test.check_different_keys_threadsafe(z)
+ utils_test.check_mapping(z)
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+@pytest.mark.skipif(sys.platform == "win32", reason="Can't delete file with open fd")
+def test_stress_same_key_threadsafe(tmp_path):
+ z = File(tmp_path)
+ utils_test.check_same_key_threadsafe(z)
+ utils_test.check_mapping(z)
diff --git a/zict/tests/test_func.py b/zict/tests/test_func.py
index fbe89b8..f3dad4c 100644
--- a/zict/tests/test_func.py
+++ b/zict/tests/test_func.py
@@ -25,7 +25,7 @@ def rotr(x):
def test_simple():
- d = dict()
+ d = {}
f = Func(inc, dec, d)
f["x"] = 10
assert f["x"] == 10
diff --git a/zict/tests/test_lmdb.py b/zict/tests/test_lmdb.py
index b4f66f8..4db649c 100644
--- a/zict/tests/test_lmdb.py
+++ b/zict/tests/test_lmdb.py
@@ -1,67 +1,69 @@
-import gc
import os
-import shutil
-import tempfile
+import pathlib
import pytest
-from zict.lmdb import LMDB
+from zict import LMDB
from zict.tests import utils_test
+pytest.importorskip("lmdb")
-@pytest.fixture
-def fn():
- dirname = tempfile.mkdtemp(prefix="test_lmdb-")
- try:
- yield dirname
- finally:
- if os.path.exists(dirname):
- shutil.rmtree(dirname)
+@pytest.mark.parametrize("dirtype", [str, pathlib.Path, lambda x: x])
+def test_dirtypes(tmp_path, check_fd_leaks, dirtype):
+ z = LMDB(tmp_path)
+ z["x"] = b"123"
+ assert z["x"] == b"123"
+ del z["x"]
-def test_mapping(fn):
+
+def test_mapping(tmp_path, check_fd_leaks):
"""
Test mapping interface for LMDB().
"""
- z = LMDB(fn)
+ z = LMDB(tmp_path)
utils_test.check_mapping(z)
-def test_reuse(fn):
+def test_bad_types(tmp_path, check_fd_leaks):
+ z = LMDB(tmp_path)
+ utils_test.check_bad_key_types(z)
+ utils_test.check_bad_value_types(z)
+
+
+def test_reuse(tmp_path, check_fd_leaks):
"""
Test persistence of a LMDB() mapping.
"""
- with LMDB(fn) as z:
+ with LMDB(tmp_path) as z:
assert len(z) == 0
z["abc"] = b"123"
- with LMDB(fn) as z:
+ with LMDB(tmp_path) as z:
assert len(z) == 1
assert z["abc"] == b"123"
-def test_creates_dir(fn):
- with LMDB(fn):
- assert os.path.isdir(fn)
-
+def test_creates_dir(tmp_path, check_fd_leaks):
+ with LMDB(tmp_path, check_fd_leaks):
+ assert os.path.isdir(tmp_path)
-def test_file_descriptors_dont_leak(fn):
- psutil = pytest.importorskip("psutil")
- proc = psutil.Process()
- before = proc.num_fds()
- z = LMDB(fn)
+def test_file_descriptors_dont_leak(tmp_path, check_fd_leaks):
+ z = LMDB(tmp_path)
del z
- gc.collect()
- assert proc.num_fds() == before
-
- z = LMDB(fn)
+ z = LMDB(tmp_path)
z.close()
- assert proc.num_fds() == before
-
- with LMDB(fn) as z:
+ with LMDB(tmp_path) as z:
pass
- assert proc.num_fds() == before
+
+def test_map_size(tmp_path, check_fd_leaks):
+ import lmdb
+
+ z = LMDB(tmp_path, map_size=2**20)
+ z["x"] = b"x" * 2**19
+ with pytest.raises(lmdb.MapFullError):
+ z["y"] = b"x" * 2**20
diff --git a/zict/tests/test_lru.py b/zict/tests/test_lru.py
index 1b0a048..9da41f5 100644
--- a/zict/tests/test_lru.py
+++ b/zict/tests/test_lru.py
@@ -1,3 +1,8 @@
+import random
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor
+
import pytest
from zict import LRU
@@ -5,7 +10,7 @@ from zict.tests import utils_test
def test_simple():
- d = dict()
+ d = {}
lru = LRU(2, d)
lru["x"] = 1
@@ -28,11 +33,11 @@ def test_simple():
assert "y" not in lru
lru["a"] = 5
- assert set(lru.keys()) == {"z", "a"}
+ assert set(lru) == {"z", "a"}
def test_str():
- d = dict()
+ d = {}
lru = LRU(2, d)
lru["x"] = 1
@@ -56,9 +61,15 @@ def test_mapping():
utils_test.check_mapping(lru)
utils_test.check_closing(lru)
+ lru.clear()
+ assert not lru.d
+ assert not lru.weights
+ assert not lru.total_weight
+ assert not lru._cancel_evict
+
def test_overwrite():
- d = dict()
+ d = {}
lru = LRU(2, d)
lru["x"] = 1
@@ -78,8 +89,8 @@ def test_callbacks():
def cb(k, v):
count[0] += 1
- L = list()
- d = dict()
+ L = []
+ d = {}
lru = LRU(2, d, on_evict=[lambda k, v: L.append((k, v)), cb])
lru["x"] = 1
@@ -126,7 +137,7 @@ def test_cb_exception_keep_on_lru():
assert set(lru) == {"x", "y", "z"}
assert lru.d == {"x": 1, "y": 2, "z": 3}
- assert dict(lru.heap) == {"x": 1, "y": 2, "z": 3}
+ assert list(lru.order) == ["x", "y", "z"]
def test_cb_exception_keep_on_lru_weights():
@@ -165,7 +176,7 @@ def test_cb_exception_keep_on_lru_weights():
assert set(lru) == {"y"}
assert lru.d == {"y": 3}
- assert dict(lru.heap) == {"y": 2}
+ assert list(lru.order) == ["y"]
with pytest.raises(MyError):
# value is individually heavier than n
@@ -184,11 +195,11 @@ def test_cb_exception_keep_on_lru_weights():
assert set(lru) == {"y", "z"}
assert lru.d == {"y": 3, "z": 4}
- assert dict(lru.heap) == {"y": 2, "z": 3}
+ assert list(lru.order) == ["y", "z"]
def test_weight():
- d = dict()
+ d = {}
weight = lambda k, v: v
lru = LRU(10, d, weight=weight)
@@ -210,17 +221,239 @@ def test_weight():
assert d == {"y": 4}
+def test_manual_eviction():
+ a = []
+ lru = LRU(100, {}, weight=lambda k, v: v, on_evict=lambda k, v: a.append(k))
+ lru.set_noevict("x", 70)
+ lru.set_noevict("y", 50)
+ lru.set_noevict("z", 110)
+ assert lru.total_weight == 70 + 50 + 110
+ assert lru.heavy == {"z"}
+ assert list(lru.order) == ["x", "y", "z"]
+ assert a == []
+
+ lru.evict_until_below_target()
+ assert dict(lru) == {"y": 50}
+ assert a == ["z", "x"]
+ assert lru.weights == {"y": 50}
+ assert lru.order == {"y"}
+ assert not lru.heavy
+
+ lru.evict_until_below_target() # No-op
+ assert dict(lru) == {"y": 50}
+ lru.evict_until_below_target(50) # Custom target
+ assert dict(lru) == {"y": 50}
+ lru.evict_until_below_target(0) # 0 != None
+ assert not lru
+ assert not lru.order
+ assert not lru.weights
+ assert a == ["z", "x", "y"]
+
+
def test_explicit_evict():
- d = dict()
+ d = {}
lru = LRU(10, d)
lru["x"] = 1
lru["y"] = 2
+ lru["z"] = 3
- assert set(d) == {"x", "y"}
+ assert set(d) == {"x", "y", "z"}
- k, v, w = lru.evict()
+ assert lru.evict() == ("x", 1, 1)
+ assert set(d) == {"y", "z"}
+ assert lru.evict("z") == ("z", 3, 1)
assert set(d) == {"y"}
- assert k == "x"
- assert v == 1
- assert w == 1
+ assert lru.evict() == ("y", 2, 1)
+ with pytest.raises(KeyError, match=r"'evict\(\): dictionary is empty'"):
+ lru.evict()
+
+ # evict() with explicit key
+ lru["v"] = 4
+ lru["w"] = 5
+ assert lru.evict("w") == ("w", 5, 1)
+ with pytest.raises(KeyError, match="notexist"):
+ lru.evict("notexist")
+
+
+def test_init_not_empty():
+ lru1 = LRU(100, {}, weight=lambda k, v: v * 2)
+ lru1.set_noevict(1, 10)
+ lru1.set_noevict(2, 20)
+ lru1.set_noevict(3, 30)
+ lru1.set_noevict(4, 60)
+ lru2 = LRU(100, {1: 10, 2: 20, 3: 30, 4: 60}, weight=lambda k, v: v * 2)
+ assert lru1.d == lru2.d == {1: 10, 2: 20, 3: 30, 4: 60}
+ assert lru1.weights == lru2.weights == {1: 20, 2: 40, 3: 60, 4: 120}
+ assert lru1.total_weight == lru2.total_weight == 240
+ assert list(lru1.order) == list(lru2.order) == [1, 2, 3, 4]
+ assert list(lru1.heavy) == list(lru2.heavy) == [4]
+
+
+def test_get_all_or_nothing():
+ lru = LRU(100, {"x": 1, "y": 2, "z": 3})
+ assert list(lru.order) == ["x", "y", "z"]
+ with pytest.raises(KeyError, match="w"):
+ lru.get_all_or_nothing(["x", "w", "y"])
+ assert list(lru.order) == ["x", "y", "z"]
+ assert lru.get_all_or_nothing(["y", "x"]) == {"y": 2, "x": 1}
+ assert list(lru.order) == ["z", "y", "x"]
+
+
+def test_close_aborts_eviction():
+ evicted = []
+
+ def cb(k, v):
+ evicted.append(k)
+ if len(evicted) == 3:
+ lru.close()
+
+ lru = LRU(100, {}, weight=lambda k, v: v, on_evict=cb)
+ lru["a"] = 20
+ lru["b"] = 20
+ lru["c"] = 20
+ lru["d"] = 20
+ lru["e"] = 90 # Trigger eviction of a, b, c, d
+
+ assert lru.closed
+ assert evicted == ["a", "b", "c"]
+ assert dict(lru) == {"d": 20, "e": 90}
+
+
+def test_flush_close():
+ flushed = 0
+ closed = False
+
+ class D(utils_test.SimpleDict):
+ def flush(self):
+ nonlocal flushed
+ flushed += 1
+
+ def close(self):
+ nonlocal closed
+ closed = True
+
+ with LRU(10, D()) as lru:
+ lru.flush()
+
+ assert flushed == 1
+ assert closed
+
+
+def test_update_n():
+ evicted = []
+ z = LRU(10, {}, on_evict=lambda k, v: evicted.append(k), weight=lambda k, v: v)
+ z["x"] = 5
+ assert not evicted
+
+ # Update n. This also changes what keys are considered heavy
+ # (but there isn't a full scan on the weights for already existing keys)
+ z.n = 3
+ assert not evicted
+ assert not z.heavy
+ z["y"] = 1
+ assert evicted == ["x"]
+ z["z"] = 4
+ assert evicted == ["x", "z"]
+
+
+def test_update_offset():
+ evicted = []
+ z = LRU(5, {}, on_evict=lambda k, v: evicted.append(k), weight=lambda k, v: v)
+
+ z.offset = 2
+ z["x"] = 1
+ # y would be a heavy key if we had reduced n by 2 instead of increasing offset
+ z["y"] = 2.5
+ assert evicted == ["x"]
+ z["z"] = 5.5 # Still heavy according to n alone
+ assert evicted == ["x", "z"]
+
+
+@pytest.mark.parametrize("event", ("set", "set_noevict", "del"))
+def test_cancel_evict(event):
+ """See also:
+
+ test_buffer.py::test_cancel_evict
+ test_buffer.py::test_cancel_restore
+ """
+ ev1 = threading.Event()
+ ev2 = threading.Event()
+ log = []
+
+ def cb(k, v):
+ ev1.set()
+ assert ev2.wait(timeout=5)
+
+ def cancel_cb(k, v):
+ log.append((k, v))
+
+ lru = LRU(100, {}, on_evict=cb, on_cancel_evict=cancel_cb, weight=lambda k, v: v)
+ lru.set_noevict("x", 1)
+ with ThreadPoolExecutor(1) as ex:
+ fut = ex.submit(lru.evict)
+ assert ev1.wait(timeout=5)
+ # cb is running
+
+ assert lru.evict() == (None, None, 0)
+ if event == "set":
+ lru["x"] = 2
+ elif event == "set_noevict":
+ lru.set_noevict("x", 2)
+ else:
+ assert event == "del"
+ del lru["x"]
+
+ ev2.set()
+ assert fut.result() == (None, None, 0)
+
+ assert log == [("x", 1)]
+ if event in ("set", "set_noevict"):
+ assert lru.d == {"x": 2}
+ assert lru.weights == {"x": 2}
+ assert list(lru.order) == ["x"]
+ else:
+ assert not lru.d
+ assert not lru.weights
+ assert not lru.order
+
+ assert not lru._cancel_evict
+
+
+def slow_cb(k, v):
+ time.sleep(0.01)
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+def test_stress_different_keys_threadsafe():
+ # Sometimes x and y can cohexist without triggering eviction
+ # Sometimes x and y are individually <n but when they're both in they cause eviction
+ # Sometimes x or y are heavy
+ lru = LRU(
+ 1,
+ {},
+ weight=lambda k, v: random.choice([0.4, 0.9, 1.1]),
+ on_evict=slow_cb,
+ on_cancel_evict=slow_cb,
+ )
+ utils_test.check_different_keys_threadsafe(lru, allow_keyerror=True)
+ lru.n = 100
+ utils_test.check_mapping(lru)
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+def test_stress_same_key_threadsafe():
+ # Sometimes x is heavy
+ lru = LRU(
+ 1,
+ {},
+ weight=lambda k, v: random.choice([0.9, 1.1]),
+ on_evict=slow_cb,
+ on_cancel_evict=slow_cb,
+ )
+
+ utils_test.check_same_key_threadsafe(lru)
+ lru.n = 100
+ utils_test.check_mapping(lru)
diff --git a/zict/tests/test_sieve.py b/zict/tests/test_sieve.py
index 9b6be69..fafb3ef 100644
--- a/zict/tests/test_sieve.py
+++ b/zict/tests/test_sieve.py
@@ -1,3 +1,10 @@
+import random
+import threading
+from collections import UserDict
+from concurrent.futures import ThreadPoolExecutor
+
+import pytest
+
from zict import Sieve
from zict.tests import utils_test
@@ -72,3 +79,146 @@ def test_mapping():
z = Sieve(mappings, selector)
utils_test.check_mapping(z)
utils_test.check_closing(z)
+
+ z.clear()
+ assert z.mappings == {0: {}, 1: {}}
+ assert not z.key_to_mapping
+
+
+@pytest.mark.parametrize("method", ("__setitem__", "update"))
+@pytest.mark.parametrize("set_when", ("before", "after"))
+@pytest.mark.parametrize("seed", [False, "same", "different"])
+def test_multithread_race_condition_del_set(method, set_when, seed):
+ """Test race conditions between __delitem__ and __setitem__/update on the same key"""
+ in_set = threading.Event()
+ block_set = threading.Event()
+
+ class Slow(UserDict):
+ def __setitem__(self, k, v):
+ if set_when == "before":
+ self.data[k] = v
+ in_set.set()
+ assert block_set.wait(timeout=5)
+ else:
+ in_set.set()
+ assert block_set.wait(timeout=5)
+ self.data[k] = v
+
+ z = Sieve({0: {}, 1: Slow()}, selector=lambda k, v: v % 2)
+ if seed == "same":
+ block_set.set()
+ z["x"] = 1 # mapping 1
+ in_set.clear()
+ block_set.clear()
+ elif seed == "different":
+ z["x"] = 0 # mapping 0
+
+ with ThreadPoolExecutor(1) as ex:
+ if method == "__setitem__":
+ set_fut = ex.submit(z.__setitem__, "x", 3) # mapping 1
+ else:
+ assert method == "update"
+ set_fut = ex.submit(z.update, {"x": 3}) # mapping 1
+ assert in_set.wait(timeout=5)
+ try:
+ del z["x"]
+ except KeyError:
+ pass
+ block_set.set()
+ set_fut.result()
+
+ assert not z.mappings[0]
+ assert not z.mappings[1]
+ assert not z.key_to_mapping
+
+
+@pytest.mark.parametrize("set1_when", ("before", "after"))
+@pytest.mark.parametrize("set2_when", ("before", "after"))
+@pytest.mark.parametrize("set1_method", ("__setitem__", "update"))
+@pytest.mark.parametrize("set2_method", ("__setitem__", "update"))
+@pytest.mark.parametrize("starts_first", (1, 2))
+@pytest.mark.parametrize("ends_first", (1, 2))
+@pytest.mark.parametrize("seed", [False, 0, 1, 2])
+def test_multithread_race_condition_set_set(
+ set1_when, set2_when, set1_method, set2_method, starts_first, ends_first, seed
+):
+ """Test __setitem__/update in race condition with __setitem__/update"""
+ when = {1: set1_when, 2: set2_when}
+ method = {1: set1_method, 2: set2_method}
+ in_set = {1: threading.Event(), 2: threading.Event()}
+ block_set = {1: threading.Event(), 2: threading.Event()}
+
+ class Slow(UserDict):
+ def __setitem__(self, k, v):
+ if v < 3:
+ # seed
+ self.data[k] = v
+ return
+ mkey = v % 3
+
+ if when[mkey] == "before":
+ self.data[k] = v
+ in_set[mkey].set()
+ assert block_set[mkey].wait(timeout=5)
+ else:
+ in_set[mkey].set()
+ assert block_set[mkey].wait(timeout=5)
+ self.data[k] = v
+
+ z = Sieve({0: {}, 1: Slow(), 2: Slow()}, selector=lambda k, v: v % 3)
+ if seed is not False:
+ z["x"] = seed
+
+ with ThreadPoolExecutor(2) as ex:
+ futures = {}
+ starts_second = 2 if starts_first == 1 else 1
+ for idx in (starts_first, starts_second):
+ if method[idx] == "__setitem__":
+ futures[idx] = ex.submit(z.__setitem__, "x", idx + 3)
+ else:
+ assert method[idx] == "update"
+ futures[idx] = ex.submit(z.update, {"x": idx + 3})
+ assert in_set[idx].wait(timeout=5)
+
+ block_set[ends_first].set()
+ futures[ends_first].result()
+ ends_second = 2 if ends_first == 1 else 1
+ block_set[ends_second].set()
+ futures[ends_second].result()
+
+ assert dict(z) in ({"x": 4}, {"x": 5})
+ assert z.mappings[0] == {}
+ if z["x"] == 4:
+ assert z.mappings[1] == {"x": 4}
+ assert z.mappings[2] == {}
+ assert z.key_to_mapping == {"x": z.mappings[1]}
+ else:
+ assert z.mappings[1] == {}
+ assert z.mappings[2] == {"x": 5}
+ assert z.key_to_mapping == {"x": z.mappings[2]}
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+def test_stress_different_keys_threadsafe():
+ a = {}
+ b = {}
+ z = Sieve({0: a, 1: b}, lambda k, v: random.choice([0, 1]))
+ utils_test.check_different_keys_threadsafe(z)
+ assert not a
+ assert not b
+ assert not z.key_to_mapping
+ utils_test.check_mapping(z)
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+def test_stress_same_key_threadsafe():
+ a = utils_test.SlowDict(0.001)
+ b = utils_test.SlowDict(0.001)
+ z = Sieve({0: a, 1: b}, lambda k, v: random.choice([0, 1]))
+ utils_test.check_same_key_threadsafe(z)
+ assert not a
+ assert not b
+ assert not z.key_to_mapping
+ utils_test.check_mapping(z)
diff --git a/zict/tests/test_utils.py b/zict/tests/test_utils.py
new file mode 100644
index 0000000..729f2dc
--- /dev/null
+++ b/zict/tests/test_utils.py
@@ -0,0 +1,112 @@
+from concurrent.futures import ThreadPoolExecutor
+from threading import Barrier
+
+import pytest
+
+from zict import InsertionSortedSet
+from zict.tests import utils_test
+
+
+def test_insertion_sorted_set():
+ s = InsertionSortedSet()
+
+ assert not s
+ assert len(s) == 0
+ assert list(s) == []
+ assert s == set()
+ assert s != []
+ assert s == InsertionSortedSet()
+ assert 1 not in s
+ s.discard(1)
+ with pytest.raises(KeyError):
+ s.remove(1)
+ with pytest.raises(KeyError):
+ s.pop()
+ with pytest.raises(KeyError):
+ s.popleft()
+ with pytest.raises(KeyError):
+ s.popright()
+
+ s.add(1)
+ assert 1 in s
+ assert 2 not in s
+ assert len(s) == 1
+ assert list(s) == [1]
+ assert s == {1}
+ assert s != [1]
+ assert s & {1, 2} == {1}
+ assert s | {1, 2} == {1, 2}
+ assert s - {1, 2} == set()
+
+ # Add already-existing element
+ s.add(1)
+ assert len(s) == 1
+ assert list(s) == [1]
+
+ s.remove(1)
+ assert not s
+ s.add(1)
+ assert list(s) == [1]
+ s.discard(1)
+ assert not s
+ s.add(1)
+ assert s.pop() == 1
+ s.add(1)
+ s.clear()
+ assert not s
+
+ # Initialise from iterable
+ s = InsertionSortedSet(iter([3, 1, 2, 5, 4, 6, 0]))
+ assert list(s) == [3, 1, 2, 5, 4, 6, 0]
+
+ # Adding already-existing element does not change order
+ s.add(2)
+ assert list(s) == [3, 1, 2, 5, 4, 6, 0]
+
+ # Removing element does not change order
+ s.remove(2)
+ assert list(s) == [3, 1, 5, 4, 6, 0]
+
+ s.add(2) # Re-added elements are added to the end
+ s.add(7)
+ assert list(s) == [3, 1, 5, 4, 6, 0, 2, 7]
+
+ assert [s.popleft() for _ in range(len(s))] == [3, 1, 5, 4, 6, 0, 2, 7]
+
+ s |= [3, 1, 5, 4, 6, 0, 2, 7]
+ assert [s.popright() for _ in range(len(s))] == [7, 2, 0, 6, 4, 5, 1, 3]
+
+ # pop() is an alias to popright()
+ s |= [3, 1, 5, 4, 6, 0, 2, 7]
+ assert [s.pop() for _ in range(len(s))] == [7, 2, 0, 6, 4, 5, 1, 3]
+
+
+@pytest.mark.stress
+@pytest.mark.repeat(utils_test.REPEAT_STRESS_TESTS)
+@pytest.mark.parametrize("method,size", [("popleft", 100_000), ("popright", 5_000_000)])
+def test_insertion_sorted_set_threadsafe(method, size):
+ s = InsertionSortedSet(range(size))
+ m = getattr(s, method)
+ barrier = Barrier(2)
+
+ def t():
+ barrier.wait()
+ n = 0
+ prev = -1 if method == "popleft" else size
+ while True:
+ try:
+ v = m()
+ assert v > prev if method == "popleft" else v < prev, (v, prev, len(s))
+ prev = v
+ n += 1
+ except KeyError:
+ assert not s
+ return n
+
+ with ThreadPoolExecutor(2) as ex:
+ f1 = ex.submit(t)
+ f2 = ex.submit(t)
+ # On Linux, these are in the 38_000 ~ 62_000 range.
+ # On Windows, we've seen as little as 2300.
+ assert f1.result() > 100
+ assert f2.result() > 100
diff --git a/zict/tests/test_zip.py b/zict/tests/test_zip.py
index 5ae1753..63d25e4 100644
--- a/zict/tests/test_zip.py
+++ b/zict/tests/test_zip.py
@@ -1,22 +1,15 @@
-import os
import zipfile
from collections.abc import MutableMapping
import pytest
from zict import Zip
+from zict.tests import utils_test
@pytest.fixture
-def fn():
- filename = ".tmp.zip"
- if os.path.exists(filename):
- os.remove(filename)
-
- yield filename
-
- if os.path.exists(filename):
- os.remove(filename)
+def fn(tmp_path, check_fd_leaks):
+ yield tmp_path / "tmp.zip"
def test_simple(fn):
@@ -83,3 +76,60 @@ def test_bytearray(fn):
with Zip(fn) as z:
assert z["x"] == b"123"
+
+
+def test_memoryview(fn):
+ data = memoryview(b"123")
+ with Zip(fn) as z:
+ z["x"] = data
+
+ with Zip(fn) as z:
+ assert z["x"] == b"123"
+
+
+def check_mapping(z):
+ """Shorter version of utils_test.check_mapping, as zip supports neither update nor
+ delete
+ """
+ assert isinstance(z, MutableMapping)
+ utils_test.check_empty_mapping(z)
+
+ z["abc"] = b"456"
+ z["xyz"] = b"12"
+ assert len(z) == 2
+ assert z["abc"] == b"456"
+
+ utils_test.check_items(z, [("abc", b"456"), ("xyz", b"12")])
+
+ assert "abc" in z
+ assert "xyz" in z
+ assert "def" not in z
+
+ with pytest.raises(KeyError):
+ z["def"]
+
+
+def test_mapping(fn):
+ """
+ Test mapping interface for Zip().
+ """
+ with Zip(fn) as z:
+ check_mapping(z)
+ utils_test.check_closing(z)
+
+
+def test_no_delete_update(fn):
+ with Zip(fn) as z:
+ z["x"] = b"123"
+ with pytest.raises(NotImplementedError):
+ del z["x"]
+ with pytest.raises(NotImplementedError):
+ z["x"] = b"456"
+ assert len(z) == 1
+ assert z["x"] == b"123"
+
+
+def test_bad_types(fn):
+ with Zip(fn) as z:
+ utils_test.check_bad_key_types(z, has_del=False)
+ utils_test.check_bad_value_types(z)
diff --git a/zict/tests/utils_test.py b/zict/tests/utils_test.py
index 8b705c5..0640e2b 100644
--- a/zict/tests/utils_test.py
+++ b/zict/tests/utils_test.py
@@ -1,16 +1,28 @@
+from __future__ import annotations
+
import random
import string
-from collections.abc import MutableMapping
+import threading
+import time
+from collections import UserDict
+from collections.abc import ItemsView, KeysView, MutableMapping, ValuesView
+from concurrent.futures import ThreadPoolExecutor
import pytest
+from zict.common import ZictBase
+
+# How many times to repeat non-deterministic stress tests.
+# You may set it as high as 50 if you wish to run in CI.
+REPEAT_STRESS_TESTS = 1
+
def generate_random_strings(n, min_len, max_len):
r = random.Random(42)
out = []
chars = string.ascii_lowercase + string.digits
- for i in range(n):
+ for _ in range(n):
nchars = r.randint(min_len, max_len)
s = "".join(r.choice(chars) for _ in range(nchars))
out.append(s)
@@ -25,7 +37,7 @@ def to_bytestring(s):
return s.encode("latin1")
-def check_items(z, expected_items):
+def check_items(z: MutableMapping, expected_items: list[tuple[str, bytes]]) -> None:
items = list(z.items())
assert len(items) == len(expected_items)
assert sorted(items) == sorted(expected_items)
@@ -34,8 +46,21 @@ def check_items(z, expected_items):
assert list(z.values()) == [v for k, v in items]
assert list(z) == [k for k, v in items]
+ # ItemsView, KeysView, ValuesView.__contains__()
+ assert isinstance(z.keys(), KeysView)
+ assert isinstance(z.values(), ValuesView)
+ assert isinstance(z.items(), ItemsView)
+ assert items[0] in z.items()
+ assert items[0][0] in z.keys()
+ assert items[0][0] in z
+ assert items[0][1] in z.values()
+ assert (object(), object()) not in z.items()
+ assert object() not in z.keys()
+ assert object() not in z
+ assert object() not in z.values()
-def stress_test_mapping_updates(z):
+
+def stress_test_mapping_updates(z: MutableMapping) -> None:
# Certain mappings shuffle between several underlying stores
# during updates. This stress tests the internal mapping
# consistency.
@@ -53,7 +78,7 @@ def stress_test_mapping_updates(z):
assert sorted(z) == sorted(keys)
assert sorted(z.items()) == sorted(zip(keys, values))
- for i in range(3):
+ for _ in range(3):
r.shuffle(keys)
r.shuffle(values)
for k, v in zip(keys, values):
@@ -66,14 +91,24 @@ def stress_test_mapping_updates(z):
check_items(z, list(zip(keys, values)))
-def check_mapping(z):
- assert isinstance(z, MutableMapping)
+def check_empty_mapping(z: MutableMapping) -> None:
assert not z
-
assert list(z) == list(z.keys()) == []
assert list(z.values()) == []
assert list(z.items()) == []
assert len(z) == 0
+ assert "x" not in z
+ assert "x" not in z.keys()
+ assert ("x", b"123") not in z.items()
+ assert b"123" not in z.values()
+
+
+def check_mapping(z: MutableMapping) -> None:
+ """See also test_zip.check_mapping"""
+ assert type(z).__name__ in str(z)
+ assert type(z).__name__ in repr(z)
+ assert isinstance(z, MutableMapping)
+ check_empty_mapping(z)
z["abc"] = b"456"
z["xyz"] = b"12"
@@ -85,6 +120,7 @@ def check_mapping(z):
assert "abc" in z
assert "xyz" in z
assert "def" not in z
+ assert object() not in z
with pytest.raises(KeyError):
z["def"]
@@ -95,6 +131,18 @@ def check_mapping(z):
check_items(z, [("abc", b"456"), ("xyz", b"654"), ("uvw", b"999")])
z.update({"xyz": b"321"})
check_items(z, [("abc", b"456"), ("xyz", b"321"), ("uvw", b"999")])
+ # Update with iterator (can read only once)
+ z.update(iter([("foo", b"132"), ("bar", b"887")]))
+ check_items(
+ z,
+ [
+ ("abc", b"456"),
+ ("xyz", b"321"),
+ ("uvw", b"999"),
+ ("foo", b"132"),
+ ("bar", b"887"),
+ ],
+ )
del z["abc"]
with pytest.raises(KeyError):
@@ -102,16 +150,154 @@ def check_mapping(z):
with pytest.raises(KeyError):
del z["abc"]
assert "abc" not in z
- assert set(z) == {"uvw", "xyz"}
- assert len(z) == 2
+ assert set(z) == {"uvw", "xyz", "foo", "bar"}
+ assert len(z) == 4
z["def"] = b"\x00\xff"
- assert len(z) == 3
+ assert len(z) == 5
assert z["def"] == b"\x00\xff"
assert "def" in z
stress_test_mapping_updates(z)
-def check_closing(z):
+def check_different_keys_threadsafe(
+ z: MutableMapping, allow_keyerror: bool = False
+) -> None:
+ barrier = threading.Barrier(2)
+ counters = [0, 0]
+
+ def worker(idx, key, value):
+ barrier.wait()
+ while any(c < 10 for c in counters):
+ z[key] = value
+ try:
+ assert z[key] == value
+ del z[key]
+ except KeyError:
+ if allow_keyerror:
+ continue # Try again, don't inc i
+ raise
+
+ assert key not in z
+ with pytest.raises(KeyError):
+ _ = z[key]
+ with pytest.raises(KeyError):
+ del z[key]
+ assert len(z) in (0, 1)
+ counters[idx] += 1
+
+ with ThreadPoolExecutor(2) as ex:
+ f1 = ex.submit(worker, 0, "x", b"123")
+ f2 = ex.submit(worker, 1, "y", b"456")
+ f1.result()
+ f2.result()
+
+ assert not z
+
+
+def check_same_key_threadsafe(z: MutableMapping) -> None:
+ barrier = threading.Barrier(4)
+ counters = [0, 0, 0, 0]
+
+ def w_set():
+ barrier.wait()
+ while any(c < 10 for c in counters):
+ z["x"] = b"123"
+ counters[0] += 1
+
+ def w_update():
+ barrier.wait()
+ while any(c < 10 for c in counters):
+ z.update(x=b"456")
+ counters[1] += 1
+
+ def w_del():
+ barrier.wait()
+ while any(c < 10 for c in counters):
+ try:
+ del z["x"]
+ counters[2] += 1
+ except KeyError:
+ pass
+
+ def w_get():
+ barrier.wait()
+ while any(c < 10 for c in counters):
+ try:
+ assert z["x"] in (b"123", b"456")
+ counters[3] += 1
+ except KeyError:
+ pass
+
+ with ThreadPoolExecutor(4) as ex:
+ futures = [
+ ex.submit(w_set),
+ ex.submit(w_update),
+ ex.submit(w_del),
+ ex.submit(w_get),
+ ]
+ for f in futures:
+ f.result()
+
+ z.pop("x", None)
+
+
+def check_closing(z: ZictBase) -> None:
z.close()
+
+
+def check_bad_key_types(z: MutableMapping, has_del: bool = True) -> None:
+ """z does not accept any Hashable as keys.
+ Test that it reacts correctly when confronted with an invalid key type.
+ """
+ bad = object()
+
+ assert bad not in z
+ assert bad not in z.keys()
+ assert (bad, b"123") not in z.items()
+
+ with pytest.raises(TypeError):
+ z[bad] = b"123"
+ with pytest.raises(TypeError):
+ z.update({bad: b"123"})
+ with pytest.raises(KeyError):
+ z[bad]
+ if has_del:
+ with pytest.raises(KeyError):
+ del z[bad]
+
+
+def check_bad_value_types(z: MutableMapping) -> None:
+ """z does not accept any Python object as values.
+ Test that it reacts correctly when confronted with an invalid value type.
+ """
+ bad = object()
+
+ assert bad not in z.values()
+ assert ("x", bad) not in z.items()
+
+ with pytest.raises(TypeError):
+ z["x"] = bad
+ with pytest.raises(TypeError):
+ z.update({"x": bad})
+
+
+class SimpleDict(ZictBase, UserDict):
+ def __init__(self):
+ ZictBase.__init__(self)
+ UserDict.__init__(self)
+
+
+class SlowDict(UserDict):
+ def __init__(self, delay):
+ self.delay = delay
+ super().__init__(self)
+
+ def __getitem__(self, key):
+ time.sleep(self.delay)
+ return super().__getitem__(key)
+
+ def __setitem__(self, key, value):
+ time.sleep(self.delay)
+ super().__setitem__(key, value)
diff --git a/zict/utils.py b/zict/utils.py
new file mode 100644
index 0000000..438310b
--- /dev/null
+++ b/zict/utils.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from collections.abc import Iterable, Iterator
+from typing import MutableSet # TODO import from collections.abc (needs Python >=3.9)
+
+from zict.common import T
+
+
+class InsertionSortedSet(MutableSet[T]):
+ """A set-like that retains insertion order, like a dict. Thread-safe.
+
+ Equality does not compare order or class, but only compares against the contents of
+ any other set-like, coherently with dict and the AbstractSet design.
+ """
+
+ _d: dict[T, None]
+ __slots__ = ("_d",)
+
+ def __init__(self, other: Iterable[T] = ()) -> None:
+ self._d = dict.fromkeys(other)
+
+ def __contains__(self, item: object) -> bool:
+ return item in self._d
+
+ def __iter__(self) -> Iterator[T]:
+ return iter(self._d)
+
+ def __len__(self) -> int:
+ return len(self._d)
+
+ def add(self, value: T) -> None:
+ """Add element to the set. If the element is already in the set, retain original
+ insertion order.
+ """
+ self._d[value] = None
+
+ def discard(self, value: T) -> None:
+ # Don't trust the thread-safety of self._d.pop(value, None)
+ try:
+ del self._d[value]
+ except KeyError:
+ pass
+
+ def remove(self, value: T) -> None:
+ del self._d[value]
+
+ def popleft(self) -> T:
+ """Pop the oldest-inserted key from the set"""
+ while True:
+ try:
+ value = next(iter(self._d))
+ del self._d[value]
+ return value
+ except StopIteration:
+ raise KeyError("pop from an empty set")
+ except (KeyError, RuntimeError):
+ # Multithreaded race condition
+ continue
+
+ def popright(self) -> T:
+ """Pop the latest-inserted key from the set"""
+ return self._d.popitem()[0]
+
+ pop = popright
+
+ def clear(self) -> None:
+ self._d.clear()
diff --git a/zict/zip.py b/zict/zip.py
index b2ca84d..207349e 100644
--- a/zict/zip.py
+++ b/zict/zip.py
@@ -3,13 +3,13 @@ from __future__ import annotations
import zipfile
from collections.abc import Iterator
from typing import MutableMapping # TODO move to collections.abc (needs Python >=3.9)
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
- # TODO: move to typing on Python 3.8+ and 3.10+ respectively
- from typing_extensions import Literal, TypeAlias
+ # TODO: import from typing (needs Python >=3.10)
+ from typing_extensions import TypeAlias
- FileMode: TypeAlias = Literal["r", "w", "x", "a"]
+FileMode: TypeAlias = Literal["r", "w", "x", "a"]
class Zip(MutableMapping[str, bytes]):
@@ -22,6 +22,11 @@ class Zip(MutableMapping[str, bytes]):
filename: string
mode: string, ('r', 'w', 'a'), defaults to 'a'
+ Notes
+ -----
+ None of this class is thread-safe - not even normally trivial methods such as
+ ``__len__ `` or ``__contains__``.
+
Examples
--------
>>> z = Zip('myfile.zip') # doctest: +SKIP
@@ -36,6 +41,7 @@ class Zip(MutableMapping[str, bytes]):
_file: zipfile.ZipFile | None
def __init__(self, filename: str, mode: FileMode = "a"):
+ super().__init__()
self.filename = filename
self.mode = mode
self._file = None
@@ -49,25 +55,32 @@ class Zip(MutableMapping[str, bytes]):
return self._file
def __getitem__(self, key: str) -> bytes:
+ if not isinstance(key, str):
+ raise KeyError(key)
return self.file.read(key)
- def __setitem__(self, key: str, value: bytes) -> None:
+ def __setitem__(self, key: str, value: bytes | bytearray | memoryview) -> None:
+ if not isinstance(key, str):
+ raise TypeError(key)
+ if not isinstance(value, (bytes, bytearray, memoryview)):
+ raise TypeError(value)
+ if key in self:
+ raise NotImplementedError("Not supported by stdlib zipfile")
self.file.writestr(key, value)
- # FIXME dictionary views https://github.com/dask/zict/issues/61
- def keys(self) -> Iterator[str]: # type: ignore
+ def __iter__(self) -> Iterator[str]:
return (zi.filename for zi in self.file.filelist)
- def values(self) -> Iterator[bytes]: # type: ignore
- return (self.file.read(key) for key in self.keys())
-
- def items(self) -> Iterator[tuple[str, bytes]]: # type: ignore
- return ((zi.filename, self.file.read(zi.filename)) for zi in self.file.filelist)
-
- def __iter__(self) -> Iterator[str]:
- return self.keys()
+ def __contains__(self, key: object) -> bool:
+ if not isinstance(key, str):
+ return False
+ try:
+ self.file.getinfo(key)
+ return True
+ except KeyError:
+ return False
- def __delitem__(self, key: str) -> None:
+ def __delitem__(self, key: str) -> None: # pragma: nocover
raise NotImplementedError("Not supported by stdlib zipfile")
def __len__(self) -> int:
@@ -87,5 +100,5 @@ class Zip(MutableMapping[str, bytes]):
def __enter__(self) -> Zip:
return self
- def __exit__(self, type, value, traceback) -> None:
+ def __exit__(self, *args: Any) -> None:
self.close()