Skip to content

Commit

Permalink
added decoding to sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzkaminski committed Aug 2, 2023
1 parent 18f3c55 commit f909764
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 30 deletions.
2 changes: 1 addition & 1 deletion bamt/builders/composite_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def overwrite_vertex(
node = GaussianNode(name=node_instance.name, regressor=regressor)
elif (
len(node_instance.cont_parents + node_instance.disc_parents) < 1
and type(node).__name__ == "CompositeDiscreteNode"
and type(node_instance).__name__ == "CompositeDiscreteNode"
):
node = DiscreteNode(name=node_instance.name)
else:
Expand Down
2 changes: 1 addition & 1 deletion bamt/builders/evo_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,4 @@ def search(self, data: DataFrame, **kwargs) -> List[Tuple[str, str]]:

@staticmethod
def _convert_to_strings(nested_list):
return [[str(item) for item in inner_list] for inner_list in nested_list]
return [tuple([str(item) for item in inner_list]) for inner_list in nested_list]
11 changes: 10 additions & 1 deletion bamt/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,18 @@ def wrapper():
seq_df = seq_df[(seq_df[positive_columns] >= 0).all(axis=1)]
seq_df.reset_index(inplace=True, drop=True)
seq = seq_df.to_dict("records")
sample_output = pd.DataFrame.from_dict(seq, orient="columns")

if as_df:
return pd.DataFrame.from_dict(seq, orient="columns")
if self.has_logit or self.type == "Composite":
for node in self.nodes:
for feature_key, encoder in node.encoders:
sample_output[feature_key] = encoder[
feature_key
].inverse_transform(sample_output[feature_key])
pass

return sample_output
else:
return seq

Expand Down
2 changes: 1 addition & 1 deletion bamt/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def encode_categorical_data_if_any(func):
@wraps(func)
def wrapper(self, data, *args, **kwargs):
for column in self.disc_parents + [self.name]:
if data[column].dtype == "object" or data[column].dtype == "str":
if data[column].dtype in ("object", "str"):
encoder = LabelEncoder()
data[column] = encoder.fit_transform(data[column])
self.encoders[column] = encoder
Expand Down
52 changes: 26 additions & 26 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,41 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent / '../../'))
sys.path.insert(0, str(Path(__file__).parent / "../../"))

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = 'BAMT'
copyright = '2023, NSS lab'
author = 'NSS lab'
release = '0.1.0'
project = "BAMT"
copyright = "2023, NSS lab"
author = "NSS lab"
release = "0.1.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

extensions = [
"myst_parser",
'sphinx_rtd_theme',
'sphinx.ext.autodoc',
'sphinx.ext.coverage',
'sphinx.ext.napoleon',
'sphinx.ext.viewcode',
'sphinx.ext.mathjax',
'sphinx.ext.autosummary',
'sphinx.ext.autodoc.typehints',
'sphinx.ext.graphviz',
'sphinx.ext.todo'
"sphinx_rtd_theme",
"sphinx.ext.autodoc",
"sphinx.ext.coverage",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.mathjax",
"sphinx.ext.autosummary",
"sphinx.ext.autodoc.typehints",
"sphinx.ext.graphviz",
"sphinx.ext.todo",
]

templates_path = ['_templates']
templates_path = ["_templates"]
exclude_patterns = []


# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_theme = 'sphinx_rtd_theme'
html_theme = "sphinx_rtd_theme"
# html_static_path = ['_static']

# -- Extension configuration -------------------------------------------------
Expand All @@ -67,13 +67,13 @@
napoleon_attr_annotations = False

autodoc_default_options = {
'members': True,
'undoc-members': False,
'show-inheritance': True,
'member-order': 'bysource',
'ignore-module-all': True,
"members": True,
"undoc-members": False,
"show-inheritance": True,
"member-order": "bysource",
"ignore-module-all": True,
}
autoclass_content = 'class'
autodoc_typehints = 'signature'
autodoc_typehints_format = 'short'
autodoc_mock_imports = ['objgraph', 'memory_profiler', 'gprof2dot', 'snakeviz']
autoclass_content = "class"
autodoc_typehints = "signature"
autodoc_typehints_format = "short"
autodoc_mock_imports = ["objgraph", "memory_profiler", "gprof2dot", "snakeviz"]

0 comments on commit f909764

Please sign in to comment.