-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjoint_policy_train.py
73 lines (61 loc) · 2.97 KB
/
joint_policy_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
import numpy as np
from immitation_learning.Utilities import *
from immitation_learning.Sequencing import *
from immitation_learning.Loader import *
from immitation_learning.train import JointPolicy
def main(args):
# Load dataset
loader = DSSportsFormat(args.ds_path)
ds = loader.load_data()
all_off, all_def, all_ball, all_length = ds
# please change the loader function since it was for my specific case
# The output 'ds' is a tuple:
# ds[0]: all_off (List[np.ndarray]) - Group A data (e.g., offensive players).
# Shape for each array: (T, num_group_A_entities * features).
# ds[1]: all_def (List[np.ndarray]) - Group B data (e.g., defensive players).
# Shape for each array: (T, num_group_B_entities * features).
# ds[2]: all_ball (List[np.ndarray]) - Central entity data (e.g., the ball).
# Shape for each array: (T, features).
# ds[3]: all_length (List[int]) - Number of timesteps (T) for each sample.
# Load role mean data
off_means = np.load(args.off_means_path)
def_means = np.load(args.def_means_path)
# Role assignment
seq = RoleAssignment()
_, def_seq = seq.assign_roles(all_def, def_means, all_length)
_, off_seq = seq.assign_roles(all_off, off_means, all_length)
# Combine sequences for single game
single_game = [np.concatenate([def_seq[i], off_seq[i], all_ball[i]], axis=1) for i in range(len(def_seq))]
train, target = get_sequences(single_game, window_size=50, step_size=15)
# Iterate over batches
for batch in iterate_minibatches(train, target, args.batch_size, shuffle=False):
x_batch, y_batch = batch
x_curr = x_batch[:, 0:0+1, :]
x_up = feature_roll(0, x_curr)
# Define hyperparameters
hyperparam = {
'horizon': [10],
'num_policies': 11,
'batch_size': args.batch_size,
'time_steps': 1,
'dup_ft': 4,
'learning_rate': args.learning_rate,
'n_epoch': args.n_epoch,
'total_timesteps': 49,
}
# Train joint policy
j_p = JointPolicy(hyperparam)
j_p.train_joint_policy(train, target)
j_p.save_agents()
if __name__ == "__main__":
# Argument parser
parser = argparse.ArgumentParser(description="Train a joint policy for football data.")
parser.add_argument("--ds_path", type=str, required=True, help="Path to the dataset folder.")
parser.add_argument("--off_means_path", type=str, required=True, help="Path to the offensive role means file.")
parser.add_argument("--def_means_path", type=str, required=True, help="Path to the defensive role means file.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for training (default: 32).")
parser.add_argument("--learning_rate", type=float, default=0.0005, help="Learning rate for the model (default: 0.0005).")
parser.add_argument("--n_epoch", type=int, default=5, help="Number of training epochs (default: 5).")
args = parser.parse_args()
main(args)