Skip to content

Commit

Permalink
Add get_stock_tag_options api
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Jul 23, 2024
1 parent 877e8da commit db69c41
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 29 deletions.
4 changes: 4 additions & 0 deletions api-tests/tag/get_stock_tag_options.http
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
GET http://127.0.0.1:8090/api/work/get_stock_tag_options?entity_id=stock_sh_600733
accept: application/json


58 changes: 58 additions & 0 deletions src/zvt/broker/qmt/data_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# -*- coding: utf-8 -*-
import logging

import pandas as pd
from xtquant import xtdata

from zvt import init_log

logger = logging.getLogger(__name__)


def download_data():
period = "1d"
xtdata.download_sector_data()
stock_codes = xtdata.get_stock_list_in_sector("沪深A股")
stock_codes = sorted(stock_codes)
count = len(stock_codes)

for index, stock_code in enumerate(stock_codes):
logger.info(f"run to {index + 1}/{count}")

xtdata.download_history_data(stock_code, period=period)
logger.info(f"download {stock_code} {period} kdata ok")
records = xtdata.get_market_data(
stock_list=[stock_code],
period=period,
count=5,
dividend_type="front",
fill_data=False,
)
dfs = []
for col in records:
df = records[col].T
df.columns = [col]
dfs.append(df)
kdatas = pd.concat(dfs, axis=1)
logger.info(kdatas)

start_time = kdatas.index.to_list()[0]
xtdata.download_history_data(stock_code, period="tick", start_time=start_time)
logger.info(f"download {stock_code} tick from {start_time} ok")
# records = xtdata.get_market_data(
# stock_list=[stock_code],
# period="tick",
# count=5,
# fill_data=False,
# )
# logger.info(records[stock_code])

xtdata.download_financial_data2(
stock_list=stock_codes, table_list=["Capital"], start_time="", end_time="", callback=lambda x: print(x)
)
logger.info("download capital data ok")


if __name__ == "__main__":
init_log("qmt_data_manager.log")
download_data()
18 changes: 16 additions & 2 deletions src/zvt/broker/qmt/qmt_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from zvt.contract import IntervalLevel, AdjustType
from zvt.contract.api import decode_entity_id, df_to_db, get_db_session
from zvt.domain import StockQuote, Stock
from zvt.domain.quotes.stock.stock_quote import Stock1mQuote
from zvt.utils.pd_utils import pd_is_not_null
from zvt.utils.time_utils import to_time_str, current_date, to_pd_timestamp, now_pd_timestamp
from zvt.utils.time_utils import to_time_str, current_date, to_pd_timestamp, now_pd_timestamp, TIME_FORMAT_MINUTE

# https://dict.thinktrader.net/nativeApi/start_now.html?id=e2M5nZ

Expand Down Expand Up @@ -197,6 +198,8 @@ def on_data(datas, stock_df=entity_df):
lambda se: "{}_{}".format(se["entity_id"], to_time_str(se["timestamp"])), axis=1
)

df["volume"] = df["pvolume"]
df["avg_price"] = df["turnover"] / df["volume"]
# 换手率
df["turnover_rate"] = df["pvolume"] / df["float_volume"]
# 涨跌幅
Expand Down Expand Up @@ -224,6 +227,11 @@ def on_data(datas, stock_df=entity_df):
cost_time = time.time() - start_time
logger.info(f"Quotes cost_time:{cost_time} for {len(datas.keys())} stocks")

df["id"] = df[["entity_id", "timestamp"]].apply(
lambda se: "{}_{}".format(se["entity_id"], to_time_str(se["timestamp"], TIME_FORMAT_MINUTE)), axis=1
)
df_to_db(df, data_schema=Stock1mQuote, provider="qmt", force_update=True, drop_duplicates=False)

return on_data


Expand Down Expand Up @@ -256,13 +264,19 @@ def record_tick():
if not client.is_connected():
raise Exception("行情服务连接断开")
current_timestamp = now_pd_timestamp()
if current_timestamp.hour == 15 and current_timestamp.minute == 10:
if current_timestamp.hour >= 15 and current_timestamp.minute >= 10:
logger.info(f"record tick finished at: {current_timestamp}")
break


if __name__ == "__main__":
from apscheduler.schedulers.background import BackgroundScheduler

sched = BackgroundScheduler()
record_tick()
sched.add_job(func=record_tick, trigger="cron", hour=9, minute=18, day_of_week="mon-fri")
sched.start()
sched._thread.join()


# the __all__ is generated
Expand Down
23 changes: 22 additions & 1 deletion src/zvt/domain/quotes/stock/stock_quote.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,29 @@ class StockQuote(StockQuoteBase, Mixin):
total_cap = Column(Float)


