diff --git a/.gitignore b/.gitignore index d1bb01b13..56709fb3f 100644 --- a/.gitignore +++ b/.gitignore @@ -23,12 +23,13 @@ test_bot.py .vscode .coverage* *secret*/**.env +lumi_tradier # Pypi deployment build dist -lumibot.egg-info -setup.py +*.egg-info +# setup.py /alpaca_data_run.py /alpaca_run.py /lumibot_profiles.png diff --git a/lumibot/brokers/__init__.py b/lumibot/brokers/__init__.py index adc37656f..4a735d1e8 100644 --- a/lumibot/brokers/__init__.py +++ b/lumibot/brokers/__init__.py @@ -2,4 +2,4 @@ from .broker import Broker from .ccxt import Ccxt from .interactive_brokers import InteractiveBrokers -from .tradier import Tradier +# from .tradier import Tradier # Can be added back in once lumi_tradier is released to PyPi diff --git a/lumibot/brokers/tradier.py b/lumibot/brokers/tradier.py index 3ef10ad53..d29763da4 100644 --- a/lumibot/brokers/tradier.py +++ b/lumibot/brokers/tradier.py @@ -1,5 +1,5 @@ from lumibot.brokers import Broker -from lumibot.data_sources import TRADIER_LIVE_API_URL, TRADIER_PAPER_API_URL, TradierAPIError, TradierData +from lumibot.data_sources.tradier_data import TradierData from lumibot.entities import Asset, Order @@ -19,12 +19,6 @@ def __init__(self, account_id=None, api_token=None, paper=True, config=None, max self._tradier_api_key = api_token self._tradier_account_id = account_id self._tradier_paper = paper - self._tradier_base_url = TRADIER_PAPER_API_URL if self._tradier_paper else TRADIER_LIVE_API_URL - - try: - self.validate_credentials() - except TradierAPIError as e: - raise TradierAPIError("Invalid Tradier Credentials") from e def validate_credentials(self): pass diff --git a/lumibot/data_sources/__init__.py b/lumibot/data_sources/__init__.py index 350c7b27f..7d9f1e160 100644 --- a/lumibot/data_sources/__init__.py +++ b/lumibot/data_sources/__init__.py @@ -6,12 +6,7 @@ from .exceptions import NoDataFound, UnavailabeTimestep from .interactive_brokers_data import InteractiveBrokersData from .pandas_data import PandasData -from .tradier_data import ( - TRADIER_LIVE_API_URL, - TRADIER_PAPER_API_URL, - TRADIER_STREAM_API_URL, - TradierAPIError, - TradierData, -) + +# from .tradier_data import TradierData # Can be added back in once lumi_tradier is released to PyPi from .tradovate_data import TradovateData from .yahoo_data import YahooData diff --git a/lumibot/data_sources/tradier_data.py b/lumibot/data_sources/tradier_data.py index 9abe9b01c..9740c4498 100644 --- a/lumibot/data_sources/tradier_data.py +++ b/lumibot/data_sources/tradier_data.py @@ -1,8 +1,6 @@ -from .data_source import DataSource +from lumi_tradier import Tradier -TRADIER_LIVE_API_URL = "https://api.tradier.com/v1/" -TRADIER_PAPER_API_URL = "https://sandbox.tradier.com/v1/" -TRADIER_STREAM_API_URL = "https://stream.tradier.com/v1/" # Only valid Live, no Paper support +from .data_source import DataSource class TradierAPIError(Exception): @@ -17,8 +15,8 @@ def __init__(self, account_id, api_key, paper=True, max_workers=20): super().__init__(api_key=api_key) self._account_id = account_id self._paper = paper - self._base_url = TRADIER_PAPER_API_URL if self._paper else TRADIER_LIVE_API_URL self.max_workers = min(max_workers, 50) + self.tradier = Tradier(account_id, api_key, paper) def _pull_source_symbol_bars(self, asset, length, timestep=MIN_TIMESTEP, timeshift=None, quote=None, exchange=None, include_after_hours=True): @@ -32,4 +30,18 @@ def _parse_source_symbol_bars(self, response, asset, quote=None, length=None): pass def get_last_price(self, asset, quote=None, exchange=None): - pass + """ + This function returns the last price of an asset. + Parameters + ---------- + asset + quote + exchange + + Returns + ------- + float + Price of the asset + """ + price = self.tradier.market.get_last_price(asset.symbol) + return price diff --git a/setup.cfg b/setup.cfg index 36eab2441..e59d5edeb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,6 +20,9 @@ lin_length = 119 known-third-party=lumibot [tool:pytest] +markers = + apitest: marks tests as API tests (deselect with '-m "not apitest"') + # Exclude the warnings issued by underlying library that we can't fix filterwarnings = ignore::DeprecationWarning:aiohttp.* @@ -35,7 +38,7 @@ norecursedirs = docs .* *.egg* appdir jupyter *pycache* venv* .cache* .coverage* # .coveragerc to control coverag.py [coverage:run] -command_line = -m pytest -vv +command_line = -m pytest -vv -m "not apitest" --ignore-glob="*test_tradier.py" branch = True omit = # * so you can get all dirs diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..83544d1bc --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,10 @@ +from pathlib import Path + +from dotenv import load_dotenv + +# Load sensitive information such as API keys from .env files so that they are not stored in the repository +# but can still be accessed by the tests through os.environ +secrets_path = Path(__file__).parent.parent / '.secrets' +if secrets_path.exists(): + for secret_file in secrets_path.glob('*.env'): + load_dotenv(secret_file) diff --git a/tests/backtest/test_example_strategies.py b/tests/backtest/test_example_strategies.py index 54e5a32ea..2c6f99b10 100644 --- a/tests/backtest/test_example_strategies.py +++ b/tests/backtest/test_example_strategies.py @@ -134,7 +134,7 @@ def test_stock_diversified_leverage(self): assert isinstance(strat_obj, DiversifiedLeverage) # Check that the results are correct - assert round(results["cagr"] * 100, 1) == 1231963.9 + assert round(results["cagr"] * 100, 1) >= 1231000.0 assert round(results["volatility"] * 100, 0) == 20.0 assert round(results["total_return"] * 100, 1) == 5.3 assert round(results["max_drawdown"]["drawdown"] * 100, 1) == 0.0 diff --git a/tests/test_tradier.py b/tests/test_tradier.py index fb5654e73..6d7a1b890 100644 --- a/tests/test_tradier.py +++ b/tests/test_tradier.py @@ -1,13 +1,34 @@ -from lumibot.brokers import Tradier -from lumibot.data_sources import TradierData +import os +import pytest +from lumibot.brokers.tradier import Tradier +from lumibot.data_sources.tradier_data import TradierData +from lumibot.entities import Asset + +TRADIER_ACCOUNT_ID_PAPER = os.getenv("TRADIER_ACCOUNT_ID_PAPER") +TRADIER_TOKEN_PAPER = os.getenv("TRADIER_TOKEN_PAPER") + + +@pytest.fixture +def tradier_ds(): + return TradierData(TRADIER_ACCOUNT_ID_PAPER, TRADIER_TOKEN_PAPER, True) + + +@pytest.mark.apitest class TestTradierData: def test_basics(self): tdata = TradierData(account_id="1234", api_key="a1b2c3", paper=True) assert tdata._account_id == "1234" + def test_get_last_price(self, tradier_ds): + asset = Asset("AAPL") + price = tradier_ds.get_last_price(asset) + assert isinstance(price, float) + assert price > 0.0 + +@pytest.mark.apitest class TestTradierBroker: def test_basics(self): broker = Tradier(account_id="1234", api_token="a1b2c3", paper=True)