Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial REINFORCE-kan code #8

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft

Add initial REINFORCE-kan code #8

wants to merge 3 commits into from

Conversation

db7894
Copy link

@db7894 db7894 commented May 5, 2024

Hello, not expecting/planning to merge right now but began writing up a basic take on KAN for reinforce based on the simplest possible version + one experiment, added efficient_kan as well for comparisons and a plot with results from one run (on my laptop, lol) with reinforce—results look pretty bad / a bit weird, and I only tested with 8 random seeds instead of 32 for a start.
Planning to keep hacking away at this, but thought I'd open a draft PR in case you wanted to discuss extending to more algorithms, since I think some refactoring is probably a good idea for people who might want to add more.

@riiswa
Copy link
Owner

riiswa commented May 5, 2024

Thanks for this PR :D, and indeed I'd like to see other algorithms added, it would be interesting to refactor the repo to make it more flexible and composable!

It may not be easy to apply KANs in an online setting, which explains the rather special results... How much faster is your KAN implementation than the official one? Do you think it's possible to run a hyperparameter search with optuna or something else (in a reasonable amount of time)?

@corentinlger also works on Reinforce, maybe you can compare your results.

@db7894
Copy link
Author

db7894 commented May 5, 2024

Awesome, on refactoring: let me throw out another PR or at least make an issue with ideas once I've played around more. Some thoughts off the top of my head based on usage so far are:

  • plot.py usage could be clearer / should allow specifying which RL algo whose results you want to plot, in addition to MLP vs KAN
  • it would be nice to have one orchestrator that can dispatch to various main methods for the different RL algos, so that user writing bash script or running an experiment can just drop in the algo they want to test with a single script. Along the same lines, maybe worth making the config file a script arg as well?

On the special results, my MLP results also look really bad (to me at least), so maybe I've messed something else up 😅

For EfficientKAN I credited this repo which I'm currently using as-is—I haven't had a chance to explore whether it's possible to squeeze out more juice and haven't yet checked exact runtimes (but that is on my todo list!). I'll see if I can do a hyperparameter search with just the efficient version.. I'm just working on my macbook right now, but it was able to at least do the 8-seed multirun experiment so maybe it'll hold up!

@db7894
Copy link
Author

db7894 commented May 5, 2024

I wouldn't call this anything definitive since I haven't done hyperparameter sweeps or anything, but using standard values and trying MLP, KAN, and the efficient version (with 16 seeds this time), I'm seeing this rather interesting set of results.

carpole_mlp_kan_efficientkan

@corentinlger
Copy link
Collaborator

Hi, thanks for the PR ! Do you know why the episode length exceeds 500 in the results ?

I also implemented Reinforce with an MLP and a KAN, and got those results on 5 seeds (300_000 steps of training on CartPole-v1) :

reinforce_results

I agree some files could be refactored to facilitate integration of new algorithms in the repo. We can discuss both points if you want !

@db7894
Copy link
Author

db7894 commented May 6, 2024

Gotcha, are you using any bells and whistles or just standard reinforce? My second plot (above comment) was with rtg—I'm still not sure why the efficient_kan version didn't run for the full 500 episodes (and KAN didn't learn anything!).

As far as why the episode lengths are so long... I'm going to step back and just look at the MLP version to see what's up. I'm not really sure.

On refactoring: I posted some ideas in my last comment, and PR #11 is a first step which just switches to use a main experiment driver that can dispatch to other algorithm scripts. Let me know if you have any thoughts!

[ PS: you might have already seen this, but an interesting notebook on KAN/MLP ]

@corentinlger
Copy link
Collaborator

Oh actually I was talking about the Episode length of 500 on the y axis ahah (which can be misleading because you also train for 500 episodes). But I saw you exceed 500 steps/ep sometimes because you don't use the truncated flag of the environment in your code.

And yes actually I implemented a slightly different algorithm than Reinforce. This is still a simple policy gradient algorithm but I updated the network every n_steps (and then I reset the environment) instead of updating it at the end of each episode.

@db7894
Copy link
Author

db7894 commented May 6, 2024

Yeah, I understood that you meant the y axis haha—thanks, it slipped by me that I wasn't using truncation! And gotcha, I'll play around and see if doing that gives me different results.

@yuzej
Copy link

yuzej commented May 7, 2024

Gotcha, are you using any bells and whistles or just standard reinforce? My second plot (above comment) was with rtg—I'm still not sure why the efficient_kan version didn't run for the full 500 episodes (and KAN didn't learn anything!).

As far as why the episode lengths are so long... I'm going to step back and just look at the MLP version to see what's up. I'm not really sure.

On refactoring: I posted some ideas in my last comment, and PR #11 is a first step which just switches to use a main experiment driver that can dispatch to other algorithm scripts. Let me know if you have any thoughts!

[ PS: you might have already seen this, but an interesting notebook on KAN/MLP ]

seems like KAN plays better than MLP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants