Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init state without configs #57

Merged
merged 19 commits into from
Apr 4, 2024
Merged

Conversation

corentinlger
Copy link
Collaborator

@corentinlger corentinlger commented Mar 20, 2024

Description

The PR modifies a lot of lines in the code, but there aren't tricky logic changes, we just modify the way to initialize the state without configs. I'll try to be as clear as I can when explaining the modifications.

I moved all the state classes and definitions from sim_computation.py to a new file called states.py.

The modifications also include refactoring imports in multiple files to align with the new file structure, enhancing readability and preventing linting errors.

In states.py, five functions have been implemented to initialize various states, such as simulator_state, nve_state, agent_state, objects_state, and the combined state. These changes offer greater flexibility in initializing states, allowing for customization such as specifying agent positions or using random values. Future iterations will extend this approach to more attributes, such as colors, diameters, behaviors ... to provide even more customization options.

Additionally, default values have been defined within these classes, and type hints have been added to functions for clarity. But I can still add some more docstrings to define what each parameter does.

Related Issue (if applicable)

closes #51

How to Test

Launch the server

python3 scripts/run_server.py

Launch the Panel interface

panel serve scripts/run_interface.py --autoreload

Check if nothing is broken with these two lines, normally it should be the case because only code changing the state init has been modified.

Normally I added some ugly lines to show that you can initiate the position of agents with a list in run_server. Run :

python3 scripts/run_server.py --custom_pos

You can also try to launch the simulation with a custom number of agents or objects (will be usefull later for eco-evo simulations). Actually for this feature I think it would make more sense to rename the current n_agents to max_agents, and n_existing_agents to n_agents (same for objects).

python3 scripts/run_server.py --n_existing_agents 7 --n_existing_objects 1

And relaunch the interface, check that the agents are indeed defined on the diagonal of the box. Later, this will easiely enable to define custom scenes, notably when using external files to store th epositions, colors ... as mentionned in #54

Screenshots (if applicable)

else:
positions = None

simulator_state = init_simulator_state(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part works to show that we can use a list of coordinates, but I think it should be removed before the merge

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this is meant to be deleted. I just added it to make the testing of the PR easier

Copy link
Collaborator

@Marsolo1 Marsolo1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new system seems to work fine for me, when this PR is merged i'll add the Hydra config files so that we can finally have scene files

Copy link
Owner

@clement-moulin-frier clement-moulin-frier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have started to review this PR but unfortunately I don't have time to finish this morning. Please start to have a look at my comments so that we can discuss it this afternoon. I'll try to finish it today.

  • Merge last master(no collision parameters)

  • When running the server with --n_existing_agents 7 --n_existing_objects 1, then selecting a non-existing agent and object and activating the the visible checkbox, the selected entity appear for a very short time and disappear right away. I'm not sure it is specific to this PR though, it might be a more general bug. Can you have a look @Marsolo1 ?

from vivarium.simulator.grpc_server.simulator_server import serve

lg = logging.getLogger(__name__)

def parse_args():
parser = argparse.ArgumentParser(description='Simulator Configuration')
parser.add_argument('--custom_pos', action='store_true', help='Just a test arg to wether use custom or random pos')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if you want to keep it in the merge. A typo in wether

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: I see in your reply to Martial below that you plan to remove it, all good

Comment on lines +21 to +23
parser.add_argument('--n_existing_agents', type=int, default=10, help='Number of agents')
parser.add_argument('--n_objects', type=int, default=2, help='Number of objects')
parser.add_argument('--n_existing_objects', type=int, default=2, help='Number of agents')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your proposition that " it would make more sense to rename the current n_agents to max_agents, and n_existing_agents to n_agents (same for objects)."

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you already do it, but it will be useful to add an assert checking that the number of existing agents/objects does not exceed the max number

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I already check this

n_objects=n_objects)

