Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Functionality to Apply Constraints to Predictions #92

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

SimonKamuk
Copy link
Contributor

@SimonKamuk SimonKamuk commented Nov 29, 2024

Describe your changes

This change implements a method for constraining model output to a specified valid range. This is useful to ensure reliable model output for variables the cannot physically fall outside of this range - such as absolute temperature which must be positive or relative humidity which must be between 0 and 100%.

This is implemented by using the config.yaml for specifying valid ranges for each parameter, where each variable defaults to not having a limit. A scaled sigmoid function is then applied to the prediction for variables that have both an upper and lower limit, and a scaled softplus is used for variables that must be above or below a certain threshold.

Issue Link

closes #19

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • author has added an entry to the changelog (and designated the change as added, changed or fixed)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@joeloskarsson
Copy link
Collaborator

@SimonKamuk did you figure out a solution to #19 (comment) ? Sorry to comment already now, I know this is work in progress, I'm just curious about this :)

Thinking about it a bit more, I realized that one solution would be to just always apply the skip-connection before any activation function. So that the skip-connection is for the non-clamped values. E.g. since both sigmoid and softplus is invertible you could do something like $f(f^{-1}(X^t) + \text{model}())$ (although there are probably better ways to implement it.

@SimonKamuk
Copy link
Contributor Author

@SimonKamuk did you figure out a solution to #19 (comment) ? Sorry to comment already now, I know this is work in progress, I'm just curious about this :)

Thinking about it a bit more, I realized that one solution would be to just always apply the skip-connection before any activation function. So that the skip-connection is for the non-clamped values. E.g. since both sigmoid and softplus is invertible you could do something like f ( f − 1 ( X t ) + m o d e l ( ) ) (although there are probably better ways to implement it.

That's quite a neat way to do it!

I initially applied the clamping function to the new state, f(X_t+model()) but then realized this would mess with the residual connection, so what I implemented is basically this:

  • First I scale the clamping values according to the state normalization: clamping between [a,b] becomes [(a-mean)/std, (b-mean)/std]
  • Then when I apply the clamping activation functions during training i subtract the previous state from the limits, so it becomes [(a-mean)/std - X_t, (b-mean)/std - X_t]. But the function is only applied to the delta outputted from the model: X_(t+1)=X_t+f(model())

I wonder if there is an argument for using this method compared to your suggestion?

@joeloskarsson
Copy link
Collaborator

This is quite interesting and I'm trying to get some better understanding of what might be a good approach. I made this desmos interactive plot to try to wrap my head around it: https://www.desmos.com/calculator/tnrd6igkqb

Your method definitely also works correctly. I realized that this (clamping the delta to new bounds) is equivalent to not having any skip-connection to the previous state. Below I'm ignoring the mean and std rescaling for simplicity, and assume we want to rescale variable to $[l,u]$. For a clamping function $c$, applied to some model output $x$

$$ c(x,l,u) = (u-l)\sigma(x) + l $$

where $\sigma$ is sigmoid, the model output is

$$ X^t + \delta = X^t + c(x, l - X^t, u - X^t) = X^t + (u - X^t - (l - X^t))\sigma(x) + l - X^t = (u-l)\sigma(x) + l = c(x,l,u) $$

So this practically removes the dependence on the previous state.

The difference to my "invert-previous-clamping" approach is that that would equate skip connections on the logits, before any activation function (clamping here). So that does maintain some dependence. I'm not sure if this is important. A simple way to implement that approach would be to do the clamping in ARModel.unroll_prediction rather than in predict_step. Then the whole AR-unrolling happens with logits, and clamping only when the forecast is returned. That should work, since the loss is not computed based on any direct call to predict_step, but should maybe be double-checked.

This really relates to if one wants to use the skip-connections over states or not. I think it would eventually be nice to have this as an option. Maybe these two clamping strategies should correspond to the selection of that then?

@SimonKamuk
Copy link
Contributor Author

Oh wow that's a good catch. I agree that we want the option to keep the skip connection, so my method is not the way to go - even if it was applied after unrolling, because then we would still be removing the final skip-connection at every ar step (although the first ones would indeed be preserved). I'll have a go at implementing your inverse method

@SimonKamuk
Copy link
Contributor Author

I implemented your suggestion, but I added the constraint that the input (previous state) to the inverse sigmoid and softplus are clamped hard to avoid the inverse functions from returning inf - this would have prevented the model from ever outputting anything other than 1 if say relative humidity was clamped to [0,1] and the previous state was already 1.

$\sigma(\sigma^{−1}(1)+model()) = \sigma(\infty+model()) = \sigma(\infty) = 1$

But this should not be an issue, as the clamping is only applied to the previous state, not the model output itself, so the gradients can still be computed.

@joeloskarsson
Copy link
Collaborator

Hmm, yes that's an important consideration. Good that you thought about this. I'm guessing that the situation could occur that a variable is >=0, and an initial state where it is = 0 exists.

Note that gradients do go through also the previous state (we don't detach these from the computational graph), not just the model output, when we unroll during training. So the clamping does still impact gradients. However, I don't think this should be a problem in practice and this solution should work fine. In the case that the previous state comes from a model prediction during rollout, it should not be possible for it to hit exactly 0/1, so the clamping would anyhow not have an effect.

@SimonKamuk SimonKamuk marked this pull request as ready for review December 13, 2024 10:27
@SimonKamuk
Copy link
Contributor Author

I still don't understand why this last test is failing, could it be a resource issue? If anyone knows what is going on I'm all ears 😄 but as far as my changes are concerned I think this is ready for review

@joeloskarsson
Copy link
Collaborator

The test failing is probably not directly related to code, but to resources. I see Error: Process completed with exit code 247., but I'm not sure what that means (what exactly I should look up this exit code for). It seems to happen when testing the training loop, so might be related to memory or other resources.

I've added to my TODO list to give this a proper review. A couple high-level consideration in the meantime:

  1. Does most of this functionality (in particular the clamping prep/application methods) belong to BaseGraphModel, or should it sit already in ARModel? I am thinking that any model (even hypothethical non-graph models) would need these methods.
  2. When I described my idea for this I thought of it as inverting the activation function clamping from the previous time step. This is now how this is implemented. This does however mean that we have to clamp and unclamp these states all the time. The inverse clamp is a bit of unnecessary compute really. Another way to do this would be (from my comment above)

A simple way to implement that approach would be to do the clamping in ARModel.unroll_prediction rather than in predict_step. Then the whole AR-unrolling happens with logits, and clamping only when the forecast is returned. That should work, since the loss is not computed based on any direct call to predict_step, but should maybe be double-checked.

What's your thoughts on this? Are there good reasons to do it the "inversion"-way? The extra unnecessary compute is quite small, so maybe not an issue really, but doing the inverse-clamping is a bit more complicated and less transparent in showing that this is applying skip connections on pre-activation representations.

changed prepare_clamping_parames to prepare_clamping_params
@SimonKamuk
Copy link
Contributor Author

SimonKamuk commented Dec 16, 2024

The test failing is probably not directly related to code, but to resources. I see Error: Process completed with exit code 247., but I'm not sure what that means (what exactly I should look up this exit code for). It seems to happen when testing the training loop, so might be related to memory or other resources.

I've added to my TODO list to give this a proper review. A couple high-level consideration in the meantime:

  1. Does most of this functionality (in particular the clamping prep/application methods) belong to BaseGraphModel, or should it sit already in ARModel? I am thinking that any model (even hypothethical non-graph models) would need these methods.
  2. When I described my idea for this I thought of it as inverting the activation function clamping from the previous time step. This is now how this is implemented. This does however mean that we have to clamp and unclamp these states all the time. The inverse clamp is a bit of unnecessary compute really. Another way to do this would be (from my comment above)

A simple way to implement that approach would be to do the clamping in ARModel.unroll_prediction rather than in predict_step. Then the whole AR-unrolling happens with logits, and clamping only when the forecast is returned. That should work, since the loss is not computed based on any direct call to predict_step, but should maybe be double-checked.

What's your thoughts on this? Are there good reasons to do it the "inversion"-way? The extra unnecessary compute is quite small, so maybe not an issue really, but doing the inverse-clamping is a bit more complicated and less transparent in showing that this is applying skip connections on pre-activation representations.

  1. I've added my changes to BaseGraphModel because the predict_step method is not implemented in ARModel. I did consider putting it in ARModel, but I figured if someone went and made another model with a different predict_step (i.e. without the skip connection), then clamp_prediction would need to change. I could move prepare_clamping_params and clamp_prediction to ARModel, and then add a comment about clamp_prediction assuming a model with a skip connection, if you prefer? Or maybe just move the prepare_clamping_params to ARModel?

  2. My gut feeling was that the extra compute would be negligible, but maybe I should actually test what the impact is. Regarding whether to put it in predict_step or unroll_prediction I think I just felt that it was more clear for the model to predict physically consistent values at every time step. As you say it should not matter much for the loss, as only the output of unroll_prediction is fed to the loss. But if the clamping is applied at each prediction_step, then the model would only ever receive valid inputs and return valid outputs, an then wouldn't need to learn to interpret what a humidity above 100% means, which could possibly help with model accuracy.

SimonKamuk and others added 2 commits December 16, 2024 13:57
Added description of clamping feature in config.yaml
@joeloskarsson
Copy link
Collaborator

Thanks for the clarifications above @SimonKamuk ! You make some good points that I did not think about. In particular, as you write, we need to consider what actually goes into the model at each time step, which depends on when the clamping is applied. I need to think that over + then give this a full review. That will have to be in 2025 though, so just letting you know to not expect my review on this until after new years 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Request: Add Functionality to Apply Constraints to Predictions
2 participants