Skip to content

Commit

Permalink
ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Mar 3, 2024
1 parent 8161278 commit dac3148
Show file tree
Hide file tree
Showing 15 changed files with 77 additions and 227 deletions.
4 changes: 1 addition & 3 deletions datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ def udaf(accum, input_type, return_type, state_type, volatility, name=None):
Create a new User Defined Aggregate Function
"""
if not issubclass(accum, Accumulator):
raise TypeError(
"`accum` must implement the abstract base class Accumulator"
)
raise TypeError("`accum` must implement the abstract base class Accumulator")
if name is None:
name = accum.__qualname__.lower()
if isinstance(input_type, pa.lib.DataType):
Expand Down
4 changes: 1 addition & 3 deletions datafusion/cudf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def to_cudf_df(self, plan):
elif isinstance(node, TableScan):
return cudf.read_parquet(self.parquet_tables[node.table_name()])
else:
raise Exception(
"unsupported logical operator: {}".format(type(node))
)
raise Exception("unsupported logical operator: {}".format(type(node)))

def create_schema(self, schema_name: str, **kwargs):
logger.debug(f"Creating schema: {schema_name}")
Expand Down
8 changes: 2 additions & 6 deletions datafusion/input/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,9 @@ class BaseInputSource(ABC):
"""

@abstractmethod
def is_correct_input(
self, input_item: Any, table_name: str, **kwargs
) -> bool:
def is_correct_input(self, input_item: Any, table_name: str, **kwargs) -> bool:
pass

@abstractmethod
def build_table(
self, input_item: Any, table_name: str, **kwarg
) -> SqlTable:
def build_table(self, input_item: Any, table_name: str, **kwarg) -> SqlTable:
pass
4 changes: 1 addition & 3 deletions datafusion/input/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def build_table(
for _ in reader:
num_rows += 1
# TODO: Need to actually consume this row into resonable columns
raise RuntimeError(
"TODO: Currently unable to support CSV input files."
)
raise RuntimeError("TODO: Currently unable to support CSV input files.")
else:
raise RuntimeError(
f"Input of format: `{format}` is currently not supported.\
Expand Down
4 changes: 1 addition & 3 deletions datafusion/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def to_pandas_df(self, plan):
elif isinstance(node, TableScan):
return pd.read_parquet(self.parquet_tables[node.table_name()])
else:
raise Exception(
"unsupported logical operator: {}".format(type(node))
)
raise Exception("unsupported logical operator: {}".format(type(node)))

def create_schema(self, schema_name: str, **kwargs):
logger.debug(f"Creating schema: {schema_name}")
Expand Down
12 changes: 3 additions & 9 deletions datafusion/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def to_polars_df(self, plan):
args = [self.to_polars_expr(expr) for expr in node.projections()]
return inputs[0].select(*args)
elif isinstance(node, Aggregate):
groupby_expr = [
self.to_polars_expr(expr) for expr in node.group_by_exprs()
]
groupby_expr = [self.to_polars_expr(expr) for expr in node.group_by_exprs()]
aggs = []
for expr in node.aggregate_exprs():
expr = expr.to_variant()
Expand All @@ -67,17 +65,13 @@ def to_polars_df(self, plan):
)
)
else:
raise Exception(
"Unsupported aggregate function {}".format(expr)
)
raise Exception("Unsupported aggregate function {}".format(expr))
df = inputs[0].groupby(groupby_expr).agg(aggs)
return df
elif isinstance(node, TableScan):
return polars.read_parquet(self.parquet_tables[node.table_name()])
else:
raise Exception(
"unsupported logical operator: {}".format(type(node))
)
raise Exception("unsupported logical operator: {}".format(type(node)))

def create_schema(self, schema_name: str, **kwargs):
logger.debug(f"Creating schema: {schema_name}")
Expand Down
12 changes: 3 additions & 9 deletions datafusion/tests/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def data_datetime(f):
datetime.datetime.now() - datetime.timedelta(days=1),
datetime.datetime.now() + datetime.timedelta(days=1),
]
return pa.array(
data, type=pa.timestamp(f), mask=np.array([False, True, False])
)
return pa.array(data, type=pa.timestamp(f), mask=np.array([False, True, False]))


def data_date32():
Expand All @@ -61,9 +59,7 @@ def data_date32():
datetime.date(1980, 1, 1),
datetime.date(2030, 1, 1),
]
return pa.array(
data, type=pa.date32(), mask=np.array([False, True, False])
)
return pa.array(data, type=pa.date32(), mask=np.array([False, True, False]))


def data_timedelta(f):
Expand All @@ -72,9 +68,7 @@ def data_timedelta(f):
datetime.timedelta(days=1),
datetime.timedelta(seconds=1),
]
return pa.array(
data, type=pa.duration(f), mask=np.array([False, True, False])
)
return pa.array(data, type=pa.duration(f), mask=np.array([False, True, False]))


