Skip to content

Commit

Permalink
Merge pull request #15 from AutoResearch/add-standard-operators-and_f…
Browse files Browse the repository at this point in the history
…unctions

bug: sympyfy might crash
  • Loading branch information
younesStrittmatter authored Sep 9, 2023
2 parents d04b2d0 + 9df9e6f commit e978ccf
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 35 deletions.
72 changes: 38 additions & 34 deletions src/equation_tree/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def sample(n, prior, max_num_variables, file=None):


def burn(
prior,
max_num_variables,
file,
n=100_000,
alpha=1,
prior,
max_num_variables,
file,
n=100_000,
alpha=1,
):
adjusted_prior = load(prior, max_num_variables, file)
sample_ = sample(n, adjusted_prior, max_num_variables)
Expand Down Expand Up @@ -75,9 +75,9 @@ def sample_trees(n, prior, max_num_variables):


def __sample_tree_raw_fast(
prior,
tree_depth,
max_num_variables,
prior,
tree_depth,
max_num_variables,
):
equation_tree = EquationTree.from_prior_fast(prior, tree_depth, max_num_variables)

Expand All @@ -99,15 +99,19 @@ def __sample_tree_raw_fast(

# Check if duplicate constants
if (
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
):
return None

# Check if more variables than max:
if equation_tree.n_variables > max_num_variables:
return None

# Check if tree depth is exact
if len(equation_tree.structure) != tree_depth:
return None

if not equation_tree.check_validity():
return None

Expand All @@ -122,8 +126,8 @@ def __sample_tree_raw_fast(


def __sample_tree_raw(
prior,
max_num_variables,
prior,
max_num_variables,
):
equation_tree = EquationTree.from_prior(prior, max_num_variables)

Expand All @@ -145,8 +149,8 @@ def __sample_tree_raw(

# Check if duplicate constants
if (
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
):
return None

Expand All @@ -168,9 +172,9 @@ def __sample_tree_raw(


def _sample_tree_iter_fast(
prior,
tree_depth,
max_num_variables,
prior,
tree_depth,
max_num_variables,
):
for _ in range(MAX_ITER):
equation_tree = __sample_tree_raw_fast(
Expand All @@ -183,8 +187,8 @@ def _sample_tree_iter_fast(


def _sample_tree_iter(
prior,
max_num_variables,
prior,
max_num_variables,
):
for _ in range(MAX_ITER):
equation_tree = __sample_tree_raw(
Expand All @@ -196,12 +200,12 @@ def _sample_tree_iter(


def sample_tree_raw_from_priors(
max_num_constants: int = 0,
max_num_variables: int = 1,
feature_priors: Optional[Dict] = None,
function_priors: PriorType = DEFAULT_FUNCTION_SPACE,
operator_priors: PriorType = DEFAULT_OPERATOR_SPACE,
structure_priors: PriorType = {},
max_num_constants: int = 0,
max_num_variables: int = 1,
feature_priors: Optional[Dict] = None,
function_priors: PriorType = DEFAULT_FUNCTION_SPACE,
operator_priors: PriorType = DEFAULT_OPERATOR_SPACE,
structure_priors: PriorType = {},
):
"""
Sample a tree from priors, simplify and check if valid tree
Expand Down Expand Up @@ -264,8 +268,8 @@ def sample_tree_raw_from_priors(

# Check if duplicate constants
if (
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
equation_tree.n_non_numeric_constants
> equation_tree.n_non_numeric_constants_unique
):
return None

Expand All @@ -281,7 +285,7 @@ def sample_tree_raw_from_priors(
return None

if not equation_tree.check_possible(
_feature_priors, _function_priors, _operator_priors, _structure_priors
_feature_priors, _function_priors, _operator_priors, _structure_priors
):
return None

Expand All @@ -293,12 +297,12 @@ def sample_tree_raw_from_priors(


def sample_tree_from_priors_iter(
max_num_constants: int = 0,
max_num_variables: int = 1,
feature_priors: Optional[Dict] = None,
function_priors: PriorType = DEFAULT_FUNCTION_SPACE,
operator_priors: PriorType = DEFAULT_OPERATOR_SPACE,
structure_priors: PriorType = {},
max_num_constants: int = 0,
max_num_variables: int = 1,
feature_priors: Optional[Dict] = None,
function_priors: PriorType = DEFAULT_FUNCTION_SPACE,
operator_priors: PriorType = DEFAULT_OPERATOR_SPACE,
structure_priors: PriorType = {},
):
for _ in range(MAX_ITER):
equation_tree = sample_tree_raw_from_priors(
Expand Down
2 changes: 1 addition & 1 deletion src/equation_tree/util/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def set_priors(priors=None, space=None):
if priors:
if not set(priors.keys()).issubset(set(space)):
raise Exception(f"Priors {priors} are not subset of space {space}")
total_custom_prior = sum(math.floor(p*100)/100 for p in (priors.values()))
total_custom_prior = sum(math.floor(p * 100) / 100 for p in (priors.values()))
if total_custom_prior > 1:
raise ValueError(f"Sum of custom priors {priors} is greater than 1")

Expand Down

0 comments on commit e978ccf

Please sign in to comment.