1
1
from __future__ import annotations
2
2
3
3
from math import sqrt
4
- from typing import Callable
4
+ from typing import TYPE_CHECKING , Callable
5
5
6
6
import cloudpickle
7
7
import numpy as np
8
8
9
9
from adaptive .learner .base_learner import BaseLearner
10
10
from adaptive .notebook_integration import ensure_holoviews
11
- from adaptive .types import Float , Int , Real
12
11
from adaptive .utils import (
13
12
assign_defaults ,
14
13
cache_latest ,
15
14
partial_function_from_dataframe ,
16
15
)
17
16
17
+ if TYPE_CHECKING :
18
+ from adaptive .types import Float , Int , Real
19
+
18
20
try :
19
- import pandas
21
+ import pandas as pd
20
22
21
23
with_pandas = True
22
24
@@ -47,6 +49,7 @@ class AverageLearner(BaseLearner):
47
49
Points that still have to be evaluated.
48
50
npoints : int
49
51
Number of evaluated points.
52
+
50
53
"""
51
54
52
55
def __init__ (
@@ -57,7 +60,8 @@ def __init__(
57
60
min_npoints : int = 2 ,
58
61
) -> None :
59
62
if atol is None and rtol is None :
60
- raise Exception ("At least one of `atol` and `rtol` should be set." )
63
+ msg = "At least one of `atol` and `rtol` should be set."
64
+ raise Exception (msg )
61
65
if atol is None :
62
66
atol = np .inf
63
67
if rtol is None :
@@ -92,7 +96,7 @@ def to_dataframe( # type: ignore[override]
92
96
function_prefix : str = "function." ,
93
97
seed_name : str = "seed" ,
94
98
y_name : str = "y" ,
95
- ) -> pandas .DataFrame :
99
+ ) -> pd .DataFrame :
96
100
"""Return the data as a `pandas.DataFrame`.
97
101
98
102
Parameters
@@ -116,10 +120,12 @@ def to_dataframe( # type: ignore[override]
116
120
------
117
121
ImportError
118
122
If `pandas` is not installed.
123
+
119
124
"""
120
125
if not with_pandas :
121
- raise ImportError ("pandas is not installed." )
122
- df = pandas .DataFrame (sorted (self .data .items ()), columns = [seed_name , y_name ])
126
+ msg = "pandas is not installed."
127
+ raise ImportError (msg )
128
+ df = pd .DataFrame (sorted (self .data .items ()), columns = [seed_name , y_name ])
123
129
df .attrs ["inputs" ] = [seed_name ]
124
130
df .attrs ["output" ] = y_name
125
131
if with_default_function_args :
@@ -128,12 +134,12 @@ def to_dataframe( # type: ignore[override]
128
134
129
135
def load_dataframe ( # type: ignore[override]
130
136
self ,
131
- df : pandas .DataFrame ,
137
+ df : pd .DataFrame ,
132
138
with_default_function_args : bool = True ,
133
139
function_prefix : str = "function." ,
134
140
seed_name : str = "seed" ,
135
141
y_name : str = "y" ,
136
- ):
142
+ ) -> None :
137
143
"""Load data from a `pandas.DataFrame`.
138
144
139
145
If ``with_default_function_args`` is True, then ``learner.function``'s
@@ -153,11 +159,14 @@ def load_dataframe( # type: ignore[override]
153
159
The ``seed_name`` used in ``to_dataframe``, by default "seed"
154
160
y_name : str, optional
155
161
The ``y_name`` used in ``to_dataframe``, by default "y"
162
+
156
163
"""
157
164
self .tell_many (df [seed_name ].values , df [y_name ].values )
158
165
if with_default_function_args :
159
166
self .function = partial_function_from_dataframe (
160
- self .function , df , function_prefix
167
+ self .function ,
168
+ df ,
169
+ function_prefix ,
161
170
)
162
171
163
172
def ask (self , n : int , tell_pending : bool = True ) -> tuple [list [int ], list [Float ]]:
@@ -168,7 +177,7 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[list[int], list[Float]
168
177
points = list (
169
178
set (range (self .n_requested + n ))
170
179
- set (self .data )
171
- - set (self .pending_points )
180
+ - set (self .pending_points ),
172
181
)[:n ]
173
182
174
183
loss_improvements = [self ._loss_improvement (n ) / n ] * n
@@ -199,7 +208,8 @@ def mean(self) -> Float:
199
208
@property
200
209
def std (self ) -> Float :
201
210
"""The corrected sample standard deviation of the values
202
- in `data`."""
211
+ in `data`.
212
+ """
203
213
n = self .npoints
204
214
if n < self .min_npoints :
205
215
return np .inf
@@ -211,10 +221,7 @@ def std(self) -> Float:
211
221
212
222
@cache_latest
213
223
def loss (self , real : bool = True , * , n = None ) -> Float :
214
- if n is None :
215
- n = self .npoints if real else self .n_requested
216
- else :
217
- n = n
224
+ n = (self .npoints if real else self .n_requested ) if n is None else n
218
225
if n < self .min_npoints :
219
226
return np .inf
220
227
standard_error = self .std / sqrt (n )
@@ -232,7 +239,7 @@ def _loss_improvement(self, n: int) -> Float:
232
239
else :
233
240
return np .inf
234
241
235
- def remove_unfinished (self ):
242
+ def remove_unfinished (self ) -> None :
236
243
"""Remove uncomputed data from the learner."""
237
244
self .pending_points = set ()
238
245
@@ -242,7 +249,9 @@ def plot(self):
242
249
Returns
243
250
-------
244
251
holoviews.element.Histogram
245
- A histogram of the evaluated data."""
252
+ A histogram of the evaluated data.
253
+
254
+ """
246
255
hv = ensure_holoviews ()
247
256
vals = [v for v in self .data .values () if v is not None ]
248
257
if not vals :
0 commit comments