Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Config objects #30

Merged
merged 27 commits into from
Mar 4, 2024
Merged

Remove Config objects #30

merged 27 commits into from
Mar 4, 2024

Conversation

ejnnr
Copy link
Owner

@ejnnr ejnnr commented Mar 1, 2024

Very much WIP (e.g. I haven't adjusted tests yet, abstractions are probably broken, and I'll probably also want to get rid of ScriptConfig). But I (perhaps naively) think most of the work is done, and I tested that things still work manually a bit (see e.g. notebooks/simple_demo.ipynb)

This will probably just replace #29 but I wanted to make a separate PR in case we decide to not use these changes but still use #29.

Adversarial examples are broken, I think they might be easier to fix after some bigger changes
More appropriate now to not call it `_config.py`
@VRehnberg VRehnberg mentioned this pull request Mar 1, 2024
ejnnr added 14 commits March 1, 2024 20:50
I think we should let the user handle this and just have big warning flags around WaNet---making sure we always do this correctly automatically seems nearly impossible so better to be explicit about that
Abstractions and tests are still very broken
I think we haven't been using these for a while
Tests all pass now; I removed one or two that aren't applicable anymore (notably checking whether WaNet loads correctly out of the box)
I think it doesn't make much sense intuitively to have them be arguments to the detector
We also want to ignore log dirs in e.g. the notebook folder
@ejnnr ejnnr marked this pull request as ready for review March 3, 2024 02:22
@ejnnr ejnnr requested a review from VRehnberg March 3, 2024 02:22
@ejnnr
Copy link
Owner Author

ejnnr commented Mar 3, 2024

Mostly done now (tests and the example notebook pass, and I've removed every single config class). Also cleaned up a few other things, so this PR ended up pretty huge. I think the main thing that would be good to look at @VRehnberg is notebooks/simple_demo.ipynb and check that the interface seems fine (we can make minor improvements later).

The other thing worth checking might be how WaNet is handled---I've removed all the logic for automagically ensuring the validation set uses the same control grid, or loading from disk. Basically, I think it's really hard for us to always handle those things correctly now that we just accept arbitrary pytorch Datasets (e.g. the user could pass in their own custom dataset, which wraps a BackdoorDataset). So I think we should let the user handle this, I've added a warning to the docstring, and tried to make it somewhat harder to forget loading the control grid by having a mandatory init argument. We can also have a notebook example that shows how to handle everything correctly. Still, this remains a bit of a footgun.

Of course you're welcome to check other things too, but I don't have concrete guesses for which parts are most likely to be buggy unfortunately.

@ejnnr
Copy link
Owner Author

ejnnr commented Mar 3, 2024

Also note that configs aren't saved to disk anymore, see #32 I don't think we actually need this at the moment but would definitely be nice to add back in a much simpler version

@ejnnr ejnnr mentioned this pull request Mar 4, 2024
Copy link
Collaborator

@VRehnberg VRehnberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface is nice. Two things to note:

  • The test_tampering.py fails for me (type mismatch when calculating the loss I believe) but it fails on main for me as well (different error) and I don't think this PR introduces the issue. Still, perhaps double check that.
  • I'm a bit unsure how you've been imagining that WanetBackdoor instances best be used right now (see comment).

Comment on lines +103 to +104
Within a single process, just make sure you only initialize WanetBackdoor once
and then use that everywhere.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Within a single process, just make sure you only initialize WanetBackdoor once
and then use that everywhere.
Within a single process, just make sure you only initialize a fresh WanetBackdoor once
and then reuse its warping pattern everywhere.

Surely you'll want to have different instances here to control p_backdoor and p_noise individually for different validation sets? Or am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added #34 for an example of how I'd imagine you'd be using these backdoors.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're right of course, thanks for #34!

@ejnnr
Copy link
Owner Author

ejnnr commented Mar 4, 2024

Hm, the tampering test passes for me both here and on main. Since it already fails for you on main anyway, I'll merge this PR now, but we should figure out what's going on probably. Can you paste the stack trace you're getting?

You're totally right about the need for multiple WaNet backdoor instances, I'll merge #34 or some version of it first and then merge into main.

@ejnnr ejnnr merged commit 7af54ed into main Mar 4, 2024
@VRehnberg
Copy link
Collaborator

@VRehnberg
Copy link
Collaborator

It's the torch.nn.functional.binary_cross_entropy that doesn't accept that the targets of data.TamperingDataset("diamonds") are boolean. MWE:

>>> import torch
>>> torch.nn.functional.binary_cross_entropy(torch.rand(10), (torch.rand(10) > 0.5).to(torch.bool))
*** RuntimeError: Found dtype Bool but expected Float
>>> torch.nn.functional.binary_cross_entropy(torch.rand(10), (torch.rand(10) > 0.5).to(torch.float32))
tensor(1.2316)

@VRehnberg
Copy link
Collaborator

The labels are boolean from the source and an additional all statement in tampering.py.

>>> datasets.load_dataset("redwoodresearch/diamonds-seed0", split="validation")["measurements"][0]
[True, True, True]

@ejnnr
Copy link
Owner Author

ejnnr commented Mar 5, 2024

Oh I think I know what's going on. Your MWE errors out on CPU for me but passes on MPS. Since pytorch lightning picks MPS by default over CPU if available, the tests pass for me. Guessing that Oliver was also on MPS (or maybe it also works on CUDA).

@VRehnberg
Copy link
Collaborator

As long as it also works with floats on MPS (CUDA errors out on booleans same as CPU). I'll create a PR that adds a typecast somewhere.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants