forked from zhangpeng0v0/suangseqiu
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssq_model.py
83 lines (74 loc) · 2.43 KB
/
ssq_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# -*- coding:utf-8 -*-
"""
Author: Niuzepeng
"""
import os
import numpy as np
import pandas as pd
from config import *
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
DATA = pd.read_csv(train_data_path)
if not len(DATA):
raise Exception("请执行 get_ssq_data.py 进行数据下载!")
def transform_data(name):
""" 数据转换
:param name: 要训练的球号
:return:
"""
data_list = DATA[name].tolist()
data_list_len = len(data_list)
end_index = int(np.float(data_list_len / float(4)) * 4)
data_list.reverse()
fl = []
for index, _ in enumerate(data_list[0: end_index - 4]):
l_ = []
for i in range(4):
l_.append(data_list[index + i])
fl.append(l_)
return np.array(fl)
def create_model_data(name):
""" 创建训练数据
:param name: 要训练的球号
:return:
"""
data = transform_data(name)
x_data = data[:, 0:3].reshape([-1, 3, 1])
y_data = data[:, 3:].ravel()
return x_data, y_data
def train_model(x_data, y_data, b_name):
""" 模型训练
:param x_data: 训练样本
:param y_data: 训练标签
:param b_name: 球号名
:return:
"""
n_class = 0
if b_name[0] == "红":
n_class = 33
elif b_name[0] == "蓝":
n_class = 16
x_data = x_data - 1
y_data = to_categorical(y_data - 1, num_classes=n_class)
print("The x_data shape is {}".format(x_data.shape))
print("The y_data shape is {}".format(y_data.shape))
model = Sequential()
model.add(LSTM(32, input_shape=(3, 1), return_sequences=True))
model.add(LSTM(32, recurrent_dropout=0.2))
model.add(Dense(n_class, activation="softmax"))
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=["accuracy"])
callbacks = [
EarlyStopping(monitor='accuracy', patience=3, verbose=2, mode='max')
]
model.fit(x_data, y_data, batch_size=1, epochs=100, verbose=1, callbacks=callbacks)
if not os.path.exists("model"):
os.mkdir("model")
model.save("model/lstm_model_{}.h5".format(b_name))
if __name__ == '__main__':
for b_n in BOLL_NAME:
print("[INFO] 开始训练: {}".format(b_n))
x_train, y_train = create_model_data(b_n)
train_model(x_train, y_train, b_n)