Skip to content

Commit

Permalink
Merge pull request #411 from LSSTDESC/u/yymao/force-consistent-schema
Browse files Browse the repository at this point in the history
  • Loading branch information
yymao authored Jan 5, 2021
2 parents 3aefa29 + 19a9ac0 commit c50043d
Showing 1 changed file with 33 additions and 21 deletions.
54 changes: 33 additions & 21 deletions scripts/write_gcr_to_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,6 @@
__all__ = ["convert_cat_to_parquet"]


if hasattr(pa.Table, "from_pydict"):
# only available in pyarrow >= 0.14.0
pa_table_from_pydict = pa.Table.from_pydict
else:
def pa_table_from_pydict(mapping):
names = []
arrays = []
for k, v in mapping.items():
names.append(k)
arrays.append(pa.array(v))
return pa.Table.from_arrays(arrays, names)


class Checkpoint():
def __init__(self, path, checkpoint_dir=None):
if checkpoint_dir is None:
Expand Down Expand Up @@ -74,8 +61,9 @@ def _has_run_no_lock(self):


def _chunk_data_generator(cat, columns, native_filters=None):
columns = sorted(columns)
for data in cat.get_quantities(columns, native_filters=native_filters, return_iterator=True):
table = pa_table_from_pydict(data)
table = pa.Table.from_arrays([pa.array(data[col]) for col in columns], columns)
del data
try:
cat.close_all_file_handles()
Expand All @@ -84,7 +72,15 @@ def _chunk_data_generator(cat, columns, native_filters=None):
yield table


def _write_one_parquet_file(output_path, cat=None, get_quantities_kwargs=None, silent=False, checkpoint_dir=None):
def _write_one_parquet_file(
output_path,
cat=None,
get_quantities_kwargs=None,
schema=None,
return_schema=False,
silent=False,
checkpoint_dir=None
):
my_print = (lambda *x: None) if silent else print

checkpoint = Checkpoint(output_path, checkpoint_dir)
Expand All @@ -96,17 +92,27 @@ def _write_one_parquet_file(output_path, cat=None, get_quantities_kwargs=None, s
with checkpoint.run():
my_print("Generating", output_path, time.strftime("[%H:%M:%S]"))
if not get_quantities_kwargs:
with pq.ParquetWriter(output_path, cat.schema, flavor='spark') as pqwriter:
if schema is None:
schema = cat.schema
with pq.ParquetWriter(output_path, schema, flavor='spark') as pqwriter:
pqwriter.write_table(cat)
else:
chunk_iter = _chunk_data_generator(cat, **get_quantities_kwargs)
table = next(chunk_iter)
with pq.ParquetWriter(output_path, schema=table.schema, flavor='spark') as pqwriter:
pqwriter.write_table(table)
if schema is None:
table = next(chunk_iter)
schema = table.schema
else:
table = None
with pq.ParquetWriter(output_path, schema=schema, flavor='spark') as pqwriter:
if table is not None:
pqwriter.write_table(table)
for table in chunk_iter:
pqwriter.write_table(table)
my_print("Done with", output_path, time.strftime("[%H:%M:%S]"))

if return_schema:
return schema


def convert_cat_to_parquet(reader,
output_filename=None,
Expand Down Expand Up @@ -224,20 +230,26 @@ def convert_cat_to_parquet(reader,
)

elif partition == "iter":
schema = None
for i, table in enumerate(_chunk_data_generator(cat, columns)):
_write_one_parquet_file(
schema = _write_one_parquet_file(
output_path=output_filename.format(i),
cat=table,
schema=schema,
return_schema=True,
silent=silent,
checkpoint_dir=checkpoint_dir,
)

elif partition_values:
schema = None
for value in partition_values:
_write_one_parquet_file(
schema = _write_one_parquet_file(
output_path=output_filename.format(value),
cat=cat,
get_quantities_kwargs=dict(columns=columns, native_filters="{} == {}".format(partition, value)),
schema=schema,
return_schema=True,
silent=silent,
checkpoint_dir=checkpoint_dir,
)
Expand Down

0 comments on commit c50043d

Please sign in to comment.