Skip to content

Commit 71f4501

Browse files
committed
Refactor, fix bugs
1 parent 17f58a1 commit 71f4501

File tree

1 file changed

+219
-57
lines changed

1 file changed

+219
-57
lines changed

src/filelock/_read_write.py

+219-57
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,85 @@
11
import os
22
import sqlite3
33
import threading
4-
4+
import logging
55
from _error import Timeout
6-
from filelock._api import BaseFileLock
7-
8-
class _SQLiteLock(BaseFileLock):
9-
def __init__(self, lock_file: str | os.PathLike[str], timeout: float = -1, blocking: bool = True):
10-
super().__init__(lock_file, timeout, blocking)
11-
self.procLock = threading.Lock()
12-
self.con = sqlite3.connect(self._context.lock_file, check_same_thread=False)
13-
# Redundant unless there are "rogue" processes that open the db
14-
# and switch the the db to journal_mode=WAL.
6+
from filelock._api import AcquireReturnProxy, BaseFileLock
7+
from typing import Literal, Any
8+
from contextlib import contextmanager
9+
from weakref import WeakValueDictionary
10+
11+
_LOGGER = logging.getLogger("filelock")
12+
13+
# PRAGMA busy_timeout=N delegates to https://www.sqlite.org/c3ref/busy_timeout.html,
14+
# which accepts an int argument, which has the maximum value of 2_147_483_647 on 32-bit
15+
# systems. Use even a lower value to be safe. This 2 bln milliseconds is about 23 days.
16+
_MAX_SQLITE_TIMEOUT_MS = 2_000_000_000 - 1
17+
18+
def timeout_for_sqlite(timeout: float = -1, blocking: bool = True) -> int:
19+
if blocking is False:
20+
return 0
21+
if timeout == -1:
22+
return _MAX_SQLITE_TIMEOUT_MS
23+
if timeout < 0:
24+
raise ValueError("timeout must be a non-negative number or -1")
25+
26+
assert timeout >= 0
27+
timeout_ms = int(timeout * 1000)
28+
if timeout_ms > _MAX_SQLITE_TIMEOUT_MS or timeout_ms < 0:
29+
_LOGGER.warning("timeout %s is too large for SQLite, using %s ms instead", timeout, _MAX_SQLITE_TIMEOUT_MS)
30+
return _MAX_SQLITE_TIMEOUT_MS
31+
return timeout_ms
32+
33+
34+
class _ReadWriteLockMeta(type):
35+
"""Metaclass that redirects instance creation to get_lock() when is_singleton=True."""
36+
def __call__(cls, lock_file: str | os.PathLike[str],
37+
timeout: float = -1, blocking: bool = True,
38+
is_singleton: bool = True, *args: Any, **kwargs: Any) -> "ReadWriteLock":
39+
if is_singleton:
40+
return cls.get_lock(lock_file, timeout, blocking)
41+
return super().__call__(lock_file, timeout, blocking, is_singleton, *args, **kwargs)
42+
43+
44+
class ReadWriteLock(metaclass=_ReadWriteLockMeta):
45+
# Singleton storage and its lock.
46+
_instances = WeakValueDictionary()
47+
_instances_lock = threading.Lock()
48+
49+
@classmethod
50+
def get_lock(cls, lock_file: str | os.PathLike[str],
51+
timeout: float = -1, blocking: bool = True) -> "ReadWriteLock":
52+
"""Return the one-and-only ReadWriteLock for a given file."""
53+
normalized = os.path.abspath(lock_file)
54+
with cls._instances_lock:
55+
if normalized not in cls._instances:
56+
cls._instances[normalized] = cls(lock_file, timeout, blocking)
57+
instance = cls._instances[normalized]
58+
if instance.timeout != timeout or instance.blocking != blocking:
59+
raise ValueError("Singleton lock created with timeout=%s, blocking=%s, cannot be changed to timeout=%s, blocking=%s", instance.timeout, instance.blocking, timeout, blocking)
60+
return instance
61+
62+
def __init__(
63+
self,
64+
lock_file: str | os.PathLike[str],
65+
timeout: float = -1,
66+
blocking: bool = True,
67+
is_singleton: bool = True,
68+
) -> None:
69+
self.lock_file = lock_file
70+
self.timeout = timeout
71+
self.blocking = blocking
72+
# _transaction_lock is for the SQLite transaction work.
73+
self._transaction_lock = threading.Lock()
74+
# _internal_lock protects the short critical sections that update _lock_level
75+
# and rollback the transaction in release().
76+
self._internal_lock = threading.Lock()
77+
self._lock_level = 0 # Reentrance counter.
78+
# _current_mode holds the active lock mode ("read" or "write") or None if no lock is held.
79+
self._current_mode: Literal["read", "write", None] = None
80+
# _lock_level is the reentrance counter.
81+
self._lock_level = 0
82+
self.con = sqlite3.connect(self.lock_file, check_same_thread=False)
1583
# Using the legacy journal mode rather than more modern WAL mode because,
1684
# apparently, in WAL mode it's impossible to enforce that read transactions
1785
# (started with BEGIN TRANSACTION) are blocked if a concurrent write transaction,
@@ -20,55 +88,149 @@ def __init__(self, lock_file: str | os.PathLike[str], timeout: float = -1, block
2088
# it seems, it's possible to do this read-write locking without table data
2189
# modification at each exclusive lock.
2290
# See https://sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
23-
self.con.execute('PRAGMA journal_mode=DELETE;')
24-
self.cur = None
25-
26-
def _release(self):
27-
with self.procLock:
28-
if self.cur is None:
29-
return # Nothing to release
30-
try:
31-
self.cur.execute('ROLLBACK TRANSACTION;')
32-
except sqlite3.ProgrammingError:
33-
pass # Already rolled back or transaction not active
34-
finally:
35-
self.cur.close()
36-
self.cur = None
37-
38-
class WriteLock(_SQLiteLock):
39-
def _acquire(self) -> None:
40-
timeout_ms = int(self._context.timeout*1000) if self._context.blocking else 0
41-
with self.procLock:
42-
if self.cur is not None:
43-
return
91+
# "MEMORY" journal mode is fine because no actual writes to the are happening in write-lock
92+
# acquire, so crashes cannot adversely affect the DB. Even journal_mode=OFF would probably
93+
# be fine, too, but the SQLite documentation says that ROLLBACK becomes *undefined behaviour*
94+
# with journal_mode=OFF which sounds scarier.
95+
self.con.execute('PRAGMA journal_mode=MEMORY;')
96+
97+
def acquire_read(self, timeout: float = -1, blocking: bool = True) -> AcquireReturnProxy:
98+
"""Acquire a read lock. If a lock is already held, it must be a read lock.
99+
Upgrading from read to write is prohibited."""
100+
with self._internal_lock:
101+
if self._lock_level > 0:
102+
# Must already be in read mode.
103+
if self._current_mode != "read":
104+
raise RuntimeError("Cannot acquire read lock when a write lock is held (no upgrade allowed)")
105+
self._lock_level += 1
106+
return AcquireReturnProxy(lock=self)
107+
108+
timeout_ms = timeout_for_sqlite(timeout, blocking)
109+
110+
# Acquire the transaction lock so that the (possibly blocking) SQLite work
111+
# happens without conflicting with other threads' transaction work.
112+
if not self._transaction_lock.acquire(blocking, timeout):
113+
raise Timeout(self.lock_file)
114+
try:
115+
# Double-check: another thread might have completed acquisition meanwhile.
116+
with self._internal_lock:
117+
if self._lock_level > 0:
118+
# Must already be in read mode.
119+
if self._current_mode != "read":
120+
raise RuntimeError("Cannot acquire read lock when a write lock is held (no upgrade allowed)")
121+
self._lock_level += 1
122+
return AcquireReturnProxy(lock=self)
123+
44124
self.con.execute('PRAGMA busy_timeout=?;', (timeout_ms,))
45-
try:
46-
self.cur = self.con.execute('BEGIN EXCLUSIVE TRANSACTION;')
47-
except sqlite3.OperationalError as e:
48-
if 'database is locked' not in str(e):
49-
raise # Re-raise unexpected errors
50-
raise Timeout(self._context.lock_file)
51-
52-
class ReadLock(_SQLiteLock):
53-
def _acquire(self):
54-
timeout_ms = int(self._context.timeout * 1000) if self._context.blocking else 0
55-
with self.procLock:
56-
if self.cur is not None:
57-
return
125+
self.con.execute('BEGIN TRANSACTION;')
126+
# Need to make SELECT to compel SQLite to actually acquire a SHARED db lock.
127+
# See https://www.sqlite.org/lockingv3.html#transaction_control
128+
self.con.execute('SELECT name from sqlite_schema LIMIT 1;')
129+
130+
with self._internal_lock:
131+
self._current_mode = "read"
132+
self._lock_level = 1
133+
134+
return AcquireReturnProxy(lock=self)
135+
136+
except sqlite3.OperationalError as e:
137+
if 'database is locked' not in str(e):
138+
raise # Re-raise unexpected errors.
139+
raise Timeout(self.lock_file)
140+
finally:
141+
self._transaction_lock.release()
142+
143+
def acquire_write(self, timeout: float = -1, blocking: bool = True) -> AcquireReturnProxy:
144+
"""Acquire a write lock. If a lock is already held, it must be a write lock.
145+
Upgrading from read to write is prohibited."""
146+
with self._internal_lock:
147+
if self._lock_level > 0:
148+
if self._current_mode != "write":
149+
raise RuntimeError("Cannot acquire write lock: already holding a read lock (no upgrade allowed)")
150+
self._lock_level += 1
151+
return AcquireReturnProxy(lock=self)
152+
153+
timeout_ms = timeout_for_sqlite(timeout, blocking)
154+
if not self._transaction_lock.acquire(blocking, timeout):
155+
raise Timeout(self.lock_file)
156+
try:
157+
# Double-check: another thread might have completed acquisition meanwhile.
158+
with self._internal_lock:
159+
if self._lock_level > 0:
160+
if self._current_mode != "write":
161+
raise RuntimeError("Cannot acquire write lock: already holding a read lock (no upgrade allowed)")
162+
self._lock_level += 1
163+
return AcquireReturnProxy(lock=self)
164+
58165
self.con.execute('PRAGMA busy_timeout=?;', (timeout_ms,))
59-
cur = None # Initialize cur to avoid potential UnboundLocalError
60-
try:
61-
cur = self.con.execute('BEGIN TRANSACTION;')
62-
# BEGIN doesn't itself acquire a SHARED lock on the db, that is needed for
63-
# effective exclusion with writeLock(). A SELECT is needed.
64-
cur.execute('SELECT name from sqlite_schema LIMIT 1;')
65-
self.cur = cur
66-
except sqlite3.OperationalError as e:
67-
if 'database is locked' not in str(e):
68-
raise # Re-raise unexpected errors
69-
if cur is not None:
70-
cur.close()
71-
raise Timeout(self._context.lock_file)
166+
self.con.execute('BEGIN EXCLUSIVE TRANSACTION;')
72167

168+
with self._internal_lock:
169+
self._current_mode = "write"
170+
self._lock_level = 1
171+
172+
return AcquireReturnProxy(lock=self)
173+
174+
except sqlite3.OperationalError as e:
175+
if 'database is locked' not in str(e):
176+
raise # Re-raise if it is an unexpected error.
177+
raise Timeout(self.lock_file)
178+
finally:
179+
self._transaction_lock.release()
180+
181+
def release(self, force: bool = False) -> None:
182+
with self._internal_lock:
183+
if self._lock_level == 0:
184+
if force:
185+
return
186+
raise RuntimeError("Cannot release a lock that is not held")
187+
if force:
188+
self._lock_level = 0
189+
else:
190+
self._lock_level -= 1
191+
if self._lock_level == 0:
192+
# Clear current mode and rollback the SQLite transaction.
193+
self._current_mode = None
194+
# Unless there are bugs in this code, sqlite3.ProgrammingError
195+
# must not be raise here, that is, the transaction should have been
196+
# started in acquire().
197+
self.con.rollback()
198+
199+
# ----- Context Manager Protocol -----
200+
# (We provide two context managers as helpers.)
201+
202+
@contextmanager
203+
def read_lock(self, timeout: float | None = None,
204+
blocking: bool | None = None):
205+
"""Context manager for acquiring a read lock.
206+
Attempts to upgrade to write lock are disallowed."""
207+
if timeout is None:
208+
timeout = self.timeout
209+
if blocking is None:
210+
blocking = self.blocking
211+
self.acquire_read(timeout, blocking)
212+
try:
213+
yield
214+
finally:
215+
self.release()
216+
217+
@contextmanager
218+
def write_lock(self, timeout: float | None = None,
219+
blocking: bool | None = None):
220+
"""Context manager for acquiring a write lock.
221+
Acquiring read locks on the same file while helding a write lock is prohibited."""
222+
if timeout is None:
223+
timeout = self.timeout
224+
if blocking is None:
225+
blocking = self.blocking
226+
self.acquire_write(timeout, blocking)
227+
try:
228+
yield
229+
finally:
230+
self.release()
231+
232+
def __del__(self) -> None:
233+
"""Called when the lock object is deleted."""
234+
self.release(force=True)
73235

74236

0 commit comments

Comments
 (0)