def data_binary_other():
Expand Down
34 changes: 8 additions & 26 deletions datafusion/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def test_built_in_aggregation(df):
assert result.column(2) == pa.array([4])
assert result.column(3) == pa.array([6])
assert result.column(4) == pa.array([[4, 4, 6]])
np.testing.assert_array_almost_equal(
result.column(5), np.average(values_a)
)
np.testing.assert_array_almost_equal(result.column(5), np.average(values_a))
np.testing.assert_array_almost_equal(
result.column(6), np.corrcoef(values_a, values_b)[0][1]
)
Expand All @@ -101,35 +99,20 @@ def test_built_in_aggregation(df):
)
np.testing.assert_array_almost_equal(result.column(11), np.max(values_a))
np.testing.assert_array_almost_equal(result.column(12), np.mean(values_b))
np.testing.assert_array_almost_equal(
result.column(13), np.median(values_b)
)
np.testing.assert_array_almost_equal(result.column(13), np.median(values_b))
np.testing.assert_array_almost_equal(result.column(14), np.min(values_a))
np.testing.assert_array_almost_equal(
result.column(15), np.sum(values_b.to_pylist())
)
np.testing.assert_array_almost_equal(
result.column(16), np.std(values_a, ddof=1)
)
np.testing.assert_array_almost_equal(
result.column(17), np.std(values_b, ddof=0)
)
np.testing.assert_array_almost_equal(
result.column(18), np.std(values_c, ddof=1)
)
np.testing.assert_array_almost_equal(
result.column(19), np.var(values_a, ddof=1)
)
np.testing.assert_array_almost_equal(
result.column(20), np.var(values_b, ddof=0)
)
np.testing.assert_array_almost_equal(
result.column(21), np.var(values_c, ddof=1)
)
np.testing.assert_array_almost_equal(result.column(16), np.std(values_a, ddof=1))
np.testing.assert_array_almost_equal(result.column(17), np.std(values_b, ddof=0))
np.testing.assert_array_almost_equal(result.column(18), np.std(values_c, ddof=1))
np.testing.assert_array_almost_equal(result.column(19), np.var(values_a, ddof=1))
np.testing.assert_array_almost_equal(result.column(20), np.var(values_b, ddof=0))
np.testing.assert_array_almost_equal(result.column(21), np.var(values_c, ddof=1))


def test_bit_add_or_xor(df):

df = df.aggregate(
[],
[
Expand All @@ -147,7 +130,6 @@ def test_bit_add_or_xor(df):


def test_bool_and_or(df):

df = df.aggregate(
[],
[
Expand Down
5 changes: 1 addition & 4 deletions datafusion/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def test_get_then_set(config):

def test_get_all(config):
config_dict = config.get_all()
assert (
config_dict["datafusion.catalog.create_default_catalog_and_schema"]
== "true"
)
assert config_dict["datafusion.catalog.create_default_catalog_and_schema"] == "true"


def test_get_invalid_config(config):
Expand Down
12 changes: 3 additions & 9 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def test_create_context_no_args():


def test_create_context_with_all_valid_args():
runtime = (
RuntimeConfig().with_disk_manager_os().with_fair_spill_pool(10000000)
)
runtime = RuntimeConfig().with_disk_manager_os().with_fair_spill_pool(10000000)
config = (
SessionConfig()
.with_create_default_catalog_and_schema(True)
Expand Down Expand Up @@ -357,9 +355,7 @@ def test_read_json_compressed(ctx, tmp_path):
with gzip.open(gzip_path, "wb") as gzipped_file:
gzipped_file.writelines(csv_file)

df = ctx.read_json(
gzip_path, file_extension=".gz", file_compression_type="gz"
)
df = ctx.read_json(gzip_path, file_extension=".gz", file_compression_type="gz")
result = df.collect()

assert result[0].column(0) == pa.array(["a", "b", "c"])
Expand All @@ -381,9 +377,7 @@ def test_read_csv_compressed(ctx, tmp_path):
with gzip.open(gzip_path, "wb") as gzipped_file:
gzipped_file.writelines(csv_file)

csv_df = ctx.read_csv(
gzip_path, file_extension=".gz", file_compression_type="gz"
)
csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz")
csv_df.select(column("c1")).show()


Expand Down
54 changes: 17 additions & 37 deletions datafusion/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def test_with_column(df):


def test_with_column_renamed(df):
df = df.with_column("c", column("a") + column("b")).with_column_renamed(
"c", "sum"
)
df = df.with_column("c", column("a") + column("b")).with_column_renamed("c", "sum")

result = df.collect()[0]

Expand Down Expand Up @@ -218,9 +216,7 @@ def test_distinct():
[pa.array([1, 2, 3]), pa.array([4, 5, 6])],
names=["a", "b"],
)
df_b = ctx.create_dataframe([[batch]]).sort(
column("a").sort(ascending=True)
)
df_b = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))

assert df_a.collect() == df_b.collect()

Expand Down Expand Up @@ -251,19 +247,15 @@ def test_window_functions(df):
"cume_dist",
),
f.alias(
f.window(
"ntile", [literal(2)], order_by=[f.order_by(column("c"))]
),
f.window("ntile", [literal(2)], order_by=[f.order_by(column("c"))]),
"ntile",
),
f.alias(
f.window("lag", [column("b")], order_by=[f.order_by(column("b"))]),
"previous",
),
f.alias(
f.window(
"lead", [column("b")], order_by=[f.order_by(column("b"))]
),
f.window("lead", [column("b")], order_by=[f.order_by(column("b"))]),
"next",
),
f.alias(
Expand All @@ -275,9 +267,7 @@ def test_window_functions(df):
"first_value",
),
f.alias(
f.window(
"last_value", [column("b")], order_by=[f.order_by(column("b"))]
),
f.window("last_value", [column("b")], order_by=[f.order_by(column("b"))]),
"last_value",
),
f.alias(
Expand Down Expand Up @@ -418,12 +408,14 @@ def test_optimized_logical_plan(aggregate_df):
def test_execution_plan(aggregate_df):
plan = aggregate_df.execution_plan()

expected = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[SUM(test.c2)]\n" # noqa: E501
expected = (
"AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[SUM(test.c2)]\n"
) # noqa: E501

assert expected == plan.display()

# Check the number of partitions is as expected.
assert type(plan.partition_count) is int
assert isinstance(type(plan.partition_count), int)

expected = (
"ProjectionExec: expr=[c1@0 as c1, SUM(test.c2)@1 as SUM(test.c2)]\n"
Expand Down Expand Up @@ -477,9 +469,7 @@ def test_intersect():
[pa.array([3]), pa.array([6])],
names=["a", "b"],
)
df_c = ctx.create_dataframe([[batch]]).sort(
column("a").sort(ascending=True)
)
df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))

