Skip to content

Commit

Permalink
1)add return_unfinished arg for recorder 2)improve tags
Browse files Browse the repository at this point in the history
  • Loading branch information
foolcage committed Apr 4, 2024
1 parent 95a2950 commit 4c440a5
Show file tree
Hide file tree
Showing 27 changed files with 622 additions and 415 deletions.
37 changes: 33 additions & 4 deletions examples/recorder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,53 @@ def run_data_recorder(
entity_ids=None,
retry_times=10,
sleeping_time=10,
return_unfinished=False,
**recorder_kv,
):
logger.info(f" record data: {domain.__name__}, entity_provider: {entity_provider}, data_provider: {data_provider}")

unfinished_entity_ids = entity_ids
while retry_times > 0:
email_action = EmailInformer()

try:
domain.record_data(
entity_ids=entity_ids, provider=data_provider, sleeping_time=sleeping_time, **recorder_kv
)
if return_unfinished:
unfinished_entity_ids = domain.record_data(
entity_ids=unfinished_entity_ids,
provider=data_provider,
sleeping_time=sleeping_time,
return_unfinished=return_unfinished,
**recorder_kv,
)
if not unfinished_entity_ids:
unfinished_entity_ids = []
logger.info(f"unfinished_entity_ids({len(unfinished_entity_ids)}): {unfinished_entity_ids}")
if unfinished_entity_ids:
time.sleep(60 * 2)
retry_times = retry_times - 1
if retry_times == 0:
email_action.send_message(
zvt_config["email_username"],
f"record {domain.__name__} error",
f"record {domain.__name__} error: {e}",
)
continue
else:
domain.record_data(
entity_ids=entity_ids,
provider=data_provider,
sleeping_time=sleeping_time,
return_unfinished=return_unfinished,
**recorder_kv,
)

