Skip to content

Commit

Permalink
Merge branch 'main' into test/datasets/sql-doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman authored Nov 30, 2023
2 parents 47dfa6a + 48e2e76 commit 8d6acac
Show file tree
Hide file tree
Showing 33 changed files with 44 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BioSequenceDataset(AbstractDataset[list, list]):
...
>>>
>>> dataset = BioSequenceDataset(
... filepath="ls_orchid.fasta",
... filepath=tmp_path / "ls_orchid.fasta",
... load_args={"format": "fasta"},
... save_args={"format": "fasta"},
... )
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/email/message_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class EmailMessageDataset(AbstractVersionedDataset[Message, Message]):
>>> msg["From"] = '"sin studly17"'
>>> msg["To"] = '"strong bad"'
>>>
>>> dataset = EmailMessageDataset(filepath="test")
>>> dataset = EmailMessageDataset(filepath=tmp_path / "test")
>>> dataset.save(msg)
>>> reloaded = dataset.load()
>>> assert msg.__dict__ == reloaded.__dict__
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/geopandas/geojson_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class GeoJSONDataset(
... {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]},
... geometry=[Point(1, 1), Point(2, 4)],
... )
>>> dataset = GeoJSONDataset(filepath="test.geojson", save_args=None)
>>> dataset = GeoJSONDataset(filepath=tmp_path / "test.geojson", save_args=None)
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class HoloviewsWriter(AbstractVersionedDataset[HoloViews, NoReturn]):
>>> from kedro_datasets.holoviews import HoloviewsWriter
>>>
>>> curve = hv.Curve(range(10))
>>> holoviews_writer = HoloviewsWriter(filepath="/tmp/holoviews")
>>> holoviews_writer = HoloviewsWriter(filepath=tmp_path / "holoviews")
>>>
>>> holoviews_writer.save(curve)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/json/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class JSONDataset(AbstractVersionedDataset[Any, Any]):
>>>
>>> data = {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}
>>>
>>> dataset = JSONDataset(filepath="test.json")
>>> dataset = JSONDataset(filepath=tmp_path / "test.json")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data == reloaded
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/networkx/gml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class GMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):
>>> from kedro_datasets.networkx import GMLDataset
>>> import networkx as nx
>>> graph = nx.complete_graph(100)
>>> graph_dataset = GMLDataset(filepath="test.gml")
>>> graph_dataset = GMLDataset(filepath=tmp_path / "test.gml")
>>> graph_dataset.save(graph)
>>> reloaded = graph_dataset.load()
>>> assert nx.is_isomorphic(graph, reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/networkx/graphml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class GraphMLDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):
>>> from kedro_datasets.networkx import GraphMLDataset
>>> import networkx as nx
>>> graph = nx.complete_graph(100)
>>> graph_dataset = GraphMLDataset(filepath="test.graphml")
>>> graph_dataset = GraphMLDataset(filepath=tmp_path / "test.graphml")
>>> graph_dataset.save(graph)
>>> reloaded = graph_dataset.load()
>>> assert nx.is_isomorphic(graph, reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/networkx/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class JSONDataset(AbstractVersionedDataset[networkx.Graph, networkx.Graph]):
>>> from kedro_datasets.networkx import JSONDataset
>>> import networkx as nx
>>> graph = nx.complete_graph(100)
>>> graph_dataset = JSONDataset(filepath="test.json")
>>> graph_dataset = JSONDataset(filepath=tmp_path / "test.json")
>>> graph_dataset.save(graph)
>>> reloaded = graph_dataset.load()
>>> assert nx.is_isomorphic(graph, reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CSVDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = CSVDataset(filepath="test.csv")
>>> dataset = CSVDataset(filepath=tmp_path / "test.csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.equals(reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class DeltaTableDataset(AbstractDataset):
>>> import pandas as pd
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>> dataset = DeltaTableDataset(filepath="test")
>>> dataset = DeltaTableDataset(filepath=tmp_path / "test")
>>>
>>> dataset.save(data)
>>> reloaded = dataset.load()
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/excel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class ExcelDataset(
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = ExcelDataset(filepath="test.xlsx")
>>> dataset = ExcelDataset(filepath=tmp_path / "test.xlsx")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.equals(reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/feather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class FeatherDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = FeatherDataset(filepath="test.feather")
>>> dataset = FeatherDataset(filepath=tmp_path / "test.feather")
>>>
>>> dataset.save(data)
>>> reloaded = dataset.load()
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/generic_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GenericDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]):
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = GenericDataset(
... filepath="test.csv", file_format="csv", save_args={"index": False}
... filepath=tmp_path / "test.csv", file_format="csv", save_args={"index": False}
... )
>>> dataset.save(data)
>>> reloaded = dataset.load()
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/hdf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class HDFDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = HDFDataset(filepath="test.h5", key="data")
>>> dataset = HDFDataset(filepath=tmp_path / "test.h5", key="data")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.equals(reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class JSONDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = JSONDataset(filepath="test.json")
>>> dataset = JSONDataset(filepath=tmp_path / "test.json")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.equals(reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ParquetDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = ParquetDataset(filepath="test.parquet")
>>> dataset = ParquetDataset(filepath=tmp_path / "test.parquet")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.equals(reloaded)
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pandas/xml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class XMLDataset(AbstractVersionedDataset[pd.DataFrame, pd.DataFrame]):
>>>
>>> data = pd.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = XMLDataset(filepath="test.xml")
>>> dataset = XMLDataset(filepath=tmp_path / "test.xml")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.equals(reloaded)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ class PartitionedDataset(AbstractDataset[dict[str, Any], dict[str, Callable[[],
>>>
>>> # Save it as small paritions with DAY_OF_MONTH as the partition key
>>> dataset = PartitionedDataset(
... path="df_with_partition", dataset="pandas.CSVDataset", filename_suffix=".csv"
... path=tmp_path / "df_with_partition",
... dataset="pandas.CSVDataset",
... filename_suffix=".csv",
... )
>>> # This will create a folder `df_with_partition` and save multiple files
>>> # with the dict key + filename_suffix as filename, i.e. 1.csv, 2.csv etc.
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/pickle/pickle_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class PickleDataset(AbstractVersionedDataset[Any, Any]):
>>> assert data.equals(reloaded)
>>>
>>> dataset = PickleDataset(
... filepath="test.pickle.lz4",
... filepath=tmp_path / "test.pickle.lz4",
... backend="compress_pickle",
... load_args={"compression": "lz4"},
... save_args={"compression": "lz4"},
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/plotly/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JSONDataset(
>>> import plotly.express as px
>>>
>>> fig = px.bar(x=["a", "b", "c"], y=[1, 3, 2])
>>> dataset = JSONDataset(filepath="test.json")
>>> dataset = JSONDataset(filepath=tmp_path / "test.json")
>>> dataset.save(fig)
>>> reloaded = dataset.load()
>>> assert fig == reloaded
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/plotly/plotly_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class PlotlyDataset(JSONDataset):
>>> df_data = pd.DataFrame([[0, 1], [1, 0]], columns=("x1", "x2"))
>>>
>>> dataset = PlotlyDataset(
... filepath="scatter_plot.json",
... filepath=tmp_path / "scatter_plot.json",
... plotly_args={
... "type": "scatter",
... "fig": {"x": "x1", "y": "x2"},
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/polars/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class CSVDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
>>>
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = CSVDataset(filepath="test.csv")
>>> dataset = CSVDataset(filepath=tmp_path / "test.csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class EagerPolarsDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
>>>
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = EagerPolarsDataset(filepath="test.parquet", file_format="parquet")
>>> dataset = EagerPolarsDataset(filepath=tmp_path / "test.parquet", file_format="parquet")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class LazyPolarsDataset(AbstractVersionedDataset[pl.LazyFrame, PolarsFrame]):
>>>
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = LazyPolarsDataset(filepath="test.csv")
>>> dataset = LazyPolarsDataset(filepath=tmp_path / "test.csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/spark/deltatable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class DeltaTableDataset(AbstractDataset[None, DeltaTable]):
>>>
>>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)
>>>
>>> dataset = SparkDataset(filepath="test_data", file_format="delta")
>>> dataset = SparkDataset(filepath=tmp_path / "test_data", file_format="delta")
>>> dataset.save(spark_df)
>>> deltatable_dataset = DeltaTableDataset(filepath="test_data")
>>> deltatable_dataset = DeltaTableDataset(filepath=tmp_path / "test_data")
>>> delta_table = deltatable_dataset.load()
>>>
>>> delta_table.update()
Expand Down
14 changes: 8 additions & 6 deletions kedro-datasets/kedro_datasets/spark/spark_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""``AbstractVersionedDataset`` implementation to access Spark dataframes using
``pyspark``.
"""
from __future__ import annotations

import json
import logging
import os
from copy import deepcopy
from fnmatch import fnmatch
from functools import partial
from pathlib import PurePosixPath
from typing import Any, Optional
from typing import Any
from warnings import warn

import fsspec
Expand Down Expand Up @@ -62,8 +64,8 @@ def _parse_glob_pattern(pattern: str) -> str:
return "/".join(clean)


def _split_filepath(filepath: str) -> tuple[str, str]:
split_ = filepath.split("://", 1)
def _split_filepath(filepath: str | os.PathLike) -> tuple[str, str]:
split_ = str(filepath).split("://", 1)
if len(split_) == 2: # noqa: PLR2004
return split_[0] + "://", split_[1]
return "", split_[0]
Expand Down Expand Up @@ -100,7 +102,7 @@ def _dbfs_glob(pattern: str, dbutils: Any) -> list[str]:
return sorted(matched)


def _get_dbutils(spark: SparkSession) -> Optional[Any]:
def _get_dbutils(spark: SparkSession) -> Any:
"""Get the instance of 'dbutils' or None if the one could not be found."""
dbutils = globals().get("dbutils")
if dbutils:
Expand Down Expand Up @@ -245,11 +247,11 @@ class SparkDataset(AbstractVersionedDataset[DataFrame, DataFrame]):
>>>
>>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema)
>>>
>>> dataset = SparkDataset(filepath="test_data")
>>> dataset = SparkDataset(filepath=tmp_path / "test_data")
>>> dataset.save(spark_df)
>>> reloaded = dataset.load()
>>>
>>> assert Row(name='Bob', age=12) in reloaded.take(4)
>>> assert Row(name="Bob", age=12) in reloaded.take(4)
"""

# this dataset cannot be used with ``ParallelRunner``,
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/svmlight/svmlight_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class SVMLightDataset(AbstractVersionedDataset[_DI, _DO]):
>>> # Features and labels.
>>> data = (np.array([[0, 1], [2, 3.14159]]), np.array([7, 3]))
>>>
>>> dataset = SVMLightDataset(filepath="test.svm")
>>> dataset = SVMLightDataset(filepath=tmp_path / "test.svm")
>>> dataset.save(data)
>>> reloaded_features, reloaded_labels = dataset.load()
>>> assert (data[0] == reloaded_features).all()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class TensorFlowModelDataset(AbstractVersionedDataset[tf.keras.Model, tf.keras.M
>>> import tensorflow as tf
>>> import numpy as np
>>>
>>> dataset = TensorFlowModelDataset("data/06_models/tensorflow_model.h5")
>>> dataset = TensorFlowModelDataset(tmp_path / "data/06_models/tensorflow_model.h5")
>>> model = tf.keras.Model()
>>> predictions = model.predict([...])
>>>
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/text/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class TextDataset(AbstractVersionedDataset[str, str]):
>>>
>>> string_to_write = "This will go in a file."
>>>
>>> dataset = TextDataset(filepath="test.md")
>>> dataset = TextDataset(filepath=tmp_path / "test.md")
>>> dataset.save(string_to_write)
>>> reloaded = dataset.load()
>>> assert string_to_write == reloaded
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/tracking/json_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class JSONDataset(json_dataset.JSONDataset):
>>>
>>> data = {"col1": 1, "col2": 0.23, "col3": 0.002}
>>>
>>> dataset = JSONDataset(filepath="test.json")
>>> dataset = JSONDataset(filepath=tmp_path / "test.json")
>>> dataset.save(data)
"""
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/tracking/metrics_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MetricsDataset(json_dataset.JSONDataset):
>>>
>>> data = {"col1": 1, "col2": 0.23, "col3": 0.002}
>>>
>>> dataset = MetricsDataset(filepath="test.json")
>>> dataset = MetricsDataset(filepath=tmp_path / "test.json")
>>> dataset.save(data)
"""
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/kedro_datasets/video/video_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class VideoDataset(AbstractDataset[AbstractVideo, AbstractVideo]):
... imgs.append(Image.fromarray(frame))
... frame -= 1
...
>>> video = VideoDataset(filepath="my_video.mp4")
>>> video = VideoDataset(filepath=tmp_path / "my_video.mp4")
>>> video.save(SequenceVideo(imgs, fps=25))
Expand All @@ -263,7 +263,7 @@ class VideoDataset(AbstractDataset[AbstractVideo, AbstractVideo]):
... yield Image.fromarray(frame)
... frame -= 1
...
>>> video = VideoDataset(filepath="my_video.mp4")
>>> video = VideoDataset(filepath=tmp_path / "my_video.mp4")
>>> video.save(GeneratorVideo(gen(), fps=25, length=None))
"""
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/kedro_datasets/yaml/yaml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class YAMLDataset(AbstractVersionedDataset[dict, dict]):
>>>
>>> data = {"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]}
>>>
>>> dataset = YAMLDataset(filepath="test.yaml")
>>> dataset = YAMLDataset(filepath=tmp_path / "test.yaml")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data == reloaded
Expand Down

0 comments on commit 8d6acac

Please sign in to comment.