Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
XIE Binghui committed Apr 11, 2024
1 parent db13117 commit 606e16e
Show file tree
Hide file tree
Showing 8 changed files with 13 additions and 16 deletions.
5 changes: 2 additions & 3 deletions ColoredMNIST/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def step(self, closure=None):
grads.append(torch.cat(cur_grad))

G = torch.stack(grads)
if self.get_grad_sim:
grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
# if self.get_grad_sim:
# grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
GG = G @ G.T
moo_losses = np.stack([l.item() for l in losses])
reset_optimizer = False
Expand Down Expand Up @@ -120,7 +120,6 @@ def step(self, closure=None):

import numpy as np
import cvxpy as cp
import cvxopt

class EPO(object):
r"""
Expand Down
1 change: 0 additions & 1 deletion ColoredMNIST/pair_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def pair_train(mlp, topmlp, steps, envs, test_envs, lossf, \
optimizer.zero_grad()
optimizer.set_losses(losses=losses)
pair_loss, moo_losses, mu_rl, alphas = optimizer.step()
pair_res = np.array([pair_loss, mu_rl, alphas])
else:
loss = erm_loss

Expand Down
6 changes: 3 additions & 3 deletions ColoredMNIST/run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def main(flags):
np.save(os.path.join(flags.save_dir,'group%d.npy' % restart), group)
np.save(os.path.join(flags.save_dir,'pseudolabel%d.npy' % restart), pseudolabel )

logs = np.array(logs)
# logs = np.array(logs)

if flags.save_dir is not None:
np.save(os.path.join(flags.save_dir,'logs.npy'), logs)
# if flags.save_dir is not None:
# np.save(os.path.join(flags.save_dir,'logs.npy'), logs)

return results, logs

Expand Down
4 changes: 2 additions & 2 deletions PAIR/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def step(self, closure=None):
grads.append(torch.cat(cur_grad))

G = torch.stack(grads)
if self.get_grad_sim:
grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
# if self.get_grad_sim:
# grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
GG = G @ G.T
moo_losses = np.stack([l.item() for l in losses])
reset_optimizer = False
Expand Down
2 changes: 1 addition & 1 deletion WILDS/scripts/iwildcam.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ exp_dir=""; # the experiment directory will be the same as data_dir by default

for seed in {0..2};
do
CUDA_VISIBLE_DEVICES=${gpu} WANDB_MODE=offline python3 ./src/main.py --frozen --need_pretrain --use_old --data-dir ${data_dir} --dataset iwdilcam --algorithm pair -pc 0 -al -ac 1e-2 --seed ${seed}
CUDA_VISIBLE_DEVICES=${gpu} WANDB_MODE=offline python3 ./src/main.py --frozen --need_pretrain --use_old --data-dir ${data_dir} --dataset iwildcam --algorithm pair -pc 0 -al -ac 1e-2 --seed ${seed}
done;
3 changes: 2 additions & 1 deletion WILDS/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Logger, return_predict_fn, return_criterion, fish_step

# This is secret and shouldn't be checked into version control
os.environ["WANDB_API_KEY"]=None
# os.environ["WANDB_API_KEY"]=None
# Name and notes optional
# WANDB_NAME="My first run"
# WANDB_NOTES="Smaller learning rate, more regularization."
Expand Down Expand Up @@ -192,6 +192,7 @@
else:
classifier = model.classifier
trainable_params = classifier.parameters() if args.frozen else model.parameters()
trainable_params = list(trainable_params)
optimiserC = opt(trainable_params, **args.optimiser_args)
predict_fn, criterion = return_predict_fn(args.dataset), return_criterion(args.dataset)

Expand Down
3 changes: 1 addition & 2 deletions WILDS/src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .camelyon import Model as camelyon
from .cdsprites import Model as cdsprites
from .civil import Model as civil
from .fmow import Model as fmow
from .iwildcam import Model as iwildcam
from .poverty import Model as poverty
from .rxrx import Model as rxrx

__all__ = [cdsprites, iwildcam, camelyon, amazon, civil, fmow, poverty, rxrx]
__all__ = [iwildcam, camelyon, civil, fmow, poverty, rxrx]
5 changes: 2 additions & 3 deletions WILDS/src/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def step(self, closure=None):
grads.append(torch.cat(cur_grad))

G = torch.stack(grads)
if self.get_grad_sim:
grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
# if self.get_grad_sim:
# grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
GG = G @ G.T
moo_losses = np.stack([l.item() for l in losses])
reset_optimizer = False
Expand Down Expand Up @@ -120,7 +120,6 @@ def step(self, closure=None):

import numpy as np
import cvxpy as cp
import cvxopt

class EPO(object):
r"""
Expand Down

0 comments on commit 606e16e

Please sign in to comment.