diff --git a/cogment_lab/process_manager.py b/cogment_lab/process_manager.py index 4b8a0e0..32531e0 100644 --- a/cogment_lab/process_manager.py +++ b/cogment_lab/process_manager.py @@ -61,6 +61,8 @@ def __init__( mp_method: str = "spawn", orchestrator_port: int = 9000, datastore_port: int = 9003, + orchestrator_host: str = "scs-cogment-orchestrator", + datastore_host: str = "scs-cogment-trial-datastore", ): """Initializes the Cogment instance @@ -84,20 +86,14 @@ def __init__( self.envs: dict[ImplName, BaseEnv] = {} self.actors: dict[ImplName, BaseActor] = {} - self.orchestrator_endpoint = f"grpc://localhost:{orchestrator_port}" - self.datastore_endpoint = f"grpc://localhost:{datastore_port}" + # self.orchestrator_endpoint = f"grpc://{orchestrator_host}:{orchestrator_port}" + # self.datastore_endpoint = f"grpc://{datastore_host}:{datastore_port}" - self.context = cogment.Context(cog_settings=cog_settings, user_id=user_id) - controller = self.context.get_controller(endpoint=cogment.Endpoint(self.orchestrator_endpoint)) - datastore = self.context.get_datastore(endpoint=cogment.Endpoint(self.datastore_endpoint)) - - assert isinstance( - controller, Controller - ), "self.controller is not an instance of Controller. Please report this." - assert isinstance(datastore, Datastore), "self.datastore is not an instance of Datastore. Please report this." + # self.context = cogment.Context(cog_settings=cog_settings, user_id=user_id) + # controller = self.context.get_controller(endpoint=cogment.Endpoint(self.orchestrator_endpoint)) + # datastore = self.context.get_datastore(endpoint=cogment.Endpoint(self.datastore_endpoint)) - self.controller = controller - self.datastore = datastore + self.context = cogment.Context(cog_settings=cog_settings, user_id=user_id) self.env_ports: dict[ImplName, int] = {} self.actor_ports: dict[ImplName, int] = {} @@ -505,7 +501,8 @@ async def start_trial( name=agent_name, implementation=actor_impl, agent_specs=env.agent_specs[agent_name], - endpoint=f"grpc://localhost:{self.actor_ports[actor_impl]}", + # endpoint=f"grpc://localhost:{self.actor_ports[actor_impl]}", + endpoint="cogment://discover", ) for agent_name, actor_impl in actor_impls.items() ] @@ -515,14 +512,17 @@ async def start_trial( trial_params = cogment.TrialParameters( cog_settings, environment_name=env_name, - environment_endpoint=f"grpc://localhost:{self.env_ports[env_name]}", + # environment_endpoint=f"grpc://localhost:{self.env_ports[env_name]}", + environment_endpoint="cogment://discover", environment_config=env_config, actors=actor_params, environment_implementation=env_name, - datalog_endpoint=self.datastore_endpoint, + # datalog_endpoint=self.datastore_endpoint, + datalog_endpoint="cogment://discover", ) - trial_id = await self.controller.start_trial(trial_id_requested=trial_name, trial_params=trial_params) + controller = await self.context.get_controller() + trial_id = await controller.start_trial(trial_id_requested=trial_name, trial_params=trial_params) logging.info(f"Started trial {trial_id} with name {trial_name}") @@ -551,8 +551,10 @@ async def get_trial_data( env = self.envs[env_name] agent_specs = env.agent_specs + datastore = await self.context.get_datastore() + assert isinstance(datastore, Datastore), "datastore is not an instance of Datastore. Please report this." data = await format_data_multiagent( - datastore=self.datastore, + datastore=datastore, trial_id=trial_id, actor_agent_specs=agent_specs, fields=fields, @@ -571,7 +573,9 @@ async def get_trial(self, trial_id: str): Returns: Trial: The trial instance """ - [trial] = await self.datastore.get_trials(ids=[trial_id]) + datastore = await self.context.get_datastore() + assert isinstance(datastore, Datastore), "datastore is not an instance of Datastore. Please report this." + [trial] = await datastore.get_trials(ids=[trial_id]) return trial def __del__(self):