From ada34dcfef09aed60151ff2400f4cfe6c972269a Mon Sep 17 00:00:00 2001 From: Aiden Grossman Date: Thu, 26 Sep 2024 05:03:12 +0000 Subject: [PATCH] Add support for emitting hparam info from train+_bc script This patch adds support for emitting hyperparameter information from the train_bc script into Tensorboard. This specifically focuses on values specified on the command line (i.e., ones that would vary during an experiment). --- compiler_opt/rl/train_bc.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/compiler_opt/rl/train_bc.py b/compiler_opt/rl/train_bc.py index 14dc80f2..ebeb0cf7 100644 --- a/compiler_opt/rl/train_bc.py +++ b/compiler_opt/rl/train_bc.py @@ -29,8 +29,10 @@ from compiler_opt.rl import registry from compiler_opt.rl import trainer +from tensorflow import summary from tf_agents.agents import tf_agent from tf_agents.policies import tf_policy +from tensorboard.plugins.hparams import api as hp from typing import Dict @@ -87,6 +89,14 @@ def train_eval(agent_config_type=agent_config.BCAgentConfig, # Save final policy. saver.save(root_dir) + # Save (command line specified) hyperparameter information. + with summary.create_file_writer(_ROOT_DIR.value).as_default(): + hparams = {} + for gin_binding in _GIN_BINDINGS.value: + param_name, param_value = gin_binding.split('=') + hparams[param_name] = param_value + hp.hparams(hparams) + def main(_): gin.parse_config_files_and_bindings(