Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add show_progress flag to BulkImportWriter #141

Merged
merged 6 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
if apikey is None:
raise ValueError(
"either argument 'apikey' or environment variable"
"'TD_API_KEY' should be set"
" 'TD_API_KEY' should be set"
)
if endpoint is None:
endpoint = os.getenv("TD_API_SERVER", "https://api.treasuredata.com")
Expand Down
6 changes: 3 additions & 3 deletions pytd/pandas_td/ipython.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""IPython Magics

IPython magics to access to Treasure Data. Load the magics first of all:
IPython magics to access to Treasure Data. Load the magics first of all:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why the linter decided to remove the tab here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the linter is correct though. Multiline comments usually align with the """, not the text


.. code-block:: ipython
.. code-block:: ipython

In [1]: %load_ext pytd.pandas_td.ipython
In [1]: %load_ext pytd.pandas_td.ipython
"""

import argparse
Expand Down
2 changes: 1 addition & 1 deletion pytd/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def fetch_td_spark_context(
if apikey is None:
raise ValueError(
"either argument 'apikey' or environment variable"
"'TD_API_KEY' should be set"
" 'TD_API_KEY' should be set"
)
if endpoint is None:
endpoint = os.getenv("TD_API_SERVER", "https://api.treasuredata.com")
Expand Down
68 changes: 58 additions & 10 deletions pytd/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
from tdclient.util import normalized_msgpack
from tqdm import tqdm

from .spark import fetch_td_spark_context

Expand Down Expand Up @@ -321,6 +322,7 @@ def write_dataframe(
keep_list=False,
max_workers=5,
chunk_record_size=10_000,
show_progress=False,
):
"""Write a given DataFrame to a Treasure Data table.

Expand Down Expand Up @@ -367,9 +369,14 @@ def write_dataframe(
will be converted array<T> on Treasure Data table.
Each type of element of list will be converted by
``numpy.array(your_list).tolist()``.

If True, ``fmt`` argument will be overwritten with ``msgpack``.


show_progress : boolean, default: False
If this argument is True, shows a TQDM progress bar
for chunking data into msgpack format and uploading before
performing a bulk import.

Examples
---------

Expand Down Expand Up @@ -456,7 +463,15 @@ def write_dataframe(
try:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for start in range(0, num_rows, _chunk_record_size):
chunk_range = (
tqdm(
range(0, num_rows, _chunk_record_size),
desc="Chunking data",
)
if show_progress
else range(0, num_rows, _chunk_record_size)
)
for start in chunk_range:
records = dataframe.iloc[
start : start + _chunk_record_size
].to_dict(orient="records")
Expand All @@ -473,7 +488,12 @@ def write_dataframe(
)
stack.callback(os.unlink, fp.name)
stack.callback(fp.close)
for start, future in sorted(futures):
resolve_range = (
tqdm(sorted(futures), desc="Resolving futures")
if show_progress
else sorted(futures)
)
for start, future in resolve_range:
fps.append(future.result())
except OSError as e:
raise RuntimeError(
Expand All @@ -485,10 +505,25 @@ def write_dataframe(
f"unsupported format '{fmt}' for bulk import. "
"should be 'csv' or 'msgpack'"
)
self._bulk_import(table, fps, if_exists, fmt, max_workers=max_workers)
self._bulk_import(
table,
fps,
if_exists,
fmt,
max_workers=max_workers,
show_progress=show_progress,
)
stack.close()

def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
def _bulk_import(
self,
table,
file_likes,
if_exists,
fmt="csv",
max_workers=5,
show_progress=False,
):
"""Write a specified CSV file to a Treasure Data table.

This method uploads the file to Treasure Data via bulk import API.
Expand All @@ -515,6 +550,10 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
max_workers : int, optional, default: 5
The maximum number of threads that can be used to execute the given calls.
This is used only when ``fmt`` is ``msgpack``.

show_progress : boolean, default: False
If this argument is True, shows a TQDM progress bar
for the upload process performed on multiple threads.
"""
params = None
if table.exists:
Expand Down Expand Up @@ -544,16 +583,25 @@ def _bulk_import(self, table, file_likes, if_exists, fmt="csv", max_workers=5):
logger.info(f"uploading data converted into a {fmt} file")
if fmt == "msgpack":
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for i, fp in enumerate(file_likes):
fsize = fp.tell()
fp.seek(0)
executor.submit(
bulk_import.upload_part,
f"part-{i}",
fp,
fsize,
futures.append(
executor.submit(
bulk_import.upload_part,
f"part-{i}",
fp,
fsize,
)
)
logger.debug(f"to upload {fp.name} to TD. File size: {fsize}B")
if show_progress:
for _ in tqdm(futures, desc="Uploading parts"):
_.result()
else:
for future in futures:
future.result()
else:
fp = file_likes[0]
bulk_import.upload_file("part", fmt, fp)
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ install_requires =
numpy>1.17.3, <2.0.0
td-client>=1.1.0
pytz>=2018.5
tqdm>=4.60.0

[options.extras_require]
spark =
Expand Down Expand Up @@ -65,7 +66,7 @@ exclude =
doc/conf.py

[isort]
known_third_party = IPython,msgpack,nox,numpy,pandas,pkg_resources,prestodb,pytz,setuptools,tdclient
known_third_party = IPython,msgpack,nox,numpy,pandas,pkg_resources,prestodb,pytz,setuptools,tdclient,tqdm
line_length=88
multi_line_output=3
include_trailing_comma=True
Loading