Skip to content

Commit

Permalink
Reformat / lint repository with new dev dependency versions (#2248)
Browse files Browse the repository at this point in the history
* update dev requirements with new pre commit hook lint dependency versions

* black reformatting

* fix flake8 checks
  • Loading branch information
dennisbader authored Feb 24, 2024
1 parent 4600453 commit 6fbb670
Show file tree
Hide file tree
Showing 66 changed files with 555 additions and 463 deletions.
1 change: 0 additions & 1 deletion darts/ad/aggregators/or_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
is flagged as anomalous (logical OR).
"""


from typing import Sequence

from darts import TimeSeries
Expand Down
4 changes: 2 additions & 2 deletions darts/ad/anomaly_model/forecasting_am.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def _prepare_covariates(
series: Sequence[TimeSeries],
name_covariates: str,
) -> Sequence[TimeSeries]:
"""Convert `covariates` into Sequence, if not already, and checks if their length is equal to the one of `series`.
"""Convert `covariates` into Sequence, if not already, and checks if their length is equal to the one of
`series`.
Parameters
----------
Expand Down Expand Up @@ -515,7 +516,6 @@ def _predict_with_forecasting(
start: Union[pd.Timestamp, float, int] = None,
num_samples: int = 1,
) -> TimeSeries:

"""Compute the historical forecasts that would have been obtained by this model on the `series`.
`retrain` is set to False if possible (this is not supported by all models). If set to True, it will always
Expand Down
4 changes: 1 addition & 3 deletions darts/ad/detectors/quantile_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def _prep_quantile(q):
return (
q.tolist()
if isinstance(q, np.ndarray)
else [q]
if not isinstance(q, Sequence)
else q
else [q] if not isinstance(q, Sequence) else q
)

low = _prep_quantile(low_quantile)
Expand Down
4 changes: 1 addition & 3 deletions darts/ad/detectors/threshold_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def _prep_thresholds(q):
return (
q.tolist()
if isinstance(q, np.ndarray)
else [q]
if not isinstance(q, Sequence)
else q
else [q] if not isinstance(q, Sequence) else q
)

low = _prep_thresholds(low_threshold)
Expand Down
2 changes: 1 addition & 1 deletion darts/ad/scorers/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def score_from_prediction(
_assert_same_length(list_actual_series, list_pred_series)

anomaly_scores = []
for (s1, s2) in zip(list_actual_series, list_pred_series):
for s1, s2 in zip(list_actual_series, list_pred_series):
_sanity_check_two_series(s1, s2)
s1 = self._assert_deterministic(s1, "actual_series")
s2 = self._assert_deterministic(s2, "pred_series")
Expand Down
1 change: 1 addition & 0 deletions darts/dataprocessing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Pipeline
--------
"""

from copy import deepcopy
from typing import Iterator, Sequence, Union

Expand Down
1 change: 1 addition & 0 deletions darts/dataprocessing/transformers/midas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Mixed-data sampling (MIDAS) Transformer
---------------------------------------
"""

from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import numpy as np
Expand Down
1 change: 0 additions & 1 deletion darts/dataprocessing/transformers/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
It can be added to a ``TimeSeries`` using e.g., the :meth:`TimeSeries.with_hierarchy` method.
"""


from typing import Any, Mapping, Optional

import numpy as np
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Static Covariates Transformer
------
"""

from collections import OrderedDict
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand Down
10 changes: 6 additions & 4 deletions darts/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def __init__(self, multivariate: bool = True):

def pre_proces_fn(extracted_dir, dataset_path):
with open(Path(extracted_dir, "LD2011_2014.txt")) as fin:
with open(dataset_path, "wt", newline="\n") as fout:
with open(dataset_path, "w", newline="\n") as fout:
for line in fin:
fout.write(line.replace(",", ".").replace(";", ","))

Expand Down Expand Up @@ -622,9 +622,11 @@ def pre_proces_fn(extracted_dir, dataset_path):
uri="https://github.com/fivethirtyeight/uber-tlc-foil-response/raw/"
"63bb878b76f47f69b4527d50af57aac26dead983/"
"uber-trip-data/uber-raw-data-janjune-15.csv.zip",
hash="9ed84ebe0df4bc664748724b633b3fe6"
if sample_freq == "hourly"
else "24f9fd67e4b9e53f0214a90268cd9bee",
hash=(
"9ed84ebe0df4bc664748724b633b3fe6"
if sample_freq == "hourly"
else "24f9fd67e4b9e53f0214a90268cd9bee"
),
header_time="Pickup_date",
format_time="%Y-%m-%d %H:%M:%S",
pre_process_zipped_csv_fn=pre_proces_fn,
Expand Down
1 change: 1 addition & 0 deletions darts/explainability/explainability.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
A `_ForecastingModelExplainer` takes a fitted forecasting model as input and generates explanations for it.
"""

from abc import ABC, abstractmethod
from typing import Optional, Sequence, Tuple, Union

Expand Down
13 changes: 6 additions & 7 deletions darts/explainability/shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,6 @@ def shap_explanations(
horizons: Optional[Sequence[int]] = None,
target_components: Optional[Sequence[str]] = None,
) -> Dict[int, Dict[str, shap.Explanation]]:

"""
Return a dictionary of dictionaries of shap.Explanation instances:
- the first dimension corresponds to the n forecasts ahead we want to explain (Horizon).
Expand Down Expand Up @@ -760,14 +759,14 @@ def _create_regression_model_shap_X(
X, indexes = create_lagged_prediction_data(
target_series=target_series if lags_list else None,
past_covariates=past_covariates if lags_past_covariates_list else None,
future_covariates=future_covariates
if lags_future_covariates_list
else None,
future_covariates=(
future_covariates if lags_future_covariates_list else None
),
lags=lags_list,
lags_past_covariates=lags_past_covariates_list if past_covariates else None,
lags_future_covariates=lags_future_covariates_list
if future_covariates
else None,
lags_future_covariates=(
lags_future_covariates_list if future_covariates else None
),
uses_static_covariates=self.model.uses_static_covariates,
last_static_covariates_shape=self.model._static_covariates_shape,
)
Expand Down
19 changes: 12 additions & 7 deletions darts/explainability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,18 @@ def _check_valid_input(
all(
[
series[idx].columns.to_list() == target_components,
past_covariates[idx].columns.to_list() == past_covariates_components
if past_covariates is not None
else True,
future_covariates[idx].columns.to_list()
== future_covariates_components
if future_covariates is not None
else True,
(
past_covariates[idx].columns.to_list()
== past_covariates_components
if past_covariates is not None
else True
),
(
future_covariates[idx].columns.to_list()
== future_covariates_components
if future_covariates is not None
else True
),
]
),
"Columns names must be identical between TimeSeries list components (multi-TimeSeries).",
Expand Down
5 changes: 1 addition & 4 deletions darts/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def wrapper_multi_ts_support(*args, **kwargs):
pred_series = (
kwargs["pred_series"]
if "pred_series" in kwargs
else args[0]
if "actual_series" in kwargs
else args[1]
else args[0] if "actual_series" in kwargs else args[1]
)

n_jobs = kwargs.pop("n_jobs", signature(func).parameters["n_jobs"].default)
Expand Down Expand Up @@ -1134,7 +1132,6 @@ def rho_risk(
n_jobs: int = 1,
verbose: bool = False
) -> float:

""":math:`\\rho`-risk (rho-risk or quantile risk).
Given a time series of actual values :math:`y_t` of length :math:`T` and a time series of stochastic predictions
Expand Down
28 changes: 16 additions & 12 deletions darts/models/forecasting/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,17 +193,19 @@ def _predict(
if series is not None:
self.model = self.model.apply(
series.values(copy=False),
exog=historic_future_covariates.values(copy=False)
if historic_future_covariates
else None,
exog=(
historic_future_covariates.values(copy=False)
if historic_future_covariates
else None
),
)

if num_samples == 1:
forecast = self.model.forecast(
steps=n,
exog=future_covariates.values(copy=False)
if future_covariates
else None,
exog=(
future_covariates.values(copy=False) if future_covariates else None
),
)
else:
forecast = self.model.simulate(
Expand All @@ -212,18 +214,20 @@ def _predict(
initial_state=self.model.states.predicted[-1, :],
random_state=self._random_state,
anchor="end",
exog=future_covariates.values(copy=False)
if future_covariates
else None,
exog=(
future_covariates.values(copy=False) if future_covariates else None
),
)

# restoring statsmodels results object state
if series is not None:
self.model = self.model.apply(
self._orig_training_series.values(copy=False),
exog=self.training_historic_future_covariates.values(copy=False)
if self.training_historic_future_covariates
else None,
exog=(
self.training_historic_future_covariates.values(copy=False)
if self.training_historic_future_covariates
else None
),
)

return self._build_forecast_series(forecast)
Expand Down
20 changes: 11 additions & 9 deletions darts/models/forecasting/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,12 @@ def fit(
for model in self.forecasting_models:
model._fit_wrapper(
series=series,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
past_covariates=(
past_covariates if model.supports_past_covariates else None
),
future_covariates=(
future_covariates if model.supports_future_covariates else None
),
)

return self
Expand All @@ -364,9 +364,11 @@ def ensemble(

if isinstance(predictions, Sequence):
return [
self._target_average(p, ts)
if not predict_likelihood_parameters
else self._params_average(p, ts)
(
self._target_average(p, ts)
if not predict_likelihood_parameters
else self._params_average(p, ts)
)
for p, ts in zip(predictions, series)
]
else:
Expand Down
2 changes: 0 additions & 2 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def __init__(
name: str,
**kwargs,
):

"""PyTorch module implementing a block RNN to be used in `BlockRNNModel`.
PyTorch module implementing a simple block RNN with the specified `name` layer.
Expand Down Expand Up @@ -196,7 +195,6 @@ def __init__(
dropout: float = 0.0,
**kwargs,
):

"""Block Recurrent Neural Network Model (RNNs).
This is a neural network model that uses an RNN encoder to encode fixed-length input chunks, and
Expand Down
8 changes: 5 additions & 3 deletions darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ def min_train_series_length(self) -> int:
# for other regression models
return max(
3,
-self.lags["target"][0] + self.output_chunk_length + 1
if "target" in self.lags
else self.output_chunk_length,
(
-self.lags["target"][0] + self.output_chunk_length + 1
if "target" in self.lags
else self.output_chunk_length
),
)
16 changes: 10 additions & 6 deletions darts/models/forecasting/croston.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non

self.model.fit(
y=series.values(copy=False).flatten(),
X=future_covariates.values(copy=False).flatten()
if future_covariates is not None
else None,
X=(
future_covariates.values(copy=False).flatten()
if future_covariates is not None
else None
),
)

return self
Expand All @@ -147,9 +149,11 @@ def _predict(
super()._predict(n, future_covariates, num_samples)
values = self.model.predict(
h=n,
X=future_covariates.values(copy=False).flatten()
if future_covariates is not None
else None,
X=(
future_covariates.values(copy=False).flatten()
if future_covariates is not None
else None
),
)["mean"]
return self._build_forecast_series(values)

Expand Down
12 changes: 6 additions & 6 deletions darts/models/forecasting/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def _make_multiple_predictions(
model._predict_wrapper(
n=n,
series=series,
past_covariates=past_covariates
if model.supports_past_covariates
else None,
future_covariates=future_covariates
if model.supports_future_covariates
else None,
past_covariates=(
past_covariates if model.supports_past_covariates else None
),
future_covariates=(
future_covariates if model.supports_future_covariates else None
),
num_samples=num_samples if model._is_probabilistic else 1,
predict_likelihood_parameters=predict_likelihood_parameters,
)
Expand Down
1 change: 0 additions & 1 deletion darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def __init__(
kwargs: Optional[Dict[str, Any]] = None,
**fit_kwargs
):

"""Exponential Smoothing
This is a wrapper around
Expand Down
6 changes: 3 additions & 3 deletions darts/models/forecasting/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _find_relevant_timestamp_attributes(series: TimeSeries) -> set:
# check for yearly seasonality
if _check_approximate_seasonality(series, 12, 1, 0):
relevant_attributes.add("month")
elif type(series.freq) == pd.tseries.offsets.Day:
elif type(series.freq) is pd.tseries.offsets.Day:
# check for yearly seasonality
if _check_approximate_seasonality(series, 365, 5, 20):
relevant_attributes.update({"month", "day"})
Expand All @@ -115,7 +115,7 @@ def _find_relevant_timestamp_attributes(series: TimeSeries) -> set:
# check for weekly seasonality
elif _check_approximate_seasonality(series, 7, 0, 0):
relevant_attributes.add("weekday")
elif type(series.freq) == pd.tseries.offsets.Hour:
elif type(series.freq) is pd.tseries.offsets.Hour:
# check for yearly seasonality
if _check_approximate_seasonality(series, 8760, 100, 100):
relevant_attributes.update({"month", "day", "hour"})
Expand All @@ -128,7 +128,7 @@ def _find_relevant_timestamp_attributes(series: TimeSeries) -> set:
# check for daily seasonality
elif _check_approximate_seasonality(series, 24, 1, 1):
relevant_attributes.add("hour")
elif type(series.freq) == pd.tseries.offsets.Minute:
elif type(series.freq) is pd.tseries.offsets.Minute:
# check for daily seasonality
if _check_approximate_seasonality(series, 1440, 20, 50):
relevant_attributes.update({"hour", "minute"})
Expand Down
Loading

0 comments on commit 6fbb670

Please sign in to comment.