Skip to content

Commit

Permalink
fix: move refinement params dict to start metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Apr 12, 2024
1 parent fa5c303 commit e6bc16b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
3 changes: 3 additions & 0 deletions pdf_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
roi: Optional[Tuple] = None,
norm_region: Optional[Tuple] = None,
offline=False,
metadata=None,
**kwargs,
):
if offline:
Expand Down Expand Up @@ -75,6 +76,8 @@ def __init__(
roi_key=self.roi_key,
roi=self.roi,
)
metadata = metadata or {}
md.update(metadata)
super().__init__(*args, metadata=md, **_default_kwargs)

def measurement_plan(self, point: ArrayLike) -> Tuple[str, List, Dict]:
Expand Down
6 changes: 4 additions & 2 deletions pdf_agents/gsas.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
cif_paths: List[Union[str, Path]],
refinement_params: List[dict],
inst_param_path: Union[str, Path],
metadata: dict = None,
**kwargs,
):
self._cif_paths = cif_paths
Expand All @@ -57,7 +58,9 @@ def __init__(
self._recent_x = None
self._recent_y = None
self._recent_uid = None
super().__init__(**kwargs)
metadata = metadata or {}
metadata.update(refinement_params)
super().__init__(metadata=metadata, **kwargs)
self.report_on_tell = True

@property
Expand Down Expand Up @@ -184,7 +187,6 @@ def report(self) -> Dict[str, ArrayLike]:
observable=self._recent_y,
cif_paths=self.cif_paths,
inst_param_path=self.inst_param_path,
refinement_params=self.refinement_params,
gsas_rwps=np.array(gsas_rwps),
gsas_ycalcs=np.stack(gsas_ycalcs),
gsas_ydiffs=np.stack(gsas_ydiffs),
Expand Down
2 changes: 1 addition & 1 deletion pdf_agents/startup_scripts/mmm5-tax-day/gsas.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

report_producer = RefinementAgent.get_default_producer()


agent = RefinementAgent(
# GSAS Args
cif_paths=["/src/pdf-agents/assets/fcc.cif"],
Expand Down

0 comments on commit e6bc16b

Please sign in to comment.