Skip to content

Commit

Permalink
combing training script
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitryduev committed Jan 14, 2021
1 parent 2ffa5bf commit 961310e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def train_and_eval(
weights: str = None,
save_model=False,
verbose=False,
**kwargs,
):
classifier = tails.models.DNN(name=model_name)

Expand All @@ -208,15 +209,18 @@ def train_and_eval(
# convert position RMSE to pixels
position_rmse = PositionRootMeanSquarredError(scaling_factor=scaling_factor)

learning_rate = kwargs.get("learning_rate", 3e-4)
patience = kwargs.get("patience", 30)

classifier.setup(
input_shape=input_shape,
n_output_neurons=3,
architecture="tails",
loss=tails_loss,
optimizer="adam",
lr=3e-4, # epsilon=1e-3, beta_1=0.7,
lr=learning_rate, # epsilon=1e-3, beta_1=0.7,
metrics=[label_accuracy, position_rmse],
patience=30,
patience=patience,
monitor="val_position_rmse",
restore_best_weights=True,
callbacks=("early_stopping", "tensorboard"),
Expand Down

0 comments on commit 961310e

Please sign in to comment.