diff --git a/jesse/indicators/keltner.py b/jesse/indicators/keltner.py index 2883afde2..8769a2e89 100644 --- a/jesse/indicators/keltner.py +++ b/jesse/indicators/keltner.py @@ -1,45 +1,60 @@ from collections import namedtuple - import numpy as np - +from numba import njit from jesse.helpers import get_candle_source, slice_candles from jesse.indicators.ma import ma KeltnerChannel = namedtuple('KeltnerChannel', ['upperband', 'middleband', 'lowerband']) +@njit def _atr(high: np.ndarray, low: np.ndarray, close: np.ndarray, period: int) -> np.ndarray: - tr = np.empty_like(high) + """ + Calculate ATR using Numba + """ + n = len(close) + tr = np.empty(n) + atr_vals = np.full(n, np.nan) + + # Calculate True Range tr[0] = high[0] - low[0] - # Compute true range for the rest of the candles using vectorized operations - tr[1:] = np.maximum( - np.maximum(high[1:] - low[1:], np.abs(high[1:] - close[:-1])), - np.abs(low[1:] - close[:-1]) - ) - - atr_vals = np.empty_like(tr, dtype=float) - # Not enough data for ATR in the first period-1 candles - atr_vals[:period-1] = float('nan') - # The first ATR value is a simple average + for i in range(1, n): + hl = high[i] - low[i] + hc = abs(high[i] - close[i-1]) + lc = abs(low[i] - close[i-1]) + tr[i] = max(max(hl, hc), lc) + + if n < period: + return atr_vals + + # First ATR value is the simple average of the first 'period' true ranges atr_vals[period-1] = np.mean(tr[:period]) - - # Wilder's smoothing method for subsequent values using vectorized operation with lfilter - if len(tr) > period: - alpha = 1 / period - from scipy.signal import lfilter - A0 = atr_vals[period - 1] # initial ATR from simple average - x = tr[period:] - # Set the initial condition such that y[0] = (1 - alpha)*A0 + alpha*tr[period] - zi = [(1 - alpha) * A0] - y, _ = lfilter([alpha], [1, -(1 - alpha)], x, zi=zi) - atr_vals[period:] = y + + # Calculate subsequent ATR values using Wilder's smoothing + for i in range(period, n): + atr_vals[i] = (atr_vals[i-1] * (period - 1) + tr[i]) / period + return atr_vals -def keltner(candles: np.ndarray, period: int = 20, multiplier: float = 2, matype: int = 1, source_type: str = "close", - sequential: bool = False) -> KeltnerChannel: +@njit +def _calculate_keltner(source: np.ndarray, high: np.ndarray, low: np.ndarray, close: np.ndarray, + ma_values: np.ndarray, period: int, multiplier: float) -> tuple: """ - Keltner Channels + Core Keltner Channel calculation using Numba + """ + atr_vals = _atr(high, low, close, period) + + up = ma_values + atr_vals * multiplier + low = ma_values - atr_vals * multiplier + + return up, ma_values, low + + +def keltner(candles: np.ndarray, period: int = 20, multiplier: float = 2, matype: int = 1, + source_type: str = "close", sequential: bool = False) -> KeltnerChannel: + """ + Keltner Channels using Numba for optimization :param candles: np.ndarray :param period: int - default: 20 @@ -50,16 +65,20 @@ def keltner(candles: np.ndarray, period: int = 20, multiplier: float = 2, matype :return: KeltnerChannel(upperband, middleband, lowerband) """ - candles = slice_candles(candles, sequential) source = get_candle_source(candles, source_type=source_type) - e = ma(source, period=period, matype=matype, sequential=True) - a = _atr(candles[:, 3], candles[:, 4], candles[:, 2], period) - - up = e + a * multiplier - mid = e - low = e - a * multiplier + ma_values = ma(source, period=period, matype=matype, sequential=True) + + up, mid, low = _calculate_keltner( + source, + candles[:, 3], # high + candles[:, 4], # low + candles[:, 2], # close + ma_values, + period, + multiplier + ) if sequential: return KeltnerChannel(up, mid, low)