Skip to content

Commit

Permalink
Merge pull request #51 from kamo-naoyuki/nttcslab-sp
Browse files Browse the repository at this point in the history
Add option of "max_cache_fd"
  • Loading branch information
kamo-naoyuki authored Sep 25, 2020
2 parents 1164391 + d9f41dd commit 72f9088
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 7 deletions.
31 changes: 26 additions & 5 deletions kaldiio/matio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from kaldiio.compression_header import GlobalHeader
from kaldiio.compression_header import PerColHeader
from kaldiio.utils import LazyLoader
from kaldiio.utils import LimitedSizeDict
from kaldiio.utils import MultiFileDescriptor
from kaldiio.utils import default_encoding
from kaldiio.utils import open_like_kaldi
Expand Down Expand Up @@ -58,7 +59,7 @@ def from_bytes(s, endianess="little"):
return int(codecs.encode(s, "hex"), 16)


def load_scp(fname, endian="<", separator=None, segments=None):
def load_scp(fname, endian="<", separator=None, segments=None, max_cache_fd=0):
"""Lazy loader for kaldi scp file.
Args:
Expand All @@ -68,8 +69,16 @@ def load_scp(fname, endian="<", separator=None, segments=None):
segments (str): The path of segments
"""
assert endian in ("<", ">"), endian

if max_cache_fd != 0:
if segments is not None:
raise ValueError("max_cache_fd is not supported for segments mode")
d = LimitedSizeDict(max_cache_fd)
else:
d = None

if segments is None:
load_func = partial(load_mat, endian=endian)
load_func = partial(load_mat, endian=endian, fd_dict=d)
loader = LazyLoader(load_func)
with open_like_kaldi(fname, "r") as fd:
for line in fd:
Expand Down Expand Up @@ -213,11 +222,23 @@ def _return(self, array, st, et):
return rate, array[int(st * rate) :]


def load_mat(ark_name, endian="<"):
def load_mat(ark_name, endian="<", fd_dict=None):
assert endian in ("<", ">"), endian
if fd_dict is not None and not isinstance(fd_dict, Mapping):
raise RuntimeError(
"fd_dict must be dict or None, bot got {}".format(type(fd_dict))
)

ark, offset, slices = _parse_arkpath(ark_name)
with open_like_kaldi(ark, "rb") as fd:

if fd_dict is not None and not (ark.strip()[-1] == "|" or ark.strip()[0] == "|"):
if ark not in fd_dict:
fd_dict[ark] = open_like_kaldi(ark, "rb")
fd = fd_dict[ark]
return _load_mat(fd, offset, slices, endian=endian)
else:
with open_like_kaldi(ark, "rb") as fd:
return _load_mat(fd, offset, slices, endian=endian)


def _parse_arkpath(ark_name):
Expand All @@ -239,7 +260,7 @@ def _parse_arkpath(ark_name):
"""

if ark_name.rstrip()[-1] == "|" or ark_name.rstrip()[0] == "|":
if ark_name.strip()[-1] == "|" or ark_name.strip()[0] == "|":
# Something like: "| cat foo" or "cat bar|" shouldn't be parsed
return ark_name, None, None

Expand Down
29 changes: 29 additions & 0 deletions kaldiio/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import unicode_literals

from collections import OrderedDict
from contextlib import contextmanager
import io
from io import TextIOBase
Expand Down Expand Up @@ -512,3 +513,31 @@ def seekable(f):
return True
else:
return False


class LimitedSizeDict(MutableMapping):
def __init__(self, maxsize):
self._maxsize = maxsize
self.data = OrderedDict()

def __repr__(self):
return repr(self.data)

def __setitem__(self, key, value):
if len(self) >= self._maxsize:
self.data.pop(next(iter(self.data)))

self.data[key] = value

def __getitem__(self, item):
return self.data[item]

def __delitem__(self, key):
self._maxsize -= 1
del self.data[key]

def __iter__(self):
return iter(self.data)

def __len__(self):
return len(self.data)
23 changes: 23 additions & 0 deletions tests/test_limited_size_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# coding: utf-8
from __future__ import unicode_literals

from kaldiio.utils import LimitedSizeDict


def test_limted_size_dict():
d = LimitedSizeDict(3)
d["foo"] = 1
d["bar"] = 2
d["baz"] = 3

assert "foo" in d
assert "bar" in d
assert "baz" in d

d["foo2"] = 4
assert "foo" not in d
assert "foo2" in d

d["bar2"] = 4
assert "bar" not in d
assert "bar2" in d
7 changes: 5 additions & 2 deletions tests/test_mat_ark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_read_arks(fname):
)
@pytest.mark.parametrize("endian", ["<", ">"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_write_read(tmpdir, shape1, shape2, endian, dtype):
@pytest.mark.parametrize("max_cache_fd", [0, 3])
def test_write_read(tmpdir, shape1, shape2, endian, dtype, max_cache_fd):
path = tmpdir.mkdir("test")

a = np.random.rand(*shape1).astype(dtype)
Expand All @@ -45,7 +46,9 @@ def test_write_read(tmpdir, shape1, shape2, endian, dtype):
d2 = {k: v for k, v in kaldiio.load_ark(path.join("a.ark").strpath, endian=endian)}
d5 = {
k: v
for k, v in kaldiio.load_scp(path.join("b.scp").strpath, endian=endian).items()
for k, v in kaldiio.load_scp(
path.join("b.scp").strpath, endian=endian, max_cache_fd=max_cache_fd
).items()
}
with io.open(path.join("a.ark").strpath, "rb") as fd:
d6 = {k: v for k, v in kaldiio.load_ark(fd, endian=endian)}
Expand Down

0 comments on commit 72f9088

Please sign in to comment.