Skip to content

Commit

Permalink
코드 정리
Browse files Browse the repository at this point in the history
  • Loading branch information
quantylab committed Mar 12, 2020
1 parent a2e3dc9 commit 33eb843
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 73 deletions.
128 changes: 56 additions & 72 deletions learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

class ReinforcementLearner:
__metaclass__ = abc.ABCMeta
lock = threading.Lock()

def __init__(self, rl_method='rl', stock_code=None,
chart_data=None, training_data=None,
Expand Down Expand Up @@ -72,8 +71,7 @@ def __init__(self, rl_method='rl', stock_code=None,
self.itr_cnt = 0
self.exploration_cnt = 0
self.batch_size = 0
self.pos_learning_cnt = 0
self.neg_learning_cnt = 0
self.learning_cnt = 0
# 로그 등 출력 경로
self.output_path = output_path

Expand Down Expand Up @@ -155,8 +153,48 @@ def reset(self):
self.itr_cnt = 0
self.exploration_cnt = 0
self.batch_size = 0
self.pos_learning_cnt = 0
self.neg_learning_cnt = 0
self.learning_cnt = 0

def build_sample(self):
self.environment.observe()
if len(self.training_data) > self.training_data_idx + 1:
self.training_data_idx += 1
self.sample = self.training_data.iloc[
self.training_data_idx].tolist()
self.sample.extend(self.agent.get_states())
return self.sample
return None

@abc.abstractmethod
def get_batch(self, batch_size, delayed_reward, discount_factor):
pass

def update_networks(self,
batch_size, delayed_reward, discount_factor):
# 배치 학습 데이터 생성
x, y_value, y_policy = self.get_batch(
batch_size, delayed_reward, discount_factor)
if len(x) > 0:
loss = 0
if y_value is not None:
# 가치 신경망 갱신
loss += self.value_network.train_on_batch(x, y_value)
if y_policy is not None:
# 정책 신경망 갱신
loss += self.policy_network.train_on_batch(x, y_policy)
return loss
return None

def fit(self, delayed_reward, discount_factor):
# 배치 학습 데이터 생성 및 신경망 갱신
if self.batch_size > 0:
_loss = self.update_networks(
self.batch_size, delayed_reward, discount_factor)
if _loss is not None:
self.loss += abs(_loss)
self.learning_cnt += 1
self.memory_learning_idx.append(self.training_data_idx)
self.batch_size = 0

def visualize(self, epoch_str, num_epoches, epsilon):
self.memory_action = [Agent.ACTION_HOLD] \
Expand Down Expand Up @@ -190,20 +228,6 @@ def visualize(self, epoch_str, num_epoches, epsilon):
'epoch_summary_{}.png'.format(epoch_str))
)

def fit(self, delayed_reward, discount_factor):
# 배치 학습 데이터 생성 및 신경망 갱신
if self.batch_size > 0:
_loss = self.update_networks(
self.batch_size, delayed_reward, discount_factor)
if _loss is not None:
self.loss += abs(_loss)
if delayed_reward > 0:
self.pos_learning_cnt += 1
else:
self.neg_learning_cnt += 1
self.memory_learning_idx.append(self.training_data_idx)
self.batch_size = 0

def run(
self, num_epoches=100, balance=10000000,
discount_factor=0.9, start_epsilon=0.5, learning=True):
Expand Down Expand Up @@ -267,7 +291,7 @@ def run(
if next_sample is None:
break

# n_step만큼 샘플 저장
# num_steps만큼 샘플 저장
q_sample.append(next_sample)
if len(q_sample) < self.num_steps:
continue
Expand Down Expand Up @@ -322,22 +346,18 @@ def run(
epoch_str = str(epoch + 1).rjust(num_epoches_digit, '0')
time_end_epoch = time.time()
elapsed_time_epoch = time_end_epoch - time_start_epoch
if self.pos_learning_cnt + self.neg_learning_cnt > 0:
self.loss /= self.pos_learning_cnt \
+ self.neg_learning_cnt
with self.lock:
logging.info("[{}][Epoch {}/{}] Epsilon:{:.4f} "
"#Expl.:{}/{} #Buy:{} #Sell:{} #Hold:{} "
"#Stocks:{} PV:{:,.0f} "
"POS:{} NEG:{} Loss:{:.6f} ET:{:.4f}".format(
self.stock_code, epoch_str,
num_epoches, epsilon,
self.exploration_cnt, self.itr_cnt,
self.agent.num_buy, self.agent.num_sell,
self.agent.num_hold, self.agent.num_stocks,
self.agent.portfolio_value,
self.pos_learning_cnt, self.neg_learning_cnt,
self.loss, elapsed_time_epoch))
if self.learning_cnt > 0:
self.loss /= self.learning_cnt
logging.info("[{}][Epoch {}/{}] Epsilon:{:.4f} "
"#Expl.:{}/{} #Buy:{} #Sell:{} #Hold:{} "
"#Stocks:{} PV:{:,.0f} "
"LC:{} Loss:{:.6f} ET:{:.4f}".format(
self.stock_code, epoch_str, num_epoches, epsilon,
self.exploration_cnt, self.itr_cnt,
self.agent.num_buy, self.agent.num_sell,
self.agent.num_hold, self.agent.num_stocks,
self.agent.portfolio_value, self.learning_cnt,
self.loss, elapsed_time_epoch))

# 에포크 관련 정보 가시화
self.visualize(epoch_str, num_epoches, epsilon)
Expand All @@ -359,42 +379,6 @@ def run(
code=self.stock_code, elapsed_time=elapsed_time,
max_pv=max_portfolio_value, cnt_win=epoch_win_cnt))

def build_sample(self):
self.environment.observe()
if len(self.training_data) > self.training_data_idx + 1:
self.training_data_idx += 1
self.sample = self.training_data.iloc[
self.training_data_idx].tolist()
self.sample.extend(self.agent.get_states())
return self.sample
return None

def get_action_network(self):
if self.policy_network is not None:
return self.policy_network
else:
return self.value_network

@abc.abstractmethod
def get_batch(self, batch_size, delayed_reward, discount_factor):
pass

def update_networks(self,
batch_size, delayed_reward, discount_factor):
# 배치 학습 데이터 생성
x, y_value, y_policy = self.get_batch(
batch_size, delayed_reward, discount_factor)
if len(x) > 0:
loss = 0
if y_value is not None:
# 가치 신경망 갱신
loss += self.value_network.train_on_batch(x, y_value)
if y_policy is not None:
# 정책 신경망 갱신
loss += self.policy_network.train_on_batch(x, y_policy)
return loss
return None

def save_models(self):
if self.value_network is not None and \
self.value_network_path is not None:
Expand Down
1 change: 0 additions & 1 deletion networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, input_dim=0, output_dim=0, lr=0.001,
self.loss = loss
self.model = None


def predict(self, sample):
with self.lock:
with graph.as_default():
Expand Down

0 comments on commit 33eb843

Please sign in to comment.