From ee0173164f9d3876c688147c4a2a5f54407a91b1 Mon Sep 17 00:00:00 2001 From: kamo-naoyuki Date: Sat, 26 Sep 2020 01:00:48 +0900 Subject: [PATCH 1/3] add option of "max_cache_fd" --- kaldiio/matio.py | 31 ++++++++++++++++++++++++++----- kaldiio/utils.py | 29 +++++++++++++++++++++++++++++ tests/test_limited_size_dict.py | 23 +++++++++++++++++++++++ tests/test_mat_ark.py | 7 +++++-- 4 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 tests/test_limited_size_dict.py diff --git a/kaldiio/matio.py b/kaldiio/matio.py index 1f1c4c5..38ec417 100644 --- a/kaldiio/matio.py +++ b/kaldiio/matio.py @@ -17,6 +17,7 @@ from kaldiio.compression_header import GlobalHeader from kaldiio.compression_header import PerColHeader from kaldiio.utils import LazyLoader +from kaldiio.utils import LimitedSizeDict from kaldiio.utils import MultiFileDescriptor from kaldiio.utils import default_encoding from kaldiio.utils import open_like_kaldi @@ -58,7 +59,7 @@ def from_bytes(s, endianess="little"): return int(codecs.encode(s, "hex"), 16) -def load_scp(fname, endian="<", separator=None, segments=None): +def load_scp(fname, endian="<", separator=None, segments=None, max_cache_fd=0): """Lazy loader for kaldi scp file. Args: @@ -68,8 +69,16 @@ def load_scp(fname, endian="<", separator=None, segments=None): segments (str): The path of segments """ assert endian in ("<", ">"), endian + + if max_cache_fd != 0: + if segments is not None: + raise ValueError("max_cache_fd is not supported for segments mode") + d = LimitedSizeDict(max_cache_fd) + else: + d = None + if segments is None: - load_func = partial(load_mat, endian=endian) + load_func = partial(load_mat, endian=endian, fd_dict=d) loader = LazyLoader(load_func) with open_like_kaldi(fname, "r") as fd: for line in fd: @@ -213,11 +222,23 @@ def _return(self, array, st, et): return rate, array[int(st * rate) :] -def load_mat(ark_name, endian="<"): +def load_mat(ark_name, endian="<", fd_dict=None): assert endian in ("<", ">"), endian + if fd_dict is not None and not isinstance(fd_dict, Mapping): + raise RuntimeError( + "fd_dict must be dict or None, bot got {}".format(type(fd_dict)) + ) + ark, offset, slices = _parse_arkpath(ark_name) - with open_like_kaldi(ark, "rb") as fd: + + if fd_dict is not None and not (ark.strip()[-1] == "|" or ark.strip()[0] == "|"): + if ark not in fd_dict: + fd_dict[ark] = open_like_kaldi(ark, "rb") + fd = fd_dict[ark] return _load_mat(fd, offset, slices, endian=endian) + else: + with open_like_kaldi(ark, "rb") as fd: + return _load_mat(fd, offset, slices, endian=endian) def _parse_arkpath(ark_name): @@ -239,7 +260,7 @@ def _parse_arkpath(ark_name): """ - if ark_name.rstrip()[-1] == "|" or ark_name.rstrip()[0] == "|": + if ark_name.strip()[-1] == "|" or ark_name.strip()[0] == "|": # Something like: "| cat foo" or "cat bar|" shouldn't be parsed return ark_name, None, None diff --git a/kaldiio/utils.py b/kaldiio/utils.py index 423997b..c0b129c 100644 --- a/kaldiio/utils.py +++ b/kaldiio/utils.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +from collections import UserDict from contextlib import contextmanager import io from io import TextIOBase @@ -512,3 +513,31 @@ def seekable(f): return True else: return False + + +class LimitedSizeDict(MutableMapping): + def __init__(self, maxsize: int): + self._maxsize = maxsize + self.data = {} + + def __repr__(self): + return repr(self.data) + + def __setitem__(self, key, value): + if len(self) >= self._maxsize: + self.data.pop(next(iter(self.data))) + + self.data[key] = value + + def __getitem__(self, item): + return self.data[item] + + def __delitem__(self, key): + self._maxsize -= 1 + del self.data[key] + + def __iter__(self): + return iter(self.data) + + def __len__(self): + return len(self.data) diff --git a/tests/test_limited_size_dict.py b/tests/test_limited_size_dict.py new file mode 100644 index 0000000..f51744a --- /dev/null +++ b/tests/test_limited_size_dict.py @@ -0,0 +1,23 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from kaldiio.utils import LimitedSizeDict + + +def test_limted_size_dict(): + d = LimitedSizeDict(3) + d["foo"] = 1 + d["bar"] = 2 + d["baz"] = 3 + + assert "foo" in d + assert "bar" in d + assert "baz" in d + + d["foo2"] = 4 + assert "foo" not in d + assert "foo2" in d + + d["bar2"] = 4 + assert "bar" not in d + assert "bar2" in d diff --git a/tests/test_mat_ark.py b/tests/test_mat_ark.py index 9c2a558..42db8c1 100644 --- a/tests/test_mat_ark.py +++ b/tests/test_mat_ark.py @@ -29,7 +29,8 @@ def test_read_arks(fname): ) @pytest.mark.parametrize("endian", ["<", ">"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_write_read(tmpdir, shape1, shape2, endian, dtype): +@pytest.mark.parametrize("max_cache_fd", [0, 3]) +def test_write_read(tmpdir, shape1, shape2, endian, dtype, max_cache_fd): path = tmpdir.mkdir("test") a = np.random.rand(*shape1).astype(dtype) @@ -45,7 +46,9 @@ def test_write_read(tmpdir, shape1, shape2, endian, dtype): d2 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath, endian=endian)} d5 = { k: v - for k, v in kaldiio.load_scp(path.join("b.scp").strpath, endian=endian).items() + for k, v in kaldiio.load_scp( + path.join("b.scp").strpath, endian=endian, max_cache_fd=max_cache_fd + ).items() } with io.open(path.join("a.ark").strpath, "rb") as fd: d6 = {k: v for k, v in kaldiio.load_ark(fd, endian=endian)} From 72ae4df6106f8942622fe95f34af26b2a69060b5 Mon Sep 17 00:00:00 2001 From: kamo-naoyuki Date: Sat, 26 Sep 2020 01:07:05 +0900 Subject: [PATCH 2/3] remove unused import: kaldiio/utils.py --- kaldiio/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/kaldiio/utils.py b/kaldiio/utils.py index c0b129c..66d432b 100644 --- a/kaldiio/utils.py +++ b/kaldiio/utils.py @@ -1,6 +1,5 @@ from __future__ import unicode_literals -from collections import UserDict from contextlib import contextmanager import io from io import TextIOBase From d9f41ddd3c570d8cc7f9a45a1398103fa5c74d68 Mon Sep 17 00:00:00 2001 From: kamo-naoyuki Date: Sat, 26 Sep 2020 01:13:43 +0900 Subject: [PATCH 3/3] fix: kaldiio/utils.py --- kaldiio/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kaldiio/utils.py b/kaldiio/utils.py index 66d432b..ab76da7 100644 --- a/kaldiio/utils.py +++ b/kaldiio/utils.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +from collections import OrderedDict from contextlib import contextmanager import io from io import TextIOBase @@ -515,9 +516,9 @@ def seekable(f): class LimitedSizeDict(MutableMapping): - def __init__(self, maxsize: int): + def __init__(self, maxsize): self._maxsize = maxsize - self.data = {} + self.data = OrderedDict() def __repr__(self): return repr(self.data)