class Stock1mQuote(StockQuoteBase, Mixin):
__tablename__ = "stock_1m_quote"
code = Column(String(length=32))
name = Column(String(length=32))

#: UNIX时间戳
time = Column(Integer)
#: 最新价
price = Column(Float)
#: 均价
avg_price = Column(Float)
# 涨跌幅
change_pct = Column(Float)
# 成交量
volume = Column(Float)
# 成交金额
turnover = Column(Float)
# 换手率
turnover_rate = Column(Float)


register_schema(providers=["qmt"], db_name="stock_quote", schema_base=StockQuoteBase, entity_type="stock")


# the __all__ is generated
__all__ = ["StockQuote"]
__all__ = ["StockQuote", "Stock1mQuote"]
6 changes: 3 additions & 3 deletions src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def record(self, entity, start, end, size, timestamps):
self.logger.info(f"no kdata for {entity.id}")


class EMStockKdataRecorder(BaseQmtKdataRecorder):
class QMTStockKdataRecorder(BaseQmtKdataRecorder):
entity_schema = Stock
data_schema = StockKdataCommon


if __name__ == "__main__":
# Stock.record_data(provider="exchange")
EMStockKdataRecorder(entity_id="stock_sz_000338", adjust_type=AdjustType.hfq).run()
QMTStockKdataRecorder(entity_id="stock_sz_000338", adjust_type=AdjustType.hfq).run()


