Skip to content

Commit

Permalink
added generative model wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Fuest committed Oct 2, 2024
1 parent a929cfb commit e5cbd4b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions generator/generative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from generator.options import Options


class GenerativeModelWrapper:
class GenerativeModel:
"""
A wrapper class for generative models with a scikit-learn-like API.
A wrapper class for generative models.
"""

def __init__(self, model_name: str, model_params: dict = None):
Expand Down Expand Up @@ -41,15 +41,14 @@ def _initialize_model(self):
"acgan": ACGAN,
"diffusion_ts": Diffusion_TS,
"diffcharge": DDPM,
# Add other models as needed
}
if self.model_name in model_dict:
model_class = model_dict[self.model_name]
self.model = model_class(self.opt)
else:
raise ValueError(f"Model {self.model_name} not recognized.")

def fit(self, X, y=None):
def fit(self, X):
"""
Train the model on the given dataset.
Expand Down

0 comments on commit e5cbd4b

Please sign in to comment.