Skip to content

Commit

Permalink
Improve Stats compute perf
Browse files Browse the repository at this point in the history
  • Loading branch information
Kh4L committed Sep 9, 2024
1 parent 546f1a2 commit 4ffd56e
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions torch_frame/data/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,25 @@ def compute(
sep: str | None = None,
) -> Any:
if self == StatType.MEAN:
flattened = np.hstack(np.hstack(ser.values))
val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values
flattened = np.hstack(val) if val.ndim > 1 else val
finite_mask = np.isfinite(flattened)
if not finite_mask.any():
# NOTE: We may just error out here if eveything is NaN
return np.nan
return np.mean(flattened[finite_mask]).item()

elif self == StatType.STD:
flattened = np.hstack(np.hstack(ser.values))
val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values
flattened = np.hstack(val) if val.ndim > 1 else val
finite_mask = np.isfinite(flattened)
if not finite_mask.any():
return np.nan
return np.std(flattened[finite_mask]).item()

elif self == StatType.QUANTILES:
flattened = np.hstack(np.hstack(ser.values))
val = np.hstack(ser.values) if ser.values.ndim > 1 else ser.values
flattened = np.hstack(val) if val.ndim > 1 else val
finite_mask = np.isfinite(flattened)
if not finite_mask.any():
return [np.nan, np.nan, np.nan, np.nan, np.nan]
Expand Down

0 comments on commit 4ffd56e

Please sign in to comment.