msg = f"record {domain.__name__} success"
logger.info(msg)
email_action.send_message(zvt_config["email_username"], msg, msg)
break
except Exception as e:
logger.exception("report error:{}".format(e))
time.sleep(60 * 3)
time.sleep(60 * 2)
retry_times = retry_times - 1
if retry_times == 0:
email_action.send_message(
Expand Down
34 changes: 34 additions & 0 deletions scripts/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
from examples.recorder_utils import run_data_recorder
from zvt.api.selector import get_entity_ids_by_filter
from zvt.domain import StockEvents, BlockStock
from zvt.tag import StockAutoTagger

if __name__ == "__main__":
data_provider = "em"
sleeping_time = 0.5
normal_stock_ids = get_entity_ids_by_filter(
provider="em", ignore_delist=True, ignore_st=False, ignore_new_stock=False
)

run_data_recorder(
entity_ids=normal_stock_ids,
day_data=False,
domain=StockEvents,
data_provider=data_provider,
force_update=False,
sleeping_time=sleeping_time,
return_unfinished=True,
)

run_data_recorder(
entity_ids=normal_stock_ids,
day_data=True,
domain=BlockStock,
data_provider=data_provider,
force_update=False,
sleeping_time=sleeping_time,
return_unfinished=True,
)

StockAutoTagger().tag()
14 changes: 2 additions & 12 deletions scripts/report_stock.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
# -*- coding: utf-8 -*-
import json

from examples.data_runner.kdata_runner import record_stock_data, record_stock_news, record_stock_events
from examples.reports.report_tops import report_top_stocks, report_top_blocks
from examples.data_runner.kdata_runner import record_stock_data
from examples.reports.report_tops import report_top_stocks
from examples.reports.report_vol_up import report_vol_up_stocks
from examples.utils import get_hot_topics
from zvt import zvt_config
from zvt.factors.top_stocks import compute_top_stocks
from zvt.informer import EmailInformer
from zvt.utils import current_date

if __name__ == "__main__":

# record_stock_news()

record_stock_data()
# record_stock_events()
compute_top_stocks()

report_top_stocks()
# report_top_blocks()
report_vol_up_stocks()
3 changes: 2 additions & 1 deletion src/zvt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,13 @@ def init_plugins():
logger.warning(f"failed to load plugin {name}", e)
logger.info(f"loaded plugins:{_plugins}")


def old_db_to_provider_dir(data_path):
files = os.listdir(data_path)
for file in files:
if file.endswith(".db"):
# Split the file name to extract the provider
provider = file.split('_')[0]
provider = file.split("_")[0]

# Define the destination directory
destination_dir = os.path.join(data_path, provider)
Expand Down
10 changes: 6 additions & 4 deletions src/zvt/api/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,19 @@ def get_entity_ids_by_filter(
entity_ids=None,
ignore_bj=False,
):
filters = [entity_schema.timestamp.isnot(None)]
if not target_date:
target_date = current_date()
filters = []
if ignore_new_stock:
if not target_date:
target_date = current_date()
pre_year = next_date(target_date, -365)
filters += [entity_schema.timestamp <= pre_year]
else:
filters += [entity_schema.timestamp <= target_date]
if target_date:
filters += [entity_schema.timestamp <= target_date]
if ignore_delist:
filters += [
entity_schema.name.not_like("%退%"),
entity_schema.name.not_like("%PT%"),
]

if ignore_st:
Expand Down
13 changes: 13 additions & 0 deletions src/zvt/contract/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
day_data=False,
entity_filters=None,
ignore_failed=True,
return_unfinished=False,
) -> None:
"""
:param code:
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
self.entity_ids = entity_ids
self.entity_filters = entity_filters
self.ignore_failed = ignore_failed
self.return_unfinished = return_unfinished

self.entity_session: Session = None
self.entities: List = None
Expand Down Expand Up @@ -191,6 +193,7 @@ def __init__(
fix_duplicate_way="add",
start_timestamp=None,
end_timestamp=None,
return_unfinished=False,
) -> None:
super().__init__(
force_update,
Expand All @@ -203,6 +206,7 @@ def __init__(
day_data=day_data,
entity_filters=entity_filters,
ignore_failed=ignore_failed,
return_unfinished=return_unfinished,
)

self.real_time = real_time
Expand Down Expand Up @@ -513,6 +517,11 @@ def run(self):
"recording data for entity_id:{},{},error:{}".format(entity_item.id, self.data_schema, e)
)
raising_exception = e
if self.return_unfinished:
self.on_finish()
unfinished_items = set(unfinished_items) - set(finished_items)
return [item.entity_id for item in unfinished_items]

finished_items = unfinished_items
break

Expand All @@ -522,6 +531,8 @@ def run(self):
break

self.on_finish()
if self.return_unfinished:
return []

if raising_exception:
raise raising_exception
Expand All @@ -547,6 +558,7 @@ def __init__(
level=IntervalLevel.LEVEL_1DAY,
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
return_unfinished=False,
) -> None:
super().__init__(
force_update,
Expand All @@ -563,6 +575,7 @@ def __init__(
fix_duplicate_way=fix_duplicate_way,
start_timestamp=start_timestamp,
end_timestamp=end_timestamp,
return_unfinished=return_unfinished,
)

self.level = IntervalLevel(level)
Expand Down
7 changes: 3 additions & 4 deletions src/zvt/contract/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def record_data(
"""
record data by the arguments
:param entity_id:
:param provider_index:
:param provider:
:param force_update:
Expand Down Expand Up @@ -286,12 +287,10 @@ def record_data(
kw[k] = kwargs[k]

r = recorder_class(**kw)
r.run()
return
return r.run()
else:
r = recorder_class(**kw)
r.run()
return
return r.run()
else:
print(f"no recorders for {cls.__name__}")

Expand Down
9 changes: 8 additions & 1 deletion src/zvt/factors/top_stocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,14 @@ def get_top_stocks(target_date, return_type="short"):
if datas:
assert len(datas) == 1
top_stock = datas[0]
if return_type == "short":
if return_type == "all":
short_stocks = json.loads(top_stock.short_stocks)
long_stocks = json.loads(top_stock.long_stocks)
small_vol_up_stocks = json.loads(top_stock.small_vol_up_stocks)
big_vol_up_stocks = json.loads(top_stock.big_vol_up_stocks)
all_stocks = list(set(short_stocks + long_stocks + small_vol_up_stocks + big_vol_up_stocks))
return all_stocks
elif return_type == "short":
stocks = json.loads(top_stock.short_stocks)
elif return_type == "long":
stocks = json.loads(top_stock.long_stocks)
Expand Down
2 changes: 2 additions & 0 deletions src/zvt/recorders/em/macro/em_treasury_yield_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
level=IntervalLevel.LEVEL_1DAY,
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
return_unfinished=False,
) -> None:
super().__init__(
force_update,
Expand All @@ -47,6 +48,7 @@ def __init__(
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)

def record(self, entity, start, end, size, timestamps):
Expand Down
5 changes: 2 additions & 3 deletions src/zvt/recorders/em/misc/em_stock_events_recorder.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
import pandas as pd

from zvt.contract.api import df_to_db
from zvt.contract.api import df_to_db, get_db_session, get_entities, get_data
from zvt.contract.recorder import FixedCycleDataRecorder
from zvt.domain import Stock
from zvt.domain.misc.stock_events import StockEvents
from zvt.domain.misc.stock_news import StockNews
from zvt.recorders.em import em_api
from zvt.utils import to_pd_timestamp, count_interval, now_pd_timestamp
from zvt.utils import to_pd_timestamp, count_interval, now_pd_timestamp, pd_is_not_null, now_time_str


class EMStockEventsRecorder(FixedCycleDataRecorder):
Expand Down
2 changes: 2 additions & 0 deletions src/zvt/recorders/em/quotes/em_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
adjust_type=AdjustType.qfq,
return_unfinished=False,
) -> None:
level = IntervalLevel(level)
self.adjust_type = AdjustType(adjust_type)
Expand All @@ -77,6 +78,7 @@ def __init__(
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)

