From 5785d2e139825f5a65b1416e95b3cb6fd713eb77 Mon Sep 17 00:00:00 2001 From: foolcage <5533061@qq.com> Date: Sat, 20 Jul 2024 22:00:35 +0800 Subject: [PATCH] Use pydantic model to replace marshmallow model in trader --- api-tests/event/ignore_stock_news.http | 10 +++++ requirements.txt | 2 - setup.py | 2 - src/zvt/domain/misc/stock_news.py | 4 +- src/zvt/recorders/em/em_api.py | 1 + src/zvt/trader/sim_account.py | 37 ++++-------------- src/zvt/trader/trader_models.py | 52 ++++++++++++++++++++++++++ src/zvt/trading/trading_models.py | 2 +- 8 files changed, 74 insertions(+), 36 deletions(-) create mode 100644 api-tests/event/ignore_stock_news.http create mode 100644 src/zvt/trader/trader_models.py diff --git a/api-tests/event/ignore_stock_news.http b/api-tests/event/ignore_stock_news.http new file mode 100644 index 00000000..0e533764 --- /dev/null +++ b/api-tests/event/ignore_stock_news.http @@ -0,0 +1,10 @@ +POST http://127.0.0.1:8090/api/event/ignore_stock_news +accept: application/json +Content-Type: application/json + +{ + "news_id": "stock_sz_000034_2024-07-17 16:08:17" +} + + + diff --git a/requirements.txt b/requirements.txt index b0158435..b40dbe9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,6 @@ pydantic==2.6.4 arrow==1.2.3 openpyxl==3.1.1 demjson3==3.0.6 -marshmallow-sqlalchemy==1.0.0 -marshmallow==3.21.1 plotly==5.13.0 dash==2.8.1 jqdatapy==0.1.8 diff --git a/setup.py b/setup.py index c64f5509..ade1b65b 100644 --- a/setup.py +++ b/setup.py @@ -47,8 +47,6 @@ "arrow==1.2.3", "openpyxl==3.1.1", "demjson3==3.0.6", - "marshmallow-sqlalchemy==1.0.0", - "marshmallow==3.21.1", "plotly==5.13.0", "dash==2.8.1", "jqdatapy==0.1.8", diff --git a/src/zvt/domain/misc/stock_news.py b/src/zvt/domain/misc/stock_news.py index 41d32587..24c407fe 100644 --- a/src/zvt/domain/misc/stock_news.py +++ b/src/zvt/domain/misc/stock_news.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from sqlalchemy import Column, String, JSON +from sqlalchemy import Column, String, JSON, Boolean from sqlalchemy.orm import declarative_base from zvt.contract import Mixin @@ -21,6 +21,8 @@ class StockNews(NewsBase, Mixin): news_content = Column(String) #: 新闻解读 news_analysis = Column(JSON) + #: 用户设置为忽略 + ignore_by_user = Column(Boolean, default=False) register_schema(providers=["em"], db_name="stock_news", schema_base=NewsBase, entity_type="stock") diff --git a/src/zvt/recorders/em/em_api.py b/src/zvt/recorders/em/em_api.py index d2c8d0f1..4301e1ff 100644 --- a/src/zvt/recorders/em/em_api.py +++ b/src/zvt/recorders/em/em_api.py @@ -745,6 +745,7 @@ def get_news(entity_id, ps=200, index=1, start_timestamp=None, session=None, lat "news_code": item.get("Art_Code", ""), "news_url": item.get("Art_Url", ""), "news_title": item.get("Art_Title", ""), + "ignore_by_user": False, } for index, item in enumerate(json_result) if not start_timestamp diff --git a/src/zvt/trader/sim_account.py b/src/zvt/trader/sim_account.py index 762091e3..cf54d06c 100644 --- a/src/zvt/trader/sim_account.py +++ b/src/zvt/trader/sim_account.py @@ -3,8 +3,6 @@ import math from typing import List, Optional -from marshmallow_sqlalchemy import SQLAlchemyAutoSchema - from zvt.api.kdata import get_kdata, get_kdata_schema from zvt.contract import IntervalLevel, TradableEntity, AdjustType from zvt.contract.api import get_db_session, decode_entity_id @@ -17,29 +15,13 @@ WrongKdataError, ) from zvt.trader.trader_info_api import get_trader_info, clear_trader +from zvt.trader.trader_models import AccountStatsModel, PositionModel from zvt.trader.trader_schemas import AccountStats, Position, Order, TraderInfo from zvt.utils.pd_utils import pd_is_not_null from zvt.utils.time_utils import to_pd_timestamp, to_time_str, TIME_FORMAT_ISO8601, is_same_date from zvt.utils.utils import fill_domain_from_dict -# FIXME:better way for schema<->domain,now just dump to schema and use dict['field'] for operation -class AccountDayStatsSchema(SQLAlchemyAutoSchema): - class Meta: - model = AccountStats - include_relationships = True - - -class PositionSchema(SQLAlchemyAutoSchema): - class Meta: - model = Position - include_relationships = True - - -account_stats_schema = AccountDayStatsSchema() -position_schema = PositionSchema() - - class SimAccountService(AccountService): def __init__( self, @@ -77,7 +59,6 @@ def __init__( self.real_time = real_time self.kdata_use_begin_time = kdata_use_begin_time - self.account = None self.account = self.init_account() account_info = ( @@ -149,20 +130,16 @@ def load_account(self) -> AccountStats: latest_record: AccountStats = records[0] # create new orm object from latest record - account_dict = account_stats_schema.dump(latest_record) - del account_dict["id"] - del account_dict["positions"] + account_stats_model = AccountStatsModel.from_orm(latest_record) account = AccountStats() - fill_domain_from_dict(account, account_dict) + fill_domain_from_dict(account, account_stats_model.model_dump(exclude={"id", "positions"})) positions: List[Position] = [] for position_domain in latest_record.positions: - position_dict = position_schema.dump(position_domain) - self.logger.debug("current position:{}".format(position_dict)) - del position_dict["id"] - del position_dict["account_stats"] + position_model = PositionModel.from_orm(position_domain) + self.logger.debug("current position:{}".format(position_model)) position = Position() - fill_domain_from_dict(position, position_dict) + fill_domain_from_dict(position, position_model.model_dump()) positions.append(position) account.positions = positions @@ -571,4 +548,4 @@ def order_by_amount( # the __all__ is generated -__all__ = ["AccountDayStatsSchema", "PositionSchema", "AccountService", "SimAccountService"] +__all__ = ["AccountService", "SimAccountService"] diff --git a/src/zvt/trader/trader_models.py b/src/zvt/trader/trader_models.py new file mode 100644 index 00000000..cbe3064f --- /dev/null +++ b/src/zvt/trader/trader_models.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +from typing import List + +from zvt.contract.model import MixinModel + + +class PositionModel(MixinModel): + #: 机器人名字 + trader_name: str + #: 做多数量 + long_amount: float + #: 可平多数量 + available_long: float + #: 平均做多价格 + average_long_price: float + #: 做空数量 + short_amount: float + #: 可平空数量 + available_short: float + #: 平均做空价格 + average_short_price: float + #: 盈亏 + profit: float + #: 盈亏比例 + profit_rate: float + #: 市值 或者 占用的保证金(方便起见,总是100%) + value: float + #: 交易类型(0代表T+0,1代表T+1) + trading_t: int + + +class AccountStatsModel(MixinModel): + #: 投入金额 + input_money: float + #: 机器人名字 + trader_name: str + #: 具体仓位 + positions: List[PositionModel] + #: 市值 + value: float + #: 可用现金 + cash: float + #: value + cash + all_value: float + + #: 盈亏 + profit: float + #: 盈亏比例 + profit_rate: float + + #: 收盘计算 + closing: bool diff --git a/src/zvt/trading/trading_models.py b/src/zvt/trading/trading_models.py index bc0e2cfd..ddf5da3d 100644 --- a/src/zvt/trading/trading_models.py +++ b/src/zvt/trading/trading_models.py @@ -17,7 +17,7 @@ class QueryStockQuoteSettingModel(CustomModel): main_tags: Optional[List[str]] = Field(default=None) -class BuildQueryStockQuoteSettingModel(BaseModel): +class BuildQueryStockQuoteSettingModel(CustomModel): stock_pool_name: str main_tags: Optional[List[str]] = Field(default=None)