Skip to content

Commit

Permalink
Allow bulk_fetch_trial_data to return mix of successes/failures (#2339)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2339

Currently bulk_fetch_trial_data (which we use for efficiency) will fail if any individual metric fetch fails. This can cause issues if Ax tries to fetch 2 metrics (say training loss and eval loss) and one is ready but the other is not. In the current setup Ax will report both as failures to fetch, in the new setup Ax will fetch the available one and return Err for the second

Reviewed By: sunnyshen321

Differential Revision: D55927149

fbshipit-source-id: 8d0ae14bfa8b2d593c4ab8810f51276afe62f50d
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Apr 9, 2024
1 parent a161618 commit 6720be4
Showing 1 changed file with 56 additions and 46 deletions.
102 changes: 56 additions & 46 deletions ax/metrics/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,54 +121,64 @@ def bulk_fetch_trial_data(

res = {}
for metric in tb_metrics:
records = [
{
"trial_index": trial.index,
"arm_name": arm_name,
"metric_name": metric.name,
self.map_key_info.key: t.step,
"mean": (
t.tensor_proto.double_val[0]
if t.tensor_proto.double_val
else t.tensor_proto.float_val[0]
),
"sem": float("nan"),
}
for run_name, tb_metrics in mul.PluginRunToTagToContent(
"scalars"
).items()
for tag in tb_metrics
if tag == metric.tag
for t in mul.Tensors(run_name, tag)
]

df = (
pd.DataFrame(records)
# If a metric has multiple records for the same arm, metric, and
# step (sometimes caused by restarts, etc) take the mean
.groupby(["arm_name", "metric_name", self.map_key_info.key])
.mean()
.reset_index()
)
try:
records = [
{
"trial_index": trial.index,
"arm_name": arm_name,
"metric_name": metric.name,
self.map_key_info.key: t.step,
"mean": (
t.tensor_proto.double_val[0]
if t.tensor_proto.double_val
else t.tensor_proto.float_val[0]
),
"sem": float("nan"),
}
for run_name, tb_metrics in mul.PluginRunToTagToContent(
"scalars"
).items()
for tag in tb_metrics
if tag == metric.tag
for t in mul.Tensors(run_name, tag)
]

df = (
pd.DataFrame(records)
# If a metric has multiple records for the same arm, metric, and
# step (sometimes caused by restarts, etc) take the mean
.groupby(["arm_name", "metric_name", self.map_key_info.key])
.mean()
.reset_index()
)

# Apply per-metric post-processing
# Apply cumulative "best" (min if lower_is_better)
if metric.cumulative_best:
if metric.lower_is_better:
df["mean"] = df["mean"].cummin()
else:
df["mean"] = df["mean"].cummax()

# Apply smoothing
if metric.smoothing > 0:
df["mean"] = df["mean"].ewm(alpha=metric.smoothing).mean()

res[metric.name] = Ok(
MapData(
df=df,
map_key_infos=[self.map_key_info],
# Apply per-metric post-processing
# Apply cumulative "best" (min if lower_is_better)
if metric.cumulative_best:
if metric.lower_is_better:
df["mean"] = df["mean"].cummin()
else:
df["mean"] = df["mean"].cummax()

# Apply smoothing
if metric.smoothing > 0:
df["mean"] = df["mean"].ewm(alpha=metric.smoothing).mean()

# Accumulate successfully extracted timeseries
res[metric.name] = Ok(
MapData(
df=df,
map_key_infos=[self.map_key_info],
)
)

except Exception as e:
res[metric.name] = Err(
MetricFetchE(
message=f"Failed to fetch data for {metric.name}",
exception=e,
)
)
)

return res

Expand Down

0 comments on commit 6720be4

Please sign in to comment.