Skip to content

Commit

Permalink
Use pydantic model to replace marshmallow model in trader
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Jul 20, 2024
1 parent 6358693 commit 5785d2e
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 36 deletions.
10 changes: 10 additions & 0 deletions api-tests/event/ignore_stock_news.http
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
POST http://127.0.0.1:8090/api/event/ignore_stock_news
accept: application/json
Content-Type: application/json

{
"news_id": "stock_sz_000034_2024-07-17 16:08:17"
}



2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ pydantic==2.6.4
arrow==1.2.3
openpyxl==3.1.1
demjson3==3.0.6
marshmallow-sqlalchemy==1.0.0
marshmallow==3.21.1
plotly==5.13.0
dash==2.8.1
jqdatapy==0.1.8
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@
"arrow==1.2.3",
"openpyxl==3.1.1",
"demjson3==3.0.6",
"marshmallow-sqlalchemy==1.0.0",
"marshmallow==3.21.1",
"plotly==5.13.0",
"dash==2.8.1",
"jqdatapy==0.1.8",
Expand Down
4 changes: 3 additions & 1 deletion src/zvt/domain/misc/stock_news.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from sqlalchemy import Column, String, JSON
from sqlalchemy import Column, String, JSON, Boolean
from sqlalchemy.orm import declarative_base

from zvt.contract import Mixin
Expand All @@ -21,6 +21,8 @@ class StockNews(NewsBase, Mixin):
news_content = Column(String)
#: 新闻解读
news_analysis = Column(JSON)
#: 用户设置为忽略
ignore_by_user = Column(Boolean, default=False)


register_schema(providers=["em"], db_name="stock_news", schema_base=NewsBase, entity_type="stock")
Expand Down
1 change: 1 addition & 0 deletions src/zvt/recorders/em/em_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ def get_news(entity_id, ps=200, index=1, start_timestamp=None, session=None, lat
"news_code": item.get("Art_Code", ""),
"news_url": item.get("Art_Url", ""),
"news_title": item.get("Art_Title", ""),
"ignore_by_user": False,
}
for index, item in enumerate(json_result)
if not start_timestamp
Expand Down
37 changes: 7 additions & 30 deletions src/zvt/trader/sim_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import math
from typing import List, Optional

from marshmallow_sqlalchemy import SQLAlchemyAutoSchema

from zvt.api.kdata import get_kdata, get_kdata_schema
from zvt.contract import IntervalLevel, TradableEntity, AdjustType
from zvt.contract.api import get_db_session, decode_entity_id
Expand All @@ -17,29 +15,13 @@
WrongKdataError,
)
from zvt.trader.trader_info_api import get_trader_info, clear_trader
from zvt.trader.trader_models import AccountStatsModel, PositionModel
from zvt.trader.trader_schemas import AccountStats, Position, Order, TraderInfo
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import to_pd_timestamp, to_time_str, TIME_FORMAT_ISO8601, is_same_date
from zvt.utils.utils import fill_domain_from_dict


# FIXME:better way for schema<->domain,now just dump to schema and use dict['field'] for operation
class AccountDayStatsSchema(SQLAlchemyAutoSchema):
class Meta:
model = AccountStats
include_relationships = True


class PositionSchema(SQLAlchemyAutoSchema):
class Meta:
model = Position
include_relationships = True


account_stats_schema = AccountDayStatsSchema()
position_schema = PositionSchema()


class SimAccountService(AccountService):
def __init__(
self,
Expand Down Expand Up @@ -77,7 +59,6 @@ def __init__(
self.real_time = real_time
self.kdata_use_begin_time = kdata_use_begin_time

self.account = None
self.account = self.init_account()

account_info = (
Expand Down Expand Up @@ -149,20 +130,16 @@ def load_account(self) -> AccountStats:
latest_record: AccountStats = records[0]

# create new orm object from latest record
account_dict = account_stats_schema.dump(latest_record)
del account_dict["id"]
del account_dict["positions"]
account_stats_model = AccountStatsModel.from_orm(latest_record)
account = AccountStats()
fill_domain_from_dict(account, account_dict)
fill_domain_from_dict(account, account_stats_model.model_dump(exclude={"id", "positions"}))

positions: List[Position] = []
for position_domain in latest_record.positions:
position_dict = position_schema.dump(position_domain)
self.logger.debug("current position:{}".format(position_dict))
del position_dict["id"]
del position_dict["account_stats"]
position_model = PositionModel.from_orm(position_domain)
self.logger.debug("current position:{}".format(position_model))
position = Position()
fill_domain_from_dict(position, position_dict)
fill_domain_from_dict(position, position_model.model_dump())
positions.append(position)

account.positions = positions
Expand Down Expand Up @@ -571,4 +548,4 @@ def order_by_amount(


# the __all__ is generated
__all__ = ["AccountDayStatsSchema", "PositionSchema", "AccountService", "SimAccountService"]
__all__ = ["AccountService", "SimAccountService"]
52 changes: 52 additions & 0 deletions src/zvt/trader/trader_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
from typing import List

from zvt.contract.model import MixinModel


class PositionModel(MixinModel):
#: 机器人名字
trader_name: str
#: 做多数量
long_amount: float
#: 可平多数量
available_long: float
#: 平均做多价格
average_long_price: float
#: 做空数量
short_amount: float
#: 可平空数量
available_short: float
#: 平均做空价格
average_short_price: float
#: 盈亏
profit: float
#: 盈亏比例
profit_rate: float
#: 市值 或者 占用的保证金(方便起见,总是100%)
value: float
#: 交易类型(0代表T+0,1代表T+1)
trading_t: int


class AccountStatsModel(MixinModel):
#: 投入金额
input_money: float
#: 机器人名字
trader_name: str
#: 具体仓位
positions: List[PositionModel]
#: 市值
value: float
#: 可用现金
cash: float
#: value + cash
all_value: float

#: 盈亏
profit: float
#: 盈亏比例
profit_rate: float

#: 收盘计算
closing: bool
2 changes: 1 addition & 1 deletion src/zvt/trading/trading_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class QueryStockQuoteSettingModel(CustomModel):
main_tags: Optional[List[str]] = Field(default=None)


class BuildQueryStockQuoteSettingModel(BaseModel):
class BuildQueryStockQuoteSettingModel(CustomModel):
stock_pool_name: str
main_tags: Optional[List[str]] = Field(default=None)

Expand Down

0 comments on commit 5785d2e

Please sign in to comment.