Skip to content

Commit

Permalink
Use NamedTemporaryFile instead of io.BytesIO for msgpack
Browse files Browse the repository at this point in the history
This change is to avoid the need to keep the entire msgpack in memory,
which can be a problem for large data sets.
  • Loading branch information
chezou committed Aug 27, 2024
1 parent a0767f3 commit 8714d12
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 37 deletions.
54 changes: 28 additions & 26 deletions pytd/tests/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import io
import os
import tempfile
import unittest
from unittest.mock import ANY, MagicMock
from unittest.mock import ANY, MagicMock, patch

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -281,14 +281,13 @@ def test_write_dataframe_msgpack(self):
api_client = self.table.client.api_client
self.assertTrue(api_client.create_bulk_import.called)
self.assertTrue(api_client.create_bulk_import().upload_part.called)
_bytes = BulkImportWriter()._write_msgpack_stream(
df.to_dict(orient="records"), io.BytesIO()
)
size = _bytes.getbuffer().nbytes
fp = tempfile.NamedTemporaryFile(delete=False)
fp = BulkImportWriter()._write_msgpack_stream(df.to_dict(orient="records"), fp)
api_client.create_bulk_import().upload_part.assert_called_with(
"part-0", ANY, size
"part-0", ANY, 62
)
self.assertFalse(api_client.create_bulk_import().upload_file.called)
os.unlink(fp.name)

def test_write_dataframe_msgpack_with_int_na(self):
# Although this conversion ensures pd.NA Int64 dtype to None,
Expand All @@ -305,13 +304,14 @@ def test_write_dataframe_msgpack_with_int_na(self):
{"a": 3, "b": 4, "c": 5, "time": 1234},
)
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
print(self.writer._write_msgpack_stream.call_args[0][0][0:2])
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)
with patch("pytd.writer.os.unlink"):
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
print(self.writer._write_msgpack_stream.call_args[0][0][0:2])
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

@unittest.skipIf(
pd.__version__ < "1.0.0", "pd.NA not supported in this pandas version"
Expand All @@ -327,12 +327,13 @@ def test_write_dataframe_msgpack_with_string_na(self):
{"a": "buzz", "b": "buzz", "c": "alice", "time": 1234},
)
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)
with patch("pytd.writer.os.unlink"):
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

@unittest.skipIf(
pd.__version__ < "1.0.0", "pd.NA not supported in this pandas version"
Expand All @@ -348,12 +349,13 @@ def test_write_dataframe_msgpack_with_boolean_na(self):
{"a": "false", "b": "true", "c": "true", "time": 1234},
)
self.writer._write_msgpack_stream = MagicMock()
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)
with patch("pytd.writer.os.unlink"):
self.writer.write_dataframe(df, self.table, "overwrite", fmt="msgpack")
self.assertTrue(self.writer._write_msgpack_stream.called)
self.assertEqual(
self.writer._write_msgpack_stream.call_args[0][0][0:2],
expected_list,
)

def test_write_dataframe_invalid_if_exists(self):
with self.assertRaises(ValueError):
Expand Down
34 changes: 23 additions & 11 deletions pytd/writer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
import gzip
import io
import logging
import os
import tempfile
Expand Down Expand Up @@ -450,11 +449,18 @@ def write_dataframe(
_replace_pd_na(dataframe)

records = dataframe.to_dict(orient="records")
for group in zip_longest(*(iter(records),) * chunk_record_size):
fp = io.BytesIO()
fp = self._write_msgpack_stream(group, fp)
fps.append(fp)
stack.callback(fp.close)
try:
for group in zip_longest(*(iter(records),) * chunk_record_size):
fp = tempfile.NamedTemporaryFile(suffix=".msgpack.gz", delete=False)
fp = self._write_msgpack_stream(group, fp)
fps.append(fp)
stack.callback(os.unlink, fp.name)
stack.callback(fp.close)
except OSError as e:
raise RuntimeError(
"failed to create a temporary file. "
"Increase chunk_record_size may mitigate the issue."
) from e
else:
raise ValueError(
f"unsupported format '{fmt}' for bulk import. "
Expand Down Expand Up @@ -514,19 +520,21 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv", max_workers=5):
bulk_import = table.client.api_client.create_bulk_import(
session_name, table.database, table.table, params=params
)
s_time = time.time()
try:
logger.info(f"uploading data converted into a {fmt} file")
if fmt == "msgpack":
with ThreadPoolExecutor(max_workers=max_workers) as executor:
_ = [
for i, fp in enumerate(file_like):
fsize = fp.tell()
fp.seek(0)
executor.submit(
bulk_import.upload_part,
f"part-{i}",
fp,
fp.getbuffer().nbytes,
fsize,
)
for i, fp in enumerate(file_like)
]
logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B")
else:
fp = file_like[0]
bulk_import.upload_file("part", fmt, fp)
Expand All @@ -535,6 +543,8 @@ def _bulk_import(self, table, file_like, if_exists, fmt="csv", max_workers=5):
bulk_import.delete()
raise RuntimeError(f"failed to upload file: {e}")

logger.info(f"uploaded data in {time.time() - s_time:.2f} sec")

logger.info("performing a bulk import job")
job = bulk_import.perform(wait=True)

Expand Down Expand Up @@ -581,7 +591,9 @@ def _write_msgpack_stream(self, items, stream):
mp = packer.pack(normalized_msgpack(item))
gz.write(mp)

stream.seek(0)
logger.debug(
f"created a msgpack file: {stream.name}. File size: {stream.tell()}"
)
return stream


Expand Down

0 comments on commit 8714d12

Please sign in to comment.