Skip to content

Commit

Permalink
Merge pull request #1 from XieBinghui/main
Browse files Browse the repository at this point in the history
minor bug fix by Binghui
  • Loading branch information
LFhase authored Apr 12, 2024
2 parents db13117 + 22b7a5e commit a9a434b
Show file tree
Hide file tree
Showing 8 changed files with 17 additions and 17 deletions.
5 changes: 2 additions & 3 deletions ColoredMNIST/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def step(self, closure=None):
grads.append(torch.cat(cur_grad))

G = torch.stack(grads)
if self.get_grad_sim:
grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
# if self.get_grad_sim:
# grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
GG = G @ G.T
moo_losses = np.stack([l.item() for l in losses])
reset_optimizer = False
Expand Down Expand Up @@ -120,7 +120,6 @@ def step(self, closure=None):

import numpy as np
import cvxpy as cp
import cvxopt

class EPO(object):
r"""
Expand Down
1 change: 0 additions & 1 deletion ColoredMNIST/pair_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def pair_train(mlp, topmlp, steps, envs, test_envs, lossf, \
optimizer.zero_grad()
optimizer.set_losses(losses=losses)
pair_loss, moo_losses, mu_rl, alphas = optimizer.step()
pair_res = np.array([pair_loss, mu_rl, alphas])
else:
loss = erm_loss

Expand Down
6 changes: 3 additions & 3 deletions ColoredMNIST/run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def main(flags):
np.save(os.path.join(flags.save_dir,'group%d.npy' % restart), group)
np.save(os.path.join(flags.save_dir,'pseudolabel%d.npy' % restart), pseudolabel )

logs = np.array(logs)
# logs = np.array(logs)

if flags.save_dir is not None:
np.save(os.path.join(flags.save_dir,'logs.npy'), logs)
# if flags.save_dir is not None:
# np.save(os.path.join(flags.save_dir,'logs.npy'), logs)

return results, logs

Expand Down
4 changes: 2 additions & 2 deletions PAIR/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def step(self, closure=None):
grads.append(torch.cat(cur_grad))

G = torch.stack(grads)
if self.get_grad_sim:
grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
# if self.get_grad_sim:
# grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
GG = G @ G.T
moo_losses = np.stack([l.item() for l in losses])
reset_optimizer = False
Expand Down
2 changes: 1 addition & 1 deletion WILDS/scripts/iwildcam.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ exp_dir=""; # the experiment directory will be the same as data_dir by default

for seed in {0..2};
do
CUDA_VISIBLE_DEVICES=${gpu} WANDB_MODE=offline python3 ./src/main.py --frozen --need_pretrain --use_old --data-dir ${data_dir} --dataset iwdilcam --algorithm pair -pc 0 -al -ac 1e-2 --seed ${seed}
CUDA_VISIBLE_DEVICES=${gpu} WANDB_MODE=offline python3 ./src/main.py --frozen --need_pretrain --use_old --data-dir ${data_dir} --dataset iwildcam --algorithm pair -pc 0 -al -ac 1e-2 --seed ${seed}
done;
3 changes: 2 additions & 1 deletion WILDS/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
Logger, return_predict_fn, return_criterion, fish_step

# This is secret and shouldn't be checked into version control
os.environ["WANDB_API_KEY"]=None
# os.environ["WANDB_API_KEY"]=None
# Name and notes optional
# WANDB_NAME="My first run"
# WANDB_NOTES="Smaller learning rate, more regularization."
Expand Down Expand Up @@ -192,6 +192,7 @@
else:
classifier = model.classifier
trainable_params = classifier.parameters() if args.frozen else model.parameters()
trainable_params = list(trainable_params)
optimiserC = opt(trainable_params, **args.optimiser_args)
predict_fn, criterion = return_predict_fn(args.dataset), return_criterion(args.dataset)

Expand Down
3 changes: 1 addition & 2 deletions WILDS/src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from .camelyon import Model as camelyon
from .cdsprites import Model as cdsprites
from .civil import Model as civil
from .fmow import Model as fmow
from .iwildcam import Model as iwildcam
from .poverty import Model as poverty
from .rxrx import Model as rxrx

__all__ = [cdsprites, iwildcam, camelyon, amazon, civil, fmow, poverty, rxrx]
__all__ = [iwildcam, camelyon, civil, fmow, poverty, rxrx]
10 changes: 6 additions & 4 deletions WILDS/src/pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def step(self, closure=None):
grads.append(torch.cat(cur_grad))

G = torch.stack(grads)
if self.get_grad_sim:
grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
# if self.get_grad_sim:
# grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True)
GG = G @ G.T
moo_losses = np.stack([l.item() for l in losses])
reset_optimizer = False
Expand All @@ -106,7 +106,10 @@ def step(self, closure=None):
alpha = self.preference / np.sum(self.preference)

scales = torch.from_numpy(alpha).float().to(losses[-1].device)
pair_loss = scales.dot(losses)
pair_loss = 0.0
for i in range(len(scales)):
pair_loss += scales[i] * losses[i]
# pair_loss = scales.dot(losses)
if reset_optimizer:
self.optimizer.param_groups[0]["lr"]/=5
# self.optimizer = torch.optim.Adam(self.params,lr=self.optimizer.param_groups[0]["lr"]/5)
Expand All @@ -120,7 +123,6 @@ def step(self, closure=None):

import numpy as np
import cvxpy as cp
import cvxopt

class EPO(object):
r"""
Expand Down

0 comments on commit a9a434b

Please sign in to comment.