def record(self, entity, start, end, size, timestamps):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
level=IntervalLevel.LEVEL_1DAY,
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
return_unfinished=False,
) -> None:
# 上证指数,深证成指,创业板指,科创板
support_codes = ["000001", "399001", "399006", "000688"]
Expand All @@ -59,6 +60,7 @@ def __init__(
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)

def record(self, entity, start, end, size, timestamps):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
compute_index_money_flow=False,
return_unfinished=False,
) -> None:
super().__init__(
force_update,
Expand All @@ -60,6 +61,7 @@ def __init__(
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)
self.compute_index_money_flow = compute_index_money_flow
get_token(zvt_config["jq_username"], zvt_config["jq_password"], force=True)
Expand Down
2 changes: 2 additions & 0 deletions src/zvt/recorders/joinquant/quotes/jq_index_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
level=IntervalLevel.LEVEL_1DAY,
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
return_unfinished=False,
) -> None:
level = IntervalLevel(level)
self.data_schema = get_kdata_schema(entity_type="index", level=level)
Expand All @@ -67,6 +68,7 @@ def __init__(
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)

def init_entities(self):
Expand Down
2 changes: 2 additions & 0 deletions src/zvt/recorders/joinquant/quotes/jq_stock_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
adjust_type=AdjustType.qfq,
return_unfinished=False,
) -> None:
level = IntervalLevel(level)
adjust_type = AdjustType(adjust_type)
Expand All @@ -68,6 +69,7 @@ def __init__(
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)

self.adjust_type = adjust_type
Expand Down
2 changes: 2 additions & 0 deletions src/zvt/recorders/qmt/quotes/qmt_kdata_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
kdata_use_begin_time=False,
one_day_trading_minutes=24 * 60,
adjust_type=AdjustType.qfq,
return_unfinished=False,
) -> None:
level = IntervalLevel(level)
self.adjust_type = AdjustType(adjust_type)
Expand All @@ -78,6 +79,7 @@ def __init__(
level,
kdata_use_begin_time,
one_day_trading_minutes,
return_unfinished,
)

def record(self, entity, start, end, size, timestamps):
Expand Down
Loading

0 comments on commit 4c440a5

Please sign in to comment.