Skip to content

Commit

Permalink
Merge pull request #17 from jiak94/watchlist
Browse files Browse the repository at this point in the history
Support Watchlist API
  • Loading branch information
jiak94 authored May 16, 2024
2 parents c28f7c5 + 035683f commit 28addf4
Show file tree
Hide file tree
Showing 6 changed files with 467 additions and 1 deletion.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,19 @@ if your are using poetry
:white_check_mark: Market WebSocket

:white_check_mark: Account WebSocket

### Watchlist

:white_check_mark: Get Watchlists

:white_check_mark: Get Watchlist

:white_check_mark: Create Watchlist

:white_check_mark: Update Watchlist

:white_check_mark: Delete Watchlist

:white_check_mark: Add Symbols

:white_check_mark: Remove a Symbol
130 changes: 130 additions & 0 deletions asynctradier/clients/watchlist_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import List

from asynctradier.common.watchlist import Watchlist
from asynctradier.utils.webutils import WebUtil


class WatchlistClient:
"""
A client for interacting with the watchlist API.
Args:
session (WebUtil): The session object used for making HTTP requests.
account_id (str): The ID of the account.
token (str): The authentication token.
sandbox (bool, optional): Whether to use the sandbox environment. Defaults to False.
"""

def __init__(
self, session: WebUtil, account_id: str, token: str, sandbox: bool = False
) -> None:
self.session = session
self.account_id = account_id
self.token = token
self.sandbox = sandbox

async def get_watchlists(self) -> List[Watchlist]:
"""
Get all watchlists for the account.
Returns:
List[Watchlist]: A list of Watchlist objects.
"""
url = "/v1/watchlists"
response = await self.session.get(url)
watchlists = response.get("watchlists", {}).get("watchlist", [])
return [Watchlist(**watchlist) for watchlist in watchlists]

async def get_watchlist(self, watchlist_id: str) -> Watchlist:
"""
Get a specific watchlist by ID.
Args:
watchlist_id (str): The ID of the watchlist.
Returns:
Watchlist: The Watchlist object.
"""
url = f"/v1/watchlists/{watchlist_id}"
response = await self.session.get(url)
return Watchlist(**response.get("watchlist"))

async def create_watchlist(self, name: str, symbols: List[str]) -> Watchlist:
"""
Create a new watchlist.
Args:
name (str): The name of the watchlist.
symbols (List[str]): A list of symbols to add to the watchlist.
Returns:
Watchlist: The Watchlist object.
"""
url = "/v1/watchlists"
data = {"name": name, "symbols": ",".join(symbols).upper()}
response = await self.session.post(url, data=data)
return Watchlist(**response.get("watchlist"))

async def update_watchlist(
self, watchlist_id: str, name: str, symbols: List[str]
) -> Watchlist:
"""
Update an existing watchlist.
Args:
watchlist_id (str): The ID of the watchlist.
name (str): The new name of the watchlist.
symbols (List[str]): A list of symbols to add to the watchlist.
Returns:
Watchlist: The Watchlist object.
"""
url = f"/v1/watchlists/{watchlist_id}"
data = {"name": name, "symbols": ",".join(symbols).upper()}
response = await self.session.put(url, data=data)
return Watchlist(**response.get("watchlist"))

async def delete_watchlist(self, watchlist_id: str) -> None:
"""
Delete a watchlist by ID.
Args:
watchlist_id (str): The ID of the watchlist.
"""
url = f"/v1/watchlists/{watchlist_id}"
await self.session.delete(url)

async def add_symbols_to_watchlist(
self, watchlist_id: str, symbols: List[str]
) -> Watchlist:
"""
Add symbols to an existing watchlist.
Args:
watchlist_id (str): The ID of the watchlist.
symbols (List[str]): A list of symbols to add to the watchlist.
Returns:
Watchlist: The Watchlist object.
"""
url = f"/v1/watchlists/{watchlist_id}/symbols"
data = {"symbols": ",".join(symbols).upper()}
response = await self.session.post(url, data=data)
return Watchlist(**response.get("watchlist"))

