-
Notifications
You must be signed in to change notification settings - Fork 367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add type annotations to base spec #1676
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,17 @@ | |
from errno import ESPIPE | ||
from glob import has_magic | ||
from hashlib import sha256 | ||
from typing import Any, ClassVar, Dict, Tuple | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Callable, | ||
ClassVar, | ||
Dict, | ||
Literal, | ||
List, | ||
Tuple, | ||
overload, | ||
) | ||
|
||
from .callbacks import DEFAULT_CALLBACK | ||
from .config import apply_config, conf | ||
|
@@ -26,6 +36,12 @@ | |
tokenize, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
import datetime | ||
|
||
from .caching import BaseCache | ||
from .mapping import FSMap | ||
|
||
logger = logging.getLogger("fsspec") | ||
|
||
|
||
|
@@ -184,8 +200,16 @@ def __eq__(self, other): | |
def __reduce__(self): | ||
return make_instance, (type(self), self.storage_args, self.storage_options) | ||
|
||
@overload | ||
@classmethod | ||
def _strip_protocol(cls, path: str) -> str: ... | ||
|
||
@overload | ||
@classmethod | ||
def _strip_protocol(cls, path): | ||
def _strip_protocol(cls, path: List[str]) -> List[str]: ... | ||
|
||
@classmethod | ||
def _strip_protocol(cls, path) -> str | List[str]: | ||
"""Turn path from fully-qualified to file-system-specific | ||
May require FS-specific handling, e.g., for relative paths or links. | ||
|
@@ -277,7 +301,7 @@ def invalidate_cache(self, path=None): | |
if self._intrans: | ||
self._invalidated_caches_in_transaction.append(path) | ||
|
||
def mkdir(self, path, create_parents=True, **kwargs): | ||
def mkdir(self, path: str, create_parents: bool = True, **kwargs): | ||
""" | ||
Create directory entry at path | ||
|
@@ -295,7 +319,7 @@ def mkdir(self, path, create_parents=True, **kwargs): | |
""" | ||
pass # not necessary to implement, may not have directories | ||
|
||
def makedirs(self, path, exist_ok=False): | ||
def makedirs(self, path: str, exist_ok: bool = False): | ||
"""Recursively make directories | ||
Creates directory at path and any intervening required directories. | ||
|
@@ -311,11 +335,11 @@ def makedirs(self, path, exist_ok=False): | |
""" | ||
pass # not necessary to implement, may not have directories | ||
|
||
def rmdir(self, path): | ||
def rmdir(self, path: str): | ||
"""Remove a directory, if empty""" | ||
pass # not necessary to implement, may not have directories | ||
|
||
def ls(self, path, detail=True, **kwargs): | ||
def ls(self, path: str, detail: bool = True, **kwargs): | ||
"""List objects at path. | ||
This should include subdirectories and files at that location. The | ||
|
@@ -381,7 +405,7 @@ def _ls_from_cache(self, path): | |
except KeyError: | ||
pass | ||
|
||
def walk(self, path, maxdepth=None, topdown=True, on_error="omit", **kwargs): | ||
def walk(self, path: str, maxdepth=None, topdown=True, on_error="omit", **kwargs): | ||
"""Return all files belows path | ||
List all files, recursing into subdirectories; output is iterator-style, | ||
|
@@ -450,8 +474,8 @@ def walk(self, path, maxdepth=None, topdown=True, on_error="omit", **kwargs): | |
files[name] = info | ||
|
||
if not detail: | ||
dirs = list(dirs) | ||
files = list(files) | ||
dirs = list(dirs) # type: ignore[assignment] | ||
files = list(files) # type: ignore[assignment] | ||
|
||
if topdown: | ||
# Yield before recursion if walking top down | ||
|
@@ -477,7 +501,14 @@ def walk(self, path, maxdepth=None, topdown=True, on_error="omit", **kwargs): | |
# Yield after recursion if walking bottom up | ||
yield path, dirs, files | ||
|
||
def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs): | ||
def find( | ||
self, | ||
path, | ||
maxdepth: int | None = None, | ||
withdirs: bool = False, | ||
detail: bool = False, | ||
**kwargs, | ||
): | ||
"""List all files below path. | ||
Like posix ``find`` command without conditions | ||
|
@@ -549,7 +580,7 @@ def du(self, path, total=True, maxdepth=None, withdirs=False, **kwargs): | |
else: | ||
return sizes | ||
|
||
def glob(self, path, maxdepth=None, **kwargs): | ||
def glob(self, path: str, maxdepth: int | None = None, **kwargs): | ||
""" | ||
Find files by glob-matching. | ||
|
@@ -606,7 +637,7 @@ def glob(self, path, maxdepth=None, **kwargs): | |
depth_double_stars = path[idx_double_stars:].count("/") + 1 | ||
depth = depth - depth_double_stars + maxdepth | ||
else: | ||
depth = None | ||
depth = None # type: ignore[assignment] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Type depth above as an Optional and you won't need this type ignore. |
||
|
||
allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs) | ||
|
||
|
@@ -630,7 +661,7 @@ def glob(self, path, maxdepth=None, **kwargs): | |
else: | ||
return list(out) | ||
|
||
def exists(self, path, **kwargs): | ||
def exists(self, path, **kwargs) -> bool: | ||
"""Is there a file at the given path""" | ||
try: | ||
self.info(path, **kwargs) | ||
|
@@ -639,12 +670,12 @@ def exists(self, path, **kwargs): | |
# any exception allowed bar FileNotFoundError? | ||
return False | ||
|
||
def lexists(self, path, **kwargs): | ||
def lexists(self, path, **kwargs) -> bool: | ||
"""If there is a file at the given path (including | ||
broken links)""" | ||
return self.exists(path) | ||
|
||
def info(self, path, **kwargs): | ||
def info(self, path, **kwargs) -> Dict[str, Any]: | ||
"""Give details of entry at path | ||
Returns a single dictionary, with exactly the same information as ``ls`` | ||
|
@@ -699,14 +730,14 @@ def sizes(self, paths): | |
"""Size in bytes of each file in a list of paths""" | ||
return [self.size(p) for p in paths] | ||
|
||
def isdir(self, path): | ||
def isdir(self, path) -> bool: | ||
"""Is this entry directory-like?""" | ||
try: | ||
return self.info(path)["type"] == "directory" | ||
except OSError: | ||
return False | ||
|
||
def isfile(self, path): | ||
def isfile(self, path) -> bool: | ||
"""Is this entry file-like?""" | ||
try: | ||
return self.info(path)["type"] == "file" | ||
|
@@ -733,7 +764,7 @@ def read_text(self, path, encoding=None, errors=None, newline=None, **kwargs): | |
return f.read() | ||
|
||
def write_text( | ||
self, path, value, encoding=None, errors=None, newline=None, **kwargs | ||
self, path: str, value: str, encoding=None, errors=None, newline=None, **kwargs | ||
): | ||
"""Write the text to the given file. | ||
|
@@ -757,7 +788,7 @@ def write_text( | |
) as f: | ||
return f.write(value) | ||
|
||
def cat_file(self, path, start=None, end=None, **kwargs): | ||
def cat_file(self, path: str, start=None, end=None, **kwargs): | ||
"""Get the content of a file | ||
Parameters | ||
|
@@ -997,11 +1028,11 @@ def put_file(self, lpath, rpath, callback=DEFAULT_CALLBACK, **kwargs): | |
|
||
def put( | ||
self, | ||
lpath, | ||
rpath, | ||
recursive=False, | ||
lpath: str | List[str], | ||
rpath: str | List[str], | ||
recursive: bool = False, | ||
callback=DEFAULT_CALLBACK, | ||
maxdepth=None, | ||
maxdepth: int | None = None, | ||
**kwargs, | ||
): | ||
"""Copy file(s) from local. | ||
|
@@ -1046,8 +1077,8 @@ def put( | |
else [self._strip_protocol(p) for p in rpath] | ||
) | ||
exists = source_is_str and ( | ||
(has_magic(lpath) and source_is_file) | ||
or (not has_magic(lpath) and dest_is_dir and not trailing_sep(lpath)) | ||
(has_magic(lpath) and source_is_file) # type: ignore[arg-type] | ||
or (not has_magic(lpath) and dest_is_dir and not trailing_sep(lpath)) # type: ignore[arg-type] | ||
) | ||
rpaths = other_paths( | ||
lpaths, | ||
|
@@ -1061,12 +1092,12 @@ def put( | |
with callback.branched(lpath, rpath) as child: | ||
self.put_file(lpath, rpath, callback=child, **kwargs) | ||
|
||
def head(self, path, size=1024): | ||
def head(self, path: str, size: int = 1024): | ||
"""Get the first ``size`` bytes from file""" | ||
with self.open(path, "rb") as f: | ||
return f.read(size) | ||
|
||
def tail(self, path, size=1024): | ||
def tail(self, path: str, size: int = 1024): | ||
"""Get the last ``size`` bytes from file""" | ||
with self.open(path, "rb") as f: | ||
f.seek(max(-size, -f.size), 2) | ||
|
@@ -1076,8 +1107,14 @@ def cp_file(self, path1, path2, **kwargs): | |
raise NotImplementedError | ||
|
||
def copy( | ||
self, path1, path2, recursive=False, maxdepth=None, on_error=None, **kwargs | ||
): | ||
self, | ||
path1: str | list[str], | ||
path2: str | list[str], | ||
recursive: bool = False, | ||
maxdepth: int | None = None, | ||
on_error: Literal["ignore", "raise"] | None = None, | ||
**kwargs, | ||
) -> None: | ||
"""Copy within two locations in the filesystem | ||
on_error : "raise", "ignore" | ||
|
@@ -1112,8 +1149,8 @@ def copy( | |
) | ||
|
||
exists = source_is_str and ( | ||
(has_magic(path1) and source_is_file) | ||
or (not has_magic(path1) and dest_is_dir and not trailing_sep(path1)) | ||
(has_magic(path1) and source_is_file) # type: ignore[arg-type] | ||
or (not has_magic(path1) and dest_is_dir and not trailing_sep(path1)) # type: ignore[arg-type] | ||
) | ||
paths2 = other_paths( | ||
paths1, | ||
|
@@ -1129,7 +1166,13 @@ def copy( | |
if on_error == "raise": | ||
raise | ||
|
||
def expand_path(self, path, recursive=False, maxdepth=None, **kwargs): | ||
def expand_path( | ||
self, | ||
path: str | List[str], | ||
recursive: bool = False, | ||
maxdepth: int | None = None, | ||
**kwargs, | ||
) -> list[str]: | ||
"""Turn one or more globs or directories into a list of all matching paths | ||
to files or directories. | ||
|
@@ -1139,6 +1182,8 @@ def expand_path(self, path, recursive=False, maxdepth=None, **kwargs): | |
if maxdepth is not None and maxdepth < 1: | ||
raise ValueError("maxdepth must be at least 1") | ||
|
||
out: set[str] | list[str] | ||
|
||
if isinstance(path, (str, os.PathLike)): | ||
out = self.expand_path([path], recursive, maxdepth) | ||
else: | ||
|
@@ -1177,7 +1222,14 @@ def expand_path(self, path, recursive=False, maxdepth=None, **kwargs): | |
raise FileNotFoundError(path) | ||
return sorted(out) | ||
|
||
def mv(self, path1, path2, recursive=False, maxdepth=None, **kwargs): | ||
def mv( | ||
self, | ||
path1, | ||
path2, | ||
recursive: bool = False, | ||
maxdepth: int | None = None, | ||
**kwargs, | ||
) -> None: | ||
"""Move file(s) from one location to another""" | ||
if path1 == path2: | ||
logger.debug("%s mv: The paths are the same, so no files were moved.", self) | ||
|
@@ -1188,7 +1240,7 @@ def mv(self, path1, path2, recursive=False, maxdepth=None, **kwargs): | |
) | ||
self.rm(path1, recursive=recursive) | ||
|
||
def rm_file(self, path): | ||
def rm_file(self, path) -> None: | ||
"""Delete a file""" | ||
self._rm(path) | ||
|
||
|
@@ -1247,11 +1299,11 @@ def _open( | |
|
||
def open( | ||
self, | ||
path, | ||
mode="rb", | ||
block_size=None, | ||
cache_options=None, | ||
compression=None, | ||
path: str, | ||
mode: str = "rb", | ||
block_size: int | None = None, | ||
cache_options: dict | None = None, | ||
compression: str | None = None, | ||
**kwargs, | ||
): | ||
""" | ||
|
@@ -1312,15 +1364,15 @@ def open( | |
from fsspec.compression import compr | ||
from fsspec.core import get_compression | ||
|
||
compression = get_compression(path, compression) | ||
compress = compr[compression] | ||
f = compress(f, mode=mode[0]) | ||
compression_meth: str = get_compression(path, compression) | ||
compress: Callable[[Any, str], Any] = compr[compression_meth] # type: ignore[index] | ||
f = compress(f, mode=mode[0]) # type: ignore[call-arg] | ||
|
||
if not ac and "r" not in mode: | ||
self.transaction.files.append(f) | ||
return f | ||
|
||
def touch(self, path, truncate=True, **kwargs): | ||
def touch(self, path: str, truncate: bool = True, **kwargs) -> None: | ||
"""Create empty file, or update timestamp | ||
Parameters | ||
|
@@ -1337,11 +1389,13 @@ def touch(self, path, truncate=True, **kwargs): | |
else: | ||
raise NotImplementedError # update timestamp, if possible | ||
|
||
def ukey(self, path): | ||
def ukey(self, path: str) -> str: | ||
"""Hash of file properties, to tell if it has changed""" | ||
return sha256(str(self.info(path)).encode()).hexdigest() | ||
|
||
def read_block(self, fn, offset, length, delimiter=None): | ||
def read_block( | ||
self, fn: str, offset: int, length: int, delimiter: bytes | None = None | ||
) -> bytes: | ||
"""Read a block of bytes from | ||
Starting at ``offset`` of the file, read ``length`` bytes. If | ||
|
@@ -1530,7 +1584,9 @@ def _get_pyarrow_filesystem(self): | |
# all instances already also derive from pyarrow | ||
return self | ||
|
||
def get_mapper(self, root="", check=False, create=False, missing_exceptions=None): | ||
def get_mapper( | ||
self, root="", check=False, create=False, missing_exceptions=None | ||
) -> FSMap: | ||
"""Create key/value store based on this file-system | ||
Makes a MutableMapping interface to the FS at the given root path. | ||
|
@@ -1547,7 +1603,7 @@ def get_mapper(self, root="", check=False, create=False, missing_exceptions=None | |
) | ||
|
||
@classmethod | ||
def clear_instance_cache(cls): | ||
def clear_instance_cache(cls) -> None: | ||
""" | ||
Clear the cache of filesystem instances. | ||
|
@@ -1561,30 +1617,30 @@ def clear_instance_cache(cls): | |
""" | ||
cls._cache.clear() | ||
|
||
def created(self, path): | ||
def created(self, path: str) -> datetime.datetime: | ||
"""Return the created timestamp of a file as a datetime.datetime""" | ||
raise NotImplementedError | ||
|
||
def modified(self, path): | ||
def modified(self, path: str) -> datetime.datetime: | ||
"""Return the modified timestamp of a file as a datetime.datetime""" | ||
raise NotImplementedError | ||
|
||
# ------------------------------------------------------------------------ | ||
# Aliases | ||
|
||
def read_bytes(self, path, start=None, end=None, **kwargs): | ||
def read_bytes(self, path: str, start=None, end=None, **kwargs): | ||
"""Alias of `AbstractFileSystem.cat_file`.""" | ||
return self.cat_file(path, start=start, end=end, **kwargs) | ||
|
||
def write_bytes(self, path, value, **kwargs): | ||
"""Alias of `AbstractFileSystem.pipe_file`.""" | ||
self.pipe_file(path, value, **kwargs) | ||
|
||
def makedir(self, path, create_parents=True, **kwargs): | ||
def makedir(self, path, create_parents=True, **kwargs) -> None: | ||
"""Alias of `AbstractFileSystem.mkdir`.""" | ||
return self.mkdir(path, create_parents=create_parents, **kwargs) | ||
|
||
def mkdirs(self, path, exist_ok=False): | ||
def mkdirs(self, path, exist_ok=False) -> None: | ||
"""Alias of `AbstractFileSystem.makedirs`.""" | ||
return self.makedirs(path, exist_ok=exist_ok) | ||
|
||
|
@@ -1670,14 +1726,14 @@ class AbstractBufferedFile(io.IOBase): | |
|
||
def __init__( | ||
self, | ||
fs, | ||
path, | ||
mode="rb", | ||
block_size="default", | ||
autocommit=True, | ||
fs: AbstractFileSystem, | ||
path: str, | ||
mode: str = "rb", | ||
block_size: int | Literal["default"] | None = "default", | ||
autocommit: bool = True, | ||
cache_type="readahead", | ||
cache_options=None, | ||
size=None, | ||
cache_options: dict | None = None, | ||
size: int | None = None, | ||
**kwargs, | ||
): | ||
""" | ||
|
@@ -1711,10 +1767,10 @@ def __init__( | |
self.path = path | ||
self.fs = fs | ||
self.mode = mode | ||
self.blocksize = ( | ||
self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size | ||
self.blocksize: int = ( | ||
self.DEFAULT_BLOCK_SIZE if block_size in ["default", None] else block_size # type: ignore[assignment] | ||
) | ||
self.loc = 0 | ||
self.loc: int = 0 | ||
self.autocommit = autocommit | ||
self.end = None | ||
self.start = None | ||
|
@@ -1740,7 +1796,7 @@ def __init__( | |
self.size = size | ||
else: | ||
self.size = self.details["size"] | ||
self.cache = caches[cache_type]( | ||
self.cache: BaseCache = caches[cache_type]( | ||
self.blocksize, self._fetch_range, self.size, **cache_options | ||
) | ||
else: | ||
|
@@ -1765,22 +1821,22 @@ def full_name(self): | |
return _unstrip_protocol(self.path, self.fs) | ||
|
||
@property | ||
def closed(self): | ||
def closed(self) -> bool: | ||
# get around this attr being read-only in IOBase | ||
# use getattr here, since this can be called during del | ||
return getattr(self, "_closed", True) | ||
|
||
@closed.setter | ||
def closed(self, c): | ||
def closed(self, c: bool) -> None: | ||
self._closed = c | ||
|
||
def __hash__(self): | ||
def __hash__(self) -> int: | ||
if "w" in self.mode: | ||
return id(self) | ||
else: | ||
return int(tokenize(self.details), 16) | ||
|
||
def __eq__(self, other): | ||
def __eq__(self, other: object) -> bool: | ||
"""Files are equal if they have the same checksum, only in read mode""" | ||
if self is other: | ||
return True | ||
|
@@ -1804,11 +1860,11 @@ def info(self): | |
else: | ||
raise ValueError("Info not available while writing") | ||
|
||
def tell(self): | ||
def tell(self) -> int: | ||
"""Current file location""" | ||
return self.loc | ||
|
||
def seek(self, loc, whence=0): | ||
def seek(self, loc: int, whence: int = 0) -> int: | ||
"""Set current file location | ||
Parameters | ||
|
@@ -1834,7 +1890,7 @@ def seek(self, loc, whence=0): | |
self.loc = nloc | ||
return self.loc | ||
|
||
def write(self, data): | ||
def write(self, data) -> int: | ||
""" | ||
Write data to buffer. | ||
|
@@ -1858,7 +1914,7 @@ def write(self, data): | |
self.flush() | ||
return out | ||
|
||
def flush(self, force=False): | ||
def flush(self, force: bool = False) -> None: | ||
""" | ||
Write buffered data to backend store. | ||
|
@@ -1889,15 +1945,15 @@ def flush(self, force=False): | |
|
||
if self.offset is None: | ||
# Initialize a multipart upload | ||
self.offset = 0 | ||
self.offset = 0 # type: ignore[assignment] | ||
try: | ||
self._initiate_upload() | ||
except: | ||
self.closed = True | ||
raise | ||
|
||
if self._upload_chunk(final=force) is not False: | ||
self.offset += self.buffer.seek(0, 2) | ||
self.offset += self.buffer.seek(0, 2) # type: ignore[assignment,operator] | ||
self.buffer = io.BytesIO() | ||
|
||
def _upload_chunk(self, final=False): | ||
|
@@ -1919,7 +1975,7 @@ def _fetch_range(self, start, end): | |
"""Get the specified set of bytes from remote""" | ||
raise NotImplementedError | ||
|
||
def read(self, length=-1): | ||
def read(self, length: int = -1) -> bytes: | ||
""" | ||
Return data from cache, or fetch pieces as necessary | ||
|
@@ -1950,7 +2006,7 @@ def read(self, length=-1): | |
self.loc += len(out) | ||
return out | ||
|
||
def readinto(self, b): | ||
def readinto(self, b) -> int: | ||
"""mirrors builtin file's readinto method | ||
https://docs.python.org/3/library/io.html#io.RawIOBase.readinto | ||
|
@@ -1960,7 +2016,7 @@ def readinto(self, b): | |
out[: len(data)] = data | ||
return len(data) | ||
|
||
def readuntil(self, char=b"\n", blocks=None): | ||
def readuntil(self, char: bytes = b"\n", blocks: int | None = None) -> bytes: | ||
"""Return data between current position and first occurrence of char | ||
char is included in the output, except if the end of the tile is | ||
|
@@ -1988,15 +2044,15 @@ def readuntil(self, char=b"\n", blocks=None): | |
out.append(part) | ||
return b"".join(out) | ||
|
||
def readline(self): | ||
def readline(self) -> bytes: # type: ignore[override] | ||
"""Read until first occurrence of newline character | ||
Note that, because of character encoding, this is not necessarily a | ||
true line ending. | ||
""" | ||
return self.readuntil(b"\n") | ||
|
||
def __next__(self): | ||
def __next__(self) -> bytes: | ||
out = self.readline() | ||
if out: | ||
return out | ||
|
@@ -2005,7 +2061,7 @@ def __next__(self): | |
def __iter__(self): | ||
return self | ||
|
||
def readlines(self): | ||
def readlines(self) -> List[bytes]: # type: ignore[override] | ||
"""Return all data, split by the newline character""" | ||
data = self.read() | ||
lines = data.split(b"\n") | ||
|
@@ -2016,10 +2072,10 @@ def readlines(self): | |
return out + [lines[-1]] | ||
# return list(self) ??? | ||
|
||
def readinto1(self, b): | ||
def readinto1(self, b) -> int: | ||
return self.readinto(b) | ||
|
||
def close(self): | ||
def close(self) -> None: | ||
"""Close file | ||
Finalizes writes, discards cache | ||
|
@@ -2029,7 +2085,7 @@ def close(self): | |
if self.closed: | ||
return | ||
if self.mode == "rb": | ||
self.cache = None | ||
self.cache = None # type: ignore[assignment] | ||
else: | ||
if not self.forced: | ||
self.flush(force=True) | ||
|
@@ -2040,29 +2096,29 @@ def close(self): | |
|
||
self.closed = True | ||
|
||
def readable(self): | ||
def readable(self) -> bool: | ||
"""Whether opened for reading""" | ||
return self.mode == "rb" and not self.closed | ||
|
||
def seekable(self): | ||
def seekable(self) -> bool: | ||
"""Whether is seekable (only in read mode)""" | ||
return self.readable() | ||
|
||
def writable(self): | ||
def writable(self) -> bool: | ||
"""Whether opened for writing""" | ||
return self.mode in {"wb", "ab"} and not self.closed | ||
|
||
def __del__(self): | ||
def __del__(self) -> None: | ||
if not self.closed: | ||
self.close() | ||
|
||
def __str__(self): | ||
def __str__(self) -> str: | ||
return f"<File-like object {type(self.fs).__name__}, {self.path}>" | ||
|
||
__repr__ = __str__ | ||
|
||
def __enter__(self): | ||
def __enter__(self) -> "AbstractBufferedFile": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May want to consider typing_extensions.Self, if it's not already a project dependency, we should add as it's one of the top 10 most popular pypi downloads anyway. |
||
return self | ||
|
||
def __exit__(self, *args): | ||
def __exit__(self, *args) -> None: | ||
self.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: on_error has only a few valid values, right? We should make it a Union of Type Literals (should probably consider setting that as a TypeVar).