forked from HumanCompatibleAI/overcooked-demo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgame.py
890 lines (718 loc) · 29.8 KB
/
game.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
from abc import ABC, abstractmethod
from threading import Lock, Thread
from queue import Queue, LifoQueue, Empty, Full
from time import time
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
from overcooked_ai_py.mdp.actions import Action, Direction
from overcooked_ai_py.planning.planners import MotionPlanner, NO_COUNTERS_PARAMS
from human_aware_rl.rllib.rllib import load_agent
import random, os, pickle, json
import ray
# Relative path to where all static pre-trained agents are stored on server
AGENT_DIR = None
# Maximum allowable game time (in seconds)
MAX_GAME_TIME = None
def _configure(max_game_time, agent_dir):
global AGENT_DIR, MAX_GAME_TIME
MAX_GAME_TIME = max_game_time
AGENT_DIR = agent_dir
class Game(ABC):
"""
Class representing a game object. Coordinates the simultaneous actions of arbitrary
number of players. Override this base class in order to use.
Players can post actions to a `pending_actions` queue, and driver code can call `tick` to apply these actions.
It should be noted that most operations in this class are not on their own thread safe. Thus, client code should
acquire `self.lock` before making any modifications to the instance.
One important exception to the above rule is `enqueue_actions` which is thread safe out of the box
"""
# Possible TODO: create a static list of IDs used by the class so far to verify id uniqueness
# This would need to be serialized, however, which might cause too great a performance hit to
# be worth it
EMPTY = 'EMPTY'
class Status:
DONE = 'done'
ACTIVE = 'active'
RESET = 'reset'
INACTIVE = 'inactive'
ERROR = 'error'
def __init__(self, *args, **kwargs):
"""
players (list): List of IDs of players currently in the game
spectators (set): Collection of IDs of players that are not allowed to enqueue actions but are currently watching the game
id (int): Unique identifier for this game
pending_actions List[(Queue)]: Buffer of (player_id, action) pairs have submitted that haven't been commited yet
lock (Lock): Used to serialize updates to the game state
is_active(bool): Whether the game is currently being played or not
"""
self.players = []
self.spectators = set()
self.pending_actions = []
self.id = kwargs.get('id', id(self))
self.lock = Lock()
self._is_active = False
@abstractmethod
def is_full(self):
"""
Returns whether there is room for additional players to join or not
"""
pass
@abstractmethod
def apply_action(self, player_idx, action):
"""
Updates the game state by applying a single (player_idx, action) tuple. Subclasses should try to override this method
if possible
"""
pass
@abstractmethod
def is_finished(self):
"""
Returns whether the game has concluded or not
"""
pass
def is_ready(self):
"""
Returns whether the game can be started. Defaults to having enough players
"""
return self.is_full()
@property
def is_active(self):
"""
Whether the game is currently being played
"""
return self._is_active
@property
def reset_timeout(self):
"""
Number of milliseconds to pause game on reset
"""
return 3000
def apply_actions(self):
"""
Updates the game state by applying each of the pending actions in the buffer. Is called by the tick method. Subclasses
should override this method if joint actions are necessary. If actions can be serialized, overriding `apply_action` is
preferred
"""
for i in range(len(self.players)):
try:
while True:
action = self.pending_actions[i].get(block=False)
self.apply_action(i, action)
except Empty:
pass
def activate(self):
"""
Activates the game to let server know real-time updates should start. Provides little functionality but useful as
a check for debugging
"""
self._is_active = True
def deactivate(self):
"""
Deactives the game such that subsequent calls to `tick` will be no-ops. Used to handle case where game ends but
there is still a buffer of client pings to handle
"""
self._is_active = False
def reset(self):
"""
Restarts the game while keeping all active players by resetting game stats and temporarily disabling `tick`
"""
if not self.is_active:
raise ValueError("Inactive Games cannot be reset")
if self.is_finished():
return self.Status.DONE
self.deactivate()
self.activate()
return self.Status.RESET
def needs_reset(self):
"""
Returns whether the game should be reset on the next call to `tick`
"""
return False
def tick(self):
"""
Updates the game state by applying each of the pending actions. This is done so that players cannot directly modify
the game state, offering an additional level of safety and thread security.
One can think of "enqueue_action" like calling "git add" and "tick" like calling "git commit"
Subclasses should try to override `apply_actions` if possible. Only override this method if necessary
"""
if not self.is_active:
return self.Status.INACTIVE
if self.needs_reset():
self.reset()
return self.Status.RESET
self.apply_actions()
return self.Status.DONE if self.is_finished() else self.Status.ACTIVE
def enqueue_action(self, player_id, action):
"""
Add (player_id, action) pair to the pending action queue, without modifying underlying game state
Note: This function IS thread safe
"""
if not self.is_active:
# Could run into issues with is_active not being thread safe
return
if player_id not in self.players:
# Only players actively in game are allowed to enqueue actions
return
try:
player_idx = self.players.index(player_id)
self.pending_actions[player_idx].put(action)
except Full:
pass
def get_state(self):
"""
Return a JSON compatible serialized state of the game. Note that this should be as minimalistic as possible
as the size of the game state will be the most important factor in game performance. This is sent to the client
every frame update.
"""
return { "players" : self.players }
def to_json(self):
"""
Return a JSON compatible serialized state of the game. Contains all information about the game, does not need to
be minimalistic. This is sent to the client only once, upon game creation
"""
return self.get_state()
def is_empty(self):
"""
Return whether it is safe to garbage collect this game instance
"""
return not self.num_players
def add_player(self, player_id, idx=None, buff_size=-1):
"""
Add player_id to the game
"""
if self.is_full():
raise ValueError("Cannot add players to full game")
if self.is_active:
raise ValueError("Cannot add players to active games")
if not idx and self.EMPTY in self.players:
idx = self.players.index(self.EMPTY)
elif not idx:
idx = len(self.players)
padding = max(0, idx - len(self.players) + 1)
for _ in range(padding):
self.players.append(self.EMPTY)
self.pending_actions.append(self.EMPTY)
self.players[idx] = player_id
self.pending_actions[idx] = Queue(maxsize=buff_size)
def add_spectator(self, spectator_id):
"""
Add spectator_id to list of spectators for this game
"""
if spectator_id in self.players:
raise ValueError("Cannot spectate and play at same time")
self.spectators.add(spectator_id)
def remove_player(self, player_id):
"""
Remove player_id from the game
"""
try:
idx = self.players.index(player_id)
self.players[idx] = self.EMPTY
self.pending_actions[idx] = self.EMPTY
except ValueError:
return False
else:
return True
def remove_spectator(self, spectator_id):
"""
Removes spectator_id if they are in list of spectators. Returns True if spectator successfully removed, False otherwise
"""
try:
self.spectators.remove(spectator_id)
except ValueError:
return False
else:
return True
def clear_pending_actions(self):
"""
Remove all queued actions for all players
"""
for i, player in enumerate(self.players):
if player != self.EMPTY:
queue = self.pending_actions[i]
queue.queue.clear()
@property
def num_players(self):
return len([player for player in self.players if player != self.EMPTY])
def get_data(self):
"""
Return any game metadata to server driver. Really only relevant for Psiturk code
"""
return {}
class DummyGame(Game):
"""
Standin class used to test basic server logic
"""
def __init__(self, **kwargs):
super(DummyGame, self).__init__(**kwargs)
self.counter = 0
def is_full(self):
return self.num_players == 2
def apply_action(self, idx, action):
pass
def apply_actions(self):
self.counter += 1
def is_finished(self):
return self.counter >= 100
def get_state(self):
state = super(DummyGame, self).get_state()
state['count'] = self.counter
return state
class DummyInteractiveGame(Game):
"""
Standing class used to test interactive components of the server logic
"""
def __init__(self, **kwargs):
super(DummyInteractiveGame, self).__init__(**kwargs)
self.max_players = int(kwargs.get('playerZero', 'human') == 'human') + int(kwargs.get('playerOne', 'human') == 'human')
self.max_count = kwargs.get('max_count', 30)
self.counter = 0
self.counts = [0] * self.max_players
def is_full(self):
return self.num_players == self.max_players
def is_finished(self):
return max(self.counts) >= self.max_count
def apply_action(self, player_idx, action):
if action.upper() == Direction.NORTH:
self.counts[player_idx] += 1
if action.upper() == Direction.SOUTH:
self.counts[player_idx] -= 1
def apply_actions(self):
super(DummyInteractiveGame, self).apply_actions()
self.counter += 1
def get_state(self):
state = super(DummyInteractiveGame, self).get_state()
state['count'] = self.counter
for i in range(self.num_players):
state['player_{}_count'.format(i)] = self.counts[i]
return state
class OvercookedGame(Game):
"""
Class for bridging the gap between Overcooked_Env and the Game interface
Instance variable:
- max_players (int): Maximum number of players that can be in the game at once
- mdp (OvercookedGridworld): Controls the underlying Overcooked game logic
- score (int): Current reward acheived by all players
- max_time (int): Number of seconds the game should last
- npc_policies (dict): Maps user_id to policy (Agent) for each AI player
- npc_state_queues (dict): Mapping of NPC user_ids to LIFO queues for the policy to process
- curr_tick (int): How many times the game server has called this instance's `tick` method
- ticker_per_ai_action (int): How many frames should pass in between NPC policy forward passes.
Note that this is a lower bound; if the policy is computationally expensive the actual frames
per forward pass can be higher
- action_to_overcooked_action (dict): Maps action names returned by client to action names used by OvercookedGridworld
Note that this is an instance variable and not a static variable for efficiency reasons
- human_players (set(str)): Collection of all player IDs that correspond to humans
- npc_players (set(str)): Collection of all player IDs that correspond to AI
- randomized (boolean): Whether the order of the layouts should be randomized
Methods:
- npc_policy_consumer: Background process that asynchronously computes NPC policy forward passes. One thread
spawned for each NPC
- _curr_game_over: Determines whether the game on the current mdp has ended
"""
def __init__(self, layouts=["cramped_room"], mdp_params={}, num_players=2, gameTime=30, playerZero='human', playerOne='human', showPotential=False, randomized=False, **kwargs):
super(OvercookedGame, self).__init__(**kwargs)
self.show_potential = showPotential
self.mdp_params = mdp_params
self.layouts = layouts
self.max_players = int(num_players)
self.mdp = None
self.mp = None
self.score = 0
self.phi = 0
self.max_time = min(int(gameTime), MAX_GAME_TIME)
self.npc_policies = {}
self.npc_state_queues = {}
self.action_to_overcooked_action = {
"STAY" : Action.STAY,
"UP" : Direction.NORTH,
"DOWN" : Direction.SOUTH,
"LEFT" : Direction.WEST,
"RIGHT" : Direction.EAST,
"SPACE" : Action.INTERACT
}
self.ticks_per_ai_action = 4
self.curr_tick = 0
self.human_players = set()
self.npc_players = set()
if randomized:
random.shuffle(self.layouts)
if playerZero != 'human':
player_zero_id = playerZero + '_0'
self.add_player(player_zero_id, idx=0, buff_size=1, is_human=False)
self.npc_policies[player_zero_id] = self.get_policy(playerZero, idx=0)
self.npc_state_queues[player_zero_id] = LifoQueue()
if playerOne != 'human':
player_one_id = playerOne + '_1'
self.add_player(player_one_id, idx=1, buff_size=1, is_human=False)
self.npc_policies[player_one_id] = self.get_policy(playerOne, idx=1)
self.npc_state_queues[player_one_id] = LifoQueue()
def _curr_game_over(self):
return time() - self.start_time >= self.max_time
def needs_reset(self):
return self._curr_game_over() and not self.is_finished()
def add_player(self, player_id, idx=None, buff_size=-1, is_human=True):
super(OvercookedGame, self).add_player(player_id, idx=idx, buff_size=buff_size)
if is_human:
self.human_players.add(player_id)
else:
self.npc_players.add(player_id)
def remove_player(self, player_id):
removed = super(OvercookedGame, self).remove_player(player_id)
if removed:
if player_id in self.human_players:
self.human_players.remove(player_id)
elif player_id in self.npc_players:
self.npc_players.remove(player_id)
else:
raise ValueError("Inconsistent state")
def npc_policy_consumer(self, policy_id):
queue = self.npc_state_queues[policy_id]
policy = self.npc_policies[policy_id]
while self._is_active:
state = queue.get()
npc_action, _ = policy.action(state)
super(OvercookedGame, self).enqueue_action(policy_id, npc_action)
def is_full(self):
return self.num_players >= self.max_players
def is_finished(self):
val = not self.layouts and self._curr_game_over()
return val
def is_empty(self):
"""
Game is considered safe to scrap if there are no active players or if there are no humans (spectating or playing)
"""
return super(OvercookedGame, self).is_empty() or not self.spectators and not self.human_players
def is_ready(self):
"""
Game is ready to be activated if there are a sufficient number of players and at least one human (spectator or player)
"""
return super(OvercookedGame, self).is_ready() and not self.is_empty()
def apply_action(self, player_id, action):
pass
def apply_actions(self):
# Default joint action, as NPC policies and clients probably don't enqueue actions fast
# enough to produce one at every tick
joint_action = [Action.STAY] * len(self.players)
# Synchronize individual player actions into a joint-action as required by overcooked logic
for i in range(len(self.players)):
try:
joint_action[i] = self.pending_actions[i].get(block=False)
except Empty:
pass
# Apply overcooked game logic to get state transition
prev_state = self.state
self.state, info = self.mdp.get_state_transition(prev_state, joint_action)
if self.show_potential:
self.phi = self.mdp.potential_function(prev_state, self.mp, gamma=0.99)
# Send next state to all background consumers if needed
if self.curr_tick % self.ticks_per_ai_action == 0:
for npc_id in self.npc_policies:
self.npc_state_queues[npc_id].put(self.state, block=False)
# Update score based on soup deliveries that might have occured
curr_reward = sum(info['sparse_reward_by_agent'])
self.score += curr_reward
# Return about the current transition
return prev_state, joint_action, info
def enqueue_action(self, player_id, action):
overcooked_action = self.action_to_overcooked_action[action]
super(OvercookedGame, self).enqueue_action(player_id, overcooked_action)
def reset(self):
status = super(OvercookedGame, self).reset()
if status == self.Status.RESET:
# Hacky way of making sure game timer doesn't "start" until after reset timeout has passed
self.start_time += self.reset_timeout / 1000
def tick(self):
self.curr_tick += 1
return super(OvercookedGame, self).tick()
def activate(self):
super(OvercookedGame, self).activate()
# Sanity check at start of each game
if not self.npc_players.union(self.human_players) == set(self.players):
raise ValueError("Inconsistent State")
self.curr_layout = self.layouts.pop()
self.mdp = OvercookedGridworld.from_layout_name(self.curr_layout, **self.mdp_params)
if self.show_potential:
self.mp = MotionPlanner.from_pickle_or_compute(self.mdp, counter_goals=NO_COUNTERS_PARAMS)
self.state = self.mdp.get_standard_start_state()
if self.show_potential:
self.phi = self.mdp.potential_function(self.state, self.mp, gamma=0.99)
self.start_time = time()
self.curr_tick = 0
self.score = 0
self.threads = []
for npc_policy in self.npc_policies:
self.npc_policies[npc_policy].reset()
self.npc_state_queues[npc_policy].put(self.state)
t = Thread(target=self.npc_policy_consumer, args=(npc_policy,))
self.threads.append(t)
t.start()
def deactivate(self):
super(OvercookedGame, self).deactivate()
# Ensure the background consumers do not hang
for npc_policy in self.npc_policies:
self.npc_state_queues[npc_policy].put(self.state)
# Wait for all background threads to exit
for t in self.threads:
t.join()
# Clear all action queues
self.clear_pending_actions()
def get_state(self):
state_dict = {}
state_dict['potential'] = self.phi if self.show_potential else None
state_dict['state'] = self.state.to_dict()
state_dict['score'] = self.score
state_dict['time_left'] = max(self.max_time - (time() - self.start_time), 0)
return state_dict
def to_json(self):
obj_dict = {}
obj_dict['terrain'] = self.mdp.terrain_mtx if self._is_active else None
obj_dict['state'] = self.get_state() if self._is_active else None
return obj_dict
def get_policy(self, npc_id, idx=0):
if npc_id.lower().startswith("rllib"):
try:
# Loading rllib agents requires additional helpers
fpath = os.path.join(AGENT_DIR, npc_id, 'agent', 'agent')
agent = load_agent(fpath, agent_index=idx)
return agent
except Exception as e:
raise IOError("Error loading Rllib Agent\n{}".format(e.__repr__()))
finally:
# Always kill ray after loading agent, otherwise, ray will crash once process exits
if ray.is_initialized():
ray.shutdown()
else:
try:
fpath = os.path.join(AGENT_DIR, npc_id, 'agent.pickle')
with open(fpath, 'rb') as f:
return pickle.load(f)
except Exception as e:
raise IOError("Error loading agent\n{}".format(e.__repr__()))
class OvercookedPsiturk(OvercookedGame):
"""
Wrapper on OvercookedGame that handles additional housekeeping for Psiturk experiments
Instance Variables:
- trajectory (list(dict)): list of state-action pairs in current trajectory
- psiturk_uid (string): Unique id for each psiturk game instance (provided by Psiturk backend)
Note, this is not the user id -- two users in the same game will have the same psiturk_uid
- trial_id (string): Unique identifier for each psiturk trial, updated on each call to reset
Note, one OvercookedPsiturk game handles multiple layouts. This is how we differentiate
Methods:
get_data: Returns the accumulated trajectory data and clears the self.trajectory instance variable
"""
def __init__(self, *args, psiturk_uid='-1', **kwargs):
super(OvercookedPsiturk, self).__init__(*args, showPotential=False, **kwargs)
self.psiturk_uid = psiturk_uid
self.trajectory = []
def activate(self):
"""
Resets trial ID at start of new "game"
"""
super(OvercookedPsiturk, self).activate()
self.trial_id = self.psiturk_uid + str(self.start_time)
def apply_actions(self):
"""
Applies pending actions then logs transition data
"""
# Apply MDP logic
prev_state, joint_action, info = super(OvercookedPsiturk, self).apply_actions()
# Log data to send to psiturk client
curr_reward = sum(info['sparse_reward_by_agent'])
transition = {
"state" : json.dumps(prev_state.to_dict()),
"joint_action" : json.dumps(joint_action),
"reward" : curr_reward,
"time_left" : max(self.max_time - (time() - self.start_time), 0),
"score" : self.score,
"time_elapsed" : time() - self.start_time,
"cur_gameloop" : self.curr_tick,
"layout" : json.dumps(self.mdp.terrain_mtx),
"layout_name" : self.curr_layout,
"trial_id" : self.trial_id,
"player_0_id" : self.players[0],
"player_1_id" : self.players[1],
"player_0_is_human" : self.players[0] in self.human_players,
"player_1_is_human" : self.players[1] in self.human_players
}
self.trajectory.append(transition)
def get_data(self):
"""
Returns and then clears the accumulated trajectory
"""
data = { "uid" : self.psiturk_uid + "_" + str(time()), "trajectory" : self.trajectory }
self.trajectory = []
return data
class OvercookedTutorial(OvercookedGame):
"""
Wrapper on OvercookedGame that includes additional data for tutorial mechanics, most notably the introduction of tutorial "phases"
Instance Variables:
- curr_phase (int): Indicates what tutorial phase we are currently on
- phase_two_score (float): The exact sparse reward the user must obtain to advance past phase 2
"""
def __init__(self, layouts=["tutorial_0"], mdp_params={}, playerZero='human', playerOne='AI', phaseTwoScore=15, **kwargs):
super(OvercookedTutorial, self).__init__(layouts=layouts, mdp_params=mdp_params, playerZero=playerZero, playerOne=playerOne, showPotential=False, **kwargs)
self.phase_two_score = phaseTwoScore
self.phase_two_finished = False
self.max_time = 0
self.max_players = 2
self.ticks_per_ai_action = 8
self.curr_phase = 0
@property
def reset_timeout(self):
return 1
def needs_reset(self):
if self.curr_phase == 0:
return self.score > 0
elif self.curr_phase == 1:
return self.score > 0
elif self.curr_phase == 2:
return self.phase_two_finished
return False
def is_finished(self):
return not self.layouts and self.score >= float('inf')
def reset(self):
super(OvercookedTutorial, self).reset()
self.curr_phase += 1
def get_policy(self, *args, **kwargs):
return TutorialAI()
def apply_actions(self):
"""
Apply regular MDP logic with retroactive score adjustment tutorial purposes
"""
_, _, info = super(OvercookedTutorial, self).apply_actions()
human_reward, ai_reward = info['sparse_reward_by_agent']
# We only want to keep track of the human's score in the tutorial
self.score -= ai_reward
# Phase two requires a specific reward to complete
if self.curr_phase == 2:
self.score = 0
if human_reward == self.phase_two_score:
self.phase_two_finished = True
class DummyOvercookedGame(OvercookedGame):
"""
Class that hardcodes the AI to be random. Used for debugging
"""
def __init__(self, layouts=["cramped_room"], **kwargs):
super(DummyOvercookedGame, self).__init__(layouts, **kwargs)
def get_policy(self, *args, **kwargs):
return DummyAI()
class DummyAI():
"""
Randomly samples actions. Used for debugging
"""
def action(self, state):
[action] = random.sample([Action.STAY, Direction.NORTH, Direction.SOUTH, Direction.WEST, Direction.EAST, Action.INTERACT], 1)
return action, None
def reset(self):
pass
class DummyComputeAI(DummyAI):
"""
Performs simulated compute before randomly sampling actions. Used for debugging
"""
def __init__(self, compute_unit_iters=1e5):
"""
compute_unit_iters (int): Number of for loop cycles in one "unit" of compute. Number of
units performed each time is randomly sampled
"""
super(DummyComputeAI, self).__init__()
self.compute_unit_iters = int(compute_unit_iters)
def action(self, state):
# Randomly sample amount of time to busy wait
iters = random.randint(1, 10) * self.compute_unit_iters
# Actually compute something (can't sleep) to avoid scheduling optimizations
val = 0
for i in range(iters):
# Avoid branch prediction optimizations
if i % 2 == 0:
val += 1
else:
val += 2
# Return randomly sampled action
return super(DummyComputeAI, self).action(state)
class StayAI():
"""
Always returns "stay" action. Used for debugging
"""
def action(self, state):
return Action.STAY, None
def reset(self):
pass
class TutorialAI():
COOK_SOUP_LOOP = [
# Grab first onion
Direction.WEST,
Direction.WEST,
Direction.WEST,
Action.INTERACT,
# Place onion in pot
Direction.EAST,
Direction.NORTH,
Action.INTERACT,
# Grab second onion
Direction.WEST,
Action.INTERACT,
# Place onion in pot
Direction.EAST,
Direction.NORTH,
Action.INTERACT,
# Grab third onion
Direction.WEST,
Action.INTERACT,
# Place onion in pot
Direction.EAST,
Direction.NORTH,
Action.INTERACT,
# Cook soup
Action.INTERACT,
# Grab plate
Direction.EAST,
Direction.SOUTH,
Action.INTERACT,
Direction.WEST,
Direction.NORTH,
# Deliver soup
Action.INTERACT,
Direction.EAST,
Direction.EAST,
Direction.EAST,
Action.INTERACT,
Direction.WEST
]
COOK_SOUP_COOP_LOOP = [
# Grab first onion
Direction.WEST,
Direction.WEST,
Direction.WEST,
Action.INTERACT,
# Place onion in pot
Direction.EAST,
Direction.SOUTH,
Action.INTERACT,
# Move to start so this loops
Direction.EAST,
Direction.EAST,
# Pause to make cooperation more real time
Action.STAY,
Action.STAY,
Action.STAY,
Action.STAY,
Action.STAY,
Action.STAY,
Action.STAY,
Action.STAY,
Action.STAY
]
def __init__(self):
self.curr_phase = -1
self.curr_tick = -1
def action(self, state):
self.curr_tick += 1
if self.curr_phase == 0:
return self.COOK_SOUP_LOOP[self.curr_tick % len(self.COOK_SOUP_LOOP)], None
elif self.curr_phase == 2:
return self.COOK_SOUP_COOP_LOOP[self.curr_tick % len(self.COOK_SOUP_COOP_LOOP)], None
return Action.STAY, None
def reset(self):
self.curr_tick = -1
self.curr_phase += 1