async def remove_symbol_from_watchlist(
self, watchlist_id: str, symbol: str
) -> Watchlist:
"""
Remove a symbol from an existing watchlist.
Args:
watchlist_id (str): The ID of the watchlist.
symbol (str): The symbol to remove from the watchlist.
Returns:
Watchlist: The Watchlist object.
"""
url = f"/v1/watchlists/{watchlist_id}/symbols/{symbol.upper()}"
response = await self.session.delete(url)
return Watchlist(**response.get("watchlist"))
32 changes: 32 additions & 0 deletions asynctradier/common/watchlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
class WatchlistItem:
"""
Represents an item in a watchlist.
Attributes:
symbol (str): The symbol of the item.
id (str): The ID of the item.
"""

def __init__(self, **kwargs):
self.symbol = kwargs.get("symbol")
self.id = kwargs.get("id")


class Watchlist:
"""
Represents a watchlist.
Attributes:
id (str): The ID of the watchlist.
name (str): The name of the watchlist.
public_id (str): The public ID of the watchlist.
items (list): A list of WatchlistItem objects representing the items in the watchlist.
"""

def __init__(self, **kwargs):
self.id = kwargs.get("id")
self.name = kwargs.get("name")
self.public_id = kwargs.get("public_id")
self.items = [
WatchlistItem(**item) for item in kwargs.get("items", {}).get("item", [])
]
5 changes: 4 additions & 1 deletion asynctradier/tradier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from asynctradier.clients.marketdata_client import MarketDataClient
from asynctradier.clients.streaming_client import StreamingClient
from asynctradier.clients.trading_client import TradingClient
from asynctradier.clients.watchlist_client import WatchlistClient
from asynctradier.utils.webutils import WebUtil


class TradierClient(AccountClient, TradingClient, MarketDataClient, StreamingClient):
class TradierClient(
AccountClient, TradingClient, MarketDataClient, StreamingClient, WatchlistClient
):
"""
A client for interacting with the Tradier API.
Expand Down
49 changes: 49 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from asynctradier.common.quote import Quote
from asynctradier.common.security import Security
from asynctradier.common.user_profile import UserAccount
from asynctradier.common.watchlist import Watchlist, WatchlistItem
from asynctradier.exceptions import InvalidExiprationDate, InvalidOptionType


Expand Down Expand Up @@ -1267,3 +1268,51 @@ def test_etb():
assert etb.exchange == detail["exchange"]
assert etb.type == SecurityType.stock
assert etb.description == detail["description"]


def test_watchlist_item():
detail = {"symbol": "AAPL", "id": "aapl"}
item = WatchlistItem(**detail)
assert item.symbol == "AAPL"
assert item.id == "aapl"


def test_watchlist():
detail = {
"name": "My Watchlist",
"id": "my_watchlist",
"public_id": "public-6f8f625wti",
"items": {
"item": [
{"symbol": "AAPL", "id": "aapl"},
{"symbol": "IBM", "id": "ibm"},
{"symbol": "NFLX", "id": "nflx"},
]
},
}
watchlist = Watchlist(**detail)

assert watchlist.id == "my_watchlist"
assert watchlist.name == "My Watchlist"
assert watchlist.public_id == "public-6f8f625wti"
assert len(watchlist.items) == 3
assert isinstance(watchlist.items[0], WatchlistItem)
assert watchlist.items[0].symbol == "AAPL"
assert watchlist.items[0].id == "aapl"
assert watchlist.items[1].symbol == "IBM"
assert watchlist.items[1].id == "ibm"
assert watchlist.items[2].symbol == "NFLX"
assert watchlist.items[2].id == "nflx"


def test_watchlist_no_items():
detail = {
"name": "default",
"id": "default",
"public_id": "public-6f8f625wti",
}
watchlist = Watchlist(**detail)

assert watchlist.id == "default"
assert watchlist.name == "default"
assert watchlist.public_id == "public-6f8f625wti"
Loading

0 comments on commit 28addf4

Please sign in to comment.