forked from autogluon/autogluon
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[eda] Initial PR to add constructs (autogluon#2209)
[EDA] added base constructs
- Loading branch information
1 parent
4e977c4
commit 1f1aa10
Showing
23 changed files
with
755 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#!/usr/bin/env python | ||
########################### | ||
# This code block is a HACK (!), but is necessary to avoid code duplication. Do NOT alter these lines. | ||
import importlib.util | ||
import os | ||
|
||
from setuptools import setup | ||
|
||
filepath = os.path.abspath(os.path.dirname(__file__)) | ||
filepath_import = os.path.join(filepath, '..', 'core', 'src', 'autogluon', 'core', '_setup_utils.py') | ||
spec = importlib.util.spec_from_file_location("ag_min_dependencies", filepath_import) | ||
ag = importlib.util.module_from_spec(spec) | ||
# Identical to `from autogluon.core import _setup_utils as ag`, but works without `autogluon.core` being installed. | ||
spec.loader.exec_module(ag) | ||
########################### | ||
|
||
version = ag.load_version_file() | ||
version = ag.update_version(version, use_file_if_exists=False, create_file=True) | ||
|
||
submodule = 'eda' | ||
install_requires = [ | ||
# version ranges added in ag.get_dependency_version_ranges() | ||
'numpy', | ||
'scipy', | ||
'scikit-learn', | ||
'pandas', | ||
'matplotlib', | ||
'missingno>=0.5.1,<0.6', | ||
'phik>=0.12.2,<0.13', | ||
'seaborn>=0.12.0,<0.13', | ||
'ipython>=8.0,<9.0', | ||
'ipywidgets>=8.0,<9.0', | ||
] | ||
|
||
extras_require = dict() | ||
|
||
test_requirements = [ | ||
'pytest' | ||
] | ||
|
||
test_requirements = list(set(test_requirements)) | ||
extras_require['tests'] = test_requirements | ||
|
||
install_requires = ag.get_dependency_version_ranges(install_requires) | ||
for key in extras_require: | ||
extras_require[key] = ag.get_dependency_version_ranges(extras_require[key]) | ||
|
||
if __name__ == '__main__': | ||
ag.create_version_file(version=version, submodule=submodule) | ||
setup_args = ag.default_setup_args(version=version, submodule=submodule) | ||
setup( | ||
install_requires=install_requires, | ||
extras_require=extras_require, | ||
**setup_args, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__import__("pkg_resources").declare_namespace(__name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .state import AnalysisState |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
import phik | ||
from .base import Namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from abc import ABC, abstractmethod | ||
from typing import List, Union, Tuple | ||
|
||
from pandas import DataFrame | ||
|
||
from ..state import AnalysisState, StateCheckMixin | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AbstractAnalysis(ABC, StateCheckMixin): | ||
|
||
def __init__(self, | ||
parent: Union[None, AbstractAnalysis] = None, | ||
children: List[AbstractAnalysis] = [], | ||
state: AnalysisState = None, | ||
**kwargs) -> None: | ||
|
||
self.parent = parent | ||
self.children: List[AbstractAnalysis] = children | ||
self.state: AnalysisState = state | ||
for c in self.children: | ||
c.parent = self | ||
c.state = self.state | ||
self.args = kwargs | ||
|
||
def _gather_args(self) -> AnalysisState: | ||
chain = [self] | ||
while chain[0].parent is not None: | ||
chain.insert(0, chain[0].parent) | ||
args = {} | ||
for node in chain: | ||
args = AnalysisState({**args, **node.args}) | ||
return args | ||
|
||
def available_datasets(self, args: AnalysisState) -> Tuple[str, DataFrame]: | ||
""" | ||
Generator which iterates only through the datasets provided in arguments | ||
Parameters | ||
---------- | ||
args: AnalysisState | ||
arguments passed into the call. These are different from `self.args` in a way that it's arguments assembled from the | ||
parents and shadowed via children (allows to isolate reused parameters in upper arguments declarations. | ||
Returns | ||
------- | ||
tuple of dataset name (train_data, test_data or tuning_data) and dataset itself | ||
""" | ||
for ds in ['train_data', 'test_data', 'tuning_data', 'val_data']: | ||
if ds in args and args[ds] is not None: | ||
df: DataFrame = args[ds] | ||
yield ds, df | ||
|
||
def _get_state_from_parent(self) -> AnalysisState: | ||
state = self.state | ||
if state is None: | ||
if self.parent is None: | ||
state = AnalysisState() | ||
else: | ||
state = self.parent.state | ||
return state | ||
|
||
@abstractmethod | ||
def can_handle(self, state: AnalysisState, args: AnalysisState) -> bool: | ||
""" | ||
Checks if state and args has all the required parameters for fitting. | ||
See also :func:`at_least_one_key_must_be_present` and :func:`all_keys_must_be_present` helpers | ||
to construct more complex logic. | ||
Parameters | ||
---------- | ||
state: AnalysisState | ||
state to be updated by this fit function | ||
args: AnalysisState | ||
analysis properties assembled from root of analysis hierarchy to this component (with lower levels shadowing upper level args). | ||
Returns | ||
------- | ||
`True` if all the pre-requisites for fitting are present | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _fit(self, state: AnalysisState, args: AnalysisState, **fit_kwargs) -> None: | ||
""" | ||
@override | ||
Component-specific internal processing. | ||
This method is designed to be overridden by the component developer | ||
Parameters | ||
---------- | ||
state: AnalysisState | ||
state to be updated by this fit function | ||
args: AnalysisState | ||
analysis properties assembled from root of analysis hierarchy to this component (with lower levels shadowing upper level args). | ||
fit_kwargs | ||
arguments passed into fit call | ||
""" | ||
raise NotImplementedError | ||
|
||
def fit(self, **kwargs) -> AnalysisState: | ||
""" | ||
Fit the analysis tree. | ||
Parameters | ||
---------- | ||
kwargs | ||
fit arguments | ||
Returns | ||
------- | ||
state produced by fit | ||
""" | ||
self.state = self._get_state_from_parent() | ||
if self.parent is not None: | ||
assert self.state is not None, "Inner analysis fit() is called while parent has no state. Please call top-level analysis fit instead" | ||
_args = self._gather_args() | ||
if self.can_handle(self.state, _args): | ||
self._fit(self.state, _args, **kwargs) | ||
for c in self.children: | ||
c.fit(**kwargs) | ||
return self.state | ||
|
||
|
||
class BaseAnalysis(AbstractAnalysis): | ||
|
||
def __init__(self, | ||
parent: Union[None, AbstractAnalysis] = None, | ||
children: List[AbstractAnalysis] = [], | ||
**kwargs) -> None: | ||
super().__init__(parent, children, **kwargs) | ||
|
||
def can_handle(self, state: AnalysisState, args: AnalysisState) -> bool: | ||
return True | ||
|
||
def _fit(self, state: AnalysisState, args: AnalysisState, **fit_kwargs): | ||
pass | ||
|
||
|
||
class Namespace(AbstractAnalysis): | ||
|
||
def can_handle(self, state: AnalysisState, args: AnalysisState) -> bool: | ||
return True | ||
|
||
def __init__(self, | ||
namespace: str = None, | ||
parent: Union[None, AbstractAnalysis] = None, | ||
children: List[AbstractAnalysis] = [], | ||
**kwargs) -> None: | ||
super().__init__(parent, children, **kwargs) | ||
self.namespace = namespace | ||
|
||
def fit(self, **kwargs) -> AnalysisState: | ||
assert self.parent is not None, "Namespace must be wrapped into other analysis. You can use BaseAnalysis of one is needed" | ||
return super().fit(**kwargs) | ||
|
||
def _fit(self, state: AnalysisState, args: AnalysisState, **fit_kwargs): | ||
pass | ||
|
||
def _get_state_from_parent(self) -> AnalysisState: | ||
state = super()._get_state_from_parent() | ||
if self.namespace not in state: | ||
state[self.namespace] = {} | ||
return state[self.namespace] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import logging | ||
from typing import List | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AnalysisState(dict): | ||
"""Dictionary wrapper which enables dot.notation access to dictionary attributes and dynamic code assist in jupyter""" | ||
__getattr__ = dict.get | ||
__delattr__ = dict.__delitem__ | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
for arg in args: | ||
if isinstance(arg, dict): | ||
for k, v in arg.items(): | ||
self[k] = v | ||
|
||
for k, v in kwargs.items(): | ||
self[k] = v | ||
|
||
def __setattr__(self, name: str, value) -> None: | ||
if isinstance(value, dict): | ||
value = AnalysisState(value) | ||
self[name] = value | ||
|
||
def __setitem__(self, key, value) -> None: | ||
if isinstance(value, dict): | ||
value = AnalysisState(value) | ||
super().__setitem__(key, value) | ||
|
||
@property | ||
def __dict__(self): | ||
return self | ||
|
||
|
||
class StateCheckMixin: | ||
def at_least_one_key_must_be_present(self, state: AnalysisState, keys: List[str]): | ||
""" | ||
Checks if at least one key is present in the state | ||
Parameters | ||
---------- | ||
state: AnalysisState | ||
state object to perform check on | ||
keys: List[str] | ||
list of the keys to check | ||
Returns | ||
------- | ||
True if at least one key from the `keys` list is present in the state | ||
""" | ||
for k in keys: | ||
if k in state: | ||
return True | ||
logger.warning(f'{self.__class__.__name__}: at least one of the following keys must be present: {keys}') | ||
return False | ||
|
||
def all_keys_must_be_present(self, state: AnalysisState, keys: List[str]): | ||
""" | ||
Checks if all the keys are present in the state | ||
Parameters | ||
---------- | ||
state: AnalysisState | ||
state object to perform check on | ||
keys: List[str] | ||
list of the keys to check | ||
Returns | ||
------- | ||
True if all the key from the `keys` list are present in the state | ||
""" | ||
keys_not_present = [k for k in keys if k not in state.keys()] | ||
can_handle = len(keys_not_present) == 0 | ||
if not can_handle: | ||
logger.warning(f'{self.__class__.__name__}: all of the following keys must be present: {keys}. The following keys are missing: {keys_not_present}') | ||
return can_handle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .layouts import SimpleVerticalLinearLayout, SimpleHorizontalLayout, TabLayout, MarkdownSectionComponent |
Oops, something went wrong.