diff --git a/bamt/builders/composite_builder.py b/bamt/builders/composite_builder.py index 0033871..64b725c 100644 --- a/bamt/builders/composite_builder.py +++ b/bamt/builders/composite_builder.py @@ -133,7 +133,7 @@ def overwrite_vertex( node = GaussianNode(name=node_instance.name, regressor=regressor) elif ( len(node_instance.cont_parents + node_instance.disc_parents) < 1 - and type(node).__name__ == "CompositeDiscreteNode" + and type(node_instance).__name__ == "CompositeDiscreteNode" ): node = DiscreteNode(name=node_instance.name) else: diff --git a/bamt/builders/evo_builder.py b/bamt/builders/evo_builder.py index 0da067e..0b7c39d 100644 --- a/bamt/builders/evo_builder.py +++ b/bamt/builders/evo_builder.py @@ -230,4 +230,4 @@ def search(self, data: DataFrame, **kwargs) -> List[Tuple[str, str]]: @staticmethod def _convert_to_strings(nested_list): - return [[str(item) for item in inner_list] for inner_list in nested_list] + return [tuple([str(item) for item in inner_list]) for inner_list in nested_list] diff --git a/bamt/networks/base.py b/bamt/networks/base.py index d363d27..9d46054 100644 --- a/bamt/networks/base.py +++ b/bamt/networks/base.py @@ -649,9 +649,18 @@ def wrapper(): seq_df = seq_df[(seq_df[positive_columns] >= 0).all(axis=1)] seq_df.reset_index(inplace=True, drop=True) seq = seq_df.to_dict("records") + sample_output = pd.DataFrame.from_dict(seq, orient="columns") if as_df: - return pd.DataFrame.from_dict(seq, orient="columns") + if self.has_logit or self.type == "Composite": + for node in self.nodes: + for feature_key, encoder in node.encoders: + sample_output[feature_key] = encoder[ + feature_key + ].inverse_transform(sample_output[feature_key]) + pass + + return sample_output else: return seq diff --git a/bamt/nodes/base.py b/bamt/nodes/base.py index af25abc..13e2757 100644 --- a/bamt/nodes/base.py +++ b/bamt/nodes/base.py @@ -96,7 +96,7 @@ def encode_categorical_data_if_any(func): @wraps(func) def wrapper(self, data, *args, **kwargs): for column in self.disc_parents + [self.name]: - if data[column].dtype == "object" or data[column].dtype == "str": + if data[column].dtype in ("object", "str"): encoder = LabelEncoder() data[column] = encoder.fit_transform(data[column]) self.encoders[column] = encoder diff --git a/docs/source/conf.py b/docs/source/conf.py index f17a7c1..92a1e5f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,41 +13,41 @@ import sys from pathlib import Path -sys.path.insert(0, str(Path(__file__).parent / '../../')) +sys.path.insert(0, str(Path(__file__).parent / "../../")) # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -project = 'BAMT' -copyright = '2023, NSS lab' -author = 'NSS lab' -release = '0.1.0' +project = "BAMT" +copyright = "2023, NSS lab" +author = "NSS lab" +release = "0.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration extensions = [ "myst_parser", - 'sphinx_rtd_theme', - 'sphinx.ext.autodoc', - 'sphinx.ext.coverage', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx.ext.mathjax', - 'sphinx.ext.autosummary', - 'sphinx.ext.autodoc.typehints', - 'sphinx.ext.graphviz', - 'sphinx.ext.todo' + "sphinx_rtd_theme", + "sphinx.ext.autodoc", + "sphinx.ext.coverage", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.mathjax", + "sphinx.ext.autosummary", + "sphinx.ext.autodoc.typehints", + "sphinx.ext.graphviz", + "sphinx.ext.todo", ] -templates_path = ['_templates'] +templates_path = ["_templates"] exclude_patterns = [] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # html_static_path = ['_static'] # -- Extension configuration ------------------------------------------------- @@ -67,13 +67,13 @@ napoleon_attr_annotations = False autodoc_default_options = { - 'members': True, - 'undoc-members': False, - 'show-inheritance': True, - 'member-order': 'bysource', - 'ignore-module-all': True, + "members": True, + "undoc-members": False, + "show-inheritance": True, + "member-order": "bysource", + "ignore-module-all": True, } -autoclass_content = 'class' -autodoc_typehints = 'signature' -autodoc_typehints_format = 'short' -autodoc_mock_imports = ['objgraph', 'memory_profiler', 'gprof2dot', 'snakeviz'] +autoclass_content = "class" +autodoc_typehints = "signature" +autodoc_typehints_format = "short" +autodoc_mock_imports = ["objgraph", "memory_profiler", "gprof2dot", "snakeviz"]