Skip to content

Commit

Permalink
Optimize indicators with Numba for improved performance
Browse files Browse the repository at this point in the history
- Refactor ADX, ATR, Bollinger Bands Width, and Ichimoku Cloud indicators
- Implement Numba-accelerated core calculation functions
- Improve computational efficiency using Numba's JIT compilation
- Maintain consistent interface and return types
  • Loading branch information
saleh-mir committed Feb 13, 2025
1 parent 519e1b4 commit 7a01f99
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 157 deletions.
162 changes: 92 additions & 70 deletions jesse/indicators/adx.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,96 @@
from typing import Union

import numpy as np

from numba import njit
from jesse.helpers import slice_candles


@njit
def _wilder_smooth(arr: np.ndarray, period: int) -> np.ndarray:
"""
Wilder's smoothing helper function
"""
n = len(arr)
result = np.full(n, np.nan)
# First value is sum of first "period" values
result[period] = np.sum(arr[1:period + 1])
# Apply smoothing formula
for i in range(period + 1, n):
result[i] = result[i - 1] - (result[i - 1] / period) + arr[i]
return result


@njit
def _calculate_adx(high: np.ndarray, low: np.ndarray, close: np.ndarray, period: int) -> np.ndarray:
"""
Core ADX calculation using Numba
"""
n = len(close)
TR = np.zeros(n)
plusDM = np.zeros(n)
minusDM = np.zeros(n)

# Calculate True Range and Directional Movement
for i in range(1, n):
hl = high[i] - low[i]
hc = abs(high[i] - close[i-1])
lc = abs(low[i] - close[i-1])
TR[i] = max(max(hl, hc), lc)

h_diff = high[i] - high[i-1]
l_diff = low[i-1] - low[i]

if h_diff > l_diff and h_diff > 0:
plusDM[i] = h_diff
else:
plusDM[i] = 0

if l_diff > h_diff and l_diff > 0:
minusDM[i] = l_diff
else:
minusDM[i] = 0

# Smooth the TR and DM values
tr_smooth = _wilder_smooth(TR, period)
plus_dm_smooth = _wilder_smooth(plusDM, period)
minus_dm_smooth = _wilder_smooth(minusDM, period)

# Calculate DI+ and DI-
DI_plus = np.full(n, np.nan)
DI_minus = np.full(n, np.nan)
DX = np.full(n, np.nan)

for i in range(period, n):
if tr_smooth[i] != 0:
DI_plus[i] = 100 * plus_dm_smooth[i] / tr_smooth[i]
DI_minus[i] = 100 * minus_dm_smooth[i] / tr_smooth[i]

if (DI_plus[i] + DI_minus[i]) != 0:
DX[i] = 100 * abs(DI_plus[i] - DI_minus[i]) / (DI_plus[i] + DI_minus[i])
else:
DX[i] = 0
else:
DI_plus[i] = 0
DI_minus[i] = 0
DX[i] = 0

# Calculate ADX
ADX = np.full(n, np.nan)
start_index = period * 2

if start_index < n:
# Calculate first ADX value
ADX[start_index] = np.mean(DX[period:start_index])

# Calculate subsequent ADX values
for i in range(start_index + 1, n):
ADX[i] = (ADX[i-1] * (period - 1) + DX[i]) / period

return ADX


def adx(candles: np.ndarray, period: int = 14, sequential: bool = False) -> Union[float, np.ndarray]:
"""
ADX - Average Directional Movement Index using vectorized matrix operations for smoothing.
ADX - Average Directional Movement Index using Numba for optimization.
:param candles: np.ndarray, expected 2D array with OHLCV data where index 3 is high, index 4 is low, and index 2 is close
:param period: int - default: 14
Expand All @@ -16,77 +99,16 @@ def adx(candles: np.ndarray, period: int = 14, sequential: bool = False) -> Unio
"""
if len(candles.shape) < 2:
raise ValueError("adx indicator requires a 2D array of candles")

candles = slice_candles(candles, sequential)

if len(candles) <= period:
return np.nan if sequential else np.nan

high = candles[:, 3]
low = candles[:, 4]
close = candles[:, 2]
n = len(close)
if n <= period:
return np.nan if sequential else np.nan

# Initialize arrays
TR = np.zeros(n)
plusDM = np.zeros(n)
minusDM = np.zeros(n)
result = _calculate_adx(high, low, close, period)

