Skip to content

Commit

Permalink
Enhance valstore Python speed (#262)
Browse files Browse the repository at this point in the history
  • Loading branch information
weiliw-amz authored Oct 11, 2023
1 parent a19b825 commit 6bad2af
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 55 deletions.
2 changes: 1 addition & 1 deletion pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,7 +1925,7 @@ def link_mmap_valstore_methods(self):

self.mmap_valstore_fn_dict = {
"float32": self._get_float32_mmap_valstore_methods(),
"bytes": self._get_bytes_mmap_valstore_methods(),
"str": self._get_bytes_mmap_valstore_methods(),
}

def mmap_valstore_init(self, store_type):
Expand Down
90 changes: 40 additions & 50 deletions pecos/utils/mmap_valstore_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class MmapValStoreBatchGetter(object):
Args:
max_row_size: Maximum row size
max_col_size: Maximum column size
trunc_val_len (Optional): Applicable for bytes value only. Truncated max length.
trunc_val_len (Optional): Applicable for str value only. Truncated max length.
threads (Optional): Number of threads to use.
"""

Expand Down Expand Up @@ -171,7 +171,7 @@ def get(self, rows, cols):
self.val_prealloc.ret_vals,
self.threads_c_uint32,
)
return self.val_prealloc.get_ret_memoryview(n_rows, n_cols)
return self.val_prealloc.format_ret(n_rows, n_cols)


class _MmapValStoreBase(object):
Expand Down Expand Up @@ -213,8 +213,8 @@ def init(cls, store_type, store_dir, lazy_load):

if store_type == "float32":
return _MmapValStoreFloat32ReadOnly(store_ptr, fn_dict)
elif store_type == "bytes":
return _MmapValStoreBytesReadOnly(store_ptr, fn_dict)
elif store_type == "str":
return _MmapValStoreStrReadOnly(store_ptr, fn_dict)
else:
raise NotImplementedError(f"{store_type}")

Expand Down Expand Up @@ -245,29 +245,17 @@ def __init__(self, max_row_size: int, max_col_size: int):

self.ret_vals = self.vals_ptr

def get_ret_memoryview(self, n_rows, n_cols):
def format_ret(self, n_rows, n_cols):
"""
Reshape return into desired shape (row-major), so elements could be retrieved by indices:
ret[i, j], 0<=i<n_rows, 0<=j<n_cols
Return can also be assigned to row-major Numpy array of same shape:
arr = np.zeros((n_rows, n_cols), dtype=np.float32)
arr.flat[:] = ret
This also works:
arr = np.array(ret)
Casting/Reshaping a memoryview does not copy data.
See: https://docs.python.org/3/library/stdtypes.html#memoryview.cast
Reshape return into desired shape (row-major), so elements could be retrieved by indices.
Numpy array slice & reshape does not copy.
"""
# Casting to bytes first then cast to float32 with desired shape.
# 'f' = float32
# For types, see: https://docs.python.org/3/library/struct.html#format-characters
return memoryview(self.vals)[: n_rows * n_cols].cast("c").cast("f", shape=[n_rows, n_cols])
return self.vals[: n_rows * n_cols].reshape(n_rows, n_cols)


class _MmapValStoreBytesReadOnly(_MmapValStoreReadOnly):
class _MmapValStoreStrReadOnly(_MmapValStoreReadOnly):
"""
Bytes value store read only implementation.
Str value store read only implementation.
"""

def batch_get(self, n_rows, n_cols, rows_ptr, cols_ptr, ret_vals, threads_c_uint32):
Expand All @@ -286,12 +274,12 @@ def batch_get(self, n_rows, n_cols, rows_ptr, cols_ptr, ret_vals, threads_c_uint

@classmethod
def get_val_alloc(cls, max_row_size: int, max_col_size: int, trunc_val_len: int = 256):
return _BytesBatchGetterValPreAlloc(max_row_size, max_col_size, trunc_val_len)
return _StrBatchGetterValPreAlloc(max_row_size, max_col_size, trunc_val_len)


class _BytesBatchGetterValPreAlloc(object):
class _StrBatchGetterValPreAlloc(object):
"""
Batch return value pre-allocate for Bytes MmapValStore.
Batch return value pre-allocate for Str MmapValStore.
"""

def __init__(self, max_row_size: int, max_col_size: int, trunc_val_len: int):
Expand All @@ -304,25 +292,26 @@ def __init__(self, max_row_size: int, max_col_size: int, trunc_val_len: int):
self.trunc_val_len = trunc_val_len
self.ret_vals = (c_uint32(trunc_val_len), self.vals_ptr, self.vals_lens_ptr)

def get_ret_memoryview(self, n_rows, n_cols):
"""
Reshape return into memoryview of bytes matrix
"""
mat_mv = memoryview(self.vals)
len_mv = memoryview(self.vals_lens)

def start_loc(i, j):
return i * n_cols * self.trunc_val_len + j * self.trunc_val_len

ret_mat_mv = [
[
mat_mv[start_loc(i, j) : start_loc(i, j) + len_mv[i * n_cols + j]]
for j in range(n_cols)
]
for i in range(n_rows)
# Pre-calculated memory view of each string
# For str decoding, from memory view is faster than from Numpy view
bytes_start_loc = [idx * self.trunc_val_len for idx in range(max_row_size * max_col_size)]
self.byte_mem_views = [
memoryview(self.vals[start_idx : start_idx + self.trunc_val_len])
for start_idx in bytes_start_loc
]

return ret_mat_mv
# Buffer for return string objects
self.ret_obj = np.zeros(max_row_size * max_col_size, dtype=np.object_)

def format_ret(self, n_rows, n_cols):
"""
Reshape return into decoded string matrix
"""
for idx in range(n_rows * n_cols):
self.ret_obj[idx] = str(
self.byte_mem_views[idx][: self.vals_lens[idx]], "utf-8", "ignore"
)
return self.ret_obj[: n_rows * n_cols].reshape(n_rows, n_cols).tolist()


class _MmapValStoreWrite(_MmapValStoreBase):
Expand Down Expand Up @@ -355,8 +344,8 @@ def init(cls, store_type, store_dir):

if store_type == "float32":
return _MmapValStoreFloat32Write(store_ptr, fn_dict, store_dir)
elif store_type == "bytes":
return _MmapValStoreBytesWrite(store_ptr, fn_dict, store_dir)
elif store_type == "str":
return _MmapValStoreStrWrite(store_ptr, fn_dict, store_dir)
else:
raise NotImplementedError(f"{store_type}")

Expand All @@ -381,20 +370,21 @@ def from_vals(self, vals):
self.vals = vals


class _MmapValStoreBytesWrite(_MmapValStoreWrite):
class _MmapValStoreStrWrite(_MmapValStoreWrite):
def from_vals(self, vals):
"""
Args:
vals: Tuple (n_row, n_col, bytes_list)
vals: Tuple (n_row, n_col, str_list)
n_row: Number of rows
n_col: Number of columns
bytes_list: List of UTF-8 encoded strings
str_list: List of strings
"""
n_row, n_col, bytes_list = vals
n_row, n_col, str_list = vals
n_total = n_row * n_col
if len(bytes_list) != n_total:
raise ValueError(f"Should get length {n_total} bytes list, got: {len(bytes_list)}")
if len(str_list) != n_total:
raise ValueError(f"Should get length {n_total} string list, got: {len(str_list)}")

bytes_list = [ss.encode("utf-8") for ss in str_list]
bytes_ptr = (c_char_p * n_total)()
bytes_ptr[:] = bytes_list
bytes_lens = np.array([len(s) for s in bytes_list], dtype=np.uint32)
Expand Down
14 changes: 10 additions & 4 deletions test/pecos/utils/test_mmap_valstore_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def test_str_mmap_valstore(tmpdir):
# ['4', '44', '444']]
n_row = 5
n_col = 3
str_list = [[f"{j}".encode("UTF-8") * (i + 1) for i in range(n_col)] for j in range(n_row)]
str_list = [[f"{j}" * (i + 1) for i in range(n_col)] for j in range(n_row)]
flat_str_list = [item for sublist in str_list for item in sublist]

# Write-only Mode
w_store = MmapValStore("bytes")
w_store = MmapValStore("str")
w_store.open("w", store_dir)
# from array
w_store.store.from_vals((n_row, n_col, flat_str_list))
Expand All @@ -70,7 +70,7 @@ def test_str_mmap_valstore(tmpdir):
w_store.close()

# Read-only Mode
r_store = MmapValStore("bytes")
r_store = MmapValStore("str")
r_store.open("r", store_dir)
# Get sub-matrix
vs_getter = MmapValStoreBatchGetter(
Expand All @@ -81,4 +81,10 @@ def test_str_mmap_valstore(tmpdir):
str_sub_mat = vs_getter.get(sub_rows, sub_cols)
for i in range(len(sub_rows)):
for j in range(len(sub_cols)):
assert str_sub_mat[i][j].tobytes() == str_list[sub_rows[i]][sub_cols[j]] # noqa: W503
assert str_sub_mat[i][j] == str_list[sub_rows[i]][sub_cols[j]] # noqa: W503

sub_rows, sub_cols = [4, 4, 1, 2], [1, 2, 0]
str_sub_mat = vs_getter.get(sub_rows, sub_cols)
for i in range(len(sub_rows)):
for j in range(len(sub_cols)):
assert str_sub_mat[i][j] == str_list[sub_rows[i]][sub_cols[j]] # noqa: W503

0 comments on commit 6bad2af

Please sign in to comment.