Skip to content

Commit

Permalink
Capture wattile version Fixes #322 (#324)
Browse files Browse the repository at this point in the history
* Captures the Wattile version as a prop on the Wattile ob
Logs the Wattile version to the output.out artifact post training

* Adds metadata.json file to model artifacts

* Fixes formatting errors
  • Loading branch information
smithcommajoseph authored Aug 19, 2024
1 parent e7fea12 commit e17d91b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 0 deletions.
10 changes: 10 additions & 0 deletions wattile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from os import path

import toml

b_path = path.dirname(__file__)
proj_path = path.abspath(path.join(b_path, "..", "pyproject.toml"))

with open(proj_path, "r") as f:
config = toml.load(f)
version = config["tool"]["poetry"]["version"]
2 changes: 2 additions & 0 deletions wattile/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd

import wattile.data_processing as bp
from wattile import version as wattile_version
from wattile.data_reading import read_dataset_from_file
from wattile.models import ModelFactory

Expand Down Expand Up @@ -35,6 +36,7 @@ def init_logging(local_results_dir):
logger.addHandler(hdlr)
logger.setLevel(logging.INFO)
logger.info("PID: {}".format(PID))
logger.info("Trained with Wattile version: {}".format(wattile_version))


def create_input_dataframe(configs):
Expand Down
9 changes: 9 additions & 0 deletions wattile/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from wattile import version as wattile_version
from wattile.error import ConfigsError
from wattile.util import factors
from wattile.visualization import timeseries_comparison
Expand Down Expand Up @@ -126,6 +127,13 @@ def apply_normalization(self, data: pd.DataFrame) -> pd.DataFrame:

return data

def write_metadata(self):
"""Write the metadata to a json file"""
path = os.path.join(self.file_prefix, "metadata.json")
metadata = {"wattile_version": wattile_version}
with open(path, "w") as fp:
json.dump(metadata, fp, indent=1)

def main(self, train_df, val_df):
"""
Main executable for prepping data for input to RNN model.
Expand Down Expand Up @@ -161,6 +169,7 @@ def train(self, train_df: pd.DataFrame, val_df: pd.DataFrame) -> None:
val_loader = self.to_data_loader(val_data, val_batch_size, shuffle=True)

self.run_training(train_loader, val_loader, val_df)
self.write_metadata()

# Create visualization
if self.configs["data_output"]["plot_comparison"]:
Expand Down

0 comments on commit e17d91b

Please sign in to comment.