df_a_i_b = df_a.intersect(df_b).sort(column("a").sort(ascending=True))

Expand All @@ -505,9 +495,7 @@ def test_except_all():
[pa.array([1, 2]), pa.array([4, 5])],
names=["a", "b"],
)
df_c = ctx.create_dataframe([[batch]]).sort(
column("a").sort(ascending=True)
)
df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))

df_a_e_b = df_a.except_all(df_b).sort(column("a").sort(ascending=True))

Expand Down Expand Up @@ -542,9 +530,7 @@ def test_union(ctx):
[pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])],
names=["a", "b"],
)
df_c = ctx.create_dataframe([[batch]]).sort(
column("a").sort(ascending=True)
)
df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))

df_a_u_b = df_a.union(df_b).sort(column("a").sort(ascending=True))

Expand All @@ -568,9 +554,7 @@ def test_union_distinct(ctx):
[pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])],
names=["a", "b"],
)
df_c = ctx.create_dataframe([[batch]]).sort(
column("a").sort(ascending=True)
)
df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True))

df_a_u_b = df_a.union(df_b, True).sort(column("a").sort(ascending=True))

Expand Down Expand Up @@ -650,7 +634,7 @@ def test_empty_to_arrow_table(df):
def test_to_pylist(df):
# Convert datafusion dataframe to Python list
pylist = df.to_pylist()
assert type(pylist) == list
assert isinstance(type(pylist), list)
assert pylist == [
{"a": 1, "b": 4, "c": 8},
{"a": 2, "b": 5, "c": 5},
Expand All @@ -661,7 +645,7 @@ def test_to_pylist(df):
def test_to_pydict(df):
# Convert datafusion dataframe to Python dictionary
pydict = df.to_pydict()
assert type(pydict) == dict
assert isinstance(type(pydict), dict)
assert pydict == {"a": [1, 2, 3], "b": [4, 5, 6], "c": [8, 5, 8]}


Expand Down Expand Up @@ -702,9 +686,7 @@ def test_write_parquet(df, tmp_path):
"compression, compression_level",
[("gzip", 6), ("brotli", 7), ("zstd", 15)],
)
def test_write_compressed_parquet(
df, tmp_path, compression, compression_level
):
def test_write_compressed_parquet(df, tmp_path, compression, compression_level):
path = tmp_path

df.write_parquet(
Expand Down Expand Up @@ -744,9 +726,7 @@ def test_write_compressed_parquet_wrong_compression_level(


@pytest.mark.parametrize("compression", ["brotli", "zstd", "wrong"])
def test_write_compressed_parquet_missing_compression_level(
df, tmp_path, compression
):
def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compression):
path = tmp_path

with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit dac3148

Please sign in to comment.