Skip to content

Commit

Permalink
Merge pull request #16 from AutoResearch/add-standard-operators-and_f…
Browse files Browse the repository at this point in the history
…unctions

chore: add naming of datafile in export srbench
  • Loading branch information
younesStrittmatter authored Sep 13, 2023
2 parents e978ccf + d0293c8 commit e6847c9
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/equation_tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import sympy
import yaml
from sympy import simplify, symbols, sympify
from sympy import simplify, symbols, sympify, dotprint

from equation_tree.src.tree_node import (
NodeKind,
Expand Down Expand Up @@ -822,6 +822,7 @@ def save_meta_srbench(self, path, name_dataset, name_target="y"):
def export_to_srbench(
self,
folder: str,
data_file_name: Optional[str] = None,
num_samples: int = 1000,
name_target: str = "y",
ranges: Optional[Dict] = None,
Expand All @@ -831,9 +832,20 @@ def export_to_srbench(
"""
Creates a folder and adds data and metadata to the folder that can be used with sr bench:
https://cavalab.org/srbench/
Args:
folder: Name of the folder
data_file_name: Name of the datafile (if none same as folder name)
num_samples: Number of samples
name_target: Name of the tartget
ranges: A dictionary with the ranges for the variables in form of a dict
default_range: Default range to fall back to if no range for a
specific variable is given
random_state: The random seed to be used
"""
if data_file_name is None:
data_file_name = folder
os.mkdir(folder)
path_data = f"{folder}/data.tsv.gz"
path_data = f"{folder}/{data_file_name}.tsv.gz"
path_meta = f"{folder}/metadata.yaml"
self.save_samples_srbench(
path_data, num_samples, ranges, default_range, random_state
Expand Down Expand Up @@ -899,6 +911,14 @@ def check_validity(
zero_representations, log_representations, division_representations, verbose
)

def draw_tree(self, out):
try:
from graphviz import Source
except ImportError:
print('drawing uses requires `graphviz` to be installed: `pip install graphviz`')
src = Source(dotprint(self.sympy_expr))
src.render(out, view=False)

def check_possible(
self,
feature_priors: Dict,
Expand Down

0 comments on commit e6847c9

Please sign in to comment.