From ca1b32446391899c5a7e8c48121188d3fff9ac02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Portela=20Afonso?= Date: Tue, 12 Mar 2024 14:07:40 +0000 Subject: [PATCH] feat(synthesizer): support for different types from datasource (#93) --- src/ydata/sdk/datasources/_models/datatype.py | 3 +++ src/ydata/sdk/synthesizers/multitable.py | 2 +- src/ydata/sdk/synthesizers/synthesizer.py | 22 ++++++++----------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/ydata/sdk/datasources/_models/datatype.py b/src/ydata/sdk/datasources/_models/datatype.py index c815e4ba..efa0f1b1 100644 --- a/src/ydata/sdk/datasources/_models/datatype.py +++ b/src/ydata/sdk/datasources/_models/datatype.py @@ -8,3 +8,6 @@ class DataSourceType(StringEnum): TIMESERIES = "timeseries" """The [`DataSource`][ydata.sdk.datasources.datasource.DataSource] has a temporal dimension. """ + MULTITABLE = "multiTable" + """The [`DataSource`][ydata.sdk.datasources.datasource.DataSource] is a multi table RDBMS. + """ diff --git a/src/ydata/sdk/synthesizers/multitable.py b/src/ydata/sdk/synthesizers/multitable.py index cc3e21e3..e9021528 100644 --- a/src/ydata/sdk/synthesizers/multitable.py +++ b/src/ydata/sdk/synthesizers/multitable.py @@ -63,7 +63,7 @@ def fit(self, X: DataSource, X (DataSource): DataSource to Train """ - self._fit_from_datasource(X) + self._fit_from_datasource(X, datatype=DataSourceType.MULTITABLE) def sample(self, frac: Union[int, float] = 1, write_connector: Optional[Union[Connector, UID]] = None) -> None: """Sample from a [`MultiTableSynthesizer`][ydata.sdk.synthesizers.MultiTableSynthesizer] diff --git a/src/ydata/sdk/synthesizers/synthesizer.py b/src/ydata/sdk/synthesizers/synthesizer.py index a1ccd7d0..5801eb0e 100644 --- a/src/ydata/sdk/synthesizers/synthesizer.py +++ b/src/ydata/sdk/synthesizers/synthesizer.py @@ -3,7 +3,6 @@ from time import sleep from typing import Dict, List, Optional, Union from uuid import uuid4 -from warnings import warn from pandas import DataFrame as pdDataFrame from pandas import read_csv @@ -17,7 +16,6 @@ DataTypeMissingError, EmptyDataError, FittingError, InputError) from ydata.sdk.common.logger import create_logger from ydata.sdk.common.types import UID, Project -from ydata.sdk.common.warnings import DataSourceTypeWarning from ydata.sdk.connectors import LocalConnector from ydata.sdk.datasources import DataSource, LocalDataSource from ydata.sdk.datasources._models.attributes import DataSourceAttrs @@ -104,12 +102,11 @@ def fit(self, X: Union[DataSource, pdDataFrame], if self._already_fitted(): raise AlreadyFittedError() - _datatype = DataSourceType(datatype) if isinstance( - X, pdDataFrame) else DataSourceType(X.datatype) + datatype = DataSourceType(datatype) dataset_attrs = self._init_datasource_attributes( sortbykey, entities, generate_cols, exclude_cols, dtypes) - self._validate_datasource_attributes(X, dataset_attrs, _datatype, target) + self._validate_datasource_attributes(X, dataset_attrs, datatype, target) # If the training data is a pandas dataframe, we first need to create a data source and then the instance if isinstance(X, pdDataFrame): @@ -121,12 +118,9 @@ def fit(self, X: Union[DataSource, pdDataFrame], self._logger.info( f'created local connector. creating datasource with {connector}') _X = LocalDataSource(connector=connector, project=self._project, - datatype=_datatype, client=self._client) + datatype=datatype, client=self._client) self._logger.info(f'created datasource {_X}') else: - if datatype != _datatype: - warn("When the training data is a DataSource, the argument `datatype` is ignored.", - DataSourceTypeWarning) _X = X if dsState(_X.status.state) != dsState.AVAILABLE: @@ -137,7 +131,7 @@ def fit(self, X: Union[DataSource, pdDataFrame], dataset_attrs = DataSourceAttrs(**dataset_attrs) self._fit_from_datasource( - X=_X, dataset_attrs=dataset_attrs, target=target, + X=_X, datatype=datatype, dataset_attrs=dataset_attrs, target=target, anonymize=anonymize, privacy_level=privacy_level, condition_on=condition_on) @staticmethod @@ -164,7 +158,6 @@ def _validate_datasource_attributes(X: Union[DataSource, pdDataFrame], dataset_a if datatype is None: raise DataTypeMissingError( "Argument `datatype` is mandatory for pandas.DataFrame training data") - datatype = DataSourceType(datatype) else: columns = [c.name for c in X.metadata.columns] @@ -232,6 +225,7 @@ def _metadata_to_payload( def _fit_from_datasource( self, X: DataSource, + datatype: DataSourceType, privacy_level: Optional[PrivacyLevel] = None, dataset_attrs: Optional[DataSourceAttrs] = None, target: Optional[str] = None, @@ -245,9 +239,11 @@ def _fit_from_datasource( if privacy_level: payload['privacyLevel'] = privacy_level.value - if X.metadata is not None and X.datatype is not None: + if X.metadata is not None: payload['metadata'] = self._metadata_to_payload( - DataSourceType(X.datatype), X.metadata, dataset_attrs, target) + datatype, X.metadata, dataset_attrs, target) + + payload['type'] = str(datatype.value) if anonymize is not None: payload["extraData"]["anonymize"] = anonymize