Skip to content

Commit

Permalink
remove njit from simple functions (#754)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Jan 12, 2024
1 parent 99178e4 commit ddbf943
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 205 deletions.
153 changes: 60 additions & 93 deletions nbs/src/core/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3146,36 +3146,23 @@
" return mse\n",
"\n",
"\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _ses_forecast(x: np.ndarray, alpha: float) -> Tuple[float, np.ndarray]:\n",
" \"\"\"One step ahead forecast with simple exponential smoothing.\"\"\"\n",
" forecast, _, fitted = _ses_fcst_mse(x, alpha)\n",
" return forecast, fitted\n",
"\n",
"\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _demand(x: np.ndarray) -> np.ndarray:\n",
" \"\"\"Extract the positive elements of a vector.\"\"\"\n",
" return x[x > 0]\n",
"\n",
"\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _intervals(x: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute the intervals between non zero elements of a vector.\"\"\"\n",
" y = []\n",
" nonzero_idxs = np.where(x != 0)[0]\n",
" return np.diff(nonzero_idxs + 1, prepend=0)\n",
"\n",
" ctr = 1\n",
" for val in x:\n",
" if val == 0:\n",
" ctr += 1\n",
" else:\n",
" y.append(ctr)\n",
" ctr = 1\n",
"\n",
" return np.array(y)\n",
"\n",
"\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _probability(x: np.ndarray) -> np.ndarray:\n",
" \"\"\"Compute the element probabilities of being non zero.\"\"\"\n",
" return (x != 0).astype(np.int32)\n",
Expand All @@ -3197,15 +3184,13 @@
" return forecast, fitted\n",
"\n",
"\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _chunk_sums(array: np.ndarray, chunk_size: int) -> np.ndarray:\n",
" \"\"\"Splits an array into chunks and returns the sum of each chunk.\"\"\"\n",
" n = array.size\n",
" n_chunks = n // chunk_size\n",
" sums = np.empty(n_chunks)\n",
" for i, start in enumerate(range(0, n, chunk_size)):\n",
" sums[i] = array[start : start + chunk_size].sum()\n",
" return sums"
" \"\"\"Splits an array into chunks and returns the sum of each chunk.\n",
" \n",
" Incomplete chunks are discarded\"\"\"\n",
" n_chunks = array.size // chunk_size\n",
" n_elems = n_chunks * chunk_size\n",
" return array[:n_elems].reshape(n_chunks, chunk_size).sum(axis=1)"
]
},
{
Expand All @@ -3215,16 +3200,14 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _ses(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" alpha: float, # smoothing parameter\n",
" ): \n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" alpha: float, # smoothing parameter\n",
") -> Dict[str, np.ndarray]: \n",
" fcst, _, fitted_vals = _ses_fcst_mse(y, alpha)\n",
" mean = _repeat_val(val=fcst, h=h)\n",
" fcst = {'mean': mean}\n",
" fcst = {'mean': _repeat_val(val=fcst, h=h)}\n",
" if fitted:\n",
" fcst['fitted'] = fitted_vals\n",
" return fcst"
Expand Down Expand Up @@ -3850,14 +3833,13 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _seasonal_exponential_smoothing(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" season_length: int, # length of season\n",
" alpha: float, # smoothing parameter\n",
" ):\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" season_length: int, # length of season\n",
" alpha: float, # smoothing parameter\n",
") -> Dict[str, np.ndarray]:\n",
" n = y.size\n",
" if n < season_length:\n",
" return {'mean': np.full(h, np.nan, np.float32)}\n",
Expand All @@ -3866,7 +3848,7 @@
" for i in range(season_length):\n",
" init_idx = (i + n % season_length)\n",
" season_vals[i], fitted_vals[init_idx::season_length] = _ses_forecast(y[init_idx::season_length], alpha)\n",
" out = _repeat_val_seas(season_vals=season_vals, h=h, season_length=season_length)\n",
" out = _repeat_val_seas(season_vals=season_vals, h=h)\n",
" fcst = {'mean': out}\n",
" if fitted:\n",
" fcst['fitted'] = fitted_vals\n",
Expand Down Expand Up @@ -3982,7 +3964,7 @@
" forecasts : dict \n",
" Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n",
" \"\"\"\n",
" mean = _repeat_val_seas(self.model_['mean'], season_length=self.season_length, h=h)\n",
" mean = _repeat_val_seas(self.model_['mean'], h=h)\n",
" res = {'mean': mean}\n",
" if level is None:\n",
" return res\n",
Expand Down Expand Up @@ -4243,7 +4225,7 @@
" for i in range(season_length):\n",
" init_idx = (i + n % season_length)\n",
" season_vals[i], fitted_vals[init_idx::season_length] = _optimized_ses_forecast(y[init_idx::season_length], [(0.01, 0.99)])\n",
" out = _repeat_val_seas(season_vals=season_vals, h=h, season_length=season_length)\n",
" out = _repeat_val_seas(season_vals=season_vals, h=h)\n",
" fcst = {'mean': out}\n",
" if fitted:\n",
" fcst['fitted'] = fitted_vals\n",
Expand Down Expand Up @@ -4357,7 +4339,7 @@
" forecasts : dict \n",
" Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n",
" \"\"\"\n",
" mean = _repeat_val_seas(self.model_['mean'], season_length=self.season_length, h=h)\n",
" mean = _repeat_val_seas(self.model_['mean'], h=h)\n",
" res = {'mean': mean}\n",
" if level is None:\n",
" return res\n",
Expand Down Expand Up @@ -5011,17 +4993,13 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _historic_average(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" ):\n",
" mean = _repeat_val(val=y.mean(), h=h)\n",
" fcst = {'mean': mean}\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
") -> Dict[str, np.ndarray]:\n",
" fcst = {'mean': _repeat_val(val=y.mean(), h=h)}\n",
" if fitted:\n",
" #fitted_vals = np.full(y.size, np.nan, np.float32) # one-step ahead\n",
" #fitted_vals[1:] = y.cumsum()[:-1] / np.arange(1, y.size) \n",
" fitted_vals = _repeat_val(val=y.mean(), h=len(y))\n",
" fcst['fitted'] = fitted_vals\n",
" return fcst"
Expand Down Expand Up @@ -5764,15 +5742,14 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _random_walk_with_drift(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" ): \n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
") -> Dict[str, np.ndarray]: \n",
" slope = (y[-1] - y[0]) / (y.size - 1)\n",
" mean = slope * (1 + np.arange(h)) + y[-1]\n",
" fcst = {'mean': mean.astype(np.float32), \n",
" mean = slope * (1 + np.arange(h, dtype=np.float32)) + y[-1]\n",
" fcst = {'mean': mean.astype(np.float32, copy=False),\n",
" 'slope': np.array([slope], dtype=np.float32), \n",
" 'last_y': np.array([y[-1]], dtype=np.float32)}\n",
" if fitted:\n",
Expand Down Expand Up @@ -6237,8 +6214,7 @@
" forecasts : dict\n",
" Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n",
" \"\"\"\n",
" mean = _repeat_val_seas(season_vals=self.model_['mean'], \n",
" season_length=self.season_length, h=h)\n",
" mean = _repeat_val_seas(season_vals=self.model_['mean'], h=h)\n",
" res = {'mean': mean}\n",
" \n",
" if level is None:\n",
Expand Down Expand Up @@ -6514,13 +6490,12 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _window_average(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" window_size: int, # window size\n",
" ): \n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" window_size: int, # window size\n",
") -> Dict[str, np.ndarray]: \n",
" if fitted:\n",
" raise NotImplementedError('return fitted')\n",
" if y.size < window_size:\n",
Expand Down Expand Up @@ -6821,24 +6796,20 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _seasonal_window_average(\n",
" y: np.ndarray,\n",
" h: int,\n",
" fitted: bool,\n",
" season_length: int,\n",
" window_size: int,\n",
" ):\n",
" y: np.ndarray,\n",
" h: int,\n",
" fitted: bool,\n",
" season_length: int,\n",
" window_size: int,\n",
") -> Dict[str, np.ndarray]:\n",
" if fitted:\n",
" raise NotImplementedError('return fitted')\n",
" min_samples = season_length * window_size\n",
" if y.size < min_samples:\n",
" return {'mean': np.full(h, np.nan, np.float32)}\n",
" season_avgs = np.zeros(season_length, np.float32)\n",
" for i, value in enumerate(y[-min_samples:]):\n",
" season = i % season_length\n",
" season_avgs[season] += value / window_size\n",
" out = _repeat_val_seas(season_vals=season_avgs, h=h, season_length=season_length)\n",
" season_avgs = y[-min_samples:].reshape(window_size, season_length).mean(axis=0)\n",
" out = _repeat_val_seas(season_vals=season_avgs, h=h)\n",
" return {'mean': out}"
]
},
Expand Down Expand Up @@ -6943,8 +6914,7 @@
" forecasts : dict \n",
" Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n",
" \"\"\"\n",
" mean = _repeat_val_seas(season_vals=self.model_['mean'], \n",
" season_length=self.season_length, h=h)\n",
" mean = _repeat_val_seas(season_vals=self.model_['mean'], h=h)\n",
" res = {'mean': mean}\n",
" if level is None:\n",
" return res\n",
Expand Down Expand Up @@ -7472,7 +7442,6 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _croston_classic(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
Expand Down Expand Up @@ -8082,12 +8051,11 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _croston_sba(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
" ):\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: bool, # fitted values\n",
") -> Dict[str, np.ndarray]:\n",
" if fitted:\n",
" raise NotImplementedError('return fitted')\n",
" mean = _croston_classic(y, h, fitted)\n",
Expand Down Expand Up @@ -8691,14 +8659,13 @@
"outputs": [],
"source": [
"#| exporti\n",
"@njit(nogil=NOGIL, cache=CACHE)\n",
"def _tsb(\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: int, # fitted values\n",
" alpha_d: float,\n",
" alpha_p: float,\n",
" ):\n",
" y: np.ndarray, # time series\n",
" h: int, # forecasting horizon\n",
" fitted: int, # fitted values\n",
" alpha_d: float,\n",
" alpha_p: float,\n",
") -> Dict[str, np.ndarray]:\n",
" if fitted:\n",
" raise NotImplementedError('return fitted')\n",
" if (y == 0).all():\n",
Expand Down
3 changes: 2 additions & 1 deletion nbs/src/theta.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@
" res[f'hi-{lv}'] = np.quantile(samples, max_q, axis=1)\n",
" \n",
" if obj.get('decompose', False):\n",
" seas_forecast = _repeat_val_seas(obj['seas_forecast']['mean'], h=h, season_length=obj['m'])\n",
" seas_forecast = _repeat_val_seas(obj['seas_forecast']['mean'], h=h)\n",
" for key in res:\n",
" if obj['decomposition_type'] == 'multiplicative':\n",
" res[key] = res[key] * seas_forecast\n",
Expand Down Expand Up @@ -1087,6 +1087,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "cd79a3e6",
"metadata": {},
"outputs": [],
"source": [
Expand Down
Loading

0 comments on commit ddbf943

Please sign in to comment.