# Vectorized True Range computation for indices 1 to n-1
true_range = np.maximum(
np.maximum(high[1:] - low[1:], np.abs(high[1:] - close[:-1])),
np.abs(low[1:] - close[:-1])
)
TR[1:] = true_range

# Directional movements
diff_high = high[1:] - high[:-1]
diff_low = low[:-1] - low[1:]
plusDM[1:] = np.where((diff_high > diff_low) & (diff_high > 0), diff_high, 0)
minusDM[1:] = np.where((diff_low > diff_high) & (diff_low > 0), diff_low, 0)

# Wilder's smoothing parameters
a = 1 / period
discount = 1 - a

# Vectorized Wilder smoothing using matrix operations
def wilder_smooth(arr):
S = np.empty(n)
S[:period] = np.nan
init = np.sum(arr[1:period+1])
S[period] = init
M = n - period - 1 # number of elements to smooth after index 'period'
if M > 0:
X = arr[period+1:]
# Construct lower-triangular matrix where element (i, j) = discount^(i - j) for i >= j
T = np.tril(np.power(discount, np.subtract.outer(np.arange(M), np.arange(M))))
offsets = np.arange(1, M + 1) # discount exponent for the initial term
S[period+1:] = init * (discount ** offsets) + a * (T @ X)
return S

tr_smoothed = wilder_smooth(TR)
plusDM_smoothed = wilder_smooth(plusDM)
minusDM_smoothed = wilder_smooth(minusDM)

# Compute DI+ and DI-
DI_plus = np.full(n, np.nan)
DI_minus = np.full(n, np.nan)
valid = np.arange(period, n)
DI_plus[valid] = np.where(tr_smoothed[valid] == 0, 0, 100 * plusDM_smoothed[valid] / tr_smoothed[valid])
DI_minus[valid] = np.where(tr_smoothed[valid] == 0, 0, 100 * minusDM_smoothed[valid] / tr_smoothed[valid])
dd = DI_plus[valid] + DI_minus[valid]
DX = np.full(n, np.nan)
DX[valid] = np.where(dd == 0, 0, 100 * np.abs(DI_plus[valid] - DI_minus[valid]) / dd)

# Compute ADX smoothing
ADX = np.full(n, np.nan)
start_index = period * 2
if start_index < n:
first_adx = np.mean(DX[period:start_index])
ADX[start_index] = first_adx
M_adx = n - start_index - 1
if M_adx > 0:
Y = DX[start_index+1:]
T_adx = np.tril(np.power(discount, np.subtract.outer(np.arange(M_adx), np.arange(M_adx))))
offsets_adx = np.arange(1, M_adx + 1)
ADX[start_index+1:] = first_adx * (discount ** offsets_adx) + a * (T_adx @ Y)
result = ADX if sequential else ADX[-1]
return result
return result if sequential else result[-1]
67 changes: 35 additions & 32 deletions jesse/indicators/atr.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,42 @@
from typing import Union

import numpy as np

from numba import njit
from jesse.helpers import slice_candles


@njit
def _atr(high: np.ndarray, low: np.ndarray, close: np.ndarray, period: int) -> np.ndarray:
"""
Calculate ATR using Numba
"""
n = len(close)
tr = np.empty(n)
atr_values = np.full(n, np.nan)

# Calculate True Range
tr[0] = high[0] - low[0]
for i in range(1, n):
hl = high[i] - low[i]
hc = abs(high[i] - close[i-1])
lc = abs(low[i] - close[i-1])
tr[i] = max(max(hl, hc), lc)

if n < period:
return atr_values

# First ATR value is the simple average of the first 'period' true ranges
atr_values[period-1] = np.mean(tr[:period])

# Calculate subsequent ATR values using Wilder's smoothing
for i in range(period, n):
atr_values[i] = (atr_values[i-1] * (period - 1) + tr[i]) / period

return atr_values


