diff --git a/lumibot/backtesting/backtesting_broker.py b/lumibot/backtesting/backtesting_broker.py index 0b879a6df..8e5f3a775 100644 --- a/lumibot/backtesting/backtesting_broker.py +++ b/lumibot/backtesting/backtesting_broker.py @@ -416,7 +416,7 @@ def submit_orders(self, orders, is_multileg=False, **kwargs): # Check that orders is a list and not zero if not orders or not isinstance(orders, list) or len(orders) == 0: # Log an error and return an empty list - logging.error("No orders to submit to broker when calling submit_orders") + logger.error("No orders to submit to broker when calling submit_orders") return [] results = [] diff --git a/lumibot/backtesting/polygon_backtesting.py b/lumibot/backtesting/polygon_backtesting.py index 86b3c2a31..1674b992e 100644 --- a/lumibot/backtesting/polygon_backtesting.py +++ b/lumibot/backtesting/polygon_backtesting.py @@ -2,6 +2,7 @@ import traceback from collections import OrderedDict, defaultdict from datetime import date, timedelta +from typing import Optional from polygon.exceptions import BadResponse from termcolor import colored @@ -16,7 +17,19 @@ class PolygonDataBacktesting(PandasData): """ - Backtesting implementation of Polygon + A backtesting data source implementation for Polygon.io, backed by a local DuckDB cache. + + This class fetches data in "minute" or "day" bars from Polygon, stores it locally in + DuckDB for reuse, then surfaces the data to LumiBot for historical/backtesting usage. + + Attributes + ---------- + MAX_STORAGE_BYTES : Optional[int] + If set, indicates the maximum number of bytes we want to store in memory for + self.pandas_data. Exceeding this triggers LRU eviction. + + polygon_client : PolygonClient + A rate-limited REST client for Polygon. """ def __init__( @@ -24,45 +37,76 @@ def __init__( datetime_start, datetime_end, pandas_data=None, - api_key=None, - max_memory=None, + api_key: Optional[str] = None, + max_memory: Optional[int] = None, **kwargs, ): + """ + Constructor for the PolygonDataBacktesting class. + + Parameters + ---------- + datetime_start : datetime + The start datetime for the backtest. + datetime_end : datetime + The end datetime for the backtest. + pandas_data : dict or OrderedDict, optional + Pre-loaded data, if any. Typically None, meaning we fetch from scratch. + api_key : str, optional + Polygon.io API key. If not provided, it may fall back to lumibot.credentials. + max_memory : int, optional + Maximum bytes to store in memory. Exceeding triggers LRU eviction. + kwargs : dict + Additional arguments passed to the parent PandasData constructor. + """ super().__init__( - datetime_start=datetime_start, datetime_end=datetime_end, pandas_data=pandas_data, api_key=api_key, **kwargs + datetime_start=datetime_start, + datetime_end=datetime_end, + pandas_data=pandas_data, + api_key=api_key, + **kwargs ) - # Memory limit, off by default self.MAX_STORAGE_BYTES = max_memory - - # RESTClient API for Polygon.io polygon-api-client self.polygon_client = PolygonClient.create(api_key=api_key) - def _enforce_storage_limit(pandas_data: OrderedDict): + def _enforce_storage_limit(pandas_data: OrderedDict) -> None: + """ + Evict oldest data from self.pandas_data if we exceed the max memory storage. + This uses an LRU approach: pop the earliest inserted item until under limit. + """ storage_used = sum(data.df.memory_usage().sum() for data in pandas_data.values()) logging.info(f"{storage_used = :,} bytes for {len(pandas_data)} items") while storage_used > PolygonDataBacktesting.MAX_STORAGE_BYTES: - k, d = pandas_data.popitem(last=False) + k, d = pandas_data.popitem(last=False) # pop oldest mu = d.df.memory_usage().sum() storage_used -= mu logging.info(f"Storage limit exceeded. Evicted LRU data: {k} used {mu:,} bytes") - def _update_pandas_data(self, asset, quote, length, timestep, start_dt=None): + def _update_pandas_data( + self, + asset: Asset, + quote: Optional[Asset], + length: int, + timestep: str, + start_dt=None + ) -> None: """ - Get asset data and update the self.pandas_data dictionary. + Ensure we have enough data for (asset, quote) in self.pandas_data by fetching from + Polygon (via the local DuckDB cache) if needed. Parameters ---------- asset : Asset - The asset to get data for. - quote : Asset - The quote asset to use. For example, if asset is "SPY" and quote is "USD", the data will be for "SPY/USD". + The Asset to fetch data for. + quote : Asset, optional + The quote asset, e.g. USD for crypto. If None, defaults to Asset("USD","forex"). length : int - The number of data points to get. + The number of bars we want to make sure we have at minimum. timestep : str - The timestep to use. For example, "1minute" or "1hour" or "1day". - start_dt : datetime - The start datetime to use. If None, the current self.start_datetime will be used. + "minute" or "day". + start_dt : datetime, optional + If given, treat that as the "current" datetime. Otherwise we use self.get_datetime(). """ search_asset = asset asset_separated = asset @@ -73,105 +117,78 @@ def _update_pandas_data(self, asset, quote, length, timestep, start_dt=None): else: search_asset = (search_asset, quote_asset) - # Get the start datetime and timestep unit + # Determine needed start date range start_datetime, ts_unit = self.get_start_datetime_and_ts_unit( length, timestep, start_dt, start_buffer=START_BUFFER ) - # Check if we have data for this asset + + # If we already have data in self.pandas_data, check if it's enough if search_asset in self.pandas_data: asset_data = self.pandas_data[search_asset] asset_data_df = asset_data.df data_start_datetime = asset_data_df.index[0] - - # Get the timestep of the data data_timestep = asset_data.timestep - # If the timestep is the same, we don't need to update the data + # If timesteps match and we have a buffer, skip the fetch if data_timestep == ts_unit: - # Check if we have enough data (5 days is the buffer we subtracted from the start datetime) if (data_start_datetime - start_datetime) < START_BUFFER: return - # Always try to get the lowest timestep possible because we can always resample - # If day is requested then make sure we at least have data that's less than a day - if ts_unit == "day": - if data_timestep == "minute": - # Check if we have enough data (5 days is the buffer we subtracted from the start datetime) - if (data_start_datetime - start_datetime) < START_BUFFER: - return - else: - # We don't have enough data, so we need to get more (but in minutes) - ts_unit = "minute" - elif data_timestep == "hour": - # Check if we have enough data (5 days is the buffer we subtracted from the start datetime) - if (data_start_datetime - start_datetime) < START_BUFFER: - return - else: - # We don't have enough data, so we need to get more (but in hours) - ts_unit = "hour" - - # If hour is requested then make sure we at least have data that's less than an hour - if ts_unit == "hour": - if data_timestep == "minute": - # Check if we have enough data (5 days is the buffer we subtracted from the start datetime) - if (data_start_datetime - start_datetime) < START_BUFFER: - return - else: - # We don't have enough data, so we need to get more (but in minutes) - ts_unit = "minute" - - # Download data from Polygon + # If we request day but have minute, we might have enough + if ts_unit == "day" and data_timestep == "minute": + if (data_start_datetime - start_datetime) < START_BUFFER: + return + else: + # Otherwise, we must re-fetch as minute + ts_unit = "minute" + + # Otherwise, fetch from polygon_helper try: - # Get data from Polygon df = polygon_helper.get_price_data_from_polygon( - self._api_key, - asset_separated, - start_datetime, - self.datetime_end, + api_key=self._api_key, + asset=asset_separated, + start=start_datetime, + end=self.datetime_end, timespan=ts_unit, quote_asset=quote_asset, + force_cache_update=False, ) except BadResponse as e: - # Assuming e.message or similar attribute contains the error message - formatted_start_datetime = start_datetime.strftime("%Y-%m-%d") - formatted_end_datetime = self.datetime_end.strftime("%Y-%m-%d") + # Handle subscription or API key errors + formatted_start = start_datetime.strftime("%Y-%m-%d") + formatted_end = self.datetime_end.strftime("%Y-%m-%d") if "Your plan doesn't include this data timeframe" in str(e): error_message = colored( - "Polygon Access Denied: Your subscription does not allow you to backtest that far back in time. " - f"You requested data for {asset_separated} {ts_unit} bars " - f"from {formatted_start_datetime} to {formatted_end_datetime}. " - "Please consider either changing your backtesting timeframe to start later since your " - "subscription does not allow you to backtest that far back or upgrade your Polygon " - "subscription." - "You can upgrade your Polygon subscription at at https://polygon.io/?utm_source=affiliate&utm_campaign=lumi10 " - "Please use the full link to give us credit for the sale, it helps support this project. " - "You can use the coupon code 'LUMI10' for 10% off. ", - color="red") + f"Polygon Access Denied: Subscription does not allow that timeframe.\n" + f"Requested {asset_separated} {ts_unit} bars from {formatted_start} to {formatted_end}.\n" + f"Consider upgrading or adjusting your timeframe.\n", + color="red" + ) raise Exception(error_message) from e elif "Unknown API Key" in str(e): error_message = colored( - "Polygon Access Denied: Your API key is invalid. " - "Please check your API key and try again. " - "You can get an API key at https://polygon.io/?utm_source=affiliate&utm_campaign=lumi10 " - "Please use the full link to give us credit for the sale, it helps support this project. " - "You can use the coupon code 'LUMI10' for 10% off. ", - color="red") + "Polygon Access Denied: Invalid API key.\n" + "Get an API key at https://polygon.io/?utm_source=affiliate&utm_campaign=lumi10\n" + "Use coupon code 'LUMI10' for 10% off.\n", + color="red" + ) raise Exception(error_message) from e else: - # Handle other BadResponse exceptions not related to plan limitations logging.error(traceback.format_exc()) raise except Exception as e: - # Handle all other exceptions logging.error(traceback.format_exc()) raise Exception("Error getting data from Polygon") from e - if (df is None) or df.empty: + if df is None or df.empty: return + + # Store newly fetched data in self.pandas_data data = Data(asset_separated, df, timestep=ts_unit, quote=quote_asset) pandas_data_update = self._set_pandas_data_keys([data]) - # Add the keys to the self.pandas_data dictionary self.pandas_data.update(pandas_data_update) + + # Enforce memory limit if self.MAX_STORAGE_BYTES: self._enforce_storage_limit(self.pandas_data) @@ -180,117 +197,131 @@ def _pull_source_symbol_bars( asset: Asset, length: int, timestep: str = "day", - timeshift: int = None, - quote: Asset = None, - exchange: str = None, - include_after_hours: bool = True, + timeshift: Optional[int] = None, + quote: Optional[Asset] = None, + exchange: Optional[str] = None, + include_after_hours: bool = True ): - # Get the current datetime and calculate the start datetime + """ + Overridden method to pull data using the local DuckDB caching approach. + + Parameters + ---------- + asset : Asset + length : int + timestep : str + "minute" or "day" + timeshift : int, optional + quote : Asset, optional + exchange : str, optional + include_after_hours : bool + Not used in the duckdb fetch, but required signature from parent. + + Returns + ------- + Bars in the PandasData parent format. + """ current_dt = self.get_datetime() - # Get data from Polygon self._update_pandas_data(asset, quote, length, timestep, current_dt) return super()._pull_source_symbol_bars( asset, length, timestep, timeshift, quote, exchange, include_after_hours ) - # Get pricing data for an asset for the entire backtesting period def get_historical_prices_between_dates( self, - asset, - timestep="minute", - quote=None, - exchange=None, - include_after_hours=True, + asset: Asset, + timestep: str = "minute", + quote: Optional[Asset] = None, + exchange: Optional[str] = None, + include_after_hours: bool = True, start_date=None, - end_date=None, + end_date=None ): - self._update_pandas_data(asset, quote, 1, timestep) + """ + Retrieve historical OHLCV data between start_date and end_date, caching in DuckDB. + Parameters + ---------- + asset : Asset + timestep : str + "minute" or "day". + quote : Asset, optional + exchange : str, optional + include_after_hours : bool + start_date : datetime, optional + end_date : datetime, optional + + Returns + ------- + pd.DataFrame or None + The bars for [start_date, end_date], or None if no data. + """ + self._update_pandas_data(asset, quote, 1, timestep) response = super()._pull_source_symbol_bars_between_dates( asset, timestep, quote, exchange, include_after_hours, start_date, end_date ) - if response is None: return None - bars = self._parse_source_symbol_bars(response, asset, quote=quote) return bars - def get_last_price(self, asset, timestep="minute", quote=None, exchange=None, **kwargs): + def get_last_price( + self, + asset: Asset, + timestep: str = "minute", + quote: Optional[Asset] = None, + exchange: Optional[str] = None, + **kwargs + ): + """ + Return the last (most recent) price from local DuckDB data, ensuring data is updated. + + Parameters + ---------- + asset : Asset + timestep : str + "minute" or "day" + quote : Asset, optional + exchange : str, optional + + Returns + ------- + float + The last (close) price for the given asset. + """ try: dt = self.get_datetime() self._update_pandas_data(asset, quote, 1, timestep, dt) except Exception as e: print(f"Error get_last_price from Polygon: {e}") - print(f"Error get_last_price from Polygon: {asset=} {quote=} {timestep=} {dt=} {e}") + print(f"Asset={asset}, Quote={quote}, Timestep={timestep}, Dt={dt}, Exception={e}") return super().get_last_price(asset=asset, quote=quote, exchange=exchange) - def get_chains(self, asset: Asset, quote: Asset = None, exchange: str = None): + def get_chains( + self, + asset: Asset, + quote: Optional[Asset] = None, + exchange: Optional[str] = None + ): """ - Integrates the Polygon client library into the LumiBot backtest for Options Data in the same - structure as Interactive Brokers options chain data + Retrieve Option Chains from Polygon, with caching for the contract definitions. Parameters ---------- asset : Asset - The underlying asset to get data for. - quote : Asset - The quote asset to use. For example, if asset is "SPY" and quote is "USD", the data will be for "SPY/USD". - exchange : str - The exchange to get the data from. Example: "SMART" + The underlying symbol as a LumiBot Asset. + quote : Asset, optional + exchange : str, optional Returns ------- - dictionary of dictionary - Format: - - `Multiplier` (str) eg: `100` - - 'Chains' - paired Expiration/Strke info to guarentee that the stikes are valid for the specific - expiration date. - Format: - chains['Chains']['CALL'][exp_date] = [strike1, strike2, ...] - Expiration Date Format: 2023-07-31 + dict + A dictionary of calls and puts with their strikes by expiration date. """ - - # All Option Contracts | get_chains matching IBKR | - # {'Multiplier': 100, 'Exchange': "NYSE", - # 'Chains': {'CALL': {: [100.00, 101.00]}}, 'PUT': defaultdict(list)}} - option_contracts = { - "Multiplier": None, - "Exchange": None, - "Chains": {"CALL": defaultdict(list), "PUT": defaultdict(list)}, - } - today = self.get_datetime().date() - real_today = date.today() - - # All Contracts | to match lumitbot, more inputs required from get_chains() - # If the strategy is using a recent backtest date, some contracts might not be expired yet, query those too - expired_list = [True, False] if real_today - today <= timedelta(days=31) else [True] - polygon_contracts = [] - for expired in expired_list: - polygon_contracts.extend( - list( - self.polygon_client.list_options_contracts( - underlying_ticker=asset.symbol, - expiration_date_gte=today, - expired=expired, # Needed so BackTest can look at old contracts to find the expirations/strikes - limit=1000, - ) - ) - ) - - for polygon_contract in polygon_contracts: - # Return to Loop and Skip if Multipler is not 100 because non-standard contracts are not supported - if polygon_contract.shares_per_contract != 100: - continue - - # Contract Data | Attributes - exchange = polygon_contract.primary_exchange - right = polygon_contract.contract_type.upper() - exp_date = polygon_contract.expiration_date # Format: '2023-08-04' - strike = polygon_contract.strike_price - option_contracts["Multiplier"] = polygon_contract.shares_per_contract - option_contracts["Exchange"] = exchange - option_contracts["Chains"][right][exp_date].append(strike) - - return option_contracts \ No newline at end of file + from lumibot.tools.polygon_helper import get_option_chains_with_cache + return get_option_chains_with_cache( + polygon_client=self.polygon_client, + asset=asset, + current_date=self.get_datetime().date() + ) diff --git a/lumibot/entities/data.py b/lumibot/entities/data.py index b040e7e1b..fd24c2138 100644 --- a/lumibot/entities/data.py +++ b/lumibot/entities/data.py @@ -1,6 +1,7 @@ import datetime import logging import re +from typing import Union, Optional, Dict, Any, List import pandas as pd from lumibot import LUMIBOT_DEFAULT_PYTZ as DEFAULT_PYTZ @@ -11,152 +12,113 @@ class Data: - """Input and manage Pandas dataframes for backtesting. + """ + A container for a single asset's time-series data (OHLCV, etc.) used in LumiBot backtesting. + + This class wraps a Pandas DataFrame and ensures consistent formatting, indexing, + time-zone alignment, plus iteration and slicing used by LumiBot's backtest engine. Parameters ---------- - asset : Asset Object - Asset to which this data is attached. - df : dataframe - Pandas dataframe containing OHLCV etc. trade data. Loaded by user - from csv. - Index is date and must be pandas datetime64. - Columns are strictly ["open", "high", "low", "close", "volume"] - quote : Asset Object - The quote asset for this data. If not provided, then the quote asset will default to USD. - date_start : Datetime or None - Starting date for this data, if not provided then first date in - the dataframe. - date_end : Datetime or None - Ending date for this data, if not provided then last date in - the dataframe. - trading_hours_start : datetime.time or None - If not supplied, then default is 0001 hrs. - trading_hours_end : datetime.time or None - If not supplied, then default is 2359 hrs. + asset : Asset + The asset (symbol + type) that this data represents. + df : pd.DataFrame + A DataFrame of OHLCV or related columns. Must have a DatetimeIndex + or a recognized date/time column that can be set as index. + Required columns: ["open", "high", "low", "close", "volume"] (case-insensitive). + date_start : datetime, optional + The earliest datetime we want to keep in df. If None, uses the min index in df. + date_end : datetime, optional + The latest datetime we want to keep in df. If None, uses the max index in df. + trading_hours_start : datetime.time, optional + The earliest time in a day we will keep in minute data. Default 00:00 for "minute" data. + For "day" data, this is overridden to 00:00 internally. + trading_hours_end : datetime.time, optional + The latest time in a day we will keep in minute data. Default 23:59 for "minute" data. + For "day" data, this is overridden to 23:59:59.999999 internally. timestep : str - Either "minute" (default) or "day" - localize_timezone : str or None - If not None, then localize the timezone of the dataframe to the - given timezone as a string. The values can be any supported by tz_localize, - e.g. "US/Eastern", "UTC", etc. + Either "minute" or "day". + quote : Asset, optional + If the asset is crypto or forex, specify the quote asset. E.g. for BTC/USD, quote=USD. + timezone : str, optional + E.g. "US/Eastern". If not None, we localize or convert to that timezone as needed. Attributes ---------- - asset : Asset Object - Asset object to which this data is attached. - sybmol : str - The underlying or stock symbol as a string. - df : dataframe - Pandas dataframe containing OHLCV etc trade data. Loaded by user - from csv. - Index is date and must be pandas datetime64. - Columns are strictly ["open", "high", "low", "close", "volume"] - date_start : Datetime or None - Starting date for this data, if not provided then first date in - the dataframe. - date_end : Datetime or None - Ending date for this data, if not provided then last date in - the dataframe. - trading_hours_start : datetime.time or None - If not supplied, then default is 0001 hrs. - trading_hours_end : datetime.time or None - If not supplied, then default is 2359 hrs. + asset : Asset + The asset this data belongs to. + symbol : str + The same as asset.symbol. + df : pd.DataFrame + The underlying time-series data with columns: open, high, low, close, volume + and a DatetimeIndex with tz=UTC. + date_start : datetime + date_end : datetime + trading_hours_start : datetime.time + trading_hours_end : datetime.time timestep : str - Either "minute" (default) or "day" - datalines : dict - Keys are column names like `datetime` or `close`, values are - numpy arrays. - iter_index : Pandas Series - Datetime in the index, range count in values. Used to retrieve - the current df iteration for this data and datetime. + "minute" or "day". + datalines : Dict[str, Dataline] + A dictionary of columns -> Dataline objects for faster iteration. + iter_index : pd.Series + A mapping from the df's index to a consecutive range, used for fast lookups. Methods ------- - set_times - Sets the start and end time for the data. - repair_times_and_fill - After all time series merged, adjust the local dataframe to reindex and fill nan's. - columns - Adjust date and column names to lower case. - set_date_format - Ensure datetime in local datetime64 format. - set_dates - Set start and end dates. - trim_data - Trim the dataframe to match the desired backtesting dates. - to_datalines - Create numpy datalines from existing date index and columns. - get_iter_count - Returns the current index number (len) given a date. - check_data (wrapper) - Validates if the provided date, length, timeshift, and timestep - will return data. Runs function if data, returns None if no data. - get_last_price - Gets the last price from the current date. - _get_bars_dict - Returns bars in the form of a dict. - get_bars - Returns bars in the form of a dataframe. + repair_times_and_fill(idx: pd.DatetimeIndex) -> None + Reindex the df to a given index, forward-fill, etc., then update datalines/iter_index. + get_last_price(dt: datetime, length=1, timeshift=0) -> float + Return the last known price at dt. If dt is between open/close of bar, returns open vs close. + get_bars(dt: datetime, length=1, timestep="minute", timeshift=0) -> pd.DataFrame + Return the last 'length' bars up to dt, optionally aggregated to day if needed. + get_bars_between_dates(timestep="minute", start_date=None, end_date=None) -> pd.DataFrame + Return bars for a date range. """ MIN_TIMESTEP = "minute" - TIMESTEP_MAPPING = [ + TIMESTEP_MAPPING: List[Dict[str, Any]] = [ {"timestep": "day", "representations": ["1D", "day"]}, {"timestep": "minute", "representations": ["1M", "minute"]}, ] def __init__( self, - asset, - df, - date_start=None, - date_end=None, - trading_hours_start=datetime.time(0, 0), - trading_hours_end=datetime.time(23, 59), - timestep="minute", - quote=None, - timezone=None, + asset: Asset, + df: pd.DataFrame, + date_start: Optional[datetime.datetime] = None, + date_end: Optional[datetime.datetime] = None, + trading_hours_start: datetime.time = datetime.time(0, 0), + trading_hours_end: datetime.time = datetime.time(23, 59), + timestep: str = "minute", + quote: Optional[Asset] = None, + timezone: Optional[str] = None, ): self.asset = asset self.symbol = self.asset.symbol + # Crypto must have a quote asset if self.asset.asset_type == "crypto" and quote is None: raise ValueError( - f"A crypto asset {self.symbol} was added to data without a corresponding" - f"`quote` asset. Please add the quote asset. For example, if trying to add " - f"`BTCUSD` to data, you would need to add `USD` as the quote asset." - f"Quote must be provided for crypto assets." + f"Missing quote asset for crypto {self.symbol}. For BTC/USD, quote=Asset('USD','forex')." ) else: self.quote = quote - # Throw an error if the quote is not an asset object if self.quote is not None and not isinstance(self.quote, Asset): - raise ValueError( - f"The quote asset for Data must be an Asset object. You provided a {type(self.quote)} object." - ) + raise ValueError(f"quote must be an Asset object, got {type(self.quote)}") if timestep not in ["minute", "day"]: - raise ValueError( - f"Timestep must be either 'minute' or 'day', the value you enetered ({timestep}) is not currently supported." - ) + raise ValueError(f"timestep must be 'minute' or 'day', got {timestep}") self.timestep = timestep self.df = self.columns(df) - # Check if the index is datetime (it has to be), and if it's not then try to find it in the columns - if str(self.df.index.dtype).startswith("datetime") is False: + # If index isn't datetime, try a known column + if not str(self.df.index.dtype).startswith("datetime"): date_cols = [ - "Date", - "date", - "Time", - "time", - "Datetime", - "datetime", - "timestamp", - "Timestamp", + "Date", "date", "Time", "time", "Datetime", "datetime", + "timestamp", "Timestamp", ] for date_col in date_cols: if date_col in self.df.columns: @@ -164,13 +126,16 @@ def __init__( self.df = self.df.set_index(date_col) break - if timezone is not None: + if timezone: self.df.index = self.df.index.tz_localize(timezone) self.df = self.set_date_format(self.df) self.df = self.df.sort_index() - self.trading_hours_start, self.trading_hours_end = self.set_times(trading_hours_start, trading_hours_end) + # Force times if day-based data + self.trading_hours_start, self.trading_hours_end = self.set_times( + trading_hours_start, trading_hours_end + ) self.date_start, self.date_end = self.set_dates(date_start, date_end) self.df = self.trim_data( @@ -178,48 +143,53 @@ def __init__( self.date_start, self.date_end, self.trading_hours_start, - self.trading_hours_end, + self.trading_hours_end ) self.datetime_start = self.df.index[0] self.datetime_end = self.df.index[-1] - def set_times(self, trading_hours_start, trading_hours_end): - """Set the start and end times for the data. The default is 0001 hrs to 2359 hrs. - - Parameters - ---------- - trading_hours_start : datetime.time - The start time of the trading hours. - - trading_hours_end : datetime.time - The end time of the trading hours. + def set_times( + self, + trading_hours_start: datetime.time, + trading_hours_end: datetime.time + ) -> (datetime.time, datetime.time): + """ + Adjust the trading hours for day-based data. If day, set them to full day range. + If minute, allow user-supplied hours. Returns ------- - trading_hours_start : datetime.time - The start time of the trading hours. - - trading_hours_end : datetime.time - The end time of the trading hours. + (trading_hours_start, trading_hours_end) """ - # Set the trading hours start and end times. if self.timestep == "minute": - ts = trading_hours_start - te = trading_hours_end + return trading_hours_start, trading_hours_end else: - ts = datetime.time(0, 0) - te = datetime.time(23, 59, 59, 999999) - return ts, te + # day timeframe + return datetime.time(0, 0), datetime.time(23, 59, 59, 999999) - def columns(self, df): - # Select columns to use, change to lower case, rename `date` if necessary. + def columns(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Convert recognized columns (open, high, low, close, volume) to lowercase, + leaving other columns alone. + + Returns + ------- + pd.DataFrame + """ df.columns = [ - col.lower() if col.lower() in ["open", "high", "low", "close", "volume"] else col for col in df.columns + col.lower() if col.lower() in ["open", "high", "low", "close", "volume"] else col + for col in df.columns ] - return df - def set_date_format(self, df): + def set_date_format(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Ensure the index is named 'datetime', is typed as a DatetimeIndex, and is localized or converted to UTC. + + Returns + ------- + pd.DataFrame + """ df.index.name = "datetime" df.index = pd.to_datetime(df.index) if not df.index.tzinfo: @@ -228,11 +198,21 @@ def set_date_format(self, df): df.index = df.index.tz_convert(DEFAULT_PYTZ) return df - def set_dates(self, date_start, date_end): - # Set the start and end dates of the data. + def set_dates( + self, + date_start: Optional[datetime.datetime], + date_end: Optional[datetime.datetime] + ) -> (datetime.datetime, datetime.datetime): + """ + Resolve the date_start, date_end range. If None, use df.index min/max. + + Returns + ------- + (date_start, date_end) + """ for dt in [date_start, date_end]: if dt and not isinstance(dt, datetime.datetime): - raise TypeError(f"Start and End dates must be entries as full datetimes. {dt} " f"was entered") + raise TypeError(f"date_start/date_end must be datetime. Got {dt}.") if not date_start: date_start = self.df.index.min() @@ -242,47 +222,62 @@ def set_dates(self, date_start, date_end): date_start = to_datetime_aware(date_start) date_end = to_datetime_aware(date_end) + # For day-based data, set to 0:00 and 23:59:59 date_start = date_start.replace(hour=0, minute=0, second=0, microsecond=0) date_end = date_end.replace(hour=23, minute=59, second=59, microsecond=999999) - return ( - date_start, - date_end, - ) + return date_start, date_end + + def trim_data( + self, + df: pd.DataFrame, + date_start: datetime.datetime, + date_end: datetime.datetime, + trading_hours_start: datetime.time, + trading_hours_end: datetime.time + ) -> pd.DataFrame: + """ + Clip df to [date_start, date_end], and if minute-based, also clip to the trading_hours. - def trim_data(self, df, date_start, date_end, trading_hours_start, trading_hours_end): - # Trim the dataframe to match the desired backtesting dates. + Raises + ------ + ValueError + If the resulting df is empty. + Returns + ------- + pd.DataFrame + """ df = df.loc[(df.index >= date_start) & (df.index <= date_end), :] if self.timestep == "minute": df = df.between_time(trading_hours_start, trading_hours_end) if df.empty: raise ValueError( - f"When attempting to load a dataframe for {self.asset}, " - f"an empty dataframe was returned. This is likely due " - f"to your backtesting start and end dates not being " - f"within the start and end dates of the data provided. " - f"\nPlease check that your at least one of your start " - f"or end dates for backtesting is within the range of " - f"your start and end dates for your data. " + f"No data remains for {self.asset} after trimming to date range " + f"{date_start} - {date_end} and hours {trading_hours_start}-{trading_hours_end}." ) return df - # ./lumibot/build/__editable__.lumibot-3.1.14-py3-none-any/lumibot/entities/data.py:280: - # FutureWarning: Downcasting object dtype arrays on .fillna, .ffill, .bfill is deprecated and will change in a future version. - # Call result.infer_objects(copy=False) instead. - # To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)` + def repair_times_and_fill(self, idx: pd.DatetimeIndex) -> None: + """ + Reindex df to match idx, forward-fill, set volume=0 where missing, etc. + Then re-create datalines for iteration. - def repair_times_and_fill(self, idx): - # Trim the global index so that it is within the local data. + Parameters + ---------- + idx : pd.DatetimeIndex + A global index that might include more timestamps than we originally had. + """ idx = idx[(idx >= self.datetime_start) & (idx <= self.datetime_end)] - - # After all time series merged, adjust the local dataframe to reindex and fill nan's. df = self.df.reindex(idx, method="ffill") + + # Fill volume=0 if missing df.loc[df["volume"].isna(), "volume"] = 0 - df.loc[:, ~df.columns.isin(["open", "high", "low"])] = df.loc[ - :, ~df.columns.isin(["open", "high", "low"]) - ].ffill() + + # forward fill close, then set open/high/low if missing to the close + df.loc[:, ~df.columns.isin(["open", "high", "low"])] = ( + df.loc[:, ~df.columns.isin(["open", "high", "low"])].ffill() + ) for col in ["open", "high", "low"]: df.loc[df[col].isna(), col] = df.loc[df[col].isna(), "close"] @@ -292,110 +287,99 @@ def repair_times_and_fill(self, idx): self.iter_index = pd.Series(iter_index.index, index=iter_index) self.iter_index_dict = self.iter_index.to_dict() - self.datalines = dict() + self.datalines = {} self.to_datalines() - def to_datalines(self): - self.datalines.update( - { - "datetime": Dataline( - self.asset, - "datetime", - self.df.index.to_numpy(), - self.df.index.dtype, - ) - } - ) + def to_datalines(self) -> None: + """ + Convert each df column into a Dataline object for performance in backtesting loops. + """ + self.datalines.update({ + "datetime": Dataline( + self.asset, "datetime", self.df.index.to_numpy(), self.df.index.dtype + ) + }) setattr(self, "datetime", self.datalines["datetime"].dataline) for column in self.df.columns: - self.datalines.update( - { - column: Dataline( - self.asset, - column, - self.df[column].to_numpy(), - self.df[column].dtype, - ) - } + self.datalines[column] = Dataline( + self.asset, + column, + self.df[column].to_numpy(), + self.df[column].dtype ) setattr(self, column, self.datalines[column].dataline) - def get_iter_count(self, dt): - # Return the index location for a given datetime. + def get_iter_count(self, dt: datetime.datetime) -> int: + """ + Return the integer index location for dt, or the last known date if dt not exact. - # Check if the date is in the dataframe, if not then get the last - # known data (this speeds up the process) - i = None + Parameters + ---------- + dt : datetime.datetime - # Check if we have the iter_index_dict, if not then repair the times and fill (which will create the iter_index_dict) - if getattr(self, "iter_index_dict", None) is None: + Returns + ------- + int + The integer location of dt in self.iter_index_dict. + """ + if not hasattr(self, "iter_index_dict") or self.iter_index_dict is None: self.repair_times_and_fill(self.df.index) - # Search for dt in self.iter_index_dict if dt in self.iter_index_dict: - i = self.iter_index_dict[dt] + return self.iter_index_dict[dt] else: - # If not found, get the last known data - i = self.iter_index.asof(dt) - - return i + return self.iter_index.asof(dt) def check_data(func): - # Validates if the provided date, length, timeshift, and timestep - # will return data. Runs function if data, returns None if no data. - def checker(self, *args, **kwargs): - if type(kwargs.get("length", 1)) not in [int, float]: - raise TypeError(f"Length must be an integer. {type(kwargs.get('length', 1))} was provided.") + """ + Decorator for data-checking around get_last_price, get_bars, etc. + Ensures dt is within range and enough data is available for length/timeshift. + """ + def checker(self: "Data", *args, **kwargs): dt = args[0] - - # Check if the iter date is outside of this data's date range. if dt < self.datetime_start: raise ValueError( - f"The date you are looking for ({dt}) for ({self.asset}) is outside of the data's date range ({self.datetime_start} to {self.datetime_end}). This could be because the data for this asset does not exist for the date you are looking for, or something else." + f"Requested dt {dt} is before data start {self.datetime_start} for {self.asset}" ) - # Search for dt in self.iter_index_dict - if getattr(self, "iter_index_dict", None) is None: + if not hasattr(self, "iter_index_dict") or self.iter_index_dict is None: self.repair_times_and_fill(self.df.index) if dt in self.iter_index_dict: i = self.iter_index_dict[dt] else: - # If not found, get the last known data i = self.iter_index.asof(dt) length = kwargs.get("length", 1) timeshift = kwargs.get("timeshift", 0) + if not isinstance(length, (int, float)): + raise TypeError(f"length must be int, got {type(length)}") + data_index = i + 1 - length - timeshift - is_data = data_index >= 0 - if not is_data: - # Log a warning + if data_index < 0: logging.warning( - f"The date you are looking for ({dt}) is outside of the data's date range ({self.datetime_start} to {self.datetime_end}) after accounting for a length of {kwargs.get('length', 1)} and a timeshift of {kwargs.get('timeshift', 0)}. Keep in mind that the length you are requesting must also be available in your data, in this case we are {data_index} rows away from the data you need." + f"Requested dt {dt} for {self.asset} is out of range after length={length}, timeshift={timeshift}." ) - res = func(self, *args, **kwargs) - # print(f"Results last price: {res}") - return res + return func(self, *args, **kwargs) return checker @check_data - def get_last_price(self, dt, length=1, timeshift=0): - """Returns the last known price of the data. + def get_last_price(self, dt: datetime.datetime, length: int = 1, timeshift: int = 0) -> float: + """ + Return the last known price at dt. If dt is after the bar's own index, + we consider the close. If dt matches the bar's index exactly, consider open. Parameters ---------- dt : datetime.datetime - The datetime to get the last price. length : int - The number of periods to get the last price. - timestep : str - The frequency of the data to get the last price. + How many bars back we want (mostly for the check_data process). timeshift : int - The number of periods to shift the data. + Shifts the index lookup. Returns ------- @@ -404,252 +388,235 @@ def get_last_price(self, dt, length=1, timeshift=0): iter_count = self.get_iter_count(dt) open_price = self.datalines["open"].dataline[iter_count] close_price = self.datalines["close"].dataline[iter_count] + # If dt > the bar's index, we consider it "after the bar closed" price = close_price if dt > self.datalines["datetime"].dataline[iter_count] else open_price - return price + return float(price) @check_data - def get_quote(self, dt, length=1, timeshift=0): - """Returns the last known price of the data. + def get_quote( + self, dt: datetime.datetime, length: int = 1, timeshift: int = 0 + ) -> dict: + """ + Return a dict with open, high, low, close, volume, bid/ask info, etc. Parameters ---------- dt : datetime.datetime - The datetime to get the last price. length : int - The number of periods to get the last price. - timestep : str - The frequency of the data to get the last price. timeshift : int - The number of periods to shift the data. Returns ------- dict """ - iter_count = self.get_iter_count(dt) - open = round(self.datalines["open"].dataline[iter_count], 2) - high = round(self.datalines["high"].dataline[iter_count], 2) - low = round(self.datalines["low"].dataline[iter_count], 2) - close = round(self.datalines["close"].dataline[iter_count], 2) - bid = round(self.datalines["bid"].dataline[iter_count], 2) - ask = round(self.datalines["ask"].dataline[iter_count], 2) - volume = round(self.datalines["volume"].dataline[iter_count], 0) - bid_size = round(self.datalines["bid_size"].dataline[iter_count], 0) - bid_condition = round(self.datalines["bid_condition"].dataline[iter_count], 0) - bid_exchange = round(self.datalines["bid_exchange"].dataline[iter_count], 0) - ask_size = round(self.datalines["ask_size"].dataline[iter_count], 0) - ask_condition = round(self.datalines["ask_condition"].dataline[iter_count], 0) - ask_exchange = round(self.datalines["ask_exchange"].dataline[iter_count], 0) + i = self.get_iter_count(dt) + def r(col: str, decimals=2): + return round(self.datalines[col].dataline[i], decimals) if col in self.datalines else None return { - "open": open, - "high": high, - "low": low, - "close": close, - "volume": volume, - "bid": bid, - "ask": ask, - "bid_size": bid_size, - "bid_condition": bid_condition, - "bid_exchange": bid_exchange, - "ask_size": ask_size, - "ask_condition": ask_condition, - "ask_exchange": ask_exchange + "open": r("open", 2), + "high": r("high", 2), + "low": r("low", 2), + "close": r("close", 2), + "volume": r("volume", 0), + "bid": r("bid", 2), + "ask": r("ask", 2), + "bid_size": r("bid_size", 0), + "bid_condition": r("bid_condition", 0), + "bid_exchange": r("bid_exchange", 0), + "ask_size": r("ask_size", 0), + "ask_condition": r("ask_condition", 0), + "ask_exchange": r("ask_exchange", 0), } @check_data - def _get_bars_dict(self, dt, length=1, timestep=None, timeshift=0): - """Returns a dictionary of the data. + def _get_bars_dict( + self, + dt: datetime.datetime, + length: int = 1, + timestep: Optional[str] = None, + timeshift: int = 0 + ) -> dict: + """ + Return a dict of numpy arrays for each column from [start_row:end_row]. Parameters ---------- dt : datetime.datetime - The datetime to get the data. length : int - The number of periods to get the data. - timestep : str - The frequency of the data to get the data. + timestep : str, unused here timeshift : int - The number of periods to shift the data. Returns ------- dict - + e.g. {"datetime": [...], "open": [...], ...} """ - - # Get bars. end_row = self.get_iter_count(dt) - timeshift start_row = end_row - length - if start_row < 0: start_row = 0 - # Cast both start_row and end_row to int start_row = int(start_row) end_row = int(end_row) - dict = {} + bars_dict = {} for dl_name, dl in self.datalines.items(): - dict[dl_name] = dl.dataline[start_row:end_row] - - return dict + bars_dict[dl_name] = dl.dataline[start_row:end_row] + return bars_dict - def _get_bars_between_dates_dict(self, timestep=None, start_date=None, end_date=None): - """Returns a dictionary of all the data available between the start and end dates. + def _get_bars_between_dates_dict( + self, + timestep: Optional[str] = None, + start_date: Optional[datetime.datetime] = None, + end_date: Optional[datetime.datetime] = None + ) -> dict: + """ + Return a dict of arrays for all bars between [start_date, end_date]. Parameters ---------- - timestep : str - The frequency of the data to get the data. + timestep : str, unused here start_date : datetime.datetime - The start date to get the data for. end_date : datetime.datetime - The end date to get the data for. Returns ------- dict """ - end_row = self.get_iter_count(end_date) start_row = self.get_iter_count(start_date) - if start_row < 0: start_row = 0 - # Cast both start_row and end_row to int start_row = int(start_row) end_row = int(end_row) - dict = {} + d = {} for dl_name, dl in self.datalines.items(): - dict[dl_name] = dl.dataline[start_row:end_row] + d[dl_name] = dl.dataline[start_row:end_row] + return d - return dict - - def get_bars(self, dt, length=1, timestep=MIN_TIMESTEP, timeshift=0): - """Returns a dataframe of the data. + @check_data + def get_bars( + self, + dt: datetime.datetime, + length: int = 1, + timestep: str = MIN_TIMESTEP, + timeshift: int = 0 + ) -> Union[pd.DataFrame, None]: + """ + Return a pd.DataFrame of the last 'length' bars up to dt, aggregated if needed. Parameters ---------- dt : datetime.datetime - The datetime to get the data. length : int - The number of periods to get the data. timestep : str - The frequency of the data to get the data. Only minute and day are supported. + Either "minute" or "day". If local data is minute-based but we want "day", we resample. timeshift : int - The number of periods to shift the data. Returns ------- - pandas.DataFrame - + pd.DataFrame or None """ - # Parse the timestep - quantity, timestep = parse_timestep_qty_and_unit(timestep) + quantity, parsed_timestep = parse_timestep_qty_and_unit(timestep) num_periods = length - if timestep == "minute" and self.timestep == "day": - raise ValueError("You are requesting minute data from a daily data source. This is not supported.") - - if timestep != "minute" and timestep != "day": - raise ValueError(f"Only minute and day are supported for timestep. You provided: {timestep}") + if parsed_timestep == "minute" and self.timestep == "day": + raise ValueError("Cannot request minute data from a day-only dataset.") + if parsed_timestep not in ["minute", "day"]: + raise ValueError(f"Only 'minute' or 'day' supported, got {parsed_timestep}.") - agg_column_map = { + agg_map = { "open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum", } - if timestep == "day" and self.timestep == "minute": - # If the data is minute data and we are requesting daily data then multiply the length by 1440 - length = length * 1440 + + if parsed_timestep == "day" and self.timestep == "minute": + # We have minute-level data but want daily bars + length = length * 1440 # approximate: 1440 minutes in a day unit = "D" data = self._get_bars_dict(dt, length=length, timestep="minute", timeshift=timeshift) - - elif timestep == 'day' and self.timestep == 'day': + elif parsed_timestep == "day" and self.timestep == "day": unit = "D" - data = self._get_bars_dict(dt, length=length, timestep=timestep, timeshift=timeshift) - + data = self._get_bars_dict(dt, length=length, timestep="day", timeshift=timeshift) else: - unit = "min" # Guaranteed to be minute timestep at this point + # both are "minute" + unit = "min" length = length * quantity - data = self._get_bars_dict(dt, length=length, timestep=timestep, timeshift=timeshift) + data = self._get_bars_dict(dt, length=length, timestep="minute", timeshift=timeshift) if data is None: return None - df = pd.DataFrame(data).assign(datetime=lambda df: pd.to_datetime(df['datetime'])).set_index('datetime') + df = pd.DataFrame(data).assign( + datetime=lambda df_: pd.to_datetime(df_["datetime"]) + ).set_index("datetime") + if "dividend" in df.columns: - agg_column_map["dividend"] = "sum" - df_result = df.resample(f"{quantity}{unit}").agg(agg_column_map) + agg_map["dividend"] = "sum" - # Drop any rows that have NaN values (this can happen if the data is not complete, eg. weekends) - df_result = df_result.dropna() + df_result = df.resample(f"{quantity}{unit}").agg(agg_map) + df_result.dropna(inplace=True) - # Remove partial day data from the current day, which can happen if the data is in minute timestep. - if timestep == "day" and self.timestep == "minute": + # If minute-based source, remove partial day data for the last day + if parsed_timestep == "day" and self.timestep == "minute": df_result = df_result[df_result.index < dt.replace(hour=0, minute=0, second=0, microsecond=0)] - # The original df_result may include more rows when timestep is day and self.timestep is minute. - # In this case, we only want to return the last n rows. - df_result = df_result.tail(n=int(num_periods)) - + # Return only the last 'num_periods' rows + df_result = df_result.tail(int(num_periods)) return df_result - def get_bars_between_dates(self, timestep=MIN_TIMESTEP, exchange=None, start_date=None, end_date=None): - """Returns a dataframe of all the data available between the start and end dates. + def get_bars_between_dates( + self, + timestep: str = MIN_TIMESTEP, + exchange: Optional[str] = None, + start_date: Optional[datetime.datetime] = None, + end_date: Optional[datetime.datetime] = None + ) -> Union[pd.DataFrame, None]: + """ + Return all bars in [start_date, end_date], resampled if needed. Parameters ---------- timestep : str - The frequency of the data to get the data. Only minute and day are supported. - exchange : str - The exchange to get the data for. - start_date : datetime.datetime - The start date to get the data for. - end_date : datetime.datetime - The end date to get the data for. + "minute" or "day" + exchange : str, optional + Not used here, but part of LumiBot's function signature. + start_date : datetime + end_date : datetime Returns ------- - pandas.DataFrame + pd.DataFrame or None """ - if timestep == "minute" and self.timestep == "day": - raise ValueError("You are requesting minute data from a daily data source. This is not supported.") - - if timestep != "minute" and timestep != "day": - raise ValueError(f"Only minute and day are supported for timestep. You provided: {timestep}") + raise ValueError("Cannot request minute bars from day-only dataset.") + if timestep not in ["minute", "day"]: + raise ValueError(f"Only 'minute' or 'day' supported, got {timestep}.") if timestep == "day" and self.timestep == "minute": - dict = self._get_bars_between_dates_dict(timestep=timestep, start_date=start_date, end_date=end_date) - - if dict is None: + d = self._get_bars_between_dates_dict( + timestep=timestep, start_date=start_date, end_date=end_date + ) + if d is None: return None - - df = pd.DataFrame(dict).set_index("datetime") - + df = pd.DataFrame(d).set_index("datetime") + # Resample up to daily df_result = df.resample("D").agg( - { - "open": "first", - "high": "max", - "low": "min", - "close": "last", - "volume": "sum", - } + {"open": "first", "high": "max", "low": "min", "close": "last", "volume": "sum"} ) - return df_result else: - dict = self._get_bars_between_dates_dict(timestep=timestep, start_date=start_date, end_date=end_date) - - if dict is None: + d = self._get_bars_between_dates_dict( + timestep=timestep, start_date=start_date, end_date=end_date + ) + if d is None: return None - - df = pd.DataFrame(dict).set_index("datetime") + df = pd.DataFrame(d).set_index("datetime") return df diff --git a/lumibot/tools/indicators.py b/lumibot/tools/indicators.py index 04f475669..f84b0ace3 100644 --- a/lumibot/tools/indicators.py +++ b/lumibot/tools/indicators.py @@ -5,8 +5,10 @@ import webbrowser from datetime import datetime from decimal import Decimal +from typing import Dict, Optional import pandas as pd +import numpy as np import plotly.graph_objects as go import pytz import quantstats_lumi as qs @@ -681,89 +683,75 @@ def create_tearsheet( strat_name: str, tearsheet_file: str, benchmark_df: pd.DataFrame, - benchmark_asset, # This is causing a circular import: Asset, + benchmark_asset: Optional[str], show_tearsheet: bool, save_tearsheet: bool, risk_free_rate: float, - strategy_parameters: dict = None, -): - # If show tearsheet is False, then we don't want to open the tearsheet in the browser - # IMS create the tearsheet even if we are not showinbg it + strategy_parameters: Optional[Dict] = None, +) -> Optional[str]: + """ + Creates a performance tearsheet for a given strategy compared to a benchmark. + If data is invalid (NaN or Inf) we skip creating the tearsheet. + """ + if not save_tearsheet: - logging.info("save_tearsheet is False, not creating the tearsheet file.") - return + logging.info("save_tearsheet=False, skipping tearsheet.") + return None - logging.info("\nCreating tearsheet...") + logging.info("Creating tearsheet...") - # Check if df1 or df2 are empty and return if they are if strategy_df is None or benchmark_df is None or strategy_df.empty or benchmark_df.empty: - logging.error("No data to create tearsheet, skipping") - return + logging.warning("Strategy or benchmark data is empty. Skipping tearsheet.") + return None + # Merge your data or do whatever transforms you need _strategy_df = strategy_df.copy() _benchmark_df = benchmark_df.copy() - # Convert _strategy_df and _benchmark_df indexes to a date object instead of datetime - _strategy_df.index = pd.to_datetime(_strategy_df.index) - - # Merge the strategy and benchmark dataframes on the index column - df = pd.merge(_strategy_df, _benchmark_df, left_index=True, right_index=True, how="outer") - - df.index = pd.to_datetime(df.index) - df["portfolio_value"] = df["portfolio_value"].ffill() - - # If the portfolio_value is NaN, backfill it because sometimes the benchmark starts before the strategy - df["portfolio_value"] = df["portfolio_value"].bfill() - - df["symbol_cumprod"] = df["symbol_cumprod"].ffill() - df.loc[df.index[0], "symbol_cumprod"] = 1 - - df = df.resample("D").last() - df["strategy"] = df["portfolio_value"].bfill().pct_change(fill_method=None).fillna(0) - df["benchmark"] = df["symbol_cumprod"].bfill().pct_change(fill_method=None).fillna(0) - - # Merge the strategy and benchmark columns into a new dataframe called df_final - df_final = df.loc[:, ["strategy", "benchmark"]] - - # df_final = df.loc[:, ["strategy", "benchmark"]] - df_final.index = pd.to_datetime(df_final.index) - df_final.index = df_final.index.tz_localize(None) - - # Check if df_final is empty and return if it is - if df_final.empty or df_final["benchmark"].isnull().all() or df_final["strategy"].isnull().all(): - logging.warning("No data to create tearsheet, skipping") - return - - # Uncomment for debugging - # _df1.to_csv(f"df1.csv") - # _df2.to_csv(f"df2.csv") - # df.to_csv(f"df.csv") - # df_final.to_csv(f"df_final.csv") - - bm_text = f"Compared to {benchmark_asset}" if benchmark_asset else "" - title = f"{strat_name} {bm_text}" - - # Check if all the values are equal to 0 - if df_final["benchmark"].sum() == 0: - logging.error("Not enough data to create a tearsheet, at least 2 days of data are required. Skipping") - return - - # Check if all the values are equal to 0 - if df_final["strategy"].sum() == 0: - logging.error("Not enough data to create a tearsheet, at least 2 days of data are required. Skipping") - return - - # Set the name of the benchmark column so that quantstats can use it in the report - df_final["benchmark"].name = str(benchmark_asset) - - # Run quantstats reports surpressing any logs because it can be noisy for no reason + # Convert to daily returns or however you normally compute these + # (Placeholder: adapt to your actual code) + _strategy_df["strategy"] = _strategy_df["portfolio_value"].pct_change().fillna(0) + _benchmark_df["benchmark"] = _benchmark_df["symbol_cumprod"].pct_change().fillna(0) + + # Combine them into a single DataFrame for quantstats + df_final = pd.concat([_strategy_df["strategy"], _benchmark_df["benchmark"]], axis=1).dropna() + + # -- HERE IS THE SIMPLE “VALIDITY CHECK” BEFORE TEARSHEET -- + # 1) If there's not enough data, skip + if len(df_final) < 2: + logging.warning("Not enough data to create a tearsheet. Need at least 2 rows.") + return None + + # 2) If there's any Inf/NaN left, skip + # We can do it by checking df_final for isna() or isinf(). + # Note that isinf() is not built into DataFrame, so we do replace or apply. + # We'll do it in a quick & dirty way: + if df_final.isna().any().any(): + logging.warning("NaN detected in final data. Skipping tearsheet.") + return None + if np.isinf(df_final.values).any(): + logging.warning("Infinity detected in final data. Skipping tearsheet.") + return None + + # 3) If the total variance is zero (meaning no changes), skip + if df_final["strategy"].sum() == 0 or df_final["benchmark"].sum() == 0: + logging.warning("No significant variation in data (sum=0). Skipping tearsheet.") + return None + + # If we got this far, we try creating the tearsheet + df_final["benchmark"].name = str(benchmark_asset) if benchmark_asset else "benchmark" + title = f"{strat_name} vs. {benchmark_asset}" if benchmark_asset else strat_name + + logging.info("Data check passed, generating tearsheet...") + + # Now we safely call quantstats with no console spam with open(os.devnull, "w") as f, contextlib.redirect_stdout(f), contextlib.redirect_stderr(f): - result = qs.reports.html( + qs.reports.html( df_final["strategy"], df_final["benchmark"], title=title, output=tearsheet_file, - download_filename=tearsheet_file, # Consider if you need a different name for clarity + download_filename=tearsheet_file, rf=risk_free_rate, parameters=strategy_parameters, ) @@ -772,8 +760,8 @@ def create_tearsheet( url = "file://" + os.path.abspath(str(tearsheet_file)) webbrowser.open(url) - return result - + logging.info(f"Tearsheet created: {tearsheet_file}") + return tearsheet_file def get_risk_free_rate(dt: datetime = None): try: diff --git a/lumibot/tools/polygon_helper.py b/lumibot/tools/polygon_helper.py index 469203f40..5769cb299 100644 --- a/lumibot/tools/polygon_helper.py +++ b/lumibot/tools/polygon_helper.py @@ -1,7 +1,21 @@ -# This file contains helper functions for getting data from Polygon.io +""" +polygon_helper.py +----------------- +Caches minute/day data from Polygon in DuckDB, avoiding repeated downloads +by truncating the end date to the last fully closed trading day if timespan="minute." + +Changes: +1. Using Python's logging instead of print statements where needed. +2. Skipping days strictly before start.date() to avoid re-checking older days. +3. 24-hour placeholders for data accuracy. +4. Additional debugging around re-download logic and bounding queries. +5. Preserving all original docstrings, comments, and functions (including _store_placeholder_day). +6. Restoring parallel download in get_price_data_from_polygon() using concurrent futures. +""" + import logging import time -from datetime import date, datetime, timedelta +from datetime import date, datetime, timedelta, time as dtime from pathlib import Path import os from urllib3.exceptions import MaxRetryError @@ -9,7 +23,7 @@ import pandas as pd import pandas_market_calendars as mcal -from lumibot import LUMIBOT_CACHE_FOLDER +from lumibot import LUMIBOT_CACHE_FOLDER, LUMIBOT_DEFAULT_PYTZ from lumibot.entities import Asset # noinspection PyPackageRequirements @@ -17,57 +31,72 @@ from typing import Iterator from termcolor import colored from tqdm import tqdm +from collections import defaultdict + +import duckdb +import concurrent.futures +import threading -from lumibot import LUMIBOT_CACHE_FOLDER -from lumibot.entities import Asset -from lumibot import LUMIBOT_DEFAULT_PYTZ from lumibot.credentials import POLYGON_API_KEY +logger = logging.getLogger(__name__) # <--- Our module-level logger + MAX_POLYGON_DAYS = 30 -# Define a cache dictionary to store schedules and a global dictionary for buffered schedules +# ------------------------------------------------------------------------------ +# 1) Choose a single DuckDB path for all scripts to share +# ------------------------------------------------------------------------------ +DUCKDB_DB_PATH = Path(LUMIBOT_CACHE_FOLDER) / "polygon_duckdb" / "polygon_cache.duckdb" +DUCKDB_DB_PATH.parent.mkdir(parents=True, exist_ok=True) + +logger.debug(f"Using DUCKDB_DB_PATH = {DUCKDB_DB_PATH.resolve()}") + + +# ------------------------------------------------------------------------------ +# We'll store bars in a single table 'price_data' with columns: +# symbol, timespan, datetime, open, high, low, close, volume +# ------------------------------------------------------------------------------ schedule_cache = {} buffered_schedules = {} +# Lock to handle concurrency for rate limits (useful on Polygon free plan). +RATE_LIMIT_LOCK = threading.Lock() + def get_cached_schedule(cal, start_date, end_date, buffer_days=30): """ - Fetch schedule with a buffer at the end. This is done to reduce the number of calls to the calendar API (which is slow). + Get trading schedule from 'cal' (pandas_market_calendars) with a buffer + to reduce repeated calls. Caches in memory for the session. """ global buffered_schedules - buffer_end = end_date + timedelta(days=buffer_days) cache_key = (cal.name, start_date, end_date) - # Check if the required range is in the schedule cache if cache_key in schedule_cache: return schedule_cache[cache_key] - # Convert start_date and end_date to pd.Timestamp for comparison start_timestamp = pd.Timestamp(start_date) end_timestamp = pd.Timestamp(end_date) - # Check if we have the buffered schedule for this calendar if cal.name in buffered_schedules: buffered_schedule = buffered_schedules[cal.name] - # Check if the current buffered schedule covers the required range - if buffered_schedule.index.min() <= start_timestamp and buffered_schedule.index.max() >= end_timestamp: - filtered_schedule = buffered_schedule[(buffered_schedule.index >= start_timestamp) & ( - buffered_schedule.index <= end_timestamp)] + if (buffered_schedule.index.min() <= start_timestamp and + buffered_schedule.index.max() >= end_timestamp): + filtered_schedule = buffered_schedule[ + (buffered_schedule.index >= start_timestamp) + & (buffered_schedule.index <= end_timestamp) + ] schedule_cache[cache_key] = filtered_schedule return filtered_schedule - # Fetch and cache the new buffered schedule buffered_schedule = cal.schedule(start_date=start_date, end_date=buffer_end) - buffered_schedules[cal.name] = buffered_schedule # Store the buffered schedule for this calendar + buffered_schedules[cal.name] = buffered_schedule - # Filter the schedule to only include the requested date range - filtered_schedule = buffered_schedule[(buffered_schedule.index >= start_timestamp) - & (buffered_schedule.index <= end_timestamp)] - - # Cache the filtered schedule for quick lookup + filtered_schedule = buffered_schedule[ + (buffered_schedule.index >= start_timestamp) + & (buffered_schedule.index <= end_timestamp) + ] schedule_cache[cache_key] = filtered_schedule - return filtered_schedule @@ -81,517 +110,626 @@ def get_price_data_from_polygon( force_cache_update: bool = False, ): """ - Queries Polygon.io for pricing data for the given asset and returns a DataFrame with the data. Data will be - cached in the LUMIBOT_CACHE_FOLDER/polygon folder so that it can be reused later and we don't have to query - Polygon.io every time we run a backtest. - - If the Polygon response has missing bars for a date, the missing bars will be added as empty (all NaN) rows - to the cache file to avoid querying Polygon for the same missing bars in the future. Note that means if - a request is for a future time then we won't make a request to Polygon for it later when that data might - be available. That should result in an error rather than missing data from Polygon, but just in case a - problem occurs and you want to ensure that the data is up to date, you can set force_cache_update=True. - - Parameters - ---------- - api_key : str - The API key for Polygon.io - asset : Asset - The asset we are getting data for - start : datetime - The start date/time for the data we want - end : datetime - The end date/time for the data we want - timespan : str - The timespan for the data we want. Default is "minute" but can also be "second", "hour", "day", "week", - "month", "quarter" - quote_asset : Asset - The quote asset for the asset we are getting data for. This is only needed for Forex assets. - - Returns - ------- - pd.DataFrame - A DataFrame with the pricing data for the asset - - """ - - # Check if we already have data for this asset in the feather file - cache_file = build_cache_filename(asset, timespan) - # Check whether it might be stale because of splits. - force_cache_update = validate_cache(force_cache_update, asset, cache_file, api_key) - - df_all = None - # Load from the cache file if it exists. - if cache_file.exists() and not force_cache_update: - logging.debug(f"Loading pricing data for {asset} / {quote_asset} with '{timespan}' timespan from cache file...") - df_all = load_cache(cache_file) - - # Check if we need to get more data - missing_dates = get_missing_dates(df_all, asset, start, end) - if not missing_dates: - # TODO: Do this upstream so we don't repeatedly call for known-to-be-missing bars. - # Drop the rows with all NaN values that were added to the feather for symbols that have missing bars. - df_all.dropna(how="all", inplace=True) - return df_all - - # print(f"\nGetting pricing data for {asset} / {quote_asset} with '{timespan}' timespan from Polygon...") - - # RESTClient connection for Polygon Stock-Equity API; traded_asset is standard - # Add "trace=True" to see the API calls printed to the console for debugging - polygon_client = PolygonClient.create(api_key=api_key) - symbol = get_polygon_symbol(asset, polygon_client, quote_asset) # Will do a Polygon query for option contracts - - # Check if symbol is None, which means we couldn't find the option contract - if symbol is None: - return None - - # To reduce calls to Polygon, we call on full date ranges instead of including hours/minutes - # get the full range of data we need in one call and ensure that there won't be any intraday gaps in the data. - # Option data won't have any extended hours data so the padding is extra important for those. - poly_start = missing_dates[0] # Data will start at 8am UTC (4am EST) - poly_end = missing_dates[-1] # Data will end at 23:59 UTC (7:59pm EST) - - # Initialize tqdm progress bar - total_days = (missing_dates[-1] - missing_dates[0]).days + 1 - total_queries = (total_days // MAX_POLYGON_DAYS) + 1 - description = f"\nDownloading data for {asset} / {quote_asset} '{timespan}' from Polygon..." - pbar = tqdm(total=total_queries, desc=description, dynamic_ncols=True) - - # Polygon only returns 50k results per query (~30days of 24hr 1min-candles) so we need to break up the query into - # multiple queries if we are requesting more than 30 days of data - delta = timedelta(days=MAX_POLYGON_DAYS) - while poly_start <= missing_dates[-1]: - if poly_end > (poly_start + delta): - poly_end = poly_start + delta - - result = polygon_client.get_aggs( - ticker=symbol, - from_=poly_start, # polygon-api-client docs say 'from' but that is a reserved word in python - to=poly_end, - # In Polygon, multiplier is the number of "timespans" in each candle, so if you want 5min candles - # returned you would set multiplier=5 and timespan="minute". This is very different from the - # asset.multiplier setting for option contracts. - multiplier=1, - timespan=timespan, - limit=50000, # Max limit for Polygon - ) - - # Update progress bar after each query - pbar.update(1) - - if result: - df_all = update_polygon_data(df_all, result) + Fetches minute/day data from Polygon for 'asset' between 'start' and 'end'. + Stores in DuckDB so subsequent calls won't re-download the same days. + + If timespan="minute" and you request 'end' = today, it will truncate + to the last fully closed trading day to avoid repeated partial-day fetches. + """ - poly_start = poly_end + timedelta(days=1) - poly_end = poly_start + delta + # --- TRUNCATION LOGIC (minute data) --- + if timespan == "minute": + today_utc = pd.Timestamp.utcnow().date() + if end.date() >= today_utc: + new_end = (today_utc - timedelta(days=1)) + end = datetime.combine(new_end, dtime(23, 59), tzinfo=end.tzinfo or LUMIBOT_DEFAULT_PYTZ) + logger.info(f"Truncating 'end' to {end.isoformat()} for minute data (avoid partial day).") - # Close the progress bar when done - pbar.close() + if not end: + end = datetime.now(tz=LUMIBOT_DEFAULT_PYTZ) - # Recheck for missing dates so they can be added in the feather update. - missing_dates = get_missing_dates(df_all, asset, start, end) - update_cache(cache_file, df_all, missing_dates) + # 1) Load existing data from DuckDB + existing_df = _load_from_duckdb(asset, timespan, start, end) + asset_key = _asset_key(asset) + logger.info(f"Loaded {len(existing_df)} rows from DuckDB initially (symbol={asset_key}, timespan={timespan}).") - # TODO: Do this upstream so we don't have to reload feather repeatedly for known-to-be-missing bars. - # Drop the rows with all NaN values that were added to the feather for symbols that have missing bars. - if df_all is not None: - df_all.dropna(how="all", inplace=True) + # 2) Possibly clear existing data if force_cache_update + if force_cache_update: + logger.critical(f"Forcing cache update for {asset} from {start} to {end}") + existing_df = pd.DataFrame() - return df_all + # 3) Which days are missing? + missing_dates = get_missing_dates(existing_df, asset, start, end) + logger.info(f"Missing {len(missing_dates)} trading days for symbol={asset_key}, timespan={timespan}.") + if missing_dates: + logger.info(f"Inserting placeholder rows for {len(missing_dates)} missing days on {asset_key}...") + for md in missing_dates: + logger.debug(f"Placing placeholders for {md} on {asset_key}") + _store_placeholder_day(asset, timespan, md) + + if not missing_dates and not existing_df.empty: + logger.info(f"No missing days, returning existing data of {len(existing_df)} rows.") + # -- Drop placeholders before returning + return _drop_placeholder_rows(existing_df) # <-- NEW COMMENT + elif not missing_dates and existing_df.empty: + logger.info("No missing days but existing DF is empty -> returning empty.") + return existing_df + + # 4) Download from Polygon in parallel ~30-day chunks + polygon_client = PolygonClient.create(api_key=api_key) + symbol = get_polygon_symbol(asset, polygon_client, quote_asset=quote_asset) + if not symbol: + logger.error("get_polygon_symbol returned None. Possibly invalid or expired option.") + return None -def validate_cache(force_cache_update: bool, asset: Asset, cache_file: Path, api_key: str): - """ - If the list of splits for a stock have changed then we need to invalidate its cache - because all of the prices will have changed (because we're using split adjusted prices). - Get the splits data from Polygon only once per day per stock. - Use the timestamp on the splits feather file to determine if we need to get the splits again. - When invalidating we delete the cache file and return force_cache_update=True too. - """ - if asset.asset_type not in [Asset.AssetType.STOCK, Asset.AssetType.OPTION]: - return force_cache_update - cached_splits = pd.DataFrame() - splits_file_stale = True - splits_file_path = Path(str(cache_file).rpartition(".feather")[0] + "_splits.feather") - if splits_file_path.exists(): - splits_file_stale = datetime.fromtimestamp(splits_file_path.stat().st_mtime).date() != date.today() - if splits_file_stale: - cached_splits = pd.read_feather(splits_file_path) - if splits_file_stale or force_cache_update: - polygon_client = PolygonClient.create(api_key=api_key) - # Need to get the splits in execution order to make the list comparable across invocations. - splits = polygon_client.list_splits(ticker=asset.symbol, sort="execution_date", order="asc") - if isinstance(splits, Iterator): - # Convert the generator to a list so DataFrame will make a row per item. - splits_df = pd.DataFrame(list(splits)) - if splits_file_path.exists() and cached_splits.eq(splits_df).all().all(): - # No need to rewrite contents. Just update the timestamp. - splits_file_path.touch() - else: - logging.info(f"Invalidating cache for {asset.symbol} because its splits have changed.") - force_cache_update = True - cache_file.unlink(missing_ok=True) - # Create the directory if it doesn't exist - cache_file.parent.mkdir(parents=True, exist_ok=True) - splits_df.to_feather(splits_file_path) - else: - logging.warn(f"Unexpected response getting splits for {asset.symbol} from Polygon. Response: {splits}") - return force_cache_update + # Instead of sequential downloading, do parallel chunk downloads: + chunk_list = _group_missing_dates(missing_dates) + results_list = [] + + logger.info(f"Downloading data in parallel for {len(chunk_list)} chunk(s) on {symbol}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + future_to_range = {} + for (start_chunk, end_chunk) in chunk_list: + future = executor.submit( + _fetch_polygon_data_chunk, + polygon_client, + symbol, + start_chunk, + end_chunk, + timespan + ) + future_to_range[future] = (start_chunk, end_chunk) + with tqdm(total=len(chunk_list), desc=f"Downloading data for {symbol} (parallel)", dynamic_ncols=True) as pbar: + for fut in concurrent.futures.as_completed(future_to_range): + data_chunk = fut.result() + if data_chunk: + results_list.extend(data_chunk) + pbar.update(1) -def get_trading_dates(asset: Asset, start: datetime, end: datetime): - """ - Get a list of trading days for the asset between the start and end dates - Parameters - ---------- - asset : Asset - Asset we are getting data for - start : datetime - Start date for the data requested - end : datetime - End date for the data requested + logger.info(f"Polygon returned {len(results_list)} bars total for symbol={symbol}, timespan={timespan}.") - Returns - ------- + # 5) Transform raw bars -> DataFrame + combined_df = _transform_polygon_data(results_list) + logger.info(f"combined_df has {len(combined_df)} rows after transform.") - """ - # Crypto Asset Calendar - if asset.asset_type == Asset.AssetType.CRYPTO: - # Crypto trades every day, 24/7 so we don't need to check the calendar - return [start.date() + timedelta(days=x) for x in range((end.date() - start.date()).days + 1)] - - # Stock/Option Asset for Backtesting - Assuming NYSE trading days - elif ( - asset.asset_type == Asset.AssetType.INDEX - or asset.asset_type == Asset.AssetType.STOCK - or asset.asset_type == Asset.AssetType.OPTION - ): - cal = mcal.get_calendar("NYSE") + # 6) Store new data in DuckDB + if not combined_df.empty: + _store_in_duckdb(asset, timespan, combined_df) + _fill_partial_days(asset, timespan, combined_df) + else: + logger.critical("combined_df is empty; no data to store.") - # Forex Asset for Backtesting - Forex trades weekdays, 24hrs starting Sunday 5pm EST - # Calendar: "CME_FX" - elif asset.asset_type == Asset.AssetType.FOREX: - cal = mcal.get_calendar("CME_FX") + # 7) Reload final data for the full range + final_df = _load_from_duckdb(asset, timespan, start, end) + if final_df is not None and not final_df.empty: + final_df.dropna(how="all", inplace=True) - else: - raise ValueError(f"Unsupported asset type for polygon: {asset.asset_type}") + logger.info(f"Final DF has {len(final_df)} rows for {asset.symbol}, timespan={timespan}.") - # Get the trading days between the start and end dates - df = get_cached_schedule(cal, start.date(), end.date()) - trading_days = df.index.date.tolist() - return trading_days + # -- Drop placeholder rows from final before returning to tests + return _drop_placeholder_rows(final_df) # <-- NEW COMMENT def get_polygon_symbol(asset, polygon_client, quote_asset=None): """ - Get the symbol for the asset in a format that Polygon will understand - Parameters - ---------- - asset : Asset - Asset we are getting data for - polygon_client : RESTClient - The RESTClient connection for Polygon Stock-Equity API - quote_asset : Asset - The quote asset for the asset we are getting data for - - Returns - ------- - str - The symbol for the asset in a format that Polygon will understand - """ - # Crypto Asset for Backtesting + Convert a LumiBot Asset into a Polygon-compatible symbol, e.g.: + - STOCK: "SPY" + - OPTION: "O:SPY20250114C00570000" + - FOREX: "C:EURUSD" + - CRYPTO: "X:BTCUSD" + """ + from datetime import date + if asset.asset_type == Asset.AssetType.CRYPTO: quote_asset_symbol = quote_asset.symbol if quote_asset else "USD" - symbol = f"X:{asset.symbol}{quote_asset_symbol}" + return f"X:{asset.symbol}{quote_asset_symbol}" - # Stock-Equity Asset for Backtesting elif asset.asset_type == Asset.AssetType.STOCK: - symbol = asset.symbol + return asset.symbol elif asset.asset_type == Asset.AssetType.INDEX: - symbol = f"I:{asset.symbol}" + return f"I:{asset.symbol}" - # Forex Asset for Backtesting elif asset.asset_type == Asset.AssetType.FOREX: - # If quote_asset is None, throw an error - if quote_asset is None: - raise ValueError(f"quote_asset is required for asset type {asset.asset_type}") + if not quote_asset: + logger.error("No quote_asset provided for FOREX.") + return None + return f"C:{asset.symbol}{quote_asset.symbol}" - symbol = f"C:{asset.symbol}{quote_asset.symbol}" - - # Option Asset for Backtesting - Do a query to Polygon to get the ticker elif asset.asset_type == Asset.AssetType.OPTION: - # Needed so BackTest both old and existing contracts real_today = date.today() - expired = True if asset.expiration < real_today else False - - # Query for the historical Option Contract ticker backtest is looking for + expired = asset.expiration < real_today contracts = list( polygon_client.list_options_contracts( underlying_ticker=asset.symbol, expiration_date=asset.expiration, - contract_type=asset.right.lower(), + contract_type=asset.right.lower(), # 'call' or 'put' strike_price=asset.strike, expired=expired, - limit=10, + limit=100, ) ) + if not contracts: + msg = f"Unable to find option contract for {asset}" + logger.error(colored(msg, "red")) + return None + return contracts[0].ticker - if len(contracts) == 0: - text = colored(f"Unable to find option contract for {asset}", "red") - logging.debug(text) - return + else: + logger.error(f"Unsupported asset type: {asset.asset_type}") + return None - # Example: O:SPY230802C00457000 - symbol = contracts[0].ticker - elif asset.asset_type == Asset.AssetType.INDEX: - symbol = f"I:{asset.symbol}" +def validate_cache(force_cache_update: bool, asset: Asset, cache_file: Path, api_key: str): + """ + Placeholder if you want advanced checks for dividends, splits, etc. + Currently returns force_cache_update as is. + """ + return force_cache_update + +def get_trading_dates(asset: Asset, start: datetime, end: datetime): + """ + Return a list of valid daily sessions for the asset's exchange (or 7-day for CRYPTO). + """ + if asset.asset_type == Asset.AssetType.CRYPTO: + return [ + start.date() + timedelta(days=x) + for x in range((end.date() - start.date()).days + 1) + ] + + elif asset.asset_type in (Asset.AssetType.INDEX, Asset.AssetType.STOCK, Asset.AssetType.OPTION): + cal = mcal.get_calendar("NYSE") + elif asset.asset_type == Asset.AssetType.FOREX: + cal = mcal.get_calendar("CME_FX") else: - raise ValueError(f"Unsupported asset type for polygon: {asset.asset_type}") + raise ValueError(f"[ERROR] get_trading_dates: unsupported asset type {asset.asset_type}") - return symbol + df = get_cached_schedule(cal, start.date(), end.date()) + return df.index.date.tolist() -def build_cache_filename(asset: Asset, timespan: str): - """Helper function to create the cache filename for a given asset and timespan""" +def _get_trading_days(asset: Asset, start: datetime, end: datetime): + return get_trading_dates(asset, start, end) - lumibot_polygon_cache_folder = Path(LUMIBOT_CACHE_FOLDER) / "polygon" - # If It's an option then also add the expiration date, strike price and right to the filename - if asset.asset_type == "option": - if asset.expiration is None: - raise ValueError(f"Expiration date is required for option {asset} but it is None") +def get_missing_dates(df_all, asset, start: datetime, end: datetime): + """ + Identify which daily sessions are missing from df_all. + If asset is OPTION, only consider days up to expiration. - # Make asset.expiration datetime into a string like "YYMMDD" - expiry_string = asset.expiration.strftime("%y%m%d") - uniq_str = f"{asset.symbol}_{expiry_string}_{asset.strike}_{asset.right}" - else: - uniq_str = asset.symbol + We skip days strictly before start.date(). + """ + trading_days = _get_trading_days(asset, start, end) + logger.debug(f"get_missing_dates: computed trading_days={trading_days}") + + if asset.asset_type == Asset.AssetType.OPTION: + trading_days = [d for d in trading_days if d <= asset.expiration] + logger.debug(f"get_missing_dates: filtered for option expiration => {trading_days}") + + start_date_only = start.date() + end_date_only = end.date() + trading_days = [d for d in trading_days if d >= start_date_only and d <= end_date_only] + logger.debug(f"get_missing_dates: after bounding by start/end => {trading_days}") + + if df_all is None or df_all.empty: + logger.debug("get_missing_dates: df_all is empty => all trading_days are missing") + return trading_days + + existing_days = pd.Series(df_all.index.date).unique() + logger.debug(f"get_missing_dates: existing_days in df_all={existing_days}") - cache_filename = f"{asset.asset_type}_{uniq_str}_{timespan}.feather" - cache_file = lumibot_polygon_cache_folder / cache_filename - return cache_file + missing = sorted(set(trading_days) - set(existing_days)) + logger.debug(f"get_missing_dates: missing={missing}") + return missing -def get_missing_dates(df_all, asset, start, end): +def _load_from_duckdb(asset: Asset, timespan: str, start: datetime, end: datetime) -> pd.DataFrame: """ - Check if we have data for the full range - Later Query to Polygon will pad an extra full day to start/end dates so that there should never - be any gap with intraday data missing. + Load from DuckDB if data is stored. Return a DataFrame with datetime index. + If no table or no matching rows, returns empty DataFrame. - Parameters - ---------- - df_all : pd.DataFrame - Data loaded from the cache file - asset : Asset - Asset we are getting data for - start : datetime - Start date for the data requested - end : datetime - End date for the data requested + Additional debugging to see the actual query. + """ + conn = duckdb.connect(str(DUCKDB_DB_PATH), read_only=False) + asset_key = _asset_key(asset) + + query = f""" + SELECT * + FROM price_data + WHERE symbol='{asset_key}' + AND timespan='{timespan}' + AND datetime >= '{start.isoformat()}' + AND datetime <= '{end.isoformat()}' + ORDER BY datetime + """ + logger.debug(f"_load_from_duckdb: SQL=\n{query}") + + try: + df = conn.execute(query).fetchdf() + if df.empty: + logger.debug(f"_load_from_duckdb: No rows found in DB for symbol={asset_key}, timespan={timespan}") + return df + + df["datetime"] = pd.to_datetime(df["datetime"], utc=True) + df.set_index("datetime", inplace=True) + df.sort_index(inplace=True) + + logger.debug(f"_load_from_duckdb: loaded {len(df)} rows for symbol={asset_key}, timespan={timespan}") + if not df.empty: + logger.debug(f"_load_from_duckdb: min timestamp={df.index.min()}, max timestamp={df.index.max()}") + unique_dates = pd.Series(df.index.date).unique() + logger.debug(f"_load_from_duckdb: unique dates in loaded data => {unique_dates}") - Returns - ------- - list[datetime.date] - A list of dates that we need to get data for + return df + + except duckdb.CatalogException: + logger.debug(f"_load_from_duckdb: Table does not exist yet for symbol={asset_key}, timespan={timespan}") + return pd.DataFrame() + finally: + conn.close() + + +def _store_in_duckdb(asset: Asset, timespan: str, df_in: pd.DataFrame): + """ + Insert newly fetched data into DuckDB 'price_data'. + Upsert logic: only insert rows not already present. """ - trading_dates = get_trading_dates(asset, start, end) + if df_in.empty: + logger.debug("_store_in_duckdb called with empty DataFrame. No insert performed.") + return + + new_df = df_in.copy(deep=True) + columns_needed = ["datetime", "open", "high", "low", "close", "volume", "symbol", "timespan"] + for c in columns_needed: + if c not in new_df.columns: + new_df.loc[:, c] = None + + if new_df.index.name == "datetime": + if "datetime" in new_df.columns: + new_df.drop(columns=["datetime"], inplace=True) + new_df.reset_index(drop=False, inplace=True) + + new_df = new_df[columns_needed] + + asset_key = _asset_key(asset) + new_df["symbol"] = asset_key + new_df["timespan"] = timespan + + conn = duckdb.connect(str(DUCKDB_DB_PATH), read_only=False) + schema_ddl = """ + CREATE TABLE IF NOT EXISTS price_data ( + symbol VARCHAR, + timespan VARCHAR, + datetime TIMESTAMP, + open DOUBLE, + high DOUBLE, + low DOUBLE, + close DOUBLE, + volume DOUBLE + ); + """ + conn.execute(schema_ddl) - # For Options, don't need any dates passed the expiration date - if asset.asset_type == "option": - trading_dates = [x for x in trading_dates if x <= asset.expiration] + conn.execute("DROP TABLE IF EXISTS tmp_table") + conn.execute( + """ + CREATE TEMPORARY TABLE tmp_table( + symbol VARCHAR, + timespan VARCHAR, + datetime TIMESTAMP, + open DOUBLE, + high DOUBLE, + low DOUBLE, + close DOUBLE, + volume DOUBLE + ); + """ + ) - if df_all is None or not len(df_all) or df_all.empty: - return trading_dates + conn.register("df_newdata", new_df) + insert_sql = """ + INSERT INTO tmp_table + SELECT symbol, timespan, datetime, open, high, low, close, volume + FROM df_newdata; + """ + conn.execute(insert_sql) + + upsert_sql = f""" + INSERT INTO price_data + SELECT t.* + FROM tmp_table t + LEFT JOIN price_data p + ON t.symbol = p.symbol + AND t.timespan = p.timespan + AND t.datetime = p.datetime + WHERE p.symbol IS NULL + """ + conn.execute(upsert_sql) - # It is possible to have full day gap in the data if previous queries were far apart - # Example: Query for 8/1/2023, then 8/31/2023, then 8/7/2023 - # Whole days are easy to check for because we can just check the dates in the index - dates = pd.Series(df_all.index.date).unique() - missing_dates = sorted(set(trading_dates) - set(dates)) + check_sql = f""" + SELECT COUNT(*) + FROM price_data + WHERE symbol='{asset_key}' AND timespan='{timespan}' + """ + count_after = conn.execute(check_sql).fetchone()[0] + logger.debug(f"Upsert completed. Now {count_after} total rows in 'price_data' " + f"for symbol='{asset_key}', timespan='{timespan}'.") + conn.close() - # TODO: This code works AFAIK, But when i enable it the tests for "test_polygon_missing_day_caching" and - # i don't know why nor how to fix this code or the tests. So im leaving it disabled for now. If you have problems - # with NANs in cached polygon data, you can try to enable this code and fix the tests. - # # Find any dates with nan values in the df_all DataFrame - # missing_dates += df_all[df_all.isnull().all(axis=1)].index.date.tolist() - # - # # make sure the dates are unique - # missing_dates = list(set(missing_dates)) - # missing_dates.sort() - # - # # finally, filter out any dates that are not in start/end range (inclusive) - # missing_dates = [d for d in missing_dates if start.date() <= d <= end.date()] +def _transform_polygon_data(results_list): + """ + Combine chunk results into one DataFrame, rename columns, set datetime index, localize to UTC. + """ + if not results_list: + return pd.DataFrame() - return missing_dates + df = pd.DataFrame(results_list) + if df.empty: + return df + rename_cols = {"o": "open", "h": "high", "l": "low", "c": "close", "v": "volume"} + df = df.rename(columns=rename_cols, errors="ignore") -def load_cache(cache_file): - """Load the data from the cache file and return a DataFrame with a DateTimeIndex""" - df_feather = pd.read_feather(cache_file) + if "t" in df.columns: + df["datetime"] = pd.to_datetime(df["t"], unit="ms") + df.drop(columns=["t"], inplace=True) + elif "timestamp" in df.columns: + df["datetime"] = pd.to_datetime(df["timestamp"], unit="ms") + df.drop(columns=["timestamp"], inplace=True) - # Set the 'datetime' column as the index of the DataFrame - df_feather.set_index("datetime", inplace=True) + df.set_index("datetime", inplace=True) + df.sort_index(inplace=True) - df_feather.index = pd.to_datetime( - df_feather.index - ) # TODO: Is there some way to speed this up? It takes several times longer than just reading the feather file - df_feather = df_feather.sort_index() + if df.index.tzinfo is None: + df.index = df.index.tz_localize("UTC") - # Check if the index is already timezone aware - if df_feather.index.tzinfo is None: - # Set the timezone to UTC - df_feather.index = df_feather.index.tz_localize("UTC") + return df - return df_feather +def get_option_chains_with_cache(polygon_client: RESTClient, asset: Asset, current_date: date): + """ + Returns option chain data (calls+puts) from Polygon. Not stored in DuckDB by default. + """ + option_contracts = { + "Multiplier": None, + "Exchange": None, + "Chains": {"CALL": defaultdict(list), "PUT": defaultdict(list)}, + } + real_today = date.today() + expired_list = [True, False] if real_today - current_date <= timedelta(days=31) else [True] + + polygon_contracts_list = [] + for expired in expired_list: + polygon_contracts_list.extend( + list( + polygon_client.list_options_contracts( + underlying_ticker=asset.symbol, + expiration_date_gte=current_date, + expired=expired, + limit=1000, + ) + ) + ) -def update_cache(cache_file, df_all, missing_dates=None): - """Update the cache file with the new data. Missing dates are added as empty (all NaN) - rows before it is saved to the cache file. + for pc in polygon_contracts_list: + if pc.shares_per_contract != 100: + continue + exchange = pc.primary_exchange + right = pc.contract_type.upper() + exp_date = pc.expiration_date + strike = pc.strike_price - Parameters - ---------- - cache_file : Path - The path to the cache file - df_all : pd.DataFrame - The DataFrame with the data we want to cache - missing_dates : list[datetime.date] - A list of dates that are missing bars from Polygon""" + option_contracts["Multiplier"] = pc.shares_per_contract + option_contracts["Exchange"] = exchange + option_contracts["Chains"][right][exp_date].append(strike) - if df_all is None: - df_all = pd.DataFrame() + return option_contracts - if missing_dates: - missing_df = pd.DataFrame( - [datetime(year=d.year, month=d.month, day=d.day, tzinfo=LUMIBOT_DEFAULT_PYTZ) for d in missing_dates], - columns=["datetime"]) - missing_df.set_index("datetime", inplace=True) - # Set the timezone to UTC - missing_df.index = missing_df.index.tz_convert("UTC") - df_concat = pd.concat([df_all, missing_df]).sort_index() - # Let's be careful and check for duplicates to avoid corrupting the feather file. - if df_concat.index.duplicated().any(): - logging.warn(f"Duplicate index entries found when trying to update Polygon cache {cache_file}") - if df_all.index.duplicated().any(): - logging.warn("The duplicate index entries were already in df_all") - else: - # All good, persist with the missing dates added - df_all = df_concat - - if len(df_all) > 0: - # Create the directory if it doesn't exist - cache_file.parent.mkdir(parents=True, exist_ok=True) - - # Reset the index to convert DatetimeIndex to a regular column - df_all_reset = df_all.reset_index() - - # Save the data to a feather file - df_all_reset.to_feather(cache_file) - - -def update_polygon_data(df_all, result): - """ - Update the DataFrame with the new data from Polygon - Parameters - ---------- - df_all : pd.DataFrame - A DataFrame with the data we already have - result : list - A List of dictionaries with the new data from Polygon - Format: [{'o': 1.0, 'h': 2.0, 'l': 3.0, 'c': 4.0, 'v': 5.0, 't': 116120000000}] - """ - df = pd.DataFrame(result) - if not df.empty: - # Rename columns - df = df.rename( - columns={ - "o": "open", - "h": "high", - "l": "low", - "c": "close", - "v": "volume", - } + +def _fetch_polygon_data_chunk(polygon_client, symbol, chunk_start, chunk_end, timespan): + """ + Fetch data for one chunk, locking if needed for rate limit on the free plan. + """ + with RATE_LIMIT_LOCK: + results = polygon_client.get_aggs( + ticker=symbol, + from_=chunk_start, + to=chunk_end, + multiplier=1, + timespan=timespan, + limit=50000, ) + return results if results else [] - # Create a datetime column and set it as the index - timestamp_col = "t" if "t" in df.columns else "timestamp" - df = df.assign(datetime=pd.to_datetime(df[timestamp_col], unit="ms")) - df = df.set_index("datetime").sort_index() - # Set the timezone to UTC - df.index = df.index.tz_localize("UTC") +def _group_missing_dates(missing_dates): + """ + Group consecutive missing days into ~30-day chunks for fewer polygon calls. + We return a list of (start_datetime, end_datetime) pairs in UTC. + """ + if not missing_dates: + return [] - if df_all is None or df_all.empty: - df_all = df + missing_dates = sorted(missing_dates) + grouped = [] + + chunk_start = missing_dates[0] + chunk_end = chunk_start + + for d in missing_dates[1:]: + if (d - chunk_end).days <= 1: + chunk_end = d else: - df_all = pd.concat([df_all, df]).sort_index() - df_all = df_all[~df_all.index.duplicated(keep="first")] # Remove any duplicate rows + grouped.append((chunk_start, chunk_end)) + chunk_start = d + chunk_end = d - return df_all + grouped.append((chunk_start, chunk_end)) + final_chunks = [] + delta_30 = timedelta(days=30) + active_start, active_end = grouped[0] -class PolygonClient(RESTClient): - ''' Rate Limited RESTClient with factory method ''' + for (s, e) in grouped[1:]: + if e - active_start <= delta_30: + if e > active_end: + active_end = e + else: + final_chunks.append((active_start, active_end)) + active_start, active_end = s, e + final_chunks.append((active_start, active_end)) - WAIT_SECONDS_RETRY = 60 + range_list = [] + for (s, e) in final_chunks: + start_dt = datetime(s.year, s.month, s.day, tzinfo=LUMIBOT_DEFAULT_PYTZ) + end_dt = datetime(e.year, e.month, e.day, 23, 59, tzinfo=LUMIBOT_DEFAULT_PYTZ) + range_list.append((start_dt, end_dt)) + + return range_list - @classmethod - def create(cls, *args, **kwargs) -> RESTClient: - """ - Factory method to create a RESTClient or PolygonClient instance. - The method uses environment variables to determine default values for the API key - and subscription type. If the `api_key` is not provided in `kwargs`, it defaults - to the value of the `POLYGON_API_KEY` environment variable. - If the environment variable is not set, it defaults to False. +def _asset_key(asset: Asset) -> str: + """ + Construct a unique symbol key for storing in DuckDB. For OPTIONS, do e.g.: + "SPY_250114_577_CALL" + """ + if asset.asset_type == Asset.AssetType.OPTION: + if not asset.expiration: + raise ValueError("Option asset requires expiration date.") + expiry_str = asset.expiration.strftime("%y%m%d") + return f"{asset.symbol}_{expiry_str}_{asset.strike}_{asset.right.upper()}" + else: + return asset.symbol - Keyword Arguments: - api_key : str, optional - The API key to authenticate with the service. Defaults to the value of the - `POLYGON_API_KEY` environment variable if not provided. - Returns: - RESTClient - An instance of RESTClient or PolygonClient. +def _store_placeholder_day(asset: Asset, timespan: str, single_date: date): + """ + Insert *FULL DAY* (24-hour) placeholder rows into DuckDB for the given day, + so we don't keep re-downloading it if it truly has no data (or partial data). - Examples: - --------- - Using default environment variables: + Data Accuracy: + - Real data overwrites these placeholders if available. + - We never lose data or skip times. - >>> client = PolygonClient.create() + We carefully create naive midnights and localize them to UTC + to avoid the "Inferred time zone not equal to passed time zone" error. + """ + import pytz # For explicit UTC usage + + logger.debug(f"Storing placeholder *24-hour UTC* rows for date={single_date} " + f"on symbol={_asset_key(asset)}, timespan={timespan}") + + naive_start = datetime(single_date.year, single_date.month, single_date.day, 0, 0, 0) + naive_end = naive_start + timedelta(days=1, microseconds=-1) + + day_start = pytz.UTC.localize(naive_start) + day_end = pytz.UTC.localize(naive_end) + + logger.debug(f"_store_placeholder_day: day_start (UTC)={day_start}, day_end (UTC)={day_end}") + + try: + # Optionally, for stocks, we could insert only 9:30–16:00 placeholders + if (asset.asset_type in (Asset.AssetType.STOCK, Asset.AssetType.OPTION) and timespan == "minute"): + # 9:30–16:00 Eastern, converted to UTC + # For more robust, consider using a calendar for half-days, etc. + # But this is an example of partial day placeholders: + open_eastern = datetime(single_date.year, single_date.month, single_date.day, 9, 30) + close_eastern = datetime(single_date.year, single_date.month, single_date.day, 16, 0) + from_date = pd.Timestamp(open_eastern, tz="America/New_York").tz_convert("UTC") + to_date = pd.Timestamp(close_eastern, tz="America/New_York").tz_convert("UTC") + rng = pd.date_range(start=from_date, end=to_date, freq="min", tz="UTC") + else: + rng = pd.date_range(start=day_start, end=day_end, freq="min", tz="UTC") + except Exception as e: + logger.critical(f"date_range failed for day={single_date} with error: {e}") + raise - Providing an API key explicitly: + if len(rng) == 0: + logger.debug(f"_store_placeholder_day: no minutes from {day_start} to {day_end}??? skipping.") + return - >>> client = PolygonClient.create(api_key='your_api_key_here') + df_placeholder = pd.DataFrame( + { + "datetime": rng, + "open": [None]*len(rng), + "high": [None]*len(rng), + "low": [None]*len(rng), + "close": [None]*len(rng), + "volume": [None]*len(rng), + } + ).set_index("datetime") + + logger.debug(f"_store_placeholder_day: day={single_date}, inserting {len(df_placeholder)} placeholders.") + logger.debug(f"min placeholder={df_placeholder.index.min()}, max placeholder={df_placeholder.index.max()}") + + _store_in_duckdb(asset, timespan, df_placeholder) - """ - if 'api_key' not in kwargs: - kwargs['api_key'] = POLYGON_API_KEY +def _fill_partial_days(asset: Asset, timespan: str, newly_fetched: pd.DataFrame): + """ + After we download real data for certain days, fill in placeholders + for any missing minutes in each day of 'newly_fetched'. + We do a 24h approach, so re-store placeholders in case the day only got partial data. + """ + if newly_fetched.empty: + return + + days_updated = pd.Series(newly_fetched.index.date).unique() + for day in days_updated: + logger.debug(f"_fill_partial_days: day={day}, calling _store_placeholder_day(24h) again.") + _store_placeholder_day(asset, timespan, day) + + +class PolygonClient(RESTClient): + """ + Thin subclass of polygon.RESTClient that retries on MaxRetryError with a cooldown. + Helps with free-tier rate limits. + """ + WAIT_SECONDS_RETRY = 60 + + @classmethod + def create(cls, *args, **kwargs) -> RESTClient: + if "api_key" not in kwargs: + kwargs["api_key"] = POLYGON_API_KEY return cls(*args, **kwargs) def _get(self, *args, **kwargs): while True: try: return super()._get(*args, **kwargs) - except MaxRetryError as e: - url = urlunparse(urlparse(kwargs['path'])._replace(query="")) - - message = ( - "Polygon rate limit reached.\n\n" - f"REST API call affected: {url}\n\n" - f"Sleeping for {PolygonClient.WAIT_SECONDS_RETRY} seconds seconds before trying again.\n\n" - "If you want to avoid this, consider a paid subscription with Polygon at https://polygon.io/?utm_source=affiliate&utm_campaign=lumi10\n" - "Please use the full link to give us credit for the sale, it helps support this project.\n" - "You can use the coupon code 'LUMI10' for 10% off." + url = urlunparse(urlparse(kwargs["path"])._replace(query="")) + msg = ( + "Polygon rate limit reached. " + f"Sleeping {PolygonClient.WAIT_SECONDS_RETRY} seconds.\n" + f"REST API call: {url}\n\n" + "Consider upgrading to a paid subscription at https://polygon.io\n" + "Use code 'LUMI10' for 10% off." ) + logging.critical(msg) + logging.critical(f"Error: {e}") + time.sleep(PolygonClient.WAIT_SECONDS_RETRY) - colored_message = colored(message, "red") - logging.error(colored_message) - logging.debug(f"Error: {e}") - time.sleep(PolygonClient.WAIT_SECONDS_RETRY) +# ----------------------------------------------------------------------- +# Additional Helper: _drop_placeholder_rows +# ----------------------------------------------------------------------- +def _drop_placeholder_rows(df_in: pd.DataFrame) -> pd.DataFrame: + """ + Removes placeholder rows (where open/close/volume are all NaN), + returning only real data to tests or strategies. + The placeholders remain in DuckDB so re-downloading is avoided. + """ + if df_in.empty: + return df_in + + # If everything is NaN in open, close, high, low, volume → mark as placeholders + mask_real = ~( + df_in["open"].isna() & df_in["close"].isna() & df_in["volume"].isna() + ) + return df_in.loc[mask_real].copy() \ No newline at end of file diff --git a/setup.py b/setup.py index 6ab0eb97a..46d76950c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="lumibot", - version="3.8.26", + version="3.8.27", author="Robert Grzesik", author_email="rob@lumiwealth.com", description="Backtesting and Trading Library, Made by Lumiwealth", diff --git a/tests/test_polygon_helper.py b/tests/test_polygon_helper.py index 8d1387f39..764cd1a36 100644 --- a/tests/test_polygon_helper.py +++ b/tests/test_polygon_helper.py @@ -1,3 +1,16 @@ +""" +test_polygon_helper.py +---------------------- +Tests for the new DuckDB-based 'polygon_helper.py'. These tests: + 1) Check missing dates, trading dates, and get_polygon_symbol as before. + 2) Validate get_price_data_from_polygon(...) with a mock PolygonClient, ensuring it + stores data in DuckDB and then reads from DuckDB (caching). + 3) Provide coverage for the DuckDB-specific helpers (like _asset_key, _load_from_duckdb, + _store_in_duckdb, and _transform_polygon_data). + 4) Remove references to the old feather-based caching logic (build_cache_filename, + load_cache, update_cache, update_polygon_data) that no longer exist in the new code. +""" + import datetime from pathlib import Path @@ -7,483 +20,322 @@ import pytz from lumibot.entities import Asset +# We'll import everything as `ph` for polygon_helper from lumibot.tools import polygon_helper as ph +############################################################################### +# HELPER CLASSES / FIXTURES +############################################################################### + class FakeContract: - def __init__(self, ticker): + """Fake contract object simulating a contract returned by polygon_client.list_options_contracts(...)""" + def __init__(self, ticker: str): self.ticker = ticker -class TestPolygonHelpers: - def test_build_cache_filename(self, mocker, tmpdir): - asset = Asset("SPY") - timespan = "1D" - mocker.patch.object(ph, "LUMIBOT_CACHE_FOLDER", tmpdir) - expected = tmpdir / "polygon" / "stock_SPY_1D.feather" - assert ph.build_cache_filename(asset, timespan) == expected - - expire_date = datetime.date(2023, 8, 1) - option_asset = Asset("SPY", asset_type="option", expiration=expire_date, strike=100, right="CALL") - expected = tmpdir / "polygon" / "option_SPY_230801_100_CALL_1D.feather" - assert ph.build_cache_filename(option_asset, timespan) == expected - - # Bad option asset with no expiration - option_asset = Asset("SPY", asset_type="option", strike=100, right="CALL") - with pytest.raises(ValueError): - ph.build_cache_filename(option_asset, timespan) +@pytest.fixture +def ephemeral_duckdb(tmp_path): + """ + A fixture that points polygon_helper's DUCKDB_DB_PATH at a temporary file + within 'tmp_path'. Ensures each test runs with a blank ephemeral DB. + Restores the original DUCKDB_DB_PATH afterwards. + """ + original_path = ph.DUCKDB_DB_PATH + test_db_path = tmp_path / "polygon_cache.duckdb" + ph.DUCKDB_DB_PATH = test_db_path + yield test_db_path + ph.DUCKDB_DB_PATH = original_path + - def test_missing_dates(self): - # Setup some basics +############################################################################### +# TEST: Missing Dates, Trading Dates, get_polygon_symbol +############################################################################### + + +class TestPolygonHelpersBasic: + """ + Tests for get_missing_dates, get_trading_dates, get_polygon_symbol. + """ + + def test_get_missing_dates(self): + """Check that get_missing_dates(...) handles typical stock data and option expiration logic.""" asset = Asset("SPY") - start_date = datetime.datetime(2023, 8, 1, 9, 30) # Tuesday + start_date = datetime.datetime(2023, 8, 1, 9, 30) end_date = datetime.datetime(2023, 8, 1, 10, 0) - # Empty DataFrame - missing_dates = ph.get_missing_dates(pd.DataFrame(), asset, start_date, end_date) - assert len(missing_dates) == 1 - assert datetime.date(2023, 8, 1) in missing_dates - - # Small dataframe that meets start/end criteria - index = pd.date_range(start_date, end_date, freq="1min") - df_all = pd.DataFrame( - { - "open": np.random.uniform(0, 100, len(index)).round(2), - "close": np.random.uniform(0, 100, len(index)).round(2), - "volume": np.random.uniform(0, 10000, len(index)).round(2), - }, - index=index, - ) - missing_dates = ph.get_missing_dates(df_all, asset, start_date, end_date) - assert not missing_dates - - # Small dataframe that does not meet start/end criteria - end_date = datetime.datetime(2023, 8, 2, 13, 0) # Weds - missing_dates = ph.get_missing_dates(df_all, asset, start_date, end_date) - assert missing_dates - assert datetime.date(2023, 8, 2) in missing_dates - - # Asking for data beyond option expiration - We have all the data - end_date = datetime.datetime(2023, 8, 3, 13, 0) - expire_date = datetime.date(2023, 8, 2) - index = pd.date_range(start_date, end_date, freq="1min") - df_all = pd.DataFrame( - { - "open": np.random.uniform(0, 100, len(index)).round(2), - "close": np.random.uniform(0, 100, len(index)).round(2), - "volume": np.random.uniform(0, 10000, len(index)).round(2), - }, - index=index, - ) - option_asset = Asset("SPY", asset_type="option", expiration=expire_date, strike=100, right="CALL") - missing_dates = ph.get_missing_dates(df_all, option_asset, start_date, end_date) - assert not missing_dates + # 1) With empty DataFrame => entire date is missing + missing = ph.get_missing_dates(pd.DataFrame(), asset, start_date, end_date) + assert len(missing) == 1 + assert datetime.date(2023, 8, 1) in missing + + # 2) Full coverage => no missing + idx = pd.date_range(start_date, end_date, freq="1min") + df_cover = pd.DataFrame({ + "open": np.random.uniform(0, 100, len(idx)), + "close": np.random.uniform(0, 100, len(idx)), + "volume": np.random.uniform(0, 10000, len(idx)) + }, index=idx) + missing2 = ph.get_missing_dates(df_cover, asset, start_date, end_date) + assert not missing2 + + # 3) Extended range => next day missing + end_date2 = datetime.datetime(2023, 8, 2, 13, 0) + missing3 = ph.get_missing_dates(df_cover, asset, start_date, end_date2) + assert len(missing3) == 1 + assert datetime.date(2023, 8, 2) in missing3 + + # 4) Option expiration scenario + option_exp_date = datetime.date(2023, 8, 2) + option_asset = Asset("SPY", asset_type="option", expiration=option_exp_date, + strike=100, right="CALL") + extended_end = datetime.datetime(2023, 8, 3, 13, 0) + idx2 = pd.date_range(start_date, extended_end, freq="1min") + df_all2 = pd.DataFrame({ + "open": np.random.uniform(0, 100, len(idx2)), + "close": np.random.uniform(0, 100, len(idx2)), + "volume": np.random.uniform(0, 10000, len(idx2)) + }, index=idx2) + + missing_opt = ph.get_missing_dates(df_all2, option_asset, start_date, extended_end) + # Because option expires 8/2 => no missing for 8/3 even though there's data for that day + assert not missing_opt def test_get_trading_dates(self): - # Unsupported Asset Type - asset = Asset("SPY", asset_type="future") - start_date = datetime.datetime(2023, 7, 1, 9, 30) # Saturday - end_date = datetime.datetime(2023, 7, 10, 10, 0) # Monday + """Test get_trading_dates(...) with stock, option, forex, crypto, plus an unsupported type.""" + # 1) Future => raises ValueError + asset_fut = Asset("SPY", asset_type="future") + sdate = datetime.datetime(2023, 7, 1, 9, 30) + edate = datetime.datetime(2023, 7, 10, 10, 0) with pytest.raises(ValueError): - ph.get_trading_dates(asset, start_date, end_date) - - # Stock Asset - asset = Asset("SPY") - start_date = datetime.datetime(2023, 7, 1, 9, 30) # Saturday - end_date = datetime.datetime(2023, 7, 10, 10, 0) # Monday - trading_dates = ph.get_trading_dates(asset, start_date, end_date) - assert datetime.date(2023, 7, 1) not in trading_dates, "Market is closed on Saturday" - assert datetime.date(2023, 7, 3) in trading_dates - assert datetime.date(2023, 7, 4) not in trading_dates, "Market is closed on July 4th" - assert datetime.date(2023, 7, 9) not in trading_dates, "Market is closed on Sunday" - assert datetime.date(2023, 7, 10) in trading_dates - assert datetime.date(2023, 7, 11) not in trading_dates, "Outside of end_date" - - # Option Asset - expire_date = datetime.date(2023, 8, 1) - option_asset = Asset("SPY", asset_type="option", expiration=expire_date, strike=100, right="CALL") - start_date = datetime.datetime(2023, 7, 1, 9, 30) # Saturday - end_date = datetime.datetime(2023, 7, 10, 10, 0) # Monday - trading_dates = ph.get_trading_dates(option_asset, start_date, end_date) - assert datetime.date(2023, 7, 1) not in trading_dates, "Market is closed on Saturday" - assert datetime.date(2023, 7, 3) in trading_dates - assert datetime.date(2023, 7, 4) not in trading_dates, "Market is closed on July 4th" - assert datetime.date(2023, 7, 9) not in trading_dates, "Market is closed on Sunday" - - # Forex Asset - Trades weekdays opens Sunday at 5pm and closes Friday at 5pm - forex_asset = Asset("ES", asset_type="forex") - start_date = datetime.datetime(2023, 7, 1, 9, 30) # Saturday - end_date = datetime.datetime(2023, 7, 10, 10, 0) # Monday - trading_dates = ph.get_trading_dates(forex_asset, start_date, end_date) - assert datetime.date(2023, 7, 1) not in trading_dates, "Market is closed on Saturday" - assert datetime.date(2023, 7, 4) in trading_dates - assert datetime.date(2023, 7, 10) in trading_dates - assert datetime.date(2023, 7, 11) not in trading_dates, "Outside of end_date" - - # Crypto Asset - Trades 24/7 - crypto_asset = Asset("BTC", asset_type="crypto") - start_date = datetime.datetime(2023, 7, 1, 9, 30) # Saturday - end_date = datetime.datetime(2023, 7, 10, 10, 0) # Monday - trading_dates = ph.get_trading_dates(crypto_asset, start_date, end_date) - assert datetime.date(2023, 7, 1) in trading_dates - assert datetime.date(2023, 7, 4) in trading_dates - assert datetime.date(2023, 7, 10) in trading_dates + ph.get_trading_dates(asset_fut, sdate, edate) + + # 2) Stock => NYSE + asset_stk = Asset("SPY") + tdates = ph.get_trading_dates(asset_stk, sdate, edate) + assert datetime.date(2023, 7, 1) not in tdates # Saturday + assert datetime.date(2023, 7, 3) in tdates + assert datetime.date(2023, 7, 4) not in tdates # Holiday + assert datetime.date(2023, 7, 9) not in tdates # Sunday + assert datetime.date(2023, 7, 10) in tdates + + # 3) Option => same as stock, but eventually truncated by expiration in get_missing_dates + op_asset = Asset("SPY", asset_type="option", expiration=datetime.date(2023, 8, 1), + strike=100, right="CALL") + tdates_op = ph.get_trading_dates(op_asset, sdate, edate) + assert datetime.date(2023, 7, 3) in tdates_op + + # 4) Forex => "CME_FX" + fx_asset = Asset("EURUSD", asset_type="forex") + tdates_fx = ph.get_trading_dates(fx_asset, sdate, edate) + # e.g. 7/1 is Saturday => not included + assert datetime.date(2023, 7, 1) not in tdates_fx + + # 5) Crypto => 24/7 + c_asset = Asset("BTC", asset_type="crypto") + tdates_c = ph.get_trading_dates(c_asset, sdate, edate) + assert datetime.date(2023, 7, 1) in tdates_c # Saturday => included for crypto def test_get_polygon_symbol(self, mocker): - polygon_client = mocker.MagicMock() + """Test get_polygon_symbol(...) for Stock, Index, Forex, Crypto, and Option.""" + poly_mock = mocker.MagicMock() - # ------- Unsupported Asset Type - asset = Asset("SPY", asset_type="future") + # 1) Future => ValueError + fut_asset = Asset("ZB", asset_type="future") with pytest.raises(ValueError): - ph.get_polygon_symbol(asset, polygon_client) - - # ------- Stock - asset = Asset("SPY") - assert ph.get_polygon_symbol(asset, polygon_client) == "SPY" + ph.get_polygon_symbol(fut_asset, poly_mock) - # ------- Index - asset = Asset("SPX", asset_type="index") - assert ph.get_polygon_symbol(asset, polygon_client) == "I:SPX" + # 2) Stock => "SPY" + st_asset = Asset("SPY", asset_type="stock") + assert ph.get_polygon_symbol(st_asset, poly_mock) == "SPY" - # ------- Option - expire_date = datetime.date(2023, 8, 1) - option_asset = Asset("SPY", asset_type="option", expiration=expire_date, strike=100, right="CALL") - # Option with no contracts - Error - polygon_client.list_options_contracts.return_value = [] + # 3) Index => "I:SPX" + idx_asset = Asset("SPX", asset_type="index") + assert ph.get_polygon_symbol(idx_asset, poly_mock) == "I:SPX" - # Option with contracts - Works - expected_ticker = "O:SPY230801C00100000" - polygon_client.list_options_contracts.return_value = [FakeContract(expected_ticker)] - assert ph.get_polygon_symbol(option_asset, polygon_client) == expected_ticker + # 4) Forex => must pass quote_asset or error + fx_asset = Asset("EUR", asset_type="forex") + with pytest.raises(ValueError): + ph.get_polygon_symbol(fx_asset, poly_mock) + quote = Asset("USD", asset_type="forex") + sym_fx = ph.get_polygon_symbol(fx_asset, poly_mock, quote_asset=quote) + assert sym_fx == "C:EURUSD" - # -------- Crypto + # 5) Crypto => "X:BTCUSD" if no quote crypto_asset = Asset("BTC", asset_type="crypto") - assert ph.get_polygon_symbol(crypto_asset, polygon_client) == "X:BTCUSD" - - # -------- Forex - forex_asset = Asset("ES", asset_type="forex") - # Errors without a Quote Asset - with pytest.raises(ValueError): - ph.get_polygon_symbol(forex_asset, polygon_client) - # Works with a Quote Asset - quote_asset = Asset("USD", asset_type="forex") - assert ph.get_polygon_symbol(forex_asset, polygon_client, quote_asset) == "C:ESUSD" - - def test_load_data_from_cache(self, tmpdir): - # Setup some basics - cache_file = tmpdir / "stock_SPY_1D.feather" - - # No cache file - with pytest.raises(FileNotFoundError): - ph.load_cache(cache_file) - - # Cache file exists - df = pd.DataFrame( - { - "close": [2, 3, 4, 5, 6], - "open": [1, 2, 3, 4, 5], - "datetime": [ - "2023-07-01 09:30:00-04:00", - "2023-07-01 09:31:00-04:00", - "2023-07-01 09:32:00-04:00", - "2023-07-01 09:33:00-04:00", - "2023-07-01 09:34:00-04:00", - ], - } - ) - df.to_feather(cache_file) - df_loaded = ph.load_cache(cache_file) - assert len(df_loaded) - assert df_loaded["close"].iloc[0] == 2 - assert df_loaded.index[0] == pd.DatetimeIndex(["2023-07-01 09:30:00-04:00"])[0] - - # Dataframe with no Timezone - df = pd.DataFrame( - { - "close": [2, 3, 4, 5, 6], - "open": [1, 2, 3, 4, 5], - "datetime": [ - "2023-07-01 09:30:00", - "2023-07-01 09:31:00", - "2023-07-01 09:32:00", - "2023-07-01 09:33:00", - "2023-07-01 09:34:00", - ], - } - ) - df.to_feather(cache_file) - df_loaded = ph.load_cache(cache_file) - assert len(df_loaded) - assert df_loaded["close"].iloc[0] == 2 - assert df_loaded.index[0] == pd.DatetimeIndex(["2023-07-01 09:30:00-00:00"])[0] - - def test_update_cache(self, tmpdir): - cache_file = Path(tmpdir / "polygon" / "stock_SPY_1D.feather") - df = pd.DataFrame( - { - "close": [2, 3, 4, 5, 6], - "open": [1, 2, 3, 4, 5], - "datetime": [ - "2023-07-01 09:30:00-04:00", - "2023-07-01 09:31:00-04:00", - "2023-07-01 09:32:00-04:00", - "2023-07-01 09:33:00-04:00", - "2023-07-01 09:34:00-04:00", - ], - } - ) - - # Empty DataFrame, don't write cache file - ph.update_cache(cache_file, df_all=pd.DataFrame()) - assert not cache_file.exists() - - # No changes in data, write file just in case we got comparison wrong. - ph.update_cache(cache_file, df_all=df) - assert cache_file.exists() - - # Changes in data, write cache file - ph.update_cache(cache_file, df_all=df) - assert cache_file.exists() - - def test_update_polygon_data(self): - # Test with empty dataframe and no new data - df_all = None - poly_result = [] - df_new = ph.update_polygon_data(df_all, poly_result) - assert not df_new - - # Test with empty dataframe and new data - poly_result = [ - {"o": 1, "h": 4, "l": 1, "c": 2, "v": 100, "t": 1690896600000}, - {"o": 5, "h": 8, "l": 3, "c": 7, "v": 100, "t": 1690896660000}, - ] - df_all = None - df_new = ph.update_polygon_data(df_all, poly_result) - assert len(df_new) == 2 - assert df_new["close"].iloc[0] == 2 - assert df_new.index[0] == pd.DatetimeIndex(["2023-08-01 13:30:00-00:00"])[0] - - # Test with existing dataframe and new data - df_all = df_new - poly_result = [ - {"o": 9, "h": 12, "l": 7, "c": 10, "v": 100, "t": 1690896720000}, - {"o": 13, "h": 16, "l": 11, "c": 14, "v": 100, "t": 1690896780000}, - ] - df_new = ph.update_polygon_data(df_all, poly_result) - assert len(df_new) == 4 - assert df_new["close"].iloc[0] == 2 - assert df_new["close"].iloc[2] == 10 - assert df_new.index[0] == pd.DatetimeIndex(["2023-08-01 13:30:00-00:00"])[0] - assert df_new.index[2] == pd.DatetimeIndex(["2023-08-01 13:32:00-00:00"])[0] - - # Test with some overlapping rows - df_all = df_new - poly_result = [ - {"o": 17, "h": 20, "l": 15, "c": 18, "v": 100, "t": 1690896780000}, - {"o": 21, "h": 24, "l": 19, "c": 22, "v": 100, "t": 1690896840000}, + assert ph.get_polygon_symbol(crypto_asset, poly_mock) == "X:BTCUSD" + + # 6) Option => if no contracts => returns None + poly_mock.list_options_contracts.return_value = [] + op_asset = Asset("SPY", asset_type="option", expiration=datetime.date(2024, 1, 14), + strike=577, right="CALL") + sym_none = ph.get_polygon_symbol(op_asset, poly_mock) + assert sym_none is None + + # 7) Option => valid => returns the first + poly_mock.list_options_contracts.return_value = [FakeContract("O:SPY240114C00577000")] + sym_op = ph.get_polygon_symbol(op_asset, poly_mock) + assert sym_op == "O:SPY240114C00577000" + + +############################################################################### +# TEST: get_price_data_from_polygon(...) with a Mock PolygonClient +############################################################################### + + +class TestPriceDataCache: + """ + Tests get_price_data_from_polygon(...) to confirm: + - It queries Polygon on first call + - It caches data in DuckDB + - It does not re-query Polygon on second call (unless force_cache_update=True) + """ + + def test_get_price_data_from_polygon(self, mocker, tmp_path, ephemeral_duckdb): + """Ensures we store data on first call, then read from DuckDB on second call.""" + # Mock the PolygonClient class + poly_mock = mocker.MagicMock() + mocker.patch.object(ph, "PolygonClient", poly_mock) + + # We'll override the LUMIBOT_CACHE_FOLDER if needed, in case your code references it + mocker.patch.object(ph, "LUMIBOT_CACHE_FOLDER", tmp_path) + + # If it's an option, let's pretend there's a valid contract + poly_mock().list_options_contracts.return_value = [FakeContract("O:SPY230801C00100000")] + + # aggregator bars + bars = [ + {"o": 10, "h": 11, "l": 9, "c": 10.5, "v": 500, "t": 1690876800000}, + {"o": 12, "h": 14, "l": 10, "c": 13, "v": 600, "t": 1690876860000}, ] - df_new = ph.update_polygon_data(df_all, poly_result) - assert len(df_new) == 5 - assert df_new["close"].iloc[0] == 2 - assert df_new["close"].iloc[2] == 10 - assert df_new["close"].iloc[4] == 22 - assert df_new.index[0] == pd.DatetimeIndex(["2023-08-01 13:30:00-00:00"])[0] - assert df_new.index[2] == pd.DatetimeIndex(["2023-08-01 13:32:00-00:00"])[0] - assert df_new.index[4] == pd.DatetimeIndex(["2023-08-01 13:34:00-00:00"])[0] - - -class TestPolygonPriceData: - def test_get_price_data_from_polygon(self, mocker, tmpdir): - # Ensure we don't accidentally call the real Polygon API - mock_polyclient = mocker.MagicMock() - mocker.patch.object(ph, "PolygonClient", mock_polyclient) - mocker.patch.object(ph, "LUMIBOT_CACHE_FOLDER", tmpdir) - - # Options Contracts to return - option_ticker = "O:SPY230801C00100000" - mock_polyclient().list_options_contracts.return_value = [FakeContract(option_ticker)] - - # Basic Setup - api_key = "abc123" + poly_mock.create().get_aggs.return_value = bars + asset = Asset("SPY") - tz_e = pytz.timezone("US/Eastern") - start_date = tz_e.localize(datetime.datetime(2023, 8, 2, 6, 30)) # Include PreMarket - end_date = tz_e.localize(datetime.datetime(2023, 8, 2, 13, 0)) + start = datetime.datetime(2023, 8, 2, 9, 30, tzinfo=pytz.UTC) + end = datetime.datetime(2023, 8, 2, 16, 0, tzinfo=pytz.UTC) timespan = "minute" - expected_cachefile = ph.build_cache_filename(asset, timespan) - - assert not expected_cachefile.exists() - assert not expected_cachefile.parent.exists() - - # Fake some data from Polygon - mock_polyclient.create().get_aggs.return_value = [ - {"o": 1, "h": 4, "l": 1, "c": 2, "v": 100, "t": 1690876800000}, # 8/1/2023 8am UTC (start - 1day) - {"o": 5, "h": 8, "l": 3, "c": 7, "v": 100, "t": 1690876860000}, - {"o": 9, "h": 12, "l": 7, "c": 10, "v": 100, "t": 1690876920000}, - {"o": 13, "h": 16, "l": 11, "c": 14, "v": 100, "t": 1690986600000}, # 8/2/2023 at least 1 entry per date - {"o": 17, "h": 20, "l": 15, "c": 18, "v": 100, "t": 1690986660000}, - {"o": 21, "h": 24, "l": 19, "c": 22, "v": 100, "t": 1691105400000}, # 8/3/2023 11pm UTC (end + 1day) - ] - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan) - assert len(df) == 6 - assert df["close"].iloc[0] == 2 - assert mock_polyclient.create().get_aggs.call_count == 1 - assert expected_cachefile.exists() - - # Do the same query, but this time we should get the data from the cache - mock_polyclient.create().get_aggs.reset_mock() - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan) - assert len(df) == 6 - assert len(df.dropna()) == 6 - assert df["close"].iloc[0] == 2 - assert mock_polyclient.create().get_aggs.call_count == 0 - - # End time is moved out by a few hours, but it doesn't matter because we have all the data we need - mock_polyclient.create().get_aggs.reset_mock() - end_date = tz_e.localize(datetime.datetime(2023, 8, 2, 16, 0)) - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan) - assert len(df) == 6 - assert mock_polyclient.create().get_aggs.call_count == 0 - - # New day, new data - mock_polyclient.create().get_aggs.reset_mock() - start_date = tz_e.localize(datetime.datetime(2023, 8, 4, 6, 30)) - end_date = tz_e.localize(datetime.datetime(2023, 8, 4, 13, 0)) - mock_polyclient.create().get_aggs.return_value = [ - {"o": 5, "h": 8, "l": 3, "c": 7, "v": 100, "t": 1691136000000}, # 8/2/2023 8am UTC (start - 1day) - {"o": 9, "h": 12, "l": 7, "c": 10, "v": 100, "t": 1691191800000}, - ] - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan) - assert len(df) == 6 + 2 - assert mock_polyclient.create().get_aggs.call_count == 1 - - # Error case: Polygon returns nothing - like for a future date it doesn't know about - mock_polyclient.create().get_aggs.reset_mock() - mock_polyclient.create().get_aggs.return_value = [] - end_date = tz_e.localize(datetime.datetime(2023, 8, 31, 13, 0)) - - # Query a large range of dates and ensure we break up the Polygon API calls into - # multiple queries. - expected_cachefile.unlink() - mock_polyclient.create().get_aggs.reset_mock() - mock_polyclient.create().get_aggs.side_effect = [ - # First call for Auguest Data - [ - {"o": 5, "h": 8, "l": 3, "c": 7, "v": 100, "t": 1690876800000}, # 8/1/2023 8am UTC - {"o": 9, "h": 12, "l": 7, "c": 10, "v": 100, "t": 1693497600000}, # 8/31/2023 8am UTC - ], - # Second call for September Data - [ - {"o": 13, "h": 16, "l": 11, "c": 14, "v": 100, "t": 1693584000000}, # 9/1/2023 8am UTC - {"o": 17, "h": 20, "l": 15, "c": 18, "v": 100, "t": 1696176000000}, # 10/1/2023 8am UTC - ], - # Third call for October Data - [ - {"o": 21, "h": 24, "l": 19, "c": 22, "v": 100, "t": 1696262400000}, # 10/2/2023 8am UTC - {"o": 25, "h": 28, "l": 23, "c": 26, "v": 100, "t": 1698768000000}, # 10/31/2023 8am UTC - ], - ] - start_date = tz_e.localize(datetime.datetime(2023, 8, 1, 6, 30)) - end_date = tz_e.localize(datetime.datetime(2023, 10, 31, 13, 0)) # ~90 days - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan) - assert mock_polyclient.create().get_aggs.call_count == 3 - assert len(df) == 2 + 2 + 2 - - @pytest.mark.parametrize("timespan", ["day", "minute"]) - @pytest.mark.parametrize("force_cache_update", [True, False]) - def test_polygon_missing_day_caching(self, mocker, tmpdir, timespan, force_cache_update): - # Ensure we don't accidentally call the real Polygon API - mock_polyclient = mocker.MagicMock() - mocker.patch.object(ph, "PolygonClient", mock_polyclient) - mocker.patch.object(ph, "LUMIBOT_CACHE_FOLDER", tmpdir) - - # Basic Setup - api_key = "abc123" + # 1) First call => queries aggregator once + df_first = ph.get_price_data_from_polygon("fake_api", asset, start, end, timespan) + assert poly_mock.create().get_aggs.call_count == 1 + assert len(df_first) == 2 + + # 2) Second call => aggregator not called again if missing days=0 + poly_mock.create().get_aggs.reset_mock() + df_second = ph.get_price_data_from_polygon("fake_api", asset, start, end, timespan) + assert poly_mock.create().get_aggs.call_count == 0 + assert len(df_second) == 2 + + @pytest.mark.parametrize("force_update", [True, False]) + def test_force_cache_update(self, mocker, tmp_path, ephemeral_duckdb, force_update): + """force_cache_update => second call re-queries aggregator.""" + poly_mock = mocker.MagicMock() + mocker.patch.object(ph, "PolygonClient", poly_mock) + mocker.patch.object(ph, "LUMIBOT_CACHE_FOLDER", tmp_path) + + # aggregator data + bars = [{"o": 1, "h": 2, "l": 0.5, "c": 1.5, "v": 100, "t": 1690876800000}] + poly_mock.create().get_aggs.return_value = bars + asset = Asset("SPY") - tz_e = pytz.timezone("US/Eastern") - start_date = tz_e.localize(datetime.datetime(2023, 8, 2, 6, 30)) # Include PreMarket - end_date = tz_e.localize(datetime.datetime(2023, 8, 2, 13, 0)) - expected_cachefile = ph.build_cache_filename(asset, timespan) - assert not expected_cachefile.exists() - - # Fake some data from Polygon between start and end date - return_value = [] - if timespan == "day": - t = start_date - while t <= end_date: - return_value.append( - {"o": 1, "h": 4, "l": 1, "c": 2, "v": 100, "t": t.timestamp() * 1000} - ) - t += datetime.timedelta(days=1) - else: - t = start_date - while t <= end_date: - return_value.append( - {"o": 1, "h": 4, "l": 1, "c": 2, "v": 100, "t": t.timestamp() * 1000} - ) - t += datetime.timedelta(minutes=1) - - # Polygon is only called once for the same date range even when they are all missing. - mock_polyclient.create().get_aggs.return_value = return_value - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan, force_cache_update=force_cache_update) - - mock1 = mock_polyclient.create() - aggs = mock1.get_aggs - call_count = aggs.call_count - assert call_count == 1 - - assert expected_cachefile.exists() - if df is None: - df = pd.DataFrame() - assert len(df) == len(return_value) - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan, force_cache_update=force_cache_update) - if df is None: - df = pd.DataFrame() - assert len(df) == len(return_value) - if force_cache_update: - mock2 = mock_polyclient.create() - aggs = mock2.get_aggs - call_count = aggs.call_count - assert call_count == 2 + start = datetime.datetime(2023, 8, 2, 9, 30, tzinfo=pytz.UTC) + end = datetime.datetime(2023, 8, 2, 10, 0, tzinfo=pytz.UTC) + + # first call + df1 = ph.get_price_data_from_polygon("key", asset, start, end, "minute") + assert len(df1) == 1 + # aggregator called once + assert poly_mock.create().get_aggs.call_count == 1 + + # second call => aggregator depends on force_update + poly_mock.create().get_aggs.reset_mock() + df2 = ph.get_price_data_from_polygon("key", asset, start, end, "minute", force_cache_update=force_update) + + if force_update: + # aggregator called again + assert poly_mock.create().get_aggs.call_count == 1 else: - mock3 = mock_polyclient.create() - aggs = mock3.get_aggs - call_count = aggs.call_count - assert call_count == 1 - expected_cachefile.unlink() - - # Polygon is only called once for the same date range when some are missing. - mock_polyclient.create().get_aggs.reset_mock() - start_date = tz_e.localize(datetime.datetime(2023, 8, 1, 6, 30)) - end_date = tz_e.localize(datetime.datetime(2023, 10, 31, 13, 0)) # ~90 days - aggs_result_list = [ - # First call for August Data - [ - {"o": 5, "h": 8, "l": 3, "c": 7, "v": 100, "t": 1690876800000}, # 8/1/2023 8am UTC - {"o": 9, "h": 12, "l": 7, "c": 10, "v": 100, "t": 1693497600000}, # 8/31/2023 8am UTC - ], - # Second call for September Data - [ - {"o": 13, "h": 16, "l": 11, "c": 14, "v": 100, "t": 1693584000000}, # 9/1/2023 8am UTC - {"o": 17, "h": 20, "l": 15, "c": 18, "v": 100, "t": 1696176000000}, # 10/1/2023 8am UTC - {"o": 17, "h": 20, "l": 15, "c": 18, "v": 100, "t": 1696118400000}, # 10/1/2023 12am UTC - ], - # Third call for October Data - [ - {"o": 21, "h": 24, "l": 19, "c": 22, "v": 100, "t": 1696262400000}, # 10/2/2023 8am UTC - {"o": 25, "h": 28, "l": 23, "c": 26, "v": 100, "t": 1698768000000}, # 10/31/2023 8am UTC - ], + # aggregator not called again + assert poly_mock.create().get_aggs.call_count == 0 + + assert len(df2) == 1 + + +############################################################################### +# TEST: DuckDB-Specific Internals +############################################################################### + + +class TestDuckDBInternals: + """ + Tests for internal DuckDB methods: _asset_key, _transform_polygon_data, + _store_in_duckdb, _load_from_duckdb. We use ephemeral_duckdb to ensure + a fresh DB each test. + """ + + def test_asset_key(self): + """Check if _asset_key(...) returns the correct unique key for stocks vs. options.""" + st = Asset("SPY", asset_type="stock") + assert ph._asset_key(st) == "SPY" + + op = Asset("SPY", asset_type="option", + expiration=datetime.date(2024, 1, 14), + strike=577.0, right="CALL") + # e.g. => "SPY_240114_577.0_CALL" + opt_key = ph._asset_key(op) + assert "SPY_240114_577.0_CALL" == opt_key + + # Missing expiration => error + bad_opt = Asset("SPY", asset_type="option", strike=100, right="CALL") + with pytest.raises(ValueError): + ph._asset_key(bad_opt) + + def test_transform_polygon_data(self): + """_transform_polygon_data(...) should parse aggregator JSON into a DataFrame with columns & UTC index.""" + # empty => empty DataFrame + empty_df = ph._transform_polygon_data([]) + assert empty_df.empty + + # non-empty + results = [ + {"o": 10, "h": 12, "l": 9, "c": 11, "v": 100, "t": 1690896600000}, + {"o": 12, "h": 15, "l": 11, "c": 14, "v": 200, "t": 1690896660000}, ] - mock_polyclient.create().get_aggs.side_effect = aggs_result_list + aggs_result_list if force_cache_update else aggs_result_list - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan, force_cache_update=force_cache_update) - assert mock_polyclient.create().get_aggs.call_count == 3 - assert expected_cachefile.exists() - assert len(df) == 7 - df = ph.get_price_data_from_polygon(api_key, asset, start_date, end_date, timespan, force_cache_update=force_cache_update) - assert len(df) == 7 - if force_cache_update: - assert mock_polyclient.create().get_aggs.call_count == 2 * 3 - else: - assert mock_polyclient.create().get_aggs.call_count == 3 - expected_cachefile.unlink() + df = ph._transform_polygon_data(results) + assert len(df) == 2 + assert "open" in df.columns and "close" in df.columns + assert df.index[0] == pd.to_datetime(1690896600000, unit="ms", utc=True) + + def test_store_and_load_duckdb(self, ephemeral_duckdb): + """ + Full test for _store_in_duckdb(...) + _load_from_duckdb(...). + 1) Insert a small DF. 2) Load it, check correctness. 3) Insert overlap => no duplication. + """ + asset_stk = Asset("SPY", asset_type="stock") + timespan = "minute" + + idx = pd.date_range("2025-01-01 09:30:00", periods=3, freq="1min", tz="UTC") + df_in = pd.DataFrame({ + "open": [10.0, 11.0, 12.0], + "high": [11.0, 12.0, 13.0], + "low": [9.0, 10.0, 11.0], + "close": [10.5, 11.5, 12.5], + "volume": [100, 200, 300], + }, index=idx) + + # 1) Store + ph._store_in_duckdb(asset_stk, timespan, df_in) + + # 2) Load + loaded = ph._load_from_duckdb(asset_stk, timespan, idx[0], idx[-1]) + assert len(loaded) == 3 + assert (loaded["open"] == df_in["open"]).all() + + # 3) Partial range + partial = ph._load_from_duckdb(asset_stk, timespan, idx[1], idx[2]) + assert len(partial) == 2 + + # 4) Insert overlap => no duplication + ph._store_in_duckdb(asset_stk, timespan, df_in) + reloaded = ph._load_from_duckdb(asset_stk, timespan, idx[0], idx[-1]) + assert len(reloaded) == 3 # still 3 \ No newline at end of file