nve_state = init_nve_state(
simulator_state,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong indentation



def test_init_simulator_args():
""" Test the initialization of state with arguments """

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might also want to test the case where we only provide a subset of the arguments, to see if it merges well with the default values

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another useful test to do would be to check if (some of) the values in the created state are indeed the ones that are provided in the args or defaults (e.g. just adding a few assert in all your test functions)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I'll add that

from typing import Optional, List, Union
from enum import Enum

import matplotlib.colors as mcolors

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but I just notice that this requires matplotlib. We should either:

  • Add matplotlib to the requirements.txt (if not already there)
  • Or: find a way to do the same thing without using matlplotlib. The mcolors function is I think just used to convert colors between an vector RGB format (used in State) and a string format (used in the config). Maybe there is a way to do this without relying on matlplotlib.

If possible, I think the second option is better, because imposing to install matlplotlib (which is quite a big lib) just for this. We can discuss it and do it later though, you can just create a new issue atm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is actually a deeper issue here, see my other comment below

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you only using matplotlib for this isn't really useful. But it is needed to launch the interface with panel anyway I am not wrong. That's why it was already in the requirements.txt.

return jnp.array(list(mcolors.to_rgb(color_str)))


def init_simulator_state(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do this in another Issue/PR, but it would be better if default values are defined at a single place in the code. For now we have them in several places, e.g. here, in the run scripts, in the Config classes (and maybe elsewhere), this will be very hard to maintain. We could e.g. have them in a dedicated file and all the classes/functions that need them would import them from there. Might be easier to do it in the context of #30 though (if you agree please update #30 or related PR so that we don't forget, otherwise let's discuss it this afternoon)



def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2):
assert (len(positions) == n_elements if positions else True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check but assert (positions is None or len(positions) == n_elements) is more clear ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed !

positions = random.uniform(key_pos, (n_elements, n_dims)) * box_size
return positions

def _init_existing(n_existing, n_elements):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what an "element" refers to here. Isn't it instead related to an entity? In that case change it to e.g. n_entities, it will be more clear (in general we should avoid multiplying the terms we use)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you do it report it everywhere you use '"element"

Comment on lines 197 to 199
# Also think it would be easier to only handle the nve state in the physics engine. (if it doesn't make the simulation run slower)
# Here it makes a really long function, and it isn't very modular
# Plus we store different attributes of entities (e.g Agents positions and sensors) in different dataclasses

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's discuss it, as we said last time it is part of a more general issue about how we handle conversions between data types that encode the simular infomation (AgentState, NVEState, Config, proto ..)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can create an issue about explaining what we have already discussed on this topic (a potential solution would be to have useful helper functions in the State class)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I wrote these comments before our discussion, I'll create an issue

n_objects = simulator_state.n_objects[0]
n_entities = n_agents + n_objects

key = random.PRNGKey(seed)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this follows the good practice about how to deal with randomness in JAX, see https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html

Typically, the seed should be defined in a single place in the code, and all functions requiring a random number should use jax.random.split. PRNGKey should be called only once in the entire code base. (it's how I understand it at least)

See if you can fix it here, otherwise create an issue (point at the code here so that we don't forget it)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good question. I am not sure about this (because in our case this key only serves to initialize random positions, and they key in simulator serve other purposes) but it will be easy to change if we want to. Actually I checked some RL repos in Jax and they seem to define keys with seeds in different parts of the code (ex: in environments and in algorithms).

From what I saw we only have another key initialization in state, we can just change the init methods so it doesn't accept a seed that generates a key, but directly a key (that you generate at the beginning of the script you run).

@Marsolo1
Copy link
Collaborator

When running the server with `--n_existing_agents 7 --n_existing_objects 1`, then selecting a non-existing agent and object and activating the the `visible` checkbox, the selected entity appear for a very short time and disappear right away. I'm not sure it is specific to this PR though, it might be a more general bug. Can you have a look @Marsolo1 ?

Yes I already saw this bug and fixed it but it was in another PR, so I just pushed the fix to main directly (it was only a False in the wrong place). If we merge this branch with main it should fix it here too

Copy link
Owner

@clement-moulin-frier clement-moulin-frier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the second part of my review, which is now completed. Please read the comments and then we discuss

Comment on lines +220 to +221
key_pos, key_or = random.split(key)
key_ag, key_obj = random.split(key_pos)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very standard, see JAX doc on split is usually used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you prefer to do a random.split(key, 4) ?

use_fori_loop=jnp.array([1*use_fori_loop]))


def _init_positions(key_pos, positions, n_elements, box_size, n_dims=2):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the the constant SPACE_NDIMS defined in sim_computation.

Also I'm not sure that having a single function for either casting positions into a jax array, or generating random ones, helps. It makes the function hard quite strange (if one wants to just cast existing positions, most of this function's arguments become useless), whereas casting to JAX array is just a single line. I would recommend to simply have a function for generating random positions ; and just explicitly cast to jnp.array where it is needed without calling this function

existing_objects = _init_existing(existing_objects, n_objects)
exists = jnp.concatenate((existing_agents, existing_objects), dtype=int)

# TODO: Why is momentum set to none ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the way it is done in jax_md documentation I think

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then jax_md initialize it (in our step function I think)

Comment on lines +202 to +205
diameter: float = 5.,
friction: float = 0.1,
mass_center: float = 1.,
mass_orientation: float = 0.125,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to pass a full array as well here, e.g. if we want to set different diameters for different entities. A solution could be to check the type of this argument : if it is a scalar, all entities will have the same value, if it is an array, we set them as such. Will require to explain it in a docstring ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I indicated this in a comment in the file normally, I just didn't do it for every attribute

)


# Could implement it as a static or class method

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that this function and the others could instead be static/class methods of their associated State class. Let's discuss it

Comment on lines +257 to +262
behavior: int = 1,
wheel_diameter: float = 2.,
speed_mul: float = 1.,
theta_mul: float = 1.,
prox_dist_max: float = 40.,
prox_cos_min: float = 0.,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem, we should be able to pass arrays if we want to

theta_mul: float = 1.,
prox_dist_max: float = 40.,
prox_cos_min: float = 0.,
color: str = "blue"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be discussed, I'm not sure it is super coherent to pass colors as strings, since we are in the simulator world here. For example, behavior above is defined as an int, not a string (whereas it is defined as a string in the client Configs, as for colors).
Also, specifying colors as a RGB vector here will solve the issue I mentioned about matlplotlib above. Maybe the best solution would be to have a RGB <-> String converter in a utils module, that can be easily imported wherever we need this. For example here ; or in the run script

Comment on lines 269 to 271
# TODO : Allow to define custom list of behaviors, wheel_diameters ... (in fact for all parameters)
# if the shape if just 1 value, assign it to all agents
# else, ensure you are given a list of arguments of size n_agents and transform it into a jax array

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, we thought of the same thing :) I agree, let's do this


def init_object_state(
simulator_state: SimulatorState,
color: str = "red"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above for color

Comment on lines +303 to +316
def init_state(
simulator_state: SimulatorState,
agents_state: AgentState,
objects_state: ObjectState,
nve_state: NVEState
) -> State:

return State(
simulator_state=simulator_state,
agent_state=agents_state,
object_state=objects_state,
nve_state=nve_state
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A general comment to discuss:

If I understand well, with this PR a user will be able to define a custom state more easily from a run script, which is nice indeed. But if the objective is to simplify the life of the use, I am wondering if we shouldn't abstract the distinction between the NVEState and the Agent/ObjectState. How about to just having a function where a user can just enter what he wants (e.g. how many agents/objects, which positions etc ...) and then the function will deal with how to construct the State, hiding this complexity to the user.

@corentinlger corentinlger merged commit a44e5ed into main Apr 4, 2024
2 checks passed
@corentinlger corentinlger mentioned this pull request Apr 4, 2024
@corentinlger corentinlger deleted the corentin/init_state_without_configs branch April 16, 2024 15:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Initializing state without configs
3 participants