Skip to content

Commit

Permalink
FEAT: implement git_friendly flag for YAMLSummary callback (#540)
Browse files Browse the repository at this point in the history
* DOC: improve `YAMLSummary` docstring
  • Loading branch information
redeboer authored Jan 29, 2025
1 parent f7631d8 commit 762ead5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
17 changes: 14 additions & 3 deletions docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@
" callback=CallbackList([\n",
" CSVSummary(\"traceback-1D.csv\"),\n",
" YAMLSummary(\"fit-result-1D.yaml\"),\n",
" YAMLSummary(\"fit-result-1D-git-friendly.yaml\", git_friendly=True),\n",
" TFSummary(),\n",
" ])\n",
")\n",
Expand All @@ -733,12 +734,22 @@
"source_hidden": true
},
"tags": [
"remove-cell"
"remove-input"
]
},
"outputs": [],
"source": [
"assert fit_result.minimum_valid"
"import yaml\n",
"\n",
"assert fit_result.minimum_valid\n",
"with open(\"fit-result-1D.yaml\") as stream:\n",
" yaml_result = yaml.safe_load(stream)\n",
"with open(\"fit-result-1D-git-friendly.yaml\") as stream:\n",
" yaml_result_git = yaml.safe_load(stream)\n",
"assert \"time\" in yaml_result\n",
"assert \"time\" not in yaml_result_git\n",
"yaml_result.pop(\"time\")\n",
"assert yaml_result_git == yaml_result"
]
},
{
Expand Down Expand Up @@ -1285,7 +1296,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
19 changes: 17 additions & 2 deletions src/tensorwaves/optimizer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,25 @@ def on_function_call_end(


class YAMLSummary(Callback, Loadable):
"""Write current fit parameters and the estimator value to a YAML file."""
"""Write current fit parameters and the estimator value to a YAML file.
def __init__(self, filename: Path | str, step_size: int = 10) -> None:
Arguments:
filename: The name of output YAML file to write the logs to.
step_size: The number of function calls between each log entry.
git_friendly: If `True`, entries that are differ per run in reproducible fits,
such as :code:`time`, are omitted from the log.
"""

def __init__(
self,
filename: Path | str,
step_size: int = 10,
git_friendly: bool = False,
) -> None:
self.__step_size = step_size
self.__filename = filename
self.__stream: IO | None = None
self.__git_friendly = git_friendly

def __del__(self) -> None:
_close_stream(self.__stream)
Expand Down Expand Up @@ -322,6 +335,8 @@ def __dump_to_yaml(self, logs: dict[str, Any]) -> None:
cast_logs["parameters"] = {
p: _cast_value(v) for p, v in logs["parameters"].items()
}
if self.__git_friendly:
cast_logs.pop("time", None)
yaml.dump(
cast_logs,
self.__stream,
Expand Down

0 comments on commit 762ead5

Please sign in to comment.