From 5c6f0e42fd56ee1e81907ad165dd38c6ace6c176 Mon Sep 17 00:00:00 2001 From: Jiakuan Li Date: Tue, 16 Jan 2024 16:11:05 -0500 Subject: [PATCH 1/5] feat: :sparkles: implement get user profile interface --- asynctradier/common/__init__.py | 49 ++++++++++++++++++++ asynctradier/common/user_profile.py | 51 ++++++++++++++++++++ asynctradier/exceptions/__init__.py | 12 +++++ asynctradier/tradier.py | 36 +++++++++++++++ tests/test_common.py | 32 +++++++++++++ tests/test_tradier.py | 72 ++++++++++++++++++++++++++++- 6 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 asynctradier/common/user_profile.py diff --git a/asynctradier/common/__init__.py b/asynctradier/common/__init__.py index 21ee4a4..c3f8874 100644 --- a/asynctradier/common/__init__.py +++ b/asynctradier/common/__init__.py @@ -142,3 +142,52 @@ class OptionType(StrEnum): call = "call" put = "put" + + +class Classification(StrEnum): + """ + Enum class representing different classifications. + + Possible values: + - individual + - entity + - joint_survivor + - traditional_ira + - roth_ira + - rollover_ira + - sep_ira + """ + + individual = "individual" + entity = "entity" + joint_survivor = "joint_survivor" + traditional_ira = "traditional_ira" + roth_ira = "roth_ira" + rollover_ira = "rollover_ira" + sep_ira = "sep_ira" + + +class AccountStatus(StrEnum): + """ + Represents the status of an account. + + Attributes: + open (str): The account is open. + closed (str): The account is closed. + """ + + active = "active" + closed = "closed" + + +class AccountType(StrEnum): + """ + Represents the type of account. + + Attributes: + cash (str): The account is a cash account. + margin (str): The account is a margin account. + """ + + cash = "cash" + margin = "margin" diff --git a/asynctradier/common/user_profile.py b/asynctradier/common/user_profile.py new file mode 100644 index 0000000..7dda549 --- /dev/null +++ b/asynctradier/common/user_profile.py @@ -0,0 +1,51 @@ +from asynctradier.common import AccountStatus, AccountType, Classification + + +class UserAccount: + """ + Represents a user profile with various attributes. + + Attributes: + id (str): The ID of the user profile. + name (str): The name of the user profile. + account_number (str): The account number associated with the user profile. + classification (Classification): The classification of the user profile. + date_created (str): The date when the user profile was created. + day_trader (bool): Indicates whether the user is a day trader or not. + option_level (int): The option level of the user profile. + status (AccountStatus): The status of the user profile. + type (AccountType): The type of the user profile. + last_update_date (str): The date of the last update to the user profile. + """ + + def __init__(self, **kwargs): + self.id = kwargs.get("id") + self.name = kwargs.get("name") + self.account_number = kwargs.get("account_number") + self.classification = Classification(kwargs.get("classification")) + self.date_created = kwargs.get("date_created") + self.day_trader = kwargs.get("day_trader") + self.option_level = ( + int(kwargs.get("option_level")) if kwargs.get("option_level") else None + ) + self.status = ( + AccountStatus(kwargs.get("status")) if kwargs.get("status") else None + ) + self.type = AccountType(kwargs.get("type")) if kwargs.get("type") else None + self.last_update_date = kwargs.get("last_update_date") + + def __dict__(self): + return { + "id": self.id, + "name": self.name, + "account_number": self.account_number, + "classification": self.classification.value + if self.classification + else None, + "date_created": self.date_created, + "day_trader": self.day_trader, + "option_level": self.option_level, + "status": self.status.value if self.status else None, + "type": self.type.value if self.type else None, + "last_update_date": self.last_update_date, + } diff --git a/asynctradier/exceptions/__init__.py b/asynctradier/exceptions/__init__.py index 121d7fa..ffe5233 100644 --- a/asynctradier/exceptions/__init__.py +++ b/asynctradier/exceptions/__init__.py @@ -77,3 +77,15 @@ class BadRequestException(Exception): def __init__(self, code: int, msg: str) -> None: super().__init__(f"Request failed: {code}, msg: {msg}") + + +class APINotAvailable(Exception): + """ + Exception raised when the API is not available. + + Attributes: + msg (str): The error message. + """ + + def __init__(self, msg: str) -> None: + super().__init__(f"API is not available. {msg}") diff --git a/asynctradier/tradier.py b/asynctradier/tradier.py index 0f914f8..4b8d69a 100644 --- a/asynctradier/tradier.py +++ b/asynctradier/tradier.py @@ -9,7 +9,9 @@ from asynctradier.common.order import Order from asynctradier.common.position import Position from asynctradier.common.quote import Quote +from asynctradier.common.user_profile import UserAccount from asynctradier.exceptions import ( + APINotAvailable, InvalidExiprationDate, InvalidOptionType, InvalidParameter, @@ -38,6 +40,7 @@ def __init__(self, account_id: str, token: str, sandbox: bool = False) -> None: "https://api.tradier.com" if not sandbox else "https://sandbox.tradier.com" ) self.session = WebUtil(base_url, token) + self.sandbox = sandbox async def get_positions(self) -> List[Position]: """ @@ -627,3 +630,36 @@ async def option_lookup(self, symbol: str) -> List[str]: if response.get("symbols") is None: return [] return response.get("symbols")[0].get("options", []) + + async def get_user_profile(self) -> List[UserAccount]: + """ + Retrieves the user profile information. + + Returns: + A list of UserProfile objects representing the user's profile information. + """ + if self.sandbox: + raise APINotAvailable("get user profile is only available in production") + + url = "/v1/user/profile" + response = await self.session.get(url) + + if response.get("profile") is None: + return [] + + if not isinstance(response["profile"]["account"], list): + accounts = [response["profile"]["account"]] + else: + accounts = response["profile"]["account"] + + res: List[UserAccount] = [] + for account in accounts: + res.append( + UserAccount( + **account, + id=response["profile"]["id"], + name=response["profile"]["name"], + ) + ) + + return res diff --git a/tests/test_common.py b/tests/test_common.py index 3f1c59e..76dbaed 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,4 +1,7 @@ from asynctradier.common import ( + AccountStatus, + AccountType, + Classification, Duration, OptionType, OrderClass, @@ -11,6 +14,7 @@ from asynctradier.common.option_contract import OptionContract from asynctradier.common.order import Order from asynctradier.common.quote import Quote +from asynctradier.common.user_profile import UserAccount from asynctradier.exceptions import InvalidExiprationDate, InvalidOptionType @@ -574,3 +578,31 @@ def test_expirations(): assert expirations.expiration_type == "weeklys" print(expirations.strikes) assert len(expirations.strikes) == len(expiration_info["strikes"]) + + +def test_userprofile(): + userprofile_info = { + "id": "id-gcostanza", + "name": "George Costanza", + "account_number": "VA000001", + "classification": "individual", + "date_created": "2016-08-01T21:08:55.000Z", + "day_trader": False, + "option_level": 6, + "status": "active", + "type": "margin", + "last_update_date": "2016-08-01T21:08:55.000Z", + } + + account = UserAccount(**userprofile_info) + + assert account.id == "id-gcostanza" + assert account.name == "George Costanza" + assert account.account_number == "123456789" + assert account.classification == Classification.individual + assert account.date_created == "2018-06-01T12:02:29.682Z" + assert account.day_trader is False + assert account.option_level == 6 + assert account.status == AccountStatus.active + assert account.type == AccountType.margin + assert account.last_update_date == "2018-06-01T12:02:29.682Z" diff --git a/tests/test_tradier.py b/tests/test_tradier.py index 1b5e5da..6a313ab 100644 --- a/tests/test_tradier.py +++ b/tests/test_tradier.py @@ -4,7 +4,7 @@ from asynctradier.common import Duration, OptionType, OrderSide, OrderType from asynctradier.common.option_contract import OptionContract -from asynctradier.exceptions import InvalidExiprationDate +from asynctradier.exceptions import APINotAvailable, InvalidExiprationDate from asynctradier.tradier import TradierClient @@ -1296,3 +1296,73 @@ def mock_get(path: str, params: dict = None): tradier_client.session.get.assert_called_once_with( "/v1/markets/options/lookup", params={"underlying": "SEFDF"} ) + + +@pytest.mark.asyncio +async def test_get_user_profile(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return { + "profile": { + "account": [ + { + "account_number": "VA000001", + "classification": "individual", + "date_created": "2016-08-01T21:08:55.000Z", + "day_trader": False, + "option_level": 6, + "status": "active", + "type": "margin", + "last_update_date": "2016-08-01T21:08:55.000Z", + }, + { + "account_number": "VA000002", + "classification": "traditional_ira", + "date_created": "2016-08-05T17:24:34.000Z", + "day_trader": False, + "option_level": 3, + "status": "active", + "type": "margin", + "last_update_date": "2016-08-05T17:24:34.000Z", + }, + { + "account_number": "VA000003", + "classification": "rollover_ira", + "date_created": "2016-08-01T21:08:56.000Z", + "day_trader": False, + "option_level": 2, + "status": "active", + "type": "cash", + "last_update_date": "2016-08-01T21:08:56.000Z", + }, + ], + "id": "id-gcostanza", + "name": "George Costanza", + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + accounts = await tradier_client.get_user_profile() + + assert len(accounts) == 3 + + for account, account_info in zip(accounts, mock_get("")["profile"]["account"]): + assert account.account_number == account_info["account_number"] + assert account.classification == account_info["classification"] + assert account.date_created == account_info["date_created"] + assert account.day_trader == account_info["day_trader"] + assert account.option_level == account_info["option_level"] + assert account.status == account_info["status"] + assert account.type == account_info["type"] + assert account.last_update_date == account_info["last_update_date"] + + tradier_client.session.get.assert_called_once_with("/v1/user/profile") + + +@pytest.mark.asyncio +async def test_get_user_profile_sanbox(tradier_client): + try: + await tradier_client.get_user_profile() + except APINotAvailable: + assert True From f764153328db5f01a0f4af74b80e4bccbbc12c99 Mon Sep 17 00:00:00 2001 From: Jiakuan Li Date: Thu, 18 Jan 2024 23:02:55 -0500 Subject: [PATCH 2/5] feat: :sparkles: implement get balance interface --- asynctradier/common/__init__.py | 2 + asynctradier/common/account_balance.py | 210 +++++++++++++++++++++ asynctradier/tradier.py | 14 ++ tests/test_common.py | 243 ++++++++++++++++++++++++- tests/test_tradier.py | 200 +++++++++++++++++++- 5 files changed, 663 insertions(+), 6 deletions(-) create mode 100644 asynctradier/common/account_balance.py diff --git a/asynctradier/common/__init__.py b/asynctradier/common/__init__.py index c3f8874..6e2df4c 100644 --- a/asynctradier/common/__init__.py +++ b/asynctradier/common/__init__.py @@ -187,7 +187,9 @@ class AccountType(StrEnum): Attributes: cash (str): The account is a cash account. margin (str): The account is a margin account. + pdt (str): The account is a pattern day trader account. """ cash = "cash" margin = "margin" + pdt = "pdt" diff --git a/asynctradier/common/account_balance.py b/asynctradier/common/account_balance.py new file mode 100644 index 0000000..95207a4 --- /dev/null +++ b/asynctradier/common/account_balance.py @@ -0,0 +1,210 @@ +from asynctradier.common import AccountType + + +class CashAccountBalanceDetails: + """ + Represents the details of a cash account balance. + + Attributes: + cash_available (float): The amount of cash available in the account. + sweep (float): The amount of cash swept from the account. + unsettled_funds (float): The amount of funds that are currently unsettled. + """ + + def __init__(self, **kwargs): + self.cash_available = kwargs.get("cash_available", 0.0) + self.sweep = kwargs.get("sweep", 0.0) + self.unsettled_funds = kwargs.get("unsettled_funds", 0.0) + + def to_dict(self): + """ + Converts the AccountBalance object to a dictionary. + + Returns: + dict: A dictionary representation of the AccountBalance object. + """ + return { + "cash_available": self.cash_available, + "sweep": self.sweep, + "unsettled_funds": self.unsettled_funds, + } + + def __str__(self): + return f"CashAccountBalanceDetails(capacity={self.cash_available}, sweep={self.sweep}, unsettled_funds={self.unsettled_funds})" + + +class MarginAccountBalanceDetails: + """ + Represents the details of a margin account balance. + + Attributes: + fed_call (float): The federal call amount. + maintenance_call (float): The maintenance call amount. + option_buying_power (float): The buying power for options. + stock_buying_power (float): The buying power for stocks. + stock_short_value (float): The value of shorted stocks. + sweep (float): The sweep amount. + """ + + def __init__(self, **kwargs): + self.fed_call = kwargs.get("fed_call", 0.0) + self.maintenance_call = kwargs.get("maintenance_call", 0.0) + self.option_buying_power = kwargs.get("option_buying_power", 0.0) + self.stock_buying_power = kwargs.get("stock_buying_power", 0.0) + self.stock_short_value = kwargs.get("stock_short_value", 0.0) + self.sweep = kwargs.get("sweep", 0.0) + + def to_dict(self): + """ + Converts the AccountBalance object to a dictionary. + + Returns: + dict: A dictionary representation of the AccountBalance object. + """ + return { + "fed_call": self.fed_call, + "maintenance_call": self.maintenance_call, + "option_buying_power": self.option_buying_power, + "stock_buying_power": self.stock_buying_power, + "stock_short_value": self.stock_short_value, + "sweep": self.sweep, + } + + def __str__(self): + return f"MarginAccountBalanceDetails(fed_call={self.fed_call}, maintenance_call={self.maintenance_call}, option_buying_power={self.option_buying_power}, stock_buying_power={self.stock_buying_power}, stock_short_value={self.stock_short_value}, sweep={self.sweep})" + + +class PDTAccountBalanceDetails: + """ + Represents the account balance details for a Pattern Day Trader (PDT). + + Attributes: + fed_call (float): The amount of the Federal Call. + maintenance_call (float): The amount of the Maintenance Call. + option_buying_power (float): The buying power for options trading. + stock_buying_power (float): The buying power for stock trading. + stock_short_value (float): The value of shorted stocks. + """ + + def __init__(self, **kwargs): + self.fed_call = kwargs.get("fed_call", 0.0) + self.maintenance_call = kwargs.get("maintenance_call", 0.0) + self.option_buying_power = kwargs.get("option_buying_power", 0.0) + self.stock_buying_power = kwargs.get("stock_buying_power", 0.0) + self.stock_short_value = kwargs.get("stock_short_value", 0.0) + + def to_dict(self): + """ + Converts the AccountBalance object to a dictionary. + + Returns: + dict: A dictionary representation of the AccountBalance object. + """ + return { + "fed_call": self.fed_call, + "maintenance_call": self.maintenance_call, + "option_buying_power": self.option_buying_power, + "stock_buying_power": self.stock_buying_power, + "stock_short_value": self.stock_short_value, + } + + def __str__(self): + return f"PDTAccountBalanceDetails(fed_call={self.fed_call}, maintenance_call={self.maintenance_call}, option_buying_power={self.option_buying_power}, stock_buying_power={self.stock_buying_power}, stock_short_value={self.stock_short_value})" + + +class AccountBalance: + """ + Represents the balance of an account. + + Attributes: + option_short_value (float): The short value of options in the account. + total_equity (float): The total equity of the account. + account_number (str): The account number. + account_type (AccountType): The type of the account. + close_pl (float): The close profit/loss of the account. + current_requirement (float): The current requirement of the account. + equity (float): The equity of the account. + long_market_value (float): The long market value of the account. + market_value (float): The market value of the account. + open_pl (float): The open profit/loss of the account. + option_long_value (float): The long value of options in the account. + option_requirement (float): The option requirement of the account. + pending_orders_count (int): The count of pending orders in the account. + short_market_value (float): The short market value of the account. + stock_long_value (float): The long value of stocks in the account. + total_cash (float): The total cash in the account. + uncleared_funds (float): The uncleared funds in the account. + pending_cash (float): The pending cash in the account. + cash (CashAccountBalanceDetails): The details of the cash account balance (if account type is cash). + margin (MarginAccountBalanceDetails): The details of the margin account balance (if account type is margin). + pdt (PDTAccountBalanceDetails): The details of the PDT account balance (if account type is pdt). + """ + + def __init__(self, **kwargs): + self.option_short_value = kwargs.get("option_short_value") + self.total_equity = kwargs.get("total_equity") + self.account_number = kwargs.get("account_number") + self.account_type = ( + AccountType(kwargs.get("account_type")) + if kwargs.get("account_type") + else None + ) + self.close_pl = kwargs.get("close_pl") + self.current_requirement = kwargs.get("current_requirement") + self.equity = kwargs.get("equity") + self.long_market_value = kwargs.get("long_market_value") + self.market_value = kwargs.get("market_value") + self.open_pl = kwargs.get("open_pl") + self.option_long_value = kwargs.get("option_long_value") + self.option_requirement = kwargs.get("option_requirement") + self.pending_orders_count = kwargs.get("pending_orders_count") + self.short_market_value = kwargs.get("short_market_value") + self.stock_long_value = kwargs.get("stock_long_value") + self.total_cash = kwargs.get("total_cash") + self.uncleared_funds = kwargs.get("uncleared_funds") + self.pending_cash = kwargs.get("pending_cash") + + self.cash = ( + CashAccountBalanceDetails(**kwargs.get("cash")) + if kwargs.get("cash") + else None + ) + self.margin = ( + MarginAccountBalanceDetails(**kwargs.get("margin")) + if kwargs.get("margin") + else None + ) + self.pdt = ( + PDTAccountBalanceDetails(**kwargs.get("pdt")) if kwargs.get("pdt") else None + ) + + def to_dict(self): + """ + Converts the AccountBalance object to a dictionary. + + Returns: + dict: A dictionary representation of the AccountBalance object. + """ + return { + "option_short_value": self.option_short_value, + "total_equity": self.total_equity, + "account_number": self.account_number, + "account_type": self.account_type, + "close_pl": self.close_pl, + "current_requirement": self.current_requirement, + "equity": self.equity, + "long_market_value": self.long_market_value, + "market_value": self.market_value, + "open_pl": self.open_pl, + "option_long_value": self.option_long_value, + "option_requirement": self.option_requirement, + "pending_orders_count": self.pending_orders_count, + "short_market_value": self.short_market_value, + "stock_long_value": self.stock_long_value, + "total_cash": self.total_cash, + "uncleared_funds": self.uncleared_funds, + "pending_cash": self.pending_cash, + "cash": self.cash.to_dict() if self.cash else None, + "margin": self.margin.to_dict() if self.margin else None, + "pdt": self.pdt.to_dict() if self.pdt else None, + } diff --git a/asynctradier/tradier.py b/asynctradier/tradier.py index 4b8d69a..d9c5eae 100644 --- a/asynctradier/tradier.py +++ b/asynctradier/tradier.py @@ -4,6 +4,7 @@ import websockets from asynctradier.common import Duration, OptionType, OrderClass, OrderSide, OrderType +from asynctradier.common.account_balance import AccountBalance from asynctradier.common.expiration import Expiration from asynctradier.common.option_contract import OptionContract from asynctradier.common.order import Order @@ -663,3 +664,16 @@ async def get_user_profile(self) -> List[UserAccount]: ) return res + + async def get_balance(self) -> AccountBalance: + """ + Retrieves the account balance. + + Returns: + AccountBalance: The account balance. + """ + url = f"/v1/accounts/{self.account_id}/balances" + response = await self.session.get(url) + return AccountBalance( + **response["balances"], + ) diff --git a/tests/test_common.py b/tests/test_common.py index 76dbaed..31a7b4e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -10,6 +10,12 @@ OrderType, QuoteType, ) +from asynctradier.common.account_balance import ( + AccountBalance, + CashAccountBalanceDetails, + MarginAccountBalanceDetails, + PDTAccountBalanceDetails, +) from asynctradier.common.expiration import Expiration from asynctradier.common.option_contract import OptionContract from asynctradier.common.order import Order @@ -596,13 +602,240 @@ def test_userprofile(): account = UserAccount(**userprofile_info) - assert account.id == "id-gcostanza" - assert account.name == "George Costanza" - assert account.account_number == "123456789" + assert account.id == userprofile_info["id"] + assert account.name == userprofile_info["name"] + assert account.account_number == userprofile_info["account_number"] assert account.classification == Classification.individual - assert account.date_created == "2018-06-01T12:02:29.682Z" + assert account.date_created == userprofile_info["date_created"] assert account.day_trader is False assert account.option_level == 6 assert account.status == AccountStatus.active assert account.type == AccountType.margin - assert account.last_update_date == "2018-06-01T12:02:29.682Z" + assert account.last_update_date == userprofile_info["last_update_date"] + + +def test_cashbalancedetail(): + detail_info = { + "cash_available": 4343.38000000, + "sweep": 0, + "unsettled_funds": 1310.00000000, + } + + detail = CashAccountBalanceDetails(**detail_info) + + assert detail.cash_available == 4343.38000000 + assert detail.sweep == 0 + assert detail.unsettled_funds == 1310.00000000 + + +def test_marginbalancedetail(): + detail = { + "fed_call": 0, + "maintenance_call": 0, + "option_buying_power": 6363.860000000000000000000000, + "stock_buying_power": 12727.7200000000000000, + "stock_short_value": 0, + "sweep": 0, + } + + detail = MarginAccountBalanceDetails(**detail) + + assert detail.fed_call == 0 + assert detail.maintenance_call == 0 + assert detail.option_buying_power == 6363.860000000000000000000000 + assert detail.stock_buying_power == 12727.7200000000000000 + assert detail.stock_short_value == 0 + assert detail.sweep == 0 + + +def test_pdfbalancedetail(): + detail = { + "fed_call": 0, + "maintenance_call": 0, + "option_buying_power": 6363.860000000000000000000000, + "stock_buying_power": 12727.7200000000000000, + "stock_short_value": 0, + } + + detail = PDTAccountBalanceDetails(**detail) + + assert detail.fed_call == 0 + assert detail.maintenance_call == 0 + assert detail.option_buying_power == 6363.860000000000000000000000 + assert detail.stock_buying_power == 12727.7200000000000000 + assert detail.stock_short_value == 0 + + +def test_balance_margin(): + detail = { + "option_short_value": 0, + "total_equity": 17798.360000000000000000000000, + "account_number": "VA00000000", + "account_type": "margin", + "close_pl": -4813.000000000000000000, + "current_requirement": 2557.00000000000000000000, + "equity": 0, + "long_market_value": 11434.50000000000000000000, + "market_value": 11434.50000000000000000000, + "open_pl": 546.900000000000000000000000, + "option_long_value": 8877.5000000000000000000, + "option_requirement": 0, + "pending_orders_count": 0, + "short_market_value": 0, + "stock_long_value": 2557.00000000000000000000, + "total_cash": 6363.860000000000000000000000, + "uncleared_funds": 0, + "pending_cash": 0, + "margin": { + "fed_call": 0, + "maintenance_call": 0, + "option_buying_power": 6363.860000000000000000000000, + "stock_buying_power": 12727.7200000000000000, + "stock_short_value": 0, + "sweep": 0, + }, + } + + balance = AccountBalance(**detail) + + assert balance.option_short_value == detail["option_short_value"] + assert balance.total_equity == detail["total_equity"] + assert balance.account_number == detail["account_number"] + assert balance.account_type == AccountType.margin + assert balance.close_pl == detail["close_pl"] + assert balance.current_requirement == detail["current_requirement"] + assert balance.equity == detail["equity"] + assert balance.long_market_value == detail["long_market_value"] + assert balance.market_value == detail["market_value"] + assert balance.open_pl == detail["open_pl"] + assert balance.option_long_value == detail["option_long_value"] + assert balance.option_requirement == detail["option_requirement"] + assert balance.pending_orders_count == detail["pending_orders_count"] + assert balance.short_market_value == detail["short_market_value"] + assert balance.stock_long_value == detail["stock_long_value"] + assert balance.total_cash == detail["total_cash"] + assert balance.uncleared_funds == detail["uncleared_funds"] + assert balance.pending_cash == detail["pending_cash"] + + assert balance.margin.fed_call == detail["margin"]["fed_call"] + assert balance.margin.maintenance_call == detail["margin"]["maintenance_call"] + assert balance.margin.option_buying_power == detail["margin"]["option_buying_power"] + assert balance.margin.stock_buying_power == detail["margin"]["stock_buying_power"] + assert balance.margin.stock_short_value == detail["margin"]["stock_short_value"] + assert balance.margin.sweep == detail["margin"]["sweep"] + + assert balance.cash is None + assert balance.pdt is None + + +def test_balance_cash(): + detail = { + "option_short_value": 0, + "total_equity": 17798.360000000000000000000000, + "account_number": "VA00000000", + "account_type": "margin", + "close_pl": -4813.000000000000000000, + "current_requirement": 2557.00000000000000000000, + "equity": 0, + "long_market_value": 11434.50000000000000000000, + "market_value": 11434.50000000000000000000, + "open_pl": 546.900000000000000000000000, + "option_long_value": 8877.5000000000000000000, + "option_requirement": 0, + "pending_orders_count": 0, + "short_market_value": 0, + "stock_long_value": 2557.00000000000000000000, + "total_cash": 6363.860000000000000000000000, + "uncleared_funds": 0, + "pending_cash": 0, + "cash": { + "cash_available": 4343.38000000, + "sweep": 0, + "unsettled_funds": 1310.00000000, + }, + } + + balance = AccountBalance(**detail) + + assert balance.option_short_value == detail["option_short_value"] + assert balance.total_equity == detail["total_equity"] + assert balance.account_number == detail["account_number"] + assert balance.account_type == AccountType.margin + assert balance.close_pl == detail["close_pl"] + assert balance.current_requirement == detail["current_requirement"] + assert balance.equity == detail["equity"] + assert balance.long_market_value == detail["long_market_value"] + assert balance.market_value == detail["market_value"] + assert balance.open_pl == detail["open_pl"] + assert balance.option_long_value == detail["option_long_value"] + assert balance.option_requirement == detail["option_requirement"] + assert balance.pending_orders_count == detail["pending_orders_count"] + assert balance.short_market_value == detail["short_market_value"] + assert balance.stock_long_value == detail["stock_long_value"] + assert balance.total_cash == detail["total_cash"] + assert balance.uncleared_funds == detail["uncleared_funds"] + assert balance.pending_cash == detail["pending_cash"] + + assert balance.cash.cash_available == detail["cash"]["cash_available"] + assert balance.cash.sweep == detail["cash"]["sweep"] + assert balance.cash.unsettled_funds == detail["cash"]["unsettled_funds"] + + assert balance.margin is None + assert balance.pdt is None + + +def test_balance_pdt(): + detail = { + "option_short_value": 0, + "total_equity": 17798.360000000000000000000000, + "account_number": "VA00000000", + "account_type": "margin", + "close_pl": -4813.000000000000000000, + "current_requirement": 2557.00000000000000000000, + "equity": 0, + "long_market_value": 11434.50000000000000000000, + "market_value": 11434.50000000000000000000, + "open_pl": 546.900000000000000000000000, + "option_long_value": 8877.5000000000000000000, + "option_requirement": 0, + "pending_orders_count": 0, + "short_market_value": 0, + "stock_long_value": 2557.00000000000000000000, + "total_cash": 6363.860000000000000000000000, + "uncleared_funds": 0, + "pending_cash": 0, + "pdt": { + "fed_call": 0, + "maintenance_call": 0, + "option_buying_power": 6363.860000000000000000000000, + "stock_buying_power": 12727.7200000000000000, + "stock_short_value": 0, + }, + } + balance = AccountBalance(**detail) + + assert balance.option_short_value == detail["option_short_value"] + assert balance.total_equity == detail["total_equity"] + assert balance.account_number == detail["account_number"] + assert balance.account_type == AccountType.margin + assert balance.close_pl == detail["close_pl"] + assert balance.current_requirement == detail["current_requirement"] + assert balance.equity == detail["equity"] + assert balance.long_market_value == detail["long_market_value"] + assert balance.market_value == detail["market_value"] + assert balance.open_pl == detail["open_pl"] + assert balance.option_long_value == detail["option_long_value"] + assert balance.option_requirement == detail["option_requirement"] + assert balance.pending_orders_count == detail["pending_orders_count"] + assert balance.short_market_value == detail["short_market_value"] + assert balance.stock_long_value == detail["stock_long_value"] + assert balance.total_cash == detail["total_cash"] + assert balance.uncleared_funds == detail["uncleared_funds"] + assert balance.pending_cash == detail["pending_cash"] + + assert balance.pdt.fed_call == detail["pdt"]["fed_call"] + assert balance.pdt.maintenance_call == detail["pdt"]["maintenance_call"] + assert balance.pdt.option_buying_power == detail["pdt"]["option_buying_power"] + + assert balance.margin is None + assert balance.cash is None diff --git a/tests/test_tradier.py b/tests/test_tradier.py index 6a313ab..4cb5a16 100644 --- a/tests/test_tradier.py +++ b/tests/test_tradier.py @@ -2,7 +2,7 @@ import pytest -from asynctradier.common import Duration, OptionType, OrderSide, OrderType +from asynctradier.common import AccountType, Duration, OptionType, OrderSide, OrderType from asynctradier.common.option_contract import OptionContract from asynctradier.exceptions import APINotAvailable, InvalidExiprationDate from asynctradier.tradier import TradierClient @@ -1366,3 +1366,201 @@ async def test_get_user_profile_sanbox(tradier_client): await tradier_client.get_user_profile() except APINotAvailable: assert True + + +@pytest.mark.asyncio +async def test_get_balance_margin(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return { + "balances": { + "option_short_value": 0, + "total_equity": 17798.360000000000000000000000, + "account_number": "VA00000000", + "account_type": "margin", + "close_pl": -4813.000000000000000000, + "current_requirement": 2557.00000000000000000000, + "equity": 0, + "long_market_value": 11434.50000000000000000000, + "market_value": 11434.50000000000000000000, + "open_pl": 546.900000000000000000000000, + "option_long_value": 8877.5000000000000000000, + "option_requirement": 0, + "pending_orders_count": 0, + "short_market_value": 0, + "stock_long_value": 2557.00000000000000000000, + "total_cash": 6363.860000000000000000000000, + "uncleared_funds": 0, + "pending_cash": 0, + "margin": { + "fed_call": 0, + "maintenance_call": 0, + "option_buying_power": 6363.860000000000000000000000, + "stock_buying_power": 12727.7200000000000000, + "stock_short_value": 0, + "sweep": 0, + }, + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + balance = await tradier_client.get_balance() + + assert balance.option_short_value == 0 + assert balance.total_equity == 17798.360000000000000000000000 + assert balance.account_number == "VA00000000" + assert balance.account_type == AccountType.margin + assert balance.close_pl == -4813.000000000000000000 + assert balance.current_requirement == 2557.00000000000000000000 + assert balance.equity == 0 + assert balance.long_market_value == 11434.50000000000000000000 + assert balance.market_value == 11434.50000000000000000000 + assert balance.open_pl == 546.900000000000000000000000 + assert balance.option_long_value == 8877.5000000000000000000 + assert balance.option_requirement == 0 + assert balance.pending_orders_count == 0 + assert balance.short_market_value == 0 + assert balance.stock_long_value == 2557.00000000000000000000 + assert balance.total_cash == 6363.860000000000000000000000 + assert balance.uncleared_funds == 0 + assert balance.pending_cash == 0 + assert balance.margin.fed_call == 0 + assert balance.margin.maintenance_call == 0 + assert balance.margin.option_buying_power == 6363.860000000000000000000000 + assert balance.margin.stock_buying_power == 12727.7200000000000000 + assert balance.margin.stock_short_value == 0 + assert balance.margin.sweep == 0 + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/balances" + ) + + +@pytest.mark.asyncio +async def test_get_balance_cash(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return { + "balances": { + "option_short_value": 0, + "total_equity": 17798.360000000000000000000000, + "account_number": "VA00000000", + "account_type": "margin", + "close_pl": -4813.000000000000000000, + "current_requirement": 2557.00000000000000000000, + "equity": 0, + "long_market_value": 11434.50000000000000000000, + "market_value": 11434.50000000000000000000, + "open_pl": 546.900000000000000000000000, + "option_long_value": 8877.5000000000000000000, + "option_requirement": 0, + "pending_orders_count": 0, + "short_market_value": 0, + "stock_long_value": 2557.00000000000000000000, + "total_cash": 6363.860000000000000000000000, + "uncleared_funds": 0, + "pending_cash": 0, + "cash": { + "cash_available": 4343.38000000, + "sweep": 0, + "unsettled_funds": 1310.00000000, + }, + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + balance = await tradier_client.get_balance() + + assert balance.option_short_value == 0 + assert balance.total_equity == 17798.360000000000000000000000 + assert balance.account_number == "VA00000000" + assert balance.account_type == AccountType.margin + assert balance.close_pl == -4813.000000000000000000 + assert balance.current_requirement == 2557.00000000000000000000 + assert balance.equity == 0 + assert balance.long_market_value == 11434.50000000000000000000 + assert balance.market_value == 11434.50000000000000000000 + assert balance.open_pl == 546.900000000000000000000000 + assert balance.option_long_value == 8877.5000000000000000000 + assert balance.option_requirement == 0 + assert balance.pending_orders_count == 0 + assert balance.short_market_value == 0 + assert balance.stock_long_value == 2557.00000000000000000000 + assert balance.total_cash == 6363.860000000000000000000000 + assert balance.uncleared_funds == 0 + assert balance.pending_cash == 0 + + assert balance.cash.cash_available == 4343.38000000 + assert balance.cash.sweep == 0 + assert balance.cash.unsettled_funds == 1310.00000000 + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/balances" + ) + + +@pytest.mark.asyncio +async def test_get_balance_pdt(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return { + "balances": { + "option_short_value": 0, + "total_equity": 17798.360000000000000000000000, + "account_number": "VA00000000", + "account_type": "margin", + "close_pl": -4813.000000000000000000, + "current_requirement": 2557.00000000000000000000, + "equity": 0, + "long_market_value": 11434.50000000000000000000, + "market_value": 11434.50000000000000000000, + "open_pl": 546.900000000000000000000000, + "option_long_value": 8877.5000000000000000000, + "option_requirement": 0, + "pending_orders_count": 0, + "short_market_value": 0, + "stock_long_value": 2557.00000000000000000000, + "total_cash": 6363.860000000000000000000000, + "uncleared_funds": 0, + "pending_cash": 0, + "pdt": { + "fed_call": 0, + "maintenance_call": 0, + "option_buying_power": 6363.860000000000000000000000, + "stock_buying_power": 12727.7200000000000000, + "stock_short_value": 0, + }, + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + balance = await tradier_client.get_balance() + + assert balance.option_short_value == 0 + assert balance.total_equity == 17798.360000000000000000000000 + assert balance.account_number == "VA00000000" + assert balance.account_type == AccountType.margin + assert balance.close_pl == -4813.000000000000000000 + assert balance.current_requirement == 2557.00000000000000000000 + assert balance.equity == 0 + assert balance.long_market_value == 11434.50000000000000000000 + assert balance.market_value == 11434.50000000000000000000 + assert balance.open_pl == 546.900000000000000000000000 + assert balance.option_long_value == 8877.5000000000000000000 + assert balance.option_requirement == 0 + assert balance.pending_orders_count == 0 + assert balance.short_market_value == 0 + assert balance.stock_long_value == 2557.00000000000000000000 + assert balance.total_cash == 6363.860000000000000000000000 + assert balance.uncleared_funds == 0 + assert balance.pending_cash == 0 + + assert balance.pdt.fed_call == 0 + assert balance.pdt.maintenance_call == 0 + assert balance.pdt.option_buying_power == 6363.860000000000000000000000 + assert balance.pdt.stock_buying_power == 12727.7200000000000000 + assert balance.pdt.stock_short_value == 0 + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/balances" + ) From f0bee347bff58bd52ffc7b3bafe67aed25e8ffe5 Mon Sep 17 00:00:00 2001 From: Jiakuan Li Date: Fri, 19 Jan 2024 00:43:05 -0500 Subject: [PATCH 3/5] feat: :sparkles: implement get history interface --- README.md | 6 +- asynctradier/common/__init__.py | 46 ++++++ asynctradier/common/event.py | 58 +++++++ asynctradier/exceptions/__init__.py | 14 ++ asynctradier/tradier.py | 89 ++++++++++- tests/test_common.py | 102 +++++++++++++ tests/test_tradier.py | 228 +++++++++++++++++++++++++++- 7 files changed, 536 insertions(+), 7 deletions(-) create mode 100644 asynctradier/common/event.py diff --git a/README.md b/README.md index a4ece06..e0cb628 100644 --- a/README.md +++ b/README.md @@ -25,13 +25,13 @@ if your are using poetry ### Account -:white_square_button: Get User Profile +:white_check_mark: Get User Profile -:white_square_button: Get Balances +:white_check_mark: Get Balances :white_check_mark: Get Positions -:white_square_button: Get History +:white_check_mark: Get History :white_square_button: Get Gain/Loss diff --git a/asynctradier/common/__init__.py b/asynctradier/common/__init__.py index 6e2df4c..65bcbc5 100644 --- a/asynctradier/common/__init__.py +++ b/asynctradier/common/__init__.py @@ -193,3 +193,49 @@ class AccountType(StrEnum): cash = "cash" margin = "margin" pdt = "pdt" + + +class EventType(StrEnum): + """ + Represents the type of an event. + + Attributes: + trade (str): The event is a trade event. + journal (str): The event is a journal event. + option (str): The event is an option event. + ach (str): The event is an ACH event. + wire (str): The event is a wire event. + dividend (str): The event is a dividend event. + fee (str): The event is a fee event. + tax (str): The event is a tax event. + check (str): The event is a check event. + transfer (str): The event is a transfer event. + adjustment (str): The event is an adjustment event. + interest (str): The event is an interest event. + """ + + trade = "trade" + journal = "journal" + option = "option" + ach = "ach" + wire = "wire" + dividend = "dividend" + fee = "fee" + tax = "tax" + check = "check" + transfer = "transfer" + adjustment = "adjustment" + interest = "interest" + + +class TradeType(StrEnum): + """ + Represents the type of a trade. + + Attributes: + buy (str): The trade is a buy trade. + sell (str): The trade is a sell trade. + """ + + equity = "Equity" + option = "Option" diff --git a/asynctradier/common/event.py b/asynctradier/common/event.py new file mode 100644 index 0000000..86b6cdd --- /dev/null +++ b/asynctradier/common/event.py @@ -0,0 +1,58 @@ +from asynctradier.common import EventType, TradeType + + +class Event: + """ + Represents an event. + + Attributes: + amount (float): The amount of the event. + date (str): The date of the event. + type (EventType): The type of the event. + description (str): The description of the event. + commision (float): The commission of the event. + price (float): The price of the event. + quantity (float): The quantity of the event. + symbol (str): The symbol of the event. + trade_type (TradeType): The type of trade. + """ + + def __init__(self, **kwargs): + self.amount = float(kwargs.get("amount")) if kwargs.get("amount") else 0.0 + self.date = kwargs.get("date") + self.type = EventType(kwargs.get("type")) if kwargs.get("type") else None + + detail = kwargs.get(self.type.value, {}) + + self.description = detail.get("description") + self.commision = ( + float(detail.get("commision")) if detail.get("commision") else 0.0 + ) + self.price = float(detail.get("price")) if detail.get("price") else 0.0 + self.quantity = float(detail.get("quantity")) if detail.get("quantity") else 0.0 + self.symbol = detail.get("symbol") + self.trade_type = ( + TradeType(detail.get("trade_type")) if detail.get("trade_type") else None + ) + + def to_dict(self): + """ + Converts the Event object to a dictionary. + + Returns: + dict: A dictionary representation of the Event object. + """ + return { + "amount": self.amount, + "date": self.date, + "type": self.type.value, + "description": self.description, + "commision": self.commision, + "price": self.price, + "symbol": self.symbol, + "trade_type": self.trade_type.value, + "quantity": self.quantity, + } + + def __str__(self): + return f"Event(amount={self.amount}, date={self.date}, type={self.type}, description={self.description}, commision={self.commision}, price={self.price}, symbol={self.symbol}, trade_type={self.trade_type})" diff --git a/asynctradier/exceptions/__init__.py b/asynctradier/exceptions/__init__.py index ffe5233..ffcc22a 100644 --- a/asynctradier/exceptions/__init__.py +++ b/asynctradier/exceptions/__init__.py @@ -89,3 +89,17 @@ class APINotAvailable(Exception): def __init__(self, msg: str) -> None: super().__init__(f"API is not available. {msg}") + + +class InvalidDateFormat(Exception): + """ + Exception raised when the date format is not valid. + + Attributes: + date (str): The invalid date. + """ + + def __init__(self, date: str) -> None: + super().__init__( + f"Date format {date} is not valid. Valid values is: YYYY-MM-DD" + ) diff --git a/asynctradier/tradier.py b/asynctradier/tradier.py index d9c5eae..b7d6cc7 100644 --- a/asynctradier/tradier.py +++ b/asynctradier/tradier.py @@ -3,8 +3,16 @@ import websockets -from asynctradier.common import Duration, OptionType, OrderClass, OrderSide, OrderType +from asynctradier.common import ( + Duration, + EventType, + OptionType, + OrderClass, + OrderSide, + OrderType, +) from asynctradier.common.account_balance import AccountBalance +from asynctradier.common.event import Event from asynctradier.common.expiration import Expiration from asynctradier.common.option_contract import OptionContract from asynctradier.common.order import Order @@ -13,6 +21,7 @@ from asynctradier.common.user_profile import UserAccount from asynctradier.exceptions import ( APINotAvailable, + InvalidDateFormat, InvalidExiprationDate, InvalidOptionType, InvalidParameter, @@ -640,7 +649,9 @@ async def get_user_profile(self) -> List[UserAccount]: A list of UserProfile objects representing the user's profile information. """ if self.sandbox: - raise APINotAvailable("get user profile is only available in production") + raise APINotAvailable( + "please check the documentation for more details: https://documentation.tradier.com/brokerage-api/user/get-profile" + ) url = "/v1/user/profile" response = await self.session.get(url) @@ -677,3 +688,77 @@ async def get_balance(self) -> AccountBalance: return AccountBalance( **response["balances"], ) + + async def get_history( + self, + page: int = 1, + limit: int = 25, + event_type: Optional[EventType] = None, + start: Optional[str] = None, + end: Optional[str] = None, + symbol: Optional[str] = None, + exact_match: bool = False, + ) -> List[Event]: + if self.sandbox: + raise APINotAvailable( + "please check the documentation for more details: https://documentation.tradier.com/brokerage-api/accounts/get-account-balance" + ) + + if start is not None and not is_valid_expiration_date(start): + raise InvalidDateFormat(start) + + if end is not None and not is_valid_expiration_date(end): + raise InvalidDateFormat(end) + + if page is None or page < 1: + page = 1 + + if limit is None or limit < 1: + limit = 25 + + if exact_match is None: + exact_match = False + + url = f"/v1/accounts/{self.account_id}/history" + + params = { + "page": page, + "limit": limit, + "exactMatch": str(exact_match).lower(), + } + + if event_type is not None: + params["type"] = event_type.value + + if start is not None: + params["start"] = start + + if end is not None: + params["end"] = end + + if symbol is not None: + params["symbol"] = symbol + + response = await self.session.get(url, params=params) + + if response.get("history") is None: + return [] + + if response["history"].get("event") is None: + return [] + + if not isinstance(response["history"]["event"], list): + events = [response["history"]["event"]] + else: + events = response["history"]["event"] + + results: List[Event] = [] + + for event in events: + results.append( + Event( + **event, + ) + ) + + return results diff --git a/tests/test_common.py b/tests/test_common.py index 31a7b4e..e55efd2 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3,6 +3,7 @@ AccountType, Classification, Duration, + EventType, OptionType, OrderClass, OrderSide, @@ -16,6 +17,7 @@ MarginAccountBalanceDetails, PDTAccountBalanceDetails, ) +from asynctradier.common.event import Event from asynctradier.common.expiration import Expiration from asynctradier.common.option_contract import OptionContract from asynctradier.common.order import Order @@ -839,3 +841,103 @@ def test_balance_pdt(): assert balance.margin is None assert balance.cash is None + + +def test_event_trade(): + detail = { + "amount": 54.90, + "date": "2024-01-17T00:00:00Z", + "type": "trade", + "trade": { + "commission": 0.0000000000, + "description": "CALL TSLA 01/19/24 226.67", + "price": 0.550000, + "quantity": -1.00000000, + "symbol": "TSLA240119C00226670", + "trade_type": "option", + }, + } + + event = Event(**detail) + + assert event.amount == detail["amount"] + assert event.date == detail["date"] + assert event.type == EventType.trade + assert event.commision == detail["trade"]["commission"] + assert event.description == detail["trade"]["description"] + assert event.price == detail["trade"]["price"] + assert event.quantity == detail["trade"]["quantity"] + assert event.symbol == detail["trade"]["symbol"] + assert event.trade_type == detail["trade"]["trade_type"] + + +def test_event_ach(): + detail = { + "amount": 3000.00, + "date": "2023-12-19T00:00:00Z", + "type": "ach", + "ach": {"description": "ACH DEPOSIT", "quantity": 0.00000000}, + } + + event = Event(**detail) + + assert event.amount == detail["amount"] + assert event.date == detail["date"] + assert event.type == EventType.ach + assert event.description == detail["ach"]["description"] + assert event.quantity == detail["ach"]["quantity"] + + +def test_event_dividend(): + detail = { + "amount": 0.12, + "date": "2018-10-25T00:00:00Z", + "type": "dividend", + "dividend": {"description": "GENERAL ELECTRIC COMPANY", "quantity": 0.00000000}, + } + + event = Event(**detail) + + assert event.amount == detail["amount"] + assert event.date == detail["date"] + assert event.type == EventType.dividend + assert event.description == detail["dividend"]["description"] + assert event.quantity == detail["dividend"]["quantity"] + + +def test_event_option(): + detail = { + "amount": 0, + "date": "2018-09-21T00:00:00Z", + "type": "option", + "option": { + "option_type": "OPTEXP", + "description": "Expired", + "quantity": -1.00000000, + }, + } + + event = Event(**detail) + + assert event.amount == detail["amount"] + assert event.date == detail["date"] + assert event.type == EventType.option + assert event.description == detail["option"]["description"] + assert event.quantity == detail["option"]["quantity"] + + +def test_journal(): + detail = { + "amount": -3000.00, + "date": "2018-05-23T00:00:00Z", + "type": "journal", + "journal": {"description": "6YA-00005 TO 6YA-00102", "quantity": 0.00000000}, + } + + event = Event(**detail) + + assert event.amount == detail["amount"] + assert event.date == detail["date"] + assert event.type == EventType.journal + assert event.description == detail["journal"]["description"] + assert event.quantity == detail["journal"]["quantity"] diff --git a/tests/test_tradier.py b/tests/test_tradier.py index 4cb5a16..eabb977 100644 --- a/tests/test_tradier.py +++ b/tests/test_tradier.py @@ -2,9 +2,20 @@ import pytest -from asynctradier.common import AccountType, Duration, OptionType, OrderSide, OrderType +from asynctradier.common import ( + AccountType, + Duration, + EventType, + OptionType, + OrderSide, + OrderType, +) from asynctradier.common.option_contract import OptionContract -from asynctradier.exceptions import APINotAvailable, InvalidExiprationDate +from asynctradier.exceptions import ( + APINotAvailable, + InvalidDateFormat, + InvalidExiprationDate, +) from asynctradier.tradier import TradierClient @@ -1564,3 +1575,216 @@ def mock_get(path: str, params: dict = None): tradier_client.session.get.assert_called_once_with( f"/v1/accounts/{tradier_client.account_id}/balances" ) + + +@pytest.mark.asyncio +async def test_get_history_single(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return { + "history": { + "event": { + "amount": -3000.00, + "date": "2018-05-23T00:00:00Z", + "type": "journal", + "journal": { + "description": "6YA-00005 TO 6YA-00102", + "quantity": 0.00000000, + }, + } + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + history = await tradier_client.get_history() + + assert len(history) == 1 + + +@pytest.mark.asyncio +async def test_get_history_multiple(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return { + "history": { + "event": [ + { + "amount": -3000.00, + "date": "2018-05-23T00:00:00Z", + "type": "journal", + "journal": { + "description": "6YA-00005 TO 6YA-00102", + "quantity": 0.00000000, + }, + }, + { + "amount": 99.95, + "date": "2018-05-23T00:00:00Z", + "type": "trade", + "trade": { + "commission": 0.0000000000, + "description": "CALL GE 06\/22\/18 14", # noqa + "price": 1.000000, + "quantity": -1.00000000, + "symbol": "GE180622C00014000", + "trade_type": "Option", + }, + }, + ] + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + history = await tradier_client.get_history() + + assert len(history) == 2 + + +@pytest.mark.asyncio() +async def test_get_history_page(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + await tradier_client.get_history(page=10) + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/history", + params={"page": 10, "limit": 25, "exactMatch": "false"}, + ) + + +@pytest.mark.asyncio() +async def test_get_history_limit(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + await tradier_client.get_history(limit=10) + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/history", + params={"page": 1, "limit": 10, "exactMatch": "false"}, + ) + + +@pytest.mark.asyncio() +async def test_get_history_exact_match(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + await tradier_client.get_history(exact_match=True) + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/history", + params={"page": 1, "limit": 25, "exactMatch": "true"}, + ) + + +@pytest.mark.asyncio() +async def test_get_history_type(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + await tradier_client.get_history(event_type=EventType.trade) + + for event_type in EventType: + await tradier_client.get_history(event_type=event_type) + tradier_client.session.get.assert_called_with( + f"/v1/accounts/{tradier_client.account_id}/history", + params={ + "page": 1, + "limit": 25, + "type": event_type.value, + "exactMatch": "false", + }, + ) + + +@pytest.mark.asyncio() +async def test_get_history_start(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_history(start="2020-01-01") + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/history", + params={ + "page": 1, + "limit": 25, + "start": "2020-01-01", + "exactMatch": "false", + }, + ) + + try: + await tradier_client.get_history(start="2020/01/01") + except InvalidDateFormat: + assert True + + +@pytest.mark.asyncio() +async def test_get_history_end(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_history(end="2020-01-01") + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/history", + params={"page": 1, "limit": 25, "end": "2020-01-01", "exactMatch": "false"}, + ) + + try: + await tradier_client.get_history(end="2020/01/01") + except InvalidDateFormat: + assert True + + +@pytest.mark.asyncio() +async def test_get_history_symbol(mocker, tradier_client): + tradier_client.sandbox = False + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_history(symbol="AAPL") + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/history", + params={"page": 1, "limit": 25, "symbol": "AAPL", "exactMatch": "false"}, + ) + + +@pytest.mark.asyncio() +async def test_get_history_sanbox(mocker, tradier_client): + tradier_client.sandbox = True + + def mock_get(path: str, params: dict = None): + return {"history": {"event": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + try: + await tradier_client.get_history() + except APINotAvailable: + assert True From 5d88927c642e93fb2c89f2dfb35b9c4775943373 Mon Sep 17 00:00:00 2001 From: Jiakuan Li Date: Fri, 19 Jan 2024 08:45:10 -0500 Subject: [PATCH 4/5] feat: :sparkles: implement get gainloss interface --- asynctradier/common/__init__.py | 4 +- asynctradier/common/event.py | 4 +- asynctradier/common/gain_loss.py | 54 ++++++++ asynctradier/tradier.py | 100 ++++++++++++++ tests/test_common.py | 53 +++++++ tests/test_tradier.py | 229 +++++++++++++++++++++++++++++++ 6 files changed, 441 insertions(+), 3 deletions(-) create mode 100644 asynctradier/common/gain_loss.py diff --git a/asynctradier/common/__init__.py b/asynctradier/common/__init__.py index 65bcbc5..48a2577 100644 --- a/asynctradier/common/__init__.py +++ b/asynctradier/common/__init__.py @@ -237,5 +237,5 @@ class TradeType(StrEnum): sell (str): The trade is a sell trade. """ - equity = "Equity" - option = "Option" + equity = "equity" + option = "option" diff --git a/asynctradier/common/event.py b/asynctradier/common/event.py index 86b6cdd..00bec0b 100644 --- a/asynctradier/common/event.py +++ b/asynctradier/common/event.py @@ -32,7 +32,9 @@ def __init__(self, **kwargs): self.quantity = float(detail.get("quantity")) if detail.get("quantity") else 0.0 self.symbol = detail.get("symbol") self.trade_type = ( - TradeType(detail.get("trade_type")) if detail.get("trade_type") else None + TradeType(detail.get("trade_type").lower()) + if detail.get("trade_type") + else None ) def to_dict(self): diff --git a/asynctradier/common/gain_loss.py b/asynctradier/common/gain_loss.py new file mode 100644 index 0000000..fa7a1c0 --- /dev/null +++ b/asynctradier/common/gain_loss.py @@ -0,0 +1,54 @@ +class ProfitLoss: + """ + ProfitLoss class for storing profit/loss information for a security. + + Attributes: + close_date (str): Date the position was closed + cost (float): Total cost of the position + gain_loss (float): Gain or loss on the position + gain_loss_percent (float): Gain or loss represented as percent + open_date (str): Date the position was opened + proceeds (float): Total amount received for the position + quantity (float): Quantity of shares/contracts + symbol (str): Symbol of the security held + term (int): Term in months position was held + """ + + def __init__(self, **kwargs): + self.close_date = kwargs.get("close_date") + self.cost = float(kwargs.get("cost")) if kwargs.get("cost") else 0.0 + self.gain_loss = ( + float(kwargs.get("gain_loss")) if kwargs.get("gain_loss") else 0.0 + ) + self.gain_loss_percent = ( + float(kwargs.get("gain_loss_percent")) + if kwargs.get("gain_loss_percent") + else 0.0 + ) + self.open_date = kwargs.get("open_date") + self.proceeds = float(kwargs.get("proceeds")) if kwargs.get("proceeds") else 0.0 + self.quantity = float(kwargs.get("quantity")) if kwargs.get("quantity") else 0.0 + self.symbol = kwargs.get("symbol") + self.term = int(kwargs.get("term")) if kwargs.get("term") else 0 + + def to_dict(self): + """ + Converts the ProfitLoss object to a dictionary. + + Returns: + dict: A dictionary representation of the ProfitLoss object. + """ + return { + "close_date": self.close_date, + "cost": self.cost, + "gain_loss": self.gain_loss, + "gain_loss_percent": self.gain_loss_percent, + "open_date": self.open_date, + "proceeds": self.proceeds, + "quantity": self.quantity, + "symbol": self.symbol, + "term": self.term, + } + + def __str__(self): + return f"ProfitLoss(close_date={self.close_date}, cost={self.cost}, gain_loss={self.gain_loss}, gain_loss_percent={self.gain_loss_percent}, open_date={self.open_date}, proceeds={self.proceeds}, quantity={self.quantity}, symbol={self.symbol}, term={self.term})" diff --git a/asynctradier/tradier.py b/asynctradier/tradier.py index b7d6cc7..d79ead9 100644 --- a/asynctradier/tradier.py +++ b/asynctradier/tradier.py @@ -14,6 +14,7 @@ from asynctradier.common.account_balance import AccountBalance from asynctradier.common.event import Event from asynctradier.common.expiration import Expiration +from asynctradier.common.gain_loss import ProfitLoss from asynctradier.common.option_contract import OptionContract from asynctradier.common.order import Order from asynctradier.common.position import Position @@ -699,6 +700,21 @@ async def get_history( symbol: Optional[str] = None, exact_match: bool = False, ) -> List[Event]: + """ + Retrieves the account history. + + Args: + page (int, optional): The page number of the history to retrieve. Defaults to 1. + limit (int, optional): The number of events to retrieve per page. Defaults to 25. + event_type (EventType, optional): The type of event to retrieve. Defaults to None. + start (str, optional): The start date of the history to retrieve (YYYY-MM-DD). Defaults to None. + end (str, optional): The end date of the history to retrieve (YYYY-MM-DD). Defaults to None. + symbol (str, optional): The symbol of the event to retrieve. Defaults to None. + exact_match (bool, optional): Whether to perform an exact match on the symbol. Defaults to False. + + Returns: + List[Event]: A list of Event objects representing the account history. + """ if self.sandbox: raise APINotAvailable( "please check the documentation for more details: https://documentation.tradier.com/brokerage-api/accounts/get-account-balance" @@ -762,3 +778,87 @@ async def get_history( ) return results + + async def get_gainloss( + self, + page: int = 1, + limit: int = 25, + start: Optional[str] = None, + end: Optional[str] = None, + symbol: Optional[str] = None, + sort_by_close_date: bool = True, + desc: bool = True, + ) -> List[ProfitLoss]: + """ + Retrieves the gain/loss information for closed positions within a specified date range. + + Args: + page (int): The page number of the results to retrieve (default is 1). + limit (int): The maximum number of results per page (default is 25). + start (str, optional): The start date of the date range (format: "YYYY-MM-DD"). + end (str, optional): The end date of the date range (format: "YYYY-MM-DD"). + symbol (str, optional): The symbol of the positions to filter by. + sort_by_close_date (bool): Whether to sort the results by close date (default is False). + desc (bool): Whether to sort the results in descending order (default is True). + + Returns: + List[ProfitLoss]: A list of ProfitLoss objects representing the gain/loss information. + + Raises: + InvalidDateFormat: If the start or end date is not in the correct format. + + """ + + if start is not None and not is_valid_expiration_date(start): + raise InvalidDateFormat(start) + + if end is not None and not is_valid_expiration_date(end): + raise InvalidDateFormat(end) + + if page is None or page < 1: + page = 1 + + if limit is None or limit < 1: + limit = 25 + + url = f"/v1/accounts/{self.account_id}/gainloss" + + params = { + "page": page, + "limit": limit, + "sortBy": "closeDate" if sort_by_close_date else "openDate", + "sort": "desc" if desc else "asc", + } + + if start is not None: + params["start"] = start + + if end is not None: + params["end"] = end + + if symbol is not None: + params["symbol"] = symbol + + response = await self.session.get(url, params=params) + + if response.get("gainloss") is None: + return [] + + if response["gainloss"].get("closed_position") is None: + return [] + + if not isinstance(response["gainloss"]["closed_position"], list): + positions = [response["gainloss"]["closed_position"]] + else: + positions = response["gainloss"]["closed_position"] + + results: List[ProfitLoss] = [] + + for position in positions: + results.append( + ProfitLoss( + **position, + ) + ) + + return results diff --git a/tests/test_common.py b/tests/test_common.py index e55efd2..1809047 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -19,6 +19,7 @@ ) from asynctradier.common.event import Event from asynctradier.common.expiration import Expiration +from asynctradier.common.gain_loss import ProfitLoss from asynctradier.common.option_contract import OptionContract from asynctradier.common.order import Order from asynctradier.common.quote import Quote @@ -941,3 +942,55 @@ def test_journal(): assert event.type == EventType.journal assert event.description == detail["journal"]["description"] assert event.quantity == detail["journal"]["quantity"] + + +def test_gainloss_equity(): + detail = { + "close_date": "2018-09-19T00:00:00.000Z", + "cost": 913.95, + "gain_loss": 6.05, + "gain_loss_percent": 0.662, + "open_date": "2018-09-18T00:00:00.000Z", + "proceeds": 920.0, + "quantity": 100.0, + "symbol": "SNAP", + "term": 1, + } + + gainloss = ProfitLoss(**detail) + + assert gainloss.close_date == detail["close_date"] + assert gainloss.cost == detail["cost"] + assert gainloss.gain_loss == detail["gain_loss"] + assert gainloss.gain_loss_percent == detail["gain_loss_percent"] + assert gainloss.open_date == detail["open_date"] + assert gainloss.proceeds == detail["proceeds"] + assert gainloss.quantity == detail["quantity"] + assert gainloss.symbol == detail["symbol"] + assert gainloss.term == detail["term"] + + +def test_gainloss_option(): + detail = { + "close_date": "2018-06-25T00:00:00.000Z", + "cost": 25.05, + "gain_loss": -25.05, + "gain_loss_percent": -100.0, + "open_date": "2018-06-22T00:00:00.000Z", + "proceeds": 0.0, + "quantity": 1.0, + "symbol": "SPY180625C00276000", + "term": 3, + } + + gainloss = ProfitLoss(**detail) + + assert gainloss.close_date == detail["close_date"] + assert gainloss.cost == detail["cost"] + assert gainloss.gain_loss == detail["gain_loss"] + assert gainloss.gain_loss_percent == detail["gain_loss_percent"] + assert gainloss.open_date == detail["open_date"] + assert gainloss.proceeds == detail["proceeds"] + assert gainloss.quantity == detail["quantity"] + assert gainloss.symbol == detail["symbol"] + assert gainloss.term == detail["term"] diff --git a/tests/test_tradier.py b/tests/test_tradier.py index eabb977..6d9494b 100644 --- a/tests/test_tradier.py +++ b/tests/test_tradier.py @@ -1788,3 +1788,232 @@ def mock_get(path: str, params: dict = None): await tradier_client.get_history() except APINotAvailable: assert True + + +@pytest.mark.asyncio() +async def test_get_gainloss_single(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return { + "gainloss": { + "closed_position": { + "close_date": "2018-10-31T00:00:00.000Z", + "cost": 12.7, + "gain_loss": -2.64, + "gain_loss_percent": -20.7874, + "open_date": "2018-06-19T00:00:00.000Z", + "proceeds": 10.06, + "quantity": 1.0, + "symbol": "GE", + "term": 134, + } + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + gainloss = await tradier_client.get_gainloss() + assert len(gainloss) == 1 + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={"page": 1, "limit": 25, "sortBy": "closeDate", "sort": "desc"}, + ) + + +@pytest.mark.asyncio() +async def test_get_gainloss_multiple(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return { + "gainloss": { + "closed_position": [ + { + "close_date": "2018-10-31T00:00:00.000Z", + "cost": 12.7, + "gain_loss": -2.64, + "gain_loss_percent": -20.7874, + "open_date": "2018-06-19T00:00:00.000Z", + "proceeds": 10.06, + "quantity": 1.0, + "symbol": "GE", + "term": 134, + }, + { + "close_date": "2018-10-31T00:00:00.000Z", + "cost": 12.7, + "gain_loss": -2.64, + "gain_loss_percent": -20.7874, + "open_date": "2018-06-19T00:00:00.000Z", + "proceeds": 10.06, + "quantity": 1.0, + "symbol": "GE", + "term": 134, + }, + ] + } + } + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + gainloss = await tradier_client.get_gainloss() + assert len(gainloss) == 2 + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={"page": 1, "limit": 25, "sortBy": "closeDate", "sort": "desc"}, + ) + + +@pytest.mark.asyncio() +async def test_get_gainloss_page(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return {"gainloss": {"closed_position": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + await tradier_client.get_gainloss(page=10) + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={"page": 10, "limit": 25, "sortBy": "closeDate", "sort": "desc"}, + ) + + +@pytest.mark.asyncio() +async def test_get_gainloss_limit(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return {"gainloss": {"closed_position": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + await tradier_client.get_gainloss(limit=10) + + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={"page": 1, "limit": 10, "sortBy": "closeDate", "sort": "desc"}, + ) + + +@pytest.mark.asyncio() +async def test_get_gainloss_sort_by(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return {"gainloss": {"closed_position": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_gainloss(sort_by_close_date=True) + tradier_client.session.get.assert_called_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={ + "page": 1, + "limit": 25, + "sortBy": "closeDate", + "sort": "desc", + }, + ) + + await tradier_client.get_gainloss(sort_by_close_date=False) + tradier_client.session.get.assert_called_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={ + "page": 1, + "limit": 25, + "sortBy": "openDate", + "sort": "desc", + }, + ) + + +@pytest.mark.asyncio() +async def test_get_gainloss_sort(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return {"gainloss": {"closed_position": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_gainloss(desc=True) + tradier_client.session.get.assert_called_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={ + "page": 1, + "limit": 25, + "sortBy": "closeDate", + "sort": "desc", + }, + ) + + await tradier_client.get_gainloss(desc=False) + tradier_client.session.get.assert_called_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={ + "page": 1, + "limit": 25, + "sortBy": "closeDate", + "sort": "asc", + }, + ) + + +@pytest.mark.asyncio() +async def test_get_gainloss_start(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return {"gainloss": {"closed_position": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_gainloss(start="2020-01-01") + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={ + "page": 1, + "limit": 25, + "sortBy": "closeDate", + "sort": "desc", + "start": "2020-01-01", + }, + ) + + try: + await tradier_client.get_gainloss(start="2020/01/01") + except InvalidDateFormat: + assert True + + +@pytest.mark.asyncio() +async def test_get_gainloss_end(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return {"gainloss": {"closed_position": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_gainloss(end="2020-01-01") + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={ + "page": 1, + "limit": 25, + "sortBy": "closeDate", + "sort": "desc", + "end": "2020-01-01", + }, + ) + + try: + await tradier_client.get_gainloss(end="2020/01/01") + except InvalidDateFormat: + assert True + + +@pytest.mark.asyncio() +async def test_get_gainloss_symbol(mocker, tradier_client): + def mock_get(path: str, params: dict = None): + return {"gainloss": {"closed_position": []}} + + mocker.patch.object(tradier_client.session, "get", side_effect=mock_get) + + await tradier_client.get_gainloss(symbol="AAPL") + tradier_client.session.get.assert_called_once_with( + f"/v1/accounts/{tradier_client.account_id}/gainloss", + params={ + "page": 1, + "limit": 25, + "sortBy": "closeDate", + "sort": "desc", + "symbol": "AAPL", + }, + ) From 65872569a2ba3952b86a69606496371de2d298dd Mon Sep 17 00:00:00 2001 From: Jiakuan Li Date: Fri, 19 Jan 2024 08:45:47 -0500 Subject: [PATCH 5/5] docs: :memo: update readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e0cb628..a21c748 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ if your are using poetry :white_check_mark: Get History -:white_square_button: Get Gain/Loss +:white_check_mark: Get Gain/Loss :white_check_mark: Get Orders