Skip to content

Commit

Permalink
feat(synthesizer): support for different types from datasource (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
portellaa committed Mar 12, 2024
1 parent eedd5e3 commit ca1b324
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/ydata/sdk/datasources/_models/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ class DataSourceType(StringEnum):
TIMESERIES = "timeseries"
"""The [`DataSource`][ydata.sdk.datasources.datasource.DataSource] has a temporal dimension.
"""
MULTITABLE = "multiTable"
"""The [`DataSource`][ydata.sdk.datasources.datasource.DataSource] is a multi table RDBMS.
"""
2 changes: 1 addition & 1 deletion src/ydata/sdk/synthesizers/multitable.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fit(self, X: DataSource,
X (DataSource): DataSource to Train
"""

self._fit_from_datasource(X)
self._fit_from_datasource(X, datatype=DataSourceType.MULTITABLE)

def sample(self, frac: Union[int, float] = 1, write_connector: Optional[Union[Connector, UID]] = None) -> None:
"""Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer]
Expand Down
22 changes: 9 additions & 13 deletions src/ydata/sdk/synthesizers/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from time import sleep
from typing import Dict, List, Optional, Union
from uuid import uuid4
from warnings import warn

from pandas import DataFrame as pdDataFrame
from pandas import read_csv
Expand All @@ -17,7 +16,6 @@
DataTypeMissingError, EmptyDataError, FittingError, InputError)
from ydata.sdk.common.logger import create_logger
from ydata.sdk.common.types import UID, Project
from ydata.sdk.common.warnings import DataSourceTypeWarning
from ydata.sdk.connectors import LocalConnector
from ydata.sdk.datasources import DataSource, LocalDataSource
from ydata.sdk.datasources._models.attributes import DataSourceAttrs
Expand Down Expand Up @@ -104,12 +102,11 @@ def fit(self, X: Union[DataSource, pdDataFrame],
if self._already_fitted():
raise AlreadyFittedError()

_datatype = DataSourceType(datatype) if isinstance(
X, pdDataFrame) else DataSourceType(X.datatype)
datatype = DataSourceType(datatype)

dataset_attrs = self._init_datasource_attributes(
sortbykey, entities, generate_cols, exclude_cols, dtypes)
self._validate_datasource_attributes(X, dataset_attrs, _datatype, target)
self._validate_datasource_attributes(X, dataset_attrs, datatype, target)

# If the training data is a pandas dataframe, we first need to create a data source and then the instance
if isinstance(X, pdDataFrame):
Expand All @@ -121,12 +118,9 @@ def fit(self, X: Union[DataSource, pdDataFrame],
self._logger.info(
f'created local connector. creating datasource with {connector}')
_X = LocalDataSource(connector=connector, project=self._project,
datatype=_datatype, client=self._client)
datatype=datatype, client=self._client)
self._logger.info(f'created datasource {_X}')
else:
if datatype != _datatype:
warn("When the training data is a DataSource, the argument `datatype` is ignored.",
DataSourceTypeWarning)
_X = X

if dsState(_X.status.state) != dsState.AVAILABLE:
Expand All @@ -137,7 +131,7 @@ def fit(self, X: Union[DataSource, pdDataFrame],
dataset_attrs = DataSourceAttrs(**dataset_attrs)

self._fit_from_datasource(
X=_X, dataset_attrs=dataset_attrs, target=target,
X=_X, datatype=datatype, dataset_attrs=dataset_attrs, target=target,
anonymize=anonymize, privacy_level=privacy_level, condition_on=condition_on)

@staticmethod
Expand All @@ -164,7 +158,6 @@ def _validate_datasource_attributes(X: Union[DataSource, pdDataFrame], dataset_a
if datatype is None:
raise DataTypeMissingError(
"Argument `datatype` is mandatory for pandas.DataFrame training data")
datatype = DataSourceType(datatype)
else:
columns = [c.name for c in X.metadata.columns]

Expand Down Expand Up @@ -232,6 +225,7 @@ def _metadata_to_payload(
def _fit_from_datasource(
self,
X: DataSource,
datatype: DataSourceType,
privacy_level: Optional[PrivacyLevel] = None,
dataset_attrs: Optional[DataSourceAttrs] = None,
target: Optional[str] = None,
Expand All @@ -245,9 +239,11 @@ def _fit_from_datasource(
if privacy_level:
payload['privacyLevel'] = privacy_level.value

if X.metadata is not None and X.datatype is not None:
if X.metadata is not None:
payload['metadata'] = self._metadata_to_payload(
DataSourceType(X.datatype), X.metadata, dataset_attrs, target)
datatype, X.metadata, dataset_attrs, target)

payload['type'] = str(datatype.value)

if anonymize is not None:
payload["extraData"]["anonymize"] = anonymize
Expand Down

0 comments on commit ca1b324

Please sign in to comment.