# the __all__ is generated
__all__ = ["BaseQmtKdataRecorder", "EMStockKdataRecorder"]
__all__ = ["BaseQmtKdataRecorder", "QMTStockKdataRecorder"]
9 changes: 9 additions & 0 deletions src/zvt/rest/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ActivateSubTagsResultModel,
ActivateSubTagsModel,
BatchSetStockTagsModel,
StockTagOptions,
)
from zvt.tag.tag_schemas import StockTags, MainTagInfo, SubTagInfo, HiddenTagInfo, StockPoolInfo, StockPools
from zvt.utils.time_utils import current_date
Expand Down Expand Up @@ -152,6 +153,14 @@ def query_simple_stock_tags(query_simple_stock_tags_model: QuerySimpleStockTagsM
return result_tags


@work_router.get("/get_stock_tag_options", response_model=StockTagOptions)
def get_stock_tag_options(entity_id: str):
"""
Get stock tag options
"""
return tag_service.get_stock_tag_options(entity_id=entity_id)


@work_router.post("/set_stock_tags", response_model=StockTagsModel)
def set_stock_tags(set_stock_tags_model: SetStockTagsModel):
"""
Expand Down
54 changes: 31 additions & 23 deletions src/zvt/tag/tag_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from zvt.contract.model import MixinModel, CustomModel
from zvt.tag.common import StockPoolType, TagType, TagStatsQueryType
from zvt.tag.tag_utils import get_main_tags, get_sub_tags, get_hidden_tags, get_stock_pool_names
from zvt.tag.tag_utils import get_stock_pool_names


class TagInfoModel(MixinModel):
Expand Down Expand Up @@ -67,6 +67,13 @@ class TagParameter(CustomModel):
sub_tag_reason: Optional[str] = None


class StockTagOptions(CustomModel):
main_tag: Optional[str] = None
sub_tag: Optional[str] = None
main_tag_options: List[CreateTagInfoModel]
sub_tag_options: List[CreateTagInfoModel]


class SetStockTagsModel(CustomModel):
entity_id: str
main_tag: str
Expand All @@ -75,28 +82,28 @@ class SetStockTagsModel(CustomModel):
sub_tag_reason: Optional[str] = None
active_hidden_tags: Optional[Dict[str, str]] = None

@field_validator("main_tag")
@classmethod
def main_tag_must_be_in(cls, v: str) -> str:
if v not in get_main_tags():
raise ValueError(f"main_tag: {v} must be created at main_tag_info at first")
return v

@field_validator("sub_tag")
@classmethod
def sub_tag_must_be_in(cls, v: str) -> str:
if v and (v not in get_sub_tags()):
raise ValueError(f"sub_tag: {v} must be created at sub_tag_info at first")
return v

@field_validator("active_hidden_tags")
@classmethod
def hidden_tag_must_be_in(cls, v: Union[Dict[str, str], None]) -> Union[Dict[str, str], None]:
if v:
for item in v.keys():
if item not in get_hidden_tags():
raise ValueError(f"hidden_tag: {v} must be created at hidden_tag_info at first")
return v
# @field_validator("main_tag")
# @classmethod
# def main_tag_must_be_in(cls, v: str) -> str:
# if v not in get_main_tags():
# raise ValueError(f"main_tag: {v} must be created at main_tag_info at first")
# return v
#
# @field_validator("sub_tag")
# @classmethod
# def sub_tag_must_be_in(cls, v: str) -> str:
# if v and (v not in get_sub_tags()):
# raise ValueError(f"sub_tag: {v} must be created at sub_tag_info at first")
# return v
#
# @field_validator("active_hidden_tags")
# @classmethod
# def hidden_tag_must_be_in(cls, v: Union[Dict[str, str], None]) -> Union[Dict[str, str], None]:
# if v:
# for item in v.keys():
# if item not in get_hidden_tags():
# raise ValueError(f"hidden_tag: {v} must be created at hidden_tag_info at first")
# return v


class StockPoolModel(MixinModel):
Expand Down Expand Up @@ -233,6 +240,7 @@ class ActivateSubTagsResultModel(CustomModel):
"QuerySimpleStockTagsModel",
"BatchSetStockTagsModel",
"TagParameter",
"StockTagOptions",
"SetStockTagsModel",
"StockPoolModel",
"StockPoolInfoModel",
Expand Down
52 changes: 52 additions & 0 deletions src/zvt/tag/tag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BatchSetStockTagsModel,
TagParameter,
CreateTagInfoModel,
StockTagOptions,
)
from zvt.tag.tag_schemas import (
StockTags,
Expand Down Expand Up @@ -53,10 +54,60 @@ def stock_tags_need_update(stock_tags: StockTags, set_stock_tags_model: SetStock
return False


def get_stock_tag_options(entity_id):
with contract_api.DBSession(provider="zvt", data_schema=StockTags)() as session:
datas: List[StockTags] = StockTags.query_data(
entity_id=entity_id, order=StockTags.timestamp.desc(), limit=1, return_type="domain", session=session
)
main_tag_options = []
sub_tag_options = []
main_tag = None
sub_tag = None
stock_tags = None
if datas:
stock_tags = datas[0]
main_tag = stock_tags.main_tag
sub_tag = stock_tags.sub_tag
main_tag_options = [
CreateTagInfoModel(tag=tag, tag_reason=tag_reason) for tag, tag_reason in stock_tags.main_tags.items()
]
sub_tag_options = [
CreateTagInfoModel(tag=tag, tag_reason=tag_reason) for tag, tag_reason in stock_tags.sub_tags.items()
]

main_tags_info: List[MainTagInfo] = MainTagInfo.query_data(session=session, return_type="domain")
main_tag_options = main_tag_options + [
CreateTagInfoModel(tag=item.tag, tag_reason=item.tag_reason)
for item in main_tags_info
if not stock_tags or (item.tag not in stock_tags.main_tags)
]

sub_tags_info: List[SubTagInfo] = SubTagInfo.query_data(session=session, return_type="domain")
sub_tag_options = sub_tag_options + [
CreateTagInfoModel(tag=item.tag, tag_reason=item.tag_reason)
for item in sub_tags_info
if not stock_tags or (item.tag not in stock_tags.sub_tags)
]
return StockTagOptions(
main_tag=main_tag, sub_tag=sub_tag, main_tag_options=main_tag_options, sub_tag_options=sub_tag_options
)


def build_stock_tags(
set_stock_tags_model: SetStockTagsModel, timestamp: pd.Timestamp, set_by_user: bool, keep_current=False
):
logger.info(set_stock_tags_model)

main_tag_info = CreateTagInfoModel(
tag=set_stock_tags_model.main_tag, tag_reason=set_stock_tags_model.main_tag_reason
)
if not is_tag_info_existed(tag_info=main_tag_info, tag_type=TagType.main_tag):
build_tag_info(tag_info=main_tag_info, tag_type=TagType.main_tag)

sub_tag_info = CreateTagInfoModel(tag=set_stock_tags_model.sub_tag, tag_reason=set_stock_tags_model.sub_tag_reason)
if not is_tag_info_existed(tag_info=sub_tag_info, tag_type=TagType.sub_tag):
build_tag_info(tag_info=sub_tag_info, tag_type=TagType.sub_tag)

with contract_api.DBSession(provider="zvt", data_schema=StockTags)() as session:
entity_id = set_stock_tags_model.entity_id
main_tags = {}
Expand Down Expand Up @@ -608,6 +659,7 @@ def activate_sub_tags(activate_sub_tags_model: ActivateSubTagsModel):
# the __all__ is generated
__all__ = [
"stock_tags_need_update",
"get_stock_tag_options",
"build_stock_tags",
"build_tag_parameter",
"batch_set_stock_tags",
Expand Down

0 comments on commit db69c41

Please sign in to comment.