Skip to content

Commit

Permalink
FIX: setup kmeans agent to work rel coords
Browse files Browse the repository at this point in the history
  • Loading branch information
maffettone committed Dec 13, 2023
1 parent b40b6bc commit a99a34c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
10 changes: 8 additions & 2 deletions pdf_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class PDFBaseAgent(Agent, ABC):
def __init__(
self,
*args,
motor_names: List[str] = ["Grid_X"],
motor_names: List[str] = ["xstage", "ystage"],
motor_origins: List[Tuple[float, float]] = [(0.0, 0.0), (0.0, 0.0)],
motor_resolution: float = 0.0002,
data_key: str = "chi_I",
roi_key: str = "chi_Q",
Expand All @@ -34,6 +35,7 @@ def __init__(
self._rkvs = redis.Redis(host="info.pdf.nsls2.bnl.gov", port=6379, db=0) # redis key value store
self._motor_names = motor_names
self._motor_resolution = motor_resolution
self._motor_origins = np.array(motor_origins)
self._data_key = data_key
self._roi_key = roi_key
self._roi = roi
Expand Down Expand Up @@ -84,6 +86,7 @@ def measurement_plan(self, point: ArrayLike) -> Tuple[str, List, Dict]:
return "agent_redisAware_XRDcount", [point], {}

def unpack_run(self, run) -> Tuple[Union[float, ArrayLike], Union[float, ArrayLike]]:
"""Subtracts background and returns motor positions and data"""
y = run.primary.data[self.data_key].read().flatten()
if self.background is not None:
y = y - self.background[1]
Expand All @@ -94,7 +97,10 @@ def unpack_run(self, run) -> Tuple[Union[float, ArrayLike], Union[float, ArrayLi
y = y[idx_min:idx_max]
try:
x = np.array(
[run.start["more_info"][motor_name][motor_name]["value"] for motor_name in self.motor_names]
[
run.start["more_info"][motor_name][f"OT_Stage_2_{motor_name[0].upper()}"]["value"]
for motor_name in self.motor_names
]
)
except KeyError:
x = np.array([run.start[motor_name][motor_name]["value"] for motor_name in self.motor_names])
Expand Down
19 changes: 0 additions & 19 deletions pdf_agents/monarch_bmm_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,22 +236,3 @@ def subject_ask(self, batch_size=1) -> Tuple[Sequence[dict[str, ArrayLike]], Seq
)
docs = [dict(suggestion=suggestion, **_default_doc) for suggestion in kept_suggestions]
return docs, kept_suggestions

def tell(self, x, y):
"""Update tell using relative info"""
x = x - self.pdf_origin[0]
doc = super().tell(x, y)
doc["absolute_position_offset"] = self.pdf_origin[0]
return doc

def ask(self, batch_size=1) -> Tuple[Sequence[dict[str, ArrayLike]], Sequence[ArrayLike]]:
"""Update ask with relative info"""
docs, suggestions = super().ask(batch_size=batch_size)
for doc in docs:
doc["absolute_position_offset"] = self.pdf_origin[0]
return docs, suggestions

def measurement_plan(self, relative_point: ArrayLike) -> Tuple[str, List, dict]:
"""Send measurement plan absolute point from reltive position"""
absolute_point = relative_point + self.pdf_origin[0]
return super().measurement_plan(absolute_point)
13 changes: 13 additions & 0 deletions pdf_agents/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def server_registrations(self) -> None:
self._register_method("clear_caches")
return super().server_registrations()

def tell(self, x, y):
"""Update tell using relative info"""
x = x - self._motor_origins
doc = super().tell(x, y)
doc["absolute_position_offset"] = self._motor_origins
return doc

@classmethod
def hud_from_report(
cls,
Expand Down Expand Up @@ -215,11 +222,17 @@ def ask(self, batch_size=1):
latest_data=self.tell_cache[-1],
requested_batch_size=batch_size,
redundant_points_discarded=batch_size - len(kept_suggestions),
absolute_position_offset=self._motor_origins,
)
docs = [dict(suggestion=suggestion, **base_doc) for suggestion in kept_suggestions]

return docs, kept_suggestions

def measurement_plan(self, relative_point: ArrayLike):
"""Send measurement plan absolute point from reltive position"""
absolute_point = relative_point + self._motor_origins
return super().measurement_plan(absolute_point)


def current_dist_gen(x, px):
"""from distribution defined by p(x), produce a discrete generator.
Expand Down

0 comments on commit a99a34c

Please sign in to comment.