Skip to content

Commit

Permalink
refactor(snowflake): replace custom temp table ddl for memtables with…
Browse files Browse the repository at this point in the history
… `read_parquet`
  • Loading branch information
cpcloud authored and kszucs committed Aug 18, 2023
1 parent 297b449 commit 41df410
Showing 1 changed file with 1 addition and 43 deletions.
44 changes: 1 addition & 43 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,18 +434,9 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
import pyarrow.parquet as pq

raw_name = op.name
table = self._quote(raw_name)

current_db = self.current_database
current_schema = self.current_schema
ident = f"{self._quote(current_db)}.{self._quote(current_schema)}.{table}"

with self.begin() as con:
if con.exec_driver_sql(f"SHOW TABLES LIKE '{raw_name}'").scalar() is None:
# 1. create a temporary stage for holding parquet files
stage = util.gen_name("stage")
con.exec_driver_sql(f"CREATE TEMP STAGE {stage}")

tmpdir = tempfile.TemporaryDirectory()
try:
path = os.path.join(tmpdir.name, f"{raw_name}.parquet")
Expand All @@ -454,44 +445,11 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
pq.write_table(
op.data.to_pyarrow(schema=op.schema), path, compression="zstd"
)

# 2. copy the parquet file into the stage
#
# disable the automatic compression to gzip because we've
# already compressed the data with zstd
#
# 99 is the limit on the number of threads use to upload data,
# who knows why?
con.exec_driver_sql(
f"""
PUT 'file://{path}' @{stage}
PARALLEL = {min((os.cpu_count() or 2) // 2, 99)}
AUTO_COMPRESS = FALSE
"""
)
self.read_parquet(path, table_name=raw_name)
finally:
with contextlib.suppress(Exception):
shutil.rmtree(tmpdir.name)

# 3. create a temporary table
schema = ", ".join(
f"{self._quote(col)} {SnowflakeType.to_string(typ) + ' NOT NULL' * (not typ.nullable)}"
for col, typ in op.schema.items()
)
con.exec_driver_sql(f"CREATE TEMP TABLE {ident} ({schema})")
# 4. copy the data into the table
columns = op.schema.names
column_names = ", ".join(map(self._quote, columns))
parquet_column_names = ", ".join(f"$1:{col}" for col in columns)
con.exec_driver_sql(
f"""
COPY INTO {ident} ({column_names})
FROM (SELECT {parquet_column_names} FROM @{stage})
FILE_FORMAT = (TYPE = PARQUET COMPRESSION = AUTO)
PURGE = TRUE
"""
)

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
Expand Down

0 comments on commit 41df410

Please sign in to comment.