Skip to content

Commit 1aa4cc7

Browse files
committed
Unsafe fixes
1 parent 05eb39f commit 1aa4cc7

38 files changed

+870
-560
lines changed

adaptive/learner/average_learner.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
from __future__ import annotations
22

33
from math import sqrt
4-
from typing import Callable
4+
from typing import TYPE_CHECKING, Callable
55

66
import cloudpickle
77
import numpy as np
88

99
from adaptive.learner.base_learner import BaseLearner
1010
from adaptive.notebook_integration import ensure_holoviews
11-
from adaptive.types import Float, Int, Real
1211
from adaptive.utils import (
1312
assign_defaults,
1413
cache_latest,
1514
partial_function_from_dataframe,
1615
)
1716

17+
if TYPE_CHECKING:
18+
from adaptive.types import Float, Int, Real
19+
1820
try:
19-
import pandas
21+
import pandas as pd
2022

2123
with_pandas = True
2224

@@ -47,6 +49,7 @@ class AverageLearner(BaseLearner):
4749
Points that still have to be evaluated.
4850
npoints : int
4951
Number of evaluated points.
52+
5053
"""
5154

5255
def __init__(
@@ -57,7 +60,8 @@ def __init__(
5760
min_npoints: int = 2,
5861
) -> None:
5962
if atol is None and rtol is None:
60-
raise Exception("At least one of `atol` and `rtol` should be set.")
63+
msg = "At least one of `atol` and `rtol` should be set."
64+
raise Exception(msg)
6165
if atol is None:
6266
atol = np.inf
6367
if rtol is None:
@@ -92,7 +96,7 @@ def to_dataframe( # type: ignore[override]
9296
function_prefix: str = "function.",
9397
seed_name: str = "seed",
9498
y_name: str = "y",
95-
) -> pandas.DataFrame:
99+
) -> pd.DataFrame:
96100
"""Return the data as a `pandas.DataFrame`.
97101
98102
Parameters
@@ -116,10 +120,12 @@ def to_dataframe( # type: ignore[override]
116120
------
117121
ImportError
118122
If `pandas` is not installed.
123+
119124
"""
120125
if not with_pandas:
121-
raise ImportError("pandas is not installed.")
122-
df = pandas.DataFrame(sorted(self.data.items()), columns=[seed_name, y_name])
126+
msg = "pandas is not installed."
127+
raise ImportError(msg)
128+
df = pd.DataFrame(sorted(self.data.items()), columns=[seed_name, y_name])
123129
df.attrs["inputs"] = [seed_name]
124130
df.attrs["output"] = y_name
125131
if with_default_function_args:
@@ -128,12 +134,12 @@ def to_dataframe( # type: ignore[override]
128134

129135
def load_dataframe( # type: ignore[override]
130136
self,
131-
df: pandas.DataFrame,
137+
df: pd.DataFrame,
132138
with_default_function_args: bool = True,
133139
function_prefix: str = "function.",
134140
seed_name: str = "seed",
135141
y_name: str = "y",
136-
):
142+
) -> None:
137143
"""Load data from a `pandas.DataFrame`.
138144
139145
If ``with_default_function_args`` is True, then ``learner.function``'s
@@ -153,11 +159,14 @@ def load_dataframe( # type: ignore[override]
153159
The ``seed_name`` used in ``to_dataframe``, by default "seed"
154160
y_name : str, optional
155161
The ``y_name`` used in ``to_dataframe``, by default "y"
162+
156163
"""
157164
self.tell_many(df[seed_name].values, df[y_name].values)
158165
if with_default_function_args:
159166
self.function = partial_function_from_dataframe(
160-
self.function, df, function_prefix
167+
self.function,
168+
df,
169+
function_prefix,
161170
)
162171

163172
def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]]:
@@ -168,7 +177,7 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]
168177
points = list(
169178
set(range(self.n_requested + n))
170179
- set(self.data)
171-
- set(self.pending_points)
180+
- set(self.pending_points),
172181
)[:n]
173182

174183
loss_improvements = [self._loss_improvement(n) / n] * n
@@ -199,7 +208,8 @@ def mean(self) -> Float:
199208
@property
200209
def std(self) -> Float:
201210
"""The corrected sample standard deviation of the values
202-
in `data`."""
211+
in `data`.
212+
"""
203213
n = self.npoints
204214
if n < self.min_npoints:
205215
return np.inf
@@ -211,10 +221,7 @@ def std(self) -> Float:
211221

212222
@cache_latest
213223
def loss(self, real: bool = True, *, n=None) -> Float:
214-
if n is None:
215-
n = self.npoints if real else self.n_requested
216-
else:
217-
n = n
224+
n = (self.npoints if real else self.n_requested) if n is None else n
218225
if n < self.min_npoints:
219226
return np.inf
220227
standard_error = self.std / sqrt(n)
@@ -232,7 +239,7 @@ def _loss_improvement(self, n: int) -> Float:
232239
else:
233240
return np.inf
234241

235-
def remove_unfinished(self):
242+
def remove_unfinished(self) -> None:
236243
"""Remove uncomputed data from the learner."""
237244
self.pending_points = set()
238245

@@ -242,7 +249,9 @@ def plot(self):
242249
Returns
243250
-------
244251
holoviews.element.Histogram
245-
A histogram of the evaluated data."""
252+
A histogram of the evaluated data.
253+
254+
"""
246255
hv = ensure_holoviews()
247256
vals = [v for v in self.data.values() if v is not None]
248257
if not vals:

0 commit comments

Comments
 (0)