def atr(candles: np.ndarray, period: int = 14, sequential: bool = False) -> Union[float, np.ndarray]:
"""
ATR - Average True Range
ATR - Average True Range using Numba for optimization
:param candles: np.ndarray
:param period: int - default: 14
Expand All @@ -21,32 +50,6 @@ def atr(candles: np.ndarray, period: int = 14, sequential: bool = False) -> Unio
low = candles[:, 4]
close = candles[:, 2]

# Compute previous close by shifting the close array; for the first element, use itself
prev_close = np.empty_like(close)
prev_close[0] = close[0]
prev_close[1:] = close[:-1]

# Calculate True Range
tr = np.maximum(high - low, np.maximum(np.abs(high - prev_close), np.abs(low - prev_close)))
tr[0] = high[0] - low[0] # ensure first element is high - low

# Initialize ATR array
atr_values = np.empty_like(tr)
# For indices with insufficient data, set to NaN
atr_values[:period-1] = np.nan
# First ATR value is the simple average of the first 'period' true ranges
atr_values[period-1] = np.mean(tr[:period])

# Compute subsequent ATR values using Wilder's smoothing method (vectorized implementation)
y0 = atr_values[period-1] # initial ATR value (simple average of first period true ranges)
n_rest = len(tr) - period
if n_rest > 0:
alpha = 1.0 / period
beta = (period - 1) / period # equivalent to 1 - alpha
indices = np.arange(1, n_rest + 1)
first_term = y0 * (beta ** indices)
# Create a lower-triangular matrix where L[i, j] = beta^(i - j) for j<=i
L = np.tril(beta ** (np.subtract.outer(np.arange(n_rest), np.arange(n_rest))))
atr_values[period:] = first_term + alpha * np.dot(L, tr[period:])

return atr_values if sequential else atr_values[-1]
result = _atr(high, low, close, period)

return result if sequential else result[-1]
66 changes: 46 additions & 20 deletions jesse/indicators/bollinger_bands_width.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,63 @@
from typing import Union
import numpy as np
from numba import njit
from jesse.helpers import get_candle_source, slice_candles


@njit
def _bb_width(source: np.ndarray, period: int, mult: float) -> np.ndarray:
"""
Calculate Bollinger Bands Width using Numba
"""
n = len(source)
bbw = np.full(n, np.nan)

if n < period:
return bbw

# Pre-calculate sum and sum of squares for optimization
sum_x = np.zeros(n - period + 1)
sum_x2 = np.zeros(n - period + 1)

# Initial window
sum_x[0] = np.sum(source[:period])
sum_x2[0] = np.sum(source[:period] ** 2)

# Rolling window calculations
for i in range(1, n - period + 1):
sum_x[i] = sum_x[i-1] - source[i-1] + source[i+period-1]
sum_x2[i] = sum_x2[i-1] - source[i-1]**2 + source[i+period-1]**2

# Calculate mean and standard deviation
mean = sum_x / period
std = np.sqrt((sum_x2 / period) - (mean ** 2))

# Calculate BBW
for i in range(period - 1, n):
idx = i - period + 1
basis = mean[idx]
upper = basis + mult * std[idx]
lower = basis - mult * std[idx]
bbw[i] = (upper - lower) / basis

return bbw


def bollinger_bands_width(candles: np.ndarray, period: int = 20, mult: float = 2.0, source_type: str = "close", sequential: bool = False) -> Union[float, np.ndarray]:
"""
BBW - Bollinger Bands Width - Bollinger Bands Bandwidth
BBW - Bollinger Bands Width - Bollinger Bands Bandwidth using Numba for optimization
:param candles: np.ndarray
:param period: int - default: 20
:param mult: float - default: 2
:param source_type: str - default: "close"
:param sequential: bool - default: False
:return: BollingerBands(upperband, middleband, lowerband)
:return: float | np.ndarray
"""
candles = slice_candles(candles, sequential)
source = get_candle_source(candles, source_type=source_type)

if sequential:
n = len(source)
bbw = np.full(n, np.nan)
if n >= period:
windows = np.lib.stride_tricks.sliding_window_view(source, window_shape=period)
basis = np.mean(windows, axis=1)
std = np.std(windows, axis=1, ddof=0)
# Compute Bollinger Bands Width using vectorized operation
bbw[period - 1:] = (( (basis + mult * std) - (basis - mult * std) ) / basis)
return bbw
else:
window = source[-period:]
basis = np.mean(window)
std = np.std(window, ddof=0)
upper = basis + mult * std
lower = basis - mult * std
return ((upper - lower) / basis)

result = _bb_width(source, period, mult)

return result if sequential else result[-1]
Loading

0 comments on commit 7a01f99

Please sign in to comment.